Skip to content

Commit 3b50dd5

Browse files
authored
feat: add use_tensor_cores option to decode kernels to accelerate GQA (#317)
The tensor-cores accelerated GQA in our [blog post](https://flashinfer.ai/2024/02/02/introduce-flashinfer.html) was not enabled by default (user need to use Prefill kernels/wrappers for decode to get such acceleration). In this PR we add an option `use_tensor_cores` to decode operators/wrappers, and user can select whether to use `tensor_cores` for acceleration depending on use cases. Not that our prefill kernels are compiled for all possible group sizes (#301 ), but decode kernels are not. So if user wants to use general group size, it's encouraged to set `use_tensor_cores=True`.
1 parent 2ef20c1 commit 3b50dd5

File tree

6 files changed

+506
-137
lines changed

6 files changed

+506
-137
lines changed

include/flashinfer/utils.cuh

-15
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,9 @@
9191
if (group_size == 1) { \
9292
constexpr size_t GROUP_SIZE = 1; \
9393
__VA_ARGS__ \
94-
} else if (group_size == 2) { \
95-
constexpr size_t GROUP_SIZE = 2; \
96-
__VA_ARGS__ \
97-
} else if (group_size == 3) { \
98-
constexpr size_t GROUP_SIZE = 3; \
99-
__VA_ARGS__ \
10094
} else if (group_size == 4) { \
10195
constexpr size_t GROUP_SIZE = 4; \
10296
__VA_ARGS__ \
103-
} else if (group_size == 5) { \
104-
constexpr size_t GROUP_SIZE = 5; \
105-
__VA_ARGS__ \
106-
} else if (group_size == 6) { \
107-
constexpr size_t GROUP_SIZE = 6; \
108-
__VA_ARGS__ \
109-
} else if (group_size == 7) { \
110-
constexpr size_t GROUP_SIZE = 7; \
111-
__VA_ARGS__ \
11297
} else if (group_size == 8) { \
11398
constexpr size_t GROUP_SIZE = 8; \
11499
__VA_ARGS__ \

python/csrc/single_prefill.cu

+4-8
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,14 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
3838
unsigned int head_dim = q.size(2);
3939
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
4040
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
41+
qo_len = q.size(0);
42+
num_qo_heads = q.size(1);
4143
if (kv_layout == QKVLayout::kNHD) {
4244
kv_len = k.size(0);
43-
qo_len = q.size(0);
4445
num_kv_heads = k.size(1);
45-
num_qo_heads = q.size(1);
4646
} else {
4747
kv_len = k.size(1);
48-
qo_len = q.size(1);
4948
num_kv_heads = k.size(0);
50-
num_qo_heads = q.size(0);
5149
}
5250
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
5351
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
@@ -122,16 +120,14 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
122120
unsigned int head_dim = q.size(2);
123121
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
124122
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
123+
qo_len = q.size(0);
124+
num_qo_heads = q.size(1);
125125
if (kv_layout == QKVLayout::kNHD) {
126126
kv_len = k.size(0);
127-
qo_len = q.size(0);
128127
num_kv_heads = k.size(1);
129-
num_qo_heads = q.size(1);
130128
} else {
131129
kv_len = k.size(1);
132-
qo_len = q.size(1);
133130
num_kv_heads = k.size(0);
134-
num_qo_heads = q.size(0);
135131
}
136132
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
137133
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();

python/flashinfer/cascade.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
307307
>>> head_dim = 128
308308
>>> max_num_pages = 128
309309
>>> page_size = 16
310-
>>> # allocate 16MB workspace buffer
311-
>>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
310+
>>> # allocate 128MB workspace buffer
311+
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
312312
>>> wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper(
313313
... workspace_buffer, "NHD"
314314
... )
@@ -540,8 +540,8 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
540540
>>> head_dim = 128
541541
>>> max_num_pages = 128
542542
>>> page_size = 16
543-
>>> # allocate 16MB workspace buffer
544-
>>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
543+
>>> # allocate 128MB workspace buffer
544+
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
545545
>>> prefill_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper(
546546
... workspace_buffer, "NHD"
547547
... )
@@ -617,7 +617,7 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
617617
----------
618618
workspace_buffer : torch.Tensor
619619
The user reserved workspace buffer used to store auxiliary data structures,
620-
recommended size is 16MB, the device of the workspace buffer should be the
620+
recommended size is 128MB, the device of the workspace buffer should be the
621621
same as the device of the input tensors.
622622
kv_layout : str
623623
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.

0 commit comments

Comments
 (0)