Skip to content

Commit 26ebac7

Browse files
authored
[NVIDIA] Add Cutlass MLA backend (#1031)
This PR add a `cutlass` backend to the flashinfer `BatchMLAPagedAttentionWrapper`. cc @yzh119 @kushanam
1 parent cb5462d commit 26ebac7

File tree

6 files changed

+417
-3
lines changed

6 files changed

+417
-3
lines changed

csrc/cutlass_mla.cu

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/attention/cutlass_mla.cuh>
17+
18+
#include "pytorch_extension_utils.h"
19+
20+
using namespace flashinfer;
21+
using namespace flashinfer::attention;
22+
23+
void CutlassMLAPagedAttention(at::Tensor workspace, at::Tensor out, at::Tensor lse,
24+
at::Tensor q_nope_pe, at::Tensor ckv_kpe_cache, at::Tensor kv_lens,
25+
at::Tensor page_table) {
26+
const c10::cuda::OptionalCUDAGuard device_guard(q_nope_pe.device());
27+
auto stream = at::cuda::getCurrentCUDAStream();
28+
29+
int device_index = q_nope_pe.device().index();
30+
int batches = q_nope_pe.sizes()[0];
31+
int page_count_per_seq = page_table.sizes()[1];
32+
int page_count_total = ckv_kpe_cache.sizes()[0];
33+
int page_size = ckv_kpe_cache.sizes()[1];
34+
35+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_nope_pe.scalar_type(), c_type, [&] {
36+
using cutlass_t = cutlass_dtype_t<c_type>;
37+
auto status = runMla<cutlass_t>(
38+
workspace.data_ptr(), out.data_ptr(), lse.data_ptr(), q_nope_pe.data_ptr(),
39+
ckv_kpe_cache.data_ptr(), kv_lens.data_ptr(), page_table.data_ptr(), batches,
40+
page_count_per_seq, page_count_total, page_size, device_index, stream);
41+
TORCH_CHECK(status == cudaSuccess,
42+
"Failed to run CutlassMLAPagedAttention: ", cudaGetErrorString(status));
43+
return true;
44+
});
45+
}

csrc/flashinfer_mla_ops.cu

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) 2023 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "pytorch_extension_utils.h"
17+
18+
void CutlassMLAPagedAttention(at::Tensor workspace, at::Tensor out, at::Tensor lse,
19+
at::Tensor q_nope_pe, at::Tensor ckv_kpe_cache, at::Tensor kv_lens,
20+
at::Tensor page_table);
21+
22+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
23+
// "Cutlass MLA Paged Attention"
24+
m.def("cutlass_mla_paged_attention", CutlassMLAPagedAttention);
25+
}

flashinfer/mla.py

+96-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
import torch
2222

23-
from .jit import gen_batch_mla_module
23+
from .jit import FLASHINFER_CSRC_DIR, gen_batch_mla_module, load_cuda_ops
24+
from .jit.env import CUTLASS_INCLUDE_DIRS as CUTLASS_INCLUDE_DIRS
2425
from .utils import (
2526
MaskMode,
2627
_check_shape_dtype_device,
@@ -29,6 +30,57 @@
2930
register_fake_op,
3031
)
3132

33+
34+
def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table):
35+
if q_nope_pe.ndim != 3:
36+
raise ValueError(f"Expected q_nope_pe.ndim == 3, got {q_nope_pe.ndim}")
37+
if ckv_kpe_cache.ndim != 3:
38+
raise ValueError(f"Expected ckv_kpe_cache.ndim == 3, got {ckv_kpe_cache.ndim}")
39+
if kv_len.ndim != 1:
40+
raise ValueError(f"Expected kv_len.ndim == 1, got {kv_len.ndim}")
41+
if page_table.ndim != 2:
42+
raise ValueError(f"Expected page_table.ndim == 2, got {page_table.ndim}")
43+
B_q, H, D_q = q_nope_pe.shape
44+
D_ckv = ckv_kpe_cache.shape[2]
45+
if H != 128:
46+
raise ValueError(f"Expected 128 heads for q_nope_pe, got {H}")
47+
if D_q != D_ckv or D_q != 576:
48+
raise ValueError(
49+
f"Expected head dim 576 for q_nope_pe and ckv_kpe_cache, got {D_q} and {D_ckv}"
50+
)
51+
B_block_table, block_num = page_table.shape
52+
block_size = ckv_kpe_cache.shape[1]
53+
if B_q != B_block_table:
54+
raise ValueError(
55+
f"Expected batch size {B_q} for q_nope_pe and block_table, got {B_q} and {B_block_table}"
56+
)
57+
if block_num % (128 / block_size) != 0:
58+
raise ValueError(
59+
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
60+
)
61+
62+
63+
_mla_module = None
64+
65+
66+
def get_mla_module():
67+
global _mla_module
68+
if _mla_module is None:
69+
_mla_module = load_cuda_ops(
70+
"mla",
71+
[
72+
FLASHINFER_CSRC_DIR / "cutlass_mla.cu",
73+
FLASHINFER_CSRC_DIR / "flashinfer_mla_ops.cu",
74+
],
75+
extra_include_paths=[
76+
CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "77_blackwell_fmha",
77+
CUTLASS_INCLUDE_DIRS[0] / ".." / "examples" / "common",
78+
],
79+
extra_cuda_cflags=["-gencode", "arch=compute_100a,code=sm_100a"],
80+
)
81+
return _mla_module
82+
83+
3284
_batch_mla_modules = {}
3385
_batch_mla_sm90_modules = {}
3486

