Skip to content

Commit 82392da

Browse files
HandH1998yych0745
andauthored
support w8a8 fp8 kernel with CUTLASS (#3047)
Co-authored-by: yych0745 <[email protected]>
1 parent 95f789a commit 82392da

File tree

8 files changed

+881
-0
lines changed

8 files changed

+881
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import argparse
2+
import copy
3+
import itertools
4+
5+
import torch
6+
import triton
7+
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
8+
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
9+
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
10+
11+
# Weight Shapes are in the format
12+
# ([K, N], TP_SPLIT_DIM)
13+
# Example:
14+
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
15+
# - TP1 : K = 14336, N = 4096
16+
# - TP2 : K = 7168, N = 4096
17+
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
18+
# - TP1 : K = 4096, N = 6144
19+
# - TP4 : K = 4096, N = 1536
20+
21+
# TP1 shapes
22+
WEIGHT_SHAPES = {
23+
"meta-llama/Llama-3.1-8B-Instruct": [
24+
([4096, 6144], 1),
25+
([4096, 4096], 0),
26+
([4096, 28672], 1),
27+
([14336, 4096], 0),
28+
],
29+
"meta-llama/Llama-3.3-70B-Instruct": [
30+
([8192, 10240], 1),
31+
([8192, 8192], 0),
32+
([8192, 57344], 1),
33+
([28672, 8192], 0),
34+
],
35+
"mistralai/Mistral-Large-Instruct-2407": [
36+
([12288, 14336], 1),
37+
([12288, 12288], 0),
38+
([12288, 57344], 1),
39+
([28672, 12288], 0),
40+
],
41+
"Qwen/Qwen2.5-7B-Instruct": [
42+
([3584, 4608], 1),
43+
([3584, 3584], 0),
44+
([3584, 37888], 1),
45+
([18944, 3584], 0),
46+
],
47+
"Qwen/Qwen2.5-32B-Instruct": [
48+
([5120, 7168], 1),
49+
([5120, 5120], 0),
50+
([5120, 55296], 1),
51+
([27648, 5120], 0),
52+
],
53+
"Qwen/Qwen2.5-72B-Instruct": [
54+
([8192, 10240], 1),
55+
([8192, 8192], 0),
56+
([8192, 59136], 1),
57+
([29568, 8192], 0),
58+
],
59+
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
60+
([2048, 3072], 1),
61+
([2048, 4096], 1),
62+
([2048, 2048], 0),
63+
([2048, 576], 0),
64+
([2048, 21888], 1),
65+
([10944, 2048], 0),
66+
([2048, 2816], 1),
67+
([1408, 2048], 0),
68+
],
69+
}
70+
71+
72+
@triton.testing.perf_report(
73+
triton.testing.Benchmark(
74+
x_names=["batch_size"],
75+
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
76+
x_log=False,
77+
line_arg="provider",
78+
line_vals=[
79+
"vllm-fp8-fp16",
80+
"vllm-fp8-bf16",
81+
"sglang-fp8-fp16",
82+
"sglang-fp8-bf16",
83+
],
84+
line_names=[
85+
"vllm-fp8-fp16",
86+
"vllm-fp8-bf16",
87+
"sglang-fp8-fp16",
88+
"sglang-fp8-bf16",
89+
],
90+
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
91+
ylabel="GB/s",
92+
plot_name="fp8 scaled matmul",
93+
args={},
94+
)
95+
)
96+
def benchmark(batch_size, provider, N, K):
97+
# M, N, K = batch_size, 4096, 8192
98+
M = batch_size
99+
a = torch.ones((M, K), device="cuda") * 5.0
100+
b = torch.ones((N, K), device="cuda") * 5.0
101+
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
102+
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
103+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
104+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
105+
b_fp8 = b_fp8.t()
106+
quantiles = [0.5, 0.2, 0.8]
107+
108+
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
109+
110+
if "vllm-fp8" in provider:
111+
ms, min_ms, max_ms = triton.testing.do_bench(
112+
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
113+
quantiles=quantiles,
114+
)
115+
elif "sglang-fp8" in provider:
116+
ms, min_ms, max_ms = triton.testing.do_bench(
117+
lambda: sgl_scaled_mm(
118+
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
119+
),
120+
quantiles=quantiles,
121+
)
122+
123+
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
124+
return gbps(ms), gbps(max_ms), gbps(min_ms)
125+
126+
127+
def prepare_shapes(args):
128+
KN_model_names = []
129+
models_tps = list(itertools.product(args.models, args.tp_sizes))
130+
for model, tp_size in models_tps:
131+
assert model in WEIGHT_SHAPES
132+
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
133+
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
134+
KN.append(model)
135+
KN_model_names.append(KN)
136+
return KN_model_names
137+
138+
139+
if __name__ == "__main__":
140+
parser = argparse.ArgumentParser()
141+
parser.add_argument(
142+
"--models",
143+
nargs="+",
144+
type=str,
145+
default=["meta-llama/Llama-3.1-8B-Instruct"],
146+
help="List of models to benchmark",
147+
)
148+
parser.add_argument(
149+
"--tp-sizes",
150+
nargs="+",
151+
type=int,
152+
default=[1],
153+
help="List of tensor parallel sizes",
154+
)
155+
args = parser.parse_args()
156+
157+
KN_model_names = prepare_shapes(args)
158+
for K, N, model_name in KN_model_names:
159+
print(f"{model_name} N={N} K={K}: ")
160+
benchmark.run(
161+
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
162+
)
163+
164+
print("Benchmark finished!")

sgl-kernel/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _get_version():
5656
turbomind.resolve(),
5757
turbomind.resolve() / "src",
5858
]
59+
5960
nvcc_flags = [
6061
"-DNDEBUG",
6162
f"-DOPERATOR_NAMESPACE={operator_namespace}",
@@ -82,6 +83,7 @@ def _get_version():
8283
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
8384
"src/sgl-kernel/csrc/moe_align_kernel.cu",
8485
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
86+
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
8587
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
8688
"src/sgl-kernel/csrc/rotary_embedding.cu",
8789
"3rdparty/flashinfer/csrc/activation.cu",

sgl-kernel/src/sgl-kernel/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
bmm_fp8,
33
custom_dispose,
44
custom_reduce,
5+
fp8_scaled_mm,
56
fused_add_rmsnorm,
67
gelu_and_mul,
78
gelu_tanh_and_mul,
@@ -27,6 +28,7 @@
2728
"bmm_fp8",
2829
"custom_dispose",
2930
"custom_reduce",
31+
"fp8_scaled_mm",
3032
"fused_add_rmsnorm",
3133
"gelu_and_mul",
3234
"gelu_tanh_and_mul",

0 commit comments

Comments
 (0)