|
| 1 | +from pathlib import Path |
| 2 | +import json |
| 3 | +import triton.profiler as proton |
| 4 | +import torch |
| 5 | +import triton_bench.swiglu |
| 6 | +from triton_bench.mxfp import downcast_to_mxfp |
| 7 | +from triton_bench.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx |
| 8 | +from triton_bench.numerics import InFlexData |
| 9 | +from triton_bench.routing import routing_torch, simulate_expert_sharded_routing |
| 10 | +from triton_bench.meta import cuda_capability_geq |
| 11 | + |
| 12 | +if torch.cuda.is_available(): |
| 13 | + from triton._C.libtriton import nvidia |
| 14 | + cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) |
| 15 | + cublas = nvidia.cublas.CublasLt(cublas_workspace) |
| 16 | +else: |
| 17 | + cublas = None |
| 18 | + |
| 19 | + |
| 20 | +def _query_gpu_specs(): |
| 21 | + import subprocess |
| 22 | + cmd = ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader", "-i=0"] |
| 23 | + output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip() |
| 24 | + name = output.splitlines()[0] |
| 25 | + return { |
| 26 | + "NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35}, "HGX GB200": |
| 27 | + {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0} |
| 28 | + }[name] |
| 29 | + |
| 30 | + |
| 31 | +SPECS = _query_gpu_specs() |
| 32 | + |
| 33 | + |
| 34 | +def quantize(w, dtype, dev, **opt): |
| 35 | + if dtype == "bf16": |
| 36 | + return w.to(torch.bfloat16), InFlexData(), MicroscalingCtx() |
| 37 | + elif dtype == "fp8": |
| 38 | + wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2) |
| 39 | + return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), \ |
| 40 | + MicroscalingCtx() |
| 41 | + else: |
| 42 | + assert dtype == "mx4", f"{dtype=}" |
| 43 | + swizzle_mx_scale = opt["swizzle_mx_scale"] |
| 44 | + swizzle_axis = 2 if swizzle_mx_scale else None |
| 45 | + w = w.to(torch.bfloat16) |
| 46 | + w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis) |
| 47 | + return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_mx=swizzle_mx_scale, |
| 48 | + actual_weight_scale_shape=weight_scale_shape) |
| 49 | + |
| 50 | + |
| 51 | +def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, |
| 52 | + # tensor / expert parallelism |
| 53 | + TP=1, EP=1, name=""): |
| 54 | + assert n_expts_tot % EP == 0 |
| 55 | + assert dim2 % TP == 0 |
| 56 | + dev = "cuda" |
| 57 | + # input |
| 58 | + # weights |
| 59 | + wg = torch.randn((dim1, n_expts_tot), device=dev) |
| 60 | + w1 = torch.randn((n_expts_tot // EP, dim1, dim2 // TP), device=dev) |
| 61 | + w2 = torch.randn((n_expts_tot // EP, dim2 // TP // 2, dim1), device=dev) |
| 62 | + # biases |
| 63 | + bg = torch.randn((n_expts_tot, ), device=dev) |
| 64 | + b1 = torch.randn((dim2 // TP, ), device=dev) |
| 65 | + b2 = torch.randn((dim1, ), device=dev) |
| 66 | + |
| 67 | + # -- numerics -- |
| 68 | + optg = dict() |
| 69 | + opt1 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict() |
| 70 | + opt2 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict() |
| 71 | + wg, wg_flex, wg_mx = quantize(wg, "bf16", dev, **optg) |
| 72 | + w1, w1_flex, w1_mx = quantize(w1, w_dtype, dev, **opt1) |
| 73 | + w2, w2_flex, w2_mx = quantize(w2, w_dtype, dev, **opt2) |
| 74 | + pcg = PrecisionConfig(mx_ctx=wg_mx, flex_ctx=FlexCtx(rhs_data=wg_flex)) |
| 75 | + pcs = triton_bench.swiglu.PrecisionConfig(limit=1.0) |
| 76 | + pc1 = PrecisionConfig(mx_ctx=w1_mx, flex_ctx=FlexCtx(rhs_data=w1_flex)) |
| 77 | + pc2 = PrecisionConfig(mx_ctx=w2_mx, flex_ctx=FlexCtx(rhs_data=w2_flex)) |
| 78 | + |
| 79 | + # -- benchmark -- |
| 80 | + fpath = Path(f"logs/{name}/{batch}-{dim1}-{dim2}-{n_expts_tot}-{n_expts_act}-{x_dtype}-{w_dtype}.hatchet") |
| 81 | + fpath.parent.mkdir(parents=True, exist_ok=True) |
| 82 | + proton.start(str(fpath.with_suffix('')), hook="triton") |
| 83 | + proton.deactivate() |
| 84 | + # run layer |
| 85 | + x_dtype = {"bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype] |
| 86 | + for i in range(100): |
| 87 | + x = torch.randn((batch, dim1), device=dev) |
| 88 | + x = x.to(wg.dtype if n_expts_tot > 1 else x_dtype) |
| 89 | + # TODO: activate proton here when fast routing is done |
| 90 | + if n_expts_tot > 1: |
| 91 | + logits = matmul_ogs(x, wg, bg, precision_config=pcg) |
| 92 | + rdata, gather_indx, scatter_indx = routing_torch(logits, n_expts_act) |
| 93 | + if EP > 1: |
| 94 | + m = logits.shape[0] * EP |
| 95 | + _, rdata, gather_indx, scatter_indx = simulate_expert_sharded_routing(m, rdata, EP, device=dev) |
| 96 | + x = x.to(x_dtype) |
| 97 | + else: |
| 98 | + rdata, gather_indx, scatter_indx = None, None, None |
| 99 | + proton.activate() |
| 100 | + # c0 = torch.empty((x.shape[0], w1.shape[-1]), device=dev, dtype=x.dtype) |
| 101 | + # c1 = torch.empty((x.shape[0], w2.shape[-1]), device=dev, dtype=x.dtype) |
| 102 | + # cublas.matmul(x, w1.squeeze(0), c0) |
| 103 | + # cublas.matmul(c0, w2.squeeze(0), c1) |
| 104 | + x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1) |
| 105 | + x = triton_bench.swiglu.swiglu(x, 1.0, pcs) |
| 106 | + x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2) |
| 107 | + proton.deactivate() |
| 108 | + proton.finalize() |
| 109 | + |
| 110 | + # -- analyze -- |
| 111 | + with open(f"{fpath}") as fd: |
| 112 | + data = json.load(fd) |
| 113 | + # TODO: this will be broken if kernels use scopes themselves |
| 114 | + # compute useful (a.k.a. matmul) bytes and flops |
| 115 | + matmuls = [x for x in data[0]["children"] if "matmul" in x["frame"]["name"]] |
| 116 | + tot_bytes = sum([x["metrics"]["bytes"] for x in matmuls]) |
| 117 | + tot_flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]} |
| 118 | + # compute total time (incl. "not useful" work) |
| 119 | + # TODO: proton should really be recording that in the json instead of |
| 120 | + # relying on the user to aggregate |
| 121 | + tot_time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"]) |
| 122 | + min_time_flops = sum([tot_flops[w] / SPECS[f"MAX_TFLOPS{w}"] for w in [8, 16]]) * 1e-3 |
| 123 | + min_time_bytes = tot_bytes / SPECS["MAX_TBPS"] * 1e-3 |
| 124 | + min_time = max(min_time_flops, min_time_bytes) |
| 125 | + util = min_time / tot_time |
| 126 | + tflops = sum([tot_flops[w] for w in [8, 16]]) / tot_time * 1e-3 |
| 127 | + tbps = tot_bytes / tot_time * 1e-3 |
| 128 | + |
| 129 | + return util, tflops, tbps |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 |
| 134 | + qxdtype = "fp8" if has_native_mx4 else "bf16" |
| 135 | + print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense")) |
| 136 | + print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense")) |
| 137 | + print(bench_mlp(1024, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=2, name="llama4")) |
| 138 | + print(bench_mlp(1024, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=2, name="llama4")) |
0 commit comments