Skip to content

Commit 7cd099a

Browse files
author
baowending.bwd
committed
feat - support mla kvache store
1 parent 3b07839 commit 7cd099a

File tree

8 files changed

+496
-0
lines changed

8 files changed

+496
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import argparse
2+
import dataclasses
3+
from typing import Tuple, cast
4+
5+
import torch
6+
from triton.testing import do_bench
7+
8+
import flashinfer
9+
10+
11+
@dataclasses.dataclass(kw_only=True)
12+
class ModelConfig:
13+
num_layers: int
14+
ckv_dim: int = 512
15+
kpe_dim: int = 64
16+
17+
18+
MODELS = {
19+
"deepseek_r1": ModelConfig(num_layers=61),
20+
"deepseek_v2_lite": ModelConfig(num_layers=27),
21+
}
22+
23+
24+
@torch.inference_mode()
25+
def main():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--seqlen", type=int, default=5000)
28+
parser.add_argument("--batch-size", type=int, default=8)
29+
parser.add_argument("--page-len", type=int, default=16)
30+
parser.add_argument("--dtype", type=str, default="float16")
31+
args = parser.parse_args()
32+
33+
seqlens_ = [
34+
[1] * args.batch_size,
35+
[args.seqlen - args.batch_size + 1] + [1] * (args.batch_size - 1),
36+
[args.seqlen],
37+
[args.seqlen // args.batch_size] * args.batch_size,
38+
]
39+
seqlen_strlen = max(len(str(seqlens)) for seqlens in seqlens_)
40+
page_len = int(args.page_len)
41+
dtype = getattr(torch, args.dtype)
42+
assert isinstance(dtype, torch.dtype)
43+
device = torch.device("cuda:0")
44+
total_pages = int(25600 / page_len)
45+
46+
torch.cuda.profiler.start()
47+
48+
for model_name, model in MODELS.items():
49+
ckv_page_shape = (page_len, model.ckv_dim)
50+
kpe_page_shape = (page_len, model.kpe_dim)
51+
ckv_layer_buf = torch.empty(
52+
(total_pages,) + ckv_page_shape, dtype=dtype, device=device
53+
)
54+
kpe_layer_buf = torch.empty(
55+
(total_pages,) + kpe_page_shape, dtype=dtype, device=device
56+
)
57+
for seqlens in seqlens_:
58+
ckv = torch.rand(
59+
(sum(seqlens), model.ckv_dim),
60+
dtype=dtype,
61+
device=device,
62+
)
63+
kpe = torch.rand(
64+
(sum(seqlens), model.kpe_dim),
65+
dtype=dtype,
66+
device=device,
67+
)
68+
x_indptr = torch.tensor([0] + seqlens, device=device, dtype=torch.int32)
69+
x_indptr = torch.cumsum(x_indptr, 0, dtype=torch.int32)
70+
kv_indices_host = []
71+
kv_indptr_host = [0]
72+
next_page_id = 0
73+
for seqlen in seqlens:
74+
npages = (seqlen + page_len - 1) // page_len
75+
kv_indices_host.extend(range(next_page_id, next_page_id + npages))
76+
next_page_id += npages
77+
kv_indptr_host.append(len(kv_indices_host))
78+
kv_indices = torch.tensor(kv_indices_host, device=device, dtype=torch.int32)
79+
kv_indptr = torch.tensor(kv_indptr_host, device=device, dtype=torch.int32)
80+
kv_last_page_len = torch.tensor(
81+
[(seqlen - 1) % page_len + 1 for seqlen in seqlens],
82+
device=device,
83+
dtype=torch.int32,
84+
)
85+
86+
@torch.cuda.nvtx.range(f"convert model={model_name}, seqlens={seqlens}")
87+
def fn_convert() -> Tuple[torch.Tensor, torch.Tensor]:
88+
return flashinfer.get_batch_indices_positions(
89+
x_indptr,
90+
flashinfer.get_seq_lens(kv_indptr, kv_last_page_len, page_len),
91+
ckv.shape[0],
92+
)
93+
94+
batch_indices, positions = fn_convert()
95+
convert_latency_ms = cast(float, do_bench(fn_convert))
96+
97+
@torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}")
98+
def fn() -> None:
99+
flashinfer.append_paged_mla_kv_cache(
100+
ckv,
101+
kpe,
102+
batch_indices,
103+
positions,
104+
ckv_layer_buf,
105+
kpe_layer_buf,
106+
kv_indices,
107+
kv_indptr,
108+
kv_last_page_len,
109+
)
110+
111+
latency_ms = cast(float, do_bench(fn))
112+
all_layers_latency_ms = convert_latency_ms + latency_ms * model.num_layers
113+
throughput = (
114+
(ckv.numel() + kpe.numel())
115+
* ckv.element_size()
116+
* sum(1 for _ in ["read", "write"])
117+
/ (latency_ms * 1e-3)
118+
)
119+
print(
120+
f"model: {model_name:8}",
121+
f"seqlens: {seqlens!r:{seqlen_strlen}}",
122+
f"convert: {convert_latency_ms*1e3:2.0f}us",
123+
f"1layer: {latency_ms*1e3:2.0f}us",
124+
f"{model.num_layers}layers: {all_layers_latency_ms*1e3:3.0f}us",
125+
f"throughput: {throughput*1e-9:8.3f}GB/s",
126+
)
127+
print("---")
128+
129+
torch.cuda.profiler.stop()
130+
131+
132+
if __name__ == "__main__":
133+
main()

