3
3
import torch .nn .functional as F
4
4
from sgl_kernel import tree_speculative_sampling_target_only
5
5
6
+ test_cases = [
7
+ (
8
+ 1 ,
9
+ 1 ,
10
+ [3 , - 1 , - 1 , 4 , 5 , 18 , 11 , - 1 , - 1 , - 1 , 12 , 18 ],
11
+ [[0 , 3 , 4 , 5 ], [6 , 10 , 11 , - 1 ]],
12
+ [3 , 2 ],
13
+ ),
14
+ (
15
+ 0 , # threshold_single
16
+ 0 , # threshold_acc
17
+ [1 , 2 , 18 , - 1 , - 1 , - 1 , 11 , - 1 , - 1 , - 1 , 12 , 18 ],
18
+ [[0 , 1 , 2 , - 1 ], [6 , 10 , 11 , - 1 ]],
19
+ [2 , 2 ],
20
+ ),
21
+ ]
22
+
23
+
24
+ @pytest .mark .parametrize (
25
+ "threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num" ,
26
+ test_cases ,
27
+ )
28
+ def test_tree_speculative_sampling_target_only (
29
+ threshold_single ,
30
+ threshold_acc ,
31
+ expected_predicts ,
32
+ expected_accept_index ,
33
+ expected_accept_token_num ,
34
+ ):
35
+ """
36
+ Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
37
+ """
38
+ device = "cuda"
6
39
7
- def test_tree_speculative_sampling_target_only (threshold_single = 1 , threshold_acc = 1 ):
8
- print (
9
- f"\n ============= run test: { threshold_single = } { threshold_acc = } ==============\n "
10
- )
11
40
candidates = torch .tensor (
12
41
[
13
42
[0 , 1 , 2 , 3 , 4 , 5 ],
14
43
[7 , 8 , 9 , 10 , 11 , 12 ],
15
44
],
16
45
dtype = torch .int32 ,
17
- device = "cuda" ,
46
+ device = device ,
18
47
)
19
48
retrive_index = torch .tensor (
20
49
[
21
50
[0 , 1 , 2 , 3 , 4 , 5 ],
22
51
[6 , 7 , 8 , 9 , 10 , 11 ],
23
52
],
24
53
dtype = torch .int32 ,
25
- device = "cuda" ,
54
+ device = device ,
26
55
)
27
56
retrive_next_token = torch .tensor (
28
57
[
29
58
[1 , 2 , - 1 , 4 , 5 , - 1 ],
30
59
[4 , 2 , 3 , - 1 , 5 , - 1 ],
31
60
],
32
61
dtype = torch .int32 ,
33
- device = "cuda" ,
62
+ device = device ,
34
63
)
35
64
retrive_next_sibling = torch .tensor (
36
65
[
37
66
[- 1 , 3 , - 1 , - 1 , - 1 , - 1 ],
38
67
[- 1 , - 1 , - 1 , - 1 , 1 , - 1 ],
39
68
],
40
69
dtype = torch .int32 ,
41
- device = "cuda" ,
70
+ device = device ,
42
71
)
43
72
44
- target_logits = torch .full ((2 , 6 , 20 ), 1 , dtype = torch .float32 , device = "cuda" )
73
+ target_logits = torch .full ((2 , 6 , 20 ), 1 , dtype = torch .float32 , device = device )
45
74
target_logits [0 , 0 , 3 ] = 10
46
75
target_logits [0 , 3 , 4 ] = 10
47
76
target_logits [0 , 4 , 5 ] = 10
48
77
target_logits [1 , 0 , 11 ] = 10
49
78
target_logits [1 , 4 , 12 ] = 10
79
+
50
80
for i in range (target_logits .shape [0 ]):
51
81
for j in range (target_logits .shape [1 ]):
52
- if torch .max (target_logits [i ][ j ]) < 10 :
53
- target_logits [i ][ j ][ 18 ] = 10
82
+ if torch .max (target_logits [i , j ]) < 10 :
83
+ target_logits [i , j , 18 ] = 10
54
84
55
- temperatures = torch .tensor ([0.01 , 0.01 ], dtype = torch .float32 , device = "cuda" )
56
- predict_shape = (12 ,)
85
+ temperatures = torch .tensor ([0.01 , 0.01 ], dtype = torch .float32 , device = device )
86
+ bs , num_draft_tokens = candidates .shape
87
+ num_spec_step = len (expected_accept_index [0 ])
88
+ predict_shape = (len (expected_predicts ),)
57
89
58
- bs = candidates .shape [0 ]
59
- num_spec_step = 4
60
- num_draft_tokens = candidates .shape [1 ]
61
-
62
- predicts = torch .full (
63
- predict_shape , - 1 , dtype = torch .int32 , device = "cuda"
64
- ) # mutable
65
- accept_index = torch .full (
66
- (bs , num_spec_step ), - 1 , dtype = torch .int32 , device = "cuda"
67
- ) # mutable
68
- accept_token_num = torch .full ((bs ,), 0 , dtype = torch .int32 , device = "cuda" ) # mutable
90
+ predicts = torch .full (predict_shape , - 1 , dtype = torch .int32 , device = device )
91
+ accept_index = torch .full ((bs , num_spec_step ), - 1 , dtype = torch .int32 , device = device )
92
+ accept_token_num = torch .full ((bs ,), 0 , dtype = torch .int32 , device = device )
69
93
70
94
expanded_temperature = temperatures .unsqueeze (1 ).unsqueeze (1 )
71
95
target_probs = F .softmax (target_logits / expanded_temperature , dim = - 1 )
72
- draft_probs = torch .full_like (target_probs , 0 , dtype = torch .float32 , device = "cuda" )
73
-
74
- coins = torch .rand (bs , num_draft_tokens , device = "cuda" ).to (torch .float32 )
75
- print (f"{ candidates = } " )
76
- print (f"{ retrive_index = } " )
77
- print (f"{ retrive_next_token = } " )
78
- print (f"{ retrive_next_sibling = } " )
79
- print (f"{ coins = } " )
96
+ draft_probs = torch .full_like (target_probs , 0 , dtype = torch .float32 , device = device )
97
+ coins = torch .rand (bs , num_draft_tokens , device = device , dtype = torch .float32 )
80
98
81
99
tree_speculative_sampling_target_only (
82
100
predicts = predicts ,
@@ -94,24 +112,15 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
94
112
deterministic = True ,
95
113
)
96
114
97
- print (f"{ predicts = } " )
98
- print (f"{ accept_index = } " )
99
- print (f"{ accept_token_num = } " )
100
-
101
- if threshold_single == 1 and threshold_acc == 1 :
102
- assert predicts .tolist () == [3 , - 1 , - 1 , 4 , 5 , 18 , 11 , - 1 , - 1 , - 1 , 12 , 18 ]
103
- assert accept_index .tolist () == [
104
- [0 , 3 , 4 , 5 ],
105
- [6 , 10 , 11 , - 1 ],
106
- ]
107
- assert accept_token_num .tolist () == [3 , 2 ]
108
- elif threshold_single == 0 and threshold_acc == 0 :
109
- assert predicts .tolist () == [1 , 2 , 18 , - 1 , - 1 , - 1 , 11 , - 1 , - 1 , - 1 , 12 , 18 ]
110
- assert accept_index .tolist () == [
111
- [0 , 1 , 2 , - 1 ],
112
- [6 , 10 , 11 , - 1 ],
113
- ]
114
- assert accept_token_num .tolist () == [2 , 2 ]
115
+ assert (
116
+ predicts .tolist () == expected_predicts
117
+ ), f"Predicts mismatch for thresholds ({ threshold_single } , { threshold_acc } )"
118
+ assert (
119
+ accept_index .tolist () == expected_accept_index
120
+ ), f"Accept index mismatch for thresholds ({ threshold_single } , { threshold_acc } )"
121
+ assert (
122
+ accept_token_num .tolist () == expected_accept_token_num
123
+ ), f"Accept token num mismatch for thresholds ({ threshold_single } , { threshold_acc } )"
115
124
116
125
117
126
if __name__ == "__main__" :
0 commit comments