1 module llama.batch;
2 
3 import llama.llama;
4 
5 /// A `llama_batch` that frees itself when it goes out of scope.
6 struct OwnedBatch
7 {
8     private llama_batch _batch;
9     private bool _owned;
10 
11     @disable this();
12     @disable this(this);
13 
14     private this(llama_batch b, bool owned) @nogc nothrow
15     {
16         _batch = b;
17         _owned = owned;
18     }
19 
20     ~this() @nogc nothrow
21     {
22         if (_owned) { llama_batch_free(_batch); _owned = false; }
23     }
24 
25     /// Underlying batch.
26     ref llama_batch get() @trusted @nogc nothrow { return _batch; }
27 }
28 
29 /++ Allocates a batch for up to `nTokensMax` tokens. Pass `embd > 0` for embedding batches. +/
30 OwnedBatch allocBatch(int nTokensMax, int embd = 0) @nogc nothrow
31 {
32     return OwnedBatch(llama_batch_init(nTokensMax, embd, 1), true);
33 }
34 
35 /// Wraps a token slice into a batch. The slice must outlive the returned batch.
36 llama_batch batchGetOne(scope const(llama_token)[] tokens) @trusted @nogc nothrow
37 {
38     return llama_batch_get_one(cast(llama_token*) tokens.ptr, cast(int) tokens.length);
39 }
40 
41 /// Wraps a raw token pointer into a batch; for C interop.
42 llama_batch batchGetOne(llama_token* tokens, int nTokens) @nogc nothrow
43 {
44     return llama_batch_get_one(tokens, nTokens);
45 }
46 
47 /// Reset a batch's token count to zero (keeps allocated memory).
48 void batchClear(ref llama_batch batch) @nogc nothrow
49 {
50     batch.n_tokens = 0;
51 }
52 
53 /++
54 Append one token to a pre-allocated batch (created via `allocBatch`).
55 
56 Params:
57     batch   = target batch; must have been allocated with `allocBatch`
58     id      = token id
59     pos     = position in the sequence
60     seqId   = sequence this token belongs to
61     logits  = request logit output for this position
62 +/
63 void batchAdd(ref llama_batch batch,
64               llama_token     id,
65               llama_pos       pos,
66               llama_seq_id    seqId,
67               bool            logits) @trusted @nogc nothrow
68 {
69     int n = batch.n_tokens;
70     batch.token   [n]    = id;
71     batch.pos     [n]    = pos;
72     batch.n_seq_id[n]    = 1;
73     batch.seq_id  [n][0] = seqId;
74     batch.logits  [n]    = logits;
75     batch.n_tokens       = n + 1;
76 }