Skip to content

Commit 2be9ad7

Browse files
yzh119Yong Wu
and
Yong Wu
authored
ci: improve jenkins (#943)
- cancel previous build if new commit comes - add task 3 - accelerate rope test Co-authored-by: Yong Wu <[email protected]>
1 parent 594febe commit 2be9ad7

10 files changed

+316
-76
lines changed

Jenkinsfile

+36-4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ def unpack_lib(name, libs) {
6161
"""
6262
}
6363

64+
def cancel_previous_build() {
65+
// cancel previous build if it is not on main.
66+
if (env.BRANCH_NAME != 'main') {
67+
def buildNumber = env.BUILD_NUMBER as int
68+
// Milestone API allows us to cancel previous build
69+
// with the same milestone number
70+
if (buildNumber > 1) milestone(buildNumber - 1)
71+
milestone(buildNumber)
72+
}
73+
}
74+
6475
def init_git(submodule = false) {
6576
cleanWs()
6677
// add retry in case checkout timeouts
@@ -84,10 +95,21 @@ def init_git(submodule = false) {
8495
// }
8596
// }
8697

87-
stage('JIT Unittest') {
98+
stage('Unittest') {
99+
cancel_previous_build()
88100
parallel(
89101
failFast: true,
90-
'GPU-G5-Test-1': {
102+
'AOT-Build-Import': {
103+
node('CPU-LARGE-SPOT') {
104+
ws(per_exec_ws('flashinfer-aot')) {
105+
init_git(true)
106+
sh(script: "ls -alh", label: 'Show work directory')
107+
sh(script: "./scripts/task_show_node_info.sh", label: 'Show node info')
108+
sh(script: "${docker_run} --no-gpu ./scripts/task_test_aot_build_import.sh", label: 'Test AOT Build and Import')
109+
}
110+
}
111+
},
112+
'JIT-Unittest-1': {
91113
node('GPU-G5-SPOT') {
92114
ws(per_exec_ws('flashinfer-unittest')) {
93115
init_git(true) // we need cutlass submodule
@@ -97,7 +119,7 @@ stage('JIT Unittest') {
97119
}
98120
}
99121
},
100-
'GPU-G5-Test-2': {
122+
'JIT-Unittest-2': {
101123
node('GPU-G5-SPOT') {
102124
ws(per_exec_ws('flashinfer-unittest')) {
103125
init_git(true) // we need cutlass submodule
@@ -107,7 +129,17 @@ stage('JIT Unittest') {
107129
}
108130
}
109131
},
110-
'GPU-G5-Test-4': {
132+
'JIT-Unittest-3': {
133+
node('GPU-G5-SPOT') {
134+
ws(per_exec_ws('flashinfer-unittest')) {
135+
init_git(true) // we need cutlass submodule
136+
sh(script: "ls -alh", label: 'Show work directory')
137+
sh(script: "./scripts/task_show_node_info.sh", label: 'Show node info')
138+
sh(script: "${docker_run} ./scripts/task_jit_run_tests_part3.sh", label: 'JIT Unittest Part 3')
139+
}
140+
}
141+
},
142+
'JIT-Unittest-4': {
111143
node('GPU-G5-SPOT') {
112144
ws(per_exec_ws('flashinfer-unittest')) {
113145
init_git(true) // we need cutlass submodule

benchmarks/bench_sampling.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import torch
2+
from triton.testing import do_bench
3+
4+
import flashinfer
5+
6+
7+
def normal_distribution(std):
8+
def normal_noise(shape, device):
9+
return torch.randn(shape, device=device) * std
10+
11+
normal_noise.__name__ = f"normal_distribution(std={std})"
12+
return normal_noise
13+
14+
15+
def gumbel_distribution(beta):
16+
def gumbel_noise(shape, device):
17+
U = torch.rand(shape, device=device)
18+
eps = 1e-20
19+
return torch.log(-torch.log(U + eps) + eps) / beta
20+
21+
gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})"
22+
return gumbel_noise
23+
24+
25+
def init_seed_sampling(*args, **kwargs):
26+
torch.manual_seed(42)
27+
return flashinfer.sampling.sampling_from_probs(*args, **kwargs)
28+
29+
30+
def init_seed_top_k_sampling(*args, **kwargs):
31+
torch.manual_seed(42)
32+
return flashinfer.sampling.top_k_sampling_from_probs(*args, **kwargs)
33+
34+
35+
def init_seed_top_p_sampling(*args, **kwargs):
36+
torch.manual_seed(42)
37+
return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs)
38+
39+
40+
@torch.inference_mode()
41+
def main():
42+
print("---")
43+
print("naive sampling")
44+
for vocab_size in [128512]:
45+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
46+
for distrib in [
47+
normal_distribution(1),
48+
normal_distribution(5),
49+
gumbel_distribution(0.1),
50+
gumbel_distribution(1),
51+
]:
52+
for deterministic in [True, False]:
53+
logits = distrib((batch_size, vocab_size), device="cuda")
54+
probs = torch.softmax(logits, dim=-1)
55+
samples = torch.zeros(
56+
batch_size, dtype=torch.int32, device=probs.device
57+
)
58+
ms = do_bench(
59+
lambda: init_seed_sampling(probs, deterministic=deterministic),
60+
warmup=100,
61+
rep=1000,
62+
)
63+
64+
io = (
65+
probs.numel() * probs.element_size()
66+
+ samples.numel() * samples.element_size()
67+
)
68+
bandwidth = io * 1e-6 / ms
69+
print(
70+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
71+
)
72+
73+
print("---")
74+
print("top-k sampling")
75+
for vocab_size in [128512]:
76+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
77+
for distrib in [
78+
normal_distribution(1),
79+
normal_distribution(5),
80+
gumbel_distribution(0.1),
81+
gumbel_distribution(1),
82+
]:
83+
for deterministic in [True, False]:
84+
for k in [10, 100, 1000, 5000]:
85+
logits = distrib((batch_size, vocab_size), device="cuda")
86+
probs = torch.softmax(logits, dim=-1)
87+
samples = torch.zeros(
88+
batch_size, dtype=torch.int32, device=probs.device
89+
)
90+
ms = do_bench(
91+
lambda: init_seed_top_k_sampling(
92+
probs, k, deterministic=deterministic
93+
),
94+
warmup=100,
95+
rep=1000,
96+
)
97+
98+
io = (
99+
probs.numel() * probs.element_size()
100+
+ samples.numel() * samples.element_size()
101+
)
102+
bandwidth = io * 1e-6 / ms
103+
print(
104+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
105+
)
106+
107+
print("---")
108+
print("top-p sampling")
109+
110+
for vocab_size in [128512]:
111+
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
112+
for distrib in [
113+
normal_distribution(1),
114+
normal_distribution(5),
115+
gumbel_distribution(0.1),
116+
gumbel_distribution(1),
117+
]:
118+
for deterministic in [True, False]:
119+
for p in [0.1, 0.5, 0.9]:
120+
logits = distrib((batch_size, vocab_size), device="cuda")
121+
probs = torch.softmax(logits, dim=-1)
122+
samples = torch.zeros(
123+
batch_size, dtype=torch.int32, device=probs.device
124+
)
125+
ms = do_bench(
126+
lambda: init_seed_top_p_sampling(
127+
probs, p, deterministic=deterministic
128+
),
129+
warmup=100,
130+
rep=1000,
131+
)
132+
133+
io = (
134+
probs.numel() * probs.element_size()
135+
+ samples.numel() * samples.element_size()
136+
)
137+
bandwidth = io * 1e-6 / ms
138+
print(
139+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
140+
)
141+
142+
143+
if __name__ == "__main__":
144+
main()

include/flashinfer/sampling.cuh

+20-15
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind
362362
float aggregate(0);
363363
float u = curand_uniform(&state);
364364

365+
#pragma unroll 2
365366
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
366367
probs_vec.fill(0);
367368
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -405,14 +406,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
405406
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
406407
smem_sampling);
407408

408-
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
409-
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
410-
probs, row_idx, d, temp_storage);
411-
412409
vec_t<float, VEC_SIZE> probs_vec;
413410
float aggregate;
414411
float q = 1;
415-
double low = 0, high = max_val;
412+
double low = 0, high = 1.f;
416413
int sampled_id;
417414
int round = 0;
418415
do {
@@ -421,6 +418,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
421418
__syncthreads();
422419
float u = curand_uniform(&state) * q;
423420
aggregate = 0;
421+
#pragma unroll 2
424422
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
425423
probs_vec.fill(0);
426424
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -446,6 +444,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType*
446444
double pivot_1 = (pivot_0 + high) / 2;
447445

448446
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
447+
#pragma unroll 2
449448
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
450449
probs_vec.fill(0);
451450
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -522,20 +521,17 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
522521
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
523522
smem_sampling);
524523

525-
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
526-
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
527-
probs, row_idx, d, temp_storage);
528-
529524
vec_t<float, VEC_SIZE> probs_vec;
530525
float aggregate;
531526
float q = 1;
532-
double low = 0, high = max_val;
527+
double low = 0, high = 1.f;
533528
int sampled_id;
534529
do {
535530
temp_storage.sampled_id = d;
536531
__syncthreads();
537532
float u = curand_uniform(&state) * q;
538533
aggregate = 0;
534+
#pragma unroll 2
539535
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
540536
probs_vec.fill(0);
541537
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -561,6 +557,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType*
561557
double pivot_1 = (pivot_0 + high) / 2;
562558

563559
float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0;
560+
#pragma unroll 2
564561
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
565562
probs_vec.fill(0);
566563
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -637,6 +634,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
637634

638635
vec_t<float, VEC_SIZE> probs_vec;
639636
float aggregate_gt_pivot = 0;
637+
#pragma unroll 2
640638
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
641639
probs_vec.fill(0);
642640
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -664,6 +662,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp
664662
temp_storage.sampled_id = d;
665663
__syncthreads();
666664
float u = curand_uniform(&state) * q;
665+
#pragma unroll 2
667666
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
668667
probs_vec.fill(0);
669668
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -709,20 +708,17 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
709708
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
710709
smem_sampling);
711710

712-
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
713-
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
714-
probs, row_idx, d, temp_storage);
715-
716711
vec_t<float, VEC_SIZE> probs_vec;
717712
float aggregate;
718713
float q = 1;
719-
double low = 0, high = max_val;
714+
double low = 0, high = 1.f;
720715
int sampled_id;
721716
do {
722717
temp_storage.sampled_id = d;
723718
__syncthreads();
724719
float u = curand_uniform(&state) * q;
725720
aggregate = 0;
721+
#pragma unroll 2
726722
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
727723
probs_vec.fill(0);
728724
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -748,6 +744,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr,
748744
double pivot_1 = (pivot_0 + high) / 2;
749745

750746
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
747+
#pragma unroll 2
751748
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
752749
probs_vec.fill(0);
753750
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -988,6 +985,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
988985
double mid = (low + high) / 2;
989986
min_gt_low = high;
990987
max_le_high = low;
988+
#pragma unroll 2
991989
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
992990
probs_vec.fill(0);
993991
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1034,6 +1032,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
10341032
float normalizer = math::ptx_rcp(max(sum_low, 1e-8));
10351033

10361034
// normalize
1035+
#pragma unroll 2
10371036
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
10381037
probs_vec.fill(0);
10391038
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1085,6 +1084,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
10851084
double mid = (low + high) / 2;
10861085
min_gt_low = high;
10871086
max_le_high = low;
1087+
#pragma unroll 2
10881088
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
10891089
logits_vec.fill(0);
10901090
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1132,6 +1132,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
11321132
}
11331133

11341134
// masking
1135+
#pragma unroll 2
11351136
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
11361137
logits_vec.fill(0);
11371138
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1185,6 +1186,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
11851186
double mid = (low + high) / 2;
11861187
min_gt_low = high;
11871188
max_le_high = low;
1189+
#pragma unroll 2
11881190
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
11891191
probs_vec.fill(0);
11901192
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1236,6 +1238,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
12361238
}
12371239

12381240
// normalize
1241+
#pragma unroll 2
12391242
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
12401243
probs_vec.fill(0);
12411244
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1372,6 +1375,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
13721375
float sum_relu_q_minus_p = 0;
13731376
vec_t<float, VEC_SIZE> q_vec, p_vec;
13741377
float relu_q_minus_p[VEC_SIZE];
1378+
#pragma unroll 2
13751379
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
13761380
q_vec.fill(0);
13771381
p_vec.fill(0);
@@ -1403,6 +1407,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
14031407
float u = curand_uniform(&curand_state) * sum_relu_q_minus_p;
14041408

14051409
float aggregate_relu_q_minus_p(0);
1410+
#pragma unroll 2
14061411
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
14071412
q_vec.fill(0);
14081413
p_vec.fill(0);

0 commit comments

Comments
 (0)