Skip to content

Commit a70c8a0

Browse files
authored
kv-cache : use ggml_set_rows (#14285)
* kv-cache : use ggml_set_rows ggml-ci * graph : separate k and v indices ggml-ci * cont : remove redundant ifs ggml-ci * kv-cache : improve find_slot impl * kv-cache : bounds-check when accessing slot_info indices * kv-cache : add comments ggml-ci * ggml : add TODOs for adding GGML_OP_SET_ROWS support in the backends ggml-ci
1 parent 9067487 commit a70c8a0

13 files changed

+450
-142
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
20862086
return false;
20872087
}
20882088
} break;
2089+
case GGML_OP_SET_ROWS:
2090+
{
2091+
// TODO: add support
2092+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
2093+
return false;
2094+
} break;
20892095
case GGML_OP_CPY: {
20902096
ggml_tensor *src = op->src[0];
20912097
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
22222222
default:
22232223
return false;
22242224
}
2225+
case GGML_OP_SET_ROWS:
2226+
{
2227+
// TODO: add support
2228+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
2229+
return false;
2230+
} break;
22252231
case GGML_OP_CPY:
22262232
case GGML_OP_DUP:
22272233
case GGML_OP_CONT:

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42854285
return false;
42864286
}
42874287
}
4288+
case GGML_OP_SET_ROWS:
4289+
{
4290+
// TODO: add support
4291+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
4292+
return false;
4293+
} break;
42884294
case GGML_OP_CPY:
42894295
{
42904296
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10339,6 +10339,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1033910339
return false;
1034010340
}
1034110341
} break;
10342+
case GGML_OP_SET_ROWS:
10343+
{
10344+
// TODO: add support
10345+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
10346+
return false;
10347+
} break;
1034210348
case GGML_OP_CONT:
1034310349
case GGML_OP_CPY:
1034410350
case GGML_OP_DUP:

src/llama-graph.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284-
if (self_kq_mask) {
285-
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286-
}
284+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
285+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
286+
287+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
287288
}
288289

289290
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290-
if (self_kq_mask) {
291-
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292-
}
291+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
293293

294-
if (self_kq_mask_swa) {
295-
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296-
}
294+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
295+
296+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
297+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
298+
299+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
297300
}
298301

299302
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333336
}
334337

335338
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336-
if (self_kq_mask) {
337-
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338-
}
339+
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
340+
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
341+
342+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
339343

340344
const int64_t n_rs = mctx->get_recr()->get_n_rs();
341345

@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350354
}
351355
}
352356

