|
7 | 7 | from sglang.srt.utils import get_device_sm, kill_process_tree
|
8 | 8 | from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
9 | 9 | from sglang.test.test_utils import (
|
| 10 | + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, |
| 11 | + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, |
10 | 12 | DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
11 | 13 | DEFAULT_MODEL_NAME_FOR_TEST,
|
12 | 14 | DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
@@ -123,11 +125,56 @@ def get_server_args(cls):
|
123 | 125 | class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
124 | 126 | """Test FlashAttention3 with speculative decode enabled."""
|
125 | 127 |
|
| 128 | + model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST |
| 129 | + |
126 | 130 | @classmethod
|
127 | 131 | def get_server_args(cls):
|
128 | 132 | args = super().get_server_args()
|
| 133 | + args.extend( |
| 134 | + [ |
| 135 | + "--cuda-graph-max-bs", |
| 136 | + "2", |
| 137 | + "--speculative-algorithm", |
| 138 | + "EAGLE3", |
| 139 | + "--speculative-draft", |
| 140 | + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, |
| 141 | + "--speculative-num-steps", |
| 142 | + "3", |
| 143 | + "--speculative-eagle-topk", |
| 144 | + "1", |
| 145 | + "--speculative-num-draft-tokens", |
| 146 | + "3", |
| 147 | + "--dtype", |
| 148 | + "float16", |
| 149 | + ] |
| 150 | + ) |
129 | 151 | return args
|
130 | 152 |
|
| 153 | + def test_gsm8k(self): |
| 154 | + """ |
| 155 | + Override the test_gsm8k to further test for average speculative accept length. |
| 156 | + """ |
| 157 | + requests.get(self.base_url + "/flush_cache") |
| 158 | + |
| 159 | + args = SimpleNamespace( |
| 160 | + num_shots=5, |
| 161 | + data_path=DATA_PATH, |
| 162 | + num_questions=200, |
| 163 | + max_new_tokens=512, |
| 164 | + parallel=128, |
| 165 | + host="http://127.0.0.1", |
| 166 | + port=int(self.base_url.split(":")[-1]), |
| 167 | + ) |
| 168 | + metrics = run_eval_few_shot_gsm8k(args) |
| 169 | + print(metrics) |
| 170 | + |
| 171 | + self.assertGreater(metrics["accuracy"], 0.60) |
| 172 | + |
| 173 | + server_info = requests.get(self.base_url + "/get_server_info") |
| 174 | + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] |
| 175 | + print(f"{avg_spec_accept_length=}") |
| 176 | + self.assertGreater(avg_spec_accept_length, 1.5) |
| 177 | + |
131 | 178 |
|
132 | 179 | if __name__ == "__main__":
|
133 | 180 | unittest.main()
|
0 commit comments