diff --git a/aot_build_utils/generate_aot_default_additional_params_header.py b/aot_build_utils/generate_aot_default_additional_params_header.py index 2e832bb36..285c5a6e2 100644 --- a/aot_build_utils/generate_aot_default_additional_params_header.py +++ b/aot_build_utils/generate_aot_default_additional_params_header.py @@ -125,23 +125,35 @@ def get_aot_default_additional_params_header_str() -> str: ret += generate_macro_entry( "BATCH_PREFILL", - ["maybe_custom_mask", "maybe_mask_indptr", "maybe_alibi_slopes"], - ["uint8_t", "int32_t", "float"], + [ + "maybe_custom_mask", + "maybe_mask_indptr", + "maybe_alibi_slopes", + "maybe_prefix_len_ptr", + "maybe_token_pos_in_items_ptr", + "maybe_max_item_len_ptr", + ], + ["uint8_t", "int32_t", "float", "uint32_t", "uint16_t", "uint16_t"], [ "logits_soft_cap", "sm_scale", "rope_rcp_scale", "rope_rcp_theta", + "token_pos_in_items_len", ], - ["double", "double", "double", "double"], + ["double", "double", "double", "double", "int64_t"], ) ret += generate_macro_entry( "BATCH_PREFILL_SM90", - [], - [], - ["logits_soft_cap", "sm_scale"], - ["double", "double"], + [ + "maybe_prefix_len_ptr", + "maybe_token_pos_in_items_ptr", + "maybe_max_item_len_ptr", + ], + ["uint32_t", "uint16_t", "uint16_t"], + ["logits_soft_cap", "sm_scale", "token_pos_in_items_len"], + ["double", "double", "int64_t"], is_sm90_template=True, ) diff --git a/aot_build_utils/literal_map.py b/aot_build_utils/literal_map.py index 9001fda53..1a47da346 100644 --- a/aot_build_utils/literal_map.py +++ b/aot_build_utils/literal_map.py @@ -18,6 +18,7 @@ 0: "MaskMode::kNone", 1: "MaskMode::kCausal", 2: "MaskMode::kCustom", + 3: "MaskMode::kMultiItemScoring", } pos_encoding_mode_literal = { diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index 66a68b73f..fb70874a3 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -141,6 +141,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer, GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.batch_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.batch_indices_offset); ADDITIONAL_PARAMS_SETTER @@ -238,6 +240,8 @@ void BatchPrefillWithPagedKVCacheSM90Run( GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.batch_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.batch_indices_offset); params.kv_indices = static_cast(paged_kv_indices.data_ptr()); ADDITIONAL_PARAMS_SETTER diff --git a/csrc/batch_prefill_sm90_customize_config.jinja b/csrc/batch_prefill_sm90_customize_config.jinja index 5b10355fc..aa2389349 100644 --- a/csrc/batch_prefill_sm90_customize_config.jinja +++ b/csrc/batch_prefill_sm90_customize_config.jinja @@ -43,6 +43,7 @@ struct RaggedParams { IdType* kv_lens; IdType* head_indices; IdType* work_indptr; + IdType* batch_indices; struct AdditionalParams { {{ additional_params_decl }} @@ -88,6 +89,7 @@ struct PagedParams { IdType* kv_lens; IdType* head_indices; IdType* work_indptr; + IdType* batch_indices; struct AdditionalParams { {{ additional_params_decl }} diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 32cda4c91..b9c952a99 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1170,6 +1170,9 @@ def run( None, # packed_custom_mask None, # mask_indptr_buf _get_cache_alibi_slopes_buf(q.shape[1], q.device), + None, # maybe_prefix_len_ptr + None, # maybe_token_pos_in_items_ptr + None, # maybe_max_item_len_ptr logits_soft_cap, sm_scale, None, # scale_q, not supported yet @@ -1177,6 +1180,7 @@ def run( None, # scale_v rope_scale, rope_theta, + 0, # token_pos_in_items_len ] self._cached_module.paged_run(*run_args) diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index 5f4a8a9bc..524ea19e5 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -645,8 +645,8 @@ def gen_customize_pod_module( source_paths = [] - for mask_mode_p in [0, 1, 2]: - for mask_mode_d in [0, 1, 2]: + for mask_mode_p in [0, 1, 2, 3]: + for mask_mode_d in [0, 1, 2, 3]: kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p] kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d] @@ -759,27 +759,42 @@ def gen_batch_prefill_module( "maybe_custom_mask", "maybe_mask_indptr", "maybe_alibi_slopes", + "maybe_prefix_len_ptr", + "maybe_token_pos_in_items_ptr", + "maybe_max_item_len_ptr", ] additional_tensor_dtypes = [ "uint8_t", "int32_t", "float", + "uint32_t", + "uint16_t", + "uint16_t", ] # NOTE(Zihao): int32_t should follow dtype_idx additional_scalar_names = [ "logits_soft_cap", "sm_scale", "rope_rcp_scale", "rope_rcp_theta", + "token_pos_in_items_len", ] - additional_scalar_dtypes = ["double", "double", "double", "double"] + additional_scalar_dtypes = ["double", "double", "double", "double", "int64_t"] variant_name = f"DefaultAttention" - variant_decl = f"#include" + variant_decl = "#include" else: if not fp8_enabled: - additional_tensor_names = [] - additional_tensor_dtypes = [] - additional_scalar_names = ["logits_soft_cap", "sm_scale"] - additional_scalar_dtypes = ["double", "double"] + additional_tensor_names = [ + "maybe_prefix_len_ptr", + "maybe_token_pos_in_items_ptr", + "maybe_max_item_len_ptr", + ] + additional_tensor_dtypes = ["uint32_t", "uint16_t", "uint16_t"] + additional_scalar_names = [ + "logits_soft_cap", + "sm_scale", + "token_pos_in_items_len", + ] + additional_scalar_dtypes = ["double", "double", "int64_t"] variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" variant_decl = f"#include" else: @@ -961,7 +976,7 @@ def gen_customize_single_prefill_module( os.makedirs(gen_directory, exist_ok=True) source_paths = [] - for mask_mode in [0, 1, 2]: + for mask_mode in [0, 1, 2, 3]: filename = f"single_prefill_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) @@ -1025,7 +1040,7 @@ def gen_customize_single_prefill_module( os.makedirs(gen_directory, exist_ok=True) source_paths = [] - for mask_mode in [0, 1, 2]: + for mask_mode in [0, 1, 2, 3]: filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) @@ -1209,7 +1224,7 @@ def gen_customize_batch_prefill_module( os.makedirs(gen_directory, exist_ok=True) source_paths = [] - for mask_mode in [0, 1, 2]: + for mask_mode in [0, 1, 2, 3]: dest_path = ( gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}.cu" ) @@ -1286,7 +1301,7 @@ def gen_customize_batch_prefill_module( generated_inc_str = config_templ.render(**kwargs) source_paths = [] - for mask_mode in [0, 1, 2]: + for mask_mode in [0, 1, 2, 3]: filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) diff --git a/flashinfer/jit/utils.py b/flashinfer/jit/utils.py index 01a698706..17fc4d356 100644 --- a/flashinfer/jit/utils.py +++ b/flashinfer/jit/utils.py @@ -98,4 +98,5 @@ def wrapper(func, args): 0: "MaskMode::kNone", 1: "MaskMode::kCausal", 2: "MaskMode::kCustom", + 3: "MaskMode::kMultiItemScoring", } diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 36e7e235e..88a773a54 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -250,10 +250,14 @@ def ragged_run( maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, + token_pos_in_items_len: int, ) -> None: if backend == "fa2": ragged_run_func( @@ -273,10 +277,14 @@ def ragged_run( maybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, + maybe_prefix_len_ptr, + maybe_token_pos_in_items_ptr, + maybe_max_item_len_ptr, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale 1.0 / rope_theta, # rope_rcp_theta + token_pos_in_items_len, ) else: ragged_run_func( @@ -293,8 +301,12 @@ def ragged_run( mask_mode, layout, window_left, + maybe_prefix_len_ptr, + maybe_token_pos_in_items_ptr, + maybe_max_item_len_ptr, logits_soft_cap, sm_scale, + token_pos_in_items_len, ) return o @@ -317,10 +329,14 @@ def _fake_ragged_run( maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, + token_pos_in_items_len: int, ) -> None: pass @@ -356,6 +372,9 @@ def paged_run( maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, scale_q: Optional[torch.Tensor], @@ -363,6 +382,7 @@ def paged_run( scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, + token_pos_in_items_len: int, ) -> None: if backend == "fa2": assert not is_float8(q) @@ -385,10 +405,14 @@ def paged_run( maybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, + maybe_prefix_len_ptr, + maybe_token_pos_in_items_ptr, + maybe_max_item_len_ptr, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale 1.0 / rope_theta, # rope_rcp_theta + token_pos_in_items_len, ) else: if not is_float8(q): @@ -408,8 +432,12 @@ def paged_run( mask_mode, layout, window_left, + maybe_prefix_len_ptr, + maybe_token_pos_in_items_ptr, + maybe_max_item_len_ptr, logits_soft_cap, sm_scale, + token_pos_in_items_len, ) else: paged_run_func( @@ -455,10 +483,14 @@ def _fake_paged_run( maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], + maybe_prefix_len_ptr: Optional[torch.Tensor], + maybe_token_pos_in_items_ptr: Optional[torch.Tensor], + maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, + token_pos_in_items_len: int, ) -> None: pass @@ -1280,6 +1312,10 @@ def plan( q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = True, + prefix_len_ptr: Optional[torch.Tensor] = None, + token_pos_in_items_ptr: Optional[torch.Tensor] = None, + token_pos_in_items_len: int = 0, + max_item_len_ptr: Optional[torch.Tensor] = None, ) -> None: r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. @@ -1354,6 +1390,20 @@ def plan( The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. + prefix_len_ptr :Optional[torch.Tensor] + prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size. + token_pos_in_items_ptr : Optional[float] + A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) + for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like + `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1, + we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by + `token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt. + token_pos_in_items_len : int + zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example. + If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]` + with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt) + max_item_len_ptr : Optional[float] + a uint16 vector contains the max token length of all items for each prompt Note ---- @@ -1393,6 +1443,11 @@ def plan( bitorder="little", ) + self._prefix_len_ptr = prefix_len_ptr + self._token_pos_in_items_ptr = token_pos_in_items_ptr + self._token_pos_in_items_len = token_pos_in_items_len + self._max_item_len_ptr = max_item_len_ptr + # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors qo_indptr_host = qo_indptr.to("cpu") paged_kv_indptr_host = paged_kv_indptr.to("cpu") @@ -1714,6 +1769,9 @@ def run( else: mask_mode = MaskMode.NON_CAUSAL.value + if self._prefix_len_ptr is not None: + mask_mode = MaskMode.MULTIITEMSCORING.value + if self._backend == "fa3": # NOTE(Zihao): we divide both stride_block and stride_n by stride_n # because we will multiply stride_n back in the kernel @@ -1757,6 +1815,9 @@ def run( self._custom_mask_buf, self._mask_indptr_buf, _get_cache_alibi_slopes_buf(q.shape[1], q.device), + self._prefix_len_ptr, + self._token_pos_in_items_ptr, + self._max_item_len_ptr, logits_soft_cap, sm_scale, None, # scale_q, not supported yet @@ -1764,6 +1825,7 @@ def run( None, # scale_v rope_scale, rope_theta, + self._token_pos_in_items_len, ] self._cached_module.paged_run(*run_args) @@ -2066,6 +2128,10 @@ def plan( q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = True, + prefix_len_ptr: Optional[torch.Tensor] = None, + token_pos_in_items_ptr: Optional[torch.Tensor] = None, + token_pos_in_items_len: int = 0, + max_item_len_ptr: Optional[torch.Tensor] = None, ) -> None: r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. @@ -2135,6 +2201,20 @@ def plan( The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. + prefix_len_ptr :Optional[torch.Tensor] + prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size. + token_pos_in_items_ptr : Optional[float] + A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter) + for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like + `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1, + we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by + `token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt. + token_pos_in_items_len : int + zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example. + If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]` + with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt) + max_item_len_ptr : Optional[float] + a uint16 vector contains the max token length of all items for each prompt Note ---- @@ -2225,6 +2305,11 @@ def plan( self._cached_kv_data_type = kv_data_type kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] + self._prefix_len_ptr = prefix_len_ptr + self._token_pos_in_items_ptr = token_pos_in_items_ptr + self._token_pos_in_items_len = token_pos_in_items_len + self._max_item_len_ptr = max_item_len_ptr + if self._jit_module is not None: self._cached_module = self._jit_module else: @@ -2445,10 +2530,14 @@ def run( self._custom_mask_buf, self._mask_indptr_buf, _get_cache_alibi_slopes_buf(q.shape[1], self.device), + self._prefix_len_ptr, + self._token_pos_in_items_ptr, + self._max_item_len_ptr, logits_soft_cap, sm_scale, rope_scale, rope_theta, + self._token_pos_in_items_len, ] self._cached_module.ragged_run(*run_args) diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 44e1bb23b..6893f9821 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -615,6 +615,9 @@ def run( self._packed_mask_buf, self._mask_indptr_buf, _get_cache_alibi_slopes_buf(q.shape[1], self.device), + None, # maybe_prefix_len_ptr + None, # maybe_token_pos_in_items_ptr + None, # maybe_max_item_len_ptr logits_soft_cap, sm_scale, scale_q, @@ -622,6 +625,7 @@ def run( scale_v, rope_scale, rope_theta, + 0, # token_pos_in_items_len ) else: self._cached_module.run( diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 7f22791f6..0f5696a67 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -37,6 +37,7 @@ class MaskMode(Enum): NON_CAUSAL = 0 CAUSAL = 1 CUSTOM = 2 + MULTIITEMSCORING = 3 class TensorLayout(Enum): diff --git a/include/flashinfer/attention/default_prefill_params.cuh b/include/flashinfer/attention/default_prefill_params.cuh index b10ebec03..2e857fcc7 100644 --- a/include/flashinfer/attention/default_prefill_params.cuh +++ b/include/flashinfer/attention/default_prefill_params.cuh @@ -172,6 +172,10 @@ struct BatchPrefillRaggedParams { uint32_t* total_num_rows; uint32_t padded_batch_size; bool partition_kv; + uint32_t* maybe_prefix_len_ptr; + uint16_t* maybe_token_pos_in_items_ptr; + uint32_t token_pos_in_items_len; + uint16_t* maybe_max_item_len_ptr; __host__ BatchPrefillRaggedParams() : q(nullptr), @@ -210,7 +214,11 @@ struct BatchPrefillRaggedParams { max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), - partition_kv(false) {} + partition_kv(false), + maybe_prefix_len_ptr(nullptr), + maybe_token_pos_in_items_ptr(nullptr), + token_pos_in_items_len(0), + maybe_max_item_len_ptr(nullptr) {} __host__ BatchPrefillRaggedParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr, @@ -257,7 +265,11 @@ struct BatchPrefillRaggedParams { max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), - partition_kv(false) {} + partition_kv(false), + maybe_prefix_len_ptr(nullptr), + maybe_token_pos_in_items_ptr(nullptr), + token_pos_in_items_len(0), + maybe_max_item_len_ptr(nullptr) {} __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; @@ -305,6 +317,10 @@ struct BatchPrefillPagedParams { uint32_t* total_num_rows; uint32_t padded_batch_size; bool partition_kv; + uint32_t* maybe_prefix_len_ptr; + uint16_t* maybe_token_pos_in_items_ptr; + uint32_t token_pos_in_items_len; + uint16_t* maybe_max_item_len_ptr; __host__ BatchPrefillPagedParams() : q(nullptr), @@ -335,7 +351,11 @@ struct BatchPrefillPagedParams { max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), - partition_kv(false) {} + partition_kv(false), + maybe_prefix_len_ptr(nullptr), + maybe_token_pos_in_items_ptr(nullptr), + token_pos_in_items_len(0), + maybe_max_item_len_ptr(nullptr) {} __host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t paged_kv, uint8_t* maybe_custom_mask, IdType* q_indptr, @@ -372,7 +392,11 @@ struct BatchPrefillPagedParams { max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), - partition_kv(false) {} + partition_kv(false), + maybe_prefix_len_ptr(nullptr), + maybe_token_pos_in_items_ptr(nullptr), + token_pos_in_items_len(0), + maybe_max_item_len_ptr(nullptr) {} __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; diff --git a/include/flashinfer/attention/hopper/default_params.cuh b/include/flashinfer/attention/hopper/default_params.cuh index 293db6601..f2b9d2e33 100644 --- a/include/flashinfer/attention/hopper/default_params.cuh +++ b/include/flashinfer/attention/hopper/default_params.cuh @@ -82,10 +82,15 @@ struct BatchPrefillRaggedParams { IdType* kv_lens; IdType* head_indices; IdType* work_indptr; + IdType* batch_indices; struct AdditionalParams { float logits_soft_cap; float sm_scale; + uint32_t* maybe_prefix_len_ptr; + uint16_t* maybe_token_pos_in_items_ptr; + uint32_t token_pos_in_items_len; + uint16_t* maybe_max_item_len_ptr; } additional_params; int64_t q_stride_n; @@ -128,10 +133,15 @@ struct BatchPrefillPagedParams { IdType* kv_lens; IdType* head_indices; IdType* work_indptr; + IdType* batch_indices; struct AdditionalParams { float logits_soft_cap; float sm_scale; + uint32_t* maybe_prefix_len_ptr; + uint16_t* maybe_token_pos_in_items_ptr; + uint32_t token_pos_in_items_len; + uint16_t* maybe_max_item_len_ptr; } additional_params; int64_t q_stride_n; diff --git a/include/flashinfer/attention/hopper/epilogue.cuh b/include/flashinfer/attention/hopper/epilogue.cuh index c30049ca6..81e43bd9a 100644 --- a/include/flashinfer/attention/hopper/epilogue.cuh +++ b/include/flashinfer/attention/hopper/epilogue.cuh @@ -153,7 +153,7 @@ struct CollectiveEpilogue { CUTLASS_DEVICE void store(Params const& epilogue_params, FrgTensorO const& tOrO, FrgTensorLSE const& lse, SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, BlockCoord const& block_coord) { - auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); @@ -213,7 +213,7 @@ struct CollectiveEpilogue { template CUTLASS_DEVICE void store_zero(Params const& epilogue_params, SharedStorage& shared_storage, int thread_idx, BlockCoord const& block_coord) { - auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.O_ptr), epilogue_params.layout_O); Tensor gO = get_local_tile_tensor(mO, select<0, 1>(TileShape_PDV{}), qo_head_idx, qo_indptr, diff --git a/include/flashinfer/attention/hopper/mainloop.cuh b/include/flashinfer/attention/hopper/mainloop.cuh index 2a8f93620..e5bf4ffb9 100644 --- a/include/flashinfer/attention/hopper/mainloop.cuh +++ b/include/flashinfer/attention/hopper/mainloop.cuh @@ -167,7 +167,8 @@ struct CollectiveMainloop { Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = + block_coord; // Prepare the TMA loads Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, diff --git a/include/flashinfer/attention/hopper/mainloop_mma.cuh b/include/flashinfer/attention/hopper/mainloop_mma.cuh index d784e0a70..26e0d1cd5 100644 --- a/include/flashinfer/attention/hopper/mainloop_mma.cuh +++ b/include/flashinfer/attention/hopper/mainloop_mma.cuh @@ -14,19 +14,19 @@ namespace flashinfer { -template -CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& variant, - MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, - PipelineState& smem_pipe_read_k, PipelineState& smem_pipe_read_v, - FrgTensorO& tOrO, AttentionUpdater& attention_updater, - int kv_tile_idx_count, int swa_begin_kv_tile_idx, - int swa_end_kv_tile_idx, int thread_idx, int work_idx, int q_tile_idx, - SharedStorage& shared_storage, const int32_t qo_len, - const int32_t kv_len, const int32_t qo_head_idx, - const int32_t kv_head_idx) { +template +CUTLASS_DEVICE void mma_f16( + const Params& mainloop_params, AttentionVariant& variant, MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, PipelineState& smem_pipe_read_k, PipelineState& smem_pipe_read_v, + FrgTensorO& tOrO, AttentionUpdater& attention_updater, int kv_tile_idx_count, + int swa_begin_kv_tile_idx, int swa_end_kv_tile_idx, int thread_idx, int work_idx, + int q_tile_idx, SharedStorage& shared_storage, const int32_t qo_len, const int32_t kv_len, + const int32_t qo_head_idx, const int32_t kv_head_idx, const uint32_t prefix_len, + uint16_t* token_pos_in_items, const int num_kv_tiles_outside_items_window = 0, + const int num_kv_tiles_prefix = 0) { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using IdType = typename Ktraits::IdType; @@ -93,6 +93,45 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var auto col_limit_left = [&](int qo_idx) { return qo_idx + kv_len - qo_len - mainloop_params.window_left; }; + auto mask_multi_item_scoring = [&](decltype(tSrS)& tSrS, int i, int qo_idx, int kv_idx) { + const uint32_t idx_in_original_seq = qo_idx + kv_len - qo_len; + const bool out_of_boundary = + kv_idx > idx_in_original_seq || (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))); + const bool is_prefix = idx_in_original_seq < prefix_len; + uint16_t token_pos_in_items_regs = + __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); + if (out_of_boundary || is_prefix) { + tSrS(i) = out_of_boundary ? (AttentionUpdater::fill_value) : tSrS(i); + } else { + tSrS(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) + ? tSrS(i) + : (AttentionUpdater::fill_value); + } + }; + auto mask_multi_item_scoring_assume_in_bound = [&](decltype(tSrS)& tSrS, int i, int qo_idx, + int kv_idx) { + const uint32_t idx_in_original_seq = qo_idx + kv_len - qo_len; + const bool is_prefix = idx_in_original_seq < prefix_len; + if (is_prefix) { + tSrS(i) = AttentionUpdater::fill_value; + } else { + uint16_t token_pos_in_items_regs = + __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); + tSrS(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) + ? tSrS(i) + : (AttentionUpdater::fill_value); + } + }; + auto kv_tile_idx_decrement = [&](int kv_tile_idx) { + int result = kv_tile_idx - 1; + if constexpr (MULTIITEMSCORING) { + if ((kv_tile_idx == num_kv_tiles_outside_items_window - 1) & + (kv_tile_idx >= num_kv_tiles_prefix)) { + result = num_kv_tiles_prefix - 1; + } + } + return result; + }; { Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); Tensor tScS = threadMmaQK.partition_C(cS); @@ -102,7 +141,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var int kv_idx = get<1>(tScS(i)) + kv_tile_idx * CTA_KV; tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); - if constexpr (!CAUSAL) { // Just masking based on col + if constexpr (MULTIITEMSCORING) { + mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); + } else if constexpr (!CAUSAL) { // Just masking based on col if (kv_idx >= kv_len) { tSrS(i) = AttentionUpdater::fill_value; } @@ -123,11 +164,13 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); - constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0; + constexpr int n_masking_steps = MULTIITEMSCORING ? (cute::ceil_div(CTA_Q, CTA_KV) + 1) + : (CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0); // masking loops + // ziangl@nvidia.com: for multi item scoring, we use this loop only to mask along the diagonal #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; - ++masking_step, --kv_tile_idx) { + ++masking_step, kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); consumer_wait(pipeline_k, smem_pipe_read_k); WarpScheduler::barrier_sync(); @@ -147,11 +190,19 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var #pragma unroll for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; - int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); - if (kv_idx >= col_limit_right(qo_idx)) { - tSrS(i) = AttentionUpdater::fill_value; + if (MULTIITEMSCORING) { + if (masking_step == n_masking_steps - 1) { + mask_multi_item_scoring_assume_in_bound(tSrS, i, qo_idx, kv_idx); + } else { + mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); + } + } else { + if (kv_idx >= col_limit_right(qo_idx)) { + tSrS(i) = AttentionUpdater::fill_value; + } } if constexpr (LEFT_SLIDING_WINDOW) { if (kv_idx < col_limit_left(qo_idx)) { @@ -170,7 +221,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var } #pragma unroll 1 - for (; kv_tile_idx > swa_end_kv_tile_idx + 1; --kv_tile_idx) { + for (; kv_tile_idx > swa_end_kv_tile_idx + 1; kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); consumer_wait(pipeline_k, smem_pipe_read_k); WarpScheduler::barrier_sync(); @@ -189,10 +240,22 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, AttentionVariant& var #pragma unroll for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; - int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); } + if constexpr (MULTIITEMSCORING) { + // auto nums_tiles_outside_causal_diagonal = kv_tile_idx_count - cute::ceil_div(CTA_Q, + // CTA_KV); + if (kv_tile_idx >= num_kv_tiles_prefix - 1) { +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; + mask_multi_item_scoring_assume_in_bound(tSrS, i, qo_idx, kv_idx); + } + } + } attention_updater.update(tSrS); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index a921b6c10..46e9f1664 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -35,8 +35,14 @@ namespace flashinfer { using namespace cute; +DEFINE_HAS_MEMBER(maybe_prefix_len_ptr) +DEFINE_HAS_MEMBER(maybe_token_pos_in_items_ptr) +DEFINE_HAS_MEMBER(token_pos_in_items_len) +DEFINE_HAS_MEMBER(maybe_max_item_len_ptr) + template + bool LEFT_SLIDING_WINDOW, bool CAUSAL, typename TileScheduler, + bool MULTIITEMSCORING = false> __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp, 1) PrefillWithKVCacheKernel(CUTE_GRID_CONSTANT typename CollectiveMainloop::Params const mainloop_params, @@ -122,6 +128,23 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp // blocks in the Cluster __syncthreads(); + uint32_t* maybe_prefix_len_ptr = nullptr; + if constexpr (has_maybe_prefix_len_ptr_v) { + maybe_prefix_len_ptr = mainloop_params.additional_params.maybe_prefix_len_ptr; + } + uint16_t* maybe_token_pos_in_items_ptr = nullptr; + if constexpr (has_maybe_token_pos_in_items_ptr_v) { + maybe_token_pos_in_items_ptr = mainloop_params.additional_params.maybe_token_pos_in_items_ptr; + } + uint32_t token_pos_in_items_len = 0; + if constexpr (has_token_pos_in_items_len_v) { + token_pos_in_items_len = mainloop_params.additional_params.token_pos_in_items_len; + } + uint16_t* maybe_max_item_len_ptr = nullptr; + if constexpr (has_maybe_max_item_len_ptr_v) { + maybe_max_item_len_ptr = mainloop_params.additional_params.maybe_max_item_len_ptr; + } + if (warp_group_idx == 0) { // Producer if constexpr (use_tma_load_kv) { cutlass::arch::warpgroup_reg_dealloc(); @@ -142,8 +165,8 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp work_tile_info = scheduler.template get_next_work( scheduler_params, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = - block_coord; + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, + batch_idx] = block_coord; if (q_tile_idx * CTA_Q >= qo_len) { continue; @@ -155,9 +178,26 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp scheduler.broadcast_next_work(work_tile_info); continue; } - collective_mainloop.load( - mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, - shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + int num_kv_tiles_outside_items_window = 0; + int num_kv_tiles_prefix = 0; + if constexpr (MULTIITEMSCORING) { + auto prefix_len = __ldg(maybe_prefix_len_ptr + batch_idx); + auto max_item_len = __ldg(maybe_max_item_len_ptr + batch_idx); + auto valid_items_window_len = + std::max(0, (q_tile_idx + 1) * CTA_Q + kv_len - qo_len - max_item_len); + num_kv_tiles_outside_items_window = cute::ceil_div(valid_items_window_len, CTA_KV); + num_kv_tiles_prefix = cute::ceil_div(prefix_len, CTA_KV); + } + if constexpr (MULTIITEMSCORING) { + collective_mainloop.load( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, + num_kv_tiles_outside_items_window, num_kv_tiles_prefix); + } else { + collective_mainloop.load( + mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + } ++work_idx; } collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); @@ -190,7 +230,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; AttentionVariant variant(mainloop_params, block_coord); @@ -217,12 +257,29 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp q_tile_idx, qo_len, kv_len); } - mma_f16( mainloop_params, variant, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, tOrO, attention_updater, num_kv_tiles, swa_begin_kv_tile_idx, swa_end_kv_tile_idx, threadIdx.x - NUM_COPY_THREADS, work_idx, q_tile_idx, shared_storage, qo_len, kv_len, - qo_head_idx, kv_head_idx); + qo_head_idx, kv_head_idx, prefix_len, token_pos_in_items, + num_kv_tiles_outside_items_window, num_kv_tiles_prefix); collective_epilogue.store(epilogue_params, tOrO, attention_updater.get_lse(), shared_storage, tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS, block_coord); @@ -295,7 +352,7 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaS } template + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename Params, bool MULTIITEMSCORING = false> cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; @@ -303,8 +360,8 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, using DTypeO = typename KernelTraits::DTypeO; using IdType = typename KernelTraits::IdType; - using CollectiveMainloop = - SparseCollectiveMainloop; + using CollectiveMainloop = SparseCollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; using Scheduler = std::conditional_t, @@ -333,17 +390,22 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, }); typename Scheduler::Arguments scheduler_args = { - params.work_indptr, params.head_indices, - params.qo_tile_indices, params.qo_indptr, - params.kv_indptr, params.qo_lens, - params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), + params.work_indptr, + params.head_indices, + params.qo_tile_indices, + params.qo_indptr, + params.kv_indptr, + params.qo_lens, + params.kv_lens, + params.batch_indices, + cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), params.num_qo_heads}; typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); // Get the ptr to kernel function. auto kernel = (void*)PrefillWithKVCacheKernel; + LEFT_SLIDING_WINDOW, CAUSAL, Scheduler, MULTIITEMSCORING>; int smem_size = sizeof(typename KernelTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -403,10 +465,15 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, // NOTE(Zihao): add support for kv head-major later typename Scheduler::Arguments scheduler_args = { - params.work_indptr, params.head_indices, - params.qo_tile_indices, params.qo_indptr, - params.kv_indptr, params.qo_lens, - params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), + params.work_indptr, + params.head_indices, + params.qo_tile_indices, + params.qo_indptr, + params.kv_indptr, + params.qo_lens, + params.kv_lens, + params.batch_indices, + cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), params.num_qo_heads}; typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); @@ -501,6 +568,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t return cudaErrorNotSupported; // Not supported yet. } constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + constexpr bool MULTIITEMSCORING = MASK_MODE == MaskMode::kMultiItemScoring; if constexpr (HEAD_DIM_QK == HEAD_DIM_VO) { if constexpr (HEAD_DIM_VO == 64) { // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later @@ -511,7 +579,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, typename Params::DTypeO, typename Params::IdType, AttentionVariant>, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( + params, stream); } else if constexpr (HEAD_DIM_VO == 128) { BatchPrefillWithPagedKVCacheKernelTraitsDispatched< AttentionKernelTraits, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( + params, stream); } else { // HEAD_DIM == 256; // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later @@ -531,7 +601,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t /*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV, typename Params::DTypeO, typename Params::IdType, AttentionVariant>, - LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>( + params, stream); } } else { return cudaErrorNotSupported; diff --git a/include/flashinfer/attention/hopper/quantization/epilogue.cuh b/include/flashinfer/attention/hopper/quantization/epilogue.cuh index 0d4883fc5..8bf5098d4 100644 --- a/include/flashinfer/attention/hopper/quantization/epilogue.cuh +++ b/include/flashinfer/attention/hopper/quantization/epilogue.cuh @@ -110,7 +110,7 @@ struct FP8CollectiveEpilogue { CUTLASS_DEVICE void store(Params const& epilogue_params, FrgTensorO const& tOrO, FrgTensorLSE const& lse, SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, BlockCoord const& block_coord) { - auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; // No need for FP8 column permutation @@ -170,7 +170,7 @@ struct FP8CollectiveEpilogue { template CUTLASS_DEVICE void store_zero(Params const& epilogue_params, SharedStorage& shared_storage, int thread_idx, BlockCoord const& block_coord) { - auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.O_ptr), epilogue_params.layout_O); Tensor gO = get_local_tile_tensor(mO, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh index 0fe91760f..988f7e9ac 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh @@ -177,7 +177,8 @@ struct FP8CollectiveMainloop { make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), SmemLayoutVtTransposeTgt{})); auto v_tranposer = SmemTransposeFP8_64x64(); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = + block_coord; // Prepare the TMA loads Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh index 0b96aaa74..9720af575 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh @@ -26,7 +26,7 @@ CUTLASS_DEVICE void mma_fp8(const Params& mainloop_params, AttentionVariant& var int swa_end_kv_tile_idx, int thread_idx, int work_idx, int q_tile_idx, SharedStorage& shared_storage, const int32_t qo_len, const int32_t kv_len, const int32_t qo_head_idx, - const int32_t kv_head_idx) { + const int32_t kv_head_idx, const int32_t batch_idx) { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using IdType = typename Ktraits::IdType; @@ -100,8 +100,8 @@ CUTLASS_DEVICE void mma_fp8(const Params& mainloop_params, AttentionVariant& var for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, - qo_head_idx, kv_head_idx); + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/batch_idx, qo_idx, + kv_idx, qo_head_idx, kv_head_idx); if constexpr (!CAUSAL) { // Just masking based on col if (kv_idx >= kv_len) { tSrS(i) = AttentionUpdater::fill_value; @@ -153,8 +153,8 @@ CUTLASS_DEVICE void mma_fp8(const Params& mainloop_params, AttentionVariant& var for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, - qo_head_idx, kv_head_idx); + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/batch_idx, qo_idx, + kv_idx, qo_head_idx, kv_head_idx); if (kv_idx >= col_limit_right(qo_idx)) { tSrS(i) = AttentionUpdater::fill_value; } @@ -199,8 +199,8 @@ CUTLASS_DEVICE void mma_fp8(const Params& mainloop_params, AttentionVariant& var for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, - qo_head_idx, kv_head_idx); + tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/batch_idx, qo_idx, + kv_idx, qo_head_idx, kv_head_idx); } attention_updater.update(tSrS); diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh index 49ab44590..1b1784244 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh @@ -185,7 +185,8 @@ struct FP8SparseCollectiveMainloop { auto v_tranposer = SmemTransposeFP8_64x64(); /* ----- V Transpose ---- */ - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = + block_coord; // Prepare the TMA loads Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, diff --git a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh index 522183f6e..5081f1dd1 100644 --- a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh @@ -157,7 +157,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; if (q_tile_idx * CTA_Q >= qo_len) { @@ -206,7 +206,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp clear(tOrO); auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; AttentionVariant variant(mainloop_params, block_coord); @@ -238,7 +238,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp mainloop_params, variant, pipeline_k, pipeline_vt, smem_pipe_read_k, smem_pipe_read_v, tOrO, attention_updater, num_kv_tiles, swa_begin_kv_tile_idx, swa_end_kv_tile_idx, threadIdx.x - NUM_COPY_THREADS, work_idx, q_tile_idx, shared_storage, qo_len, kv_len, - qo_head_idx, kv_head_idx); + qo_head_idx, kv_head_idx, batch_idx); collective_epilogue.store(epilogue_params, tOrO, attention_updater.get_lse(), shared_storage, tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS, block_coord); diff --git a/include/flashinfer/attention/hopper/sparse_mainloop.cuh b/include/flashinfer/attention/hopper/sparse_mainloop.cuh index 2e1154431..52fb54d7d 100644 --- a/include/flashinfer/attention/hopper/sparse_mainloop.cuh +++ b/include/flashinfer/attention/hopper/sparse_mainloop.cuh @@ -33,7 +33,7 @@ namespace flashinfer { using namespace cute; -template +template struct SparseCollectiveMainloop { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; @@ -153,6 +153,10 @@ struct SparseCollectiveMainloop { num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); } + if constexpr (MULTIITEMSCORING) { + num_kv_tiles = std::min(num_kv_tiles, + cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); + } return num_kv_tiles; } @@ -164,7 +168,9 @@ struct SparseCollectiveMainloop { PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, typename Scheduler::WorkTileInfo& work_tile_info, - BlockCoord const& block_coord, int work_idx) { + BlockCoord const& block_coord, int work_idx, + const int num_kv_tiles_outside_items_window = 0, + const int num_kv_tiles_prefix = 0) { int thread_idx = threadIdx.x; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); @@ -173,7 +179,8 @@ struct SparseCollectiveMainloop { Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = + block_coord; // Prepare the TMA loads Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, @@ -235,6 +242,16 @@ struct SparseCollectiveMainloop { auto s_coords = tVcVGroup(_0{}, coords); return elem_less(get<0>(s_coords), valid_last_kv_tile_size); }; + auto kv_tile_idx_decrement = [&](int kv_tile_idx) { + int result = kv_tile_idx - 1; + if constexpr (MULTIITEMSCORING) { + if ((kv_tile_idx == num_kv_tiles_outside_items_window - 1) & + (kv_tile_idx >= num_kv_tiles_prefix)) { + result = num_kv_tiles_prefix - 1; + } + } + return result; + }; // load last k-tile { @@ -278,8 +295,8 @@ struct SparseCollectiveMainloop { } else { // load second last k-tile and last v-tile pipeline_k.producer_acquire(smem_pipe_write_k); - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) copy(gmem_tiled_copy_k, tKgKi, tKsKi); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); @@ -292,16 +309,17 @@ struct SparseCollectiveMainloop { copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); - --kv_tile_idx; + kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx); ++smem_pipe_write_v; // load remaining k/v tiles #pragma unroll 2 - for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { + for (; kv_tile_idx > swa_begin_kv_tile_idx; + kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { pipeline_k.producer_acquire(smem_pipe_write_k); - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) copy(gmem_tiled_copy_k, tKgKi, tKsKi); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); diff --git a/include/flashinfer/attention/hopper/tile_scheduler.cuh b/include/flashinfer/attention/hopper/tile_scheduler.cuh index 4637d3cff..51a322346 100644 --- a/include/flashinfer/attention/hopper/tile_scheduler.cuh +++ b/include/flashinfer/attention/hopper/tile_scheduler.cuh @@ -46,8 +46,8 @@ struct SingleTileScheduler { CUTLASS_DEVICE auto get_block_coord(Params const& params) const { - return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, /*qo_indptr=*/0, - /*kv_indptr=*/0, params.qo_len, params.kv_len}; + return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, /*qo_indptr=*/0, + /*kv_indptr=*/0, params.qo_len, params.kv_len, /*batch_idx=*/0}; } }; @@ -83,7 +83,7 @@ struct BatchPrefillPersistentTileScheduler { // Host side kernel arguments struct Arguments { IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, - *kv_lens; + *kv_lens, *batch_indices; cutlass::FastDivmod group_size_fastdiv; int num_qo_heads; // placeholder }; @@ -91,13 +91,14 @@ struct BatchPrefillPersistentTileScheduler { // Device side kernel params struct Params { IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, - *kv_lens; + *kv_lens, *batch_indices; cutlass::FastDivmod group_size_fastdiv; }; static Params to_underlying_arguments(Arguments const& args) { - return {args.work_indptr, args.head_indices, args.qo_tile_indices, args.qo_indptr, - args.kv_indptr, args.qo_lens, args.kv_lens, args.group_size_fastdiv}; + return {args.work_indptr, args.head_indices, args.qo_tile_indices, + args.qo_indptr, args.kv_indptr, args.qo_lens, + args.kv_lens, args.batch_indices, args.group_size_fastdiv}; } static dim3 get_grid_dim(Arguments const& args, int num_sm) { return {(unsigned)num_sm}; } @@ -110,6 +111,7 @@ struct BatchPrefillPersistentTileScheduler { int kv_indptr = 0; int qo_len = 0; int kv_len = 0; + int batch_idx = 0; int counter = 0; int ptr_begin = 0; int ptr_end = 0; @@ -120,7 +122,7 @@ struct BatchPrefillPersistentTileScheduler { CUTLASS_DEVICE auto get_block_coord(Params const& params) const { return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, - kv_indptr, qo_len, kv_len}; + kv_indptr, qo_len, kv_len, batch_idx}; } }; @@ -142,11 +144,12 @@ struct BatchPrefillPersistentTileScheduler { params.kv_indptr[work_idx], params.qo_lens[work_idx], params.kv_lens[work_idx], + params.batch_indices[work_idx], /*counter=*/0, ptr_begin, ptr_end}; } else { - return {-1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end}; + return {-1, -1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end}; } } @@ -173,6 +176,7 @@ struct BatchPrefillPersistentTileScheduler { params.kv_indptr[work_idx], params.qo_lens[work_idx], params.kv_lens[work_idx], + params.batch_indices[work_idx], current_work.counter + 1, current_work.ptr_begin, current_work.ptr_end}; @@ -183,6 +187,7 @@ struct BatchPrefillPersistentTileScheduler { -1, -1, -1, + -1, current_work.counter + 1, current_work.ptr_begin, current_work.ptr_end}; @@ -199,21 +204,23 @@ struct BatchPrefillTileScheduler { // Host side kernel arguments struct Arguments { IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, - *kv_lens; // head_indices is a placeholder + *kv_lens, *batch_indices; // head_indices is a placeholder cutlass::FastDivmod group_size_fastdiv; int num_qo_heads; }; // Device side kernel params struct Params { - IdType *work_indptr, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, *kv_lens; + IdType *work_indptr, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, *kv_lens, + *batch_indices; cutlass::FastDivmod group_size_fastdiv; int num_qo_heads; }; static Params to_underlying_arguments(Arguments const& args) { - return {args.work_indptr, args.qo_tile_indices, args.qo_indptr, args.kv_indptr, - args.qo_lens, args.kv_lens, args.group_size_fastdiv, args.num_qo_heads}; + return {args.work_indptr, args.qo_tile_indices, args.qo_indptr, args.kv_indptr, + args.qo_lens, args.kv_lens, args.batch_indices, args.group_size_fastdiv, + args.num_qo_heads}; } static dim3 get_grid_dim(Arguments const& args, int num_sm) { @@ -228,6 +235,7 @@ struct BatchPrefillTileScheduler { int kv_indptr = 0; int qo_len = 0; int kv_len = 0; + int batch_idx = 0; int counter = 0; int ptr_begin = 0; int ptr_end = 0; @@ -238,7 +246,7 @@ struct BatchPrefillTileScheduler { CUTLASS_DEVICE auto get_block_coord(Params const& params) const { return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, - kv_indptr, qo_len, kv_len}; + kv_indptr, qo_len, kv_len, batch_idx}; } }; @@ -260,11 +268,12 @@ struct BatchPrefillTileScheduler { params.kv_indptr[work_idx], params.qo_lens[work_idx], params.kv_lens[work_idx], + params.batch_indices[work_idx], /*counter=*/0, ptr_begin, ptr_end}; } else { - return {-1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end}; + return {-1, -1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end}; } } @@ -282,11 +291,17 @@ struct BatchPrefillTileScheduler { WorkTileInfo const& current_work) const { int work_idx = current_work.ptr_begin + current_work.counter + 1; if (work_idx < current_work.ptr_end) { - return {params.qo_tile_indices[work_idx], current_work.qo_head_idx, - current_work.kv_head_idx, params.qo_indptr[work_idx], - params.kv_indptr[work_idx], params.qo_lens[work_idx], - params.kv_lens[work_idx], current_work.counter + 1, - current_work.ptr_begin, current_work.ptr_end}; + return {params.qo_tile_indices[work_idx], + current_work.qo_head_idx, + current_work.kv_head_idx, + params.qo_indptr[work_idx], + params.kv_indptr[work_idx], + params.qo_lens[work_idx], + params.kv_lens[work_idx], + params.batch_indices[work_idx], + current_work.counter + 1, + current_work.ptr_begin, + current_work.ptr_end}; } else { return {-1, -1, @@ -294,6 +309,7 @@ struct BatchPrefillTileScheduler { -1, -1, -1, + -1, current_work.counter + 1, current_work.ptr_begin, current_work.ptr_end}; diff --git a/include/flashinfer/attention/hopper/variants.cuh b/include/flashinfer/attention/hopper/variants.cuh index ae6c53377..1a199e2d3 100644 --- a/include/flashinfer/attention/hopper/variants.cuh +++ b/include/flashinfer/attention/hopper/variants.cuh @@ -66,7 +66,8 @@ struct StandardFP8Attention { template __device__ StandardFP8Attention(const MainloopParams& params, const BlockCoord& block_coord) { - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = + block_coord; // 448 for e4m3; 57344 for e5m2 p_scale = std::numeric_limits::max(); scale_pv = params.additional_params.scale_v[kv_head_idx] / p_scale; diff --git a/include/flashinfer/attention/mask.cuh b/include/flashinfer/attention/mask.cuh index 771c2f476..6692b0cf3 100644 --- a/include/flashinfer/attention/mask.cuh +++ b/include/flashinfer/attention/mask.cuh @@ -22,6 +22,7 @@ enum class MaskMode { kNone = 0U, // No mask kCausal = 1U, // Causal mask kCustom = 2U, // Custom mask + kMultiItemScoring = 3U, }; } // namespace flashinfer diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index f3336b9bd..fb0d260dc 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -41,6 +41,10 @@ namespace flashinfer { DEFINE_HAS_MEMBER(maybe_q_rope_offset) DEFINE_HAS_MEMBER(maybe_k_rope_offset) +DEFINE_HAS_MEMBER(maybe_prefix_len_ptr) +DEFINE_HAS_MEMBER(maybe_token_pos_in_items_ptr) +DEFINE_HAS_MEMBER(token_pos_in_items_len) +DEFINE_HAS_MEMBER(maybe_max_item_len_ptr) namespace cg = cooperative_groups; using cp_async::SharedMemFillMode; @@ -786,6 +790,73 @@ __device__ __forceinline__ void logits_mask( } } +template +__device__ __forceinline__ void logits_mask_multi_item_scoring( + const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, const uint32_t window_left, const uint32_t chunk_end, + const uint_fastdiv group_size, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], + // new arguments for compact description of mask + const uint32_t prefix_len, uint16_t* token_pos_in_items) { + const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], + r[mma_q][j]); + } + } + // prefetching global memory to registers + uint16_t token_pos_in_items_regs[NUM_MMA_Q][(4 / 2)]; +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t eff_reg_id = 0; eff_reg_id < (4 / 2); ++eff_reg_id) { + const uint32_t q_idx = q[mma_q][eff_reg_id]; + // use __ldca to hint compiler to cache in L1 for further reuse by other tiles + const int idx_in_original_seq = q_idx + kv_len - qo_len; + if (idx_in_original_seq >= prefix_len & idx_in_original_seq < kv_len) { + token_pos_in_items_regs[mma_q][eff_reg_id] = + __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); + } + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; + const uint32_t idx_in_original_seq = q_idx + kv_len - qo_len; + const bool out_of_boundary = kv_idx > idx_in_original_seq || (kv_idx >= chunk_end) || + kv_idx + window_left < idx_in_original_seq; + const bool is_prefix = idx_in_original_seq < prefix_len; + if (out_of_boundary || is_prefix) { + s_frag[mma_q][mma_kv][reg_id] = + out_of_boundary ? (KTraits::MaskFillValue) : s_frag[mma_q][mma_kv][reg_id]; + } else { + s_frag[mma_q][mma_kv][reg_id] = + (kv_idx < prefix_len | + (idx_in_original_seq < kv_idx + token_pos_in_items_regs[mma_q][((reg_id % 4) / 2)])) + ? s_frag[mma_q][mma_kv][reg_id] + : (KTraits::MaskFillValue); + } + } + } + } +} + template __device__ __forceinline__ void update_mdo_states( typename KTraits::AttentionVariant variant, @@ -1980,6 +2051,23 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( const int32_t maybe_window_left = params.window_left; const uint_fastdiv& group_size = params.group_size; + uint32_t* maybe_prefix_len_ptr = nullptr; + if constexpr (has_maybe_prefix_len_ptr_v) { + maybe_prefix_len_ptr = params.maybe_prefix_len_ptr; + } + uint16_t* maybe_token_pos_in_items_ptr = nullptr; + if constexpr (has_maybe_token_pos_in_items_ptr_v) { + maybe_token_pos_in_items_ptr = params.maybe_token_pos_in_items_ptr; + } + uint32_t token_pos_in_items_len = 0; + if constexpr (has_token_pos_in_items_len_v) { + token_pos_in_items_len = params.token_pos_in_items_len; + } + uint16_t* maybe_max_item_len_ptr = nullptr; + if constexpr (has_maybe_max_item_len_ptr_v) { + maybe_max_item_len_ptr = params.maybe_max_item_len_ptr; + } + static_assert(sizeof(DTypeQ) == 2); auto block = cg::this_thread_block(); const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); @@ -2093,13 +2181,43 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( chunk_size, tid); cp_async::commit_group(); - const uint32_t num_iterations = ceil_div( - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, sub_if_greater_or_zero( - kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, - chunk_start)) - : chunk_size), - CTA_TILE_KV); + uint32_t num_iterations_prefix; + uint32_t num_iterations_mask; + uint32_t num_iterations = 0; + + if constexpr (MASK_MODE != MaskMode::kMultiItemScoring) { + num_iterations = + ceil_div((MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size), + CTA_TILE_KV); + } else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) { + num_iterations_prefix = ceil_div( + min(min(chunk_size, sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)), + sub_if_greater_or_zero(__ldg(maybe_prefix_len_ptr + request_idx), chunk_start)), + CTA_TILE_KV); + num_iterations_mask = max( + min(chunk_size, + sub_if_greater_or_zero( + sub_if_greater_or_zero(kv_len - qo_len + (qo_tile_idx * CTA_TILE_Q) / group_size, + __ldg(maybe_max_item_len_ptr + request_idx)), + chunk_start)) / + (CTA_TILE_KV), + num_iterations_prefix); + + num_iterations = max( + num_iterations_mask, + ceil_div( + min(chunk_size, sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)), + CTA_TILE_KV)); + } const uint32_t window_iteration = ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, @@ -2115,8 +2233,16 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( CTA_TILE_KV; #pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - packed_page_iter_base += CTA_TILE_KV; + for (uint32_t iter = 0; iter < num_iterations; + iter = (MASK_MODE == MaskMode::kMultiItemScoring) + ? ((iter + 1 == num_iterations_prefix) ? num_iterations_mask : (iter + 1)) + : (iter + 1)) { + const uint32_t prefetch_skip_step = + (MASK_MODE == MaskMode::kMultiItemScoring) + ? ((iter + 1 == num_iterations_prefix) ? (num_iterations_mask - num_iterations_prefix) + : 0) + : 0; + packed_page_iter_base += (1 + prefetch_skip_step) * CTA_TILE_KV; #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { @@ -2149,11 +2275,37 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); // apply mask - if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + if (MASK_MODE == MaskMode::kCustom) { logits_mask( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, - chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, - qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag); + } else { + if constexpr (MASK_MODE != MaskMode::kMultiItemScoring) { + if (iter >= mask_iteration || iter < window_iteration) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag); + } + } else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) { + if (iter + 1 >= num_iterations_prefix) { + logits_mask_multi_item_scoring( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, + qo_len, kv_len, window_left, chunk_end, group_size, s_frag, + __ldg(maybe_prefix_len_ptr + request_idx), + maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len); + } else { + if (iter >= mask_iteration || iter < window_iteration) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag); + } + } + } } // compute m,d states in online softmax diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index ae943a719..89eb6a082 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -789,6 +789,7 @@ struct PrefillPlanSM90Info { int64_t kv_len_offset; int64_t head_indices_offset; int64_t work_indptr_offset; + int64_t batch_indices_offset; bool same_schedule_for_all_heads; PrefillPlanSM90Info() @@ -799,21 +800,21 @@ struct PrefillPlanSM90Info { kv_len_offset(0), head_indices_offset(0), work_indptr_offset(0), + batch_indices_offset(0), same_schedule_for_all_heads(false) {} // convert PrefillPlanSM90Info to std::vector std::vector ToVector() const { - return {qo_tile_indices_offset, qo_indptr_offset, - kv_indptr_offset, qo_len_offset, - kv_len_offset, head_indices_offset, - work_indptr_offset, same_schedule_for_all_heads}; + return {qo_tile_indices_offset, qo_indptr_offset, kv_indptr_offset, + qo_len_offset, kv_len_offset, head_indices_offset, + work_indptr_offset, batch_indices_offset, same_schedule_for_all_heads}; } // From std::vector to PrefillPlanSM90Info void FromVector(const std::vector& vec) { - if (vec.size() != 8) { + if (vec.size() != 9) { std::ostringstream err_msg; - err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be 8, but got " << vec.size(); + err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be 9, but got " << vec.size(); FLASHINFER_ERROR(err_msg.str()); } qo_tile_indices_offset = vec[0]; @@ -823,7 +824,8 @@ struct PrefillPlanSM90Info { kv_len_offset = vec[4]; head_indices_offset = vec[5]; work_indptr_offset = vec[6]; - same_schedule_for_all_heads = vec[7]; + batch_indices_offset = vec[7]; + same_schedule_for_all_heads = vec[8]; } }; @@ -879,7 +881,8 @@ inline cudaError_t PrefillSM90Plan( cta_kv_indptr(num_sm90_ctas, std::vector()), cta_qo_len(num_sm90_ctas, std::vector()), cta_kv_len(num_sm90_ctas, std::vector()), - cta_head_indices(num_sm90_ctas, std::vector()); + cta_head_indices(num_sm90_ctas, std::vector()), + cta_batch_indices(num_sm90_ctas, std::vector()); int max_num_works_per_head = ceil_div(total_num_rows, cta_tile_q) + batch_size - 1; plan_info.same_schedule_for_all_heads = max_num_works_per_head > 4096; @@ -903,6 +906,7 @@ inline cudaError_t PrefillSM90Plan( cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]); cta_kv_len[cta_idx].push_back(kv_len); cta_head_indices[cta_idx].push_back(qo_head_idx); + cta_batch_indices[cta_idx].push_back(i); } } } @@ -918,6 +922,7 @@ inline cudaError_t PrefillSM90Plan( auto qo_len_vec = flatten(cta_qo_len, total_num_works); auto kv_len_vec = flatten(cta_kv_len, total_num_works); auto head_indices_vec = flatten(cta_head_indices, total_num_works); + auto batch_indices_vec = flatten(cta_batch_indices, total_num_works); AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); int max_total_num_works; @@ -944,6 +949,8 @@ inline cudaError_t PrefillSM90Plan( sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_head_indices"); plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset( sizeof(IdType) * (num_sm90_ctas + 1), 16, "batch_prefill_sm90_work_indptr"); + plan_info.batch_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_batch_indices"); IdType* qo_tile_indices_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); @@ -957,6 +964,8 @@ inline cudaError_t PrefillSM90Plan( GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.head_indices_offset); IdType* work_indptr_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.work_indptr_offset); + IdType* batch_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.batch_indices_offset); std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); std::copy(qo_indptr_vec.begin(), qo_indptr_vec.end(), qo_offset_h); @@ -965,6 +974,7 @@ inline cudaError_t PrefillSM90Plan( std::copy(kv_len_vec.begin(), kv_len_vec.end(), kv_len_h); std::copy(head_indices_vec.begin(), head_indices_vec.end(), head_indices_h); std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), work_indptr_h); + std::copy(batch_indices_vec.begin(), batch_indices_vec.end(), batch_indices_h); size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index cef905be5..8fbb1eaac 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -142,28 +142,33 @@ FLASHINFER_ERROR(err_msg.str()); \ } -#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ - switch (mask_mode) { \ - case MaskMode::kNone: { \ - constexpr MaskMode MASK_MODE = MaskMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCausal: { \ - constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCustom: { \ - constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported mask_mode: " << int(mask_mode); \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kMultiItemScoring: { \ + constexpr MaskMode MASK_MODE = MaskMode::kMultiItemScoring; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ } // convert head_dim to compile-time constant diff --git a/setup.py b/setup.py index c24555bf7..dd3760da6 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ [(k, v) for k, v in SM90_ALLOWED_HEAD_DIMS if k != v] ) # Always enable (192,128) +# NOTE(Zihao): exclude 3 (multi-item scoring) from AOT wheel mask_modes = [0, 1, 2] enable_aot = os.environ.get("FLASHINFER_ENABLE_AOT", "0") == "1" diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index a3c3ebaf5..bcd17222c 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -435,11 +435,14 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( paged_kv_t paged_kv, DTypeO* o, float* lse, uint32_t num_qo_heads, bool causal = true, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool use_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { + float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr, + uint32_t* maybe_prefix_len_ptr, uint16_t* maybe_token_pos_in_items_ptr, + uint32_t token_pos_in_items_len, uint16_t* maybe_max_item_len_ptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; - const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + if (maybe_prefix_len_ptr != nullptr) mask_mode = MaskMode::kMultiItemScoring; auto plan_info = handler->GetPlanInfo(); DISPATCH_head_dim( head_dim, HEAD_DIM, @@ -470,6 +473,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( params.max_total_num_rows = plan_info.total_num_rows; params.total_num_rows = handler->GetTotalNumRows(); params.padded_batch_size = plan_info.padded_batch_size; + params.maybe_prefix_len_ptr = maybe_prefix_len_ptr; + params.maybe_token_pos_in_items_ptr = maybe_token_pos_in_items_ptr; + params.token_pos_in_items_len = token_pos_in_items_len; + params.maybe_max_item_len_ptr = maybe_max_item_len_ptr; DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { return BatchPrefillWithPagedKVCacheDispatched< CTA_TILE_Q, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index 51b260c82..84234a02b 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -14,6 +14,7 @@ limitations under the License. """ +import numpy import pytest import torch from jit_utils import jit_prefill_attention_func_args @@ -803,6 +804,219 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "kv_len, qo_len, prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr", + [ + (54, 37, 17, list(range(17)) + list(range(19)) + [0], 100, [18]), + (97, 81, 16, list(range(80)) + [0], 97, [79]), + ], +) +@pytest.mark.parametrize("page_size", [1, 5, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["ROPE_LLAMA"]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +@pytest.mark.parametrize("return_lse", [True, False]) +def test_batch_prefill_with_paged_kv_cache_multi_item_scoring( + batch_size, + kv_len, + qo_len, + prefix_len_ptr, + token_pos_in_items_ptr, + token_pos_in_items_len, + max_item_len_ptr, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + kv_layout, + pos_encoding_mode, + logits_soft_cap, + return_lse, +): + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + if kv_layout == "HND" + else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + .to(0) + .half() + ) + kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq + kv_indices_cpu = torch.arange(0, total_num_pages).int() + kv_last_page_len_cpu = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + q_indptr_gpu = q_indptr_cpu.to(0) + kv_indptr_gpu = kv_indptr_cpu.to(0) + kv_indices_gpu = kv_indices_cpu.to(0) + kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), + token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) + .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=torch.tensor(token_pos_in_items_len) + .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), + ) + if return_lse: + o, _ = wrapper.run_return_lse(q, kv_data) + else: + o = wrapper.run(q, kv_data) + + for i in range(batch_size): + perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] + perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] + qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + ki = torch.cat( + [ + kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]] + if kv_layout == "HND" + else kv_data[ + kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], : + ] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + vi = torch.cat( + [ + kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]] + if kv_layout == "HND" + else kv_data[ + kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], : + ] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + + def create_2D_multi_item_mask_dense( + is_delimiter, sliding_window_size=-1, prefix_cache_len=None + ): + # Function to create custom_mask for multi-item scoring + # + # Note, sliding window implementation assumes that candidate_i_size < sliding_window_size < prefix_size + # Args: + # is_delimiter: a boolen torch vec to indicate the delimiter position for creating custom attnetion mask in multi-item scoring + # currently assume qo len and kv len are the same and 1D (bsz=1) case + # sliding_window_size: the window size for sliding window attention, -1 means no sliding window attention + delimiter_idx = is_delimiter.nonzero(as_tuple=True)[0] + if len(delimiter_idx) == 0: + return None + else: + first_delimiter_pos = delimiter_idx[0] + seq_len = len(is_delimiter) + pos = torch.arange(seq_len, device=is_delimiter.device) + + group_ids = torch.cumsum(is_delimiter, 0) + # Get mask for within-group causal attention + within_group_causal = (group_ids.unsqueeze(1) == group_ids.unsqueeze(0)) & ( + pos.unsqueeze(0) <= pos.unsqueeze(1) + ) + # Combine all conditions + attention_mask = ( + ( + within_group_causal + | ( + (pos >= first_delimiter_pos).unsqueeze(1) + & (pos < first_delimiter_pos).unsqueeze(0) + ) # Prefix attention + ) + & ~is_delimiter.unsqueeze(0) + & ~is_delimiter.unsqueeze(1) + ) # No delimiter attention + + if sliding_window_size > 0 and sliding_window_size < len(is_delimiter): + # Calculate how many positions from right of prefix each token can attend to + + group_size = torch.sum( + within_group_causal & ~is_delimiter.unsqueeze(0), dim=1 + ) + + # For prefix: after sliding_window_size position, can see window_size tokens + # For candidate items: can see (sliding_window_size - group_size) tokens from prefix end + prefix_window = torch.where( + pos >= first_delimiter_pos, + sliding_window_size - group_size, + torch.where( + pos < sliding_window_size, + first_delimiter_pos, + sliding_window_size, + ), + ) + + # Starting index of attention window relative to token position for candidate item/group + prefix_start = first_delimiter_pos - prefix_window.unsqueeze(1) + + attention_mask = attention_mask & (pos >= prefix_start) + if prefix_cache_len: + patch = torch.ones( + seq_len, + prefix_cache_len, + device=is_delimiter.device, + dtype=torch.bool, + ) + attention_mask = torch.concat([patch, attention_mask], dim=1) + return attention_mask.unsqueeze(0).reshape(-1) + + custom_mask = create_2D_multi_item_mask_dense( + is_delimiter=torch.tensor(token_pos_in_items_ptr).to(0) == 0, + sliding_window_size=-1, + prefix_cache_len=prefix_len_ptr, + ) + o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( + qi, + ki, + vi, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + custom_mask=custom_mask, + ) + o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy() + o_ref_i_np = o_ref_i.cpu().numpy() + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False, True diff --git a/tests/test_hopper.py b/tests/test_hopper.py index 266a5077c..d954a3e98 100644 --- a/tests/test_hopper.py +++ b/tests/test_hopper.py @@ -296,6 +296,250 @@ def test_batch_paged_prefill( torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "kv_len, qo_len, prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr", + [ + (54, 37, 17, list(range(17)) + list(range(19)) + [0], 100, [18]), + (97, 81, 16, list(range(80)) + [0], 97, [79]), + ], +) +@pytest.mark.parametrize("page_size", [1, 5, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +@pytest.mark.parametrize("return_lse", [True, False]) +def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3( + batch_size, + kv_len, + qo_len, + prefix_len_ptr, + token_pos_in_items_ptr, + token_pos_in_items_len, + max_item_len_ptr, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + kv_layout, + logits_soft_cap, + return_lse, +): + + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + if kv_layout == "HND" + else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + .to(0) + .half() + ) + kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq + kv_indices_cpu = torch.arange(0, total_num_pages).int() + kv_last_page_len_cpu = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + q_indptr_gpu = q_indptr_cpu.to(0) + kv_indptr_gpu = kv_indptr_cpu.to(0) + kv_indices_gpu = kv_indices_cpu.to(0) + kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) + + wrapper_fa2 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, backend="fa2" + ) + wrapper_fa2.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), + token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) + .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=torch.tensor(token_pos_in_items_len) + .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), + ) + o_fa2, lse_fa2 = wrapper_fa2.run_return_lse(q, kv_data) + + wrapper_fa3 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, backend="fa3" + ) + wrapper_fa3.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), + token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) + .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=torch.tensor(token_pos_in_items_len) + .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), + ) + + o_fa3, lse_fa3 = wrapper_fa3.run_return_lse(q, kv_data) + + torch.testing.assert_close(lse_fa2, lse_fa3, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_fa2, o_fa3, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "kv_len, qo_len, prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr", + [ + ( + 54, + 37, + [17, 17], + list(range(17)) + + list(range(19)) + + [0] + + [0] * 63 + + list(range(15)) + + list(range(21)) + + [0], + 100, + [18, 20], + ), + ( + 97, + 81, + [16, 16], + list(range(80)) + [0] + [0] * 16 + list(range(76)) + [0], + 97, + [79, 75], + ), + ], +) +@pytest.mark.parametrize("page_size", [1, 5, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +@pytest.mark.parametrize("return_lse", [True, False]) +def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3_bsz2( + batch_size, + kv_len, + qo_len, + prefix_len_ptr, + token_pos_in_items_ptr, + token_pos_in_items_len, + max_item_len_ptr, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + kv_layout, + logits_soft_cap, + return_lse, +): + + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + if kv_layout == "HND" + else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + .to(0) + .half() + ) + kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq + kv_indices_cpu = torch.arange(0, total_num_pages).int() + kv_last_page_len_cpu = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + q_indptr_gpu = q_indptr_cpu.to(0) + kv_indptr_gpu = kv_indptr_cpu.to(0) + kv_indices_gpu = kv_indices_cpu.to(0) + kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) + + wrapper_fa2 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, backend="fa2" + ) + wrapper_fa2.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), + token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) + .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=torch.tensor(token_pos_in_items_len) + .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), + ) + o_fa2, lse_fa2 = wrapper_fa2.run_return_lse(q, kv_data) + + wrapper_fa3 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, backend="fa3" + ) + wrapper_fa3.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + logits_soft_cap=logits_soft_cap, + prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), + token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) + .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=torch.tensor(token_pos_in_items_len) + .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), + ) + + o_fa3, lse_fa3 = wrapper_fa3.run_return_lse(q, kv_data) + + torch.testing.assert_close(lse_fa2, lse_fa3, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_fa2, o_fa3, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": # test_batch_prefill(14, 64, 32, 32, False, 128) # test_batch_prefill(1, 32767, 8, 8, True, 128) diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index 4a10b0ca7..935a7d9d8 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -635,7 +635,7 @@ def test_sm90_debug_print_logits(): template __device__ __host__ DebugPrintLogits(const MainloopParams& params, const BlockCoord& block_coord) { sm_scale_log2 = params.additional_params.sm_scale * math::log2e; - auto [_, __, ___, ____, _____, qo_len_, kv_len_] = + auto [_, __, ___, ____, _____, qo_len_, kv_len_, batch_idx] = block_coord; qo_len = qo_len_;