@@ -362,6 +362,7 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind
362
362
float aggregate (0 );
363
363
float u = curand_uniform (&state);
364
364
365
+ #pragma unroll 2
365
366
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
366
367
probs_vec.fill (0 );
367
368
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -405,14 +406,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
405
406
reinterpret_cast <SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
406
407
smem_sampling);
407
408
408
- float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
409
- SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
410
- probs, row_idx, d, temp_storage);
411
-
412
409
vec_t <float , VEC_SIZE> probs_vec;
413
410
float aggregate;
414
411
float q = 1 ;
415
- double low = 0 , high = max_val ;
412
+ double low = 0 , high = 1 . f ;
416
413
int sampled_id;
417
414
int round = 0 ;
418
415
do {
@@ -421,6 +418,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
421
418
__syncthreads ();
422
419
float u = curand_uniform (&state) * q;
423
420
aggregate = 0 ;
421
+ #pragma unroll 2
424
422
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
425
423
probs_vec.fill (0 );
426
424
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -446,6 +444,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
446
444
double pivot_1 = (pivot_0 + high) / 2 ;
447
445
448
446
ValueCount<float > aggregate_gt_pivot_0{0 , 0 }, aggregate_gt_pivot_1{0 , 0 };
447
+ #pragma unroll 2
449
448
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
450
449
probs_vec.fill (0 );
451
450
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -522,20 +521,17 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
522
521
reinterpret_cast <SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
523
522
smem_sampling);
524
523
525
- float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
526
- SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
527
- probs, row_idx, d, temp_storage);
528
-
529
524
vec_t <float , VEC_SIZE> probs_vec;
530
525
float aggregate;
531
526
float q = 1 ;
532
- double low = 0 , high = max_val ;
527
+ double low = 0 , high = 1 . f ;
533
528
int sampled_id;
534
529
do {
535
530
temp_storage.sampled_id = d;
536
531
__syncthreads ();
537
532
float u = curand_uniform (&state) * q;
538
533
aggregate = 0 ;
534
+ #pragma unroll 2
539
535
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
540
536
probs_vec.fill (0 );
541
537
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -561,6 +557,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
561
557
double pivot_1 = (pivot_0 + high) / 2 ;
562
558
563
559
float aggregate_gt_pivot_0 = 0 , aggregate_gt_pivot_1 = 0 ;
560
+ #pragma unroll 2
564
561
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
565
562
probs_vec.fill (0 );
566
563
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -637,6 +634,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
637
634
638
635
vec_t <float , VEC_SIZE> probs_vec;
639
636
float aggregate_gt_pivot = 0 ;
637
+ #pragma unroll 2
640
638
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
641
639
probs_vec.fill (0 );
642
640
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -664,6 +662,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
664
662
temp_storage.sampled_id = d;
665
663
__syncthreads ();
666
664
float u = curand_uniform (&state) * q;
665
+ #pragma unroll 2
667
666
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
668
667
probs_vec.fill (0 );
669
668
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -709,20 +708,17 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
709
708
reinterpret_cast <SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
710
709
smem_sampling);
711
710
712
- float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
713
- SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
714
- probs, row_idx, d, temp_storage);
715
-
716
711
vec_t <float , VEC_SIZE> probs_vec;
717
712
float aggregate;
718
713
float q = 1 ;
719
- double low = 0 , high = max_val ;
714
+ double low = 0 , high = 1 . f ;
720
715
int sampled_id;
721
716
do {
722
717
temp_storage.sampled_id = d;
723
718
__syncthreads ();
724
719
float u = curand_uniform (&state) * q;
725
720
aggregate = 0 ;
721
+ #pragma unroll 2
726
722
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
727
723
probs_vec.fill (0 );
728
724
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -748,6 +744,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
748
744
double pivot_1 = (pivot_0 + high) / 2 ;
749
745
750
746
ValueCount<float > aggregate_gt_pivot_0{0 , 0 }, aggregate_gt_pivot_1{0 , 0 };
747
+ #pragma unroll 2
751
748
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
752
749
probs_vec.fill (0 );
753
750
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -988,6 +985,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
988
985
double mid = (low + high) / 2 ;
989
986
min_gt_low = high;
990
987
max_le_high = low;
988
+ #pragma unroll 2
991
989
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
992
990
probs_vec.fill (0 );
993
991
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1034,6 +1032,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
1034
1032
float normalizer = math::ptx_rcp (max (sum_low, 1e-8 ));
1035
1033
1036
1034
// normalize
1035
+ #pragma unroll 2
1037
1036
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1038
1037
probs_vec.fill (0 );
1039
1038
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1085,6 +1084,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
1085
1084
double mid = (low + high) / 2 ;
1086
1085
min_gt_low = high;
1087
1086
max_le_high = low;
1087
+ #pragma unroll 2
1088
1088
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1089
1089
logits_vec.fill (0 );
1090
1090
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1132,6 +1132,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
1132
1132
}
1133
1133
1134
1134
// masking
1135
+ #pragma unroll 2
1135
1136
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1136
1137
logits_vec.fill (0 );
1137
1138
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1185,6 +1186,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1185
1186
double mid = (low + high) / 2 ;
1186
1187
min_gt_low = high;
1187
1188
max_le_high = low;
1189
+ #pragma unroll 2
1188
1190
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1189
1191
probs_vec.fill (0 );
1190
1192
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1236,6 +1238,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
1236
1238
}
1237
1239
1238
1240
// normalize
1241
+ #pragma unroll 2
1239
1242
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1240
1243
probs_vec.fill (0 );
1241
1244
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1372,6 +1375,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
1372
1375
float sum_relu_q_minus_p = 0 ;
1373
1376
vec_t <float , VEC_SIZE> q_vec, p_vec;
1374
1377
float relu_q_minus_p[VEC_SIZE];
1378
+ #pragma unroll 2
1375
1379
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1376
1380
q_vec.fill (0 );
1377
1381
p_vec.fill (0 );
@@ -1403,6 +1407,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
1403
1407
float u = curand_uniform (&curand_state) * sum_relu_q_minus_p;
1404
1408
1405
1409
float aggregate_relu_q_minus_p (0 );
1410
+ #pragma unroll 2
1406
1411
for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
1407
1412
q_vec.fill (0 );
1408
1413
p_vec.fill (0 );
0 commit comments