-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add unit test on page_size > 1 and mla and integration test for Flash Attention 3 #4760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
286ed41
add integration test for flash attention 3
yubofredwang 48e904b
Merge branch 'sgl-project:main' into main
yubofredwang f208a10
minor fix
yubofredwang 22ddd57
Merge branch 'main' of https://github.com/yubofredwang/sglang
yubofredwang d9ddd67
add end file
yubofredwang 8bd7995
set datapath to None
yubofredwang 0c03ef4
revert model
yubofredwang c80fe75
Merge branch 'main' into main
yubofredwang 4b764c6
Merge branch 'main' into main
hebiao064 0b8e428
use llama 8b for testing
yubofredwang ca6beed
Merge branch 'sgl-project:main' into main
yubofredwang 7c6bf1c
Merge branch 'main' into main
hebiao064 b09d4e2
Use sglkernel's fa3
hebiao064 588df10
Merge branch 'main' into add-test-fa3
hebiao064 8c9ca10
Merge pull request #1 from yubofredwang/add-test-fa3
hebiao064 acdc853
add mla integration tes
yubofredwang 4a5fb80
format fix
yubofredwang 61d6cdc
Merge branch 'main' into main
hebiao064 031ae71
Merge branch 'sgl-project:main' into main
yubofredwang fa4a193
fix circular dependency
yubofredwang 8fe5e99
refactor the code
yubofredwang 79d6bee
page size done
yubofredwang 112996a
add mla unit test, clean up int test
yubofredwang 6f0fd66
merge main
yubofredwang ffbd153
skip test if < 90
yubofredwang 0ddab6b
Merge branch 'main' into main
yubofredwang 46f8aae
fix device sm
yubofredwang fd867eb
Merge branch 'main' into main
yubofredwang 8bda15b
Merge branch 'main' into main
hebiao064 7fe328c
merge main
yubofredwang b5357ff
Merge branch 'main' into main
zhyncs 0a5f3d4
Merge remote-tracking branch 'upstream/main' into main
yubofredwang 8e1f76b
Fix conflict
hebiao064 272ce3f
Merge pull request #2 from yubofredwang/fa3_merge_conflict
hebiao064 bd29b3a
fix according to comments
yubofredwang b452023
Merge branch 'main' of https://github.com/yubofredwang/sglang into main
yubofredwang 28e7d86
Merge branch 'main' into main
yubofredwang a5274b0
fix comments
yubofredwang ec15675
Merge branch 'main' of https://github.com/yubofredwang/sglang into main
yubofredwang 2fca209
fix disable mla
yubofredwang 389d4f0
Merge branch 'main' into main
yubofredwang 03f4a29
Merge branch 'main' into main
yubofredwang a566531
trim
yubofredwang f7cfc36
Merge branch 'main' into main
zhyncs d01c6d7
use 3 draft tokens
yubofredwang 77d6ef4
Merge pull request #3 from yubofredwang/add-spec-dec-top-1
hebiao064 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import unittest | ||
from types import SimpleNamespace | ||
|
||
import requests | ||
import torch | ||
|
||
from sglang.srt.utils import kill_process_tree | ||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k | ||
from sglang.test.test_utils import ( | ||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
DEFAULT_URL_FOR_TEST, | ||
popen_launch_server, | ||
) | ||
|
||
""" | ||
Integration test for python/sglang/srt/layers/attention/flashattention_backend.py | ||
""" | ||
# Change to your own model if testing model is not public. | ||
MODEL_USED_FOR_TEST = "lmsys/sglang-ci-dsv3-test" | ||
Fridge003 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Setting data path to None uses default data path in few_shot_gsm8k eval test. | ||
DATA_PATH = None | ||
|
||
|
||
class TestFlashAttention3(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.model = MODEL_USED_FOR_TEST | ||
cls.base_url = DEFAULT_URL_FOR_TEST | ||
other_args = ["--trust-remote-code"] | ||
if torch.cuda.is_available() and torch.version.cuda: | ||
hebiao064 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
other_args.extend( | ||
[ | ||
"--enable-torch-compile", | ||
"--cuda-graph-max-bs", | ||
"2", | ||
"--attention-backend", | ||
"fa3", | ||
] | ||
) | ||
cls.process = popen_launch_server( | ||
cls.model, | ||
cls.base_url, | ||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
other_args=other_args, | ||
) | ||
|
||
def test_gsm8k(self): | ||
args = SimpleNamespace( | ||
num_shots=5, | ||
num_questions=200, | ||
max_new_tokens=512, | ||
parallel=128, | ||
host="http://127.0.0.1", | ||
port=int(self.base_url.split(":")[-1]), | ||
data_path=DATA_PATH, | ||
) | ||
metrics = run_eval_few_shot_gsm8k(args) | ||
print(metrics) | ||
|
||
self.assertGreater(metrics["accuracy"], 0.62) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
kill_process_tree(cls.process.pid) | ||
|
||
|
||
class TestFlashAttention3DisableCudaGraph(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.model = MODEL_USED_FOR_TEST | ||
cls.base_url = DEFAULT_URL_FOR_TEST | ||
other_args = ["--trust-remote-code"] | ||
if torch.cuda.is_available() and torch.version.cuda: | ||
other_args.extend( | ||
[ | ||
"--enable-torch-compile", | ||
"--disable-cuda-graph", | ||
"--cuda-graph-max-bs", | ||
hebiao064 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"4", | ||
"--attention-backend", | ||
"fa3", | ||
] | ||
) | ||
cls.process = popen_launch_server( | ||
cls.model, | ||
cls.base_url, | ||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
other_args=other_args, | ||
) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
kill_process_tree(cls.process.pid) | ||
|
||
def test_gsm8k(self): | ||
args = SimpleNamespace( | ||
num_shots=5, | ||
num_questions=200, | ||
max_new_tokens=512, | ||
parallel=128, | ||
host="http://127.0.0.1", | ||
port=int(self.base_url.split(":")[-1]), | ||
data_path=DATA_PATH, | ||
) | ||
metrics = run_eval_few_shot_gsm8k(args) | ||
print(metrics) | ||
|
||
self.assertGreater(metrics["accuracy"], 0.62) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.