Skip to content

Commit 9a44330

Browse files
yzh119Edenzzzz
authored andcommitted
bugfix: fix fp8 attention kernels aot compilation issue (flashinfer-ai#1087)
<!-- .github/pull_request_template.md --> ## 📌 Description - `batch_indices_offset` (introduced in flashinfer-ai#1015 ) are not passed to fp8 attention kernels, this PR fixes the issue. - adding fp8 kernels to aot generators. ## 🔍 Related Issues flashinfer-ai#1064 --- ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). --- ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). --- ## Reviewer Notes cc @abcdabcd987
1 parent fc28711 commit 9a44330

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

csrc/batch_prefill_fp8_sm90.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
151151
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
152152
params.qo_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_len_offset);
153153
params.kv_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
154+
params.batch_indices =
155+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
154156
params.head_indices =
155157
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
156158
params.work_indptr =

flashinfer/aot.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def gen_fa2(
4040
) -> List[JitSpec]:
4141
if dtype_qo.itemsize == dtype_kv.itemsize and dtype_qo != dtype_kv:
4242
return []
43+
if dtype_qo.itemsize == 1:
44+
return [] # fp8 tensor cores not supported in fa2
4345
return [
4446
gen_single_prefill_module(
4547
backend="fa2",
@@ -91,24 +93,30 @@ def gen_fa2(
9193

9294

9395
def gen_fa3(
94-
dtype_qo: torch.dtype,
96+
dtype_q: torch.dtype,
9597
dtype_kv: torch.dtype,
98+
dtype_o: torch.dtype,
9699
head_dim_qk: int,
97100
head_dim_vo: int,
98101
use_sliding_window: bool,
99102
use_logits_soft_cap: bool,
100103
) -> List[JitSpec]:
101-
if dtype_qo.itemsize == dtype_kv.itemsize and dtype_qo != dtype_kv:
102-
return []
104+
if dtype_q != dtype_kv:
105+
return [] # fa3 template do not support mixed precision
106+
if dtype_q.itemsize == 2:
107+
if dtype_q != dtype_o:
108+
return [] # for fp16, dtype_o must be the same as dtype_q/dtype_kv
109+
103110
if dtype_kv.itemsize == 1:
104-
# fp8 kv not supported in FA3
105-
return []
111+
if head_dim_qk == 192 or head_dim_qk == 64:
112+
return [] # (192, 128) & (64, 64) not supported for fp8 yet.
113+
106114
return [
107115
gen_single_prefill_module(
108116
backend="fa3",
109-
dtype_q=dtype_qo,
117+
dtype_q=dtype_q,
110118
dtype_kv=dtype_kv,
111-
dtype_o=dtype_qo,
119+
dtype_o=dtype_o,
112120
head_dim_qk=head_dim_qk,
113121
head_dim_vo=head_dim_vo,
114122
pos_encoding_mode=0,
@@ -118,9 +126,9 @@ def gen_fa3(
118126
),
119127
gen_batch_prefill_module(
120128
backend="fa3",
121-
dtype_q=dtype_qo,
129+
dtype_q=dtype_q,
122130
dtype_kv=dtype_kv,
123-
dtype_o=dtype_qo,
131+
dtype_o=dtype_o,
124132
dtype_idx=torch.int32,
125133
head_dim_qk=head_dim_qk,
126134
head_dim_vo=head_dim_vo,
@@ -174,20 +182,21 @@ def gen_attention(
174182
if has_sm90:
175183
for (
176184
(head_dim_qk, head_dim_vo),
177-
dtype_qo,
178-
dtype_kv,
185+
dtype_qkv,
186+
dtype_o,
179187
use_sliding_window,
180188
use_logits_soft_cap,
181189
) in product(
182190
fa3_head_dim_,
183-
f16_dtype_,
184191
f16_dtype_ + f8_dtype_,
192+
f16_dtype_,
185193
use_sliding_window_,
186194
use_logits_soft_cap_,
187195
):
188196
jit_specs += gen_fa3(
189-
dtype_qo=dtype_qo,
190-
dtype_kv=dtype_kv,
197+
dtype_q=dtype_qkv,
198+
dtype_kv=dtype_qkv,
199+
dtype_o=dtype_o,
191200
head_dim_qk=head_dim_qk,
192201
head_dim_vo=head_dim_vo,
193202
use_sliding_window=use_sliding_window,
@@ -203,7 +212,7 @@ def gen_attention(
203212
) in product(
204213
f16_dtype_,
205214
f16_dtype_ + f8_dtype_,
206-
[(False, False), (True, True)],
215+
[(True, True)],
207216
):
208217
jit_specs += gen_fa2(
209218
dtype_qo=dtype_qo,
@@ -213,10 +222,20 @@ def gen_attention(
213222
use_sliding_window=use_sliding_window,
214223
use_logits_soft_cap=use_logits_soft_cap,
215224
)
216-
if has_sm90:
225+
if has_sm90:
226+
for (
227+
dtype_qkv,
228+
dtype_o,
229+
(use_sliding_window, use_logits_soft_cap),
230+
) in product(
231+
f16_dtype_ + f8_dtype_,
232+
f16_dtype_,
233+
[(True, True)],
234+
):
217235
jit_specs += gen_fa3(
218-
dtype_qo=dtype_qo,
219-
dtype_kv=dtype_kv,
236+
dtype_q=dtype_qkv,
237+
dtype_kv=dtype_qkv,
238+
dtype_o=dtype_o,
220239
head_dim_qk=256,
221240
head_dim_vo=256,
222241
use_sliding_window=use_sliding_window,

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,15 @@ cudaError_t BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched(Params& params
351351
});
352352

353353
typename Scheduler::Arguments scheduler_args = {
354-
params.work_indptr, params.head_indices,
355-
params.qo_tile_indices, params.qo_indptr,
356-
params.kv_indptr, params.qo_lens,
357-
params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads),
354+
params.work_indptr,
355+
params.head_indices,
356+
params.qo_tile_indices,
357+
params.qo_indptr,
358+
params.kv_indptr,
359+
params.qo_lens,
360+
params.kv_lens,
361+
params.batch_indices,
362+
cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads),
358363
params.num_qo_heads};
359364
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
360365

0 commit comments

Comments
 (0)