@@ -71,8 +71,9 @@ static int llama_sample_top_p_top_k(
71
71
int top_k,
72
72
float top_p,
73
73
float temp,
74
- float repeat_penalty) {
75
- auto logits = llama_get_logits (ctx);
74
+ float repeat_penalty,
75
+ int32_t pos) {
76
+ auto logits = llama_get_logits_ith (ctx, pos);
76
77
auto n_vocab = llama_n_vocab (llama_get_model (ctx));
77
78
// Populate initial list of all candidates
78
79
std::vector<llama_token_data> candidates;
@@ -274,26 +275,30 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
274
275
return llama_sample_top_p_top_k (d_ptr->ctx ,
275
276
promptCtx.tokens .data () + promptCtx.tokens .size () - n_prev_toks,
276
277
n_prev_toks, promptCtx.top_k , promptCtx.top_p , promptCtx.temp ,
277
- promptCtx.repeat_penalty );
278
+ promptCtx.repeat_penalty , promptCtx. n_last_batch_tokens - 1 );
278
279
}
279
280
280
281
bool LLamaModel::evalTokens (PromptContext &ctx, const std::vector<int32_t > &tokens) const
281
282
{
282
283
llama_batch batch = llama_batch_init (tokens.size (), 0 , 1 );
283
284
284
285
batch.n_tokens = tokens.size ();
286
+ ctx.n_last_batch_tokens = tokens.size ();
285
287
286
288
for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
287
- batch.token [i] = tokens[i];
288
- batch.pos [i] = ctx.n_past + i;
289
- batch.seq_id [i] = 0 ;
290
- batch.logits [i] = false ;
289
+ batch.token [i] = tokens[i];
290
+ batch.pos [i] = ctx.n_past + i;
291
+ batch.n_seq_id [i] = 1 ;
292
+ batch.seq_id [i][0 ] = 0 ;
293
+ batch.logits [i] = false ;
291
294
}
292
295
293
296
// llama_decode will output logits only for the last token of the prompt
294
297
batch.logits [batch.n_tokens - 1 ] = true ;
295
298
296
- return llama_decode (d_ptr->ctx , batch) == 0 ;
299
+ int res = llama_decode (d_ptr->ctx , batch);
300
+ llama_batch_free (batch);
301
+ return res == 0 ;
297
302
}
298
303
299
304
int32_t LLamaModel::contextLength () const
0 commit comments