Skip to content

Commit 58c95be

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 421dd5a + 11b288c commit 58c95be

File tree

150 files changed

+9565
-2812
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

150 files changed

+9565
-2812
lines changed

.github/workflows/create_release.yml

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
name: Create Release
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- release/*
8+
tags:
9+
# Final Release tags look like: v1.11.0
10+
- v[0-9]+.[0-9]+.[0-9]+
11+
# Release candidate tags look like: v1.11.0-rc1
12+
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
13+
release:
14+
types: [published]
15+
pull_request:
16+
paths: [.github/workflows/create_release.yml]
17+
18+
jobs:
19+
20+
release:
21+
if: ${{ github.repository == 'triton-lang/triton' }}
22+
name: Create Release
23+
runs-on: ubuntu-latest
24+
permissions:
25+
contents: write
26+
outputs:
27+
release_name: "${{ steps.release_name.outputs.name }}"
28+
steps:
29+
- uses: actions/checkout@v4
30+
with:
31+
show-progress: false
32+
submodules: 'recursive'
33+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
34+
- name: Fake name for PRs
35+
if: ${{ github.event_name == 'pull_request' }}
36+
run: echo "PT_GITHUB_REF=refs/tags/pr-tag" >> "$GITHUB_ENV"
37+
- name: Real name for non-PRs
38+
if: ${{ github.event_name != 'pull_request' }}
39+
run: echo "PT_GITHUB_REF=$GITHUB_REF" >> "$GITHUB_ENV"
40+
- name: Set filenames
41+
run: |
42+
tag_or_branch="${PT_GITHUB_REF#refs/tags/}"
43+
tag_or_branch="${tag_or_branch#refs/heads/}"
44+
# replace directory separators with _ in branch name
45+
tag_or_branch="${tag_or_branch//\//_}"
46+
echo "RELEASE_NAME=triton-$tag_or_branch" >> "$GITHUB_ENV"
47+
echo "RELEASE_FILE=triton-$tag_or_branch.tar.gz" >> "$GITHUB_ENV"
48+
- name: Create source distribution
49+
run: |
50+
# Create new folder with specified name so extracting the archive yields that
51+
rm -rf "/tmp/$RELEASE_NAME"
52+
cp -r "$PWD" "/tmp/$RELEASE_NAME"
53+
mv "/tmp/$RELEASE_NAME" .
54+
# Cleanup
55+
find "$RELEASE_NAME" -name '.git*' -exec rm -rv {} \; || true
56+
# Create archive
57+
tar -czf "$RELEASE_FILE" "$RELEASE_NAME"
58+
echo "Created source archive $RELEASE_FILE with content: $(ls -a "$RELEASE_NAME")"
59+
- name: Upload source distribution for release
60+
if: ${{ github.event_name == 'release' }}
61+
uses: softprops/action-gh-release@v2
62+
with:
63+
files: ${{env.RELEASE_FILE}}
64+
- name: Upload source distribution to GHA artifacts for release tags
65+
if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
66+
uses: actions/[email protected]
67+
with:
68+
name: ${{ env.RELEASE_FILE }}
69+
path: ${{ env.RELEASE_FILE }}
70+
- name: Set output
71+
id: release_name
72+
run: echo "name=release_name::${{ env.RELEASE_NAME }}.tar.gz" >> "${GITHUB_OUTPUT}"
73+
74+
concurrency:
75+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
76+
cancel-in-progress: true

bench/bench/bench_mlp.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from pathlib import Path
2+
import json
3+
import triton.profiler as proton
4+
import torch
5+
import triton_bench.swiglu
6+
from triton_bench.mxfp import downcast_to_mxfp
7+
from triton_bench.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx
8+
from triton_bench.numerics import InFlexData
9+
from triton_bench.routing import routing_torch, simulate_expert_sharded_routing
10+
from triton_bench.meta import cuda_capability_geq
11+
12+
if torch.cuda.is_available():
13+
from triton._C.libtriton import nvidia
14+
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
15+
cublas = nvidia.cublas.CublasLt(cublas_workspace)
16+
else:
17+
cublas = None
18+
19+
20+
def _query_gpu_specs():
21+
import subprocess
22+
cmd = ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader", "-i=0"]
23+
output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
24+
name = output.splitlines()[0]
25+
return {
26+
"NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35}, "HGX GB200":
27+
{"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0}
28+
}[name]
29+
30+
31+
SPECS = _query_gpu_specs()
32+
33+
34+
def quantize(w, dtype, dev, **opt):
35+
if dtype == "bf16":
36+
return w.to(torch.bfloat16), InFlexData(), MicroscalingCtx()
37+
elif dtype == "fp8":
38+
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
39+
return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), \
40+
MicroscalingCtx()
41+
else:
42+
assert dtype == "mx4", f"{dtype=}"
43+
swizzle_mx_scale = opt["swizzle_mx_scale"]
44+
swizzle_axis = 2 if swizzle_mx_scale else None
45+
w = w.to(torch.bfloat16)
46+
w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis)
47+
return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_mx=swizzle_mx_scale,
48+
actual_weight_scale_shape=weight_scale_shape)
49+
50+
51+
def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
52+
# tensor / expert parallelism
53+
TP=1, EP=1, name=""):
54+
assert n_expts_tot % EP == 0
55+
assert dim2 % TP == 0
56+
dev = "cuda"
57+
# input
58+
# weights
59+
wg = torch.randn((dim1, n_expts_tot), device=dev)
60+
w1 = torch.randn((n_expts_tot // EP, dim1, dim2 // TP), device=dev)
61+
w2 = torch.randn((n_expts_tot // EP, dim2 // TP // 2, dim1), device=dev)
62+
# biases
63+
bg = torch.randn((n_expts_tot, ), device=dev)
64+
b1 = torch.randn((dim2 // TP, ), device=dev)
65+
b2 = torch.randn((dim1, ), device=dev)
66+
67+
# -- numerics --
68+
optg = dict()
69+
opt1 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict()
70+
opt2 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict()
71+
wg, wg_flex, wg_mx = quantize(wg, "bf16", dev, **optg)
72+
w1, w1_flex, w1_mx = quantize(w1, w_dtype, dev, **opt1)
73+
w2, w2_flex, w2_mx = quantize(w2, w_dtype, dev, **opt2)
74+
pcg = PrecisionConfig(mx_ctx=wg_mx, flex_ctx=FlexCtx(rhs_data=wg_flex))
75+
pcs = triton_bench.swiglu.PrecisionConfig(limit=1.0)
76+
pc1 = PrecisionConfig(mx_ctx=w1_mx, flex_ctx=FlexCtx(rhs_data=w1_flex))
77+
pc2 = PrecisionConfig(mx_ctx=w2_mx, flex_ctx=FlexCtx(rhs_data=w2_flex))
78+
79+
# -- benchmark --
80+
fpath = Path(f"logs/{name}/{batch}-{dim1}-{dim2}-{n_expts_tot}-{n_expts_act}-{x_dtype}-{w_dtype}.hatchet")
81+
fpath.parent.mkdir(parents=True, exist_ok=True)
82+
proton.start(str(fpath.with_suffix('')), hook="triton")
83+
proton.deactivate()
84+
# run layer
85+
x_dtype = {"bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
86+
for i in range(100):
87+
x = torch.randn((batch, dim1), device=dev)
88+
x = x.to(wg.dtype if n_expts_tot > 1 else x_dtype)
89+
# TODO: activate proton here when fast routing is done
90+
if n_expts_tot > 1:
91+
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
92+
rdata, gather_indx, scatter_indx = routing_torch(logits, n_expts_act)
93+
if EP > 1:
94+
m = logits.shape[0] * EP
95+
_, rdata, gather_indx, scatter_indx = simulate_expert_sharded_routing(m, rdata, EP, device=dev)
96+
x = x.to(x_dtype)
97+
else:
98+
rdata, gather_indx, scatter_indx = None, None, None
99+
proton.activate()
100+
# c0 = torch.empty((x.shape[0], w1.shape[-1]), device=dev, dtype=x.dtype)
101+
# c1 = torch.empty((x.shape[0], w2.shape[-1]), device=dev, dtype=x.dtype)
102+
# cublas.matmul(x, w1.squeeze(0), c0)
103+
# cublas.matmul(c0, w2.squeeze(0), c1)
104+
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
105+
x = triton_bench.swiglu.swiglu(x, 1.0, pcs)
106+
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2)
107+
proton.deactivate()
108+
proton.finalize()
109+
110+
# -- analyze --
111+
with open(f"{fpath}") as fd:
112+
data = json.load(fd)
113+
# TODO: this will be broken if kernels use scopes themselves
114+
# compute useful (a.k.a. matmul) bytes and flops
115+
matmuls = [x for x in data[0]["children"] if "matmul" in x["frame"]["name"]]
116+
tot_bytes = sum([x["metrics"]["bytes"] for x in matmuls])
117+
tot_flops = {w: sum([x["metrics"].get(f"flops{w}", 0) for x in matmuls]) for w in [8, 16]}
118+
# compute total time (incl. "not useful" work)
119+
# TODO: proton should really be recording that in the json instead of
120+
# relying on the user to aggregate
121+
tot_time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"])
122+
min_time_flops = sum([tot_flops[w] / SPECS[f"MAX_TFLOPS{w}"] for w in [8, 16]]) * 1e-3
123+
min_time_bytes = tot_bytes / SPECS["MAX_TBPS"] * 1e-3
124+
min_time = max(min_time_flops, min_time_bytes)
125+
util = min_time / tot_time
126+
tflops = sum([tot_flops[w] for w in [8, 16]]) / tot_time * 1e-3
127+
tbps = tot_bytes / tot_time * 1e-3
128+
129+
return util, tflops, tbps
130+
131+
132+
if __name__ == "__main__":
133+
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10
134+
qxdtype = "fp8" if has_native_mx4 else "bf16"
135+
print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense"))
136+
print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense"))
137+
print(bench_mlp(1024, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=2, name="llama4"))
138+
print(bench_mlp(1024, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=2, name="llama4"))

bench/pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[project]
2+
name = "triton_bench"
3+
version = "1.0.0"
4+
dependencies = ["torch", "numpy", "pytest"]
5+
6+
[build-system]
7+
requires = ["setuptools>=64.0"]
8+
build-backend = "setuptools.build_meta"
9+
10+
[tool.setuptools.packages.find]
11+
include = ["triton_bench*"]

bench/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)