1 module llama.ctx; 2 3 import llama.llama; 4 import llama.model : LlamaModel; 5 import llama.owned; 6 7 /++ Context params with the given window and batch size. `nCtx = 0` uses the model's training length. +/ 8 llama_context_params contextParams(uint nCtx = 0, uint nBatch = 512, bool noPerf = false) @nogc nothrow 9 { 10 auto p = llama_context_default_params(); 11 p.n_ctx = nCtx; 12 p.n_batch = nBatch; 13 p.no_perf = noPerf; 14 return p; 15 } 16 17 /// A `llama_context` that frees itself on destruction. 18 struct LlamaContext 19 { 20 mixin Owned!(llama_context, llama_free); 21 22 /// Create from explicit params. 23 static LlamaContext fromModel(ref LlamaModel model, llama_context_params params) @nogc nothrow 24 { 25 return LlamaContext(llama_init_from_model(model.ptr, params)); 26 } 27 28 /// Create from a window size and batch size. 29 static LlamaContext fromModel(ref LlamaModel model, uint nCtx, uint nBatch = 512) @nogc nothrow 30 { 31 return LlamaContext(llama_init_from_model(model.ptr, contextParams(nCtx, nBatch))); 32 } 33 34 /// Decodes a token batch; returns 0 on success. 35 int decode(llama_batch batch) @nogc nothrow { return llama_decode(_ptr, batch); } 36 37 /// Encodes a batch (encoder-decoder models); returns 0 on success. 38 int encode(llama_batch batch) @nogc nothrow { return llama_encode(_ptr, batch); } 39 40 /++ Logits at output position `idx` (-1 = last). Valid until the next decode call. +/ 41 const(float)[] getLogits(int idx = -1) @trusted @nogc nothrow 42 { 43 auto model = llama_get_model(_ptr); 44 auto vocab = llama_model_get_vocab(model); 45 int nVocab = llama_vocab_n_tokens(cast(llama_vocab*) vocab); 46 return llama_get_logits_ith(_ptr, idx)[0 .. nVocab]; 47 } 48 49 @property uint nCtx() @nogc nothrow { return llama_n_ctx(_ptr); } /// Context window size in tokens. 50 51 /// Active pooling type as an int (compare to `LLAMA_POOLING_TYPE_*` constants). 52 @property int poolingType() @nogc nothrow { return cast(int) llama_pooling_type(_ptr); } 53 54 /++ Raw memory handle. Use for sequence management (copy, remove, shift, etc.). +/ 55 @property llama_memory_t memory() @nogc nothrow { return llama_get_memory(_ptr); } 56 57 /// Clear the KV cache. Pass `data = true` to also zero-fill memory. 58 void memoryClear(bool data = false) @nogc nothrow 59 { 60 llama_memory_clear(llama_get_memory(_ptr), data); 61 } 62 63 /++ 64 All output embeddings packed contiguously. 65 Valid after `decode`; shape is `[n_outputs * nEmbd]`. 66 Returns `null` when pooling is `LLAMA_POOLING_TYPE_NONE` or for generative models. 67 +/ 68 float[] getEmbeddings() @trusted @nogc nothrow 69 { 70 auto model = llama_get_model(_ptr); 71 int nEmbd = llama_model_n_embd(model); 72 float* p = llama_get_embeddings(_ptr); 73 return p ? p[0 .. nEmbd] : null; 74 } 75 76 /++ Embeddings for the `i`th output token (-1 = last). Returns `null` for invalid index. +/ 77 float[] getEmbeddingsIth(int i) @trusted @nogc nothrow 78 { 79 auto model = llama_get_model(_ptr); 80 int nEmbd = llama_model_n_embd(model); 81 float* p = llama_get_embeddings_ith(_ptr, i); 82 return p ? p[0 .. nEmbd] : null; 83 } 84 85 /++ Pooled embeddings for a sequence. Returns `null` when pooling is `LLAMA_POOLING_TYPE_NONE`. +/ 86 float[] getEmbeddingsSeq(llama_seq_id seqId) @trusted @nogc nothrow 87 { 88 auto model = llama_get_model(_ptr); 89 int nEmbd = llama_model_n_embd(model); 90 float* p = llama_get_embeddings_seq(_ptr, seqId); 91 return p ? p[0 .. nEmbd] : null; 92 } 93 94 // ── State (session) save / load ────────────────────────────────────────── 95 96 /// Byte count of the current state. Use this to size a buffer before `stateGetData`. 97 size_t stateGetSize() @nogc nothrow { return llama_state_get_size(_ptr); } 98 99 /++ Copy the current state into `dst`. Returns the number of bytes written. +/ 100 size_t stateGetData(ubyte[] dst) @trusted @nogc nothrow 101 { 102 return llama_state_get_data(_ptr, dst.ptr, dst.length); 103 } 104 105 /++ Restore the state from `src`. Returns the number of bytes consumed. +/ 106 size_t stateSetData(const(ubyte)[] src) @trusted @nogc nothrow 107 { 108 return llama_state_set_data(_ptr, src.ptr, src.length); 109 } 110 111 /++ 112 Save the state to a session file, recording `tokens` as the session prompt. 113 Returns `true` on success. 114 +/ 115 bool stateSaveFile(string path, const(llama_token)[] tokens) @trusted nothrow 116 { 117 import std.string : toStringz; 118 return llama_state_save_file(_ptr, path.toStringz, tokens.ptr, tokens.length); 119 } 120 121 /++ 122 Load state from a session file. On success `tokensOut` is filled and 123 `tokenCount` holds the number of tokens read; returns `true`. 124 +/ 125 bool stateLoadFile(string path, llama_token[] tokensOut, scope size_t* tokenCount) @trusted nothrow 126 { 127 import std.string : toStringz; 128 return llama_state_load_file(_ptr, path.toStringz, 129 tokensOut.ptr, tokensOut.length, tokenCount); 130 } 131 132 // ── Per-sequence KV state ──────────────────────────────────────────────── 133 134 /// Byte count required to snapshot sequence `seqId`. 135 size_t stateSeqGetSize(llama_seq_id seqId) @nogc nothrow 136 { 137 return llama_state_seq_get_size(_ptr, seqId); 138 } 139 140 /++ Copy sequence `seqId`'s KV cache into `dst`. Returns bytes written. +/ 141 size_t stateSeqGetData(ubyte[] dst, llama_seq_id seqId) @trusted @nogc nothrow 142 { 143 return llama_state_seq_get_data(_ptr, dst.ptr, dst.length, seqId); 144 } 145 146 /++ 147 Restore a KV snapshot from `src` into sequence `destSeqId`. 148 Returns bytes consumed; 0 means failure. 149 +/ 150 size_t stateSeqSetData(const(ubyte)[] src, llama_seq_id destSeqId) @trusted @nogc nothrow 151 { 152 return llama_state_seq_set_data(_ptr, src.ptr, src.length, destSeqId); 153 } 154 155 void printPerf() @nogc nothrow { llama_perf_context_print(_ptr); } 156 }