Skip to content

Commit 4984a27

Browse files
authored
misc: parallel sampling from probability (#214)
follow up of #213 , add the parallel version of parallel sampling without top-p restriction.
1 parent b3f1ffb commit 4984a27

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

include/flashinfer/sampling.cuh

+25-3
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ __device__ void DeviceSamplingFromProb(
118118
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM, uint32_t VEC_SIZE, typename DType,
119119
typename IdType>
120120
__global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output,
121-
uint32_t d) {
121+
IdType* row_indices, uint32_t d) {
122122
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
123+
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
123124

124125
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, ALGORITHM>))
125126
uint8_t smem[];
@@ -133,7 +134,7 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
133134
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
134135
probs_vec.fill(DType(0));
135136
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
136-
probs_vec.load(probs + bx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
137+
probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
137138
}
138139

139140
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, ALGORITHM, DType>(i, DType(0), u, probs_vec,
@@ -314,7 +315,28 @@ cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint3
314315
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
315316
dim3 nblks(batch_size);
316317
dim3 nthrs(BLOCK_THREADS);
317-
void* args[] = {&probs, &uniform_samples, &output, &d};
318+
IdType* row_indices_placeholder = nullptr;
319+
void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder, &d};
320+
const uint32_t smem_size =
321+
sizeof(SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
322+
323+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
324+
auto kernel =
325+
SamplingFromProbKernel<BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
326+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
327+
});
328+
return cudaSuccess;
329+
}
330+
331+
template <typename T, typename IdType>
332+
cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
333+
IdType* row_indices, uint32_t batch_size, uint32_t d,
334+
cudaStream_t stream = 0) {
335+
constexpr uint32_t BLOCK_THREADS = 1024;
336+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
337+
dim3 nblks(batch_size);
338+
dim3 nthrs(BLOCK_THREADS);
339+
void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d};
318340
const uint32_t smem_size =
319341
sizeof(SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
320342

0 commit comments

Comments
 (0)