-
Notifications
You must be signed in to change notification settings - Fork 1.9k
support w8a8 fp8 kernel with CUTLASS #3047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
955a2fb
Add performance and accuracy test code for FP8 GEMM operations
yych0745 30bdf20
support w8a8 fp8
HandH1998 4cac9fb
support bias
HandH1998 ecc90a4
opitmize
yych0745 3497950
add config_profile for sm_89
yych0745 05eb204
fp8 sm90-H100 singleTest done
yych0745 8d95538
fp8 sm90-H100 singleTest done
yych0745 724cf62
clean code
yych0745 93e2d85
fix
yych0745 8c08dbb
clean code
yych0745 fb95b0e
clean code
yych0745 2bac342
fp8 dispatch change
yych0745 ba7ca85
clean code
yych0745 2727d7d
fix
yych0745 b11682e
clean code
yych0745 fe490cc
Add performance and accuracy test code for FP8 GEMM operations
yych0745 b2de73d
support w8a8 fp8
HandH1998 3691d68
support bias
HandH1998 38bcf52
fix compilation
HandH1998 d57f756
clean code
yych0745 e620244
clean code
yych0745 699fe9e
Merge pull request #6 from HandH1998/tmptmp
HandH1998 b6a88bb
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 604f4f5
format
HandH1998 98dc70d
format
HandH1998 a4025f6
Merge branch 'main' into main_w8a8_fp8
zhyncs 8b87aad
upd
zhyncs b287319
Merge branch 'main' into main_w8a8_fp8
zhyncs 6de3ad4
Merge branch 'main_w8a8_fp8' of https://github.com/HandH1998/sglang i…
yych0745 b4195b0
fix include
HandH1998 8290ba6
add more shapes for benchmark
HandH1998 a455233
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 42f408f
fix bug
HandH1998 1739631
Merge branch 'main_w8a8_fp8' of https://github.com/HandH1998/sglang i…
yych0745 0666d39
cutlass optimization
yych0745 b9980af
clean code
HandH1998 cd51083
fix reivew issues
HandH1998 4a98c75
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 8c3dc13
fix bug
HandH1998 a1b582e
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 248391e
Merge branch 'main' into main_w8a8_fp8
zhyncs 62bf9a4
fix name conflict
HandH1998 0d7f5a0
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import argparse | ||
import copy | ||
import itertools | ||
|
||
import torch | ||
import triton | ||
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm | ||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm | ||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant | ||
|
||
# Weight Shapes are in the format | ||
# ([K, N], TP_SPLIT_DIM) | ||
# Example: | ||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape, | ||
# - TP1 : K = 14336, N = 4096 | ||
# - TP2 : K = 7168, N = 4096 | ||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape, | ||
# - TP1 : K = 4096, N = 6144 | ||
# - TP4 : K = 4096, N = 1536 | ||
|
||
# TP1 shapes | ||
WEIGHT_SHAPES = { | ||
"meta-llama/Llama-3.1-8B-Instruct": [ | ||
([4096, 6144], 1), | ||
([4096, 4096], 0), | ||
([4096, 28672], 1), | ||
([14336, 4096], 0), | ||
], | ||
"meta-llama/Llama-3.3-70B-Instruct": [ | ||
([8192, 10240], 1), | ||
([8192, 8192], 0), | ||
([8192, 57344], 1), | ||
([28672, 8192], 0), | ||
], | ||
"mistralai/Mistral-Large-Instruct-2407": [ | ||
([12288, 14336], 1), | ||
([12288, 12288], 0), | ||
([12288, 57344], 1), | ||
([28672, 12288], 0), | ||
], | ||
"Qwen/Qwen2.5-7B-Instruct": [ | ||
([3584, 4608], 1), | ||
([3584, 3584], 0), | ||
([3584, 37888], 1), | ||
([18944, 3584], 0), | ||
], | ||
"Qwen/Qwen2.5-32B-Instruct": [ | ||
([5120, 7168], 1), | ||
([5120, 5120], 0), | ||
([5120, 55296], 1), | ||
([27648, 5120], 0), | ||
], | ||
"Qwen/Qwen2.5-72B-Instruct": [ | ||
([8192, 10240], 1), | ||
([8192, 8192], 0), | ||
([8192, 59136], 1), | ||
([29568, 8192], 0), | ||
], | ||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ | ||
([2048, 3072], 1), | ||
([2048, 4096], 1), | ||
([2048, 2048], 0), | ||
([2048, 576], 0), | ||
([2048, 21888], 1), | ||
([10944, 2048], 0), | ||
([2048, 2816], 1), | ||
([1408, 2048], 0), | ||
], | ||
} | ||
|
||
|
||
@triton.testing.perf_report( | ||
triton.testing.Benchmark( | ||
x_names=["batch_size"], | ||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], | ||
x_log=False, | ||
line_arg="provider", | ||
line_vals=[ | ||
"vllm-fp8-fp16", | ||
"vllm-fp8-bf16", | ||
"sglang-fp8-fp16", | ||
"sglang-fp8-bf16", | ||
], | ||
line_names=[ | ||
"vllm-fp8-fp16", | ||
"vllm-fp8-bf16", | ||
"sglang-fp8-fp16", | ||
"sglang-fp8-bf16", | ||
], | ||
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], | ||
ylabel="GB/s", | ||
plot_name="fp8 scaled matmul", | ||
args={}, | ||
) | ||
) | ||
def benchmark(batch_size, provider, N, K): | ||
# M, N, K = batch_size, 4096, 8192 | ||
M = batch_size | ||
a = torch.ones((M, K), device="cuda") * 5.0 | ||
b = torch.ones((N, K), device="cuda") * 5.0 | ||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) | ||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) | ||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||
b_fp8 = b_fp8.t() | ||
quantiles = [0.5, 0.2, 0.8] | ||
|
||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16 | ||
|
||
if "vllm-fp8" in provider: | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), | ||
quantiles=quantiles, | ||
) | ||
elif "sglang-fp8" in provider: | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: sgl_scaled_mm( | ||
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None | ||
), | ||
quantiles=quantiles, | ||
) | ||
|
||
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) | ||
return gbps(ms), gbps(max_ms), gbps(min_ms) | ||
|
||
|
||
def prepare_shapes(args): | ||
KN_model_names = [] | ||
models_tps = list(itertools.product(args.models, args.tp_sizes)) | ||
for model, tp_size in models_tps: | ||
assert model in WEIGHT_SHAPES | ||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size | ||
KN.append(model) | ||
KN_model_names.append(KN) | ||
return KN_model_names | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--models", | ||
nargs="+", | ||
type=str, | ||
default=["meta-llama/Llama-3.1-8B-Instruct"], | ||
help="List of models to benchmark", | ||
) | ||
parser.add_argument( | ||
"--tp-sizes", | ||
nargs="+", | ||
type=int, | ||
default=[1], | ||
help="List of tensor parallel sizes", | ||
) | ||
args = parser.parse_args() | ||
|
||
KN_model_names = prepare_shapes(args) | ||
for K, N, model_name in KN_model_names: | ||
print(f"{model_name} N={N} K={K}: ") | ||
benchmark.run( | ||
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K | ||
) | ||
|
||
print("Benchmark finished!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.