Skip to content

Commit 4bc79e6

Browse files
arde171qingquansongzianglih
committed
add multi-item scoring
Co-authored-by: qingquansong <[email protected]> Co-authored-by: zianglih <[email protected]>
1 parent f579ca2 commit 4bc79e6

33 files changed

+1101
-106
lines changed

aot_build_utils/literal_map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
0: "MaskMode::kNone",
1919
1: "MaskMode::kCausal",
2020
2: "MaskMode::kCustom",
21+
3: "MaskMode::kMultiItemScoring"
2122
}
2223

2324
pos_encoding_mode_literal = {

csrc/batch_prefill.cu

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ at::Tensor BatchPrefillWithKVCachePlan(
4545
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
4646
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
4747
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
48-
int64_t head_dim_vo, bool causal) {
48+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr, std::optional<at::Tensor> token_pos_in_items_ptr,
49+
std::optional<int64_t> token_pos_in_items_len, std::optional<at::Tensor> max_item_len_ptr) {
4950
size_t float_workspace_size_in_bytes =
5051
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
5152
size_t int_workspace_size_in_bytes =
@@ -55,12 +56,18 @@ at::Tensor BatchPrefillWithKVCachePlan(
5556

5657
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
5758
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
59+
// Check if the optional values have a value before accessing them
60+
auto* prefix_len_p = prefix_len_ptr.has_value() ? prefix_len_ptr->data_ptr() : nullptr;
61+
auto* token_pos_in_items_p = token_pos_in_items_ptr.has_value() ? token_pos_in_items_ptr->data_ptr() : nullptr;
62+
auto token_pos_in_items_v = token_pos_in_items_len.has_value() ? token_pos_in_items_len.value() : 0;
63+
auto* max_item_len_p = max_item_len_ptr.has_value() ? max_item_len_ptr->data_ptr() : nullptr;
5864
cudaError_t status = PrefillPlan<IdType>(
5965
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
6066
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
6167
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
6268
kv_indptr.data_ptr<IdType>(), total_num_rows, batch_size, num_qo_heads, num_kv_heads,
63-
head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
69+
head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream,
70+
prefix_len_p, token_pos_in_items_p, token_pos_in_items_v, max_item_len_p);
6471

6572
TORCH_CHECK(status == cudaSuccess,
6673
"Failed to plan prefill with error: ", cudaGetErrorString(status));
@@ -174,6 +181,13 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
174181
}
175182
params.padded_batch_size = plan_info.padded_batch_size;
176183
params.max_total_num_rows = plan_info.total_num_rows;
184+
185+
// set the prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr
186+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
187+
params.token_pos_in_items_ptr = reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
188+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
189+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
190+
177191
if (plan_info.enable_cuda_graph) {
178192
params.total_num_rows =
179193
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
@@ -308,6 +322,13 @@ void BatchPrefillWithPagedKVCacheRun(at::Tensor float_workspace_buffer,
308322
}
309323
params.padded_batch_size = plan_info.padded_batch_size;
310324
params.max_total_num_rows = plan_info.total_num_rows;
325+
326+
// set the prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr
327+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
328+
params.token_pos_in_items_ptr = reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
329+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
330+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
331+
311332
if (plan_info.enable_cuda_graph) {
312333
params.total_num_rows =
313334
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);

csrc/batch_prefill_customize_config.jinja

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ struct RaggedParams {
6868
uint32_t* total_num_rows;
6969
uint32_t padded_batch_size;
7070
bool partition_kv;
71+
uint32_t* prefix_len_ptr;
72+
uint16_t* token_pos_in_items_ptr;
73+
uint32_t token_pos_in_items_len;
74+
uint16_t* max_item_len_ptr;
7175

7276
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
7377
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
@@ -108,6 +112,10 @@ struct PagedParams {
108112
uint32_t* total_num_rows;
109113
uint32_t padded_batch_size;
110114
bool partition_kv;
115+
uint32_t* prefix_len_ptr;
116+
uint16_t* token_pos_in_items_ptr;
117+
uint32_t token_pos_in_items_len;
118+
uint16_t* max_item_len_ptr;
111119

112120
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
113121
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];

csrc/batch_prefill_jit_pybind.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ at::Tensor BatchPrefillWithKVCachePlan(
2121
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
2222
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
2323
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
24-
int64_t head_dim_vo, bool causal);
24+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr, std::optional<at::Tensor> token_pos_in_items_ptr,
25+
std::optional<int64_t> token_pos_in_items_len, std::optional<at::Tensor> max_item_len_ptr);
2526

2627
void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
2728
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

csrc/batch_prefill_sm90.cu

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
4343
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
4444
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
4545
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
46-
int64_t head_dim_vo, bool causal) {
46+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr, std::optional<at::Tensor> token_pos_in_items_ptr,
47+
std::optional<int64_t> token_pos_in_items_len, std::optional<at::Tensor> max_item_len_ptr) {
4748
size_t float_workspace_size_in_bytes =
4849
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
4950
size_t int_workspace_size_in_bytes =
@@ -53,14 +54,20 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
5354

5455
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
5556
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
57+
// Check if the optional values have a value before accessing them
58+
auto* prefix_len_p = prefix_len_ptr.has_value() ? prefix_len_ptr->data_ptr() : nullptr;
59+
auto* token_pos_in_items_p = token_pos_in_items_ptr.has_value() ? token_pos_in_items_ptr->data_ptr() : nullptr;
60+
auto token_pos_in_items_v = token_pos_in_items_len.has_value() ? token_pos_in_items_len.value() : 0;
61+
auto* max_item_len_p = max_item_len_ptr.has_value() ? max_item_len_ptr->data_ptr() : nullptr;
5662

5763
cudaError_t status =
5864
PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
5965
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
6066
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
6167
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows,
6268
batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
63-
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
69+
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream,
70+
prefix_len_p, token_pos_in_items_p, token_pos_in_items_v, max_item_len_p);
6471

6572
TORCH_CHECK(status == cudaSuccess,
6673
"PrefillSM90Plan failed with error: ", cudaGetErrorString(status));
@@ -141,6 +148,14 @@ void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
141148
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
142149
params.work_indptr =
143150
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
151+
params.batch_indices =
152+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
153+
154+
// Set the multi-item scoring parameters
155+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
156+
params.token_pos_in_items_ptr = reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
157+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
158+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
144159

145160
ADDITIONAL_PARAMS_SETTER
146161

@@ -238,8 +253,16 @@ void BatchPrefillWithPagedKVCacheSM90Run(
238253
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
239254
params.work_indptr =
240255
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
256+
params.batch_indices =
257+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
241258
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());
242259

260+
// Set the multi-item scoring parameters
261+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
262+
params.token_pos_in_items_ptr = reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
263+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
264+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
265+
243266
ADDITIONAL_PARAMS_SETTER
244267

245268
bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;

csrc/batch_prefill_sm90_customize_config.jinja

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct RaggedParams {
4343
IdType* kv_lens;
4444
IdType* head_indices;
4545
IdType* work_indptr;
46+
IdType* batch_indices;
4647

4748
struct AdditionalParams {
4849
{{ additional_params_decl }}
@@ -66,6 +67,11 @@ struct RaggedParams {
6667
int window_left;
6768

6869
bool causal;
70+
71+
uint32_t* prefix_len_ptr;
72+
uint16_t* token_pos_in_items_ptr;
73+
uint32_t token_pos_in_items_len;
74+
uint16_t* max_item_len_ptr;
6975
};
7076

7177
struct PagedParams {
@@ -88,6 +94,7 @@ struct PagedParams {
8894
IdType* kv_lens;
8995
IdType* head_indices;
9096
IdType* work_indptr;
97+
IdType* batch_indices;
9198

9299
struct AdditionalParams {
93100
{{ additional_params_decl }}
@@ -111,6 +118,11 @@ struct PagedParams {
111118
int window_left;
112119

113120
bool causal;
121+
122+
uint32_t* prefix_len_ptr;
123+
uint16_t* token_pos_in_items_ptr;
124+
uint32_t token_pos_in_items_len;
125+
uint16_t* max_item_len_ptr;
114126
};
115127

116128
{{ variant_decl }}

csrc/batch_prefill_sm90_jit_pybind.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
2121
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
2222
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
2323
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
24-
int64_t head_dim_vo, bool causal);
24+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr, std::optional<at::Tensor> token_pos_in_items_ptr,
25+
std::optional<int64_t> token_pos_in_items_len, std::optional<at::Tensor> max_item_len_ptr);
2526

2627
void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
2728
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

csrc/flashinfer_ops.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ at::Tensor BatchPrefillWithKVCachePlan(
105105
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
106106
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
107107
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
108-
int64_t head_dim_vo, bool causal);
108+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr, std::optional<at::Tensor> token_pos_in_items_ptr,
109+
std::optional<int64_t> token_pos_in_items_len, std::optional<at::Tensor> max_item_len_ptr);
109110

110111
void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
111112
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

csrc/flashinfer_ops_sm90.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
3232
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
3333
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
3434
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
35-
int64_t head_dim_vo, bool causal);
35+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr, std::optional<at::Tensor> token_pos_in_items_ptr,
36+
std::optional<int64_t> token_pos_in_items_len, std::optional<at::Tensor> max_item_len_ptr);
3637

3738
void BatchPrefillWithRaggedKVCacheSM90Run(
3839
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

flashinfer/jit/attention/pytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,8 @@ def gen_customize_pod_module(
629629

630630
source_paths = []
631631

632-
for mask_mode_p in [0, 1, 2]:
633-
for mask_mode_d in [0, 1, 2]:
632+
for mask_mode_p in [0, 1, 2, 3]:
633+
for mask_mode_d in [0, 1, 2, 3]:
634634
kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p]
635635
kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d]
636636

@@ -929,7 +929,7 @@ def gen_customize_single_prefill_module(
929929
os.makedirs(gen_directory, exist_ok=True)
930930

931931
source_paths = []
932-
for mask_mode in [0, 1, 2]:
932+
for mask_mode in [0, 1, 2, 3]:
933933
filename = f"single_prefill_kernel_mask_{mask_mode}.cu"
934934
dest_path = gen_directory / filename
935935
source_paths.append(dest_path)
@@ -987,7 +987,7 @@ def gen_customize_single_prefill_module(
987987
os.makedirs(gen_directory, exist_ok=True)
988988

989989
source_paths = []
990-
for mask_mode in [0, 1, 2]:
990+
for mask_mode in [0, 1, 2, 3]:
991991
filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu"
992992
dest_path = gen_directory / filename
993993
source_paths.append(dest_path)
@@ -1170,7 +1170,7 @@ def gen_customize_batch_prefill_module(
11701170
os.makedirs(gen_directory, exist_ok=True)
11711171

11721172
source_paths = []
1173-
for mask_mode in [0, 1, 2]:
1173+
for mask_mode in [0, 1, 2, 3]:
11741174
dest_path = (
11751175
gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}.cu"
11761176
)
@@ -1243,7 +1243,7 @@ def gen_customize_batch_prefill_module(
12431243
generated_inc_str = config_templ.render(**kwargs)
12441244

12451245
source_paths = []
1246-
for mask_mode in [0, 1, 2]:
1246+
for mask_mode in [0, 1, 2, 3]:
12471247
filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu"
12481248
dest_path = gen_directory / filename
12491249
source_paths.append(dest_path)

flashinfer/jit/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,5 @@ def wrapper(func, args):
9898
0: "MaskMode::kNone",
9999
1: "MaskMode::kCausal",
100100
2: "MaskMode::kCustom",
101+
3: "MaskMode::kMultiItemScoring",
101102
}

0 commit comments

Comments
 (0)