csrc/flashinfer_ops.cu

+7
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T
8585
at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len,
8686
int64_t layout, int64_t cuda_stream);
8787

88+
void append_paged_mla_kv_cache(at::Tensor append_ckv, at::Tensor append_kpe,
89+
at::Tensor batch_indices, at::Tensor positions, at::Tensor ckv_cache,
90+
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor kv_indptr,
91+
at::Tensor kv_last_page_len, int64_t cuda_stream);
92+
8893
void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices,
8994
at::Tensor block_sparse_indptr,
9095
at::Tensor vector_sparse_offsets,
@@ -246,6 +251,8 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
246251
// page
247252
// Append paged KV-Cache operator
248253
m.def("append_paged_kv_cache", append_paged_kv_cache);
254+
// Append paged MLA KV-Cache operator
255+
m.def("append_paged_mla_kv_cache", append_paged_mla_kv_cache);
249256
// Precompute block sparse offsets
250257
m.def("block_sparse_indices_to_vector_sparse_offsets",
251258
block_sparse_indices_to_vector_sparse_offsets);

csrc/flashinfer_page_ops.cu

+7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::T
2020
at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len,
2121
int64_t layout, int64_t cuda_stream);
2222

23+
void append_paged_mla_kv_cache(at::Tensor append_ckv, at::Tensor append_kpe,
24+
at::Tensor batch_indices, at::Tensor positions, at::Tensor ckv_cache,
25+
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor kv_indptr,
26+
at::Tensor kv_last_page_len, int64_t cuda_stream);
27+
2328
void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices,
2429
at::Tensor block_sparse_indptr,
2530
at::Tensor vector_sparse_offsets,
@@ -31,6 +36,8 @@ void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indic
3136
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3237
// "Append paged KV-Cache operator"
3338
m.def("append_paged_kv_cache", append_paged_kv_cache);
39+
// "Append paged MLA KV-Cache operator"
40+
m.def("append_paged_mla_kv_cache", append_paged_mla_kv_cache);
3441
// "Precompute block sparse offsets"
3542
m.def("block_sparse_indices_to_vector_sparse_offsets",
3643
block_sparse_indices_to_vector_sparse_offsets);

csrc/page.cu

