Skip to content

Commit c6b7c20

Browse files
authored
doc: update documentation for mask layout (#270)
Followup of #266 , this pr adds some docstring and diagrams for 2D ragged tensor mask layout.
1 parent b16bbe4 commit c6b7c20

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

docs/tutorials/kv_layout.rst

+34-3
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ Ragged Tensor
2525
-------------
2626

2727
In batched inference/serving, the input sequence length may vary across different samples.
28-
When there is no need to change the sequence length (e.g. in prefilling stage), we can use ``RaggedTensor`` to store
29-
the key/value tensors in KV-Cache:
28+
When there is no need to change the sequence length (e.g. in prefilling stage), we can use ``RaggedTensor``
29+
with a single ragged (variable length) dimension to store the key/value tensors in KV-Cache:
3030

3131
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/ragged.png
3232
:width: 400
@@ -41,6 +41,37 @@ shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``.
4141

4242
We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``.
4343

44+
.. _mask-layout:
45+
46+
Mask Layout (2D Ragged Tensor)
47+
------------------------------
48+
49+
The aforementioned Ragged Tensor can be generalized to multiple "ragged" dimensions. For example,
50+
the attention mask in FlashInfer is a 2D ragged tensor for batch size greater than 1:
51+
52+
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/mask-layout.png
53+
:width: 800
54+
:align: center
55+
:alt: Data structure of Mask Layout.
56+
57+
When number of requests is greater than 1, different request might have different query length and kv length.
58+
To avoid padding, we use a 2D ragged tensor to store attention mask. The input ``qo_indptr`` and
59+
``kv_indptr`` arrays (both with length ``num_requests+1``) are used to store the information of
60+
variable sequence lengths of each request,
61+
``qo_indptr[i+1]-qo_indptr[i]`` is the query length of request ``i`` (``qo_len[i]``),
62+
``kv_indptr[i+1]-kv_indptr[i]`` is the kv length of request ``i`` (``kv_len[i]``).
63+
64+
The mask array of all requests are flattened (with query as the first dimension, and kv as last dimension)
65+
and concatenated into a single 1D array: ``mask_data``. FlashInfer will create a ``qk_indptr`` array implicitly
66+
to store the start offset of each request's mask in the flattened mask array: ``qk_indptr[1:] = cumsum(qo_len * kv_len)``.
67+
68+
``mask_data`` has shape ``(qk_indptr[-1],)``, we can use ``mask_data[qk_indptr[i]:qk_indptr[i+1]]`` to slice the flattened
69+
mask of request ``i``.
70+
71+
:class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`
72+
allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions,
73+
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.
74+
4475
.. _page-layout:
4576

4677
FlashInfer APIs
@@ -92,7 +123,7 @@ FlashInfer APIs
92123
:meth:`flashinfer.page.append_paged_kv_cache` can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
93124
(the pages for these appended keys/values must be allocated prior to calling this API).
94125

95-
:class:`BatchDecodeWithPagedKVCacheWrapper` and :class:`BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
126+
:class:`flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
96127
and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.
97128

98129
FAQ

python/flashinfer/prefill.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def single_prefill_with_kv_cache(
8383
``HND``.
8484
custom_mask : Optional[torch.Tensor]
8585
The custom mask tensor, shape: ``[qo_len, kv_len]``.
86+
If provided, the custom mask will be added to the attention matrix before
87+
softmax and after scaling, and the :attr:`causal` parameter will be ignored.
8688
causal : bool
8789
Whether to apply causal mask to the attention matrix.
8890
This is only effective when :attr:`custom_mask` is not provided.
@@ -201,6 +203,8 @@ def single_prefill_with_kv_cache_return_lse(
201203
``HND``.
202204
custom_mask : Optional[torch.Tensor]
203205
The custom_mask tensor, shape: ``[qo_len, kv_len]``.
206+
If provided, the custom mask will be added to the attention matrix before
207+
softmax and after scaling, and the :attr:`causal` parameter will be ignored.
204208
causal : bool
205209
Whether to apply causal mask to the attention matrix.
206210
This is only effective when :attr:`custom_mask` is not provided.
@@ -474,7 +478,11 @@ def begin_forward(
474478
The size of each page in the paged kv-cache.
475479
custom_mask : Optional[torch.Tensor]
476480
The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``.
477-
The mask tensor will be applied to the attention matrix before softmax if provided.
481+
If provided, the custom mask will be added to the attention matrix before softmax
482+
and after scaling. The mask tensor should be in the same device as the input tensors.
483+
484+
Please refer to the :ref:`mask layout <mask-layout>` for more details about flattened
485+
layout of mask tensor.
478486
479487
Notes
480488
-----
@@ -845,7 +853,11 @@ def begin_forward(
845853
The dimension of the heads.
846854
custom_mask : Optional[torch.Tensor]
847855
The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``.
848-
The mask tensor will be added to the attention matrix before softmax.
856+
If provided, the custom mask will be added to the attention matrix before softmax
857+
and after scaling. The mask tensor should be in the same device as the input tensors.
858+
859+
Please refer to the :ref:`mask layout <mask-layout>` for more details about flattened
860+
layout of mask tensor.
849861
850862
Notes
851863
-----

0 commit comments

Comments
 (0)