Skip to content

Commit 3afb6d3

Browse files
authored
benchmark: add batch prefill with ragged kv-cache benchmark (#338)
1 parent 10e6b17 commit 3afb6d3

File tree

3 files changed

+136
-2
lines changed

3 files changed

+136
-2
lines changed

CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,15 @@ if (FLASHINFER_PREFILL)
360360
target_link_libraries(test_single_prefill PRIVATE gtest gtest_main prefill_kernels)
361361
target_compile_options(test_single_prefill PRIVATE -Wno-switch-bool)
362362

363+
message(STATUS "Compile batch prefill kernel benchmarks.")
364+
file(GLOB_RECURSE BENCH_PREFILL_SRCS ${PROJECT_SOURCE_DIR}/src/bench_batch_prefill.cu)
365+
add_executable(bench_batch_prefill ${BENCH_PREFILL_SRCS})
366+
target_include_directories(bench_batch_prefill PRIVATE ${FLASHINFER_INCLUDE_DIR})
367+
target_include_directories(bench_batch_prefill PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench)
368+
add_dependencies(bench_batch_prefill dispatch_inc)
369+
target_link_libraries(bench_batch_prefill PRIVATE nvbench::main prefill_kernels)
370+
target_compile_options(bench_batch_prefill PRIVATE -Wno-switch-bool)
371+
363372
message(STATUS "Compile batch prefill kernel tests.")
364373
file(GLOB_RECURSE TEST_PREFILL_SRCS ${PROJECT_SOURCE_DIR}/src/test_batch_prefill.cu)
365374
add_executable(test_batch_prefill ${TEST_PREFILL_SRCS})

src/bench_batch_prefill.cu

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright (c) 2023 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <thrust/detail/raw_pointer_cast.h>
17+
#include <thrust/device_vector.h>
18+
19+
#include <cstdint>
20+
#include <nvbench/nvbench.cuh>
21+
#include <optional>
22+
23+
#include "flashinfer/attention/handler.cuh"
24+
#include "flashinfer/layout.cuh"
25+
#include "flashinfer/pos_enc.cuh"
26+
#include "flashinfer_ops.cuh"
27+
28+
using namespace flashinfer;
29+
30+
inline uint32_t ceil_div(uint32_t a, uint32_t b) { return (a + b - 1) / b; }
31+
32+
template <typename dtype_in, typename dtype_out, bool append>
33+
void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) {
34+
size_t kv_len = state.get_int64("kv_len");
35+
size_t qo_len = kv_len;
36+
size_t batch_size = state.get_int64("batch_size");
37+
size_t num_qo_heads = state.get_int64("num_qo_heads");
38+
size_t num_kv_heads = state.get_int64("num_kv_heads");
39+
size_t head_dim = state.get_int64("head_dim");
40+
size_t pos_encoding_mode = state.get_int64("pos_encoding_mode");
41+
size_t kv_layout = state.get_int64("kv_layout");
42+
bool causal = state.get_int64("causal");
43+
bool cooperative = state.get_int64("cooperative");
44+
bool allow_fp16_qk_reduction = state.get_int64("allow_fp16_qk_reduction");
45+
46+
// Allocate input data:
47+
thrust::device_vector<dtype_in> Q(batch_size * qo_len * num_qo_heads * head_dim);
48+
thrust::device_vector<dtype_in> K(batch_size * kv_len * num_kv_heads * head_dim);
49+
thrust::device_vector<dtype_in> V(batch_size * kv_len * num_kv_heads * head_dim);
50+
thrust::device_vector<dtype_out> O(batch_size * qo_len * num_qo_heads * head_dim);
51+
thrust::device_vector<uint8_t> workspace(128 * 1024 * 1024);
52+
53+
// Provide throughput information:
54+
state.add_global_memory_reads<dtype_in>(
55+
(batch_size * qo_len * num_qo_heads + 2 * batch_size * kv_len * num_kv_heads) * head_dim,
56+
"Read");
57+
state.add_global_memory_writes<dtype_out>(qo_len * batch_size * num_qo_heads * head_dim, "Write");
58+
59+
std::vector<int32_t> qo_indptr_h(batch_size + 1);
60+
std::vector<int32_t> kv_indptr_h(batch_size + 1);
61+
62+
for (uint32_t i = 0; i <= batch_size; ++i) {
63+
qo_indptr_h[i] = i * qo_len;
64+
kv_indptr_h[i] = i * kv_len;
65+
}
66+
67+
thrust::device_vector<int32_t> qo_indptr_d(qo_indptr_h);
68+
thrust::device_vector<int32_t> kv_indptr_d(kv_indptr_h);
69+
70+
BatchPrefillHandler handler;
71+
72+
handler.BeginForward<dtype_out>(thrust::raw_pointer_cast(workspace.data()),
73+
workspace.size() * sizeof(uint8_t), qo_indptr_h.data(),
74+
kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads,
75+
head_dim, /*page_size=*/1);
76+
77+
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
78+
timer.start();
79+
cudaError_t status;
80+
status = BatchPrefillWithRaggedKVCacheWrapper<dtype_in, dtype_out, int32_t>(
81+
&handler, thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()),
82+
thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()),
83+
thrust::raw_pointer_cast(kv_indptr_d.data()),
84+
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, thrust::raw_pointer_cast(O.data()),
85+
/*lse=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim, causal,
86+
QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction);
87+
if (status != cudaSuccess) {
88+
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
89+
}
90+
timer.stop();
91+
});
92+
const auto measured_mean = static_cast<nvbench::float32_t>(
93+
state.get_summary("nv/cold/time/gpu/mean").get_float64("value"));
94+
auto& summ = state.add_summary("nv/tflops");
95+
summ.set_string("description", "Achieved TFlops/s");
96+
summ.set_string("name", "TFlops/s");
97+
float tflops;
98+
if (causal) {
99+
tflops = (batch_size * (qo_len * (2 * kv_len - qo_len) * 2 * num_qo_heads * head_dim)) /
100+
measured_mean / 1e12;
101+
} else {
102+
tflops = (batch_size * qo_len * kv_len * 4 * num_qo_heads * head_dim) / measured_mean / 1e12;
103+
}
104+
summ.set_float64("value", tflops);
105+
}
106+
107+
#define STR_HELPER(x) #x
108+
#define STR(x) STR_HELPER(x)
109+
#define BENCH_FLASHINFER_BATCH_PREFILL_WITH_RAGGED_KV(dtype_in, dtype_out) \
110+
auto bench_flashinfer_batch_prefill_with_ragged_kv_##dtype_in##_##dtype_out##_ = \
111+
bench_flashinfer_batch_prefill_with_ragged_kv<dtype_in, dtype_out, false>; \
112+
NVBENCH_BENCH(bench_flashinfer_batch_prefill_with_ragged_kv_##dtype_in##_##dtype_out##_) \
113+
.set_name( \
114+
("bench_flashinfer_batch_prefill_with_ragged_kv_" STR(dtype_in) "_" STR(dtype_out))) \
115+
.add_int64_axis("kv_len", {32, 64, 128, 256, 512, 1024, 2048, 4096}) \
116+
.add_int64_axis("batch_size", {4, 8, 32}) \
117+
.add_int64_axis("num_qo_heads", {32}) \
118+
.add_int64_axis("num_kv_heads", {32}) \
119+
.add_int64_axis("head_dim", {128}) \
120+
.add_int64_axis("causal", {0, 1}) \
121+
.add_int64_axis("kv_layout", {0}) \
122+
.add_int64_axis("pos_encoding_mode", {0}) \
123+
.add_int64_axis("allow_fp16_qk_reduction", {0}) \
124+
.add_int64_axis("cooperative", {1})
125+
126+
BENCH_FLASHINFER_BATCH_PREFILL_WITH_RAGGED_KV(half, half);

src/bench_single_prefill.cu

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#include <driver_types.h>
1716
#include <thrust/device_vector.h>
1817

1918
#include <nvbench/nvbench.cuh>
@@ -54,7 +53,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) {
5453

5554
// Provide throughput information:
5655
state.add_global_memory_reads<dtype_in>(
57-
(2 * qo_len * num_qo_heads + 2 * kv_len * num_kv_heads) * head_dim, "Read");
56+
(qo_len * num_qo_heads + 2 * kv_len * num_kv_heads) * head_dim, "Read");
5857
state.add_global_memory_writes<dtype_out>(qo_len * num_qo_heads * head_dim, "Write");
5958

6059
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {

0 commit comments

Comments
 (0)