+77
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,80 @@ void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indic
137137
TORCH_CHECK(status == cudaSuccess, "BlockSparseIndicesToVectorSparseOffset failed with error: ",
138138
cudaGetErrorString(status));
139139
}
140+
141+
void append_paged_mla_kv_cache(at::Tensor append_ckv, at::Tensor append_kpe,
142+
at::Tensor batch_indices, at::Tensor positions, at::Tensor ckv_cache,
143+
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor kv_indptr,
144+
at::Tensor kv_last_page_len, int64_t cuda_stream) {
145+
CHECK_LAST_DIM_CONTIGUOUS(append_ckv);
146+
CHECK_LAST_DIM_CONTIGUOUS(append_kpe);
147+
CHECK_INPUT(batch_indices);
148+
CHECK_INPUT(positions);
149+
// NOTE(Zihao): doesn't have to be contiguous
150+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(ckv_cache);
151+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(kpe_cache);
152+
CHECK_INPUT(kv_indices);
153+
CHECK_INPUT(kv_indptr);
154+
CHECK_INPUT(kv_last_page_len);
155+
CHECK_DIM(2, append_ckv);
156+
CHECK_DIM(2, append_kpe);
157+
CHECK_DIM(1, batch_indices);
158+
CHECK_DIM(1, positions);
159+
CHECK_DIM(3, ckv_cache);
160+
CHECK_DIM(3, kpe_cache);
161+
CHECK_DIM(1, kv_indices);
162+
CHECK_DIM(1, kv_indptr);
163+
CHECK_DIM(1, kv_last_page_len);
164+
unsigned int nnz = append_ckv.size(0);
165+
unsigned int batch_size = kv_last_page_len.size(0);
166+
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
167+
CHECK_EQ(batch_indices.size(0), nnz);
168+
CHECK_EQ(positions.size(0), nnz);
169+
auto device = append_ckv.device();
170+
CHECK_EQ(append_ckv.device(), device);
171+
CHECK_EQ(append_kpe.device(), device);
172+
CHECK_EQ(ckv_cache.device(), device);
173+
174+
CHECK_EQ(kv_indices.device(), device);
175+
CHECK_EQ(kv_indptr.device(), device);
176+
CHECK_EQ(kv_last_page_len.device(), device);
177+
178+
unsigned int page_size, ckv_dim, kpe_dim;
179+
page_size = ckv_cache.size(1);
180+
ckv_dim = ckv_cache.size(2);
181+
kpe_dim = kpe_cache.size(2);
182+
183+
// get kv_cache_strides
184+
const int64_t* ckv_strides = ckv_cache.strides().data();
185+
const int64_t* kpe_strides = kpe_cache.strides().data();
186+
187+
auto append_ckv_strides = append_ckv.strides();
188+
auto append_ckv_stride_n = append_ckv_strides[0];
189+
auto append_kpe_strides = append_kpe.strides();
190+
auto append_kpe_stride_n = append_kpe_strides[0];
191+
192+
CHECK_EQ(append_ckv.size(1), ckv_dim);
193+
CHECK_EQ(append_kpe.size(1), kpe_dim);
194+
195+
auto kv_scalar_dtype = ckv_cache.scalar_type();
196+
197+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
198+
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] {
199+
paged_kv_mla_t<c_type, int32_t> paged_mla_kv(
200+
page_size, ckv_dim, kpe_dim, batch_size, static_cast<c_type*>(ckv_cache.data_ptr()),
201+
ckv_strides, static_cast<c_type*>(kpe_cache.data_ptr()), kpe_strides,
202+
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
203+
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
204+
cudaError_t status =
205+
AppendPagedKVMlaCache(paged_mla_kv, static_cast<c_type*>(append_ckv.data_ptr()),
206+
static_cast<c_type*>(append_kpe.data_ptr()),
207+
static_cast<int32_t*>(batch_indices.data_ptr()),
208+
static_cast<int32_t*>(positions.data_ptr()), nnz, append_ckv_stride_n,
209+
append_kpe_stride_n, stream);
210+
TORCH_CHECK(status == cudaSuccess,
211+
"AppendPagedKVMlaCache failed with error: ", cudaGetErrorString(status));
212+
return true;
213+
});
214+
215+
TORCH_CHECK(success, "AppendPagedKVMlaCache failed to dispatch with dtype ", kv_scalar_dtype);
216+
}

