|
| 1 | +import argparse |
| 2 | +import copy |
| 3 | +import itertools |
| 4 | + |
| 5 | +import torch |
| 6 | +import triton |
| 7 | +from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm |
| 8 | +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm |
| 9 | +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant |
| 10 | + |
| 11 | +# Weight Shapes are in the format |
| 12 | +# ([K, N], TP_SPLIT_DIM) |
| 13 | +# Example: |
| 14 | +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, |
| 15 | +# - TP1 : K = 14336, N = 4096 |
| 16 | +# - TP2 : K = 7168, N = 4096 |
| 17 | +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, |
| 18 | +# - TP1 : K = 4096, N = 6144 |
| 19 | +# - TP4 : K = 4096, N = 1536 |
| 20 | + |
| 21 | +# TP1 shapes |
| 22 | +WEIGHT_SHAPES = { |
| 23 | + "meta-llama/Llama-3.1-8B-Instruct": [ |
| 24 | + ([4096, 6144], 1), |
| 25 | + ([4096, 4096], 0), |
| 26 | + ([4096, 28672], 1), |
| 27 | + ([14336, 4096], 0), |
| 28 | + ], |
| 29 | + "meta-llama/Llama-3.3-70B-Instruct": [ |
| 30 | + ([8192, 10240], 1), |
| 31 | + ([8192, 8192], 0), |
| 32 | + ([8192, 57344], 1), |
| 33 | + ([28672, 8192], 0), |
| 34 | + ], |
| 35 | + "mistralai/Mistral-Large-Instruct-2407": [ |
| 36 | + ([12288, 14336], 1), |
| 37 | + ([12288, 12288], 0), |
| 38 | + ([12288, 57344], 1), |
| 39 | + ([28672, 12288], 0), |
| 40 | + ], |
| 41 | + "Qwen/Qwen2.5-7B-Instruct": [ |
| 42 | + ([3584, 4608], 1), |
| 43 | + ([3584, 3584], 0), |
| 44 | + ([3584, 37888], 1), |
| 45 | + ([18944, 3584], 0), |
| 46 | + ], |
| 47 | + "Qwen/Qwen2.5-32B-Instruct": [ |
| 48 | + ([5120, 7168], 1), |
| 49 | + ([5120, 5120], 0), |
| 50 | + ([5120, 55296], 1), |
| 51 | + ([27648, 5120], 0), |
| 52 | + ], |
| 53 | + "Qwen/Qwen2.5-72B-Instruct": [ |
| 54 | + ([8192, 10240], 1), |
| 55 | + ([8192, 8192], 0), |
| 56 | + ([8192, 59136], 1), |
| 57 | + ([29568, 8192], 0), |
| 58 | + ], |
| 59 | + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ |
| 60 | + ([2048, 3072], 1), |
| 61 | + ([2048, 4096], 1), |
| 62 | + ([2048, 2048], 0), |
| 63 | + ([2048, 576], 0), |
| 64 | + ([2048, 21888], 1), |
| 65 | + ([10944, 2048], 0), |
| 66 | + ([2048, 2816], 1), |
| 67 | + ([1408, 2048], 0), |
| 68 | + ], |
| 69 | +} |
| 70 | + |
| 71 | + |
| 72 | +@triton.testing.perf_report( |
| 73 | + triton.testing.Benchmark( |
| 74 | + x_names=["batch_size"], |
| 75 | + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], |
| 76 | + x_log=False, |
| 77 | + line_arg="provider", |
| 78 | + line_vals=[ |
| 79 | + "vllm-fp8-fp16", |
| 80 | + "vllm-fp8-bf16", |
| 81 | + "sglang-fp8-fp16", |
| 82 | + "sglang-fp8-bf16", |
| 83 | + ], |
| 84 | + line_names=[ |
| 85 | + "vllm-fp8-fp16", |
| 86 | + "vllm-fp8-bf16", |
| 87 | + "sglang-fp8-fp16", |
| 88 | + "sglang-fp8-bf16", |
| 89 | + ], |
| 90 | + styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], |
| 91 | + ylabel="GB/s", |
| 92 | + plot_name="fp8 scaled matmul", |
| 93 | + args={}, |
| 94 | + ) |
| 95 | +) |
| 96 | +def benchmark(batch_size, provider, N, K): |
| 97 | + # M, N, K = batch_size, 4096, 8192 |
| 98 | + M = batch_size |
| 99 | + a = torch.ones((M, K), device="cuda") * 5.0 |
| 100 | + b = torch.ones((N, K), device="cuda") * 5.0 |
| 101 | + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) |
| 102 | + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) |
| 103 | + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) |
| 104 | + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) |
| 105 | + b_fp8 = b_fp8.t() |
| 106 | + quantiles = [0.5, 0.2, 0.8] |
| 107 | + |
| 108 | + dtype = torch.float16 if "fp16" in provider else torch.bfloat16 |
| 109 | + |
| 110 | + if "vllm-fp8" in provider: |
| 111 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 112 | + lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), |
| 113 | + quantiles=quantiles, |
| 114 | + ) |
| 115 | + elif "sglang-fp8" in provider: |
| 116 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 117 | + lambda: sgl_scaled_mm( |
| 118 | + a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None |
| 119 | + ), |
| 120 | + quantiles=quantiles, |
| 121 | + ) |
| 122 | + |
| 123 | + gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) |
| 124 | + return gbps(ms), gbps(max_ms), gbps(min_ms) |
| 125 | + |
| 126 | + |
| 127 | +def prepare_shapes(args): |
| 128 | + KN_model_names = [] |
| 129 | + models_tps = list(itertools.product(args.models, args.tp_sizes)) |
| 130 | + for model, tp_size in models_tps: |
| 131 | + assert model in WEIGHT_SHAPES |
| 132 | + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): |
| 133 | + KN[tp_split_dim] = KN[tp_split_dim] // tp_size |
| 134 | + KN.append(model) |
| 135 | + KN_model_names.append(KN) |
| 136 | + return KN_model_names |
| 137 | + |
| 138 | + |
| 139 | +if __name__ == "__main__": |
| 140 | + parser = argparse.ArgumentParser() |
| 141 | + parser.add_argument( |
| 142 | + "--models", |
| 143 | + nargs="+", |
| 144 | + type=str, |
| 145 | + default=["meta-llama/Llama-3.1-8B-Instruct"], |
| 146 | + help="List of models to benchmark", |
| 147 | + ) |
| 148 | + parser.add_argument( |
| 149 | + "--tp-sizes", |
| 150 | + nargs="+", |
| 151 | + type=int, |
| 152 | + default=[1], |
| 153 | + help="List of tensor parallel sizes", |
| 154 | + ) |
| 155 | + args = parser.parse_args() |
| 156 | + |
| 157 | + KN_model_names = prepare_shapes(args) |
| 158 | + for K, N, model_name in KN_model_names: |
| 159 | + print(f"{model_name} N={N} K={K}: ") |
| 160 | + benchmark.run( |
| 161 | + print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K |
| 162 | + ) |
| 163 | + |
| 164 | + print("Benchmark finished!") |
0 commit comments