Skip to content

Commit c68e881

Browse files
committed
llamamodel: fix incorrect use of batch API
I filed an upstream PR to discuss this: ggml-org/llama.cpp#4274 Also, make sure to free the batch when we're done with it.
1 parent caba345 commit c68e881

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

gpt4all-backend/llamamodel.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ static int llama_sample_top_p_top_k(
7171
int top_k,
7272
float top_p,
7373
float temp,
74-
float repeat_penalty) {
75-
auto logits = llama_get_logits(ctx);
74+
float repeat_penalty,
75+
int32_t pos) {
76+
auto logits = llama_get_logits_ith(ctx, pos);
7677
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
7778
// Populate initial list of all candidates
7879
std::vector<llama_token_data> candidates;
@@ -274,26 +275,30 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
274275
return llama_sample_top_p_top_k(d_ptr->ctx,
275276
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
276277
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
277-
promptCtx.repeat_penalty);
278+
promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1);
278279
}
279280

280281
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
281282
{
282283
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
283284

284285
batch.n_tokens = tokens.size();
286+
ctx.n_last_batch_tokens = tokens.size();
285287

286288
for (int32_t i = 0; i < batch.n_tokens; i++) {
287-
batch.token[i] = tokens[i];
288-
batch.pos[i] = ctx.n_past + i;
289-
batch.seq_id[i] = 0;
290-
batch.logits[i] = false;
289+
batch.token [i] = tokens[i];
290+
batch.pos [i] = ctx.n_past + i;
291+
batch.n_seq_id[i] = 1;
292+
batch.seq_id [i][0] = 0;
293+
batch.logits [i] = false;
291294
}
292295

293296
// llama_decode will output logits only for the last token of the prompt
294297
batch.logits[batch.n_tokens - 1] = true;
295298

296-
return llama_decode(d_ptr->ctx, batch) == 0;
299+
int res = llama_decode(d_ptr->ctx, batch);
300+
llama_batch_free(batch);
301+
return res == 0;
297302
}
298303

299304
int32_t LLamaModel::contextLength() const

gpt4all-backend/llmodel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ class LLModel {
5454
int32_t n_batch = 9;
5555
float repeat_penalty = 1.10f;
5656
int32_t repeat_last_n = 64; // last n tokens to penalize
57-
float contextErase = 0.75f; // percent of context to erase if we exceed the context
58-
// window
57+
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
58+
int32_t n_last_batch_tokens = 0;
5959
};
6060

6161
struct GPUDevice {

0 commit comments

Comments
 (0)