353-
void llm_graph_input_one::set_input(const llama_ubatch *) {
357+
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
358+
GGML_UNUSED(ubatch);
354359
GGML_ASSERT(one && ggml_nelements(one) == 1);
355360
float f_one = 1.0f;
356361
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@@ -997,6 +1002,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
9971002

9981003
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
9991004

1005+
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
1006+
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
1007+
10001008
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
10011009
//cb(inp->self_kq_mask, "KQ_mask", -1);
10021010
ggml_set_input(inp->self_kq_mask);
@@ -1198,8 +1206,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
11981206

11991207
const auto n_kv = mctx_cur->get_n_kv();
12001208

1209+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1210+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1211+
12011212
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1202-
//cb(inp->self_kq_mask, "KQ_mask", -1);
12031213
ggml_set_input(inp->self_kq_mask);
12041214

12051215
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1230,8 +1240,11 @@ ggml_tensor * llm_graph_context::build_attn(
12301240

12311241
// store to KV cache
12321242
{
1233-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1243+
const auto & k_idxs = inp->get_k_idxs();
1244+
const auto & v_idxs = inp->get_v_idxs();
1245+
1246+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1247+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
12351248
}
12361249

12371250
const auto & kq_mask = inp->get_kq_mask();
@@ -1290,11 +1303,15 @@ ggml_tensor * llm_graph_context::build_attn(
12901303

12911304
// optionally store to KV cache
12921305
if (k_cur) {
1293-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1306+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1307+
1308+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
12941309
}
12951310

12961311
if (v_cur) {
1297-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1312+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1313+
1314+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
12981315
}
12991316

13001317
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1398,8 +1415,11 @@ ggml_tensor * llm_graph_context::build_attn(
13981415

13991416
// store to KV cache
14001417
{
1401-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1418+
const auto & k_idxs = inp->get_k_idxs();
1419+
const auto & v_idxs = inp->get_v_idxs();
1420+
1421+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1422+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
14031423
}
14041424

14051425
const auto & kq_mask = inp->get_kq_mask();
@@ -1434,8 +1454,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14341454
{
14351455
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14361456

1457+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1458+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1459+
14371460
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438-
//cb(inp->self_kq_mask, "KQ_mask", -1);
14391461
ggml_set_input(inp->self_kq_mask);
14401462

14411463
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1446,8 +1468,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14461468

14471469
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
14481470

1471+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1472+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1473+
14491474
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14511475
ggml_set_input(inp->self_kq_mask_swa);
14521476

14531477
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,14 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
249249

250250
void set_input(const llama_ubatch * ubatch) override;
251251

252+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
253+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
254+
252255
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
253256

257+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
259+
254260
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
255261
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
256262

@@ -274,9 +280,19 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
274280

275281
void set_input(const llama_ubatch * ubatch) override;
276282

283+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
284+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
285+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
286+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
287+
277288
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
278289
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
279290

291+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
292+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
293+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
295+
280296
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
281297
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
282298
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
@@ -319,8 +335,14 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
319335

320336
ggml_tensor * s_copy; // I32 [kv_size]
321337

338+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
339+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
340+
322341
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
323342

343+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
344+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
345+
324346
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
325347
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
326348

@@ -336,7 +358,7 @@ class llm_graph_input_one : public llm_graph_input_i {
336358
llm_graph_input_one() {}
337359
virtual ~llm_graph_input_one() = default;
338360

339-
void set_input(const llama_ubatch *) override;
361+
void set_input(const llama_ubatch * ubatch) override;
340362

341363
ggml_tensor * one = nullptr; // F32
342364
};

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113
ubatches.push_back(std::move(ubatch)); // NOLINT
114114
}
115115

116-
auto heads_base = kv_base->prepare(ubatches);
117-
if (heads_base.empty()) {
116+
auto sinfos_base = kv_base->prepare(ubatches);
117+
if (sinfos_base.empty()) {
118118
break;
119119
}
120120

121-
auto heads_swa = kv_swa->prepare(ubatches);
122-
if (heads_swa.empty()) {
121+
auto sinfos_swa = kv_swa->prepare(ubatches);
122+
if (sinfos_swa.empty()) {
123123
break;
124124
}
125125

126-
assert(heads_base.size() == heads_swa.size());
126+
assert(sinfos_base.size() == sinfos_swa.size());
127127

128128
return std::make_unique<llama_kv_cache_unified_iswa_context>(
129-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
129+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
130130
} while (false);
131131

132132
// if it fails, try equal split
@@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144144
ubatches.push_back(std::move(ubatch)); // NOLINT
145145
}
146146

147-
auto heads_base = kv_base->prepare(ubatches);
148-
if (heads_base.empty()) {
147+
auto sinfos_base = kv_base->prepare(ubatches);
148+
if (sinfos_base.empty()) {
149149
break;
150150
}
151151

152-
auto heads_swa = kv_swa->prepare(ubatches);
153-
if (heads_swa.empty()) {
152+
auto sinfos_swa = kv_swa->prepare(ubatches);
153+
if (sinfos_swa.empty()) {
154154
break;
155155
}
156156

157-
assert(heads_base.size() == heads_swa.size());
157+
assert(sinfos_base.size() == sinfos_swa.size());
158158

159159
return std::make_unique<llama_kv_cache_unified_iswa_context>(
160-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
160+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
161161
} while (false);
162162

163163
// TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
220220

221221
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222222
llama_kv_cache_unified_iswa * kv,
223-
std::vector<uint32_t> heads_base,
224-
std::vector<uint32_t> heads_swa,
223+
slot_info_vec_t sinfos_base,
224+
slot_info_vec_t sinfos_swa,
225225
std::vector<llama_ubatch> ubatches) :
226226
ubatches(std::move(ubatches)),
227227
// note: here we copy the ubatches. not sure if this is ideal
228-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
228+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
229+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
230230
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231231
}
232232

src/llama-kv-cache-unified-iswa.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
7474

7575
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
7676
public:
77+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
78+
7779
// used for errors
7880
llama_kv_cache_unified_iswa_context(llama_memory_status status);
7981

@@ -90,8 +92,8 @@ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
9092
// used to create a batch processing context from a batch
9193
llama_kv_cache_unified_iswa_context(
9294
llama_kv_cache_unified_iswa * kv,
93-
std::vector<uint32_t> heads_base,
94-
std::vector<uint32_t> heads_swa,
95+
slot_info_vec_t sinfos_base,
96+
slot_info_vec_t sinfos_swa,
9597
std::vector<llama_ubatch> ubatches);
9698

9799
virtual ~llama_kv_cache_unified_iswa_context();

0 commit comments

Comments
 (0)