Skip to content

Commit c642bc0

Browse files
authored
kv-cache : separate recurrent vs non-recurrent impl (#12799)
* kv-cache : serparate recurrent vs non-recurrent impl (wip) ggml-ci * kv-cache : init -> contructor + add llama_memory_params ggml-ci * kv-cache : fix callback reference ggml-ci * context : llama_kv_cache -> llama_memory_i ggml-ci * context : move memory creation logic to model ggml-ci * llama : remove reference of memory during encode ggml-ci * kv-cache : hide padding details in the implementation ggml-ci * kv-cache : add ubatch_next() ggml-ci * context : simplify sbatch logic ggml-ci * kv-cache : hide defrag logic in the implementation ggml-ci * context : hide kv cache details in implementation ggml-ci * build : fix ggml-ci * cont : another fix ggml-ci * kv-cache : simplify interface (wip) ggml-ci * kv-cache : use separate KV cell structs for unified/recurrent ggml-ci * kv-cache : clean-up ggml-ci * model : better llama_model::create_model() signature ggml-ci * kv-cache : fix recurrent seq_rm() ggml-ci * kv-cache : replace `struct callbacks` with `llama_model &` ggml-ci * kv-cache : replace `struct graph_params` with `llama_context &` ggml-ci * kv-cache : fix offload check ggml-ci * context : avoid passing unique_ptr ggml-ci * kv-cache : avoid using the backends from the llama_context ref #13113 ggml-ci * kv-cache : more consistent debug logs [no ci] * kv-cache : do not pass the full llama_context for kv graphs ggml-ci * kv-cache : remove comment * kv-cache : ggml_rope_ext_inplace -> ggml_rope_ext ggml-ci * kv-cache : fix recurrent multi-user case ggml-ci * memory : remove comments [no ci]
1 parent cb06a3c commit c642bc0

11 files changed

+1964
-1052
lines changed

src/llama-batch.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
189189
return ubatch;
190190
}
191191

192-
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
192+
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193193
GGML_ASSERT(batch.n_tokens >= 0);
194194
this->batch = &batch;
195195
this->n_embd = n_embd;
@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
203203
for (size_t i = 0; i < n_tokens; ++i) {
204204
ids[i] = i;
205205
}
206+
206207
if (simple_split) {
207208
seq.resize(1);
208209
llama_sbatch_seq & s = seq[0];
@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
212213
s.length = n_tokens;
213214
return;
214215
}
216+
215217
std::sort(ids.begin(), ids.end(),
216218
[&batch](size_t a, size_t b) {
217219
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
239241
return n_seq_a > n_seq_b;
240242
}
241243
);
244+
242245
// init seq
243246
llama_sbatch_seq * last_seq = nullptr;
244247

@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
262265
seq.push_back(new_seq);
263266
last_seq = &seq.back();
264267
}
268+
265269
// keep shared prompts first at the end, then sort by length descending.
266270
std::sort(seq.begin(), seq.end(),
267271
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {

src/llama-batch.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ struct llama_sbatch {
7070
// sequence-wise split
7171
llama_ubatch split_seq(size_t n_ubatch);
7272

73-
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
73+
llama_sbatch() = default;
74+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
7475
};
7576

7677
// temporary allocate memory for the input batch if needed

0 commit comments

Comments
 (0)