Skip to content

Commit ed01b45

Browse files
authored
[Misc] Clean sgl-kernel test (#5216)
1 parent d050df3 commit ed01b45

8 files changed

+63
-76
lines changed

sgl-kernel/tests/speculative/test_eagle_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def test_verify_tree_greedy():
4949
if torch.max(target_logits[i][j]) < 10:
5050
target_logits[i][j][18] = 10
5151

52-
print(f"{target_logits=}")
5352
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
5453
predict_shape = (12,)
5554

@@ -65,12 +64,6 @@ def test_verify_tree_greedy():
6564
) # mutable
6665
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
6766

68-
print(f"{candidates=}")
69-
print(f"{retrive_index=}")
70-
print(f"{retrive_next_token=}")
71-
print(f"{retrive_next_sibling=}")
72-
print(f"{target_predict=}")
73-
7467
verify_tree_greedy(
7568
predicts=predicts,
7669
accept_index=accept_index,
@@ -82,10 +75,6 @@ def test_verify_tree_greedy():
8275
target_predict=target_predict,
8376
)
8477

85-
print(f"{predicts=}")
86-
print(f"{accept_index=}")
87-
print(f"{accept_token_num=}")
88-
8978
# Check the expected output.
9079
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
9180
assert accept_index.tolist() == [

sgl-kernel/tests/speculative/test_speculative_sampling.py

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,80 +3,98 @@
33
import torch.nn.functional as F
44
from sgl_kernel import tree_speculative_sampling_target_only
55

6+
test_cases = [
7+
(
8+
1,
9+
1,
10+
[3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18],
11+
[[0, 3, 4, 5], [6, 10, 11, -1]],
12+
[3, 2],
13+
),
14+
(
15+
0, # threshold_single
16+
0, # threshold_acc
17+
[1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18],
18+
[[0, 1, 2, -1], [6, 10, 11, -1]],
19+
[2, 2],
20+
),
21+
]
22+
23+
24+
@pytest.mark.parametrize(
25+
"threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num",
26+
test_cases,
27+
)
28+
def test_tree_speculative_sampling_target_only(
29+
threshold_single,
30+
threshold_acc,
31+
expected_predicts,
32+
expected_accept_index,
33+
expected_accept_token_num,
34+
):
35+
"""
36+
Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
37+
"""
38+
device = "cuda"
639

7-
def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1):
8-
print(
9-
f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n"
10-
)
1140
candidates = torch.tensor(
1241
[
1342
[0, 1, 2, 3, 4, 5],
1443
[7, 8, 9, 10, 11, 12],
1544
],
1645
dtype=torch.int32,
17-
device="cuda",
46+
device=device,
1847
)
1948
retrive_index = torch.tensor(
2049
[
2150
[0, 1, 2, 3, 4, 5],
2251
[6, 7, 8, 9, 10, 11],
2352
],
2453
dtype=torch.int32,
25-
device="cuda",
54+
device=device,
2655
)
2756
retrive_next_token = torch.tensor(
2857
[
2958
[1, 2, -1, 4, 5, -1],
3059
[4, 2, 3, -1, 5, -1],
3160
],
3261
dtype=torch.int32,
33-
device="cuda",
62+
device=device,
3463
)
3564
retrive_next_sibling = torch.tensor(
3665
[
3766
[-1, 3, -1, -1, -1, -1],
3867
[-1, -1, -1, -1, 1, -1],
3968
],
4069
dtype=torch.int32,
41-
device="cuda",
70+
device=device,
4271
)
4372

44-
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
73+
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device)
4574
target_logits[0, 0, 3] = 10
4675
target_logits[0, 3, 4] = 10
4776
target_logits[0, 4, 5] = 10
4877
target_logits[1, 0, 11] = 10
4978
target_logits[1, 4, 12] = 10
79+
5080
for i in range(target_logits.shape[0]):
5181
for j in range(target_logits.shape[1]):
52-
if torch.max(target_logits[i][j]) < 10:
53-
target_logits[i][j][18] = 10
82+
if torch.max(target_logits[i, j]) < 10:
83+
target_logits[i, j, 18] = 10
5484

