Skip to content

Commit a02be3e

Browse files
authored
[TVMWrapper] Add wrapper functions for sampler (#215)
This commit adds the wrapper functions for sampling functions.
1 parent 4984a27 commit a02be3e

File tree

2 files changed

+139
-9
lines changed

2 files changed

+139
-9
lines changed

include/flashinfer/sampling.cuh

+14-7
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,14 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
216216
if (tx == 0) {
217217
if (temp_storage.data.block_aggregate.pair.count + k <= d) {
218218
// failed to sample within MAX_TOP_P_ROUNDS
219-
success[bx] = false;
219+
if (success != nullptr) {
220+
success[bx] = false;
221+
}
220222
} else {
221223
output[bx] = sampled_id;
222-
success[bx] = true;
224+
if (success != nullptr) {
225+
success[bx] = true;
226+
}
223227
}
224228
}
225229
}
@@ -300,10 +304,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
300304
if (tx == 0) {
301305
if (q + top_p <= 1 + eps) {
302306
// failed to sample within MAX_TOP_P_ROUNDS
303-
success[bx] = false;
307+
if (success != nullptr) {
308+
success[bx] = false;
309+
}
304310
} else {
305311
output[bx] = sampled_id;
306-
success[bx] = true;
312+
if (success != nullptr) {
313+
success[bx] = true;
314+
}
307315
}
308316
}
309317
}
@@ -415,9 +423,8 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o
415423
dim3 nblks(batch_size);
416424
dim3 nthrs(BLOCK_THREADS);
417425
T top_p_placeholder = 0;
418-
void* args[] = {&probs, &uniform_samples, &output,
419-
&success, &row_indices & top_p_arr, &top_p_placeholder,
420-
&d};
426+
void* args[] = {&probs, &uniform_samples, &output, &success, &row_indices,
427+
&top_p_arr, &top_p_placeholder, &d};
421428

422429
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
423430
auto kernel = TopPSamplingFromProbKernel<MAX_TOP_P_ROUNDS, BLOCK_THREADS,

src/tvm_wrapper.cu

+125-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <flashinfer/attention/cascade.cuh>
2424
#include <flashinfer/decode_attention_decl.cuh>
2525
#include <flashinfer/prefill_attention_decl.cuh>
26+
#include <flashinfer/sampling.cuh>
2627
#include <optional>
2728

2829
using tvm::runtime::Array;
@@ -47,6 +48,26 @@ using namespace flashinfer;
4748
LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \
4849
}
4950

51+
#define DISPATCH_REJECTIVE_SAMPLING_NUM_ROUNDS(num_rounds, NUM_ROUNDS, ...) \
52+
if (num_rounds == 1) { \
53+
constexpr bool NUM_ROUNDS = 1; \
54+
__VA_ARGS__ \
55+
} else if (num_rounds == 2) { \
56+
constexpr bool NUM_ROUNDS = 2; \
57+
__VA_ARGS__ \
58+
} else if (num_rounds == 4) { \
59+
constexpr bool NUM_ROUNDS = 4; \
60+
__VA_ARGS__ \
61+
} else if (num_rounds == 8) { \
62+
constexpr bool NUM_ROUNDS = 8; \
63+
__VA_ARGS__ \
64+
} else if (num_rounds == 16) { \
65+
constexpr bool NUM_ROUNDS = 16; \
66+
__VA_ARGS__ \
67+
} else { \
68+
LOG(FATAL) << "Unsupported number of rejective sampling rounds " << num_rounds; \
69+
}
70+
5071
int _FlashInferSinglePrefillWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp,
5172
bool causal, int64_t kv_layout, int64_t pos_encoding_mode,
5273
bool allow_fp16_qk_reduction, double rope_scale,
@@ -534,6 +555,10 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(
534555
/*causal=*/bool(causal), QKVLayout::kNHD, PosEncodingMode(pos_encoding_mode),
535556
/*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta,
536557
/*sm_scale=*/0);
558+
if (status != cudaSuccess) {
559+
LOG(FATAL) << "FlashInfer AttentionPrefillWithRaggedKVCache error "
560+
<< cudaGetErrorString(status);
561+
}
537562
})})})
538563
}
539564

