Skip to content

Commit b3f1ffb

Browse files
authored
sampling: support parallel top-p sampling (#213)
Add a new API `ParallelTopPSamplingFromProb`, which enables sampling from the same distribution multiple times, and allowing user to specify batch-specific `top_p`. cc @MasterJH5574
1 parent fb69910 commit b3f1ffb

File tree

1 file changed

+54
-13
lines changed

1 file changed

+54
-13
lines changed

include/flashinfer/sampling.cuh

+54-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#ifndef FLASHINFER_SAMPLING_CUH_
1717
#define FLASHINFER_SAMPLING_CUH_
1818

19-
#include <cstdint>
2019
#include <cub/block/block_adjacent_difference.cuh>
2120
#include <cub/block/block_reduce.cuh>
2221
#include <cub/block/block_scan.cuh>
@@ -229,10 +228,16 @@ constexpr float eps = 1e-5;
229228
template <uint32_t MAX_TOP_P_ROUNDS, uint32_t BLOCK_THREADS, BlockScanAlgorithm ALGORITHM,
230229
uint32_t VEC_SIZE, typename DType, typename IdType>
231230
__global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output,
232-
bool* success, float p, uint32_t d) {
231+
bool* success, IdType* row_indices, float* top_p_arr,
232+
float top_p, uint32_t d) {
233233
const uint32_t batch_size = gridDim.x;
234234
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
235235

236+
if (top_p_arr != nullptr) {
237+
top_p = top_p_arr[bx];
238+
}
239+
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
240+
236241
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, ALGORITHM>))
237242
uint8_t smem[];
238243
auto& temp_storage =
@@ -249,7 +254,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
249254
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
250255
probs_vec.fill(DType(0));
251256
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
252-
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
257+
probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
253258
}
254259

255260
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, ALGORITHM, DType>(i, pivot, u, probs_vec,
@@ -260,13 +265,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
260265
}
261266
__syncthreads();
262267
sampled_id = (aggregate > u) ? temp_storage.data.sampled_id : d - 1;
263-
pivot = probs[bx * d + sampled_id];
268+
pivot = probs[row_idx * d + sampled_id];
264269

265270
DType aggregate_leq_pivot = DType(0);
266271
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
267272
probs_vec.fill(DType(0));
268273
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
269-
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
274+
probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
270275
}
271276

272277
DType probs_leq_pivot[VEC_SIZE];
@@ -281,18 +286,18 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
281286
temp_storage.data.block_aggregate.value = aggregate_leq_pivot;
282287
}
283288
__syncthreads();
284-
if (temp_storage.data.block_aggregate.value + p > 1 + eps) {
289+
if (temp_storage.data.block_aggregate.value + top_p > 1 + eps) {
285290
break;
286291
}
287292
}
288293
q = temp_storage.data.block_aggregate.value;
289-
if (q + p > 1 + eps) {
294+
if (q + top_p > 1 + eps) {
290295
break;
291296
}
292297
}
293298
__syncthreads();
294299
if (tx == 0) {
295-
if (q + p <= 1 + eps) {
300+
if (q + top_p <= 1 + eps) {
296301
// failed to sample within MAX_TOP_P_ROUNDS
297302
success[bx] = false;
298303
} else {
@@ -323,7 +328,7 @@ cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint3
323328

324329
template <uint32_t MAX_TOP_K_ROUNDS, typename T, typename IdType>
325330
cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success,
326-
IdType k, uint32_t batch_size, uint32_t d,
331+
IdType top_k, uint32_t batch_size, uint32_t d,
327332
cudaStream_t stream = 0) {
328333
constexpr uint32_t BLOCK_THREADS = 1024;
329334
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
@@ -332,7 +337,7 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
332337
sizeof(SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
333338
dim3 nblks(batch_size);
334339
dim3 nthrs(BLOCK_THREADS);
335-
void* args[] = {&probs, &uniform_samples, &output, &success, &k, &d};
340+
void* args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d};
336341

337342
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
338343
auto kernel = TopKSamplingFromProbKernel<MAX_TOP_K_ROUNDS, BLOCK_THREADS,
@@ -345,16 +350,52 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
345350
}
346351

347352
template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
348-
cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, T p,
349-
uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) {
353+
cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success,
354+
T top_p, uint32_t batch_size, uint32_t d,
355+
cudaStream_t stream = 0) {
356+
constexpr uint32_t BLOCK_THREADS = 1024;
357+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
358+
359+
const uint32_t smem_size =
360+
sizeof(SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
361+
dim3 nblks(batch_size);
362+
dim3 nthrs(BLOCK_THREADS);
363+
IdType* row_indices_placeholder = nullptr;
364+
T* top_p_arr_placeholder = nullptr;
365+
void* args[] = {&probs,
366+
&uniform_samples,
367+
&output,
368+
&success,
369+
&row_indices_placeholder,
370+
&top_p_arr_placeholder,
371+
&top_p,
372+
&d};
373+
374+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
375+
auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,
376+
BLOCK_SCAN_RAKING_MEMOIZE, VEC_SIZE, T, IdType>;
377+
FLASHINFER_CUDA_CALL(
378+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
379+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
380+
});
381+
return cudaSuccess;
382+
}
383+
384+
template <uint32_t MAX_TOP_P_ROUNDS, typename T, typename IdType>
385+
cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
386+
bool* success, IdType* row_indices, T* top_p_arr,
387+
uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) {
350388
constexpr uint32_t BLOCK_THREADS = 1024;
351389
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
352390

353391
const uint32_t smem_size =
354392
sizeof(SamplingTempStorage<T, BLOCK_THREADS, BLOCK_SCAN_RAKING_MEMOIZE>);
355393
dim3 nblks(batch_size);
356394
dim3 nthrs(BLOCK_THREADS);
357-
void* args[] = {&probs, &uniform_samples, &output, &success, &p, &d};
395+
T top_p_placeholder = 0;
396+
void* args[] = {&probs, &uniform_samples, &output,
397+
&success, &row_indices & top_p_arr, &top_p_placeholder,
398+
&d};
358399

359400
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
360401
auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,

0 commit comments

Comments
 (0)