Skip to content

Commit 5751fc6

Browse files
authored
feat: SM-constraint-GEMM by triton persistent kernel (#982)
Add SM-constraint GEMM operation by triton persistent kernel to support Nanoflow infra-device parallelism. Checklist: - [x] functional test passed - [x] benchmark - [x] SM usage by nsys profile - [ ] (optional for this PR) tune: get best config for gemm **Benchmark results**: https://docs.google.com/document/d/189f1VdZ36B-iJTYlC2LDgWGDSWKCiZI6PTfg-ltjXv4/edit?usp=sharing **Nsys Results** > When num_sm = 1: > gemm_kernel_persistent > Begins: 3.89591s > Ends: 3.89592s (+5.248 μs) > grid: <<<1, 1, 1>>> > block: <<<128, 1, 1>>> > > When num_sm = 32: > gemm_kernel_persistent > Begins: 3.91269s > Ends: 3.92016s (+7.466 ms) > grid: <<<32, 1, 1>>> > block: <<<128, 1, 1>>> > > When num_sm = 64: > gemm_kernel_persistent > Begins: 3.59851s > Ends: 3.60234s (+3.829 ms) > grid: <<<64, 1, 1>>> > block: <<<128, 1, 1>>> > Launch Type: Regular > > When num_sm = 128: > gemm_kernel_persistent > Begins: 3.17387s > Ends: 3.17586s (+1.992 ms) > grid: <<<128, 1, 1>>> > block: <<<128, 1, 1>>> > > When num_sm = 133: > gemm_kernel_persistent > Begins: 3.51542s > Ends: 3.5173s (+1.879 ms) > grid: <<<132, 1, 1>>> > block: <<<128, 1, 1>>> Related issues: #591 #675
1 parent d7a9234 commit 5751fc6

File tree

5 files changed

+815
-0
lines changed

5 files changed

+815
-0
lines changed

benchmarks/bench_persistent_gemm.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
import torch
3+
import triton
4+
from triton.testing import do_bench
5+
6+
import flashinfer
7+
import flashinfer.triton
8+
9+
10+
def is_cuda():
11+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
12+
13+
14+
def supports_tma():
15+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
16+
17+
18+
def bench_gemm_persistent(num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000):
19+
ms = do_bench(
20+
lambda: flashinfer.triton.sm_constraint_gemm.gemm_persistent(
21+
a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype),
22+
b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype),
23+
alpha=1.0,
24+
beta=0.0,
25+
num_sms=num_sms,
26+
),
27+
warmup=warmup_reps,
28+
rep=reps,
29+
)
30+
31+
# matmul: 2 * M * N * K
32+
# scale and add: 3 * M * N
33+
flops = (2 * M * N * K + 3 * M * N) / ms / 1e9
34+
print(
35+
f"GEMM Persistent | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s"
36+
)
37+
38+
39+
def bench_gemm_descriptor_persistent(
40+
num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000
41+
):
42+
if dtype == torch.float32:
43+
return
44+
ms = do_bench(
45+
lambda: flashinfer.triton.sm_constraint_gemm.gemm_descriptor_persistent(
46+
a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype),
47+
b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype),
48+
alpha=1.0,
49+
beta=0.0,
50+
num_sms=num_sms,
51+
),
52+
warmup=warmup_reps,
53+
rep=reps,
54+
)
55+
56+
# matmul: 2 * M * N * K
57+
# scale and add: 3 * M * N
58+
flops = (2 * M * N * K + 3 * M * N) / ms / 1e9
59+
print(
60+
f"GEMM Descriptor | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s"
61+
)
62+
63+
64+
if __name__ == "__main__":
65+
assert supports_tma()
66+
67+
for M, N, K in [(4096, 4096, 4096), (8192, 8192, 8192)]:
68+
for dtype in [
69+
torch.float8_e4m3fn,
70+
torch.float16,
71+
torch.bfloat16,
72+
torch.float32,
73+
]:
74+
for num_sms in [1, 16, 32, 64, 128, 132, 133, 256]:
75+
bench_gemm_persistent(num_sms, dtype, M, N, K)
76+
bench_gemm_descriptor_persistent(num_sms, dtype, M, N, K)

