Skip to content

SM-constraint-GEMM by triton persistent kernel #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
58b1ca3
init
yyihuang Mar 29, 2025
6cdb997
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Mar 29, 2025
d4c1204
upd
yyihuang Mar 29, 2025
bbe6e5c
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Mar 29, 2025
6dbad35
upd num_sm
yyihuang Mar 29, 2025
1b7276f
upd (todo: bf16, fp32, bench)
yyihuang Mar 29, 2025
2afe463
add fp32 (todo: fp8, bf16)
yyihuang Mar 29, 2025
318cf2a
upd
yyihuang Mar 29, 2025
36b8d71
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Mar 29, 2025
8b1fcab
upd
yyihuang Mar 29, 2025
2490994
upd
yyihuang Mar 29, 2025
ac797b6
upd
yyihuang Mar 30, 2025
14c3ec4
cleanup
yyihuang Mar 30, 2025
52a243c
upd
yyihuang Mar 30, 2025
c8d413a
init tma desc gemm
yyihuang Mar 30, 2025
5d420cc
upd fma
yyihuang Mar 31, 2025
3837ee4
add test
yyihuang Mar 31, 2025
e562de5
upd
yyihuang Mar 31, 2025
c19acf5
upd test (todo: fp8)
yyihuang Mar 31, 2025
ed44012
checkpoint: gemv and fp8 fail
yyihuang Mar 31, 2025
5e9f9b1
upd
yyihuang Mar 31, 2025
9924f17
add tmp check
yyihuang Mar 31, 2025
33a782a
ckpt: disable fp8 test on overflow
yyihuang Mar 31, 2025
4d2cee0
add param: out_dtype
yyihuang Mar 31, 2025
62d70cd
upd tma size check
yyihuang Mar 31, 2025
3fe81fe
upd tests
yyihuang Apr 1, 2025
f2ced09
cleanup
yyihuang Apr 1, 2025
d8f82ad
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
yyihuang Apr 1, 2025
a6e85dc
upd
yyihuang Apr 1, 2025
8b2a219
upd comments
yyihuang Apr 1, 2025
b906846
upd isort
yyihuang Apr 1, 2025
c527450
upd
yyihuang Apr 1, 2025
51c68ff
upd assert
yyihuang Apr 1, 2025
44544ec
upd
yyihuang Apr 1, 2025
1f880a2
upd
yyihuang Apr 1, 2025
d66d0a6
add benchmark
yyihuang Apr 1, 2025
50a5bed
fmt
yyihuang Apr 1, 2025
2804abf
pre-commit
yyihuang Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flashinfer/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import cascade # noqa: F401
from . import sm_constraint_gemm # Add this line
200 changes: 200 additions & 0 deletions flashinfer/triton/kernels/sm_constraint_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import triton # type: ignore[import]
import triton.language as tl # type: ignore[import]

# todo(yingyi): config??


def matmul_get_configs():
return [
triton.Config(
{
"BLOCK_SIZE_M": BM,
"BLOCK_SIZE_N": BN,
"BLOCK_SIZE_K": BK,
"GROUP_SIZE_M": 8,
},
num_stages=s,
num_warps=w,
)
for BM in [128]
for BN in [128]
for BK in [64]
for s in ([3])
for w in [4]
]


def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K = args["M"], args["N"], args["K"]
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
return ret


@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n


