Skip to content

fix sgl-kernel unit tests #5666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ set(SOURCES
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/common_extension.cc"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
Expand Down
6 changes: 6 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"bool is_causal, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);

/*
* From XGrammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
}

REGISTER_EXTENSION(common_ops)
251 changes: 251 additions & 0 deletions sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
// Adapted from
// https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// clang-format off
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format on

#ifndef CUDART_INF_FP16
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
#endif

#ifndef CUDART_INF_BF16
#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
#endif

constexpr int32_t BITS_PER_BLOCK = 32;
constexpr int32_t THREADS_PER_THREAD_BLOCK = 256;

template <typename T>
__device__ T NegativeInfinity() {
return -INFINITY;
}

template <>
__device__ __half NegativeInfinity<__half>() {
return -CUDART_INF_FP16;
}

template <>
__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() {
return -CUDART_INF_BF16;
}

template <typename T, typename PackedT>
__device__ PackedT PackedNegativeInfinity() {
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
T packed[kAlignment];
#pragma unroll
for (int i = 0; i < kAlignment; i++) {
packed[i] = NegativeInfinity<T>();
}
return *reinterpret_cast<PackedT*>(packed);
}

template <typename T, typename PackedT, int32_t kBitsPerThread>
__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel(
T* __restrict__ logits,
const int32_t* __restrict__ bitmask,
const int32_t* __restrict__ indices,
int32_t vocab_size,
int32_t logits_stride,
int32_t bitmask_stride) {
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
constexpr uint32_t kPackedMask = (1 << kAlignment) - 1;

const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y];

const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread;
T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset;
const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK;
const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment);
T logits_reg[kAlignment];

#pragma unroll
for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread;
offset += THREADS_PER_THREAD_BLOCK * kAlignment) {
if (block_offset + offset >= vocab_size) {
break;
}

const uint32_t bitmask_val =
(~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask;

if (bitmask_val == 0) {
continue;
}

if (bitmask_val == kPackedMask) {
*reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = PackedNegativeInfinity<T, PackedT>();
continue;
}

*reinterpret_cast<PackedT*>(logits_reg) = *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset);
#pragma unroll
for (int i = 0; i < kAlignment; i++) {
if (((bitmask_val >> i) & 1)) {
logits_reg[i] = NegativeInfinity<T>();
}
}
*reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = *reinterpret_cast<PackedT*>(logits_reg);
}
}

template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
constexpr auto CeilDiv(T numerator, T denominator) {
return (numerator + denominator - 1) / denominator;
}

template <typename T, typename PackedT>
void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(
T* __restrict__ logits,
const int32_t* __restrict__ bitmask,
const int32_t* __restrict__ indices,
int32_t vocab_size,
int32_t logits_stride,
int32_t bitmask_stride,
int32_t num_rows) {
constexpr int kAlignment = sizeof(PackedT) / sizeof(T);
const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows);
const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row);

const dim3 block(THREADS_PER_THREAD_BLOCK);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();

if (num_bits_per_thread <= 4 && kAlignment <= 4) {
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows);
LogitsBitmaskKernel<T, PackedT, 4>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
} else if (num_bits_per_thread <= 8 && kAlignment <= 8) {
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows);
LogitsBitmaskKernel<T, PackedT, 8>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
} else if (num_bits_per_thread <= 16 && kAlignment <= 16) {
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows);
LogitsBitmaskKernel<T, PackedT, 16>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
} else {
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows);
LogitsBitmaskKernel<T, PackedT, 32>
<<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
}
}

template <typename T>
void ApplyTokenBitmaskInplaceDispatchToPackedT(
T* __restrict__ logits,
const int32_t* __restrict__ bitmask,
const int32_t* __restrict__ indices,
int32_t vocab_size,
int32_t logits_stride,
int32_t bitmask_stride,
int32_t num_rows) {
if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) {
ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, float4>(
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);
} else {
ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, T>(
logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);
}
}

void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt) {
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor.");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous.");
TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor.");
std::pair<int32_t, int32_t> logits_shape =
logits.dim() == 2 ? std::make_pair(static_cast<int32_t>(logits.size(0)), static_cast<int32_t>(logits.size(1)))
: std::make_pair(1, static_cast<int32_t>(logits.size(0)));

TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor.");
TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous.");
TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor.");
std::pair<int32_t, int32_t> bitmask_shape =
bitmask.dim() == 2 ? std::make_pair(static_cast<int32_t>(bitmask.size(0)), static_cast<int32_t>(bitmask.size(1)))
: std::make_pair(1, static_cast<int32_t>(bitmask.size(0)));

TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32.");

TORCH_CHECK(
(logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second,
"The provided logits's vocab size should be no less than the bitmask's vocab size "
"(converted from bitmask size). But got vocab size ",
logits_shape.second,
" vs bitmask size ",
bitmask_shape.second);

int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK);

int32_t num_rows = logits_shape.first;
int32_t* indices_ptr = nullptr;
if (indices) {
TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor.");
TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous.");
TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor.");
TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32.");
num_rows = indices->size(0);
indices_ptr = indices->data_ptr<int32_t>();
} else {
TORCH_CHECK(logits_shape.first == bitmask_shape.first, "logits and bitmask must have the same batch size.");
}

switch (logits.scalar_type()) {
case torch::kFloat32: {
ApplyTokenBitmaskInplaceDispatchToPackedT(
logits.data_ptr<float>(),
bitmask.data_ptr<int32_t>(),
indices_ptr,
vocab_size,
logits_shape.second,
bitmask_shape.second,
num_rows);
break;
}
case torch::kFloat16: {
ApplyTokenBitmaskInplaceDispatchToPackedT(
reinterpret_cast<__half*>(logits.data_ptr<torch::Half>()),
bitmask.data_ptr<int32_t>(),
indices_ptr,
vocab_size,
logits_shape.second,
bitmask_shape.second,
num_rows);
break;
}
case torch::kBFloat16: {
ApplyTokenBitmaskInplaceDispatchToPackedT(
reinterpret_cast<__nv_bfloat16*>(logits.data_ptr<torch::BFloat16>()),
bitmask.data_ptr<int32_t>(),
indices_ptr,
vocab_size,
logits_shape.second,
bitmask_shape.second,
num_rows);
break;
}
default:
TORCH_CHECK(false, "logits dtype must be float, half or bfloat16.");
break;
}
}
5 changes: 5 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,8 @@ std::vector<at::Tensor> mha_varlen_fwd_sparse(
const bool return_softmax,
c10::optional<at::Generator> gen_);
} // namespace flash

/*
* From XGrammar
*/
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);
1 change: 1 addition & 0 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,
Expand Down
15 changes: 15 additions & 0 deletions sgl-kernel/python/sgl_kernel/grammar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List, Optional, Union

import torch


def apply_token_bitmask_inplace_cuda(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
) -> None:
if isinstance(indices, list):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
if indices is not None:
indices = indices.to(logits.device)
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
23 changes: 23 additions & 0 deletions sgl-kernel/tests/test_apply_token_bitmask_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
import torch
from sgl_kernel import apply_token_bitmask_inplace_cuda


def test_apply_token_bitmask_inplace_kernel():
neginf = float("-inf")
bool_mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
logits = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], dtype=torch.float32
)
expected = torch.where(bool_mask, logits, neginf)

logits_gpu = logits.to("cuda")
bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda")
apply_token_bitmask_inplace_cuda(logits_gpu, bitmask)
torch.cuda.synchronize()
torch.testing.assert_close(logits_gpu, expected.to("cuda"))


if __name__ == "__main__":
test_apply_token_bitmask_inplace_kernel()
pytest.main([__file__])
10 changes: 10 additions & 0 deletions sgl-kernel/tests/test_fp8_blockwise_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def group_broadcast(t, shape):
).to(out_dtype)


def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)


@pytest.mark.skipif(
not is_sm100_supported(),
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100",
)
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/tests/test_moe_fused_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
topk_group=topk_group,
compiled=False,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=2.5,
)

# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
Expand Down
Loading