Skip to content

Commit 4b27040

Browse files
authored
misc: make max_top_p/k_rounds a input argument instead of template parameter (#219)
max_top_p/k_rounds doesn't have to be a template parameter.
1 parent f978e02 commit 4b27040

File tree

4 files changed

+47
-66
lines changed

4 files changed

+47
-66
lines changed

include/flashinfer/sampling.cuh

+27-24
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
146146
output[bx] = (aggregate > u) ? temp_storage.data.sampled_id : d - 1;
147147
}
148148

149-
template <uint32_t MAX_TOP_K_ROUNDS, uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM,
150-
uint32_t VEC_SIZE, typename DType, typename IdType>
149+
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM, uint32_t VEC_SIZE, typename DType,
150+
typename IdType>
151151
__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output,
152-
bool* success, uint32_t k, uint32_t d) {
152+
bool* success, uint32_t k, uint32_t d,
153+
uint32_t max_top_k_rounds) {
153154
const uint32_t batch_size = gridDim.x;
154155
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
155156

@@ -163,7 +164,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
163164
DType q = DType(0);
164165
DType pivot = DType(0);
165166
IdType sampled_id;
166-
for (uint32_t round = 0; round < MAX_TOP_K_ROUNDS; ++round) {
167+
for (uint32_t round = 0; round < max_top_k_rounds; ++round) {
167168
DType u = uniform_samples[round * batch_size + bx] * (1 - q);
168169
aggregate = DType(0);
169170
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -230,11 +231,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
230231

231232
constexpr float eps = 1e-5;
232233

233-
template <uint32_t MAX_TOP_P_ROUNDS, uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM,
234-
uint32_t VEC_SIZE, typename DType, typename IdType>
234+
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM, uint32_t VEC_SIZE, typename DType,
235+
typename IdType>
235236
__global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output,
236237
bool* success, IdType* row_indices, float* top_p_arr,
237-
float top_p, uint32_t d) {
238+
float top_p, uint32_t d, uint32_t max_top_p_rounds) {
238239
const uint32_t batch_size = gridDim.x;
239240
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
240241

@@ -253,7 +254,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
253254
DType q = DType(0);
254255
DType pivot = DType(0);
255256
IdType sampled_id;
256-
for (uint32_t round = 0; round < MAX_TOP_P_ROUNDS; ++round) {
257+
for (uint32_t round = 0; round < max_top_p_rounds; ++round) {
257258
DType u = uniform_samples[round * batch_size + bx] * (1 - q);
258259
aggregate = DType(0);
259260
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -356,33 +357,33 @@ cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* outpu
356357
return cudaSuccess;
357358
}
358359

359-
template <uint32_t MAX_TOP_K_ROUNDS, typename T, typename IdType>
360+
template <typename T, typename IdType>
360361
cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success,
361362
IdType top_k, uint32_t batch_size, uint32_t d,
362-
cudaStream_t stream = 0) {
363+
uint32_t max_top_k_rounds, cudaStream_t stream = 0) {
363364
constexpr uint32_t BLOCK_THREADS = 1024;
364365
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
365366

366367
const uint32_t smem_size =
367368
sizeof(SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
368369
dim3 nblks(batch_size);
369370
dim3 nthrs(BLOCK_THREADS);
370-
void* args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d};
371+
void* args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d, &max_top_k_rounds};
371372

372373
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
373-
auto kernel = TopKSamplingFromProbKernel<MAX_TOP_K_ROUNDS, BLOCK_THREADS,
374-
BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
374+
auto kernel =
375+
TopKSamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
375376
FLASHINFER_CUDA_CALL(
376377
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
377378
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
378379
});
379380
return cudaSuccess;
380381
}
381382

382-
template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
383+
template <typename T, typename IdType>
383384
cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success,
384385
T top_p, uint32_t batch_size, uint32_t d,
385-
cudaStream_t stream = 0) {
386+
uint32_t max_top_p_rounds, cudaStream_t stream = 0) {
386387
constexpr uint32_t BLOCK_THREADS = 1024;
387388
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
388389

@@ -399,22 +400,24 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
399400
&row_indices_placeholder,
400401
&top_p_arr_placeholder,
401402
&top_p,
402-
&d};
403+
&d,
404+
&max_top_p_rounds};
403405

404406
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
405-
auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,
406-
BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
407+
auto kernel =
408+
TopPSamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
407409
FLASHINFER_CUDA_CALL(
408410
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
409411
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
410412
});
411413
return cudaSuccess;
412414
}
413415

