1 module llama.batch; 2 3 import llama.llama; 4 5 /// A `llama_batch` that frees itself when it goes out of scope. 6 struct OwnedBatch 7 { 8 private llama_batch _batch; 9 private bool _owned; 10 11 @disable this(); 12 @disable this(this); 13 14 private this(llama_batch b, bool owned) @nogc nothrow 15 { 16 _batch = b; 17 _owned = owned; 18 } 19 20 ~this() @nogc nothrow 21 { 22 if (_owned) { llama_batch_free(_batch); _owned = false; } 23 } 24 25 /// Underlying batch. 26 ref llama_batch get() @trusted @nogc nothrow { return _batch; } 27 } 28 29 /++ Allocates a batch for up to `nTokensMax` tokens. Pass `embd > 0` for embedding batches. +/ 30 OwnedBatch allocBatch(int nTokensMax, int embd = 0) @nogc nothrow 31 { 32 return OwnedBatch(llama_batch_init(nTokensMax, embd, 1), true); 33 } 34 35 /// Wraps a token slice into a batch. The slice must outlive the returned batch. 36 llama_batch batchGetOne(scope const(llama_token)[] tokens) @trusted @nogc nothrow 37 { 38 return llama_batch_get_one(cast(llama_token*) tokens.ptr, cast(int) tokens.length); 39 } 40 41 /// Wraps a raw token pointer into a batch; for C interop. 42 llama_batch batchGetOne(llama_token* tokens, int nTokens) @nogc nothrow 43 { 44 return llama_batch_get_one(tokens, nTokens); 45 } 46 47 /// Reset a batch's token count to zero (keeps allocated memory). 48 void batchClear(ref llama_batch batch) @nogc nothrow 49 { 50 batch.n_tokens = 0; 51 } 52 53 /++ 54 Append one token to a pre-allocated batch (created via `allocBatch`). 55 56 Params: 57 batch = target batch; must have been allocated with `allocBatch` 58 id = token id 59 pos = position in the sequence 60 seqId = sequence this token belongs to 61 logits = request logit output for this position 62 +/ 63 void batchAdd(ref llama_batch batch, 64 llama_token id, 65 llama_pos pos, 66 llama_seq_id seqId, 67 bool logits) @trusted @nogc nothrow 68 { 69 int n = batch.n_tokens; 70 batch.token [n] = id; 71 batch.pos [n] = pos; 72 batch.n_seq_id[n] = 1; 73 batch.seq_id [n][0] = seqId; 74 batch.logits [n] = logits; 75 batch.n_tokens = n + 1; 76 }