@@ -152,10 +204,17 @@ def __init__(
152204
backend : str
153205
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
154206
If set to ``auto``, the function will automatically choose the backend based on the
155-
device architecture and kernel availability.
207+
device architecture and kernel availability. If ``cutlass`` is provided, the MLA
208+
kernels will be generated by CUTLASS and only float_workspace_buffer is required and
209+
other arguments are ignored.
156210
"""
157211
self._float_workspace_buffer = float_workspace_buffer
158212
self.device = float_workspace_buffer.device
213+
214+
if backend == "cutlass":
215+
self._backend = backend
216+
return
217+
159218
self._int_workspace_buffer = torch.empty(
160219
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
161220
)
@@ -294,6 +353,8 @@ def run(
294353
lse: Optional[torch.Tensor] = None,
295354
return_lse: bool = False,
296355
profiler_buffer: Optional[torch.Tensor] = None,
356+
kv_len: Optional[torch.Tensor] = None,
357+
page_table: Optional[torch.Tensor] = None,
297358
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
298359
r"""Run the MLA attention computation.
299360
@@ -317,7 +378,40 @@ def run(
317378
Whether to return the log-sum-exp value, default is False.
318379
profiler_buffer : Optional[torch.Tensor]
319380
The buffer to store the profiler data.
381+
kv_len : Optional[torch.Tensor]
382+
The query length of each request, shape: ``[batch_size]``. Required when ``backend`` is ``cutlass``.
383+
page_table : Optional[torch.Tensor]
384+
The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``.
320385
"""
386+
if self._backend == "cutlass":
387+
if return_lse:
388+
raise ValueError("return_lse does not support cutlass backend for now.")
389+
if profiler_buffer is not None:
390+
raise ValueError(
391+
"profiler_buffer does not support cutlass backend for now."
392+
)
393+
self._cached_module = get_mla_module()
394+
if out is None:
395+
out = torch.empty_like(q_nope)
396+
else:
397+
_check_shape_dtype_device(
398+
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
399+
)
400+
q_nope_pe = torch.cat([q_nope, q_pe], dim=-1)
401+
ckv_kpe_cache = torch.cat([ckv_cache, kpe_cache], dim=-1)
402+
_check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table)
403+
lse = torch.empty(0, dtype=torch.float32, device=self.device)
404+
self._cached_module.cutlass_mla_paged_attention.default(
405+
self._float_workspace_buffer,
406+
out,
407+
lse,
408+
q_nope_pe,
409+
ckv_kpe_cache,
410+
kv_len,
411+
page_table,
412+
)
413+
return out
414+
321415
if profiler_buffer is None:
322416
if self._use_profiler:
323417
raise ValueError(

flashinfer/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,11 @@ def is_sm90a_supported(device: torch.device) -> bool:
359359
return major == 9 and torch.version.cuda >= "12.3"
360360

361361

362+
def is_sm100a_supported(device: torch.device) -> bool:
363+
major, minor = get_compute_capability(device)
364+
return major == 10 and minor == 0 and torch.version.cuda >= "12.9"
365+
366+
362367
def determine_mla_backend(device: torch.device) -> str:
363368
return "fa3" if is_sm90a_supported(device) else "fa2"
364369

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_
17+
#define FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_
18+
#include <sstream>
19+
20+
#include "../cutlass_utils.cuh"
21+
#include "../exception.h"
22+
#include "cutlass/kernel_hardware_info.h"
23+
24+
// From 3rdparty/cutlass/examples/77_blackwell_fmha
25+
#include "device/sm100_mla.hpp"
26+
#include "kernel/sm100_mla_tile_scheduler.hpp"
27+
28+
namespace flashinfer {
29+
30+
namespace attention {
31+
32+
using namespace cute;
33+
using namespace cutlass::fmha::kernel;
34+
35+
#define CUTLASS_CHECK(cmd) \
36+
do { \
37+
auto status = cmd; \
38+
if (status != cutlass::Status::kSuccess) { \
39+
std::ostringstream err_msg; \
40+
err_msg << "cutlass " << #cmd << " failed: " << cutlassGetStatusString(status); \
41+
FLASHINFER_ERROR(err_msg.str()); \
42+
} \
43+
} while (0)
44+
45+
template <bool v>
46+
struct IsPersistent {
47+
static const bool value = v;
48+
};
49+
50+
template <typename T, typename PersistenceOption = IsPersistent<true>>
51+
struct MlaSm100 {
52+
using Element = T;
53+
using ElementAcc = float;
54+
using ElementOut = T;
55+
56+
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
57+
using TileShapeH = cute::tuple_element_t<0, TileShape>;
58+
using TileShapeD = cute::tuple_element_t<2, TileShape>;
59+
60+
// H K (D_latent D_rope) B
61+
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
62+
63+
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
64+
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
65+
using StrideO = StrideK; // H D B
66+
using StrideLSE = cute::tuple<_1, int>; // H B
67+
68+
using TileScheduler =
69+
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler,
70+
Sm100MlaIndividualTileScheduler>;
71+
72+
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
73+
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, /*kIsCpAsync=*/true>;
74+
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
75+
};
76+
77+
template <typename T>
78+
typename T::Fmha::Arguments args_from_options(void* out_ptr, void* lse_ptr, void* q_absorbed_ptr,
79+
void* ckv_kpe_cache_ptr, void* seq_lens_ptr,
80+
void* page_table_ptr, int batches,
81+
int page_count_per_seq, int page_count_total,
82+
int page_size, int device_index) {
83+
cutlass::KernelHardwareInfo hw_info;
84+
hw_info.device_id = device_index;
85+
hw_info.sm_count =
86+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
87+
88+
int max_seq_len = page_size * page_count_per_seq;
89+
using TileShapeH = typename T::TileShapeH;
90+
using TileShapeD = typename T::TileShapeD;
91+
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
92+
93+
auto [H, K, D, B] = problem_shape;
94+
auto [D_latent, D_rope] = D;
95+
96+
// the scale is based on the non-absorbed sizes, change as appropriate
97+
// we can't determine this parameter from the info we have, it's an input
98+
int D_non_latent = 128;
99+
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));
100+
101+
using StrideQ = typename T::StrideQ;
102+
using StrideK = typename T::StrideK;
103+
using StrideO = typename T::StrideO;
104+
using StrideLSE = typename T::StrideLSE;
105+
106+
StrideQ stride_Q = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
107+
static_cast<int64_t>(H * (0 + D_latent + D_rope)));
108+
StrideK stride_C = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
109+
static_cast<int64_t>(page_size * (D_latent + D_rope)));
110+
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
111+
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
112+
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{},
113+
static_cast<int64_t>(0 + H * D_latent));
114+
115+
using Element = typename T::Element;
116+
using ElementOut = typename T::ElementOut;
117+
using ElementAcc = typename T::ElementAcc;
118+
auto Q_ptr = reinterpret_cast<Element*>(q_absorbed_ptr);
119+
auto C_ptr = reinterpret_cast<Element*>(ckv_kpe_cache_ptr);
120+
typename T::Fmha::Arguments arguments{
121+
problem_shape,
122+
{scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C, C_ptr + D_latent,
123+
stride_C, reinterpret_cast<int*>(seq_lens_ptr), reinterpret_cast<int*>(page_table_ptr),
124+
stride_PT, page_count_total, page_size},
125+
{reinterpret_cast<ElementOut*>(out_ptr), stride_O,
126+
// static_cast<ElementAcc*>(lse.data_ptr()), stride_LSE},
127+
static_cast<ElementAcc*>(nullptr), stride_LSE},
128+
hw_info,
129+
-1, // split_kv
130+
nullptr, // is_var_split_kv=false
131+
};
132+
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
133+
// split_kv automatically based on batch size and sequence length to balance
134+
// workload across available SMs. Consider using var_split_kv for manual
135+
// control if needed.
136+
T::Fmha::set_split_kv(arguments);
137+
return arguments;
138+
}
139+
140+
template <typename Element>
141+
cudaError_t runMla(void* workspace_ptr, void* out_ptr, void* lse_ptr, void* q_absorbed_ptr,
142+
void* ckv_kpe_cache_ptr, void* seq_lens_ptr, void* page_table_ptr, int batches,
143+
int page_count_per_seq, int page_count_total, int page_size, int device_index,
144+
cudaStream_t stream) {
145+
using MlaSm100Type = MlaSm100<Element>;
146+
typename MlaSm100Type::Fmha fmha;
147+
auto arguments = args_from_options<MlaSm100Type>(
148+
out_ptr, lse_ptr, q_absorbed_ptr, ckv_kpe_cache_ptr, seq_lens_ptr, page_table_ptr, batches,
149+
page_count_per_seq, page_count_total, page_size, device_index);
150+
151+
CUTLASS_CHECK(fmha.can_implement(arguments));
152+
153+
CUTLASS_CHECK(fmha.initialize(arguments, workspace_ptr, stream));
154+
155+
CUTLASS_CHECK(fmha.run(arguments, workspace_ptr, stream));
156+
157+
return cudaSuccess;
158+
}
159+
160+
} // namespace attention
161+
162+
} // namespace flashinfer
163+
#endif // FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_

0 commit comments

Comments
 (0)