1 /++
2 Compute normalized text embeddings and print cosine similarity.
3 
4 Usage: `embedding -m embed-model.gguf [-p text] [-ngl n_gpu_layers]`
5 
6 The model must be an embedding model (e.g. nomic-embed-text, bge-*).
7 If no `-p` is given, two built-in sentences are embedded and compared.
8 +/
9 module embedding;
10 
11 import llama;
12 import std.stdio  : writefln, writeln, stderr;
13 import std.conv   : to;
14 import std.math   : sqrt;
15 import core.stdc.locale : setlocale, LC_NUMERIC;
16 import core.stdc.stdio  : printf;
17 
18 int main(string[] args)
19 {
20     setlocale(LC_NUMERIC, "C");
21 
22     string   modelPath;
23     string[] prompts = [
24         "Hello, world!",
25         "The quick brown fox jumps over the lazy dog.",
26     ];
27     int ngl    = 99;
28     int nBatch = 512;
29 
30     for (int i = 1; i < cast(int) args.length; i++)
31     {
32         switch (args[i])
33         {
34             case "-m":
35                 if (++i < cast(int) args.length) modelPath = args[i];
36                 else return printUsage(args[0]);
37                 break;
38             case "-p":
39                 if (++i < cast(int) args.length) prompts = [args[i]];
40                 else return printUsage(args[0]);
41                 break;
42             case "-ngl":
43                 if (++i < cast(int) args.length) ngl = args[i].to!int;
44                 else return printUsage(args[0]);
45                 break;
46             default:
47                 prompts = args[i .. $];
48                 i = cast(int) args.length;
49                 break;
50         }
51     }
52     if (modelPath.length == 0) return printUsage(args[0]);
53 
54     loadAllBackends();
55 
56     auto model = LlamaModel.loadFromFile(modelPath, ngl);
57     if (!model)
58     {
59         stderr.writefln("error: failed to load model '%s'", modelPath);
60         return 1;
61     }
62 
63     const vocab = model.vocab;
64     int   nEmbd = model.nEmbd;
65 
66     // Enable embedding extraction; required for generative models.
67     auto ctxp        = contextParams(cast(uint) nBatch, cast(uint) nBatch);
68     ctxp.embeddings  = true;
69     auto ctx = LlamaContext.fromModel(model, ctxp);
70     if (!ctx)
71     {
72         stderr.writeln("error: failed to create context");
73         return 1;
74     }
75 
76     int poolType = ctx.poolingType;
77     writefln("pooling type : %d  (%s)",
78              poolType,
79              poolType == 0 ? "none — using token embeddings" : "sequence pooling");
80     writefln("n_embd       : %d", nEmbd);
81     writeln();
82 
83     auto embeddings = new float[][](prompts.length, nEmbd);
84 
85     // Embed each prompt independently (simple, non-batched path).
86     foreach (si, prompt; prompts)
87     {
88         auto tokens = tokenize(vocab, prompt, /*addSpecial=*/true);
89         if (tokens is null)
90         {
91             stderr.writefln("error: tokenization failed for prompt %d", si);
92             return 1;
93         }
94 
95         // Clear the KV cache between prompts.
96         ctx.memoryClear(/*data=*/true);
97 
98         // Always use seq_id=0: memory is cleared between prompts.
99         auto ob = allocBatch(cast(int) tokens.length);
100         foreach (j, tok; tokens)
101             batchAdd(ob.get(), tok, cast(llama_pos) j, 0, /*logits=*/true);
102 
103         if (ctx.decode(ob.get()))
104         {
105             stderr.writefln("error: decode failed for prompt %d", si);
106             return 1;
107         }
108 
109         // Retrieve embeddings: pooled (mean/CLS) or last-token.
110         const(float)[] raw;
111         if (poolType == 0) // LLAMA_POOLING_TYPE_NONE
112             raw = ctx.getEmbeddingsIth(ob.get().n_tokens - 1);
113         else
114             raw = ctx.getEmbeddingsSeq(0);
115 
116         if (raw is null)
117         {
118             stderr.writefln("error: no embeddings returned for prompt %d", si);
119             return 1;
120         }
121 
122         // L2-normalize into the output buffer.
123         float norm = 0;
124         foreach (v; raw) norm += v * v;
125         norm = sqrt(norm);
126         if (norm > 1e-9f)
127             foreach (j; 0 .. nEmbd) embeddings[si][j] = raw[j] / norm;
128     }
129 
130     // Print a compact summary (first 8 components) — @trusted wrapper avoids
131     // -preview=safer restrictions on ptr and printf varargs.
132     foreach (i, prompt; prompts)
133         printEmbeddingRow(prompt, i, embeddings[i], nEmbd);
134 
135     // Pairwise cosine similarity (vectors are already unit-length).
136     if (prompts.length >= 2)
137     {
138         double dot = 0;
139         foreach (j; 0 .. nEmbd) dot += embeddings[0][j] * embeddings[1][j];
140         printCosine(dot);
141     }
142 
143     return 0;
144 }
145 
146 // @trusted helpers — isolate all raw-pointer / printf-vararg operations.
147 
148 void printEmbeddingRow(string prompt, size_t idx,
149                        const(float)[] embd, int nEmbd) @trusted nothrow
150 {
151     printf("prompt[%zu]: \"", idx);
152     foreach (c; prompt) printf("%c", cast(int) c);
153     printf("\"\n");
154     printf("embedding  : [");
155     int show = nEmbd < 8 ? nEmbd : 8;
156     foreach (j; 0 .. show)
157     {
158         if (j > 0) printf(", ");
159         printf("%.6f", cast(double) embd[j]);
160     }
161     if (nEmbd > 8) printf(", ...");
162     printf("]\n\n");
163 }
164 
165 void printCosine(double sim) @trusted nothrow
166 {
167     printf("cosine similarity [0, 1]: %.6f\n", sim);
168 }
169 
170 int printUsage(string prog) @trusted nothrow
171 {
172     printf(
173         "\nusage: %s -m embed-model.gguf [-p text] [-ngl n_gpu_layers]\n\n"
174         ~ "  -m    embedding model (GGUF)\n"
175         ~ "  -p    text to embed (default: two built-in sentences)\n"
176         ~ "  -ngl  GPU layers (default: 99)\n\n",
177         prog.ptr);
178     return 1;
179 }