Skip to content

Commit 5a8674d

Browse files
cebtenzzreggerganov
authored andcommitted
llama : sanity checks for access to logits (ggml-org#4274)
Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 46eef5b commit 5a8674d

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

llama.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,10 @@ struct llama_context {
15231523

15241524
// decode output (2-dimensional array: [n_tokens][n_vocab])
15251525
std::vector<float> logits;
1526+
#ifndef NDEBUG
1527+
// guard against access to unset logits
1528+
std::vector<bool> logits_valid;
1529+
#endif
15261530
bool logits_all = false;
15271531

15281532
// input embedding (1-dimensional array: [n_embd])
@@ -6216,20 +6220,37 @@ static int llama_decode_internal(
62166220
{
62176221
auto & logits_out = lctx.logits;
62186222

6223+
#ifndef NDEBUG
6224+
auto & logits_valid = lctx.logits_valid;
6225+
logits_valid.clear();
6226+
logits_valid.resize(n_tokens);
6227+
6228+
logits_out.clear();
6229+
#endif
6230+
62196231
if (batch.logits) {
62206232
logits_out.resize(n_vocab * n_tokens);
62216233
for (uint32_t i = 0; i < n_tokens; i++) {
62226234
if (batch.logits[i] == 0) {
62236235
continue;
62246236
}
62256237
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
6238+
#ifndef NDEBUG
6239+
logits_valid[i] = true;
6240+
#endif
62266241
}
62276242
} else if (lctx.logits_all) {
62286243
logits_out.resize(n_vocab * n_tokens);
62296244
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
6245+
#ifndef NDEBUG
6246+
std::fill(logits_valid.begin(), logits_valid.end(), true);
6247+
#endif
62306248
} else {
62316249
logits_out.resize(n_vocab);
62326250
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
6251+
#ifndef NDEBUG
6252+
logits_valid[n_tokens - 1] = true;
6253+
#endif
62336254
}
62346255
}
62356256

@@ -10118,6 +10139,7 @@ float * llama_get_logits(struct llama_context * ctx) {
1011810139
}
1011910140

1012010141
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
10142+
assert(ctx->logits_valid.at(i));
1012110143
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
1012210144
}
1012310145

0 commit comments

Comments
 (0)