Skip to content

Commit eef0ada

Browse files
cyx-6yzh119
authored andcommitted
wrapper
1 parent a79f9d4 commit eef0ada

File tree

2 files changed

+265
-9
lines changed

2 files changed

+265
-9
lines changed

flashinfer/prefill.py

+251-3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def get_fmha_module(
7272
pos_encoding_mode: PosEncodingMode,
7373
use_sliding_window: bool,
7474
use_logits_soft_cap: bool,
75+
use_fp16_qk_reduction: bool = False,
7576
):
7677
if is_sm100a_supported(torch.device("cuda")):
7778
return gen_fmha_cutlass_sm100a_module(
@@ -2366,9 +2367,12 @@ def plan(
23662367
logits_soft_cap > 0, # use_logits_soft_cap
23672368
use_fp16_qk_reduction,
23682369
)
2369-
self._cached_module = get_batch_prefill_module(self._backend)(
2370-
*get_module_args
2371-
)
2370+
if self._backend == "cutlass":
2371+
self._cached_module = get_cutlass_mha_module()(*get_module_args)
2372+
else:
2373+
self._cached_module = get_batch_prefill_module(self._backend)(
2374+
*get_module_args
2375+
)
23722376

23732377
self._plan_info = self._cached_module.plan(
23742378
self._float_workspace_buffer,
@@ -2727,3 +2731,247 @@ def fmha_varlen(
27272731
lse = lse_padded
27282732

27292733
return out, lse
2734+
2735+
2736+
def get_cutlass_mha_module():
2737+
def backend_module(*args):
2738+
modules_dict = _batch_prefill_modules
2739+
2740+
if args not in modules_dict:
2741+
uri = get_batch_prefill_uri("cutlass", *args)
2742+
module = get_fmha_module(*args)
2743+
2744+
@register_custom_op(
2745+
f"flashinfer::{uri}_ragged_run",
2746+
mutates_args=(
2747+
"float_workspace_buffer",
2748+
"int_workspace_buffer",
2749+
"o",
2750+
"maybe_lse",
2751+
),
2752+
)
2753+
def ragged_run(
2754+
float_workspace_buffer: torch.Tensor,
2755+
int_workspace_buffer: torch.Tensor,
2756+
plan_info_vec: List[int],
2757+
q: torch.Tensor,
2758+
k: torch.Tensor,
2759+
v: torch.Tensor,
2760+
qo_indptr: torch.Tensor,
2761+
kv_indptr: torch.Tensor,
2762+
o: torch.Tensor,
2763+
maybe_lse: Optional[torch.Tensor],
2764+
mask_mode: int,
2765+
layout: int,
2766+
window_left: int,
2767+
maybe_custom_mask: Optional[torch.Tensor],
2768+
maybe_mask_indptr: Optional[torch.Tensor],
2769+
maybe_alibi_slopes: Optional[torch.Tensor],
2770+
logits_soft_cap: float,
2771+
sm_scale: float,
2772+
rope_scale: float,
2773+
rope_theta: float,
2774+
) -> None:
2775+
nnz_qo, num_qo_heads, head_dim_qk = q.shape
2776+
nnz_kv, num_kv_heads, head_dim_vo = v.shape
2777+
2778+
sm_scale = 1.0 / math.sqrt(head_dim_qk)
2779+
2780+
qo_lens = qo_indptr[1:] - qo_indptr[:-1]
2781+
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
2782+
batch_size = qo_lens.shape[0]
2783+
max_qo_len = qo_lens.max()
2784+
max_kv_len = kv_lens.max()
2785+
2786+
q_padded = torch.cat(
2787+
[
2788+
torch.zeros(
2789+
max(max_qo_len, 128),
2790+
q.shape[1],
2791+
q.shape[2],
2792+
device=q.device,
2793+
dtype=q.dtype,
2794+
),
2795+
q,
2796+
],
2797+
dim=0,
2798+
)[max(max_qo_len, 128) :]
2799+
2800+
qo_total_len = nnz_qo
2801+
2802+
k_padded = torch.cat(
2803+
[
2804+
torch.zeros(
2805+
max(max_kv_len, 128),
2806+
k.shape[1],
2807+
k.shape[2],
2808+
device=k.device,
2809+
dtype=k.dtype,
2810+
),
2811+
k,
2812+
],
2813+
dim=0,
2814+
)[max(max_kv_len, 128) :]
2815+
v_padded = torch.cat(
2816+
[
2817+
torch.zeros(
2818+
max(max_kv_len, 128),
2819+
v.shape[1],
2820+
v.shape[2],
2821+
device=v.device,
2822+
dtype=v.dtype,
2823+
),
2824+
v,
2825+
],
2826+
dim=0,
2827+
)[max(max_kv_len, 128) :]
2828+
2829+
if o is None:
2830+
out_padded = torch.empty(
2831+
qo_total_len + max(max_qo_len, 128),
2832+
num_qo_heads,
2833+
head_dim_vo,
2834+
device=q.device,
2835+
dtype=q.dtype,
2836+
)[max(max_qo_len, 128) :]
2837+
else:
2838+
out_padded = o
2839+
2840+
if maybe_lse is None:
2841+
lse_padded = torch.empty(
2842+
qo_total_len, num_qo_heads, device=q.device, dtype=torch.float32
2843+
)
2844+
else:
2845+
lse_padded = maybe_lse
2846+
2847+
module.run(
2848+
q_padded,
2849+
k_padded,
2850+
v_padded,
2851+
qo_lens,
2852+
kv_lens,
2853+
qo_indptr,
2854+
kv_indptr,
2855+
out_padded,
2856+
lse_padded,
2857+
mask_mode,
2858+
sm_scale,
2859+
num_qo_heads,
2860+
num_kv_heads,
2861+
head_dim_qk,
2862+
batch_size,
2863+
nnz_qo,
2864+
nnz_kv,
2865+
max_qo_len,
2866+
max_kv_len,
2867+
)
2868+
2869+
o = out_padded
2870+
maybe_lse = lse_padded
2871+
2872+
return o, maybe_lse
2873+
2874+
@register_custom_op(
2875+
f"flashinfer::{uri}_paged_run",
2876+
mutates_args=(
2877+
"float_workspace_buffer",
2878+
"int_workspace_buffer",
2879+
"paged_k_cache",
2880+
"paged_v_cache",
2881+
"o",
2882+
"maybe_lse",
2883+
),
2884+
)
2885+
def paged_run(
2886+
float_workspace_buffer: torch.Tensor,
2887+
int_workspace_buffer: torch.Tensor,
2888+
plan_info_vec: List[int],
2889+
q: torch.Tensor,
2890+
paged_k_cache: torch.Tensor,
2891+
paged_v_cache: torch.Tensor,
2892+
qo_indptr: torch.Tensor,
2893+
paged_kv_indptr: torch.Tensor,
2894+
paged_kv_indices: torch.Tensor,
2895+
paged_kv_last_page_len: torch.Tensor,
2896+
o: torch.Tensor,
2897+
maybe_lse: Optional[torch.Tensor],
2898+
mask_mode: int,
2899+
layout: int,
2900+
window_left: int,
2901+
maybe_custom_mask: Optional[torch.Tensor],
2902+
maybe_mask_indptr: Optional[torch.Tensor],
2903+
maybe_alibi_slopes: Optional[torch.Tensor],
2904+
logits_soft_cap: float,
2905+
sm_scale: float,
2906+
rope_scale: float,
2907+
rope_theta: float,
2908+
) -> None:
2909+
pass
2910+
2911+
@register_fake_op(f"flashinfer::{uri}_ragged_run")
2912+
def _fake_ragged_run(
2913+
float_workspace_buffer: torch.Tensor,
2914+
int_workspace_buffer: torch.Tensor,
2915+
plan_info_vec: List[int],
2916+
q: torch.Tensor,
2917+
k: torch.Tensor,
2918+
v: torch.Tensor,
2919+
qo_indptr: torch.Tensor,
2920+
kv_indptr: torch.Tensor,
2921+
o: torch.Tensor,
2922+
maybe_lse: Optional[torch.Tensor],
2923+
mask_mode: int,
2924+
layout: int,
2925+
window_left: int,
2926+
maybe_custom_mask: Optional[torch.Tensor],
2927+
maybe_mask_indptr: Optional[torch.Tensor],
2928+
maybe_alibi_slopes: Optional[torch.Tensor],
2929+
logits_soft_cap: float,
2930+
sm_scale: float,
2931+
rope_scale: float,
2932+
rope_theta: float,
2933+
) -> None:
2934+
pass
2935+
2936+
@register_fake_op(f"flashinfer::{uri}_paged_run")
2937+
def _fake_paged_run(
2938+
float_workspace_buffer: torch.Tensor,
2939+
int_workspace_buffer: torch.Tensor,
2940+
plan_info_vec: List[int],
2941+
q: torch.Tensor,
2942+
paged_k_cache: torch.Tensor,
2943+
paged_v_cache: torch.Tensor,
2944+
qo_indptr: torch.Tensor,
2945+
paged_kv_indptr: torch.Tensor,
2946+
paged_kv_indices: torch.Tensor,
2947+
paged_kv_last_page_len: torch.Tensor,
2948+
o: torch.Tensor,
2949+
maybe_lse: Optional[torch.Tensor],
2950+
mask_mode: int,
2951+
layout: int,
2952+
window_left: int,
2953+
maybe_custom_mask: Optional[torch.Tensor],
2954+
maybe_mask_indptr: Optional[torch.Tensor],
2955+
maybe_alibi_slopes: Optional[torch.Tensor],
2956+
logits_soft_cap: float,
2957+
sm_scale: float,
2958+
rope_scale: float,
2959+
rope_theta: float,
2960+
) -> None:
2961+
pass
2962+
2963+
def plan(*args):
2964+
pass
2965+
2966+
# Register the module.
2967+
#
2968+
# Note that plan is not part of model logic. It should not be included in
2969+
# Cuda Graph or torch.compile. So, we don't provide a torch library for plan.
2970+
modules_dict[args] = SimpleNamespace(
2971+
plan=plan,
2972+
ragged_run=ragged_run,
2973+
paged_run=paged_run,
2974+
)
2975+
return modules_dict[args]
2976+
2977+
return backend_module

tests/test_blackwell_fmha.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def test_blackwell_cutlass_fmha(
8282
q = torch.randn(
8383
batch_size * qo_len, num_qo_heads, head_dim, dtype=dtype, device="cuda"
8484
)
85-
qo_segment_offsets = (
86-
torch.arange(batch_size + 1, device="cuda", dtype=torch.int32) * qo_len
85+
qo_indptr = (
86+
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len
8787
)
8888

8989
k = torch.randn(
@@ -92,14 +92,21 @@ def test_blackwell_cutlass_fmha(
9292
v = torch.randn(
9393
batch_size * kv_len, num_kv_heads, head_dim, dtype=dtype, device="cuda"
9494
)
95-
kv_segment_offsets = (
96-
torch.arange(batch_size + 1, device="cuda", dtype=torch.int32) * kv_len
95+
kv_indptr = (
96+
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len
9797
)
9898

99-
o, lse = flashinfer.prefill.fmha_varlen(
100-
q, k, v, qo_segment_offsets, kv_segment_offsets, causal=causal
99+
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
100+
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
101+
workspace_buffer, kv_layout, backend="cutlass"
101102
)
102103

104+
wrapper.plan(
105+
qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal
106+
)
107+
108+
o, lse = wrapper.run(q, k, v, return_lse=True)
109+
103110
sm_scale = 1.0 / (head_dim**0.5)
104111
gqa_group_ratio = num_qo_heads // num_kv_heads
105112
k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1)
@@ -128,6 +135,7 @@ def test_blackwell_cutlass_fmha(
128135
17,
129136
17,
130137
1,
138+
1,
131139
128,
132140
True,
133141
torch.half,

0 commit comments

Comments
 (0)