flashinfer/triton/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from . import cascade # noqa: F401
2+
from . import sm_constraint_gemm # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import triton # type: ignore[import]
2+
import triton.language as tl # type: ignore[import]
3+
4+
5+
def matmul_get_configs():
6+
return [
7+
triton.Config(
8+
{
9+
"BLOCK_SIZE_M": BM,
10+
"BLOCK_SIZE_N": BN,
11+
"BLOCK_SIZE_K": BK,
12+
"GROUP_SIZE_M": 8,
13+
},
14+
num_stages=s,
15+
num_warps=w,
16+
)
17+
for BM in [128]
18+
for BN in [128]
19+
for BK in [64]
20+
for s in ([3])
21+
for w in [4]
22+
]
23+
24+
25+
def _matmul_launch_metadata(grid, kernel, args):
26+
ret = {}
27+
M, N, K = args["M"], args["N"], args["K"]
28+
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
29+
if "c_ptr" in args:
30+
bytes_per_elem = args["c_ptr"].element_size()
31+
else:
32+
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
33+
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K
34+
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
35+
return ret
36+
37+
38+
@triton.jit
39+
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
40+
group_id = tile_id // num_pid_in_group
41+
first_pid_m = group_id * GROUP_SIZE_M
42+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
43+
pid_m = first_pid_m + (tile_id % group_size_m)
44+
pid_n = (tile_id % num_pid_in_group) // group_size_m
45+
return pid_m, pid_n
46+
47+
48+
@triton.autotune(
49+
configs=matmul_get_configs(),
50+
key=["M", "N", "K"],
51+
)
52+
@triton.jit(launch_metadata=_matmul_launch_metadata)
53+
def gemm_kernel_persistent(
54+
a_ptr,
55+
b_ptr,
56+
c_ptr,
57+
M,
58+
N,
59+
K,
60+
stride_am,
61+
stride_ak,
62+
stride_bk,
63+
stride_bn,
64+
stride_cm,
65+
stride_cn,
66+
alpha,
67+
beta,
68+
BLOCK_SIZE_M: tl.constexpr,
69+
BLOCK_SIZE_N: tl.constexpr,
70+
BLOCK_SIZE_K: tl.constexpr,
71+
GROUP_SIZE_M: tl.constexpr,
72+
NUM_SMS: tl.constexpr,
73+
):
74+
start_pid = tl.program_id(axis=0)
75+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
76+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
77+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
78+
num_tiles = num_pid_m * num_pid_n
79+
80+
# NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being
81+
# used in both the prologue and epilogue, so we duplicate the counters as a work-around.
82+
tile_id_c = start_pid - NUM_SMS
83+
84+
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
85+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
86+
87+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
88+
pid_m, pid_n = _compute_pid(
89+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
90+
)
91+
start_m = pid_m * BLOCK_SIZE_M
92+
start_n = pid_n * BLOCK_SIZE_N
93+
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
94+
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
95+
offs_am = tl.where(offs_am < M, offs_am, 0)
96+
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
97+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
98+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
99+
100+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
101+
for ki in range(k_tiles):
102+
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
103+
a_ptrs = a_ptr + (
104+
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
105+
)
106+
b_ptrs = b_ptr + (
107+
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
108+
)
109+
110+
a = tl.load(
111+
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
112+
)
113+
b = tl.load(
114+
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
115+
)
116+
accumulator = tl.dot(a, b, accumulator)
117+
118+
tile_id_c += NUM_SMS
119+
pid_m, pid_n = _compute_pid(
120+
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
121+
)
122+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
123+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
124+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
125+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
126+
c = accumulator.to(c_ptr.dtype.element_ty)
127+
128+
c = tl.fma(c, alpha, beta * tl.load(c_ptrs, mask=c_mask))
129+
tl.store(c_ptrs, c, mask=c_mask)
130+
131+
132+
@triton.jit(launch_metadata=_matmul_launch_metadata)
133+
def gemm_kernel_descriptor_persistent(
134+
a_ptr,
135+
b_ptr,
136+
c_ptr, #
137+
M,
138+
N,
139+
K, #
140+
alpha,
141+
beta,
142+
BLOCK_SIZE_M: tl.constexpr, #
143+
BLOCK_SIZE_N: tl.constexpr, #
144+
BLOCK_SIZE_K: tl.constexpr, #
145+
GROUP_SIZE_M: tl.constexpr, #
146+
EPILOGUE_SUBTILE: tl.constexpr, #
147+
NUM_SMS: tl.constexpr,
148+
): #
149+
dtype = c_ptr.dtype.element_ty
150+
start_pid = tl.program_id(axis=0)
151+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
152+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
153+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
154+
num_tiles = num_pid_m * num_pid_n
155+
156+
a_desc = tl.make_tensor_descriptor(
157+
a_ptr,
158+
shape=[M, K],
159+
strides=[K, 1],
160+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
161+
)
162+
b_desc = tl.make_tensor_descriptor(
163+
b_ptr,
164+
shape=[N, K],
165+
strides=[K, 1],
166+
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
167+
)
168+
c_desc = tl.make_tensor_descriptor(
169+
c_ptr,
170+
shape=[M, N],
171+
strides=[N, 1],
172+
block_shape=[
173+
BLOCK_SIZE_M,
174+
BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2,
175+
],
176+
)
177+
178+
# tile_id_c is used in the epilogue to break the dependency between
179+
# the prologue and the epilogue
180+
tile_id_c = start_pid - NUM_SMS
181+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
182+
183+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
184+
pid_m, pid_n = _compute_pid(
185+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
186+
)
187+
offs_am = pid_m * BLOCK_SIZE_M
188+
offs_bn = pid_n * BLOCK_SIZE_N
189+
190+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
191+
for ki in range(k_tiles):
192+
offs_k = ki * BLOCK_SIZE_K
193+
a = a_desc.load([offs_am, offs_k])
194+
b = b_desc.load([offs_bn, offs_k])
195+
accumulator = tl.dot(a, b.T, accumulator)
196+
197+
tile_id_c += NUM_SMS
198+
pid_m, pid_n = _compute_pid(
199+
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
200+
)
201+
offs_cm = pid_m * BLOCK_SIZE_M
202+
offs_cn = pid_n * BLOCK_SIZE_N
203+
204+
if EPILOGUE_SUBTILE:
205+
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
206+
acc = tl.permute(acc, (0, 2, 1))
207+
acc0, acc1 = tl.split(acc)
208+
acc0 = tl.fma(acc0, alpha, beta * c_desc.load([offs_cm, offs_cn]))
209+
acc1 = tl.fma(
210+
acc1, alpha, beta * c_desc.load([offs_cm, offs_cn + BLOCK_SIZE_N // 2])
211+
)
212+
c0 = acc0.to(dtype)
213+
c_desc.store([offs_cm, offs_cn], c0)
214+
c1 = acc1.to(dtype)
215+
c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1)
216+
else:
217+
accumulator = tl.fma(
218+
accumulator, alpha, beta * c_desc.load([offs_cm, offs_cn])
219+
)
220+
c = accumulator.to(dtype)
221+
c_desc.store([offs_cm, offs_cn], c)
222+
223+
224+
# only for testing
225+
@triton.autotune(
226+
configs=matmul_get_configs(),
227+
key=["M", "N", "K"],
228+
)
229+
@triton.jit(launch_metadata=_matmul_launch_metadata)
230+
def gemm_kernel(
231+
a_ptr,
232+
b_ptr,
233+
c_ptr, #
234+
M,
235+
N,
236+
K, #
237+
stride_am,
238+
stride_ak, #
239+
stride_bk,
240+
stride_bn, #
241+
stride_cm,
242+
stride_cn, #
243+
alpha,
244+
beta,
245+
BLOCK_SIZE_M: tl.constexpr, #
246+
BLOCK_SIZE_N: tl.constexpr, #
247+
BLOCK_SIZE_K: tl.constexpr, #
248+
GROUP_SIZE_M: tl.constexpr, #
249+
):
250+
pid = tl.program_id(axis=0)
251+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
252+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
253+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
254+
group_id = pid // num_pid_in_group
255+
first_pid_m = group_id * GROUP_SIZE_M
256+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
257+
pid_m = first_pid_m + (pid % group_size_m)
258+
pid_n = (pid % num_pid_in_group) // group_size_m
259+
260+
start_m = pid_m * BLOCK_SIZE_M
261+
start_n = pid_n * BLOCK_SIZE_N
262+
263+
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
264+
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
265+
offs_am = tl.where(offs_am < M, offs_am, 0)
266+
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
267+
268+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
269+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
270+
offs_k = tl.arange(0, BLOCK_SIZE_K)
271+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
272+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
273+
274+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
275+
276+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
277+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
278+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
279+
accumulator = tl.dot(a, b, accumulator)
280+
a_ptrs += BLOCK_SIZE_K * stride_ak
281+
b_ptrs += BLOCK_SIZE_K * stride_bk
282+
283+
c = accumulator.to(c_ptr.dtype.element_ty)
284+
285+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
286+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
287+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
288+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
289+
c = tl.fma(c, alpha, beta * tl.load(c_ptrs, mask=c_mask))
290+
tl.store(c_ptrs, c, mask=c_mask)

0 commit comments

Comments
 (0)