Skip to content

Commit 3d43dc9

Browse files
authored
perf: use packed bit array for attention mask (#308)
1 parent 876cc53 commit 3d43dc9

23 files changed

+593
-128
lines changed

cmake/config.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
2727
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
2828
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
2929
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
30-
set(FLASHINFER_GEN_MASK_MODES 0 1)
30+
set(FLASHINFER_GEN_MASK_MODES 0 1 2)
3131

3232
# Set target cuda architectures for tests/benchmarks, defaults to native.
3333
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.

docs/api/python/quantization.rst

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.. _apiquantization:
2+
3+
flashinfer.quantization
4+
=======================
5+
6+
Quantization related kernels.
7+
8+
.. currentmodule:: flashinfer.quantization
9+
10+
.. autosummary::
11+
:toctree: _generate
12+
13+
packbits
14+
segment_packbits

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ FlashInfer is a library for Language Languages Models that provides high-perform
3434
api/python/sampling
3535
api/python/group_gemm
3636
api/python/norm
37+
api/python/quantization

docs/tutorials/kv_layout.rst

+8
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,21 @@ to store the start offset of each request's mask in the flattened mask array: ``
7575
``mask_data`` has shape ``(qk_indptr[-1],)``, we can use ``mask_data[qk_indptr[i]:qk_indptr[i+1]]`` to slice the flattened
7676
mask of request ``i``.
7777

78+
To save memory, we can further packes the boolean flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements
79+
are packed together as a `uint8`) with "little" bit-order (see `numpy.packbits <https://numpy.org/doc/stable/reference/generated/numpy.packbits.html>`_
80+
for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed
81+
array internally.
82+
7883
FlashInfer APIs
7984
~~~~~~~~~~~~~~~
8085

8186
:class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`
8287
allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions,
8388
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.
8489

90+
:meth:`flashinfer.quantization.packbits` and :meth:`flashinfer.quantization.segment_packbits` are the utility functions
91+
to pack boolean mask into bit-packed array.
92+
8593
.. _page-layout:
8694

8795
Page Table

include/flashinfer/attention/prefill.cuh

+14-14
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ template <bool partition_kv, MaskMode mask_mode, uint32_t num_warps, uint32_t nu
547547
__device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
548548
const uint32_t kv_idx_base, const uint32_t qo_len,
549549
const uint32_t kv_len, const uint32_t chunk_end,
550-
const uint_fastdiv group_size, float* custom_mask,
550+
const uint_fastdiv group_size, uint8_t* custom_mask,
551551
DTypeQKAccum (*s_frag)[num_frags_z][8]) {
552552
const uint32_t tx = threadIdx.x;
553553
#pragma unroll
@@ -565,11 +565,11 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
565565
? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end))
566566
: kv_idx >= chunk_end);
567567
s_frag[fx][fz][reg_id] =
568-
out_of_boundary ? DTypeQKAccum(-5e4)
569-
: s_frag[fx][fz][reg_id] +
570-
DTypeQKAccum((mask_mode == MaskMode::kCustom && q_idx < qo_len)
571-
? custom_mask[q_idx * kv_len + kv_idx]
572-
: 0.f);
568+
(out_of_boundary ||
569+
((mask_mode == MaskMode::kCustom && q_idx < qo_len &&
570+
!(custom_mask[(q_idx * kv_len + kv_idx) / 8] >> ((q_idx * kv_len + kv_idx) % 8)))))
571+
? DTypeQKAccum(-5e4)
572+
: s_frag[fx][fz][reg_id];
573573
}
574574
}
575575
}
@@ -891,7 +891,7 @@ template <LogitsPostHook logits_post_hook, bool partition_kv, MaskMode mask_mode
891891
typename DTypeQKAccum, typename DTypeOut>
892892
__global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k,
893893
DTypeIn* __restrict__ v,
894-
float* __restrict__ custom_mask,
894+
uint8_t* __restrict__ custom_mask,
895895
DTypeOut* __restrict__ o, void* __restrict__ tmp,
896896
float* __restrict__ lse, const uint32_t qo_len,
897897
const uint32_t kv_len, const uint_fastdiv group_size,
@@ -1107,7 +1107,7 @@ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, QKVLayout kv_layo
11071107
__global__ void BatchPrefillWithRaggedKVCacheKernel(
11081108
DTypeIn* __restrict__ q, IdType* __restrict__ request_indices,
11091109
IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k,
1110-
DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, float* __restrict__ custom_mask,
1110+
DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask,
11111111
IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset,
11121112
IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp,
11131113
float* __restrict__ lse, uint32_t batch_size, const uint_fastdiv group_size, float sm_scale,
@@ -1324,9 +1324,9 @@ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode p
13241324
__global__ void BatchPrefillWithPagedKVCacheKernel(
13251325
IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
13261326
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
1327-
IdType* __restrict__ qo_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr,
1328-
IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp,
1329-
float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale,
1327+
IdType* __restrict__ qo_indptr, uint8_t* __restrict__ custom_mask,
1328+
IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o,
1329+
float* __restrict__ tmp, float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale,
13301330
float log2_rope_rcp_scale, float log2_rope_rcp_theta) {
13311331
static_assert(sizeof(DTypeIn) == 2);
13321332
static_assert(sizeof(DTypeOut) == 2);
@@ -1534,7 +1534,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
15341534
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
15351535
typename DTypeIn, typename DTypeOut>
15361536
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
1537-
float* custom_mask, DTypeOut* o, float* tmp,
1537+
uint8_t* custom_mask, DTypeOut* o, float* tmp,
15381538
float* lse, uint32_t num_qo_heads,
15391539
uint32_t num_kv_heads, uint32_t qo_len,
15401540
uint32_t kv_len, float sm_scale, float rope_scale,
@@ -1674,7 +1674,7 @@ template <uint32_t num_frags_x, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HO
16741674
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
16751675
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
16761676
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
1677-
DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
1677+
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
16781678
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size,
16791679
const uint32_t num_qo_heads, const uint32_t num_qo_tiles, const uint32_t num_kv_heads,
16801680
const float sm_scale, const float rope_scale, const float rope_theta,
@@ -1758,7 +1758,7 @@ template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
17581758
typename IdType>
17591759
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
17601760
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
1761-
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, float* custom_mask,
1761+
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
17621762
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads,
17631763
uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta,
17641764
cudaStream_t stream) {

include/flashinfer/prefill_attention_decl.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
3232
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
3333
typename DTypeIn, typename DTypeOut>
3434
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
35-
float* custom_mask, DTypeOut* o, float* tmp,
35+
uint8_t* custom_mask, DTypeOut* o, float* tmp,
3636
float* lse, uint32_t num_qo_heads,
3737
uint32_t num_kv_heads, uint32_t qo_len,
3838
uint32_t kv_len, float sm_scale, float rope_scale,
@@ -43,7 +43,7 @@ template <uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HO
4343
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
4444
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
4545
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
46-
DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
46+
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
4747
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size,
4848
uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale,
4949
float rope_scale, float rope_theta, cudaStream_t stream = nullptr);
@@ -54,7 +54,7 @@ template <PageStorage PAGE_STORAGE, uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM,
5454
typename IdType>
5555
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
5656
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
57-
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, float* custom_mask,
57+
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
5858
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles,
5959
uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);
6060

@@ -63,7 +63,7 @@ template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POS
6363
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
6464
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
6565
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
66-
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, float* custom_mask,
66+
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
6767
IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale,
6868
float rope_scale, float rope_theta, cudaStream_t stream) {
6969
float* tmp = nullptr;
@@ -98,7 +98,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
9898
typename DTypeIn, typename DTypeOut, typename IdType>
9999
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
100100
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
101-
IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
101+
IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
102102
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads,
103103
uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta,
104104
cudaStream_t stream) {

include/flashinfer/quantization.cuh

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_QUANTIZATION_CUH_
17+
#define FLASHINFER_QUANTIZATION_CUH_
18+
#include <cuda_runtime.h>
19+
#include <cuda_runtime_api.h>
20+
21+
#include <cub/cub.cuh>
22+
23+
#include "utils.cuh"
24+
25+
namespace flashinfer {
26+
namespace quantization {
27+
28+
enum class BitOrder { kBig = 0U, kLittle = 1U };
29+
30+
#define DISPATCH_BITORDER(bitorder, BITORDER, ...) \
31+
if (bitorder == BitOrder::kBig) { \
32+
constexpr BitOrder BITORDER = BitOrder::kBig; \
33+
__VA_ARGS__ \
34+
} else { \
35+
constexpr BitOrder BITORDER = BitOrder::kLittle; \
36+
__VA_ARGS__ \
37+
}
38+
39+
template <BitOrder BITORDER>
40+
__global__ void PackBitsKernel(bool* input, uint8_t* output, int64_t num_elements) {
41+
int64_t start_offset = blockIdx.x * blockDim.x * 8, tx = threadIdx.x;
42+
uint8_t ret = 0;
43+
bool input_vec[8];
44+
typedef cub::BlockLoad<bool, 256, 8, cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
45+
__shared__ typename BlockLoad::TempStorage temp_storage;
46+
BlockLoad(temp_storage)
47+
.Load(input + start_offset, input_vec, num_elements - start_offset, /*default=*/0);
48+
49+
if constexpr (BITORDER == BitOrder::kBig) {
50+
ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) |
51+
(input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7];
52+
} else {
53+
ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) |
54+
(input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0];
55+
}
56+
if (start_offset + tx * 8 < num_elements) output[start_offset / 8 + tx] = ret;
57+
}
58+
59+
template <BitOrder BITORDER, typename IdType>
60+
__global__ void SegmentPackBitsKernel(bool* input, uint8_t* output, IdType* input_indptr,
61+
IdType* output_indptr) {
62+
int64_t bx = blockIdx.x, tx = threadIdx.x;
63+
bool input_vec[8];
64+
typedef cub::BlockLoad<bool, 256, 8, cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
65+
__shared__ typename BlockLoad::TempStorage temp_storage;
66+
int64_t num_elements = input_indptr[bx + 1] - input_indptr[bx];
67+
for (uint32_t start_offset = 0; start_offset < num_elements; start_offset += 8 * blockDim.x) {
68+
uint8_t ret = 0;
69+
BlockLoad(temp_storage)
70+
.Load(input + input_indptr[bx] + start_offset, input_vec, num_elements - start_offset,
71+
/*default=*/0);
72+
73+
if constexpr (BITORDER == BitOrder::kBig) {
74+
ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) |
75+
(input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7];
76+
} else {
77+
ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) |
78+
(input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0];
79+
}
80+
if (start_offset + tx * 8 < num_elements)
81+
output[output_indptr[bx] + start_offset / 8 + tx] = ret;
82+
}
83+
}
84+
85+
cudaError_t PackBits(bool* input, uint8_t* output, int64_t num_elements, BitOrder bitorder,
86+
cudaStream_t stream) {
87+
DISPATCH_BITORDER(bitorder, BITORDER, {
88+
auto kernel = PackBitsKernel<BITORDER>;
89+
const dim3 nthrs(256);
90+
const dim3 nblks(ceil_div(num_elements, nthrs.x * 8));
91+
void* args[] = {&input, &output, &num_elements};
92+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
93+
});
94+
return cudaSuccess;
95+
}
96+
97+
template <typename IdType>
98+
cudaError_t SegmentPackBits(bool* input, uint8_t* output, IdType* input_indptr,
99+
IdType* output_indptr, uint32_t batch_size, BitOrder bitorder,
100+
cudaStream_t stream) {
101+
DISPATCH_BITORDER(bitorder, BITORDER, {
102+
auto kernel = SegmentPackBitsKernel<BITORDER, IdType>;
103+
const dim3 nthrs(256);
104+
const dim3 nblks(batch_size);
105+
void* args[] = {&input, &output, &input_indptr, &output_indptr};
106+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
107+
});
108+
return cudaSuccess;
109+
}
110+
111+
} // namespace quantization
112+
} // namespace flashinfer
113+
114+
#endif // FLASHINFER_QUANTIZATION_CUH_

python/csrc/batch_prefill.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
232232
handler_.get(), static_cast<c_type*>(q.data_ptr()),
233233
static_cast<int32_t*>(qo_indptr.data_ptr()),
234234
/*q_offset=*/nullptr, paged_kv,
235-
static_cast<float*>(custom_mask.data_ptr()),
235+
static_cast<uint8_t*>(custom_mask.data_ptr()),
236236
static_cast<int32_t*>(qk_indptr.data_ptr()),
237237
static_cast<c_type*>(o.data_ptr()),
238238
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
@@ -434,7 +434,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
434434
static_cast<int32_t*>(qo_indptr.data_ptr()),
435435
static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(v.data_ptr()),
436436
static_cast<int32_t*>(kv_indptr.data_ptr()),
437-
static_cast<float*>(custom_mask.data_ptr()),
437+
static_cast<uint8_t*>(custom_mask.data_ptr()),
438438
static_cast<int32_t*>(qk_indptr.data_ptr()),
439439
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
440440
static_cast<c_type*>(o.data_ptr()),

python/csrc/flashinfer_ops.cu

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4242
m.def("chain_speculative_sampling", &chain_speculative_sampling,
4343
"Speculative sampling from sequence of probabilities");
4444
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
45+
m.def("packbits", &packbits, "GPU packbits operator");
46+
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
4547
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
4648
"BatchDecodeWithPagedKVCachePyTorchWrapper")
4749
.def(py::init<unsigned int, bool, unsigned int>())

0 commit comments

Comments
 (0)