414-
template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
416+
template <typename T, typename IdType>
415417
cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
416418
bool* success, IdType* row_indices, T* top_p_arr,
417-
uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) {
419+
uint32_t batch_size, uint32_t d, uint32_t max_top_p_rounds,
420+
cudaStream_t stream = 0) {
418421
constexpr uint32_t BLOCK_THREADS = 1024;
419422
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
420423

@@ -423,12 +426,12 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o
423426
dim3 nblks(batch_size);
424427
dim3 nthrs(BLOCK_THREADS);
425428
T top_p_placeholder = 0;
426-
void* args[] = {&probs, &uniform_samples, &output, &success, &row_indices,
427-
&top_p_arr, &top_p_placeholder, &d};
429+
void* args[] = {&probs, &uniform_samples, &output, &success, &row_indices,
430+
&top_p_arr, &top_p_placeholder, &d, &max_top_p_rounds};
428431

429432
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
430-
auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,
431-
BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
433+
auto kernel =
434+
TopPSamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
432435
FLASHINFER_CUDA_CALL(
433436
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
434437
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));

src/bench_sampling.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ void bench_top_p_sampling_with_probability(nvbench::state& state) {
9696

9797
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
9898
timer.start();
99-
cudaError_t status = sampling::TopPSamplingFromProb<max_top_p_rounds, T, int32_t>(
99+
cudaError_t status = sampling::TopPSamplingFromProb<T, int32_t>(
100100
thrust::raw_pointer_cast(probs_d.data()),
101101
thrust::raw_pointer_cast(uniform_samples_d.data()),
102102
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), p,
103-
batch_size, vocab_size);
103+
batch_size, vocab_size, max_top_p_rounds);
104104
timer.stop();
105105
if (status != cudaSuccess) {
106106
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
@@ -141,11 +141,11 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) {
141141

142142
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
143143
timer.start();
144-
cudaError_t status = sampling::TopKSamplingFromProb<max_top_k_rounds, T, int32_t>(
144+
cudaError_t status = sampling::TopKSamplingFromProb<T, int32_t>(
145145
thrust::raw_pointer_cast(probs_d.data()),
146146
thrust::raw_pointer_cast(uniform_samples_d.data()),
147147
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), k,
148-
batch_size, vocab_size);
148+
batch_size, vocab_size, max_top_k_rounds);
149149
timer.stop();
150150
if (status != cudaSuccess) {
151151
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));

src/test_sampling.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ void _TestTopKSamplingFromProb(size_t batch_size, uint32_t k, size_t vocab_size)
5656
utils::vec_uniform_<T>(uniform_samples_h, 0, 1);
5757
thrust::device_vector<T> uniform_samples_d(uniform_samples_h);
5858

59-
auto status = sampling::TopKSamplingFromProb<max_top_p_rounds, T, IdType>(
59+
auto status = sampling::TopKSamplingFromProb<T, IdType>(
6060
thrust::raw_pointer_cast(probs_d.data()),
6161
thrust::raw_pointer_cast(uniform_samples_d.data()),
6262
thrust::raw_pointer_cast(sampled_ids_d.data()), thrust::raw_pointer_cast(success_d.data()),
63-
k, batch_size, vocab_size);
63+
k, batch_size, vocab_size, max_top_p_rounds);
6464

6565
EXPECT_EQ(status, cudaSuccess) << "TopKSamplingFromProb kernel launch failed, error message: "
6666
<< cudaGetErrorString(status);
@@ -121,11 +121,11 @@ void _TestTopPSamplingFromProb(size_t batch_size, uint32_t k, size_t vocab_size)
121121
utils::vec_uniform_<T>(uniform_samples_h, 0, 1);
122122
thrust::device_vector<T> uniform_samples_d(uniform_samples_h);
123123

124-
auto status = sampling::TopPSamplingFromProb<max_top_p_rounds, T, IdType>(
124+
auto status = sampling::TopPSamplingFromProb<T, IdType>(
125125
thrust::raw_pointer_cast(probs_d.data()),
126126
thrust::raw_pointer_cast(uniform_samples_d.data()),
127127
thrust::raw_pointer_cast(sampled_ids_d.data()), thrust::raw_pointer_cast(success_d.data()),
128-
p, batch_size, vocab_size);
128+
p, batch_size, vocab_size, max_top_p_rounds);
129129

130130
EXPECT_EQ(status, cudaSuccess) << "TopPSamplingFromProb kernel launch failed, error message: "
131131
<< cudaGetErrorString(status);

src/tvm_wrapper.cu

+12-34
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,6 @@ using namespace flashinfer;
4848
LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \
4949
}
5050

51-
#define DISPATCH_REJECTIVE_SAMPLING_NUM_ROUNDS(num_rounds, NUM_ROUNDS, ...) \
52-
if (num_rounds == 1) { \
53-
constexpr bool NUM_ROUNDS = 1; \
54-
__VA_ARGS__ \
55-
} else if (num_rounds == 2) { \
56-
constexpr bool NUM_ROUNDS = 2; \
57-
__VA_ARGS__ \
58-
} else if (num_rounds == 4) { \
59-
constexpr bool NUM_ROUNDS = 4; \
60-
__VA_ARGS__ \
61-
} else if (num_rounds == 8) { \
62-
constexpr bool NUM_ROUNDS = 8; \
63-
__VA_ARGS__ \
64-
} else if (num_rounds == 16) { \
65-
constexpr bool NUM_ROUNDS = 16; \
66-
__VA_ARGS__ \
67-
} else { \
68-
LOG(FATAL) << "Unsupported number of rejective sampling rounds " << num_rounds; \
69-
}
70-
7151
int _FlashInferSinglePrefillWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp,
7252
bool causal, int64_t kv_layout, int64_t pos_encoding_mode,
7353
bool allow_fp16_qk_reduction, double rope_scale,
@@ -739,7 +719,7 @@ void _FlashInferParallelSamplingFromProb(DLTensor* probs, DLTensor* uniform_samp
739719

740720
void _FlashInferParallelTopPSamplingFromProb(DLTensor* probs, DLTensor* uniform_samples,
741721
DLTensor* row_indices, DLTensor* top_p,
742-
DLTensor* sampled_token_ids, int num_rounds) {
722+
DLTensor* sampled_token_ids) {
743723
CHECK_EQ(probs->device.device_type, kDLCUDA) << "The device of probs must be CUDA.";
744724
CHECK_EQ(uniform_samples->device.device_type, kDLCUDA)
745725
<< "The device of uniform_samples must be CUDA.";
@@ -764,28 +744,26 @@ void _FlashInferParallelTopPSamplingFromProb(DLTensor* probs, DLTensor* uniform_
764744
CHECK(sampled_token_ids->dtype.code == kDLInt && sampled_token_ids->dtype.bits == 32);
765745

766746
CHECK_EQ(probs->ndim, 2); // num_probs, vocab_size
767-
CHECK_EQ(uniform_samples->ndim, 1); // batch_size * num_rounds,
747+
CHECK_EQ(uniform_samples->ndim, 2); // num_rounds, batch_size
768748
CHECK_EQ(row_indices->ndim, 1); // batch_size,
769749
CHECK_EQ(top_p->ndim, 1); // num_probs,
770750
CHECK_EQ(sampled_token_ids->ndim, 1); // batch_size,
771751
int64_t num_probs = probs->shape[0];
772752
int64_t vocab_size = probs->shape[1];
773753
int64_t batch_size = row_indices->shape[0];
774-
CHECK_EQ(uniform_samples->shape[0], batch_size * num_rounds);
754+
int64_t num_rounds = uniform_samples->shape[0];
755+
CHECK_EQ(uniform_samples->shape[1], batch_size);
775756
CHECK_EQ(top_p->shape[0], num_probs);
776757
CHECK_EQ(sampled_token_ids->shape[0], batch_size);
777758

778-
DISPATCH_REJECTIVE_SAMPLING_NUM_ROUNDS(num_rounds, rej_samping_num_rounds, {
779-
cudaError_t status =
780-
sampling::ParallelTopPSamplingFromProb<rej_samping_num_rounds, float, int32_t>(
781-
static_cast<float*>(probs->data), static_cast<float*>(uniform_samples->data),
782-
static_cast<int32_t*>(sampled_token_ids->data), /*success=*/nullptr,
783-
static_cast<int32_t*>(row_indices->data), static_cast<float*>(top_p->data), batch_size,
784-
vocab_size);
785-
if (status != cudaSuccess) {
786-
LOG(FATAL) << "FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString(status);
787-
}
788-
});
759+
cudaError_t status = sampling::ParallelTopPSamplingFromProb<float, int32_t>(
760+
static_cast<float*>(probs->data), static_cast<float*>(uniform_samples->data),
761+
static_cast<int32_t*>(sampled_token_ids->data), /*success=*/nullptr,
762+
static_cast<int32_t*>(row_indices->data), static_cast<float*>(top_p->data), batch_size,
763+
vocab_size, num_rounds);
764+
if (status != cudaSuccess) {
765+
LOG(FATAL) << "FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString(status);
766+
}
789767
}
790768

791769
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache")

0 commit comments

Comments
 (0)