Skip to content

Commit 15fabcc

Browse files
authored
fix sgl-kernel unit tests (#5666)
1 parent e62c495 commit 15fabcc

File tree

9 files changed

+313
-0
lines changed

9 files changed

+313
-0
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ set(SOURCES
199199
"csrc/speculative/eagle_utils.cu"
200200
"csrc/speculative/speculative_sampling.cu"
201201
"csrc/speculative/packbit.cu"
202+
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
202203
"csrc/common_extension.cc"
203204
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
204205
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"

sgl-kernel/csrc/common_extension.cc

100755100644
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
233233
"bool is_causal, float softcap, bool return_softmax, "
234234
"Generator? gen) -> Tensor[]");
235235
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
236+
237+
/*
238+
* From XGrammar
239+
*/
240+
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
241+
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
236242
}
237243

238244
REGISTER_EXTENSION(common_ops)
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
// Adapted from
2+
// https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu
3+
4+
/*
5+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
6+
* SPDX-License-Identifier: Apache-2.0
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
21+
// clang-format off
22+
#include <cuda_bf16.h>
23+
#include <cuda_fp16.h>
24+
#include <cuda_runtime.h>
25+
#include <torch/all.h>
26+
#include <ATen/cuda/CUDAContext.h>
27+
// clang-format on
28+
29+
#ifndef CUDART_INF_FP16
30+
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
31+
#endif
32+
33+
#ifndef CUDART_INF_BF16
34+
#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
35+
#endif
36+
37+
constexpr int32_t BITS_PER_BLOCK = 32;
38+
constexpr int32_t THREADS_PER_THREAD_BLOCK = 256;
39+
40+
template <typename T>
41+
__device__ T NegativeInfinity() {
42+
return -INFINITY;
43+
}
44+
45+
template <>
46+
__device__ __half NegativeInfinity<__half>() {
47+
return -CUDART_INF_FP16;
48+
}
49+
50+
template <>
51+
__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() {
52+
return -CUDART_INF_BF16;
53+
}
54+
55+
template <typename T, typename PackedT>
56+
__device__ PackedT PackedNegativeInfinity() {
57+
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
58+
T packed[kAlignment];
59+
#pragma unroll
60+
for (int i = 0; i < kAlignment; i++) {
61+
packed[i] = NegativeInfinity<T>();
62+
}
63+
return *reinterpret_cast<PackedT*>(packed);
64+
}
65+
66+
template <typename T, typename PackedT, int32_t kBitsPerThread>
67+
__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel(
68+
T* __restrict__ logits,
69+
const int32_t* __restrict__ bitmask,
70+
const int32_t* __restrict__ indices,
71+
int32_t vocab_size,
72+
int32_t logits_stride,
73+
int32_t bitmask_stride) {
74+
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
75+
constexpr uint32_t kPackedMask = (1 << kAlignment) - 1;
76+
77+
const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y];
78+
79+
const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread;
80+
T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset;
81+
const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK;
82+
const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment);
83+
T logits_reg[kAlignment];
84+
85+
#pragma unroll
86+
for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread;
87+
offset += THREADS_PER_THREAD_BLOCK * kAlignment) {
88+
if (block_offset + offset >= vocab_size) {
89+
break;
90+
}
91+
92+
const uint32_t bitmask_val =
93+
(~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask;
94+
95+
if (bitmask_val == 0) {
96+
continue;
97+
}
98+
99+
if (bitmask_val == kPackedMask) {
100+
*reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = PackedNegativeInfinity<T, PackedT>();
101+
continue;
102+
}
103+
104+
*reinterpret_cast<PackedT*>(logits_reg) = *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset);
105+
#pragma unroll
106+
for (int i = 0; i < kAlignment; i++) {
107+
if (((bitmask_val >> i) & 1)) {
108+
logits_reg[i] = NegativeInfinity<T>();
109+
}
110+
}
111+
*reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = *reinterpret_cast<PackedT*>(logits_reg);
112+
}
113+
}
114+
115+
template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
116+
constexpr auto CeilDiv(T numerator, T denominator) {
117+
return (numerator + denominator - 1) / denominator;
118+
}
119+
120+
template <typename T, typename PackedT>
121+
void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(
122+
T* __restrict__ logits,
123+
const int32_t* __restrict__ bitmask,
124+
const int32_t* __restrict__ indices,
125+
int32_t vocab_size,
126+
int32_t logits_stride,
127+
int32_t bitmask_stride,
128+
int32_t num_rows) {
129+
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
130+
const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows);
131+
const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row);
132+
133+
const dim3 block(THREADS_PER_THREAD_BLOCK);
134+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
135+
136+
if (num_bits_per_thread <= 4 && kAlignment <= 4) {
137+
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows);
138+
LogitsBitmaskKernel<T, PackedT, 4>
139+
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
140+
} else if (num_bits_per_thread <= 8 && kAlignment <= 8) {
141+
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows);
142+
LogitsBitmaskKernel<T, PackedT, 8>
143+
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
144+
} else if (num_bits_per_thread <= 16 && kAlignment <= 16) {
145+
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows);
146+
LogitsBitmaskKernel<T, PackedT, 16>
147+
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
148+
} else {
149+
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows);
150+
LogitsBitmaskKernel<T, PackedT, 32>
151+
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
152+
}
153+
}
154+
155+
template <typename T>
156+
void ApplyTokenBitmaskInplaceDispatchToPackedT(
157+
T* __restrict__ logits,
158+
const int32_t* __restrict__ bitmask,
159+
const int32_t* __restrict__ indices,
160+
int32_t vocab_size,
161+
int32_t logits_stride,
162+
int32_t bitmask_stride,
163+
int32_t num_rows) {
164+
if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) {
165+
ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, float4>(
166+
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);
167+
} else {
168+
ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, T>(
169+
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);
170+
}
171+
}
172+
173+
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt) {
174+
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor.");
175+
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous.");
176+
TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor.");
177+
std::pair<int32_t, int32_t> logits_shape =
178+
logits.dim() == 2 ? std::make_pair(static_cast<int32_t>(logits.size(0)), static_cast<int32_t>(logits.size(1)))
179+
: std::make_pair(1, static_cast<int32_t>(logits.size(0)));
180+
181+
TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor.");
182+
TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous.");
183+
TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor.");
184+
std::pair<int32_t, int32_t> bitmask_shape =
185+
bitmask.dim() == 2 ? std::make_pair(static_cast<int32_t>(bitmask.size(0)), static_cast<int32_t>(bitmask.size(1)))
186+
: std::make_pair(1, static_cast<int32_t>(bitmask.size(0)));
187+
188+
TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32.");
189+
190+
TORCH_CHECK(
191+
(logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second,
192+
"The provided logits's vocab size should be no less than the bitmask's vocab size "
193+
"(converted from bitmask size). But got vocab size ",
194+
logits_shape.second,
195+
" vs bitmask size ",
196+
bitmask_shape.second);
197+
198+
int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK);
199+
200+
int32_t num_rows = logits_shape.first;
201+
int32_t* indices_ptr = nullptr;
202+
if (indices) {
203+
TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor.");
204+
TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous.");
205+
TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor.");
206+
TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32.");
207+
num_rows = indices->size(0);
208+
indices_ptr = indices->data_ptr<int32_t>();
209+
} else {
210+
TORCH_CHECK(logits_shape.first == bitmask_shape.first, "logits and bitmask must have the same batch size.");
211+
}
212+
213+
switch (logits.scalar_type()) {
214+
case torch::kFloat32: {
215+
ApplyTokenBitmaskInplaceDispatchToPackedT(
216+
logits.data_ptr<float>(),
217+
bitmask.data_ptr<int32_t>(),
218+
indices_ptr,
219+
vocab_size,
220+
logits_shape.second,
221+
bitmask_shape.second,
222+
num_rows);
223+
break;
224+
}
225+
case torch::kFloat16: {
226+
ApplyTokenBitmaskInplaceDispatchToPackedT(
227+
reinterpret_cast<__half*>(logits.data_ptr<torch::Half>()),
228+
bitmask.data_ptr<int32_t>(),
229+
indices_ptr,
230+
vocab_size,
231+
logits_shape.second,
232+
bitmask_shape.second,
233+
num_rows);
234+
break;
235+
}
236+
case torch::kBFloat16: {
237+
ApplyTokenBitmaskInplaceDispatchToPackedT(
238+
reinterpret_cast<__nv_bfloat16*>(logits.data_ptr<torch::BFloat16>()),
239+
bitmask.data_ptr<int32_t>(),
240+
indices_ptr,
241+
vocab_size,
242+
logits_shape.second,
243+
bitmask_shape.second,
244+
num_rows);
245+
break;
246+
}
247+
default:
248+
TORCH_CHECK(false, "logits dtype must be float, half or bfloat16.");
249+
break;
250+
}
251+
}

sgl-kernel/include/sgl_kernel_ops.h

100755100644
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,8 @@ std::vector<at::Tensor> mha_varlen_fwd_sparse(
352352
const bool return_softmax,
353353
c10::optional<at::Generator> gen_);
354354
} // namespace flash
355+
356+
/*
357+
* From XGrammar
358+
*/
359+
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);

