@@ -146,10 +146,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
146
146
output[bx] = (aggregate > u) ? temp_storage.data .sampled_id : d - 1 ;
147
147
}
148
148
149
- template <uint32_t MAX_TOP_K_ROUNDS, uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM ,
150
- uint32_t VEC_SIZE, typename DType, typename IdType>
149
+ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM, uint32_t VEC_SIZE, typename DType ,
150
+ typename IdType>
151
151
__global__ void TopKSamplingFromProbKernel (DType* probs, DType* uniform_samples, IdType* output,
152
- bool * success, uint32_t k, uint32_t d) {
152
+ bool * success, uint32_t k, uint32_t d,
153
+ uint32_t max_top_k_rounds) {
153
154
const uint32_t batch_size = gridDim .x ;
154
155
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
155
156
@@ -163,7 +164,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
163
164
DType q = DType (0 );
164
165
DType pivot = DType (0 );
165
166
IdType sampled_id;
166
- for (uint32_t round = 0 ; round < MAX_TOP_K_ROUNDS ; ++round ) {
167
+ for (uint32_t round = 0 ; round < max_top_k_rounds ; ++round ) {
167
168
DType u = uniform_samples[round * batch_size + bx] * (1 - q);
168
169
aggregate = DType (0 );
169
170
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -230,11 +231,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
230
231
231
232
constexpr float eps = 1e-5 ;
232
233
233
- template <uint32_t MAX_TOP_P_ROUNDS, uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM ,
234
- uint32_t VEC_SIZE, typename DType, typename IdType>
234
+ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM, uint32_t VEC_SIZE, typename DType ,
235
+ typename IdType>
235
236
__global__ void TopPSamplingFromProbKernel (DType* probs, DType* uniform_samples, IdType* output,
236
237
bool * success, IdType* row_indices, float * top_p_arr,
237
- float top_p, uint32_t d) {
238
+ float top_p, uint32_t d, uint32_t max_top_p_rounds ) {
238
239
const uint32_t batch_size = gridDim .x ;
239
240
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
240
241
@@ -253,7 +254,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
253
254
DType q = DType (0 );
254
255
DType pivot = DType (0 );
255
256
IdType sampled_id;
256
- for (uint32_t round = 0 ; round < MAX_TOP_P_ROUNDS ; ++round ) {
257
+ for (uint32_t round = 0 ; round < max_top_p_rounds ; ++round ) {
257
258
DType u = uniform_samples[round * batch_size + bx] * (1 - q);
258
259
aggregate = DType (0 );
259
260
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -356,33 +357,33 @@ cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* outpu
356
357
return cudaSuccess;
357
358
}
358
359
359
- template <uint32_t MAX_TOP_K_ROUNDS, typename T, typename IdType>
360
+ template <typename T, typename IdType>
360
361
cudaError_t TopKSamplingFromProb (T* probs, T* uniform_samples, IdType* output, bool * success,
361
362
IdType top_k, uint32_t batch_size, uint32_t d,
362
- cudaStream_t stream = 0 ) {
363
+ uint32_t max_top_k_rounds, cudaStream_t stream = 0 ) {
363
364
constexpr uint32_t BLOCK_THREADS = 1024 ;
364
365
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
365
366
366
367
const uint32_t smem_size =
367
368
sizeof (SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
368
369
dim3 nblks (batch_size);
369
370
dim3 nthrs (BLOCK_THREADS);
370
- void * args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d};
371
+ void * args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d, &max_top_k_rounds };
371
372
372
373
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
373
- auto kernel = TopKSamplingFromProbKernel<MAX_TOP_K_ROUNDS, BLOCK_THREADS,
374
- BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
374
+ auto kernel =
375
+ TopKSamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
375
376
FLASHINFER_CUDA_CALL (
376
377
cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
377
378
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
378
379
});
379
380
return cudaSuccess;
380
381
}
381
382
382
- template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
383
+ template <typename T, typename IdType>
383
384
cudaError_t TopPSamplingFromProb (T* probs, T* uniform_samples, IdType* output, bool * success,
384
385
T top_p, uint32_t batch_size, uint32_t d,
385
- cudaStream_t stream = 0 ) {
386
+ uint32_t max_top_p_rounds, cudaStream_t stream = 0 ) {
386
387
constexpr uint32_t BLOCK_THREADS = 1024 ;
387
388
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
388
389
@@ -399,22 +400,24 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
399
400
&row_indices_placeholder,
400
401
&top_p_arr_placeholder,
401
402
&top_p,
402
- &d};
403
+ &d,
404
+ &max_top_p_rounds};
403
405
404
406
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
405
- auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,
406
- BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
407
+ auto kernel =
408
+ TopPSamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
407
409
FLASHINFER_CUDA_CALL (
408
410
cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
409
411
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
410
412
});
411
413
return cudaSuccess;
412
414
}
413
415
414
- template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
416
+ template <typename T, typename IdType>
415
417
cudaError_t ParallelTopPSamplingFromProb (T* probs, T* uniform_samples, IdType* output,
416
418
bool * success, IdType* row_indices, T* top_p_arr,
417
- uint32_t batch_size, uint32_t d, cudaStream_t stream = 0 ) {
419
+ uint32_t batch_size, uint32_t d, uint32_t max_top_p_rounds,
420
+ cudaStream_t stream = 0 ) {
418
421
constexpr uint32_t BLOCK_THREADS = 1024 ;
419
422
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
420
423
@@ -423,12 +426,12 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o
423
426
dim3 nblks (batch_size);
424
427
dim3 nthrs (BLOCK_THREADS);
425
428
T top_p_placeholder = 0 ;
426
- void * args[] = {&probs, &uniform_samples, &output, &success, &row_indices,
427
- &top_p_arr, &top_p_placeholder, &d};
429
+ void * args[] = {&probs, &uniform_samples, &output, &success, &row_indices,
430
+ &top_p_arr, &top_p_placeholder, &d, &max_top_p_rounds };
428
431
429
432
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
430
- auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,
431
- BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
433
+ auto kernel =
434
+ TopPSamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
432
435
FLASHINFER_CUDA_CALL (
433
436
cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
434
437
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
0 commit comments