Skip to content

add multi-item scoring #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Apr 30, 2025
Merged

add multi-item scoring #1015

merged 41 commits into from
Apr 30, 2025

Conversation

arde171
Copy link
Contributor

@arde171 arde171 commented Apr 11, 2025

Co-authored with Qingquan Song (@qingquansong) and Ziang Li (@zianglih )

Multi-item scoring

  1. concatenate multiple candidates of a same member with all ranking candidates with delimiter separation.
    <member prefix (profile & history)> + + + + item 2 + ... + item N
  2. Extract the logits of the hidden states of the tokens before each delimiter token and extract the log prob of given label tokens. For each single prompt, output returned will be a 2D list with shape N * K where N is the number of candidate it contains and K is the number of choices we provided to the server engine (e.g., 2 for ["Yes", "No"])) (mainly done in the logit processor)
    image

The PR optimized the multi-item scoring attention by passing four new args and use it to check the masking condition. The provided args are:

prefix_len_ptr :Optional[torch.Tensor]
    prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr : Optional[float]
    A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter)
    for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like
    `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1,
    we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by
    `token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt.
token_pos_in_items_len : Optional[int]
    zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example.
    If we set `token_pos_in_items_len` to be 20, it will be  `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]`
    with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr : Optional[float]
    a uint16 vector contains the max token length of all items for each prompt

Optimizations

  1. Implement efficient multi-item scoring mask for FA2 and FA3.
  2. Enhance FA3 to support batch-idx for the multi-item scoring mask.
  3. Implement skip tiles for FA2 and FA3 multi-item scoring
  4. Optimize mask by preloading to L1 cache for thread register.

@qingquansong
Copy link
Contributor

Hey @yzh119 as discussed, here's the PR for multi-item scoring masked attention. Please feel free to leave comments and provide suggestions if there could be better ways to help upstream the change. Thank you in advance!

@yzh119 yzh119 self-requested a review April 11, 2025 18:11
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, leave some comments on additional parameters.

btw, some unittests failed (https://github.com/flashinfer-ai/flashinfer/blob/9220fb3443b5a5d274f00ca5552f798e225239b7/tests/test_block_sparse.py) bacause of the pybind interface change, would you mind fixing the sparse APIs as well (https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/sparse.py)?

@@ -108,6 +112,10 @@ struct PagedParams {
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;
uint32_t* prefix_len_ptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move them to additional params? I tend to managing all of the options parameters as additional, instead of default ones, which is easier to manager.

Examples include:

additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"]
(we use prefix maybe_ to indicate these components are optional and have type std::optional<...>).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 as suggested, moved multi-item scoring parameters as addtional.

@@ -66,6 +67,11 @@ struct RaggedParams {
int window_left;

bool causal;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, better to move to additional params.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -43,6 +43,7 @@ struct RaggedParams {
IdType* kv_lens;
IdType* head_indices;
IdType* work_indptr;
IdType* batch_indices;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this, yes we have to add it.

@@ -786,6 +786,73 @@ __device__ __forceinline__ void logits_mask(
}
}

template <typename KTraits, typename Params>
__device__ __forceinline__ void logits_mask_customized(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better rename this function to "logits_mask_multi_item_scoring". Or move its body to the previous "logits_mask".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -2114,9 +2216,18 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice(
: chunk_size) /
CTA_TILE_KV;

const uint32_t unified_num_iterations =
Copy link
Contributor

@zianglih zianglih Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename MIS "num_iterations_full" to "num_iterations", and MIS "num_iterations" to "num_iterations_prefix" to avoid redundancy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@qingquansong
Copy link
Contributor

Hey @yzh119 , @arde171 has resolved the comments, could you help take another look? Thank you!

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arde171 @qingquansong @zianglih thanks for the great contribution, the PR looks good to me in general and I have add some commits to address the conflicts with mainline. Let's merge this PR first and then move forward.

The remaining possible improvements include (not necessarily in this PR):

  1. consider the skipped blocks in scheduler (plan function), otherwise the ahead-of-time scheduler might have false estimation about execution time of a tile.
  2. further modularize the template to make the attention pattern in multi-item scoring a special form of attention variant, currently we still insert some special code to handle this pattern in the template but we hope we can fully decouple attention variant from the template themselves.

@yzh119 yzh119 merged commit 6c6f1a5 into flashinfer-ai:main Apr 30, 2025
2 checks passed
yzh119 added a commit that referenced this pull request May 23, 2025
<!-- .github/pull_request_template.md -->
## 📌 Description

- `batch_indices_offset` (introduced in #1015 ) are not passed to fp8
attention kernels, this PR fixes the issue.
- adding fp8 kernels to aot generators.

## 🔍 Related Issues

#1064 

---

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

---

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

---

## Reviewer Notes

cc @abcdabcd987
Edenzzzz pushed a commit to Edenzzzz/flashinfer that referenced this pull request Jun 6, 2025
…i#1087)

<!-- .github/pull_request_template.md -->
## 📌 Description

- `batch_indices_offset` (introduced in flashinfer-ai#1015 ) are not passed to fp8
attention kernels, this PR fixes the issue.
- adding fp8 kernels to aot generators.

## 🔍 Related Issues

flashinfer-ai#1064 

---

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

---

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

---

## Reviewer Notes

cc @abcdabcd987
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants