Skip to content

Commit 6c6f1a5

Browse files
arde171qingquansongzianglihyzh119
authored
add multi-item scoring (#1015)
Co-authored with Qingquan Song (@qingquansong) and Ziang Li (@zianglih ) **Multi-item scoring** 1. concatenate multiple candidates of a same member with all ranking candidates with delimiter separation. <member prefix (profile & history)> + <delimiter> + <item1> + <delimiter> + item 2 + ... + item N <delimiter> 2. Extract the logits of the hidden states of the tokens before each delimiter token and extract the log prob of given label tokens. For each single prompt, output returned will be a 2D list with shape N * K where N is the number of candidate it contains and K is the number of choices we provided to the server engine (e.g., 2 for ["Yes", "No"])) (mainly done in the logit processor) ![image](https://github.com/user-attachments/assets/837ea448-99e0-4e75-b6ec-5d7bf89b53dc) The PR optimized the multi-item scoring attention by passing four new args and use it to check the masking condition. The provided args are: ``` prefix_len_ptr :Optional[torch.Tensor] prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size. token_pos_in_items_ptr : Optional[float] A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1, we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by `token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt. token_pos_in_items_len : Optional[int] zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example. If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]` with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt) max_item_len_ptr : Optional[float] a uint16 vector contains the max token length of all items for each prompt ``` **Optimizations** 1. Implement efficient multi-item scoring mask for FA2 and FA3. 2. Enhance FA3 to support batch-idx for the multi-item scoring mask. 3. Implement skip tiles for FA2 and FA3 multi-item scoring 4. Optimize mask by preloading to L1 cache for thread register. --------- Co-authored-by: qingquansong <[email protected]> Co-authored-by: zianglih <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent 116d97d commit 6c6f1a5

33 files changed

+1132
-159
lines changed

aot_build_utils/generate_aot_default_additional_params_header.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,23 +125,35 @@ def get_aot_default_additional_params_header_str() -> str:
125125

126126
ret += generate_macro_entry(
127127
"BATCH_PREFILL",
128-
["maybe_custom_mask", "maybe_mask_indptr", "maybe_alibi_slopes"],
129-
["uint8_t", "int32_t", "float"],
128+
[
129+
"maybe_custom_mask",
130+
"maybe_mask_indptr",
131+
"maybe_alibi_slopes",
132+
"maybe_prefix_len_ptr",
133+
"maybe_token_pos_in_items_ptr",
134+
"maybe_max_item_len_ptr",
135+
],
136+
["uint8_t", "int32_t", "float", "uint32_t", "uint16_t", "uint16_t"],
130137
[
131138
"logits_soft_cap",
132139
"sm_scale",
133140
"rope_rcp_scale",
134141
"rope_rcp_theta",
142+
"token_pos_in_items_len",
135143
],
136-
["double", "double", "double", "double"],
144+
["double", "double", "double", "double", "int64_t"],
137145
)
138146

139147
ret += generate_macro_entry(
140148
"BATCH_PREFILL_SM90",
141-
[],
142-
[],
143-
["logits_soft_cap", "sm_scale"],
144-
["double", "double"],
149+
[
150+
"maybe_prefix_len_ptr",
151+
"maybe_token_pos_in_items_ptr",
152+
"maybe_max_item_len_ptr",
153+
],
154+
["uint32_t", "uint16_t", "uint16_t"],
155+
["logits_soft_cap", "sm_scale", "token_pos_in_items_len"],
156+
["double", "double", "int64_t"],
145157
is_sm90_template=True,
146158
)
147159

aot_build_utils/literal_map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
0: "MaskMode::kNone",
1919
1: "MaskMode::kCausal",
2020
2: "MaskMode::kCustom",
21+
3: "MaskMode::kMultiItemScoring",
2122
}
2223

2324
pos_encoding_mode_literal = {

csrc/batch_prefill_sm90.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
141141
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
142142
params.work_indptr =
143143
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
144+
params.batch_indices =
145+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
144146

145147
ADDITIONAL_PARAMS_SETTER
146148

@@ -238,6 +240,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
238240
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
239241
params.work_indptr =
240242
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
243+
params.batch_indices =
244+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
241245
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());
242246

243247
ADDITIONAL_PARAMS_SETTER

