Skip to content

Commit 8d29762

Browse files
diptorupdyzh119
andauthored
ci: add pre-commit (#931)
- The PR applies missing pre-commit identified formatting fixes to the CUDA headers inside `include`. - Also adds a github action workflow against PRs to enforce pre-commit checks. --------- Co-authored-by: Zihao <[email protected]>
1 parent f959354 commit 8d29762

32 files changed

+416
-275
lines changed

.github/workflows/pre-commit.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: pre-commit
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [main]
7+
8+
permissions: read-all
9+
10+
jobs:
11+
pre-commit:
12+
runs-on: ubuntu-latest
13+
timeout-minutes: 30
14+
steps:
15+
- uses: actions/[email protected]
16+
- uses: actions/setup-python@v5
17+
with:
18+
python-version: '3.11'
19+
- uses: pre-commit/[email protected]

aot_build_utils/generate.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def write_if_different(path: Path, content: str) -> None:
275275
)
276276
parser.add_argument(
277277
"--use_fp16_qk_reductions",
278-
type=lambda x: x if isinstance(x, int) else int(x.lower() == "true" or x.lower() == "on"),
278+
type=lambda x: (
279+
x if isinstance(x, int) else int(x.lower() == "true" or x.lower() == "on")
280+
),
279281
required=True,
280282
nargs="+",
281283
help="Allow fp16 qk reductions",
@@ -289,28 +291,36 @@ def write_if_different(path: Path, content: str) -> None:
289291
)
290292
parser.add_argument(
291293
"--enable_f16",
292-
type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"),
294+
type=lambda x: (
295+
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
296+
),
293297
required=True,
294298
nargs="?",
295299
help="Enable fp16",
296300
)
297301
parser.add_argument(
298302
"--enable_bf16",
299-
type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"),
303+
type=lambda x: (
304+
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
305+
),
300306
required=True,
301307
nargs="?",
302308
help="Enable bf16",
303309
)
304310
parser.add_argument(
305311
"--enable_fp8_e4m3",
306-
type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"),
312+
type=lambda x: (
313+
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
314+
),
307315
default=True,
308316
nargs="?",
309317
help="Enable fp8_e4m3",
310318
)
311319
parser.add_argument(
312320
"--enable_fp8_e5m2",
313-
type=lambda x: x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on"),
321+
type=lambda x: (
322+
x if isinstance(x, int) else (x.lower() == "true" or x.lower() == "on")
323+
),
314324
default=True,
315325
nargs="?",
316326
help="Enable fp8_e5m2",

aot_build_utils/generate_dispatch_inc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
100100
"--path", type=str, required=True, help="Path to the dispatch inc file"
101101
)
102102
parser.add_argument(
103-
"--head_dims_sm90", type=str, required=True, nargs="+", help="Head dimensions in format of 'head_dim_qk,head_dim_vo'",
103+
"--head_dims_sm90",
104+
type=str,
105+
required=True,
106+
nargs="+",
107+
help="Head dimensions in format of 'head_dim_qk,head_dim_vo'",
104108
)
105109
parser.add_argument(
106110
"--head_dims", type=int, required=True, nargs="+", help="Head dimensions"

benchmarks/bench_fused_add_rmsnorm.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@
66

77
import flashinfer
88

9+
910
@torch.inference_mode()
1011
def main():
1112
parser = argparse.ArgumentParser()
12-
parser.add_argument("--batch-sizes", nargs='+', type=int, default=[1, 19, 99, 989])
13-
parser.add_argument("--hidden-sizes", nargs='+', type=int, default=[111, 500, 1024, 3072, 4096, 8192])
14-
parser.add_argument("--dtypes", nargs='+', choices=["float16", "bfloat16"], default=["float16"])
13+
parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 19, 99, 989])
14+
parser.add_argument(
15+
"--hidden-sizes",
16+
nargs="+",
17+
type=int,
18+
default=[111, 500, 1024, 3072, 4096, 8192],
19+
)
20+
parser.add_argument(
21+
"--dtypes", nargs="+", choices=["float16", "bfloat16"], default=["float16"]
22+
)
1523
args = parser.parse_args()
1624

1725
eps = 1e-6
@@ -27,18 +35,19 @@ def main():
2735
residual = torch.randn_like(x)
2836
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
2937

