Skip to content

add multi-item scoring #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2025
Merged

add multi-item scoring #1

merged 2 commits into from
Apr 11, 2025

Conversation

arde171
Copy link
Owner

@arde171 arde171 commented Apr 11, 2025

Co-authored with Qingquan Song (@qingquansong) and Ziang Li (@zianglih )

Multi-item scoring

  1. concatenate multiple candidates of a same member with all ranking candidates with delimiter separation.
    <member prefix (profile & history)> + + + + item 2 + ... + item N
  2. Extract the logits of the hidden states of the tokens before each delimiter token and extract the log prob of given label tokens. For each single prompt, output returned will be a 2D list with shape N * K where N is the number of candidate it contains and K is the number of choices we provided to the server engine (e.g., 2 for ["Yes", "No"])) (mainly done in the logit processor)
    image

The PR optimized the multi-item scoring attention by passing four new args and use it to check the masking condition. The provided args are:

prefix_len_ptr :Optional[torch.Tensor]
    prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr : Optional[float]
    A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter)
    for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like
    `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1,
    we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by
    `token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt.
token_pos_in_items_len : Optional[int]
    zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example.
    If we set `token_pos_in_items_len` to be 20, it will be  `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]`
    with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr : Optional[float]
    a uint16 vector contains the max token length of all items for each prompt

Optimizations

  1. Implement efficient multi-item scoring mask for FA2 and FA3.
  2. Enhance FA3 to support batch-idx for the multi-item scoring mask.
  3. Implement skip tiles for FA2 and FA3 multi-item scoring
  4. Optimize mask by preloading to L1 cache for thread register.

Co-authored-by: qingquansong <[email protected]>
Co-authored-by: zianglih <[email protected]>
@arde171 arde171 merged commit 1f79f0d into main Apr 11, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant