@@ -118,8 +118,9 @@ __device__ void DeviceSamplingFromProb(
118
118
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM, uint32_t VEC_SIZE, typename DType,
119
119
typename IdType>
120
120
__global__ void SamplingFromProbKernel (DType* probs, DType* uniform_samples, IdType* output,
121
- uint32_t d) {
121
+ IdType* row_indices, uint32_t d) {
122
122
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
123
+ const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
123
124
124
125
extern __shared__ __align__ (alignof (SamplingTempStorage<DType, BLOCK_THREADS, ALGORITHM>))
125
126
uint8_t smem[];
@@ -133,7 +134,7 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
133
134
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
134
135
probs_vec.fill (DType (0 ));
135
136
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
136
- probs_vec.load (probs + bx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
137
+ probs_vec.load (probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
137
138
}
138
139
139
140
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, ALGORITHM, DType>(i, DType (0 ), u, probs_vec,
@@ -314,7 +315,28 @@ cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint3
314
315
const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
315
316
dim3 nblks (batch_size);
316
317
dim3 nthrs (BLOCK_THREADS);
317
- void * args[] = {&probs, &uniform_samples, &output, &d};
318
+ IdType* row_indices_placeholder = nullptr ;
319
+ void * args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder, &d};
320
+ const uint32_t smem_size =
321
+ sizeof (SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
322
+
323
+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
324
+ auto kernel =
325
+ SamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
326
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
327
+ });
328
+ return cudaSuccess;
329
+ }
330
+
331
+ template <typename T, typename IdType>
332
+ cudaError_t ParallelSamplingFromProb (T* probs, T* uniform_samples, IdType* output,
333
+ IdType* row_indices, uint32_t batch_size, uint32_t d,
334
+ cudaStream_t stream = 0 ) {
335
+ constexpr uint32_t BLOCK_THREADS = 1024 ;
336
+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
337
+ dim3 nblks (batch_size);
338
+ dim3 nthrs (BLOCK_THREADS);
339
+ void * args[] = {&probs, &uniform_samples, &output, &row_indices, &d};
318
340
const uint32_t smem_size =
319
341
sizeof (SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
320
342
0 commit comments