1 /++ 2 Demonstrate context-state save and load for reproducible generation. 3 4 Encodes a prompt into context 1, saves the state to a file, then restores 5 it in context 2 and generates the same token sequence — verifying that both 6 runs match. 7 8 Usage: `save-load-state -m model.gguf [-n n_predict] [-ngl n_gpu_layers] 9 [-s state_file] [prompt]` 10 +/ 11 module save_load_state; 12 13 import llama; 14 import std.stdio : writef, writefln, writeln, stderr; 15 import std.conv : to; 16 import core.stdc.locale : setlocale, LC_NUMERIC; 17 import core.stdc.stdio : printf; 18 19 int main(string[] args) 20 { 21 setlocale(LC_NUMERIC, "C"); 22 23 string modelPath; 24 string prompt = "The quick brown fox"; 25 string stateFile = "dump_state.bin"; 26 int ngl = 0; 27 int nPredict = 16; 28 29 for (int i = 1; i < cast(int) args.length; i++) 30 { 31 switch (args[i]) 32 { 33 case "-m": 34 if (++i < cast(int) args.length) modelPath = args[i]; 35 else return printUsage(args[0]); 36 break; 37 case "-n": 38 if (++i < cast(int) args.length) nPredict = args[i].to!int; 39 else return printUsage(args[0]); 40 break; 41 case "-ngl": 42 if (++i < cast(int) args.length) ngl = args[i].to!int; 43 else return printUsage(args[0]); 44 break; 45 case "-s": 46 if (++i < cast(int) args.length) stateFile = args[i]; 47 else return printUsage(args[0]); 48 break; 49 default: 50 prompt = args[i]; 51 break; 52 } 53 } 54 if (modelPath.length == 0) return printUsage(args[0]); 55 56 loadAllBackends(); 57 58 auto model = LlamaModel.loadFromFile(modelPath, ngl); 59 if (!model) 60 { 61 stderr.writefln("error: failed to load model '%s'", modelPath); 62 return 1; 63 } 64 65 const vocab = model.vocab; 66 int nBatch = 512; 67 68 // ── Context 1: encode prompt + save state ──────────────────────────────── 69 // kv_unified required for recurrent (SSM/Mamba) models when n_seq_max=1. 70 auto cp1 = contextParams(cast(uint)(nBatch + nPredict + 1), cast(uint) nBatch); 71 cp1.kv_unified = true; 72 auto ctx1 = LlamaContext.fromModel(model, cp1); 73 if (!ctx1) 74 { 75 stderr.writeln("error: failed to create context 1"); 76 return 1; 77 } 78 79 auto tokens = tokenize(vocab, prompt, /*addSpecial=*/true); 80 if (tokens is null) 81 { 82 stderr.writeln("error: tokenization failed"); 83 return 1; 84 } 85 86 // Decode all but the last token; save state BEFORE the final token so that 87 // both runs can decode it from the same recurrent state (important for 88 // hybrid SSM/Mamba models where replaying at a new position changes state). 89 if (tokens.length > 1) 90 { 91 auto ob = allocBatch(cast(int)(tokens.length - 1)); 92 foreach (i, tok; tokens[0 .. $ - 1]) 93 batchAdd(ob.get(), tok, cast(llama_pos) i, 0, false); 94 95 if (ctx1.decode(ob.get())) 96 { 97 stderr.writeln("error: prompt pre-decode failed"); 98 return 1; 99 } 100 } 101 102 if (!ctx1.stateSaveFile(stateFile, tokens[0 .. $ - 1])) 103 { 104 stderr.writeln("error: stateSaveFile failed"); 105 return 1; 106 } 107 stderr.writefln("state saved to '%s'", stateFile); 108 109 // Decode the final prompt token to get generation-ready logits. 110 { 111 auto ob = allocBatch(1); 112 batchAdd(ob.get(), tokens[$ - 1], cast(llama_pos)(tokens.length - 1), 0, true); 113 if (ctx1.decode(ob.get())) 114 { 115 stderr.writeln("error: last-token decode failed"); 116 return 1; 117 } 118 } 119 120 // ── First generation run (ctx1) ────────────────────────────────────────── 121 auto smpl1 = SamplerChain.create(); 122 smpl1.dist(1234u); 123 124 string result1; 125 { 126 writef("\nfirst run: %s", prompt); 127 128 // nPast = tokens.length because we decoded all prompt tokens above. 129 llama_pos nPast = cast(llama_pos) tokens.length; 130 auto ob = allocBatch(1); 131 132 for (int g; g < nPredict; g++) 133 { 134 auto tok = smpl1.sample(ctx1); 135 if (isEog(vocab, tok)) break; 136 137 string piece = tokenToString(vocab, tok); 138 writef("%s", piece); 139 result1 ~= piece; 140 141 smpl1.accept(tok); 142 batchClear(ob.get()); 143 batchAdd(ob.get(), tok, nPast, 0, true); 144 145 if (ctx1.decode(ob.get())) 146 { 147 stderr.writeln("\nerror: decode failed"); 148 return 1; 149 } 150 nPast++; 151 } 152 writeln("\n"); 153 } 154 155 // ── Context 2: restore state + re-generate ─────────────────────────────── 156 auto cp2 = contextParams(cast(uint)(nBatch + nPredict + 1), cast(uint) nBatch); 157 cp2.kv_unified = true; 158 auto ctx2 = LlamaContext.fromModel(model, cp2); 159 if (!ctx2) 160 { 161 stderr.writeln("error: failed to create context 2"); 162 return 1; 163 } 164 165 auto smpl2 = SamplerChain.create(); 166 smpl2.dist(1234u); 167 168 string result2; 169 { 170 auto sessionToks = new llama_token[](tokens.length); 171 size_t nLoaded; 172 173 if (!ctx2.stateLoadFile(stateFile, sessionToks, &nLoaded)) 174 { 175 stderr.writeln("error: stateLoadFile failed"); 176 return 1; 177 } 178 stderr.writefln("loaded state: %d tokens", nLoaded); 179 180 // Decode the same final prompt token from the same base state that 181 // run 1 used. This guarantees identical logits for hybrid SSM models. 182 llama_pos nPast = cast(llama_pos) nLoaded; 183 { 184 auto ob = allocBatch(1); 185 batchAdd(ob.get(), tokens[$ - 1], nPast, 0, true); 186 if (ctx2.decode(ob.get())) 187 { 188 stderr.writeln("error: last-token decode failed (ctx2)"); 189 return 1; 190 } 191 nPast++; 192 } 193 194 writef("second run: %s", prompt); 195 auto ob = allocBatch(1); 196 197 for (int g; g < nPredict; g++) 198 { 199 auto tok = smpl2.sample(ctx2); 200 if (isEog(vocab, tok)) break; 201 202 string piece = tokenToString(vocab, tok); 203 writef("%s", piece); 204 result2 ~= piece; 205 206 smpl2.accept(tok); 207 batchClear(ob.get()); 208 batchAdd(ob.get(), tok, nPast, 0, true); 209 210 if (ctx2.decode(ob.get())) 211 { 212 stderr.writeln("\nerror: decode failed"); 213 return 1; 214 } 215 nPast++; 216 } 217 writeln("\n"); 218 } 219 220 if (result1 != result2) 221 { 222 stderr.writeln("FAIL: the two runs produced different output"); 223 stderr.writefln(" run 1: %s", result1); 224 stderr.writefln(" run 2: %s", result2); 225 return 1; 226 } 227 228 stderr.writeln("OK: both runs match"); 229 return 0; 230 } 231 232 int printUsage(string prog) @trusted nothrow 233 { 234 printf( 235 "\nusage: %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers]\n" 236 ~ " [-s state_file] [prompt]\n\n" 237 ~ " -m model path (GGUF)\n" 238 ~ " -n tokens to predict (default: 16)\n" 239 ~ " -ngl GPU layers (default: 0)\n" 240 ~ " -s state file (default: dump_state.bin)\n\n", 241 prog.ptr); 242 return 1; 243 }