Skip to content

add multi-item scoring #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4bc79e6
add multi-item scoring
arde171 Apr 10, 2025
70bd358
fix precommit errors
arde171 Apr 11, 2025
1f79f0d
Merge pull request #1 from arde171/arde/mis
arde171 Apr 11, 2025
46fbae2
Merge branch 'flashinfer-ai:main' into main
arde171 Apr 11, 2025
3e8f974
fix clang
arde171 Apr 18, 2025
059fffc
fix unit test
arde171 Apr 19, 2025
3f76f38
update
arde171 Apr 20, 2025
6164dfe
additional params
arde171 Apr 25, 2025
8084962
fix
arde171 Apr 25, 2025
b15dafd
fixt
arde171 Apr 25, 2025
b3c5deb
fix
arde171 Apr 25, 2025
53de0e7
fix
arde171 Apr 25, 2025
a134677
fix
arde171 Apr 26, 2025
f0a4458
fix
arde171 Apr 26, 2025
6ea1d2a
fix
arde171 Apr 26, 2025
f0ee31d
fix
arde171 Apr 26, 2025
4ef78e5
fix
arde171 Apr 26, 2025
6c2f25e
fix
arde171 Apr 27, 2025
485e2bb
revert
arde171 Apr 27, 2025
68d87bd
fix
arde171 Apr 27, 2025
c959fcf
revert
arde171 Apr 27, 2025
32b8ea8
fix pybind
arde171 Apr 28, 2025
7e1d7a2
fix
arde171 Apr 28, 2025
c348fda
revert
arde171 Apr 28, 2025
33ff636
revert
arde171 Apr 28, 2025
5fbb4b2
incorporate review comments
arde171 Apr 28, 2025
0858050
Merge branch 'main' into main
arde171 Apr 28, 2025
bf2c2f3
typo
arde171 Apr 28, 2025
6225af8
fix
arde171 Apr 28, 2025
a1535a6
Merge remote-tracking branch 'origin/main' into arde171/main
yzh119 Apr 30, 2025
079c7fe
upd
yzh119 Apr 30, 2025
b917c64
else branch should be protected by constexpr
yzh119 Apr 30, 2025
e77eca5
upd
yzh119 Apr 30, 2025
24bfece
bugfix
yzh119 Apr 30, 2025
cef3255
fix conflicts with fp8 hopper
yzh119 Apr 30, 2025
4c19a0c
fix fp8
yzh119 Apr 30, 2025
b256c8f
bugfix
yzh119 Apr 30, 2025
f4629f4
bugfix
yzh119 Apr 30, 2025
06e11dd
fix decode
yzh119 Apr 30, 2025
e929fc4
remove lineinfo to fix binary size overflow
arde171 Apr 30, 2025
398b7a0
remove 3 from aot mask mode
yzh119 Apr 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions aot_build_utils/generate_aot_default_additional_params_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,23 +125,35 @@ def get_aot_default_additional_params_header_str() -> str:

ret += generate_macro_entry(
"BATCH_PREFILL",
["maybe_custom_mask", "maybe_mask_indptr", "maybe_alibi_slopes"],
["uint8_t", "int32_t", "float"],
[
"maybe_custom_mask",
"maybe_mask_indptr",
"maybe_alibi_slopes",
"maybe_prefix_len_ptr",
"maybe_token_pos_in_items_ptr",
"maybe_max_item_len_ptr",
],
["uint8_t", "int32_t", "float", "uint32_t", "uint16_t", "uint16_t"],
[
"logits_soft_cap",
"sm_scale",
"rope_rcp_scale",
"rope_rcp_theta",
"token_pos_in_items_len",
],
["double", "double", "double", "double"],
["double", "double", "double", "double", "int64_t"],
)

ret += generate_macro_entry(
"BATCH_PREFILL_SM90",
[],
[],
["logits_soft_cap", "sm_scale"],
["double", "double"],
[
"maybe_prefix_len_ptr",
"maybe_token_pos_in_items_ptr",
"maybe_max_item_len_ptr",
],
["uint32_t", "uint16_t", "uint16_t"],
["logits_soft_cap", "sm_scale", "token_pos_in_items_len"],
["double", "double", "int64_t"],
is_sm90_template=True,
)

Expand Down
1 change: 1 addition & 0 deletions aot_build_utils/literal_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
0: "MaskMode::kNone",
1: "MaskMode::kCausal",
2: "MaskMode::kCustom",
3: "MaskMode::kMultiItemScoring",
}

pos_encoding_mode_literal = {
Expand Down
4 changes: 4 additions & 0 deletions csrc/batch_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.batch_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);

ADDITIONAL_PARAMS_SETTER