flashinfer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from .norm import gemma_rmsnorm as gemma_rmsnorm
4646
from .norm import rmsnorm as rmsnorm
4747
from .page import append_paged_kv_cache as append_paged_kv_cache
48+
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
4849
from .page import get_batch_indices_positions as get_batch_indices_positions
4950
from .page import get_seq_lens as get_seq_lens
5051
from .prefill import (

flashinfer/page.py

+84
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,41 @@ def block_sparse_indices_to_vector_sparse_offsets(
8787
return vector_sparse_offsets
8888

8989

90+
@register_custom_op(
91+
"flashinfer::append_paged_mla_kv_cache",
92+
mutates_args=("ckv_cache", "kpe_cache"),
93+
)
94+
def _append_paged_mla_kv_cache_kernel(
95+
append_ckv: torch.Tensor,
96+
append_kpe: torch.Tensor,
97+
batch_indices: torch.Tensor,
98+
positions: torch.Tensor,
99+
ckv_cache: Optional[torch.Tensor],
100+
kpe_cache: Optional[torch.Tensor],
101+
kv_indices: torch.Tensor,
102+
kv_indptr: torch.Tensor,
103+
kv_last_page_len: torch.Tensor,
104+
) -> None:
105+
with append_ckv.device as device:
106+
batch_indices = batch_indices.int()
107+
positions = positions.int()
108+
kv_indices = kv_indices.int()
109+
kv_indptr = kv_indptr.int()
110+
kv_last_page_len = kv_last_page_len.int()
111+
get_page_module().append_paged_mla_kv_cache(
112+
append_ckv,
113+
append_kpe,
114+
batch_indices,
115+
positions,
116+
ckv_cache,
117+
kpe_cache,
118+
kv_indices,
119+
kv_indptr,
120+
kv_last_page_len,
121+
get_cuda_stream(device),
122+
)
123+
124+
90125
@register_custom_op(
91126
"flashinfer::append_paged_kv_cache",
92127
mutates_args=("paged_k_cache", "paged_v_cache"),
@@ -221,6 +256,55 @@ def get_seq_lens(
221256
)
222257

223258

259+
def append_paged_mla_kv_cache(
260+
append_ckv: torch.Tensor,
261+
append_kpe: torch.Tensor,
262+
batch_indices: torch.Tensor,
263+
positions: torch.Tensor,
264+
ckv_cache: Optional[torch.Tensor],
265+
kpe_cache: Optional[torch.Tensor],
266+
kv_indices: torch.Tensor,
267+
kv_indptr: torch.Tensor,
268+
kv_last_page_len: torch.Tensor,
269+
) -> None:
270+
r"""Append a batch of key-value pairs to a paged key-value cache,
271+
Note: current only support ckv=512 and kpe=64
272+
273+
Parameters
274+
----------
275+
append_ckv : torch.Tensor
276+
The compressed kv tensor to append in ragged tensor format, shape:
277+
``[append_indptr[-1], ckv_dim]``.
278+
append_kpe : torch.Tensor
279+
The value tensor to append in ragged tensor format, shape:
280+
``[append_indptr[-1], kpe_dim]``.
281+
batch_indices : torch.Tensor
282+
The batch indices of the each entry in the appended key-value pairs, shape: ``[append_indptr[-1]]``.
283+
positions : torch.Tensor
284+
The positions of the each entry in the appended key-value pairs, shape: ``[append_indptr[-1]]``.
285+
ckv_cache : cache for compressed kv, torch.Tensor, shape: [page_num, page_size, ckv_dim]
286+
kpe_cache : cache for key position embedding, torch.Tensor, shape: [page_num, page_size, kpe_dim]
287+
kv_indices : torch.Tensor
288+
The page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``.
289+
kv_indptr : torch.Tensor
290+
The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
291+
kv_last_page_len : torch.Tensor
292+
The number of entries in the last page of each request in the paged kv cache,
293+
shape: ``[batch_size]``.
294+
"""
295+
_append_paged_mla_kv_cache_kernel(
296+
append_ckv,
297+
append_kpe,
298+
batch_indices,
299+
positions,
300+
ckv_cache,
301+
kpe_cache,
302+
kv_indices,
303+
kv_indptr,
304+
kv_last_page_len,
305+
)
306+
307+
224308
def append_paged_kv_cache(
225309
append_key: torch.Tensor,
226310
append_value: torch.Tensor,

0 commit comments

Comments
 (0)