@@ -1523,6 +1523,10 @@ struct llama_context {
1523
1523
1524
1524
// decode output (2-dimensional array: [n_tokens][n_vocab])
1525
1525
std::vector<float > logits;
1526
+ #ifndef NDEBUG
1527
+ // guard against access to unset logits
1528
+ std::vector<bool > logits_valid;
1529
+ #endif
1526
1530
bool logits_all = false ;
1527
1531
1528
1532
// input embedding (1-dimensional array: [n_embd])
@@ -6216,20 +6220,37 @@ static int llama_decode_internal(
6216
6220
{
6217
6221
auto & logits_out = lctx.logits ;
6218
6222
6223
+ #ifndef NDEBUG
6224
+ auto & logits_valid = lctx.logits_valid ;
6225
+ logits_valid.clear ();
6226
+ logits_valid.resize (n_tokens);
6227
+
6228
+ logits_out.clear ();
6229
+ #endif
6230
+
6219
6231
if (batch.logits ) {
6220
6232
logits_out.resize (n_vocab * n_tokens);
6221
6233
for (uint32_t i = 0 ; i < n_tokens; i++) {
6222
6234
if (batch.logits [i] == 0 ) {
6223
6235
continue ;
6224
6236
}
6225
6237
memcpy (logits_out.data () + (n_vocab*i), (float *) ggml_get_data (res) + (n_vocab*i), sizeof (float )*n_vocab);
6238
+ #ifndef NDEBUG
6239
+ logits_valid[i] = true ;
6240
+ #endif
6226
6241
}
6227
6242
} else if (lctx.logits_all ) {
6228
6243
logits_out.resize (n_vocab * n_tokens);
6229
6244
memcpy (logits_out.data (), (float *) ggml_get_data (res), sizeof (float )*n_vocab*n_tokens);
6245
+ #ifndef NDEBUG
6246
+ std::fill (logits_valid.begin (), logits_valid.end (), true );
6247
+ #endif
6230
6248
} else {
6231
6249
logits_out.resize (n_vocab);
6232
6250
memcpy (logits_out.data (), (float *) ggml_get_data (res) + (n_vocab*(n_tokens - 1 )), sizeof (float )*n_vocab);
6251
+ #ifndef NDEBUG
6252
+ logits_valid[n_tokens - 1 ] = true ;
6253
+ #endif
6233
6254
}
6234
6255
}
6235
6256
@@ -10118,6 +10139,7 @@ float * llama_get_logits(struct llama_context * ctx) {
10118
10139
}
10119
10140
10120
10141
float * llama_get_logits_ith (struct llama_context * ctx, int32_t i) {
10142
+ assert (ctx->logits_valid .at (i));
10121
10143
return ctx->logits .data () + i*ctx->model .hparams .n_vocab ;
10122
10144
}
10123
10145
0 commit comments