Skip to content

Commit e37f112

Browse files
committed
kv-cache : replace struct callbacks with llama_model &
ggml-ci
1 parent 65cde6d commit e37f112

File tree

4 files changed

+79
-90
lines changed

4 files changed

+79
-90
lines changed

src/llama-kv-cache.cpp

+37-15
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
2222
}
2323

2424
llama_kv_cache_unified::llama_kv_cache_unified(
25-
const llama_hparams & hparams,
26-
callbacks cbs,
27-
ggml_type type_k,
28-
ggml_type type_v,
29-
bool v_trans,
30-
uint32_t kv_size,
31-
uint32_t padding) : cbs(std::move(cbs)), hparams(hparams), v_trans(v_trans), padding(padding) {
25+
const llama_model & model,
26+
ggml_type type_k,
27+
ggml_type type_v,
28+
bool v_trans,
29+
bool offload,
30+
uint32_t kv_size,
31+
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
3232
const int32_t n_layer = hparams.n_layer;
3333

3434
has_shift = false;
@@ -81,7 +81,18 @@ llama_kv_cache_unified::llama_kv_cache_unified(
8181
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
8282
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
8383

84-
ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
84+
const char * dev_name = "CPU";
85+
86+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
87+
88+
if (!offload) {
89+
auto * dev = model.dev_layer(i);
90+
buft = ggml_backend_dev_buffer_type(dev);
91+
92+
dev_name = ggml_backend_dev_name(dev);
93+
}
94+
95+
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);
8596

8697
ggml_context * ctx = ctx_for_buft(buft);
8798
if (!ctx) {
@@ -697,7 +708,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
697708
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
698709
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
699710

700-
ggml_tensor * rope_factors = cbs.get_rope_factors(n_ctx_per_seq, il);
711+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
701712

702713
ggml_tensor * k =
703714
ggml_view_3d(ctx, k_l[il],
@@ -1377,11 +1388,11 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
13771388
//
13781389

13791390
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1380-
const llama_hparams & hparams,
1381-
callbacks cbs,
1382-
ggml_type type_k,
1383-
ggml_type type_v,
1384-
uint32_t kv_size) : cbs(std::move(cbs)), hparams(hparams) {
1391+
const llama_model & model,
1392+
ggml_type type_k,
1393+
ggml_type type_v,
1394+
bool offload,
1395+
uint32_t kv_size) : model(model), hparams(model.hparams) {
13851396
const int32_t n_layer = hparams.n_layer;
13861397

13871398
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -1429,7 +1440,18 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
14291440
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
14301441
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
14311442

1432-
ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
1443+
const char * dev_name = "CPU";
1444+
1445+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
1446+
1447+
if (!offload) {
1448+
auto * dev = model.dev_layer(i);
1449+
buft = ggml_backend_dev_buffer_type(dev);
1450+
1451+
dev_name = ggml_backend_dev_name(dev);
1452+
}
1453+
1454+
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);
14331455

14341456
ggml_context * ctx = ctx_for_buft(buft);
14351457
if (!ctx) {

src/llama-kv-cache.h

+15-24
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,9 @@ struct llama_cparams;
1515
struct llama_hparams;
1616
struct llama_ubatch;
1717
struct llama_sbatch;
18+
struct llama_model;
1819

1920
struct llama_kv_cache : public llama_memory_i {
20-
// can be used to query data from the model if needed
21-
struct callbacks {
22-
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
23-
24-
// get the buffer type of layer il, can be used to offload KV cache layers to a different device
25-
std::function<ggml_backend_buffer_type_t (int il)> get_buft;
26-
};
27-
2821
struct graph_params {
2922
const llm_arch arch;
3023

@@ -139,13 +132,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
139132
static uint32_t get_padding(const llama_cparams & cparams);
140133

141134
llama_kv_cache_unified(
142-
const llama_hparams & hparams,
143-
callbacks cbs,
144-
ggml_type type_k,
145-
ggml_type type_v,
146-
bool v_trans,
147-
uint32_t kv_size,
148-
uint32_t padding);
135+
const llama_model & model,
136+
ggml_type type_k,
137+
ggml_type type_v,
138+
bool v_trans,
139+
bool offload,
140+
uint32_t kv_size,
141+
uint32_t padding);
149142

150143
~llama_kv_cache_unified() = default;
151144

@@ -208,14 +201,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
208201
// computed before each graph build
209202
uint32_t n = 0;
210203

211-
callbacks cbs;
212-
213204
std::vector<kv_cell> cells;
214205

215206
std::vector<ggml_tensor *> k_l; // per layer
216207
std::vector<ggml_tensor *> v_l;
217208

218209
private:
210+
const llama_model & model;
219211
const llama_hparams & hparams;
220212

221213
bool has_shift = false;
@@ -312,11 +304,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
312304
};
313305

314306
llama_kv_cache_recurrent(
315-
const llama_hparams & hparams,
316-
callbacks cbs,
317-
ggml_type type_k,
318-
ggml_type type_v,
319-
uint32_t kv_size);
307+
const llama_model & model,
308+
ggml_type type_k,
309+
ggml_type type_v,
310+
bool offload,
311+
uint32_t kv_size);
320312

321313
~llama_kv_cache_recurrent() = default;
322314

@@ -370,8 +362,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
370362
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
371363
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
372364

373-
callbacks cbs;
374-
375365
// Note: The value of head isn't only used to optimize searching
376366
// for a free KV slot. llama_decode_impl also uses it, so it
377367
// cannot be freely changed after a slot has been allocated.
@@ -388,6 +378,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
388378
std::vector<ggml_tensor *> v_l;
389379

390380
private:
381+
const llama_model & model;
391382
const llama_hparams & hparams;
392383

393384
// commit/restore cache

src/llama-model.cpp

+25-51
Original file line numberDiff line numberDiff line change
@@ -4416,6 +4416,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
44164416
return it->second;
44174417
}
44184418

4419+
ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
4420+
// choose long/short freq factors based on the context size
4421+
if (layers[il].rope_freqs != nullptr) {
4422+
return layers[il].rope_freqs;
4423+
}
4424+
4425+
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
4426+
return layers[il].rope_long;
4427+
}
4428+
4429+
return layers[il].rope_short;
4430+
}
4431+
44194432
struct llm_build_llama : public llm_graph_context {
44204433
llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
44214434
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -4456,7 +4469,7 @@ struct llm_build_llama : public llm_graph_context {
44564469
// self-attention
44574470
{
44584471
// rope freq factors for llama3; may return nullptr for llama2 and other models
4459-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4472+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
44604473

44614474
// compute Q and K and RoPE them
44624475
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4681,7 +4694,7 @@ struct llm_build_deci : public llm_graph_context {
46814694
} else if (n_head > 0) {
46824695
// self-attention
46834696
// rope freq factors for llama3; may return nullptr for llama2 and other models
4684-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4697+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
46854698

46864699
// compute Q and K and RoPE them
46874700
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -7141,7 +7154,7 @@ struct llm_build_phi3 : public llm_graph_context {
71417154
// self-attention
71427155
{
71437156
// rope freq factors for 128k context
7144-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7157+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
71457158

71467159
ggml_tensor* attn_norm_output = build_norm(inpL,
71477160
model.layers[il].attn_norm,
@@ -7893,7 +7906,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
78937906
for (int il = 0; il < n_layer; ++il) {
78947907
ggml_tensor * inpSA = inpL;
78957908

7896-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7909+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
78977910

78987911
// norm
78997912
cur = build_norm(inpL,
@@ -8961,7 +8974,7 @@ struct llm_build_cohere2 : public llm_graph_context {
89618974
// self-attention
89628975
{
89638976
// rope freq factors for 128k context
8964-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
8977+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
89658978

89668979
// compute Q and K and RoPE them
89678980
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9899,7 +9912,7 @@ struct llm_build_deepseek : public llm_graph_context {
98999912
// self-attention
99009913
{
99019914
// rope freq factors for llama3; may return nullptr for llama2 and other models
9902-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9915+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
99039916

99049917
// compute Q and K and RoPE them
99059918
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11264,7 +11277,7 @@ struct llm_build_exaone : public llm_graph_context {
1126411277
// self-attention
1126511278
{
1126611279
// rope freq factors for llama3; may return nullptr for llama2 and other models
11267-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
11280+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
1126811281

1126911282
// compute Q and K and RoPE them
1127011283
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12645,7 +12658,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
1264512658
// self-attention
1264612659
{
1264712660
// rope freq factors for llama3; may return nullptr for llama2 and other models
12648-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
12661+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
1264912662

1265012663
// compute Q and K and RoPE them
1265112664
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12768,28 +12781,6 @@ struct llm_build_bailingmoe : public llm_graph_context {
1276812781
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
1276912782
llama_memory_i * res;
1277012783

12771-
const bool offload = cparams.offload_kqv;
12772-
12773-
auto get_buft = [this, offload](int il) {
12774-
const char * dev_name = "CPU";
12775-
12776-
ggml_backend_buffer_type_t buft;
12777-
if (offload) {
12778-
auto * dev = dev_layer(il);
12779-
buft = ggml_backend_dev_buffer_type(dev);
12780-
12781-
dev_name = ggml_backend_dev_name(dev);
12782-
} else {
12783-
buft = ggml_backend_cpu_buffer_type();
12784-
}
12785-
12786-
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", il, dev_name);
12787-
12788-
return buft;
12789-
};
12790-
12791-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
12792-
1279312784
switch (arch) {
1279412785
case LLM_ARCH_MAMBA:
1279512786
case LLM_ARCH_RWKV6:
@@ -12798,13 +12789,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1279812789
case LLM_ARCH_ARWKV7:
1279912790
{
1280012791
res = new llama_kv_cache_recurrent(
12801-
hparams,
12802-
{
12803-
/*.get_rope_factors =*/ nullptr,
12804-
/*.get_buft =*/ get_buft,
12805-
},
12792+
*this,
1280612793
GGML_TYPE_F32,
1280712794
GGML_TYPE_F32,
12795+
cparams.offload_kqv,
1280812796
std::max((uint32_t) 1, cparams.n_seq_max));
1280912797
} break;
1281012798
default:
@@ -12816,25 +12804,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1281612804
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1281712805

1281812806
res = new llama_kv_cache_unified(
12819-
hparams,
12820-
{
12821-
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
12822-
// choose long/short freq factors based on the context size
12823-
if (layers[il].rope_freqs != nullptr) {
12824-
return layers[il].rope_freqs;
12825-
}
12826-
12827-
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
12828-
return layers[il].rope_long;
12829-
}
12830-
12831-
return layers[il].rope_short;
12832-
},
12833-
/*.get_buft =*/ get_buft,
12834-
},
12807+
*this,
1283512808
params.type_k,
1283612809
params.type_v,
1283712810
!cparams.flash_attn,
12811+
cparams.offload_kqv,
1283812812
cparams.n_ctx,
1283912813
padding);
1284012814
}

src/llama-model.h

+2
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ struct llama_model {
390390

391391
const struct ggml_tensor * get_tensor(const char * name) const;
392392

393+
ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
394+
393395
// note: can mutate `cparams`
394396
// TODO: move this to new llm_arch_model_i interface
395397
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;

0 commit comments

Comments
 (0)