Skip to content

add multi-item scoring #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aot_build_utils/literal_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
0: "MaskMode::kNone",
1: "MaskMode::kCausal",
2: "MaskMode::kCustom",
3: "MaskMode::kMultiItemScoring",
}

pos_encoding_mode_literal = {
Expand Down
30 changes: 28 additions & 2 deletions csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ at::Tensor BatchPrefillWithKVCachePlan(
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal) {
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
std::optional<at::Tensor> max_item_len_ptr) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
Expand All @@ -55,12 +57,20 @@ at::Tensor BatchPrefillWithKVCachePlan(

const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
// Check if the optional values have a value before accessing them
auto* prefix_len_p = prefix_len_ptr.has_value() ? prefix_len_ptr->data_ptr() : nullptr;
auto* token_pos_in_items_p =
token_pos_in_items_ptr.has_value() ? token_pos_in_items_ptr->data_ptr() : nullptr;
auto token_pos_in_items_v =
token_pos_in_items_len.has_value() ? token_pos_in_items_len.value() : 0;
auto* max_item_len_p = max_item_len_ptr.has_value() ? max_item_len_ptr->data_ptr() : nullptr;
cudaError_t status = PrefillPlan<IdType>(
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
kv_indptr.data_ptr<IdType>(), total_num_rows, batch_size, num_qo_heads, num_kv_heads,
head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream,
prefix_len_p, token_pos_in_items_p, token_pos_in_items_v, max_item_len_p);

TORCH_CHECK(status == cudaSuccess,
"Failed to plan prefill with error: ", cudaGetErrorString(status));
Expand Down Expand Up @@ -174,6 +184,14 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
}
params.padded_batch_size = plan_info.padded_batch_size;
params.max_total_num_rows = plan_info.total_num_rows;

// set the prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
params.token_pos_in_items_ptr =
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);

if (plan_info.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
Expand Down Expand Up @@ -308,6 +326,14 @@ void BatchPrefillWithPagedKVCacheRun(at::Tensor float_workspace_buffer,
}
params.padded_batch_size = plan_info.padded_batch_size;
params.max_total_num_rows = plan_info.total_num_rows;

// set the prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
params.token_pos_in_items_ptr =
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);

if (plan_info.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
Expand Down
8 changes: 8 additions & 0 deletions csrc/batch_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ struct RaggedParams {
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;
uint32_t* prefix_len_ptr;
uint16_t* token_pos_in_items_ptr;
uint32_t token_pos_in_items_len;
uint16_t* max_item_len_ptr;

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
Expand Down Expand Up @@ -108,6 +112,10 @@ struct PagedParams {
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;
uint32_t* prefix_len_ptr;
uint16_t* token_pos_in_items_ptr;
uint32_t token_pos_in_items_len;
uint16_t* max_item_len_ptr;

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
Expand Down
4 changes: 3 additions & 1 deletion csrc/batch_prefill_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ at::Tensor BatchPrefillWithKVCachePlan(
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal);
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
std::optional<at::Tensor> max_item_len_ptr);

void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
Expand Down
32 changes: 30 additions & 2 deletions csrc/batch_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal) {
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
std::optional<at::Tensor> max_item_len_ptr) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
Expand All @@ -53,14 +55,22 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(

const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
// Check if the optional values have a value before accessing them
auto* prefix_len_p = prefix_len_ptr.has_value() ? prefix_len_ptr->data_ptr() : nullptr;
auto* token_pos_in_items_p =
token_pos_in_items_ptr.has_value() ? token_pos_in_items_ptr->data_ptr() : nullptr;
auto token_pos_in_items_v =
token_pos_in_items_len.has_value() ? token_pos_in_items_len.value() : 0;
auto* max_item_len_p = max_item_len_ptr.has_value() ? max_item_len_ptr->data_ptr() : nullptr;

cudaError_t status =
PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows,
batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream, prefix_len_p,
token_pos_in_items_p, token_pos_in_items_v, max_item_len_p);

TORCH_CHECK(status == cudaSuccess,
"PrefillSM90Plan failed with error: ", cudaGetErrorString(status));
Expand Down Expand Up @@ -141,6 +151,15 @@ void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.batch_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);

// Set the multi-item scoring parameters
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
params.token_pos_in_items_ptr =
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);

ADDITIONAL_PARAMS_SETTER