Expand Down Expand Up @@ -238,6 +240,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.batch_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());

ADDITIONAL_PARAMS_SETTER
Expand Down
2 changes: 2 additions & 0 deletions csrc/batch_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct RaggedParams {
IdType* kv_lens;
IdType* head_indices;
IdType* work_indptr;
IdType* batch_indices;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this, yes we have to add it.


struct AdditionalParams {
{{ additional_params_decl }}
Expand Down Expand Up @@ -88,6 +89,7 @@ struct PagedParams {
IdType* kv_lens;
IdType* head_indices;
IdType* work_indptr;
IdType* batch_indices;

struct AdditionalParams {
{{ additional_params_decl }}
Expand Down
4 changes: 4 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,13 +1170,17 @@ def run(
None, # packed_custom_mask
None, # mask_indptr_buf
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
None, # maybe_prefix_len_ptr
None, # maybe_token_pos_in_items_ptr
None, # maybe_max_item_len_ptr
logits_soft_cap,
sm_scale,
None, # scale_q, not supported yet
None, # scale_k
None, # scale_v
rope_scale,
rope_theta,
0, # token_pos_in_items_len
]

self._cached_module.paged_run(*run_args)
Expand Down
39 changes: 27 additions & 12 deletions flashinfer/jit/attention/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,8 @@ def gen_customize_pod_module(

source_paths = []

for mask_mode_p in [0, 1, 2]:
for mask_mode_d in [0, 1, 2]:
for mask_mode_p in [0, 1, 2, 3]:
for mask_mode_d in [0, 1, 2, 3]:
kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p]
kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d]

Expand Down Expand Up @@ -759,27 +759,42 @@ def gen_batch_prefill_module(
"maybe_custom_mask",
"maybe_mask_indptr",
"maybe_alibi_slopes",
"maybe_prefix_len_ptr",
"maybe_token_pos_in_items_ptr",
"maybe_max_item_len_ptr",
]
additional_tensor_dtypes = [
"uint8_t",
"int32_t",
"float",
"uint32_t",
"uint16_t",
"uint16_t",
] # NOTE(Zihao): int32_t should follow dtype_idx
additional_scalar_names = [
"logits_soft_cap",
"sm_scale",
"rope_rcp_scale",
"rope_rcp_theta",
"token_pos_in_items_len",
]
additional_scalar_dtypes = ["double", "double", "double", "double"]
additional_scalar_dtypes = ["double", "double", "double", "double", "int64_t"]
variant_name = f"DefaultAttention<use_custom_mask, {str(use_sliding_window).lower()}, {str(use_logits_soft_cap).lower()}, {str(pos_encoding_mode == 2).lower()}>"
variant_decl = f"#include<flashinfer/attention/variants.cuh>"
variant_decl = "#include<flashinfer/attention/variants.cuh>"
else:
if not fp8_enabled:
additional_tensor_names = []
additional_tensor_dtypes = []
additional_scalar_names = ["logits_soft_cap", "sm_scale"]
additional_scalar_dtypes = ["double", "double"]
additional_tensor_names = [
"maybe_prefix_len_ptr",
"maybe_token_pos_in_items_ptr",
"maybe_max_item_len_ptr",
]
additional_tensor_dtypes = ["uint32_t", "uint16_t", "uint16_t"]
additional_scalar_names = [
"logits_soft_cap",
"sm_scale",
"token_pos_in_items_len",
]
additional_scalar_dtypes = ["double", "double", "int64_t"]
variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>"
variant_decl = f"#include<flashinfer/attention/hopper/variants.cuh>"
else:
Expand Down Expand Up @@ -961,7 +976,7 @@ def gen_customize_single_prefill_module(
os.makedirs(gen_directory, exist_ok=True)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
filename = f"single_prefill_kernel_mask_{mask_mode}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)
Expand Down Expand Up @@ -1025,7 +1040,7 @@ def gen_customize_single_prefill_module(
os.makedirs(gen_directory, exist_ok=True)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)
Expand Down Expand Up @@ -1209,7 +1224,7 @@ def gen_customize_batch_prefill_module(
os.makedirs(gen_directory, exist_ok=True)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
dest_path = (
gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}.cu"
)
Expand Down Expand Up @@ -1286,7 +1301,7 @@ def gen_customize_batch_prefill_module(
generated_inc_str = config_templ.render(**kwargs)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)
Expand Down
1 change: 1 addition & 0 deletions flashinfer/jit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ def wrapper(func, args):
0: "MaskMode::kNone",
1: "MaskMode::kCausal",
2: "MaskMode::kCustom",
3: "MaskMode::kMultiItemScoring",
}
Loading