You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When number of requests is greater than 1, different request might have different query length and kv length.
58
+
To avoid padding, we use a 2D ragged tensor to store attention mask. The input ``qo_indptr`` and
59
+
``kv_indptr`` arrays (both with length ``num_requests+1``) are used to store the information of
60
+
variable sequence lengths of each request,
61
+
``qo_indptr[i+1]-qo_indptr[i]`` is the query length of request ``i`` (``qo_len[i]``),
62
+
``kv_indptr[i+1]-kv_indptr[i]`` is the kv length of request ``i`` (``kv_len[i]``).
63
+
64
+
The mask array of all requests are flattened (with query as the first dimension, and kv as last dimension)
65
+
and concatenated into a single 1D array: ``mask_data``. FlashInfer will create a ``qk_indptr`` array implicitly
66
+
to store the start offset of each request's mask in the flattened mask array: ``qk_indptr[1:] = cumsum(qo_len * kv_len)``.
67
+
68
+
``mask_data`` has shape ``(qk_indptr[-1],)``, we can use ``mask_data[qk_indptr[i]:qk_indptr[i+1]]`` to slice the flattened
69
+
mask of request ``i``.
70
+
71
+
:class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`
72
+
allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions,
73
+
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.
74
+
44
75
.. _page-layout:
45
76
46
77
FlashInfer APIs
@@ -92,7 +123,7 @@ FlashInfer APIs
92
123
:meth:`flashinfer.page.append_paged_kv_cache` can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
93
124
(the pages for these appended keys/values must be allocated prior to calling this API).
94
125
95
-
:class:`BatchDecodeWithPagedKVCacheWrapper` and :class:`BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
126
+
:class:`flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
96
127
and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.
0 commit comments