30-
@torch.cuda.nvtx.range(f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}")
38+
@torch.cuda.nvtx.range(
39+
f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}"
40+
)
3141
def fn() -> None:
3242
flashinfer.fused_add_rmsnorm(x, residual, weight, eps)
3343

3444
# Run benchmarking
3545
latency_ms = cast(float, do_bench(fn))
3646
throughput = (
37-
(x.numel() * x.element_size() * 2
38-
+ residual.numel() * residual.element_size() * 2
39-
+ weight.numel() * weight.element_size())
40-
/ (latency_ms * 1e-3)
41-
)
47+
x.numel() * x.element_size() * 2
48+
+ residual.numel() * residual.element_size() * 2
49+
+ weight.numel() * weight.element_size()
50+
) / (latency_ms * 1e-3)
4251
print(
4352
f"batch_size: {batch_size:3},",
4453
f"hidden_size: {hidden_size:5},",
@@ -51,5 +60,6 @@ def fn() -> None:
5160

5261
torch.cuda.profiler.stop()
5362

63+
5464
if __name__ == "__main__":
5565
main()

benchmarks/bench_rope.py

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@
1111

1212
import torch
1313
import torch.nn as nn
14-
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
15-
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding as vLLMRotaryEmbedding
1614
import triton
15+
from vllm.model_executor.layers.rotary_embedding import (
16+
RotaryEmbedding as vLLMRotaryEmbedding,
17+
)
18+
19+
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
20+
1721

1822
class FlashInferRotaryEmbedding(nn.Module):
1923

@@ -39,8 +43,12 @@ def __init__(
3943
self.register_buffer("cos_sin_cache", cache, persistent=False)
4044

4145
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
42-
inv_freq = 1.0 / (base**(torch.arange(
43-
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
46+
inv_freq = 1.0 / (
47+
base
48+
** (
49+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
50+
)
51+
)
4452
return inv_freq
4553

4654
def _compute_cos_sin_cache(self) -> torch.Tensor:
@@ -82,7 +90,7 @@ def _apply_rotary_emb(
8290
return torch.cat((o1, o2), dim=-1)
8391
else:
8492
return torch.stack((o1, o2), dim=-1).flatten(-2)
85-
93+
8694
def forward_cuda(
8795
self,
8896
positions: torch.Tensor,
@@ -100,42 +108,99 @@ def forward_cuda(
100108
)
101109
return query, key
102110

111+
103112
@triton.testing.perf_report(
104113
triton.testing.Benchmark(
105114
x_names=["seq_len"],
106-
x_vals=[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
115+
x_vals=[
116+
2,
117+
4,
118+
8,
119+
16,
120+
32,
121+
64,
122+
128,
123+
256,
124+
512,
125+
1024,
126+
2048,
127+
4096,
128+
8192,
129+
16384,
130+
32768,
131+
65536,
132+
],
107133
line_arg="provider",
108134
line_vals=["flashinfer", "native", "vllm"],
109135
line_names=["FlashInfer", "Native", "vLLM"],
110136
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
111137
ylabel="Latency (ms)",
112138
plot_name="rope-latency",
113-
args={"head_size": 4096//32, "rotary_dim": 4096//32, "max_position_embeddings": 65536, "base": 500000, "is_neox_style": True, "dtype": torch.bfloat16, "device": "cuda", "batch_size": 2, "num_q_heads": 32, "num_kv_heads": 8},
139+
args={
140+
"head_size": 4096 // 32,
141+
"rotary_dim": 4096 // 32,
142+
"max_position_embeddings": 65536,
143+
"base": 500000,
144+
"is_neox_style": True,
145+
"dtype": torch.bfloat16,
146+
"device": "cuda",
147+
"batch_size": 2,
148+
"num_q_heads": 32,
149+
"num_kv_heads": 8,
150+
},
114151
)
115152
)
116-
def benchmark(provider, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads):
117-
print(f"provider: {provider}, head_size: {head_size}, rotary_dim: {rotary_dim}, max_position_embeddings: {max_position_embeddings}, base: {base}, is_neox_style: {is_neox_style}, dtype: {dtype}, device: {device}, batch_size: {batch_size}, seq_len: {seq_len}, num_q_heads: {num_q_heads}, num_kv_heads: {num_kv_heads}")
118-
153+
def benchmark(
154+
provider,
155+
head_size,
156+
rotary_dim,
157+
max_position_embeddings,
158+
base,
159+
is_neox_style,
160+
dtype,
161+
device,
162+
batch_size,
163+
seq_len,
164+
num_q_heads,
165+
num_kv_heads,
166+
):
167+
print(
168+
f"provider: {provider}, head_size: {head_size}, rotary_dim: {rotary_dim}, max_position_embeddings: {max_position_embeddings}, base: {base}, is_neox_style: {is_neox_style}, dtype: {dtype}, device: {device}, batch_size: {batch_size}, seq_len: {seq_len}, num_q_heads: {num_q_heads}, num_kv_heads: {num_kv_heads}"
169+
)
170+
119171
rope_forward = None
120172

121173
if provider == "vllm":
122-
rope = vLLMRotaryEmbedding(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype).to(device)
174+
rope = vLLMRotaryEmbedding(
175+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
176+
).to(device)
123177
rope_forward = rope.forward_cuda
124178
elif provider == "flashinfer":
125-
rope = FlashInferRotaryEmbedding(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype).to(device)
179+
rope = FlashInferRotaryEmbedding(
180+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
181+
).to(device)
126182
rope_forward = rope.forward_cuda
127183
elif provider == "native":
128-
rope = vLLMRotaryEmbedding(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype).to(device)
184+
rope = vLLMRotaryEmbedding(
185+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
186+
).to(device)
129187
rope_forward = rope.forward_native
130188

131189
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
132-
query = torch.randn(batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device)
133-
key = torch.randn(batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device)
190+
query = torch.randn(
191+
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
192+
)
193+
key = torch.randn(
194+
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
195+
)
134196

135197
quantiles = [0.5, 0.2, 0.8]
136-
ms, min_ms, max_ms = triton.testing.do_bench(lambda: rope_forward(pos_ids, query, key), quantiles=quantiles)
198+
ms, min_ms, max_ms = triton.testing.do_bench(
199+
lambda: rope_forward(pos_ids, query, key), quantiles=quantiles
200+
)
137201

138202
return ms, min_ms, max_ms
139203

204+
140205
if __name__ == "__main__":
141206
benchmark.run(print_data=True, show_plots=True, save_path="rope_benchmark.png")

csrc/batch_decode.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
#include <optional>
2020

2121
#include "batch_decode_config.inc"
22-
#include "pytorch_extension_utils.h"
2322
#include "pytorch_conversion_utils.h"
23+
#include "pytorch_extension_utils.h"
2424

2525
namespace flashinfer {
2626

@@ -36,9 +36,9 @@ using namespace flashinfer;
3636
at::Tensor BatchDecodeWithPagedKVCachePlan(
3737
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
3838
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
39-
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
40-
bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk,
41-
int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream) {
39+
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
40+
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
41+
at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream) {
4242
size_t float_workspace_size_in_bytes =
4343
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
4444
size_t int_workspace_size_in_bytes =
@@ -78,11 +78,11 @@ at::Tensor BatchDecodeWithPagedKVCachePlan(
7878
}
7979

8080
void BatchDecodeWithPagedKVCacheRun(
81-
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
82-
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
83-
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
84-
at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse,
85-
int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
81+
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
82+
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr,
83+
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o,
84+
std::optional<at::Tensor> maybe_lse, int64_t kv_layout_code,
85+
int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
8686
DecodePlanInfo plan_info;
8787
plan_info.FromVector(tensor_to_vec(plan_info_vec));
8888
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);

csrc/batch_decode_jit_pybind.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919
at::Tensor BatchDecodeWithPagedKVCachePlan(
2020
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
2121
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
22-
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
23-
bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk,
24-
int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream);
22+
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
23+
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
24+
at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream);
2525

2626
void BatchDecodeWithPagedKVCacheRun(
27-
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
28-
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
29-
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
30-
at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse,
31-
int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
27+
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
28+
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr,
29+
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o,
30+
std::optional<at::Tensor> maybe_lse, int64_t kv_layout_code,
31+
int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
3232

3333
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3434
// Batched decode with paged KV-Cache plan

0 commit comments

Comments
 (0)