Skip to content

Commit f579ca2

Browse files
authored
misc: more benchmark scripts in Python (#1010)
Move benchmarks from C++ side to python for easier performance tracking.
1 parent db0b975 commit f579ca2

14 files changed

+441
-25
lines changed

benchmarks/bench_batch_decode.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import numpy as np
18+
import torch
19+
from triton.testing import do_bench
20+
21+
import flashinfer
22+
23+
page_block_size = 16
24+
num_kv_heads = 4
25+
num_qo_heads = 32
26+
head_dim = 128
27+
28+
29+
def bench_batch_decode(
30+
batch_size,
31+
seq_len,
32+
num_qo_heads,
33+
num_kv_heads,
34+
head_dim,
35+
page_block_size,
36+
q_dtype,
37+
kv_dtype,
38+
):
39+
np.random.seed(42)
40+
seq_lens = torch.full((batch_size,), seq_len)
41+
seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
42+
kv_indptr = torch.cat([torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0)
43+
kv_indptr = kv_indptr.int()
44+
last_page_len = seq_lens - (seq_lens_blocks - 1) * page_block_size
45+
last_page_len = last_page_len.int()
46+
num_blocks = kv_indptr[-1].item()
47+
48+
q = torch.rand(batch_size, num_qo_heads, head_dim, dtype=q_dtype, device="cuda:0")
49+
kv_data = torch.randn(
50+
num_blocks, 2, page_block_size, num_kv_heads, head_dim, device="cuda:0"
51+
).to(kv_dtype)
52+
workspace_buffer = torch.empty(
53+
128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"
54+
)
55+
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
56+
workspace_buffer, kv_layout="NHD", use_tensor_cores=True
57+
)
58+
wrapper.plan(
59+
kv_indptr.to(0),
60+
torch.arange(num_blocks).int().to(0),
61+
last_page_len.to(0),
62+
num_qo_heads,
63+
num_kv_heads,
64+
head_dim,
65+
page_block_size,
66+
data_type=kv_dtype,
67+
q_data_type=q_dtype,
68+
)
69+
70+
ms = do_bench(lambda: wrapper.run(q, kv_data))
71+
72+
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
73+
print(
74+
f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_block_size={page_block_size}, q_dtype={q_dtype}, kv_dtype={kv_dtype}"
75+
)
76+
print(f"execution time: {ms}ms")
77+
print(f"memory bandwidth: {io / ms / 1024 / 1024 :.2f} GB/s")
78+
79+
80+
if __name__ == "__main__":
81+
for q_dtype in [torch.bfloat16]:
82+
for kv_dtype in [torch.bfloat16, torch.float8_e4m3fn]:
83+
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
84+
for seq_len in [512, 1024, 2048, 4096, 8192, 16384]:
85+
bench_batch_decode(
86+
batch_size,
87+
seq_len,
88+
num_qo_heads,
89+
num_kv_heads,
90+
head_dim,
91+
page_block_size,
92+
q_dtype,
93+
kv_dtype,
94+
)

benchmarks/bench_grouped_gemm.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import numpy as np
18+
import torch
19+
from triton.testing import do_bench
20+
21+
import flashinfer
22+
23+
24+
def bench_grouped_gemm(
25+
batch_size, num_tokens_per_group, d_in, d_out, dtype, output_dtype
26+
):
27+
np.random.seed(42)
28+
W = torch.randn(batch_size, d_out, d_in, device="cuda:0").to(dtype)
29+
X = torch.randn(batch_size * num_tokens_per_group, d_in, device="cuda:0").to(dtype)
30+
Y = torch.empty(
31+
batch_size * num_tokens_per_group, d_out, dtype=output_dtype, device="cuda:0"
32+
)
33+
34+
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
35+
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend="auto")
36+
seg_indptr = torch.arange(
37+
0,
38+
(batch_size + 1) * num_tokens_per_group,
39+
num_tokens_per_group,
40+
dtype=torch.int64,
41+
device="cuda:0",
42+
)
43+
44+
ms = do_bench(
45+
lambda: segment_gemm.run(X, W, batch_size, True, out=Y, seg_indptr=seg_indptr)
46+
)
47+
flops = 2 * batch_size * num_tokens_per_group * d_in * d_out
48+
49+
print(
50+
f"Config: batch_size={batch_size}, num_tokens_per_group={num_tokens_per_group}, d_in={d_in}, d_out={d_out}, dtype={dtype}, output_dtype={output_dtype}"
51+
)
52+
print(f"FLOPs: {flops / ms * 1e-9:.2f} TFLOPs/s")
53+
54+
55+
if __name__ == "__main__":
56+
for dtype_in in [torch.float8_e4m3fn, torch.bfloat16]:
57+
for dtype_out in [torch.bfloat16]:
58+
for batch_size in [1, 3, 8, 16]:
59+
for num_tokens_per_group in [32, 64, 128, 256, 512]:
60+
for d_in in [4096, 8192]:
61+
for d_out in [4096, 8192]:
62+
bench_grouped_gemm(
63+
batch_size,
64+
num_tokens_per_group,
65+
d_in,
66+
d_out,
67+
dtype_in,
68+
dtype_out,
69+
)

