Skip to content

Commit 88e3dee

Browse files
authored
perf: MLA decode kernel implemented by CuTe targeted to SM80 (#844)
Hi @yzh119 , this is a follow up of #766, an interesting idea came to my mind today, can't help to change few lines to verify this idea. We can use asymmetric warp config to solve the register file size limit issue, the solution is simply to use 8 warps for the output mma stage, and keep other parts unchanged, because the limitation is on the reg num per cuda block not the whole SM, there is 64K 32b registers per SM which is enough for the f32 output of 64 heads. So we now have 4 warps for the att mma stage, 2 warps for the softmax stage, 8 warps for output mma stage, and 4 warps for data load stage, the diagram is updated below: ![image](https://github.com/user-attachments/assets/2af8c5d9-d5a5-47e6-bd63-7e6b4305a529) After the change, output mma stage needs more computation, the benchmark drops a little as expected, but still looks good: ![image](https://github.com/user-attachments/assets/470ec576-ba91-4e71-9604-fcd6f0a9d691) It seems the performance of this CuTe implementation is slightly better than the current FA2 implementation according to #814 ![image](https://github.com/user-attachments/assets/9f61e2ff-4bb6-4581-a199-bb6176173192) So I think this CuTe implementation still has its value, consider such interesting scheduling design and better performance, maybe we can regard it as an ad hoc implementation for (decode only /128 q-heads / SM80) case, and JIT logic can accommodate this kernel.
1 parent b19ad91 commit 88e3dee

File tree

7 files changed

+790
-14
lines changed

7 files changed

+790
-14
lines changed

csrc/batch_decode_mla_config.jinja

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ constexpr bool USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
1414
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
1515
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};
1616

17+
constexpr int QO_TILE_LEN = {{ qo_tile_len }};
18+
1719
using Params = BatchDecodeParamsMLA<DTypeQ, DTypeKV, DTypeO, IdType>;
1820
using AttentionVariant =
1921
DefaultAttention</*use_custom_mask=*/false, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, /*use_alibi*/false>;

csrc/batch_decode_mla_cute_sm80.cu

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#include <optional>
2+
3+
#include "pytorch_extension_utils.h"
4+
5+
#include "mla_config.inc"
6+
7+
#include <flashinfer/attention/decode_mla_cute_sm80.cuh>
8+
#include <flashinfer/attention/scheduler.cuh>
9+
10+
using namespace flashinfer;
11+
12+
std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA(
13+
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
14+
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
15+
unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph,
16+
int64_t cuda_stream) {
17+
size_t float_workspace_size_in_bytes =
18+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
19+
size_t int_workspace_size_in_bytes =
20+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
21+
22+
DecodePlanInfo plan_info;
23+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
24+
25+
auto work_estimation_func =
26+
BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN,
27+
AttentionVariant, Params>;
28+
cudaError_t status =
29+
DecodePlan<HEAD_DIM_CKV, flashinfer::PosEncodingMode::kNone, AttentionVariant, Params>(
30+
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
31+
static_cast<void*>(int_workspace_buffer.data_ptr()),
32+
static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()),
33+
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()),
34+
batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream,
35+
work_estimation_func);
36+
37+
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCachePlanMLA failed with error ",
38+
cudaGetErrorString(status));
39+
40+
return plan_info.ToVector();
41+
}
42+
43+
44+
void BatchDecodeWithPagedKVCacheRunMLA(
45+
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
46+
std::vector<int64_t> plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
47+
at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr,
48+
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale,
49+
int window_left, float logits_soft_cap, float rope_scale, float rope_theta,
50+
std::optional<at::Tensor> maybe_lse, int64_t cuda_stream) {
51+
DecodePlanInfo plan_info;
52+
plan_info.FromVector(plan_info_vec);
53+
54+
auto device = q_nope.device();
55+
int64_t batch_size = q_nope.size(0);
56+
int64_t num_qo_heads = q_nope.size(1);
57+
int64_t page_size = paged_ckv_cache.size(1);
58+
59+
if (maybe_lse) {
60+
const auto& lse = *maybe_lse;
61+
TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q_nope.size(0));
62+
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q_nope.size(1));
63+
}
64+
65+
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
66+
67+
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
68+
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());
69+
70+
paged_kv_mla_t<DTypeKV, IdType> paged_kv(
71+
page_size, HEAD_DIM_CKV, HEAD_DIM_KPE, batch_size,
72+
static_cast<DTypeKV*>(paged_ckv_cache.data_ptr()), paged_ckv_cache.strides().data(),
73+
static_cast<DTypeKV*>(paged_kpe_cache.data_ptr()), paged_kpe_cache.strides().data(),
74+
static_cast<IdType*>(paged_kv_indices.data_ptr()),
75+
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
76+
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
77+
Params params(static_cast<DTypeQ*>(q_nope.data_ptr()), static_cast<DTypeQ*>(q_pe.data_ptr()),
78+
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
79+
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
80+
num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);
81+
82+
DTypeO* tmp_v = nullptr;
83+
float* tmp_s = nullptr;
84+
params.request_indices =
85+
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
86+
params.kv_tile_indices =
87+
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
88+
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
89+
params.kv_chunk_size_ptr =
90+
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset);
91+
if (plan_info.split_kv) {
92+
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
93+
tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
94+
if (plan_info.enable_cuda_graph) {
95+
params.block_valid_mask =
96+
GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset);
97+
}
98+
}
99+
params.padded_batch_size = plan_info.padded_batch_size;
100+
101+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
102+
cudaError_t status =
103+
BatchDecodeWithPagedKVCacheDispatchedMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN,
104+
Params>(params, tmp_v, tmp_s, /*stream=*/stream);
105+
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
106+
cudaGetErrorString(status));
107+
}

