|
| 1 | +// Adapted from |
| 2 | +// https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu |
| 3 | + |
| 4 | +/* |
| 5 | + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 6 | + * SPDX-License-Identifier: Apache-2.0 |
| 7 | + * |
| 8 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 9 | + * you may not use this file except in compliance with the License. |
| 10 | + * You may obtain a copy of the License at |
| 11 | + * |
| 12 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 13 | + * |
| 14 | + * Unless required by applicable law or agreed to in writing, software |
| 15 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 16 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 17 | + * See the License for the specific language governing permissions and |
| 18 | + * limitations under the License. |
| 19 | + */ |
| 20 | + |
| 21 | +// clang-format off |
| 22 | +#include <cuda_bf16.h> |
| 23 | +#include <cuda_fp16.h> |
| 24 | +#include <cuda_runtime.h> |
| 25 | +#include <torch/all.h> |
| 26 | +#include <ATen/cuda/CUDAContext.h> |
| 27 | +// clang-format on |
| 28 | + |
| 29 | +#ifndef CUDART_INF_FP16 |
| 30 | +#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) |
| 31 | +#endif |
| 32 | + |
| 33 | +#ifndef CUDART_INF_BF16 |
| 34 | +#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) |
| 35 | +#endif |
| 36 | + |
| 37 | +constexpr int32_t BITS_PER_BLOCK = 32; |
| 38 | +constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; |
| 39 | + |
| 40 | +template <typename T> |
| 41 | +__device__ T NegativeInfinity() { |
| 42 | + return -INFINITY; |
| 43 | +} |
| 44 | + |
| 45 | +template <> |
| 46 | +__device__ __half NegativeInfinity<__half>() { |
| 47 | + return -CUDART_INF_FP16; |
| 48 | +} |
| 49 | + |
| 50 | +template <> |
| 51 | +__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() { |
| 52 | + return -CUDART_INF_BF16; |
| 53 | +} |
| 54 | + |
| 55 | +template <typename T, typename PackedT> |
| 56 | +__device__ PackedT PackedNegativeInfinity() { |
| 57 | + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); |
| 58 | + T packed[kAlignment]; |
| 59 | +#pragma unroll |
| 60 | + for (int i = 0; i < kAlignment; i++) { |
| 61 | + packed[i] = NegativeInfinity<T>(); |
| 62 | + } |
| 63 | + return *reinterpret_cast<PackedT*>(packed); |
| 64 | +} |
| 65 | + |
| 66 | +template <typename T, typename PackedT, int32_t kBitsPerThread> |
| 67 | +__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel( |
| 68 | + T* __restrict__ logits, |
| 69 | + const int32_t* __restrict__ bitmask, |
| 70 | + const int32_t* __restrict__ indices, |
| 71 | + int32_t vocab_size, |
| 72 | + int32_t logits_stride, |
| 73 | + int32_t bitmask_stride) { |
| 74 | + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); |
| 75 | + constexpr uint32_t kPackedMask = (1 << kAlignment) - 1; |
| 76 | + |
| 77 | + const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y]; |
| 78 | + |
| 79 | + const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread; |
| 80 | + T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset; |
| 81 | + const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK; |
| 82 | + const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment); |
| 83 | + T logits_reg[kAlignment]; |
| 84 | + |
| 85 | +#pragma unroll |
| 86 | + for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread; |
| 87 | + offset += THREADS_PER_THREAD_BLOCK * kAlignment) { |
| 88 | + if (block_offset + offset >= vocab_size) { |
| 89 | + break; |
| 90 | + } |
| 91 | + |
| 92 | + const uint32_t bitmask_val = |
| 93 | + (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask; |
| 94 | + |
| 95 | + if (bitmask_val == 0) { |
| 96 | + continue; |
| 97 | + } |
| 98 | + |
| 99 | + if (bitmask_val == kPackedMask) { |
| 100 | + *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = PackedNegativeInfinity<T, PackedT>(); |
| 101 | + continue; |
| 102 | + } |
| 103 | + |
| 104 | + *reinterpret_cast<PackedT*>(logits_reg) = *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset); |
| 105 | +#pragma unroll |
| 106 | + for (int i = 0; i < kAlignment; i++) { |
| 107 | + if (((bitmask_val >> i) & 1)) { |
| 108 | + logits_reg[i] = NegativeInfinity<T>(); |
| 109 | + } |
| 110 | + } |
| 111 | + *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = *reinterpret_cast<PackedT*>(logits_reg); |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>> |
| 116 | +constexpr auto CeilDiv(T numerator, T denominator) { |
| 117 | + return (numerator + denominator - 1) / denominator; |
| 118 | +} |
| 119 | + |
| 120 | +template <typename T, typename PackedT> |
| 121 | +void ApplyTokenBitmaskInplaceDispatchToBitsPerThread( |
| 122 | + T* __restrict__ logits, |
| 123 | + const int32_t* __restrict__ bitmask, |
| 124 | + const int32_t* __restrict__ indices, |
| 125 | + int32_t vocab_size, |
| 126 | + int32_t logits_stride, |
| 127 | + int32_t bitmask_stride, |
| 128 | + int32_t num_rows) { |
| 129 | + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); |
| 130 | + const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows); |
| 131 | + const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row); |
| 132 | + |
| 133 | + const dim3 block(THREADS_PER_THREAD_BLOCK); |
| 134 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
| 135 | + |
| 136 | + if (num_bits_per_thread <= 4 && kAlignment <= 4) { |
| 137 | + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows); |
| 138 | + LogitsBitmaskKernel<T, PackedT, 4> |
| 139 | + <<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 140 | + } else if (num_bits_per_thread <= 8 && kAlignment <= 8) { |
| 141 | + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows); |
| 142 | + LogitsBitmaskKernel<T, PackedT, 8> |
| 143 | + <<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 144 | + } else if (num_bits_per_thread <= 16 && kAlignment <= 16) { |
| 145 | + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows); |
| 146 | + LogitsBitmaskKernel<T, PackedT, 16> |
| 147 | + <<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 148 | + } else { |
| 149 | + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows); |
| 150 | + LogitsBitmaskKernel<T, PackedT, 32> |
| 151 | + <<<grid, block, 0, stream>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 152 | + } |
| 153 | +} |
| 154 | + |
| 155 | +template <typename T> |
| 156 | +void ApplyTokenBitmaskInplaceDispatchToPackedT( |
| 157 | + T* __restrict__ logits, |
| 158 | + const int32_t* __restrict__ bitmask, |
| 159 | + const int32_t* __restrict__ indices, |
| 160 | + int32_t vocab_size, |
| 161 | + int32_t logits_stride, |
| 162 | + int32_t bitmask_stride, |
| 163 | + int32_t num_rows) { |
| 164 | + if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) { |
| 165 | + ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, float4>( |
| 166 | + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); |
| 167 | + } else { |
| 168 | + ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, T>( |
| 169 | + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); |
| 170 | + } |
| 171 | +} |
| 172 | + |
| 173 | +void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt) { |
| 174 | + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor."); |
| 175 | + TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous."); |
| 176 | + TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor."); |
| 177 | + std::pair<int32_t, int32_t> logits_shape = |
| 178 | + logits.dim() == 2 ? std::make_pair(static_cast<int32_t>(logits.size(0)), static_cast<int32_t>(logits.size(1))) |
| 179 | + : std::make_pair(1, static_cast<int32_t>(logits.size(0))); |
| 180 | + |
| 181 | + TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor."); |
| 182 | + TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous."); |
| 183 | + TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor."); |
| 184 | + std::pair<int32_t, int32_t> bitmask_shape = |
| 185 | + bitmask.dim() == 2 ? std::make_pair(static_cast<int32_t>(bitmask.size(0)), static_cast<int32_t>(bitmask.size(1))) |
| 186 | + : std::make_pair(1, static_cast<int32_t>(bitmask.size(0))); |
| 187 | + |
| 188 | + TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32."); |
| 189 | + |
| 190 | + TORCH_CHECK( |
| 191 | + (logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second, |
| 192 | + "The provided logits's vocab size should be no less than the bitmask's vocab size " |
| 193 | + "(converted from bitmask size). But got vocab size ", |
| 194 | + logits_shape.second, |
| 195 | + " vs bitmask size ", |
| 196 | + bitmask_shape.second); |
| 197 | + |
| 198 | + int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK); |
| 199 | + |
| 200 | + int32_t num_rows = logits_shape.first; |
| 201 | + int32_t* indices_ptr = nullptr; |
| 202 | + if (indices) { |
| 203 | + TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor."); |
| 204 | + TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous."); |
| 205 | + TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor."); |
| 206 | + TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32."); |
| 207 | + num_rows = indices->size(0); |
| 208 | + indices_ptr = indices->data_ptr<int32_t>(); |
| 209 | + } else { |
| 210 | + TORCH_CHECK(logits_shape.first == bitmask_shape.first, "logits and bitmask must have the same batch size."); |
| 211 | + } |
| 212 | + |
| 213 | + switch (logits.scalar_type()) { |
| 214 | + case torch::kFloat32: { |
| 215 | + ApplyTokenBitmaskInplaceDispatchToPackedT( |
| 216 | + logits.data_ptr<float>(), |
| 217 | + bitmask.data_ptr<int32_t>(), |
| 218 | + indices_ptr, |
| 219 | + vocab_size, |
| 220 | + logits_shape.second, |
| 221 | + bitmask_shape.second, |
| 222 | + num_rows); |
| 223 | + break; |
| 224 | + } |
| 225 | + case torch::kFloat16: { |
| 226 | + ApplyTokenBitmaskInplaceDispatchToPackedT( |
| 227 | + reinterpret_cast<__half*>(logits.data_ptr<torch::Half>()), |
| 228 | + bitmask.data_ptr<int32_t>(), |
| 229 | + indices_ptr, |
| 230 | + vocab_size, |
| 231 | + logits_shape.second, |
| 232 | + bitmask_shape.second, |
| 233 | + num_rows); |
| 234 | + break; |
| 235 | + } |
| 236 | + case torch::kBFloat16: { |
| 237 | + ApplyTokenBitmaskInplaceDispatchToPackedT( |
| 238 | + reinterpret_cast<__nv_bfloat16*>(logits.data_ptr<torch::BFloat16>()), |
| 239 | + bitmask.data_ptr<int32_t>(), |
| 240 | + indices_ptr, |
| 241 | + vocab_size, |
| 242 | + logits_shape.second, |
| 243 | + bitmask_shape.second, |
| 244 | + num_rows); |
| 245 | + break; |
| 246 | + } |
| 247 | + default: |
| 248 | + TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); |
| 249 | + break; |
| 250 | + } |
| 251 | +} |
0 commit comments