55-
temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device="cuda")
56-
predict_shape = (12,)
85+
temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device)
86+
bs, num_draft_tokens = candidates.shape
87+
num_spec_step = len(expected_accept_index[0])
88+
predict_shape = (len(expected_predicts),)
5789

58-
bs = candidates.shape[0]
59-
num_spec_step = 4
60-
num_draft_tokens = candidates.shape[1]
61-
62-
predicts = torch.full(
63-
predict_shape, -1, dtype=torch.int32, device="cuda"
64-
) # mutable
65-
accept_index = torch.full(
66-
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
67-
) # mutable
68-
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
90+
predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device)
91+
accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device)
92+
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device)
6993

7094
expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1)
7195
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
72-
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device="cuda")
73-
74-
coins = torch.rand(bs, num_draft_tokens, device="cuda").to(torch.float32)
75-
print(f"{candidates=}")
76-
print(f"{retrive_index=}")
77-
print(f"{retrive_next_token=}")
78-
print(f"{retrive_next_sibling=}")
79-
print(f"{coins=}")
96+
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
97+
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
8098

8199
tree_speculative_sampling_target_only(
82100
predicts=predicts,
@@ -94,24 +112,15 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
94112
deterministic=True,
95113
)
96114

97-
print(f"{predicts=}")
98-
print(f"{accept_index=}")
99-
print(f"{accept_token_num=}")
100-
101-
if threshold_single == 1 and threshold_acc == 1:
102-
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
103-
assert accept_index.tolist() == [
104-
[0, 3, 4, 5],
105-
[6, 10, 11, -1],
106-
]
107-
assert accept_token_num.tolist() == [3, 2]
108-
elif threshold_single == 0 and threshold_acc == 0:
109-
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
110-
assert accept_index.tolist() == [
111-
[0, 1, 2, -1],
112-
[6, 10, 11, -1],
113-
]
114-
assert accept_token_num.tolist() == [2, 2]
115+
assert (
116+
predicts.tolist() == expected_predicts
117+
), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})"
118+
assert (
119+
accept_index.tolist() == expected_accept_index
120+
), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})"
121+
assert (
122+
accept_token_num.tolist() == expected_accept_token_num
123+
), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})"
115124

116125

117126
if __name__ == "__main__":

sgl-kernel/tests/test_fp8_blockwise_gemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
7979
rtol = 0.02
8080
atol = 1
8181
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
82-
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
8382

8483

8584
@pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])

sgl-kernel/tests/test_int8_gemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
2828
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
2929
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
3030
torch.testing.assert_close(o, o1)
31-
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
3231

3332

3433
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])

sgl-kernel/tests/test_lightning_attention_decode.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,13 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim
7070
ref_output,
7171
rtol=rtol,
7272
atol=atol,
73-
msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
74-
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
7573
)
7674

7775
torch.testing.assert_close(
7876
new_kv,
7977
ref_new_kv,
8078
rtol=rtol,
8179
atol=atol,
82-
msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
83-
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
8480
)
8581

8682

sgl-kernel/tests/test_moe_topk_softmax.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
4242
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
4343
), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}"
4444

45-
assert torch.equal(
46-
topk_indices_ref, topk_indices
45+
assert torch.allclose(
46+
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
4747
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"
4848

49-
print("✅ Native torch and custom kernel implementations match.")
50-
5149

5250
if __name__ == "__main__":
5351
pytest.main([__file__])

sgl-kernel/tests/test_per_token_group_quant_8bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,10 @@ def test_per_token_group_quant_with_column_major(
304304
scale_tma_aligned=scale_tma_aligned,
305305
)
306306

307-
assert torch.allclose(
307+
torch.testing.assert_close(
308308
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
309309
)
310-
assert torch.allclose(
310+
torch.testing.assert_close(
311311
x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5
312312
)
313313

sgl-kernel/tests/test_rotary_embedding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,6 @@ def test_correctness(
187187
pos_ids, query_flashinfer, key_flashinfer
188188
)
189189

190-
print(query_ref_out)
191-
print(query_flashinfer_out)
192-
193190
torch.testing.assert_close(
194191
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
195192
)

0 commit comments

Comments
 (0)