16
16
#ifndef FLASHINFER_SAMPLING_CUH_
17
17
#define FLASHINFER_SAMPLING_CUH_
18
18
19
- #include < cstdint>
20
19
#include < cub/block/block_adjacent_difference.cuh>
21
20
#include < cub/block/block_reduce.cuh>
22
21
#include < cub/block/block_scan.cuh>
@@ -229,10 +228,16 @@ constexpr float eps = 1e-5;
229
228
template <uint32_t MAX_TOP_P_ROUNDS, uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM,
230
229
uint32_t VEC_SIZE, typename DType, typename IdType>
231
230
__global__ void TopPSamplingFromProbKernel (DType* probs, DType* uniform_samples, IdType* output,
232
- bool * success, float p, uint32_t d) {
231
+ bool * success, IdType* row_indices, float * top_p_arr,
232
+ float top_p, uint32_t d) {
233
233
const uint32_t batch_size = gridDim .x ;
234
234
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
235
235
236
+ if (top_p_arr != nullptr ) {
237
+ top_p = top_p_arr[bx];
238
+ }
239
+ const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
240
+
236
241
extern __shared__ __align__ (alignof (SamplingTempStorage<DType, BLOCK_THREADS, ALGORITHM>))
237
242
uint8_t smem[];
238
243
auto & temp_storage =
@@ -249,7 +254,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
249
254
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
250
255
probs_vec.fill (DType (0 ));
251
256
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
252
- probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
257
+ probs_vec.load (probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
253
258
}
254
259
255
260
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, ALGORITHM, DType>(i, pivot, u, probs_vec,
@@ -260,13 +265,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
260
265
}
261
266
__syncthreads ();
262
267
sampled_id = (aggregate > u) ? temp_storage.data .sampled_id : d - 1 ;
263
- pivot = probs[bx * d + sampled_id];
268
+ pivot = probs[row_idx * d + sampled_id];
264
269
265
270
DType aggregate_leq_pivot = DType (0 );
266
271
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
267
272
probs_vec.fill (DType (0 ));
268
273
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
269
- probs_vec.load (probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
274
+ probs_vec.load (probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
270
275
}
271
276
272
277
DType probs_leq_pivot[VEC_SIZE];
@@ -281,18 +286,18 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
281
286
temp_storage.data .block_aggregate .value = aggregate_leq_pivot;
282
287
}
283
288
__syncthreads ();
284
- if (temp_storage.data .block_aggregate .value + p > 1 + eps) {
289
+ if (temp_storage.data .block_aggregate .value + top_p > 1 + eps) {
285
290
break ;
286
291
}
287
292
}
288
293
q = temp_storage.data .block_aggregate .value ;
289
- if (q + p > 1 + eps) {
294
+ if (q + top_p > 1 + eps) {
290
295
break ;
291
296
}
292
297
}
293
298
__syncthreads ();
294
299
if (tx == 0 ) {
295
- if (q + p <= 1 + eps) {
300
+ if (q + top_p <= 1 + eps) {
296
301
// failed to sample within MAX_TOP_P_ROUNDS
297
302
success[bx] = false ;
298
303
} else {
@@ -323,7 +328,7 @@ cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint3
323
328
324
329
template <uint32_t MAX_TOP_K_ROUNDS, typename T, typename IdType>
325
330
cudaError_t TopKSamplingFromProb (T* probs, T* uniform_samples, IdType* output, bool * success,
326
- IdType k , uint32_t batch_size, uint32_t d,
331
+ IdType top_k , uint32_t batch_size, uint32_t d,
327
332
cudaStream_t stream = 0 ) {
328
333
constexpr uint32_t BLOCK_THREADS = 1024 ;
329
334
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
@@ -332,7 +337,7 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
332
337
sizeof (SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
333
338
dim3 nblks (batch_size);
334
339
dim3 nthrs (BLOCK_THREADS);
335
- void * args[] = {&probs, &uniform_samples, &output, &success, &k , &d};
340
+ void * args[] = {&probs, &uniform_samples, &output, &success, &top_k , &d};
336
341
337
342
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
338
343
auto kernel = TopKSamplingFromProbKernel<MAX_TOP_K_ROUNDS, BLOCK_THREADS,
@@ -345,16 +350,52 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
345
350
}
346
351
347
352
template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
348
- cudaError_t TopPSamplingFromProb (T* probs, T* uniform_samples, IdType* output, bool * success, T p,
349
- uint32_t batch_size, uint32_t d, cudaStream_t stream = 0 ) {
353
+ cudaError_t TopPSamplingFromProb (T* probs, T* uniform_samples, IdType* output, bool * success,
354
+ T top_p, uint32_t batch_size, uint32_t d,
355
+ cudaStream_t stream = 0 ) {
356
+ constexpr uint32_t BLOCK_THREADS = 1024 ;
357
+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
358
+
359
+ const uint32_t smem_size =
360
+ sizeof (SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
361
+ dim3 nblks (batch_size);
362
+ dim3 nthrs (BLOCK_THREADS);
363
+ IdType* row_indices_placeholder = nullptr ;
364
+ T* top_p_arr_placeholder = nullptr ;
365
+ void * args[] = {&probs,
366
+ &uniform_samples,
367
+ &output,
368
+ &success,
369
+ &row_indices_placeholder,
370
+ &top_p_arr_placeholder,
371
+ &top_p,
372
+ &d};
373
+
374
+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
375
+ auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,
376
+ BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
377
+ FLASHINFER_CUDA_CALL (
378
+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
379
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
380
+ });
381
+ return cudaSuccess;
382
+ }
383
+
384
+ template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
385
+ cudaError_t ParallelTopPSamplingFromProb (T* probs, T* uniform_samples, IdType* output,
386
+ bool * success, IdType* row_indices, T* top_p_arr,
387
+ uint32_t batch_size, uint32_t d, cudaStream_t stream = 0 ) {
350
388
constexpr uint32_t BLOCK_THREADS = 1024 ;
351
389
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
352
390
353
391
const uint32_t smem_size =
354
392
sizeof (SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
355
393
dim3 nblks (batch_size);
356
394
dim3 nthrs (BLOCK_THREADS);
357
- void * args[] = {&probs, &uniform_samples, &output, &success, &p, &d};
395
+ T top_p_placeholder = 0 ;
396
+ void * args[] = {&probs, &uniform_samples, &output,
397
+ &success, &row_indices & top_p_arr, &top_p_placeholder,
398
+ &d};
358
399
359
400
DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
360
401
auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,
0 commit comments