1 /++ 2 Compute normalized text embeddings and print cosine similarity. 3 4 Usage: `embedding -m embed-model.gguf [-p text] [-ngl n_gpu_layers]` 5 6 The model must be an embedding model (e.g. nomic-embed-text, bge-*). 7 If no `-p` is given, two built-in sentences are embedded and compared. 8 +/ 9 module embedding; 10 11 import llama; 12 import std.stdio : writefln, writeln, stderr; 13 import std.conv : to; 14 import std.math : sqrt; 15 import core.stdc.locale : setlocale, LC_NUMERIC; 16 import core.stdc.stdio : printf; 17 18 int main(string[] args) 19 { 20 setlocale(LC_NUMERIC, "C"); 21 22 string modelPath; 23 string[] prompts = [ 24 "Hello, world!", 25 "The quick brown fox jumps over the lazy dog.", 26 ]; 27 int ngl = 99; 28 int nBatch = 512; 29 30 for (int i = 1; i < cast(int) args.length; i++) 31 { 32 switch (args[i]) 33 { 34 case "-m": 35 if (++i < cast(int) args.length) modelPath = args[i]; 36 else return printUsage(args[0]); 37 break; 38 case "-p": 39 if (++i < cast(int) args.length) prompts = [args[i]]; 40 else return printUsage(args[0]); 41 break; 42 case "-ngl": 43 if (++i < cast(int) args.length) ngl = args[i].to!int; 44 else return printUsage(args[0]); 45 break; 46 default: 47 prompts = args[i .. $]; 48 i = cast(int) args.length; 49 break; 50 } 51 } 52 if (modelPath.length == 0) return printUsage(args[0]); 53 54 loadAllBackends(); 55 56 auto model = LlamaModel.loadFromFile(modelPath, ngl); 57 if (!model) 58 { 59 stderr.writefln("error: failed to load model '%s'", modelPath); 60 return 1; 61 } 62 63 const vocab = model.vocab; 64 int nEmbd = model.nEmbd; 65 66 // Enable embedding extraction; required for generative models. 67 auto ctxp = contextParams(cast(uint) nBatch, cast(uint) nBatch); 68 ctxp.embeddings = true; 69 auto ctx = LlamaContext.fromModel(model, ctxp); 70 if (!ctx) 71 { 72 stderr.writeln("error: failed to create context"); 73 return 1; 74 } 75 76 int poolType = ctx.poolingType; 77 writefln("pooling type : %d (%s)", 78 poolType, 79 poolType == 0 ? "none — using token embeddings" : "sequence pooling"); 80 writefln("n_embd : %d", nEmbd); 81 writeln(); 82 83 auto embeddings = new float[][](prompts.length, nEmbd); 84 85 // Embed each prompt independently (simple, non-batched path). 86 foreach (si, prompt; prompts) 87 { 88 auto tokens = tokenize(vocab, prompt, /*addSpecial=*/true); 89 if (tokens is null) 90 { 91 stderr.writefln("error: tokenization failed for prompt %d", si); 92 return 1; 93 } 94 95 // Clear the KV cache between prompts. 96 ctx.memoryClear(/*data=*/true); 97 98 // Always use seq_id=0: memory is cleared between prompts. 99 auto ob = allocBatch(cast(int) tokens.length); 100 foreach (j, tok; tokens) 101 batchAdd(ob.get(), tok, cast(llama_pos) j, 0, /*logits=*/true); 102 103 if (ctx.decode(ob.get())) 104 { 105 stderr.writefln("error: decode failed for prompt %d", si); 106 return 1; 107 } 108 109 // Retrieve embeddings: pooled (mean/CLS) or last-token. 110 const(float)[] raw; 111 if (poolType == 0) // LLAMA_POOLING_TYPE_NONE 112 raw = ctx.getEmbeddingsIth(ob.get().n_tokens - 1); 113 else 114 raw = ctx.getEmbeddingsSeq(0); 115 116 if (raw is null) 117 { 118 stderr.writefln("error: no embeddings returned for prompt %d", si); 119 return 1; 120 } 121 122 // L2-normalize into the output buffer. 123 float norm = 0; 124 foreach (v; raw) norm += v * v; 125 norm = sqrt(norm); 126 if (norm > 1e-9f) 127 foreach (j; 0 .. nEmbd) embeddings[si][j] = raw[j] / norm; 128 } 129 130 // Print a compact summary (first 8 components) — @trusted wrapper avoids 131 // -preview=safer restrictions on ptr and printf varargs. 132 foreach (i, prompt; prompts) 133 printEmbeddingRow(prompt, i, embeddings[i], nEmbd); 134 135 // Pairwise cosine similarity (vectors are already unit-length). 136 if (prompts.length >= 2) 137 { 138 double dot = 0; 139 foreach (j; 0 .. nEmbd) dot += embeddings[0][j] * embeddings[1][j]; 140 printCosine(dot); 141 } 142 143 return 0; 144 } 145 146 // @trusted helpers — isolate all raw-pointer / printf-vararg operations. 147 148 void printEmbeddingRow(string prompt, size_t idx, 149 const(float)[] embd, int nEmbd) @trusted nothrow 150 { 151 printf("prompt[%zu]: \"", idx); 152 foreach (c; prompt) printf("%c", cast(int) c); 153 printf("\"\n"); 154 printf("embedding : ["); 155 int show = nEmbd < 8 ? nEmbd : 8; 156 foreach (j; 0 .. show) 157 { 158 if (j > 0) printf(", "); 159 printf("%.6f", cast(double) embd[j]); 160 } 161 if (nEmbd > 8) printf(", ..."); 162 printf("]\n\n"); 163 } 164 165 void printCosine(double sim) @trusted nothrow 166 { 167 printf("cosine similarity [0, 1]: %.6f\n", sim); 168 } 169 170 int printUsage(string prog) @trusted nothrow 171 { 172 printf( 173 "\nusage: %s -m embed-model.gguf [-p text] [-ngl n_gpu_layers]\n\n" 174 ~ " -m embedding model (GGUF)\n" 175 ~ " -p text to embed (default: two built-in sentences)\n" 176 ~ " -ngl GPU layers (default: 99)\n\n", 177 prog.ptr); 178 return 1; 179 }