23
23
#include < flashinfer/attention/cascade.cuh>
24
24
#include < flashinfer/decode_attention_decl.cuh>
25
25
#include < flashinfer/prefill_attention_decl.cuh>
26
+ #include < flashinfer/sampling.cuh>
26
27
#include < optional>
27
28
28
29
using tvm::runtime::Array;
@@ -47,6 +48,26 @@ using namespace flashinfer;
47
48
LOG (FATAL) << " Unsupported data type " << dl_dtype.code ; \
48
49
}
49
50
51
+ #define DISPATCH_REJECTIVE_SAMPLING_NUM_ROUNDS (num_rounds, NUM_ROUNDS, ...) \
52
+ if (num_rounds == 1 ) { \
53
+ constexpr bool NUM_ROUNDS = 1 ; \
54
+ __VA_ARGS__ \
55
+ } else if (num_rounds == 2 ) { \
56
+ constexpr bool NUM_ROUNDS = 2 ; \
57
+ __VA_ARGS__ \
58
+ } else if (num_rounds == 4 ) { \
59
+ constexpr bool NUM_ROUNDS = 4 ; \
60
+ __VA_ARGS__ \
61
+ } else if (num_rounds == 8 ) { \
62
+ constexpr bool NUM_ROUNDS = 8 ; \
63
+ __VA_ARGS__ \
64
+ } else if (num_rounds == 16 ) { \
65
+ constexpr bool NUM_ROUNDS = 16 ; \
66
+ __VA_ARGS__ \
67
+ } else { \
68
+ LOG (FATAL) << " Unsupported number of rejective sampling rounds " << num_rounds; \
69
+ }
70
+
50
71
int _FlashInferSinglePrefillWithKVCache (DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp,
51
72
bool causal, int64_t kv_layout, int64_t pos_encoding_mode,
52
73
bool allow_fp16_qk_reduction, double rope_scale,
@@ -534,6 +555,10 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(
534
555
/* causal=*/ bool (causal), QKVLayout::kNHD , PosEncodingMode (pos_encoding_mode),
535
556
/* allow_fp16_qk_reduction=*/ false , sm_scale, rope_scale, rope_theta,
536
557
/* sm_scale=*/ 0 );
558
+ if (status != cudaSuccess) {
559
+ LOG (FATAL) << " FlashInfer AttentionPrefillWithRaggedKVCache error "
560
+ << cudaGetErrorString (status);
561
+ }
537
562
})})})
538
563
}
539
564
@@ -609,7 +634,7 @@ void _FlashInferMergeState(DLTensor* v_a, DLTensor* s_a, DLTensor* v_b, DLTensor
609
634
static_cast <dtype_out*>(v_merged->data ), static_cast <float *>(s_merged->data ),
610
635
batch_size, num_heads, head_dim);
611
636
if (status != cudaSuccess) {
612
- LOG (FATAL) << " FlashInfer CUDA kernel error " << cudaGetErrorString (status);
637
+ LOG (FATAL) << " FlashInfer CUDA MergeState error " << cudaGetErrorString (status);
613
638
}
614
639
})});
615
640
}
@@ -651,7 +676,7 @@ void _FlashInferMergeStateInPlace(DLTensor* v, DLTensor* s, DLTensor* v_other, D
651
676
static_cast <dtype*>(v_other->data ), static_cast <float *>(s_other->data ),
652
677
batch_size, num_heads, head_dim);
653
678
if (status != cudaSuccess) {
654
- LOG (FATAL) << " FlashInfer CUDA kernel error " << cudaGetErrorString (status);
679
+ LOG (FATAL) << " FlashInfer CUDA MergeStateInPlace error " << cudaGetErrorString (status);
655
680
}
656
681
});
657
682
}
@@ -672,6 +697,97 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in
672
697
})});
673
698
}
674
699
700
+ void _FlashInferParallelSamplingFromProb (DLTensor* probs, DLTensor* uniform_samples,
701
+ DLTensor* row_indices, DLTensor* sampled_token_ids) {
702
+ CHECK_EQ (probs->device .device_type , kDLCUDA ) << " The device of probs must be CUDA." ;
703
+ CHECK_EQ (uniform_samples->device .device_type , kDLCUDA )
704
+ << " The device of uniform_samples must be CUDA." ;
705
+ CHECK_EQ (row_indices->device .device_type , kDLCUDA ) << " The device of row_indices must be CUDA." ;
706
+ CHECK_EQ (sampled_token_ids->device .device_type , kDLCUDA )
707
+ << " The device of sampled_token_ids must be CUDA." ;
708
+
709
+ int dev_id = probs->device .device_id ;
710
+ CHECK_EQ (uniform_samples->device .device_id , dev_id);
711
+ CHECK_EQ (row_indices->device .device_id , dev_id);
712
+ CHECK_EQ (sampled_token_ids->device .device_id , dev_id);
713
+
714
+ CHECK (probs->dtype .lanes == 1 && uniform_samples->dtype .lanes == 1 &&
715
+ row_indices->dtype .lanes == 1 && sampled_token_ids->dtype .lanes == 1 );
716
+ CHECK (probs->dtype .code == kDLFloat && probs->dtype .bits == 32 );
717
+ CHECK (uniform_samples->dtype .code == kDLFloat && uniform_samples->dtype .bits == 32 );
718
+ CHECK (row_indices->dtype .code == kDLInt && row_indices->dtype .bits == 32 );
719
+ CHECK (sampled_token_ids->dtype .code == kDLInt && sampled_token_ids->dtype .bits == 32 );
720
+
721
+ CHECK_EQ (probs->ndim , 2 ); // num_probs, vocab_size
722
+ CHECK_EQ (uniform_samples->ndim , 1 ); // batch_size,
723
+ CHECK_EQ (row_indices->ndim , 1 ); // batch_size,
724
+ CHECK_EQ (sampled_token_ids->ndim , 1 ); // batch_size,
725
+ int64_t num_probs = probs->shape [0 ];
726
+ int64_t vocab_size = probs->shape [1 ];
727
+ int64_t batch_size = row_indices->shape [0 ];
728
+ CHECK_EQ (uniform_samples->shape [0 ], batch_size);
729
+ CHECK_EQ (sampled_token_ids->shape [0 ], batch_size);
730
+
731
+ cudaError_t status = sampling::ParallelSamplingFromProb<float , int32_t >(
732
+ static_cast <float *>(probs->data ), static_cast <float *>(uniform_samples->data ),
733
+ static_cast <int32_t *>(sampled_token_ids->data ), static_cast <int32_t *>(row_indices->data ),
734
+ batch_size, vocab_size);
735
+ if (status != cudaSuccess) {
736
+ LOG (FATAL) << " FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString (status);
737
+ }
738
+ }
739
+
740
+ void _FlashInferParallelTopPSamplingFromProb (DLTensor* probs, DLTensor* uniform_samples,
741
+ DLTensor* row_indices, DLTensor* top_p,
742
+ DLTensor* sampled_token_ids, int num_rounds) {
743
+ CHECK_EQ (probs->device .device_type , kDLCUDA ) << " The device of probs must be CUDA." ;
744
+ CHECK_EQ (uniform_samples->device .device_type , kDLCUDA )
745
+ << " The device of uniform_samples must be CUDA." ;
746
+ CHECK_EQ (row_indices->device .device_type , kDLCUDA ) << " The device of row_indices must be CUDA." ;
747
+ CHECK_EQ (top_p->device .device_type , kDLCUDA ) << " The device of top_p must be CUDA." ;
748
+ CHECK_EQ (sampled_token_ids->device .device_type , kDLCUDA )
749
+ << " The device of sampled_token_ids must be CUDA." ;
750
+
751
+ int dev_id = probs->device .device_id ;
752
+ CHECK_EQ (uniform_samples->device .device_id , dev_id);
753
+ CHECK_EQ (row_indices->device .device_id , dev_id);
754
+ CHECK_EQ (top_p->device .device_id , dev_id);
755
+ CHECK_EQ (sampled_token_ids->device .device_id , dev_id);
756
+
757
+ CHECK (probs->dtype .lanes == 1 && uniform_samples->dtype .lanes == 1 &&
758
+ row_indices->dtype .lanes == 1 && top_p->dtype .lanes == 1 &&
759
+ sampled_token_ids->dtype .lanes == 1 );
760
+ CHECK (probs->dtype .code == kDLFloat && probs->dtype .bits == 32 );
761
+ CHECK (uniform_samples->dtype .code == kDLFloat && uniform_samples->dtype .bits == 32 );
762
+ CHECK (top_p->dtype .code == kDLFloat && top_p->dtype .bits == 32 );
763
+ CHECK (row_indices->dtype .code == kDLInt && row_indices->dtype .bits == 32 );
764
+ CHECK (sampled_token_ids->dtype .code == kDLInt && sampled_token_ids->dtype .bits == 32 );
765
+
766
+ CHECK_EQ (probs->ndim , 2 ); // num_probs, vocab_size
767
+ CHECK_EQ (uniform_samples->ndim , 1 ); // batch_size * num_rounds,
768
+ CHECK_EQ (row_indices->ndim , 1 ); // batch_size,
769
+ CHECK_EQ (top_p->ndim , 1 ); // num_probs,
770
+ CHECK_EQ (sampled_token_ids->ndim , 1 ); // batch_size,
771
+ int64_t num_probs = probs->shape [0 ];
772
+ int64_t vocab_size = probs->shape [1 ];
773
+ int64_t batch_size = row_indices->shape [0 ];
774
+ CHECK_EQ (uniform_samples->shape [0 ], batch_size * num_rounds);
775
+ CHECK_EQ (top_p->shape [0 ], num_probs);
776
+ CHECK_EQ (sampled_token_ids->shape [0 ], batch_size);
777
+
778
+ DISPATCH_REJECTIVE_SAMPLING_NUM_ROUNDS (num_rounds, rej_samping_num_rounds, {
779
+ cudaError_t status =
780
+ sampling::ParallelTopPSamplingFromProb<rej_samping_num_rounds, float , int32_t >(
781
+ static_cast <float *>(probs->data ), static_cast <float *>(uniform_samples->data ),
782
+ static_cast <int32_t *>(sampled_token_ids->data ), /* success=*/ nullptr ,
783
+ static_cast <int32_t *>(row_indices->data ), static_cast <float *>(top_p->data ), batch_size,
784
+ vocab_size);
785
+ if (status != cudaSuccess) {
786
+ LOG (FATAL) << " FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString (status);
787
+ }
788
+ });
789
+ }
790
+
675
791
TVM_REGISTER_GLOBAL (" flashinfer.attention_kernel_prefill_with_paged_kv_cache" )
676
792
.set_body_typed(_FlashInferAttentionPrefillWithPagedKVCache);
677
793
@@ -708,4 +824,11 @@ TVM_REGISTER_GLOBAL("flashinfer.batch_qk_apply_rotary_in_place")
708
824
709
825
TVM_REGISTER_GLOBAL (" flashinfer.single_prefill" )
710
826
.set_body_typed(_FlashInferSinglePrefillWithKVCache);
827
+
711
828
TVM_REGISTER_GLOBAL (" flashinfer.single_decode" ).set_body_typed(_FlashInferSingleDecodeWithKVCache);
829
+
830
+ TVM_REGISTER_GLOBAL (" flashinfer.sampling.parallel_sampling_from_prob" )
831
+ .set_body_typed(_FlashInferParallelSamplingFromProb);
832
+
833
+ TVM_REGISTER_GLOBAL (" flashinfer.sampling.parallel_top_p_sampling_from_prob" )
834
+ .set_body_typed(_FlashInferParallelTopPSamplingFromProb);
0 commit comments