Skip to content

[Misc] Clean sgl-kernel test #5216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions sgl-kernel/tests/speculative/test_eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_verify_tree_greedy():
if torch.max(target_logits[i][j]) < 10:
target_logits[i][j][18] = 10

print(f"{target_logits=}")
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
predict_shape = (12,)

Expand All @@ -65,12 +64,6 @@ def test_verify_tree_greedy():
) # mutable
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable

print(f"{candidates=}")
print(f"{retrive_index=}")
print(f"{retrive_next_token=}")
print(f"{retrive_next_sibling=}")
print(f"{target_predict=}")

verify_tree_greedy(
predicts=predicts,
accept_index=accept_index,
Expand All @@ -82,10 +75,6 @@ def test_verify_tree_greedy():
target_predict=target_predict,
)

print(f"{predicts=}")
print(f"{accept_index=}")
print(f"{accept_token_num=}")

# Check the expected output.
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
Expand Down
109 changes: 59 additions & 50 deletions sgl-kernel/tests/speculative/test_speculative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,80 +3,98 @@
import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only

test_cases = [
(
1,
1,
[3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18],
[[0, 3, 4, 5], [6, 10, 11, -1]],
[3, 2],
),
(
0, # threshold_single
0, # threshold_acc
[1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18],
[[0, 1, 2, -1], [6, 10, 11, -1]],
[2, 2],
),
]


@pytest.mark.parametrize(
"threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num",
test_cases,
)
def test_tree_speculative_sampling_target_only(
threshold_single,
threshold_acc,
expected_predicts,
expected_accept_index,
expected_accept_token_num,
):
"""
Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
"""
device = "cuda"

def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1):
print(
f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n"
)
candidates = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int32,
device="cuda",
device=device,
)
retrive_index = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int32,
device="cuda",
device=device,
)
retrive_next_token = torch.tensor(
[
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int32,
device="cuda",
device=device,
)
retrive_next_sibling = torch.tensor(
[
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int32,
device="cuda",
device=device,
)

target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device)
target_logits[0, 0, 3] = 10
target_logits[0, 3, 4] = 10
target_logits[0, 4, 5] = 10
target_logits[1, 0, 11] = 10
target_logits[1, 4, 12] = 10

for i in range(target_logits.shape[0]):
for j in range(target_logits.shape[1]):
if torch.max(target_logits[i][j]) < 10:
target_logits[i][j][18] = 10
if torch.max(target_logits[i, j]) < 10:
target_logits[i, j, 18] = 10

temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device="cuda")
predict_shape = (12,)
temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device)
bs, num_draft_tokens = candidates.shape
num_spec_step = len(expected_accept_index[0])
predict_shape = (len(expected_predicts),)

bs = candidates.shape[0]
num_spec_step = 4
num_draft_tokens = candidates.shape[1]

predicts = torch.full(
predict_shape, -1, dtype=torch.int32, device="cuda"
) # mutable
accept_index = torch.full(
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
) # mutable
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device)
accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device)
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device)

expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1)
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device="cuda")

coins = torch.rand(bs, num_draft_tokens, device="cuda").to(torch.float32)
print(f"{candidates=}")
print(f"{retrive_index=}")
print(f"{retrive_next_token=}")
print(f"{retrive_next_sibling=}")
print(f"{coins=}")
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)

tree_speculative_sampling_target_only(
predicts=predicts,
Expand All @@ -94,24 +112,15 @@ def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc
deterministic=True,
)

print(f"{predicts=}")
print(f"{accept_index=}")
print(f"{accept_token_num=}")

if threshold_single == 1 and threshold_acc == 1:
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 3, 4, 5],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [3, 2]
elif threshold_single == 0 and threshold_acc == 0:
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
assert accept_index.tolist() == [
[0, 1, 2, -1],
[6, 10, 11, -1],
]
assert accept_token_num.tolist() == [2, 2]
assert (
predicts.tolist() == expected_predicts
), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})"
assert (
accept_index.tolist() == expected_accept_index
), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})"
assert (
accept_token_num.tolist() == expected_accept_token_num
), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})"


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion sgl-kernel/tests/test_fp8_blockwise_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")


@pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])
Expand Down
1 change: 0 additions & 1 deletion sgl-kernel/tests/test_int8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")


@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
Expand Down
4 changes: 0 additions & 4 deletions sgl-kernel/tests/test_lightning_attention_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,13 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim
ref_output,
rtol=rtol,
atol=atol,
msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
)

torch.testing.assert_close(
new_kv,
ref_new_kv,
rtol=rtol,
atol=atol,
msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
)


Expand Down
6 changes: 2 additions & 4 deletions sgl-kernel/tests/test_moe_topk_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: torch={topk_indices_ref} vs SGLang={topk_weights}"

assert torch.equal(
topk_indices_ref, topk_indices
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"

print("✅ Native torch and custom kernel implementations match.")


if __name__ == "__main__":
pytest.main([__file__])
4 changes: 2 additions & 2 deletions sgl-kernel/tests/test_per_token_group_quant_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ def test_per_token_group_quant_with_column_major(
scale_tma_aligned=scale_tma_aligned,
)

assert torch.allclose(
torch.testing.assert_close(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
)
assert torch.allclose(
torch.testing.assert_close(
x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5
)

Expand Down
3 changes: 0 additions & 3 deletions sgl-kernel/tests/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ def test_correctness(
pos_ids, query_flashinfer, key_flashinfer
)

print(query_ref_out)
print(query_flashinfer_out)

torch.testing.assert_close(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
)
Expand Down
Loading