Skip to content

Commit 1f79f0d

Browse files
authored
Merge pull request #1 from arde171/arde/mis
add multi-item scoring
2 parents f579ca2 + 70bd358 commit 1f79f0d

33 files changed

+1272
-166
lines changed

aot_build_utils/literal_map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
0: "MaskMode::kNone",
1919
1: "MaskMode::kCausal",
2020
2: "MaskMode::kCustom",
21+
3: "MaskMode::kMultiItemScoring",
2122
}
2223

2324
pos_encoding_mode_literal = {

csrc/batch_prefill.cu

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ at::Tensor BatchPrefillWithKVCachePlan(
4545
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
4646
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
4747
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
48-
int64_t head_dim_vo, bool causal) {
48+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
49+
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
50+
std::optional<at::Tensor> max_item_len_ptr) {
4951
size_t float_workspace_size_in_bytes =
5052
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
5153
size_t int_workspace_size_in_bytes =
@@ -55,12 +57,20 @@ at::Tensor BatchPrefillWithKVCachePlan(
5557

5658
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
5759
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
60+
// Check if the optional values have a value before accessing them
61+
auto* prefix_len_p = prefix_len_ptr.has_value() ? prefix_len_ptr->data_ptr() : nullptr;
62+
auto* token_pos_in_items_p =
63+
token_pos_in_items_ptr.has_value() ? token_pos_in_items_ptr->data_ptr() : nullptr;
64+
auto token_pos_in_items_v =
65+
token_pos_in_items_len.has_value() ? token_pos_in_items_len.value() : 0;
66+
auto* max_item_len_p = max_item_len_ptr.has_value() ? max_item_len_ptr->data_ptr() : nullptr;
5867
cudaError_t status = PrefillPlan<IdType>(
5968
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
6069
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
6170
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
6271
kv_indptr.data_ptr<IdType>(), total_num_rows, batch_size, num_qo_heads, num_kv_heads,
63-
head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
72+
head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream,
73+
prefix_len_p, token_pos_in_items_p, token_pos_in_items_v, max_item_len_p);
6474

6575
TORCH_CHECK(status == cudaSuccess,
6676
"Failed to plan prefill with error: ", cudaGetErrorString(status));
@@ -174,6 +184,14 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
174184
}
175185
params.padded_batch_size = plan_info.padded_batch_size;
176186
params.max_total_num_rows = plan_info.total_num_rows;
187+
188+
// set the prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr
189+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
190+
params.token_pos_in_items_ptr =
191+
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
192+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
193+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
194+
177195
if (plan_info.enable_cuda_graph) {
178196
params.total_num_rows =
179197
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
@@ -308,6 +326,14 @@ void BatchPrefillWithPagedKVCacheRun(at::Tensor float_workspace_buffer,
308326
}
309327
params.padded_batch_size = plan_info.padded_batch_size;
310328
params.max_total_num_rows = plan_info.total_num_rows;
329+
330+
// set the prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr
331+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
332+
params.token_pos_in_items_ptr =
333+
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
334+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
335+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
336+
311337
if (plan_info.enable_cuda_graph) {
312338
params.total_num_rows =
313339
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);

csrc/batch_prefill_customize_config.jinja

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ struct RaggedParams {
6868
uint32_t* total_num_rows;
6969
uint32_t padded_batch_size;
7070
bool partition_kv;
71+
uint32_t* prefix_len_ptr;
72+
uint16_t* token_pos_in_items_ptr;
73+
uint32_t token_pos_in_items_len;
74+
uint16_t* max_item_len_ptr;
7175

7276
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
7377
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
@@ -108,6 +112,10 @@ struct PagedParams {
108112
uint32_t* total_num_rows;
109113
uint32_t padded_batch_size;
110114
bool partition_kv;
115+
uint32_t* prefix_len_ptr;
116+
uint16_t* token_pos_in_items_ptr;
117+
uint32_t token_pos_in_items_len;
118+
uint16_t* max_item_len_ptr;
111119

112120
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
113121
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];

csrc/batch_prefill_jit_pybind.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ at::Tensor BatchPrefillWithKVCachePlan(
2121
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
2222
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
2323
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
24-
int64_t head_dim_vo, bool causal);
24+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
25+
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
26+
std::optional<at::Tensor> max_item_len_ptr);
2527

2628
void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
2729
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

csrc/batch_prefill_sm90.cu

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
4343
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
4444
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
4545
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
46-
int64_t head_dim_vo, bool causal) {
46+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
47+
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
48+
std::optional<at::Tensor> max_item_len_ptr) {
4749
size_t float_workspace_size_in_bytes =
4850
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
4951
size_t int_workspace_size_in_bytes =
@@ -53,14 +55,22 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
5355

5456
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
5557
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
58+
// Check if the optional values have a value before accessing them
59+
auto* prefix_len_p = prefix_len_ptr.has_value() ? prefix_len_ptr->data_ptr() : nullptr;
60+
auto* token_pos_in_items_p =
61+
token_pos_in_items_ptr.has_value() ? token_pos_in_items_ptr->data_ptr() : nullptr;
62+
auto token_pos_in_items_v =
63+
token_pos_in_items_len.has_value() ? token_pos_in_items_len.value() : 0;
64+
auto* max_item_len_p = max_item_len_ptr.has_value() ? max_item_len_ptr->data_ptr() : nullptr;
5665

5766
cudaError_t status =
5867
PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
5968
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
6069
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
6170
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows,
6271
batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
63-
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
72+
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream, prefix_len_p,
73+
token_pos_in_items_p, token_pos_in_items_v, max_item_len_p);
6474

6575
TORCH_CHECK(status == cudaSuccess,
6676
"PrefillSM90Plan failed with error: ", cudaGetErrorString(status));
@@ -141,6 +151,15 @@ void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
141151
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
142152
params.work_indptr =
143153
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
154+
params.batch_indices =
155+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
156+
157+
// Set the multi-item scoring parameters
158+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
159+
params.token_pos_in_items_ptr =
160+
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
161+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
162+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
144163

145164
ADDITIONAL_PARAMS_SETTER
146165

@@ -238,8 +257,17 @@ void BatchPrefillWithPagedKVCacheSM90Run(
238257
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
239258
params.work_indptr =
240259
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
260+
params.batch_indices =
261+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
241262
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());
242263

264+
// Set the multi-item scoring parameters
265+
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
266+
params.token_pos_in_items_ptr =
267+
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
268+
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
269+
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);
270+
243271
ADDITIONAL_PARAMS_SETTER
244272

245273
bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;

csrc/batch_prefill_sm90_customize_config.jinja

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct RaggedParams {
4343
IdType* kv_lens;
4444
IdType* head_indices;
4545
IdType* work_indptr;
46+
IdType* batch_indices;
4647

4748
struct AdditionalParams {
4849
{{ additional_params_decl }}
@@ -66,6 +67,11 @@ struct RaggedParams {
6667
int window_left;
6768

6869
bool causal;
70+
71+
uint32_t* prefix_len_ptr;
72+
uint16_t* token_pos_in_items_ptr;
73+
uint32_t token_pos_in_items_len;
74+
uint16_t* max_item_len_ptr;
6975
};
7076

7177
struct PagedParams {
@@ -88,6 +94,7 @@ struct PagedParams {
8894
IdType* kv_lens;
8995
IdType* head_indices;
9096
IdType* work_indptr;
97+
IdType* batch_indices;
9198

9299
struct AdditionalParams {
93100
{{ additional_params_decl }}
@@ -111,6 +118,11 @@ struct PagedParams {
111118
int window_left;
112119

113120
bool causal;
121+
122+
uint32_t* prefix_len_ptr;
123+
uint16_t* token_pos_in_items_ptr;
124+
uint32_t token_pos_in_items_len;
125+
uint16_t* max_item_len_ptr;
114126
};
115127

116128
{{ variant_decl }}

csrc/batch_prefill_sm90_jit_pybind.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
2121
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
2222
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
2323
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
24-
int64_t head_dim_vo, bool causal);
24+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
25+
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
26+
std::optional<at::Tensor> max_item_len_ptr);
2527

2628
void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
2729
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

csrc/flashinfer_ops.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ at::Tensor BatchPrefillWithKVCachePlan(
105105
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
106106
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
107107
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
108-
int64_t head_dim_vo, bool causal);
108+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
109+
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
110+
std::optional<at::Tensor> max_item_len_ptr);
109111

110112
void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
111113
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

csrc/flashinfer_ops_sm90.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
3232
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
3333
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
3434
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
35-
int64_t head_dim_vo, bool causal);
35+
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
36+
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
37+
std::optional<at::Tensor> max_item_len_ptr);
3638

3739
void BatchPrefillWithRaggedKVCacheSM90Run(
3840
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,

flashinfer/jit/attention/pytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,8 @@ def gen_customize_pod_module(
629629

630630
source_paths = []
631631

632-
for mask_mode_p in [0, 1, 2]:
633-
for mask_mode_d in [0, 1, 2]:
632+
for mask_mode_p in [0, 1, 2, 3]:
633+
for mask_mode_d in [0, 1, 2, 3]:
634634
kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p]
635635
kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d]
636636

@@ -929,7 +929,7 @@ def gen_customize_single_prefill_module(
929929
os.makedirs(gen_directory, exist_ok=True)
930930

931931
source_paths = []
932-
for mask_mode in [0, 1, 2]:
932+
for mask_mode in [0, 1, 2, 3]:
933933
filename = f"single_prefill_kernel_mask_{mask_mode}.cu"
934934
dest_path = gen_directory / filename
935935
source_paths.append(dest_path)
@@ -987,7 +987,7 @@ def gen_customize_single_prefill_module(
987987
os.makedirs(gen_directory, exist_ok=True)
988988

989989
source_paths = []
990-
for mask_mode in [0, 1, 2]:
990+
for mask_mode in [0, 1, 2, 3]:
991991
filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu"
992992
dest_path = gen_directory / filename
993993
source_paths.append(dest_path)
@@ -1170,7 +1170,7 @@ def gen_customize_batch_prefill_module(
11701170
os.makedirs(gen_directory, exist_ok=True)
11711171

11721172
source_paths = []
1173-
for mask_mode in [0, 1, 2]:
1173+
for mask_mode in [0, 1, 2, 3]:
11741174
dest_path = (
11751175
gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}.cu"
11761176
)
@@ -1243,7 +1243,7 @@ def gen_customize_batch_prefill_module(
12431243
generated_inc_str = config_templ.render(**kwargs)
12441244

12451245
source_paths = []
1246-
for mask_mode in [0, 1, 2]:
1246+
for mask_mode in [0, 1, 2, 3]:
12471247
filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu"
12481248
dest_path = gen_directory / filename
12491249
source_paths.append(dest_path)

flashinfer/jit/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,5 @@ def wrapper(func, args):
9898
0: "MaskMode::kNone",
9999
1: "MaskMode::kCausal",
100100
2: "MaskMode::kCustom",
101+
3: "MaskMode::kMultiItemScoring",
101102
}

0 commit comments

Comments
 (0)