1 module llama.ctx;
2 
3 import llama.llama;
4 import llama.model : LlamaModel;
5 import llama.owned;
6 
7 /++ Context params with the given window and batch size. `nCtx = 0` uses the model's training length. +/
8 llama_context_params contextParams(uint nCtx = 0, uint nBatch = 512, bool noPerf = false) @nogc nothrow
9 {
10     auto p = llama_context_default_params();
11     p.n_ctx   = nCtx;
12     p.n_batch = nBatch;
13     p.no_perf = noPerf;
14     return p;
15 }
16 
17 /// A `llama_context` that frees itself on destruction.
18 struct LlamaContext
19 {
20     mixin Owned!(llama_context, llama_free);
21 
22     /// Create from explicit params.
23     static LlamaContext fromModel(ref LlamaModel model, llama_context_params params) @nogc nothrow
24     {
25         return LlamaContext(llama_init_from_model(model.ptr, params));
26     }
27 
28     /// Create from a window size and batch size.
29     static LlamaContext fromModel(ref LlamaModel model, uint nCtx, uint nBatch = 512) @nogc nothrow
30     {
31         return LlamaContext(llama_init_from_model(model.ptr, contextParams(nCtx, nBatch)));
32     }
33 
34     /// Decodes a token batch; returns 0 on success.
35     int decode(llama_batch batch) @nogc nothrow { return llama_decode(_ptr, batch); }
36 
37     /// Encodes a batch (encoder-decoder models); returns 0 on success.
38     int encode(llama_batch batch) @nogc nothrow { return llama_encode(_ptr, batch); }
39 
40     /++ Logits at output position `idx` (-1 = last). Valid until the next decode call. +/
41     const(float)[] getLogits(int idx = -1) @trusted @nogc nothrow
42     {
43         auto model  = llama_get_model(_ptr);
44         auto vocab  = llama_model_get_vocab(model);
45         int nVocab  = llama_vocab_n_tokens(cast(llama_vocab*) vocab);
46         return llama_get_logits_ith(_ptr, idx)[0 .. nVocab];
47     }
48 
49     @property uint nCtx() @nogc nothrow { return llama_n_ctx(_ptr); } /// Context window size in tokens.
50 
51     /// Active pooling type as an int (compare to `LLAMA_POOLING_TYPE_*` constants).
52     @property int poolingType() @nogc nothrow { return cast(int) llama_pooling_type(_ptr); }
53 
54     /++ Raw memory handle. Use for sequence management (copy, remove, shift, etc.). +/
55     @property llama_memory_t memory() @nogc nothrow { return llama_get_memory(_ptr); }
56 
57     /// Clear the KV cache. Pass `data = true` to also zero-fill memory.
58     void memoryClear(bool data = false) @nogc nothrow
59     {
60         llama_memory_clear(llama_get_memory(_ptr), data);
61     }
62 
63     /++
64     All output embeddings packed contiguously.
65     Valid after `decode`; shape is `[n_outputs * nEmbd]`.
66     Returns `null` when pooling is `LLAMA_POOLING_TYPE_NONE` or for generative models.
67     +/
68     float[] getEmbeddings() @trusted @nogc nothrow
69     {
70         auto model  = llama_get_model(_ptr);
71         int  nEmbd  = llama_model_n_embd(model);
72         float* p = llama_get_embeddings(_ptr);
73         return p ? p[0 .. nEmbd] : null;
74     }
75 
76     /++ Embeddings for the `i`th output token (-1 = last). Returns `null` for invalid index. +/
77     float[] getEmbeddingsIth(int i) @trusted @nogc nothrow
78     {
79         auto model  = llama_get_model(_ptr);
80         int  nEmbd  = llama_model_n_embd(model);
81         float* p = llama_get_embeddings_ith(_ptr, i);
82         return p ? p[0 .. nEmbd] : null;
83     }
84 
85     /++ Pooled embeddings for a sequence. Returns `null` when pooling is `LLAMA_POOLING_TYPE_NONE`. +/
86     float[] getEmbeddingsSeq(llama_seq_id seqId) @trusted @nogc nothrow
87     {
88         auto model  = llama_get_model(_ptr);
89         int  nEmbd  = llama_model_n_embd(model);
90         float* p = llama_get_embeddings_seq(_ptr, seqId);
91         return p ? p[0 .. nEmbd] : null;
92     }
93 
94     // ── State (session) save / load ──────────────────────────────────────────
95 
96     /// Byte count of the current state. Use this to size a buffer before `stateGetData`.
97     size_t stateGetSize() @nogc nothrow { return llama_state_get_size(_ptr); }
98 
99     /++ Copy the current state into `dst`. Returns the number of bytes written. +/
100     size_t stateGetData(ubyte[] dst) @trusted @nogc nothrow
101     {
102         return llama_state_get_data(_ptr, dst.ptr, dst.length);
103     }
104 
105     /++ Restore the state from `src`. Returns the number of bytes consumed. +/
106     size_t stateSetData(const(ubyte)[] src) @trusted @nogc nothrow
107     {
108         return llama_state_set_data(_ptr, src.ptr, src.length);
109     }
110 
111     /++
112     Save the state to a session file, recording `tokens` as the session prompt.
113     Returns `true` on success.
114     +/
115     bool stateSaveFile(string path, const(llama_token)[] tokens) @trusted nothrow
116     {
117         import std.string : toStringz;
118         return llama_state_save_file(_ptr, path.toStringz, tokens.ptr, tokens.length);
119     }
120 
121     /++
122     Load state from a session file. On success `tokensOut` is filled and
123     `tokenCount` holds the number of tokens read; returns `true`.
124     +/
125     bool stateLoadFile(string path, llama_token[] tokensOut, scope size_t* tokenCount) @trusted nothrow
126     {
127         import std.string : toStringz;
128         return llama_state_load_file(_ptr, path.toStringz,
129                                      tokensOut.ptr, tokensOut.length, tokenCount);
130     }
131 
132     // ── Per-sequence KV state ────────────────────────────────────────────────
133 
134     /// Byte count required to snapshot sequence `seqId`.
135     size_t stateSeqGetSize(llama_seq_id seqId) @nogc nothrow
136     {
137         return llama_state_seq_get_size(_ptr, seqId);
138     }
139 
140     /++ Copy sequence `seqId`'s KV cache into `dst`. Returns bytes written. +/
141     size_t stateSeqGetData(ubyte[] dst, llama_seq_id seqId) @trusted @nogc nothrow
142     {
143         return llama_state_seq_get_data(_ptr, dst.ptr, dst.length, seqId);
144     }
145 
146     /++
147     Restore a KV snapshot from `src` into sequence `destSeqId`.
148     Returns bytes consumed; 0 means failure.
149     +/
150     size_t stateSeqSetData(const(ubyte)[] src, llama_seq_id destSeqId) @trusted @nogc nothrow
151     {
152         return llama_state_seq_set_data(_ptr, src.ptr, src.length, destSeqId);
153     }
154 
155     void printPerf() @nogc nothrow { llama_perf_context_print(_ptr); }
156 }