1 module llama.sampling;
2 
3 import llama.llama;
4 import llama.ctx : LlamaContext;
5 import llama.owned;
6 
7 /// A sampler chain you configure then use to pick the next token.
8 struct SamplerChain
9 {
10     mixin Owned!(llama_sampler, llama_sampler_free);
11 
12     /// Create a new sampler chain.
13     static SamplerChain create(bool noPerf = false) @nogc nothrow
14     {
15         auto p = llama_sampler_chain_default_params();
16         p.no_perf = noPerf;
17         return SamplerChain(llama_sampler_chain_init(p));
18     }
19 
20     /// Adds greedy (argmax) sampling.
21     ref SamplerChain greedy() @nogc nothrow return
22     {
23         llama_sampler_chain_add(_ptr, llama_sampler_init_greedy());
24         return this;
25     }
26 
27     /// Adds temperature scaling.
28     ref SamplerChain temp(float t) @nogc nothrow return
29     {
30         llama_sampler_chain_add(_ptr, llama_sampler_init_temp(t));
31         return this;
32     }
33 
34     /// Adds top-K filtering.
35     ref SamplerChain topK(int k) @nogc nothrow return
36     {
37         llama_sampler_chain_add(_ptr, llama_sampler_init_top_k(k));
38         return this;
39     }
40 
41     /// Adds top-P (nucleus) sampling.
42     ref SamplerChain topP(float p, size_t minKeep = 1) @nogc nothrow return
43     {
44         llama_sampler_chain_add(_ptr, llama_sampler_init_top_p(p, minKeep));
45         return this;
46     }
47 
48     /// Adds min-P sampling.
49     ref SamplerChain minP(float p, size_t minKeep = 1) @nogc nothrow return
50     {
51         llama_sampler_chain_add(_ptr, llama_sampler_init_min_p(p, minKeep));
52         return this;
53     }
54 
55     /// Adds stochastic (dist) sampling with an optional seed.
56     ref SamplerChain dist(uint seed = LLAMA_DEFAULT_SEED) @nogc nothrow return
57     {
58         llama_sampler_chain_add(_ptr, llama_sampler_init_dist(seed));
59         return this;
60     }
61 
62     /++
63     Adds repetition / frequency / presence penalties.
64     `penaltyLastN = -1` uses the full context; `penaltyLastN = 0` disables the penalty.
65     +/
66     ref SamplerChain penalties(int penaltyLastN = 64,
67                                float penaltyRepeat  = 1.0f,
68                                float penaltyFreq    = 0.0f,
69                                float penaltyPresent = 0.0f) @nogc nothrow return
70     {
71         llama_sampler_chain_add(_ptr,
72             llama_sampler_init_penalties(penaltyLastN, penaltyRepeat,
73                                          penaltyFreq, penaltyPresent));
74         return this;
75     }
76 
77     /// Adds typical-P sampling.
78     ref SamplerChain typical(float p, size_t minKeep = 1) @nogc nothrow return
79     {
80         llama_sampler_chain_add(_ptr, llama_sampler_init_typical(p, minKeep));
81         return this;
82     }
83 
84     /++ Adds temperature sampling with dynamic range extension (`delta` and `exponent`). +/
85     ref SamplerChain tempExt(float t, float delta, float exponent) @nogc nothrow return
86     {
87         llama_sampler_chain_add(_ptr, llama_sampler_init_temp_ext(t, delta, exponent));
88         return this;
89     }
90 
91     /++ Adds top-N-sigma sampling (keeps tokens within `n` sigma of the top logit). +/
92     ref SamplerChain topNSigma(float n) @nogc nothrow return
93     {
94         llama_sampler_chain_add(_ptr, llama_sampler_init_top_n_sigma(n));
95         return this;
96     }
97 
98     /++ Adds XTC (exclude top choices) sampling. +/
99     ref SamplerChain xtc(float p, float t, size_t minKeep = 1,
100                          uint seed = LLAMA_DEFAULT_SEED) @nogc nothrow return
101     {
102         llama_sampler_chain_add(_ptr, llama_sampler_init_xtc(p, t, minKeep, seed));
103         return this;
104     }
105 
106     /++ Adds Mirostat v2 sampling (adaptive entropy targeting). +/
107     ref SamplerChain mirostatV2(float tau = 5.0f, float eta = 0.1f,
108                                 uint seed = LLAMA_DEFAULT_SEED) @nogc nothrow return
109     {
110         llama_sampler_chain_add(_ptr, llama_sampler_init_mirostat_v2(seed, tau, eta));
111         return this;
112     }
113 
114     /++
115     Adds grammar-constrained sampling.
116     `grammarStr` is a GBNF grammar; `grammarRoot` is the root rule name (usually `"root"`).
117     +/
118     ref SamplerChain grammar(const(llama_vocab)* vocab,
119                              string grammarStr, string grammarRoot = "root") @trusted return
120     {
121         import std.string : toStringz;
122         llama_sampler_chain_add(_ptr,
123             llama_sampler_init_grammar(cast(llama_vocab*) vocab,
124                                        grammarStr.toStringz, grammarRoot.toStringz));
125         return this;
126     }
127 
128     /++
129     Adds DRY (Don't Repeat Yourself) sampling.
130     `seqBreakers` lists token strings that reset the repetition check (e.g. `["\n"]`).
131     Pass `seqBreakers = []` to use no breakers.
132     +/
133     ref SamplerChain dry(const(llama_vocab)* vocab,
134                          int nCtxTrain,
135                          float multiplier    = 0.0f,
136                          float base          = 1.75f,
137                          int allowedLength   = 2,
138                          int penaltyLastN    = -1,
139                          string[] seqBreakers = []) @trusted return
140     {
141         import std.string : toStringz;
142         auto cptrs = new const(char)*[](seqBreakers.length);
143         foreach (i, s; seqBreakers)
144             cptrs[i] = s.toStringz;
145         llama_sampler_chain_add(_ptr,
146             llama_sampler_init_dry(cast(llama_vocab*) vocab, nCtxTrain, multiplier, base,
147                                    allowedLength, penaltyLastN,
148                                    cptrs.ptr, cptrs.length));
149         return this;
150     }
151 
152     /++
153     Adds per-token logit bias adjustments.
154     Each entry in `biases` is a `{token, bias}` pair; `bias > 0` increases probability, `bias < 0` decreases it.
155     +/
156     ref SamplerChain logitBias(int nVocab, scope const(llama_logit_bias)[] biases) @trusted @nogc nothrow return
157     {
158         llama_sampler_chain_add(_ptr,
159             llama_sampler_init_logit_bias(nVocab, cast(int) biases.length,
160                                           cast(llama_logit_bias*) biases.ptr));
161         return this;
162     }
163 
164     /// Sample the next token. `batchIdx = -1` uses the last output position.
165     llama_token sample(llama_context* ctx, int batchIdx = -1) @nogc nothrow
166     {
167         return llama_sampler_sample(_ptr, ctx, batchIdx);
168     }
169 
170     /// Sample the next token from a `LlamaContext`.
171     llama_token sample(ref LlamaContext ctx, int batchIdx = -1) @nogc nothrow
172     {
173         return llama_sampler_sample(_ptr, ctx.ptr, batchIdx);
174     }
175 
176     /// Feed the accepted token back (needed for repetition penalties and similar).
177     void accept(llama_token token) @nogc nothrow { llama_sampler_accept(_ptr, token); }
178 
179     void printPerf() @nogc nothrow { llama_perf_sampler_print(_ptr); }
180 }