Skip to content

Commit 7e4b545

Browse files
committed
kv-cache : replace struct callbacks with llama_model &
ggml-ci
1 parent 65cde6d commit 7e4b545

File tree

5 files changed

+80
-95
lines changed

5 files changed

+80
-95
lines changed

src/llama-context.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,6 @@ void llama_context::kv_self_update() {
440440
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
441441

442442
need_reserve = kv_self->update({
443-
/*.arch =*/ model.arch,
444443
/*.cparams =*/ cparams,
445444
/*.sched =*/ sched.get(),
446445
/*.backends =*/ backends,

src/llama-kv-cache.cpp

+38-17
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
2222
}
2323

2424
llama_kv_cache_unified::llama_kv_cache_unified(
25-
const llama_hparams & hparams,
26-
callbacks cbs,
27-
ggml_type type_k,
28-
ggml_type type_v,
29-
bool v_trans,
30-
uint32_t kv_size,
31-
uint32_t padding) : cbs(std::move(cbs)), hparams(hparams), v_trans(v_trans), padding(padding) {
25+
const llama_model & model,
26+
ggml_type type_k,
27+
ggml_type type_v,
28+
bool v_trans,
29+
bool offload,
30+
uint32_t kv_size,
31+
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
3232
const int32_t n_layer = hparams.n_layer;
3333

3434
has_shift = false;
@@ -81,7 +81,18 @@ llama_kv_cache_unified::llama_kv_cache_unified(
8181
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
8282
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
8383

84-
ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
84+
const char * dev_name = "CPU";
85+
86+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
87+
88+
if (!offload) {
89+
auto * dev = model.dev_layer(i);
90+
buft = ggml_backend_dev_buffer_type(dev);
91+
92+
dev_name = ggml_backend_dev_name(dev);
93+
}
94+
95+
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);
8596

8697
ggml_context * ctx = ctx_for_buft(buft);
8798
if (!ctx) {
@@ -588,7 +599,6 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
588599
float freq_base,
589600
float freq_scale,
590601
ggml_backend_buffer * bbuf) const {
591-
const auto & arch = params.arch;
592602
const auto & cparams = params.cparams;
593603
const auto & backends = params.backends;
594604
const auto & sched = params.sched;
@@ -604,7 +614,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
604614

605615
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
606616
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
607-
const float yarn_attn_factor = arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
617+
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
608618

609619
ggml_tensor * tmp;
610620

@@ -697,7 +707,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
697707
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
698708
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
699709

700-
ggml_tensor * rope_factors = cbs.get_rope_factors(n_ctx_per_seq, il);
710+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
701711

702712
ggml_tensor * k =
703713
ggml_view_3d(ctx, k_l[il],
@@ -1377,11 +1387,11 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
13771387
//
13781388

13791389
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1380-
const llama_hparams & hparams,
1381-
callbacks cbs,
1382-
ggml_type type_k,
1383-
ggml_type type_v,
1384-
uint32_t kv_size) : cbs(std::move(cbs)), hparams(hparams) {
1390+
const llama_model & model,
1391+
ggml_type type_k,
1392+
ggml_type type_v,
1393+
bool offload,
1394+
uint32_t kv_size) : hparams(model.hparams) {
13851395
const int32_t n_layer = hparams.n_layer;
13861396

13871397
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -1429,7 +1439,18 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
14291439
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
14301440
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
14311441

1432-
ggml_backend_buffer_type_t buft = this->cbs.get_buft(i);
1442+
const char * dev_name = "CPU";
1443+
1444+
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
1445+
1446+
if (!offload) {
1447+
auto * dev = model.dev_layer(i);
1448+
buft = ggml_backend_dev_buffer_type(dev);
1449+
1450+
dev_name = ggml_backend_dev_name(dev);
1451+
}
1452+
1453+
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name);
14331454

14341455
ggml_context * ctx = ctx_for_buft(buft);
14351456
if (!ctx) {

src/llama-kv-cache.h

+15-26
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,10 @@ struct llama_cparams;
1515
struct llama_hparams;
1616
struct llama_ubatch;
1717
struct llama_sbatch;
18+
struct llama_model;
1819

1920
struct llama_kv_cache : public llama_memory_i {
20-
// can be used to query data from the model if needed
21-
struct callbacks {
22-
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
23-
24-
// get the buffer type of layer il, can be used to offload KV cache layers to a different device
25-
std::function<ggml_backend_buffer_type_t (int il)> get_buft;
26-
};
27-
2821
struct graph_params {
29-
const llm_arch arch;
30-
3122
const llama_cparams & cparams;
3223

3324
const ggml_backend_sched_t & sched;
@@ -139,13 +130,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
139130
static uint32_t get_padding(const llama_cparams & cparams);
140131

141132
llama_kv_cache_unified(
142-
const llama_hparams & hparams,
143-
callbacks cbs,
144-
ggml_type type_k,
145-
ggml_type type_v,
146-
bool v_trans,
147-
uint32_t kv_size,
148-
uint32_t padding);
133+
const llama_model & model,
134+
ggml_type type_k,
135+
ggml_type type_v,
136+
bool v_trans,
137+
bool offload,
138+
uint32_t kv_size,
139+
uint32_t padding);
149140

150141
~llama_kv_cache_unified() = default;
151142

@@ -208,14 +199,13 @@ class llama_kv_cache_unified : public llama_kv_cache {
208199
// computed before each graph build
209200
uint32_t n = 0;
210201

211-
callbacks cbs;
212-
213202
std::vector<kv_cell> cells;
214203

215204
std::vector<ggml_tensor *> k_l; // per layer
216205
std::vector<ggml_tensor *> v_l;
217206

218207
private:
208+
const llama_model & model;
219209
const llama_hparams & hparams;
220210

221211
bool has_shift = false;
@@ -312,11 +302,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
312302
};
313303

314304
llama_kv_cache_recurrent(
315-
const llama_hparams & hparams,
316-
callbacks cbs,
317-
ggml_type type_k,
318-
ggml_type type_v,
319-
uint32_t kv_size);
305+
const llama_model & model,
306+
ggml_type type_k,
307+
ggml_type type_v,
308+
bool offload,
309+
uint32_t kv_size);
320310

321311
~llama_kv_cache_recurrent() = default;
322312

@@ -370,8 +360,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
370360
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
371361
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
372362

373-
callbacks cbs;
374-
375363
// Note: The value of head isn't only used to optimize searching
376364
// for a free KV slot. llama_decode_impl also uses it, so it
377365
// cannot be freely changed after a slot has been allocated.
@@ -388,6 +376,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
388376
std::vector<ggml_tensor *> v_l;
389377

390378
private:
379+
//const llama_model & model;
391380
const llama_hparams & hparams;
392381

393382
// commit/restore cache

src/llama-model.cpp

+25-51
Original file line numberDiff line numberDiff line change
@@ -4416,6 +4416,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
44164416
return it->second;
44174417
}
44184418

4419+
ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
4420+
// choose long/short freq factors based on the context size
4421+
if (layers[il].rope_freqs != nullptr) {
4422+
return layers[il].rope_freqs;
4423+
}
4424+
4425+
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
4426+
return layers[il].rope_long;
4427+
}
4428+
4429+
return layers[il].rope_short;
4430+
}
4431+
44194432
struct llm_build_llama : public llm_graph_context {
44204433
llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
44214434
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -4456,7 +4469,7 @@ struct llm_build_llama : public llm_graph_context {
44564469
// self-attention
44574470
{
44584471
// rope freq factors for llama3; may return nullptr for llama2 and other models
4459-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4472+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
44604473

44614474
// compute Q and K and RoPE them
44624475
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4681,7 +4694,7 @@ struct llm_build_deci : public llm_graph_context {
46814694
} else if (n_head > 0) {
46824695
// self-attention
46834696
// rope freq factors for llama3; may return nullptr for llama2 and other models
4684-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
4697+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
46854698

46864699
// compute Q and K and RoPE them
46874700
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -7141,7 +7154,7 @@ struct llm_build_phi3 : public llm_graph_context {
71417154
// self-attention
71427155
{
71437156
// rope freq factors for 128k context
7144-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7157+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
71457158

71467159
ggml_tensor* attn_norm_output = build_norm(inpL,
71477160
model.layers[il].attn_norm,
@@ -7893,7 +7906,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
78937906
for (int il = 0; il < n_layer; ++il) {
78947907
ggml_tensor * inpSA = inpL;
78957908

7896-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
7909+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
78977910

78987911
// norm
78997912
cur = build_norm(inpL,
@@ -8961,7 +8974,7 @@ struct llm_build_cohere2 : public llm_graph_context {
89618974
// self-attention
89628975
{
89638976
// rope freq factors for 128k context
8964-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
8977+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
89658978

89668979
// compute Q and K and RoPE them
89678980
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9899,7 +9912,7 @@ struct llm_build_deepseek : public llm_graph_context {
98999912
// self-attention
99009913
{
99019914
// rope freq factors for llama3; may return nullptr for llama2 and other models
9902-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
9915+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
99039916

99049917
// compute Q and K and RoPE them
99059918
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11264,7 +11277,7 @@ struct llm_build_exaone : public llm_graph_context {
1126411277
// self-attention
1126511278
{
1126611279
// rope freq factors for llama3; may return nullptr for llama2 and other models
11267-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
11280+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
1126811281

1126911282
// compute Q and K and RoPE them
1127011283
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12645,7 +12658,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
1264512658
// self-attention
1264612659
{
1264712660
// rope freq factors for llama3; may return nullptr for llama2 and other models
12648-
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
12661+
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
1264912662

1265012663
// compute Q and K and RoPE them
1265112664
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12768,28 +12781,6 @@ struct llm_build_bailingmoe : public llm_graph_context {
1276812781
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
1276912782
llama_memory_i * res;
1277012783

12771-
const bool offload = cparams.offload_kqv;
12772-
12773-
auto get_buft = [this, offload](int il) {
12774-
const char * dev_name = "CPU";
12775-
12776-
ggml_backend_buffer_type_t buft;
12777-
if (offload) {
12778-
auto * dev = dev_layer(il);
12779-
buft = ggml_backend_dev_buffer_type(dev);
12780-
12781-
dev_name = ggml_backend_dev_name(dev);
12782-
} else {
12783-
buft = ggml_backend_cpu_buffer_type();
12784-
}
12785-
12786-
LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", il, dev_name);
12787-
12788-
return buft;
12789-
};
12790-
12791-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
12792-
1279312784
switch (arch) {
1279412785
case LLM_ARCH_MAMBA:
1279512786
case LLM_ARCH_RWKV6:
@@ -12798,13 +12789,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1279812789
case LLM_ARCH_ARWKV7:
1279912790
{
1280012791
res = new llama_kv_cache_recurrent(
12801-
hparams,
12802-
{
12803-
/*.get_rope_factors =*/ nullptr,
12804-
/*.get_buft =*/ get_buft,
12805-
},
12792+
*this,
1280612793
GGML_TYPE_F32,
1280712794
GGML_TYPE_F32,
12795+
cparams.offload_kqv,
1280812796
std::max((uint32_t) 1, cparams.n_seq_max));
1280912797
} break;
1281012798
default:
@@ -12816,25 +12804,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1281612804
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1281712805

1281812806
res = new llama_kv_cache_unified(
12819-
hparams,
12820-
{
12821-
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
12822-
// choose long/short freq factors based on the context size
12823-
if (layers[il].rope_freqs != nullptr) {
12824-
return layers[il].rope_freqs;
12825-
}
12826-
12827-
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
12828-
return layers[il].rope_long;
12829-
}
12830-
12831-
return layers[il].rope_short;
12832-
},
12833-
/*.get_buft =*/ get_buft,
12834-
},
12807+
*this,
1283512808
params.type_k,
1283612809
params.type_v,
1283712810
!cparams.flash_attn,
12811+
cparams.offload_kqv,
1283812812
cparams.n_ctx,
1283912813
padding);
1284012814
}

src/llama-model.h

+2
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ struct llama_model {
390390

391391
const struct ggml_tensor * get_tensor(const char * name) const;
392392

393+
ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
394+
393395
// note: can mutate `cparams`
394396
// TODO: move this to new llm_arch_model_i interface
395397
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;

0 commit comments

Comments
 (0)