csrc/flashinfer_gemm_sm90_ops.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
1919
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
2020
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
21-
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major);
21+
at::Tensor y_stride, at::Tensor empty_x_data, at::Tensor empty_y_data,
22+
bool weight_column_major);
2223

2324
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
2425
// "Cutlass Segment GEMM operator for SM90"

csrc/flashinfer_ops_sm90.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
2020
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
2121
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
22-
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major);
22+
at::Tensor y_stride, at::Tensor empty_x_data, at::Tensor empty_y_data,
23+
bool weight_column_major);
2324

2425
void single_prefill_with_kv_cache_sm90(
2526
at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o,

csrc/group_gemm_bf16_bf16_sm90.cu

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/gemm/group_gemm_sm90.cuh>
17+
18+
using namespace flashinfer;
19+
using namespace flashinfer::group_gemm;
20+
21+
namespace flashinfer {
22+
namespace group_gemm {
23+
24+
template cudaError_t CutlassSegmentGEMMSM90Run<cutlass::bfloat16_t, cutlass::bfloat16_t>(
25+
void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer,
26+
size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w,
27+
void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major,
28+
cudaStream_t stream);
29+
30+
}; // namespace group_gemm
31+
}; // namespace flashinfer

csrc/group_gemm_e4m3_bf16_sm90.cu

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/gemm/group_gemm_sm90.cuh>
17+
18+
using namespace flashinfer;
19+
using namespace flashinfer::group_gemm;
20+
21+
namespace flashinfer {
22+
namespace group_gemm {
23+
24+
template cudaError_t CutlassSegmentGEMMSM90Run<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
25+
void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer,
26+
size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w,
27+
void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major,
28+
cudaStream_t stream);
29+
30+
}; // namespace group_gemm
31+
}; // namespace flashinfer

csrc/group_gemm_e4m3_f16_sm90.cu

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/gemm/group_gemm_sm90.cuh>
17+
18+
using namespace flashinfer;
19+
using namespace flashinfer::group_gemm;
20+
21+
namespace flashinfer {
22+
namespace group_gemm {
23+
24+
template cudaError_t CutlassSegmentGEMMSM90Run<cutlass::float_e4m3_t, cutlass::half_t>(
25+
void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer,
26+
size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w,
27+
void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major,
28+
cudaStream_t stream);
29+
30+
}; // namespace group_gemm
31+
}; // namespace flashinfer

csrc/group_gemm_e5m2_bf16_sm90.cu

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/gemm/group_gemm_sm90.cuh>
17+
18+
using namespace flashinfer;
19+
using namespace flashinfer::group_gemm;
20+
21+
namespace flashinfer {
22+
namespace group_gemm {
23+
24+
template cudaError_t CutlassSegmentGEMMSM90Run<cutlass::float_e5m2_t, cutlass::bfloat16_t>(
25+
void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer,
26+
size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w,
27+
void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major,
28+
cudaStream_t stream);
29+
30+
}; // namespace group_gemm
31+
}; // namespace flashinfer

csrc/group_gemm_e5m2_f16_sm90.cu

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/gemm/group_gemm_sm90.cuh>
17+
18+
using namespace flashinfer;
19+
using namespace flashinfer::group_gemm;
20+
21+
namespace flashinfer {
22+
namespace group_gemm {
23+
24+
template cudaError_t CutlassSegmentGEMMSM90Run<cutlass::float_e5m2_t, cutlass::half_t>(
25+
void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer,
26+
size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w,
27+
void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major,
28+
cudaStream_t stream);
29+
30+
}; // namespace group_gemm
31+
}; // namespace flashinfer

csrc/group_gemm_f16_f16_sm90.cu

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/gemm/group_gemm_sm90.cuh>
17+
18+
using namespace flashinfer;
19+
using namespace flashinfer::group_gemm;
20+
21+
namespace flashinfer {
22+
namespace group_gemm {
23+
24+
template cudaError_t CutlassSegmentGEMMSM90Run<cutlass::half_t, cutlass::half_t>(
25+
void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer,
26+
size_t int_buffer_size_in_bytes, void* all_problems, int64_t batch_size, void* x, void* w,
27+
void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major,
28+
cudaStream_t stream);
29+
30+
}; // namespace group_gemm
31+
}; // namespace flashinfer

0 commit comments

Comments
 (0)