Expand Down Expand Up @@ -238,8 +257,17 @@ void BatchPrefillWithPagedKVCacheSM90Run(
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.batch_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());

// Set the multi-item scoring parameters
params.prefix_len_ptr = reinterpret_cast<uint32_t*>(plan_info.prefix_len_ptr);
params.token_pos_in_items_ptr =
reinterpret_cast<uint16_t*>(plan_info.token_pos_in_items_ptr);
params.token_pos_in_items_len = static_cast<uint32_t>(plan_info.token_pos_in_items_len);
params.max_item_len_ptr = reinterpret_cast<uint16_t*>(plan_info.max_item_len_ptr);

ADDITIONAL_PARAMS_SETTER

bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;
Expand Down
12 changes: 12 additions & 0 deletions csrc/batch_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct RaggedParams {
IdType* kv_lens;
IdType* head_indices;
IdType* work_indptr;
IdType* batch_indices;

struct AdditionalParams {
{{ additional_params_decl }}
Expand All @@ -66,6 +67,11 @@ struct RaggedParams {
int window_left;

bool causal;

uint32_t* prefix_len_ptr;
uint16_t* token_pos_in_items_ptr;
uint32_t token_pos_in_items_len;
uint16_t* max_item_len_ptr;
};

struct PagedParams {
Expand All @@ -88,6 +94,7 @@ struct PagedParams {
IdType* kv_lens;
IdType* head_indices;
IdType* work_indptr;
IdType* batch_indices;

struct AdditionalParams {
{{ additional_params_decl }}
Expand All @@ -111,6 +118,11 @@ struct PagedParams {
int window_left;

bool causal;

uint32_t* prefix_len_ptr;
uint16_t* token_pos_in_items_ptr;
uint32_t token_pos_in_items_len;
uint16_t* max_item_len_ptr;
};

{{ variant_decl }}
4 changes: 3 additions & 1 deletion csrc/batch_prefill_sm90_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal);
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
std::optional<at::Tensor> max_item_len_ptr);

void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
Expand Down
4 changes: 3 additions & 1 deletion csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ at::Tensor BatchPrefillWithKVCachePlan(
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal);
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
std::optional<at::Tensor> max_item_len_ptr);

void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
Expand Down
4 changes: 3 additions & 1 deletion csrc/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ at::Tensor BatchPrefillWithKVCacheSM90Plan(
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal);
int64_t head_dim_vo, bool causal, std::optional<at::Tensor> prefix_len_ptr,
std::optional<at::Tensor> token_pos_in_items_ptr, std::optional<int64_t> token_pos_in_items_len,
std::optional<at::Tensor> max_item_len_ptr);

void BatchPrefillWithRaggedKVCacheSM90Run(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
Expand Down
12 changes: 6 additions & 6 deletions flashinfer/jit/attention/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ def gen_customize_pod_module(

source_paths = []

for mask_mode_p in [0, 1, 2]:
for mask_mode_d in [0, 1, 2]:
for mask_mode_p in [0, 1, 2, 3]:
for mask_mode_d in [0, 1, 2, 3]:
kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p]
kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d]

Expand Down Expand Up @@ -929,7 +929,7 @@ def gen_customize_single_prefill_module(
os.makedirs(gen_directory, exist_ok=True)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
filename = f"single_prefill_kernel_mask_{mask_mode}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)
Expand Down Expand Up @@ -987,7 +987,7 @@ def gen_customize_single_prefill_module(
os.makedirs(gen_directory, exist_ok=True)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)
Expand Down Expand Up @@ -1170,7 +1170,7 @@ def gen_customize_batch_prefill_module(
os.makedirs(gen_directory, exist_ok=True)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
dest_path = (
gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}.cu"
)
Expand Down Expand Up @@ -1243,7 +1243,7 @@ def gen_customize_batch_prefill_module(
generated_inc_str = config_templ.render(**kwargs)

source_paths = []
for mask_mode in [0, 1, 2]:
for mask_mode in [0, 1, 2, 3]:
filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu"
dest_path = gen_directory / filename
source_paths.append(dest_path)
Expand Down
1 change: 1 addition & 0 deletions flashinfer/jit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ def wrapper(func, args):
0: "MaskMode::kNone",
1: "MaskMode::kCausal",
2: "MaskMode::kCustom",
3: "MaskMode::kMultiItemScoring",
}
Loading