Skip to content

SM-constraint-GEMM by triton persistent kernel #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
58b1ca3
init
yyihuang Mar 29, 2025
6cdb997
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Mar 29, 2025
d4c1204
upd
yyihuang Mar 29, 2025
bbe6e5c
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Mar 29, 2025
6dbad35
upd num_sm
yyihuang Mar 29, 2025
1b7276f
upd (todo: bf16, fp32, bench)
yyihuang Mar 29, 2025
2afe463
add fp32 (todo: fp8, bf16)
yyihuang Mar 29, 2025
318cf2a
upd
yyihuang Mar 29, 2025
36b8d71
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Mar 29, 2025
8b1fcab
upd
yyihuang Mar 29, 2025
2490994
upd
yyihuang Mar 29, 2025
ac797b6
upd
yyihuang Mar 30, 2025
14c3ec4
cleanup
yyihuang Mar 30, 2025
52a243c
upd
yyihuang Mar 30, 2025
c8d413a
init tma desc gemm
yyihuang Mar 30, 2025
5d420cc
upd fma
yyihuang Mar 31, 2025
3837ee4
add test
yyihuang Mar 31, 2025
e562de5
upd
yyihuang Mar 31, 2025
c19acf5
upd test (todo: fp8)
yyihuang Mar 31, 2025
ed44012
checkpoint: gemv and fp8 fail
yyihuang Mar 31, 2025
5e9f9b1
upd
yyihuang Mar 31, 2025
9924f17
add tmp check
yyihuang Mar 31, 2025
33a782a
ckpt: disable fp8 test on overflow
yyihuang Mar 31, 2025
4d2cee0
add param: out_dtype
yyihuang Mar 31, 2025
62d70cd
upd tma size check
yyihuang Mar 31, 2025
3fe81fe
upd tests
yyihuang Apr 1, 2025
f2ced09
cleanup
yyihuang Apr 1, 2025
d8f82ad
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Apr 1, 2025
a6e85dc
upd
yyihuang Apr 1, 2025
8b2a219
upd comments
yyihuang Apr 1, 2025
b906846
upd isort
yyihuang Apr 1, 2025
c527450
upd
yyihuang Apr 1, 2025
51c68ff
upd assert
yyihuang Apr 1, 2025
44544ec
upd
yyihuang Apr 1, 2025
1f880a2
upd
yyihuang Apr 1, 2025
d66d0a6
add benchmark
yyihuang Apr 1, 2025
50a5bed
fmt
yyihuang Apr 1, 2025
2804abf
pre-commit
yyihuang Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flashinfer/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import cascade # noqa: F401
from . import sm_constraint_gemm # Add this line
290 changes: 290 additions & 0 deletions flashinfer/triton/kernels/sm_constraint_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import triton # type: ignore[import]
import triton.language as tl # type: ignore[import]


def matmul_get_configs():
return [
triton.Config(
{
"BLOCK_SIZE_M": BM,
"BLOCK_SIZE_N": BN,
"BLOCK_SIZE_K": BK,
"GROUP_SIZE_M": 8,
},
num_stages=s,
num_warps=w,
)
for BM in [128]
for BN in [128]
for BK in [64]
for s in ([3])
for w in [4]
]


def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K = args["M"], args["N"], args["K"]
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
return ret


@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n


@triton.autotune(
configs=matmul_get_configs(),
key=["M", "N", "K"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def gemm_kernel_persistent(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
alpha,
beta,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
):
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

# NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being
# used in both the prologue and epilogue, so we duplicate the counters as a work-around.
tile_id_c = start_pid - NUM_SMS

offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n

for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
pid_m, pid_n = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = b_ptr + (
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
)

a = tl.load(
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
)
b = tl.load(
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
)
accumulator = tl.dot(a, b, accumulator)

tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
c = accumulator.to(c_ptr.dtype.element_ty)

c = tl.fma(c, alpha, beta * tl.load(c_ptrs, mask=c_mask))
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit(launch_metadata=_matmul_launch_metadata)
def gemm_kernel_descriptor_persistent(
a_ptr,
b_ptr,
c_ptr, #
M,
N,
K, #
alpha,
beta,
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
EPILOGUE_SUBTILE: tl.constexpr, #
NUM_SMS: tl.constexpr,
): #
dtype = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[M, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
c_desc = tl.make_tensor_descriptor(
c_ptr,
shape=[M, N],
strides=[N, 1],
block_shape=[
BLOCK_SIZE_M,
BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2,
],
)

# tile_id_c is used in the epilogue to break the dependency between
# the prologue and the epilogue
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_SIZE_M * num_pid_n

for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
pid_m, pid_n = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K
a = a_desc.load([offs_am, offs_k])
b = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a, b.T, accumulator)

tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_SIZE_M
offs_cn = pid_n * BLOCK_SIZE_N

if EPILOGUE_SUBTILE:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
acc0 = tl.fma(acc0, alpha, beta * c_desc.load([offs_cm, offs_cn]))
acc1 = tl.fma(
acc1, alpha, beta * c_desc.load([offs_cm, offs_cn + BLOCK_SIZE_N // 2])
)
c0 = acc0.to(dtype)
c_desc.store([offs_cm, offs_cn], c0)
c1 = acc1.to(dtype)
c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1)
else:
accumulator = tl.fma(
accumulator, alpha, beta * c_desc.load([offs_cm, offs_cn])
)
c = accumulator.to(dtype)
c_desc.store([offs_cm, offs_cn], c)


# only for testing
@triton.autotune(
configs=matmul_get_configs(),
key=["M", "N", "K"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def gemm_kernel(
a_ptr,
b_ptr,
c_ptr, #
M,
N,
K, #
stride_am,
stride_ak, #
stride_bk,
stride_bn, #
stride_cm,
stride_cn, #
alpha,
beta,
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N

offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)

offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

c = accumulator.to(c_ptr.dtype.element_ty)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
c = tl.fma(c, alpha, beta * tl.load(c_ptrs, mask=c_mask))
tl.store(c_ptrs, c, mask=c_mask)
Loading