@triton.autotune(
configs=matmul_get_configs(),
key=["M", "N", "K"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def gemm_kernel_persistent(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
alpha,
beta,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
):
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

# NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being
# used in both the prologue and epilogue, so we duplicate the counters as a work-around.
tile_id_c = start_pid - NUM_SMS

offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n

for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
pid_m, pid_n = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = b_ptr + (
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
)

a = tl.load(
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
)
b = tl.load(
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
)
accumulator = tl.dot(a, b, accumulator)

tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
c = accumulator.to(c_ptr.dtype.element_ty)

c = alpha * c + beta * tl.load(c_ptrs, mask=c_mask)
tl.store(c_ptrs, c, mask=c_mask)


# only for testing
@triton.autotune(
configs=matmul_get_configs(),
key=["M", "N", "K"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def gemm_kernel(
a_ptr,
b_ptr,
c_ptr, #
M,
N,
K, #
stride_am,
stride_ak, #
stride_bk,
stride_bn, #
stride_cm,
stride_cn, #
alpha,
beta,
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N

offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)

offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

c = accumulator.to(c_ptr.dtype.element_ty)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
c = alpha * c + beta * tl.load(c_ptrs, mask=c_mask)
tl.store(c_ptrs, c, mask=c_mask)
129 changes: 129 additions & 0 deletions flashinfer/triton/sm_constraint_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Optional

import torch
import triton

from .kernels.sm_constraint_gemm import (
gemm_kernel_persistent, gemm_kernel
)
from .utils import check_device, check_dim, check_input


def gemm_persistent(a, b, c=None, alpha=1.0, beta=0.0, num_sms=None):
"""
GEMM operation with SM constraint by Triton (Hopper).
C = alpha * (a @ b.T) + beta * C

Args:
a: The first input matrix. Shape: (M, K)
b: The second input matrix. Shape: (K, N)
c: The output matrix. Shape: (M, N). In-place operation is supported.
alpha: The scaling factor for the product of a and b.
beta: The scaling factor for the output matrix c.
num_sms: The number of SMs to use for the computation.
"""

# Check constraints.
check_input(a)
# check_input(b) # b can be non-contiguous
check_input(c)
check_device([a, b, c])
check_dim(2, a)
check_dim(2, b)
check_dim(2, c)

assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b"
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c"

M, K = a.shape
K, N = b.shape
dtype = a.dtype

# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=dtype) if c is None else c

# Set num_sms to be 100% of the available SMs
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
num_sms = NUM_SMS if num_sms is None else min(NUM_SMS, num_sms)

# 1D launch kernel where each block gets its own program.
grid = lambda META: (
min(
num_sms,
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
),
)

gemm_kernel_persistent[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
alpha=alpha,
beta=beta,
NUM_SMS=num_sms,
)
return c


def gemm(a, b, c=None, alpha=1.0, beta=0.0, num_sms=None):
"""
GEMM operation without SM constraint by Triton.
C = alpha * (a @ b.T) + beta * C

Args:
a: The first input matrix. Shape: (M, K)
b: The second input matrix. Shape: (K, N)
c: The output matrix. Shape: (M, N). In-place operation is supported.
alpha: The scaling factor for the product of a and b.
"""
# Check constraints.
check_input(a)
check_input(c)
check_device([a, b, c])
check_dim(2, a)
check_dim(2, b)
check_dim(2, c)

assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b"
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c"

M, K = a.shape
K, N = b.shape
dtype = a.dtype

# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=dtype) if c is None else c

# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )

gemm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
alpha=alpha,
beta=beta,
)
return c
61 changes: 61 additions & 0 deletions tests/test_sm_constraint_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import torch

import flashinfer
import flashinfer.triton

def torch_gemm(a, b, c, alpha, beta):
x = torch.matmul(a, b.T)
c = alpha * x + beta * c
return c

def torch_addmm(a, b, c, alpha=1.0, beta=0.0):
# Transpose b to match torch_gemm's matmul(a, b.T)
C = torch.addmm(c, a, b.T, beta=beta, alpha=alpha)
return C

@pytest.mark.parametrize("M", [128, 256, 512, 1024, 8192])
@pytest.mark.parametrize("N", [128, 256, 512, 1024, 8192])
@pytest.mark.parametrize("K", [128, 256, 512, 1024, 8192])
@pytest.mark.parametrize("alpha", [1.0, 0.5, 2.0])
@pytest.mark.parametrize("beta", [0.0, 0.5, 2.0])
@pytest.mark.parametrize("num_sms", [1, 16, 64, 128, 132, 133])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float16, torch.bfloat16, torch.float32])
def test_sm_constraint_gemm(M, N, K, alpha, beta, num_sms, dtype):
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
c = torch.randn((M, N), device="cuda", dtype=torch.float16).to(dtype)
c_clone = c.clone()
assert torch.allclose(c.to(torch.float16), c_clone.to(torch.float16))

c_torch = torch_gemm(a, b, c, alpha, beta) if dtype == torch.float16 or dtype == torch.float32 or dtype == torch.bfloat16 else None
c_triton = flashinfer.triton.sm_constraint_gemm.gemm_persistent(a, b.T, c, alpha, beta, num_sms)
c_naive = flashinfer.triton.sm_constraint_gemm.gemm(a, b.T, c_clone, alpha, beta)

cmp_dtype = torch.float16 if dtype == torch.float8_e4m3fn else dtype
torch_atol = 10.0 if dtype == torch.bfloat16 else 1.0

in_place_triton_persistent = c_triton.data_ptr() == c.data_ptr() and torch.allclose(c_triton.to(cmp_dtype), c.to(cmp_dtype))
assert in_place_triton_persistent # modified in place

in_place_naive = c_naive.data_ptr() == c_clone.data_ptr() and torch.allclose(c_naive.to(cmp_dtype), c_clone.to(cmp_dtype))
assert in_place_naive # modified in place

if c_torch is not None:
torch_vs_triton = torch.allclose(c_torch.to(cmp_dtype), c_triton.to(cmp_dtype), atol=torch_atol)
if torch_vs_triton == False:
print(f"c_torch: {c_torch}")
print(f"c_triton: {c_triton}")
print(f"max diff: {torch.max(torch.abs(c_torch.to(cmp_dtype) - c_triton.to(cmp_dtype)))}")
assert torch_vs_triton # value is correct

triton_atol = 1.0
naive_vs_persistent = torch.allclose(c_naive.to(cmp_dtype), c_triton.to(cmp_dtype), atol=triton_atol)
if naive_vs_persistent == False:
if c_torch is not None:
print(f"c_torch: {c_torch}")
print(f"c_naive: {c_naive}")
print(f"c_triton: {c_triton}")
print(f"max diff: {torch.max(torch.abs(c_naive.to(cmp_dtype) - c_triton.to(cmp_dtype)))}")
assert naive_vs_persistent # value is correct
Loading