@@ -609,7 +634,7 @@ void _FlashInferMergeState(DLTensor* v_a, DLTensor* s_a, DLTensor* v_b, DLTensor
609634
static_cast<dtype_out*>(v_merged->data), static_cast<float*>(s_merged->data),
610635
batch_size, num_heads, head_dim);
611636
if (status != cudaSuccess) {
612-
LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status);
637+
LOG(FATAL) << "FlashInfer CUDA MergeState error " << cudaGetErrorString(status);
613638
}
614639
})});
615640
}
@@ -651,7 +676,7 @@ void _FlashInferMergeStateInPlace(DLTensor* v, DLTensor* s, DLTensor* v_other, D
651676
static_cast<dtype*>(v_other->data), static_cast<float*>(s_other->data),
652677
batch_size, num_heads, head_dim);
653678
if (status != cudaSuccess) {
654-
LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status);
679+
LOG(FATAL) << "FlashInfer CUDA MergeStateInPlace error " << cudaGetErrorString(status);
655680
}
656681
});
657682
}
@@ -672,6 +697,97 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in
672697
})});
673698
}
674699

700+
void _FlashInferParallelSamplingFromProb(DLTensor* probs, DLTensor* uniform_samples,
701+
DLTensor* row_indices, DLTensor* sampled_token_ids) {
702+
CHECK_EQ(probs->device.device_type, kDLCUDA) << "The device of probs must be CUDA.";
703+
CHECK_EQ(uniform_samples->device.device_type, kDLCUDA)
704+
<< "The device of uniform_samples must be CUDA.";
705+
CHECK_EQ(row_indices->device.device_type, kDLCUDA) << "The device of row_indices must be CUDA.";
706+
CHECK_EQ(sampled_token_ids->device.device_type, kDLCUDA)
707+
<< "The device of sampled_token_ids must be CUDA.";
708+
709+
int dev_id = probs->device.device_id;
710+
CHECK_EQ(uniform_samples->device.device_id, dev_id);
711+
CHECK_EQ(row_indices->device.device_id, dev_id);
712+
CHECK_EQ(sampled_token_ids->device.device_id, dev_id);
713+
714+
CHECK(probs->dtype.lanes == 1 && uniform_samples->dtype.lanes == 1 &&
715+
row_indices->dtype.lanes == 1 && sampled_token_ids->dtype.lanes == 1);
716+
CHECK(probs->dtype.code == kDLFloat && probs->dtype.bits == 32);
717+
CHECK(uniform_samples->dtype.code == kDLFloat && uniform_samples->dtype.bits == 32);
718+
CHECK(row_indices->dtype.code == kDLInt && row_indices->dtype.bits == 32);
719+
CHECK(sampled_token_ids->dtype.code == kDLInt && sampled_token_ids->dtype.bits == 32);
720+
721+
CHECK_EQ(probs->ndim, 2); // num_probs, vocab_size
722+
CHECK_EQ(uniform_samples->ndim, 1); // batch_size,
723+
CHECK_EQ(row_indices->ndim, 1); // batch_size,
724+
CHECK_EQ(sampled_token_ids->ndim, 1); // batch_size,
725+
int64_t num_probs = probs->shape[0];
726+
int64_t vocab_size = probs->shape[1];
727+
int64_t batch_size = row_indices->shape[0];
728+
CHECK_EQ(uniform_samples->shape[0], batch_size);
729+
CHECK_EQ(sampled_token_ids->shape[0], batch_size);
730+
731+
cudaError_t status = sampling::ParallelSamplingFromProb<float, int32_t>(
732+
static_cast<float*>(probs->data), static_cast<float*>(uniform_samples->data),
733+
static_cast<int32_t*>(sampled_token_ids->data), static_cast<int32_t*>(row_indices->data),
734+
batch_size, vocab_size);
735+
if (status != cudaSuccess) {
736+
LOG(FATAL) << "FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString(status);
737+
}
738+
}
739+
740+
void _FlashInferParallelTopPSamplingFromProb(DLTensor* probs, DLTensor* uniform_samples,
741+
DLTensor* row_indices, DLTensor* top_p,
742+
DLTensor* sampled_token_ids, int num_rounds) {
743+
CHECK_EQ(probs->device.device_type, kDLCUDA) << "The device of probs must be CUDA.";
744+
CHECK_EQ(uniform_samples->device.device_type, kDLCUDA)
745+
<< "The device of uniform_samples must be CUDA.";
746+
CHECK_EQ(row_indices->device.device_type, kDLCUDA) << "The device of row_indices must be CUDA.";
747+
CHECK_EQ(top_p->device.device_type, kDLCUDA) << "The device of top_p must be CUDA.";
748+
CHECK_EQ(sampled_token_ids->device.device_type, kDLCUDA)
749+
<< "The device of sampled_token_ids must be CUDA.";
750+
751+
int dev_id = probs->device.device_id;
752+
CHECK_EQ(uniform_samples->device.device_id, dev_id);
753+
CHECK_EQ(row_indices->device.device_id, dev_id);
754+
CHECK_EQ(top_p->device.device_id, dev_id);
755+
CHECK_EQ(sampled_token_ids->device.device_id, dev_id);
756+
757+
CHECK(probs->dtype.lanes == 1 && uniform_samples->dtype.lanes == 1 &&
758+
row_indices->dtype.lanes == 1 && top_p->dtype.lanes == 1 &&
759+
sampled_token_ids->dtype.lanes == 1);
760+
CHECK(probs->dtype.code == kDLFloat && probs->dtype.bits == 32);
761+
CHECK(uniform_samples->dtype.code == kDLFloat && uniform_samples->dtype.bits == 32);
762+
CHECK(top_p->dtype.code == kDLFloat && top_p->dtype.bits == 32);
763+
CHECK(row_indices->dtype.code == kDLInt && row_indices->dtype.bits == 32);
764+
CHECK(sampled_token_ids->dtype.code == kDLInt && sampled_token_ids->dtype.bits == 32);
765+
766+
CHECK_EQ(probs->ndim, 2); // num_probs, vocab_size
767+
CHECK_EQ(uniform_samples->ndim, 1); // batch_size * num_rounds,
768+
CHECK_EQ(row_indices->ndim, 1); // batch_size,
769+
CHECK_EQ(top_p->ndim, 1); // num_probs,
770+
CHECK_EQ(sampled_token_ids->ndim, 1); // batch_size,
771+
int64_t num_probs = probs->shape[0];
772+
int64_t vocab_size = probs->shape[1];
773+
int64_t batch_size = row_indices->shape[0];
774+
CHECK_EQ(uniform_samples->shape[0], batch_size * num_rounds);
775+
CHECK_EQ(top_p->shape[0], num_probs);
776+
CHECK_EQ(sampled_token_ids->shape[0], batch_size);
777+
778+
DISPATCH_REJECTIVE_SAMPLING_NUM_ROUNDS(num_rounds, rej_samping_num_rounds, {
779+
cudaError_t status =
780+
sampling::ParallelTopPSamplingFromProb<rej_samping_num_rounds, float, int32_t>(
781+
static_cast<float*>(probs->data), static_cast<float*>(uniform_samples->data),
782+
static_cast<int32_t*>(sampled_token_ids->data), /*success=*/nullptr,
783+
static_cast<int32_t*>(row_indices->data), static_cast<float*>(top_p->data), batch_size,
784+
vocab_size);
785+
if (status != cudaSuccess) {
786+
LOG(FATAL) << "FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString(status);
787+
}
788+
});
789+
}
790+
675791
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache")
676792
.set_body_typed(_FlashInferAttentionPrefillWithPagedKVCache);
677793

@@ -708,4 +824,11 @@ TVM_REGISTER_GLOBAL("flashinfer.batch_qk_apply_rotary_in_place")
708824

709825
TVM_REGISTER_GLOBAL("flashinfer.single_prefill")
710826
.set_body_typed(_FlashInferSinglePrefillWithKVCache);
827+
711828
TVM_REGISTER_GLOBAL("flashinfer.single_decode").set_body_typed(_FlashInferSingleDecodeWithKVCache);
829+
830+
TVM_REGISTER_GLOBAL("flashinfer.sampling.parallel_sampling_from_prob")
831+
.set_body_typed(_FlashInferParallelSamplingFromProb);
832+
833+
TVM_REGISTER_GLOBAL("flashinfer.sampling.parallel_top_p_sampling_from_prob")
834+
.set_body_typed(_FlashInferParallelTopPSamplingFromProb);

0 commit comments

Comments
 (0)