csrc/batch_prefill_sm90_customize_config.jinja

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct RaggedParams {
4343
IdType* kv_lens;
4444
IdType* head_indices;
4545
IdType* work_indptr;
46+
IdType* batch_indices;
4647

4748
struct AdditionalParams {
4849
{{ additional_params_decl }}
@@ -88,6 +89,7 @@ struct PagedParams {
8889
IdType* kv_lens;
8990
IdType* head_indices;
9091
IdType* work_indptr;
92+
IdType* batch_indices;
9193

9294
struct AdditionalParams {
9395
{{ additional_params_decl }}

flashinfer/decode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,13 +1170,17 @@ def run(
11701170
None, # packed_custom_mask
11711171
None, # mask_indptr_buf
11721172
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
1173+
None, # maybe_prefix_len_ptr
1174+
None, # maybe_token_pos_in_items_ptr
1175+
None, # maybe_max_item_len_ptr
11731176
logits_soft_cap,
11741177
sm_scale,
11751178
None, # scale_q, not supported yet
11761179
None, # scale_k
11771180
None, # scale_v
11781181
rope_scale,
11791182
rope_theta,
1183+
0, # token_pos_in_items_len
11801184
]
11811185

11821186
self._cached_module.paged_run(*run_args)

flashinfer/jit/attention/pytorch.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,8 @@ def gen_customize_pod_module(
645645

646646
source_paths = []
647647

648-
for mask_mode_p in [0, 1, 2]:
649-
for mask_mode_d in [0, 1, 2]:
648+
for mask_mode_p in [0, 1, 2, 3]:
649+
for mask_mode_d in [0, 1, 2, 3]:
650650
kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p]
651651
kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d]
652652

@@ -759,27 +759,42 @@ def gen_batch_prefill_module(
759759
"maybe_custom_mask",
760760
"maybe_mask_indptr",
761761
"maybe_alibi_slopes",
762+
"maybe_prefix_len_ptr",
763+
"maybe_token_pos_in_items_ptr",
764+
"maybe_max_item_len_ptr",
762765
]
763766
additional_tensor_dtypes = [
764767
"uint8_t",
765768
"int32_t",
766769
"float",
770+
"uint32_t",
771+
"uint16_t",
772+
"uint16_t",
767773
] # NOTE(Zihao): int32_t should follow dtype_idx
768774
additional_scalar_names = [
769775
"logits_soft_cap",
770776
"sm_scale",
771777
"rope_rcp_scale",
772778
"rope_rcp_theta",
779+
"token_pos_in_items_len",
773780
]
774-
additional_scalar_dtypes = ["double", "double", "double", "double"]
781+
additional_scalar_dtypes = ["double", "double", "double", "double", "int64_t"]
775782
variant_name = f"DefaultAttention<use_custom_mask, {str(use_sliding_window).lower()}, {str(use_logits_soft_cap).lower()}, {str(pos_encoding_mode == 2).lower()}>"
776-
variant_decl = f"#include<flashinfer/attention/variants.cuh>"
783+
variant_decl = "#include<flashinfer/attention/variants.cuh>"
777784
else:
778785
if not fp8_enabled:
779-
additional_tensor_names = []
780-
additional_tensor_dtypes = []
781-
additional_scalar_names = ["logits_soft_cap", "sm_scale"]
782-
additional_scalar_dtypes = ["double", "double"]
786+
additional_tensor_names = [
787+
"maybe_prefix_len_ptr",
788+
"maybe_token_pos_in_items_ptr",
789+
"maybe_max_item_len_ptr",
790+
]
791+
additional_tensor_dtypes = ["uint32_t", "uint16_t", "uint16_t"]
792+
additional_scalar_names = [
793+
"logits_soft_cap",
794+
"sm_scale",
795+
"token_pos_in_items_len",
796+
]
797+
additional_scalar_dtypes = ["double", "double", "int64_t"]
783798
variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>"
784799
variant_decl = f"#include<flashinfer/attention/hopper/variants.cuh>"
785800
else:
@@ -961,7 +976,7 @@ def gen_customize_single_prefill_module(
961976
os.makedirs(gen_directory, exist_ok=True)
962977

963978
source_paths = []
964-
for mask_mode in [0, 1, 2]:
979+
for mask_mode in [0, 1, 2, 3]:
965980
filename = f"single_prefill_kernel_mask_{mask_mode}.cu"
966981
dest_path = gen_directory / filename
967982
source_paths.append(dest_path)
@@ -1025,7 +1040,7 @@ def gen_customize_single_prefill_module(
10251040
os.makedirs(gen_directory, exist_ok=True)
10261041

10271042
source_paths = []
1028-
for mask_mode in [0, 1, 2]:
1043+
for mask_mode in [0, 1, 2, 3]:
10291044
filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu"
10301045
dest_path = gen_directory / filename
10311046
source_paths.append(dest_path)
@@ -1209,7 +1224,7 @@ def gen_customize_batch_prefill_module(
12091224
os.makedirs(gen_directory, exist_ok=True)
12101225

12111226
source_paths = []
1212-
for mask_mode in [0, 1, 2]:
1227+
for mask_mode in [0, 1, 2, 3]:
12131228
dest_path = (
12141229
gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}.cu"
12151230
)
@@ -1286,7 +1301,7 @@ def gen_customize_batch_prefill_module(
12861301
generated_inc_str = config_templ.render(**kwargs)
12871302

12881303
source_paths = []
1289-
for mask_mode in [0, 1, 2]:
1304+
for mask_mode in [0, 1, 2, 3]:
12901305
filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu"
12911306
dest_path = gen_directory / filename
12921307
source_paths.append(dest_path)

flashinfer/jit/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,5 @@ def wrapper(func, args):
9898
0: "MaskMode::kNone",
9999
1: "MaskMode::kCausal",
100100
2: "MaskMode::kCustom",
101+
3: "MaskMode::kMultiItemScoring",
101102
}

0 commit comments

Comments
 (0)