-
Notifications
You must be signed in to change notification settings - Fork 363
add multi-item scoring #1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add multi-item scoring #1015
Conversation
Co-authored-by: qingquansong <[email protected]> Co-authored-by: zianglih <[email protected]>
add multi-item scoring
Hey @yzh119 as discussed, here's the PR for multi-item scoring masked attention. Please feel free to leave comments and provide suggestions if there could be better ways to help upstream the change. Thank you in advance! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, leave some comments on additional parameters.
btw, some unittests failed (https://github.com/flashinfer-ai/flashinfer/blob/9220fb3443b5a5d274f00ca5552f798e225239b7/tests/test_block_sparse.py) bacause of the pybind interface change, would you mind fixing the sparse APIs as well (https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/sparse.py)?
@@ -108,6 +112,10 @@ struct PagedParams { | |||
uint32_t* total_num_rows; | |||
uint32_t padded_batch_size; | |||
bool partition_kv; | |||
uint32_t* prefix_len_ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move them to additional params? I tend to managing all of the options parameters as additional, instead of default ones, which is easier to manager.
Examples include:
additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] |
maybe_
to indicate these components are optional and have type std::optional<...>).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzh119 as suggested, moved multi-item scoring parameters as addtional.
@@ -66,6 +67,11 @@ struct RaggedParams { | |||
int window_left; | |||
|
|||
bool causal; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto, better to move to additional params.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -43,6 +43,7 @@ struct RaggedParams { | |||
IdType* kv_lens; | |||
IdType* head_indices; | |||
IdType* work_indptr; | |||
IdType* batch_indices; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this, yes we have to add it.
@@ -786,6 +786,73 @@ __device__ __forceinline__ void logits_mask( | |||
} | |||
} | |||
|
|||
template <typename KTraits, typename Params> | |||
__device__ __forceinline__ void logits_mask_customized( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better rename this function to "logits_mask_multi_item_scoring". Or move its body to the previous "logits_mask".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -2114,9 +2216,18 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( | |||
: chunk_size) / | |||
CTA_TILE_KV; | |||
|
|||
const uint32_t unified_num_iterations = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename MIS "num_iterations_full" to "num_iterations", and MIS "num_iterations" to "num_iterations_prefix" to avoid redundancy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
use additional params, reactor and simplify the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @arde171 @qingquansong @zianglih thanks for the great contribution, the PR looks good to me in general and I have add some commits to address the conflicts with mainline. Let's merge this PR first and then move forward.
The remaining possible improvements include (not necessarily in this PR):
- consider the skipped blocks in scheduler (
plan
function), otherwise the ahead-of-time scheduler might have false estimation about execution time of a tile. - further modularize the template to make the attention pattern in multi-item scoring a special form of attention variant, currently we still insert some special code to handle this pattern in the template but we hope we can fully decouple attention variant from the template themselves.
<!-- .github/pull_request_template.md --> ## 📌 Description - `batch_indices_offset` (introduced in #1015 ) are not passed to fp8 attention kernels, this PR fixes the issue. - adding fp8 kernels to aot generators. ## 🔍 Related Issues #1064 --- ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). --- ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). --- ## Reviewer Notes cc @abcdabcd987
…i#1087) <!-- .github/pull_request_template.md --> ## 📌 Description - `batch_indices_offset` (introduced in flashinfer-ai#1015 ) are not passed to fp8 attention kernels, this PR fixes the issue. - adding fp8 kernels to aot generators. ## 🔍 Related Issues flashinfer-ai#1064 --- ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). --- ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). --- ## Reviewer Notes cc @abcdabcd987
Co-authored with Qingquan Song (@qingquansong) and Ziang Li (@zianglih )
Multi-item scoring
<member prefix (profile & history)> + + + + item 2 + ... + item N
The PR optimized the multi-item scoring attention by passing four new args and use it to check the masking condition. The provided args are:
Optimizations