Skip to content

Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 #4760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Apr 8, 2025
Merged
Changes from 7 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
286ed41
add integration test for flash attention 3
yubofredwang Mar 25, 2025
48e904b
Merge branch 'sgl-project:main' into main
yubofredwang Mar 25, 2025
f208a10
minor fix
yubofredwang Mar 25, 2025
22ddd57
Merge branch 'main' of https://github.com/yubofredwang/sglang
yubofredwang Mar 25, 2025
d9ddd67
add end file
yubofredwang Mar 25, 2025
8bd7995
set datapath to None
yubofredwang Mar 25, 2025
0c03ef4
revert model
yubofredwang Mar 25, 2025
c80fe75
Merge branch 'main' into main
yubofredwang Mar 26, 2025
4b764c6
Merge branch 'main' into main
hebiao064 Mar 27, 2025
0b8e428
use llama 8b for testing
yubofredwang Mar 27, 2025
ca6beed
Merge branch 'sgl-project:main' into main
yubofredwang Mar 28, 2025
7c6bf1c
Merge branch 'main' into main
hebiao064 Mar 31, 2025
b09d4e2
Use sglkernel's fa3
hebiao064 Mar 31, 2025
588df10
Merge branch 'main' into add-test-fa3
hebiao064 Mar 31, 2025
8c9ca10
Merge pull request #1 from yubofredwang/add-test-fa3
hebiao064 Mar 31, 2025
acdc853
add mla integration tes
yubofredwang Apr 1, 2025
4a5fb80
format fix
yubofredwang Apr 1, 2025
61d6cdc
Merge branch 'main' into main
hebiao064 Apr 1, 2025
031ae71
Merge branch 'sgl-project:main' into main
yubofredwang Apr 1, 2025
fa4a193
fix circular dependency
yubofredwang Apr 1, 2025
8fe5e99
refactor the code
yubofredwang Apr 3, 2025
79d6bee
page size done
yubofredwang Apr 3, 2025
112996a
add mla unit test, clean up int test
yubofredwang Apr 4, 2025
6f0fd66
merge main
yubofredwang Apr 4, 2025
ffbd153
skip test if < 90
yubofredwang Apr 4, 2025
0ddab6b
Merge branch 'main' into main
yubofredwang Apr 4, 2025
46f8aae
fix device sm
yubofredwang Apr 4, 2025
fd867eb
Merge branch 'main' into main
yubofredwang Apr 4, 2025
8bda15b
Merge branch 'main' into main
hebiao064 Apr 4, 2025
7fe328c
merge main
yubofredwang Apr 5, 2025
b5357ff
Merge branch 'main' into main
zhyncs Apr 6, 2025
0a5f3d4
Merge remote-tracking branch 'upstream/main' into main
yubofredwang Apr 7, 2025
8e1f76b
Fix conflict
hebiao064 Apr 8, 2025
272ce3f
Merge pull request #2 from yubofredwang/fa3_merge_conflict
hebiao064 Apr 8, 2025
bd29b3a
fix according to comments
yubofredwang Apr 8, 2025
b452023
Merge branch 'main' of https://github.com/yubofredwang/sglang into main
yubofredwang Apr 8, 2025
28e7d86
Merge branch 'main' into main
yubofredwang Apr 8, 2025
a5274b0
fix comments
yubofredwang Apr 8, 2025
ec15675
Merge branch 'main' of https://github.com/yubofredwang/sglang into main
yubofredwang Apr 8, 2025
2fca209
fix disable mla
yubofredwang Apr 8, 2025
389d4f0
Merge branch 'main' into main
yubofredwang Apr 8, 2025
03f4a29
Merge branch 'main' into main
yubofredwang Apr 8, 2025
a566531
trim
yubofredwang Apr 8, 2025
f7cfc36
Merge branch 'main' into main
zhyncs Apr 8, 2025
d01c6d7
use 3 draft tokens
yubofredwang Apr 8, 2025
77d6ef4
Merge pull request #3 from yubofredwang/add-spec-dec-top-1
hebiao064 Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions test/srt/test_flash_attention3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import unittest
from types import SimpleNamespace

import requests
import torch

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)

"""
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py
"""
# Change to your own model if testing model is not public.
MODEL_USED_FOR_TEST = "lmsys/sglang-ci-dsv3-test"
# Setting data path to None uses default data path in few_shot_gsm8k eval test.
DATA_PATH = None


class TestFlashAttention3(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.model = MODEL_USED_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--enable-torch-compile",
"--cuda-graph-max-bs",
"2",
"--attention-backend",
"fa3",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
data_path=DATA_PATH,
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)

self.assertGreater(metrics["accuracy"], 0.62)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)


class TestFlashAttention3DisableCudaGraph(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_USED_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--enable-torch-compile",
"--disable-cuda-graph",
"--cuda-graph-max-bs",
"4",
"--attention-backend",
"fa3",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
data_path=DATA_PATH,
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)

self.assertGreater(metrics["accuracy"], 0.62)


if __name__ == "__main__":
unittest.main()