1 module llama.sampling; 2 3 import llama.llama; 4 import llama.ctx : LlamaContext; 5 import llama.owned; 6 7 /// A sampler chain you configure then use to pick the next token. 8 struct SamplerChain 9 { 10 mixin Owned!(llama_sampler, llama_sampler_free); 11 12 /// Create a new sampler chain. 13 static SamplerChain create(bool noPerf = false) @nogc nothrow 14 { 15 auto p = llama_sampler_chain_default_params(); 16 p.no_perf = noPerf; 17 return SamplerChain(llama_sampler_chain_init(p)); 18 } 19 20 /// Adds greedy (argmax) sampling. 21 ref SamplerChain greedy() @nogc nothrow return 22 { 23 llama_sampler_chain_add(_ptr, llama_sampler_init_greedy()); 24 return this; 25 } 26 27 /// Adds temperature scaling. 28 ref SamplerChain temp(float t) @nogc nothrow return 29 { 30 llama_sampler_chain_add(_ptr, llama_sampler_init_temp(t)); 31 return this; 32 } 33 34 /// Adds top-K filtering. 35 ref SamplerChain topK(int k) @nogc nothrow return 36 { 37 llama_sampler_chain_add(_ptr, llama_sampler_init_top_k(k)); 38 return this; 39 } 40 41 /// Adds top-P (nucleus) sampling. 42 ref SamplerChain topP(float p, size_t minKeep = 1) @nogc nothrow return 43 { 44 llama_sampler_chain_add(_ptr, llama_sampler_init_top_p(p, minKeep)); 45 return this; 46 } 47 48 /// Adds min-P sampling. 49 ref SamplerChain minP(float p, size_t minKeep = 1) @nogc nothrow return 50 { 51 llama_sampler_chain_add(_ptr, llama_sampler_init_min_p(p, minKeep)); 52 return this; 53 } 54 55 /// Adds stochastic (dist) sampling with an optional seed. 56 ref SamplerChain dist(uint seed = LLAMA_DEFAULT_SEED) @nogc nothrow return 57 { 58 llama_sampler_chain_add(_ptr, llama_sampler_init_dist(seed)); 59 return this; 60 } 61 62 /++ 63 Adds repetition / frequency / presence penalties. 64 `penaltyLastN = -1` uses the full context; `penaltyLastN = 0` disables the penalty. 65 +/ 66 ref SamplerChain penalties(int penaltyLastN = 64, 67 float penaltyRepeat = 1.0f, 68 float penaltyFreq = 0.0f, 69 float penaltyPresent = 0.0f) @nogc nothrow return 70 { 71 llama_sampler_chain_add(_ptr, 72 llama_sampler_init_penalties(penaltyLastN, penaltyRepeat, 73 penaltyFreq, penaltyPresent)); 74 return this; 75 } 76 77 /// Adds typical-P sampling. 78 ref SamplerChain typical(float p, size_t minKeep = 1) @nogc nothrow return 79 { 80 llama_sampler_chain_add(_ptr, llama_sampler_init_typical(p, minKeep)); 81 return this; 82 } 83 84 /++ Adds temperature sampling with dynamic range extension (`delta` and `exponent`). +/ 85 ref SamplerChain tempExt(float t, float delta, float exponent) @nogc nothrow return 86 { 87 llama_sampler_chain_add(_ptr, llama_sampler_init_temp_ext(t, delta, exponent)); 88 return this; 89 } 90 91 /++ Adds top-N-sigma sampling (keeps tokens within `n` sigma of the top logit). +/ 92 ref SamplerChain topNSigma(float n) @nogc nothrow return 93 { 94 llama_sampler_chain_add(_ptr, llama_sampler_init_top_n_sigma(n)); 95 return this; 96 } 97 98 /++ Adds XTC (exclude top choices) sampling. +/ 99 ref SamplerChain xtc(float p, float t, size_t minKeep = 1, 100 uint seed = LLAMA_DEFAULT_SEED) @nogc nothrow return 101 { 102 llama_sampler_chain_add(_ptr, llama_sampler_init_xtc(p, t, minKeep, seed)); 103 return this; 104 } 105 106 /++ Adds Mirostat v2 sampling (adaptive entropy targeting). +/ 107 ref SamplerChain mirostatV2(float tau = 5.0f, float eta = 0.1f, 108 uint seed = LLAMA_DEFAULT_SEED) @nogc nothrow return 109 { 110 llama_sampler_chain_add(_ptr, llama_sampler_init_mirostat_v2(seed, tau, eta)); 111 return this; 112 } 113 114 /++ 115 Adds grammar-constrained sampling. 116 `grammarStr` is a GBNF grammar; `grammarRoot` is the root rule name (usually `"root"`). 117 +/ 118 ref SamplerChain grammar(const(llama_vocab)* vocab, 119 string grammarStr, string grammarRoot = "root") @trusted return 120 { 121 import std.string : toStringz; 122 llama_sampler_chain_add(_ptr, 123 llama_sampler_init_grammar(cast(llama_vocab*) vocab, 124 grammarStr.toStringz, grammarRoot.toStringz)); 125 return this; 126 } 127 128 /++ 129 Adds DRY (Don't Repeat Yourself) sampling. 130 `seqBreakers` lists token strings that reset the repetition check (e.g. `["\n"]`). 131 Pass `seqBreakers = []` to use no breakers. 132 +/ 133 ref SamplerChain dry(const(llama_vocab)* vocab, 134 int nCtxTrain, 135 float multiplier = 0.0f, 136 float base = 1.75f, 137 int allowedLength = 2, 138 int penaltyLastN = -1, 139 string[] seqBreakers = []) @trusted return 140 { 141 import std.string : toStringz; 142 auto cptrs = new const(char)*[](seqBreakers.length); 143 foreach (i, s; seqBreakers) 144 cptrs[i] = s.toStringz; 145 llama_sampler_chain_add(_ptr, 146 llama_sampler_init_dry(cast(llama_vocab*) vocab, nCtxTrain, multiplier, base, 147 allowedLength, penaltyLastN, 148 cptrs.ptr, cptrs.length)); 149 return this; 150 } 151 152 /++ 153 Adds per-token logit bias adjustments. 154 Each entry in `biases` is a `{token, bias}` pair; `bias > 0` increases probability, `bias < 0` decreases it. 155 +/ 156 ref SamplerChain logitBias(int nVocab, scope const(llama_logit_bias)[] biases) @trusted @nogc nothrow return 157 { 158 llama_sampler_chain_add(_ptr, 159 llama_sampler_init_logit_bias(nVocab, cast(int) biases.length, 160 cast(llama_logit_bias*) biases.ptr)); 161 return this; 162 } 163 164 /// Sample the next token. `batchIdx = -1` uses the last output position. 165 llama_token sample(llama_context* ctx, int batchIdx = -1) @nogc nothrow 166 { 167 return llama_sampler_sample(_ptr, ctx, batchIdx); 168 } 169 170 /// Sample the next token from a `LlamaContext`. 171 llama_token sample(ref LlamaContext ctx, int batchIdx = -1) @nogc nothrow 172 { 173 return llama_sampler_sample(_ptr, ctx.ptr, batchIdx); 174 } 175 176 /// Feed the accepted token back (needed for repetition penalties and similar). 177 void accept(llama_token token) @nogc nothrow { llama_sampler_accept(_ptr, token); } 178 179 void printPerf() @nogc nothrow { llama_perf_sampler_print(_ptr); } 180 }