sgl-kernel/python/sgl_kernel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
sgl_per_token_group_quant_int8,
4242
sgl_per_token_quant_fp8,
4343
)
44+
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
4445
from sgl_kernel.moe import (
4546
fp8_blockwise_scaled_grouped_mm,
4647
moe_align_block_size,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import List, Optional, Union
2+
3+
import torch
4+
5+
6+
def apply_token_bitmask_inplace_cuda(
7+
logits: torch.Tensor,
8+
bitmask: torch.Tensor,
9+
indices: Optional[Union[List[int], torch.Tensor]] = None,
10+
) -> None:
11+
if isinstance(indices, list):
12+
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
13+
if indices is not None:
14+
indices = indices.to(logits.device)
15+
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
from sgl_kernel import apply_token_bitmask_inplace_cuda
4+
5+
6+
def test_apply_token_bitmask_inplace_kernel():
7+
neginf = float("-inf")
8+
bool_mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
9+
logits = torch.tensor(
10+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], dtype=torch.float32
11+
)
12+
expected = torch.where(bool_mask, logits, neginf)
13+
14+
logits_gpu = logits.to("cuda")
15+
bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda")
16+
apply_token_bitmask_inplace_cuda(logits_gpu, bitmask)
17+
torch.cuda.synchronize()
18+
torch.testing.assert_close(logits_gpu, expected.to("cuda"))
19+
20+
21+
if __name__ == "__main__":
22+
test_apply_token_bitmask_inplace_kernel()
23+
pytest.main([__file__])

sgl-kernel/tests/test_fp8_blockwise_moe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ def group_broadcast(t, shape):
4747
).to(out_dtype)
4848

4949

50+
def is_sm100_supported(device=None) -> bool:
51+
return (torch.cuda.get_device_capability(device)[0] == 10) and (
52+
torch.version.cuda >= "12.8"
53+
)
54+
55+
56+
@pytest.mark.skipif(
57+
not is_sm100_supported(),
58+
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100",
59+
)
5060
@pytest.mark.parametrize("num_experts", [8, 16])
5161
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
5262
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):

sgl-kernel/tests/test_moe_fused_gate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
4848
topk_group=topk_group,
4949
compiled=False,
5050
n_share_experts_fusion=n_share_experts_fusion,
51+
routed_scaling_factor=2.5,
5152
)
5253

5354
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension

0 commit comments

Comments
 (0)