1 /++
2 Demonstrate context-state save and load for reproducible generation.
3 
4 Encodes a prompt into context 1, saves the state to a file, then restores
5 it in context 2 and generates the same token sequence — verifying that both
6 runs match.
7 
8 Usage: `save-load-state -m model.gguf [-n n_predict] [-ngl n_gpu_layers]
9                         [-s state_file] [prompt]`
10 +/
11 module save_load_state;
12 
13 import llama;
14 import std.stdio  : writef, writefln, writeln, stderr;
15 import std.conv   : to;
16 import core.stdc.locale : setlocale, LC_NUMERIC;
17 import core.stdc.stdio  : printf;
18 
19 int main(string[] args)
20 {
21     setlocale(LC_NUMERIC, "C");
22 
23     string modelPath;
24     string prompt    = "The quick brown fox";
25     string stateFile = "dump_state.bin";
26     int    ngl       = 0;
27     int    nPredict  = 16;
28 
29     for (int i = 1; i < cast(int) args.length; i++)
30     {
31         switch (args[i])
32         {
33             case "-m":
34                 if (++i < cast(int) args.length) modelPath = args[i];
35                 else return printUsage(args[0]);
36                 break;
37             case "-n":
38                 if (++i < cast(int) args.length) nPredict = args[i].to!int;
39                 else return printUsage(args[0]);
40                 break;
41             case "-ngl":
42                 if (++i < cast(int) args.length) ngl = args[i].to!int;
43                 else return printUsage(args[0]);
44                 break;
45             case "-s":
46                 if (++i < cast(int) args.length) stateFile = args[i];
47                 else return printUsage(args[0]);
48                 break;
49             default:
50                 prompt = args[i];
51                 break;
52         }
53     }
54     if (modelPath.length == 0) return printUsage(args[0]);
55 
56     loadAllBackends();
57 
58     auto model = LlamaModel.loadFromFile(modelPath, ngl);
59     if (!model)
60     {
61         stderr.writefln("error: failed to load model '%s'", modelPath);
62         return 1;
63     }
64 
65     const vocab  = model.vocab;
66     int   nBatch = 512;
67 
68     // ── Context 1: encode prompt + save state ────────────────────────────────
69     // kv_unified required for recurrent (SSM/Mamba) models when n_seq_max=1.
70     auto cp1       = contextParams(cast(uint)(nBatch + nPredict + 1), cast(uint) nBatch);
71     cp1.kv_unified = true;
72     auto ctx1 = LlamaContext.fromModel(model, cp1);
73     if (!ctx1)
74     {
75         stderr.writeln("error: failed to create context 1");
76         return 1;
77     }
78 
79     auto tokens = tokenize(vocab, prompt, /*addSpecial=*/true);
80     if (tokens is null)
81     {
82         stderr.writeln("error: tokenization failed");
83         return 1;
84     }
85 
86     // Decode all but the last token; save state BEFORE the final token so that
87     // both runs can decode it from the same recurrent state (important for
88     // hybrid SSM/Mamba models where replaying at a new position changes state).
89     if (tokens.length > 1)
90     {
91         auto ob = allocBatch(cast(int)(tokens.length - 1));
92         foreach (i, tok; tokens[0 .. $ - 1])
93             batchAdd(ob.get(), tok, cast(llama_pos) i, 0, false);
94 
95         if (ctx1.decode(ob.get()))
96         {
97             stderr.writeln("error: prompt pre-decode failed");
98             return 1;
99         }
100     }
101 
102     if (!ctx1.stateSaveFile(stateFile, tokens[0 .. $ - 1]))
103     {
104         stderr.writeln("error: stateSaveFile failed");
105         return 1;
106     }
107     stderr.writefln("state saved to '%s'", stateFile);
108 
109     // Decode the final prompt token to get generation-ready logits.
110     {
111         auto ob = allocBatch(1);
112         batchAdd(ob.get(), tokens[$ - 1], cast(llama_pos)(tokens.length - 1), 0, true);
113         if (ctx1.decode(ob.get()))
114         {
115             stderr.writeln("error: last-token decode failed");
116             return 1;
117         }
118     }
119 
120     // ── First generation run (ctx1) ──────────────────────────────────────────
121     auto smpl1 = SamplerChain.create();
122     smpl1.dist(1234u);
123 
124     string result1;
125     {
126         writef("\nfirst run: %s", prompt);
127 
128         // nPast = tokens.length because we decoded all prompt tokens above.
129         llama_pos nPast = cast(llama_pos) tokens.length;
130         auto ob = allocBatch(1);
131 
132         for (int g; g < nPredict; g++)
133         {
134             auto tok = smpl1.sample(ctx1);
135             if (isEog(vocab, tok)) break;
136 
137             string piece = tokenToString(vocab, tok);
138             writef("%s", piece);
139             result1 ~= piece;
140 
141             smpl1.accept(tok);
142             batchClear(ob.get());
143             batchAdd(ob.get(), tok, nPast, 0, true);
144 
145             if (ctx1.decode(ob.get()))
146             {
147                 stderr.writeln("\nerror: decode failed");
148                 return 1;
149             }
150             nPast++;
151         }
152         writeln("\n");
153     }
154 
155     // ── Context 2: restore state + re-generate ───────────────────────────────
156     auto cp2       = contextParams(cast(uint)(nBatch + nPredict + 1), cast(uint) nBatch);
157     cp2.kv_unified = true;
158     auto ctx2 = LlamaContext.fromModel(model, cp2);
159     if (!ctx2)
160     {
161         stderr.writeln("error: failed to create context 2");
162         return 1;
163     }
164 
165     auto smpl2 = SamplerChain.create();
166     smpl2.dist(1234u);
167 
168     string result2;
169     {
170         auto sessionToks = new llama_token[](tokens.length);
171         size_t nLoaded;
172 
173         if (!ctx2.stateLoadFile(stateFile, sessionToks, &nLoaded))
174         {
175             stderr.writeln("error: stateLoadFile failed");
176             return 1;
177         }
178         stderr.writefln("loaded state: %d tokens", nLoaded);
179 
180         // Decode the same final prompt token from the same base state that
181         // run 1 used. This guarantees identical logits for hybrid SSM models.
182         llama_pos nPast = cast(llama_pos) nLoaded;
183         {
184             auto ob = allocBatch(1);
185             batchAdd(ob.get(), tokens[$ - 1], nPast, 0, true);
186             if (ctx2.decode(ob.get()))
187             {
188                 stderr.writeln("error: last-token decode failed (ctx2)");
189                 return 1;
190             }
191             nPast++;
192         }
193 
194         writef("second run: %s", prompt);
195         auto ob = allocBatch(1);
196 
197         for (int g; g < nPredict; g++)
198         {
199             auto tok = smpl2.sample(ctx2);
200             if (isEog(vocab, tok)) break;
201 
202             string piece = tokenToString(vocab, tok);
203             writef("%s", piece);
204             result2 ~= piece;
205 
206             smpl2.accept(tok);
207             batchClear(ob.get());
208             batchAdd(ob.get(), tok, nPast, 0, true);
209 
210             if (ctx2.decode(ob.get()))
211             {
212                 stderr.writeln("\nerror: decode failed");
213                 return 1;
214             }
215             nPast++;
216         }
217         writeln("\n");
218     }
219 
220     if (result1 != result2)
221     {
222         stderr.writeln("FAIL: the two runs produced different output");
223         stderr.writefln("  run 1: %s", result1);
224         stderr.writefln("  run 2: %s", result2);
225         return 1;
226     }
227 
228     stderr.writeln("OK: both runs match");
229     return 0;
230 }
231 
232 int printUsage(string prog) @trusted nothrow
233 {
234     printf(
235         "\nusage: %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers]\n"
236         ~ "                       [-s state_file] [prompt]\n\n"
237         ~ "  -m  model path (GGUF)\n"
238         ~ "  -n  tokens to predict (default: 16)\n"
239         ~ "  -ngl GPU layers (default: 0)\n"
240         ~ "  -s  state file (default: dump_state.bin)\n\n",
241         prog.ptr);
242     return 1;
243 }