flashinfer/decode.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,7 @@ def __init__(
12521252
self,
12531253
float_workspace_buffer: torch.Tensor,
12541254
use_cuda_graph: bool = False,
1255+
use_tensor_cores: bool = False,
12551256
paged_kv_indptr_buffer: Optional[torch.Tensor] = None,
12561257
paged_kv_indices_buffer: Optional[torch.Tensor] = None,
12571258
paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None,
@@ -1269,7 +1270,11 @@ def __init__(
12691270
Whether to enable CUDAGraph for batch decode attention, if enabled, the
12701271
auxiliary data structures will be stored as the provided buffers. The ``batch_size``
12711272
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.
1272-
1273+
1274+
use_tensor_cores : bool
1275+
Whether to use tensor cores for the computation. Will be faster for large group
1276+
size in grouped query attention. Defaults to ``False``.
1277+
12731278
paged_kv_indptr_buffer : Optional[torch.Tensor]
12741279
The user reserved buffer on GPU to store the indptr of the paged kv cache, the size
12751280
of the buffer should be ``[batch_size + 1]``.
@@ -1319,6 +1324,7 @@ def __init__(
13191324
else:
13201325
self._fixed_batch_size = 0
13211326

1327+
self._use_tensor_cores = use_tensor_cores
13221328
self._paged_kv_indptr_buf = paged_kv_indptr_buffer
13231329
self._paged_kv_indices_buf = paged_kv_indices_buffer
13241330
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer
@@ -1328,6 +1334,10 @@ def __init__(
13281334
def is_cuda_graph_enabled(self) -> bool:
13291335
return self._use_cuda_graph
13301336

1337+
@property
1338+
def use_tensor_cores(self) -> bool:
1339+
return self._use_tensor_cores
1340+
13311341
def reset_workspace_buffer(
13321342
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
13331343
) -> None:
@@ -1445,8 +1455,10 @@ def plan(
14451455
q_data_type,
14461456
indptr.dtype,
14471457
head_dim_compressed_kv,
1458+
num_qo_heads,
14481459
window_left != -1, # use_sliding_window
14491460
logits_soft_cap > 0, # use_logits_soft_cap
1461+
self._use_tensor_cores,
14501462
)
14511463
with self.device as device:
14521464
self._plan_info = self._cached_module.plan(

flashinfer/jit/attention.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import jinja2
2323
import torch
2424

25-
from .core import load_cuda_ops, sm90a_nvcc_flags
25+
from .core import logger, load_cuda_ops, sm90a_nvcc_flags
2626
from .env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
2727
from .utils import (
2828
dtype_map,
@@ -216,20 +216,20 @@ def get_batch_decode_mla_uri(
216216
dtype_kv: torch.dtype,
217217
dtype_o: torch.dtype,
218218
dtype_idx: torch.dtype,
219-
head_dim_qk: int,
220-
head_dim_vo: int,
219+
head_dim_ckv: int,
221220
use_sliding_window: bool,
222221
use_logits_soft_cap: bool,
222+
arc: str,
223223
) -> str:
224224
return (
225225
f"batch_decode_mla_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
226226
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
227227
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
228228
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
229-
f"head_dim_qk_{head_dim_qk}_"
230-
f"head_dim_vo_{head_dim_vo}_"
229+
f"head_dim_ckv{head_dim_ckv}_"
231230
f"use_swa_{use_sliding_window}_"
232-
f"use_logits_cap_{use_logits_soft_cap}"
231+
f"use_logits_cap_{use_logits_soft_cap}_"
232+
f"arc_{arc}"
233233
)
234234

235235

@@ -239,18 +239,39 @@ def gen_batch_decode_mla_module(
239239
dtype_o: torch.dtype,
240240
dtype_idx: torch.dtype,
241241
head_dim: int,
242+
num_qo_heads: int,
242243
use_sliding_window: bool,
243244
use_logits_soft_cap: bool,
245+
use_tensor_cores: bool,
244246
):
247+
cuda_arch_major = torch.cuda.get_device_properties(0).major
248+
249+
if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data
250+
qo_tile_len = 128
251+
else:
252+
qo_tile_len = 64
253+
254+
if (
255+
use_tensor_cores and
256+
cuda_arch_major >= 8 and num_qo_heads % qo_tile_len == 0 and
257+
dtype_q == torch.float16 and dtype_kv == torch.float16 and
258+
dtype_o == torch.float16
259+
):
260+
logger.info(f"Use tensor-core SM80 version of MLA decode kernel.")
261+
arc = "sm80"
262+
else:
263+
logger.info(f"Fall back to cuda-core version of MLA decode kernel.")
264+
arc = "cuda_core"
265+
245266
uri = get_batch_decode_mla_uri(
246267
dtype_q,
247268
dtype_kv,
248269
dtype_o,
249270
dtype_idx,
250271
head_dim,
251-
head_dim,
252272
use_sliding_window,
253273
use_logits_soft_cap,
274+
arc,
254275
)
255276
gen_directory = FLASHINFER_GEN_SRC_DIR / uri
256277
os.makedirs(gen_directory, exist_ok=True)
@@ -267,17 +288,27 @@ def gen_batch_decode_mla_module(
267288
dtype_idx=dtype_map[dtype_idx],
268289
head_dim_ckv=head_dim,
269290
head_dim_kpe=head_dim // 8,
291+
qo_tile_len=qo_tile_len,
270292
use_sliding_window=str(use_sliding_window).lower(),
271293
use_logits_soft_cap=str(use_logits_soft_cap).lower(),
272294
),
273295
)
296+
297+
filenames = []
298+
if arc == "sm80":
299+
filenames = [
300+
"batch_decode_mla_cute_sm80.cu",
301+
"batch_decode_mla_pybind.cu",
302+
]
303+
else:
304+
filenames = [
305+
"batch_decode_mla_plan.cu",
306+
"batch_decode_mla_run.cu",
307+
"batch_decode_mla_pybind.cu",
308+
]
274309

275310
source_paths = []
276-
for filename in [
277-
"batch_decode_mla_plan.cu",
278-
"batch_decode_mla_run.cu",
279-
"batch_decode_mla_pybind.cu",
280-
]:
311+
for filename in filenames:
281312
src_path = FLASHINFER_CSRC_DIR / filename
282313
dest_path = gen_directory / filename
283314
source_paths.append(dest_path)

0 commit comments

Comments
 (0)