5
5
from dataclasses import dataclass
6
6
from typing import List , Optional
7
7
8
+ import pytest
8
9
import torch
9
10
from tqdm import tqdm
10
11
30
31
from megatron .core .models .gpt .gpt_model import GPTModel
31
32
from megatron .core .tensor_parallel .random import model_parallel_cuda_manual_seed
32
33
from megatron .core .transformer .transformer_config import TransformerConfig
34
+ from megatron .core .utils import is_fa_min_version
33
35
from tests .unit_tests .test_utilities import Utils
34
36
35
37
DynamicInferenceContext .ROUNDER = 4 # decreased from 64 for unit tests.
@@ -310,6 +312,9 @@ def setup_method(self, method):
310
312
def teardown_method (self , method ):
311
313
Utils .destroy_model_parallel ()
312
314
315
+ @pytest .mark .skipif (
316
+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
317
+ )
313
318
def test_simple (self ) -> None :
314
319
"""Simple test that runs without errors, and validates output."""
315
320
@@ -336,6 +341,9 @@ def test_simple(self) -> None:
336
341
for request , expected_output in zip (env .requests , expected_outputs ):
337
342
assert request .output == expected_output
338
343
344
+ @pytest .mark .skipif (
345
+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
346
+ )
339
347
def test_overflow_factor (self ) -> None :
340
348
"""Test overflow factor arg."""
341
349
@@ -350,6 +358,9 @@ def test_overflow_factor(self) -> None:
350
358
assert env .engine .context .max_requests == 1120
351
359
assert env .engine .context .max_tokens == 1120
352
360
361
+ @pytest .mark .skipif (
362
+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
363
+ )
353
364
def test_request_overflow (self ) -> None :
354
365
"""Test request overflow."""
355
366
try :
@@ -358,6 +369,9 @@ def test_request_overflow(self) -> None:
358
369
return
359
370
raise Exception ("failed." )
360
371
372
+ @pytest .mark .skipif (
373
+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
374
+ )
361
375
def test_token_overflow (self ) -> None :
362
376
"""Test token overflow."""
363
377
try :
@@ -366,6 +380,9 @@ def test_token_overflow(self) -> None:
366
380
return
367
381
raise Exception ("failed." )
368
382
383
+ @pytest .mark .skipif (
384
+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
385
+ )
369
386
def test_chunk_overflow (self ) -> None :
370
387
"""Test chunk overflow."""
371
388
env = self ._build_test_env (TestConfig ())
@@ -378,10 +395,16 @@ def test_chunk_overflow(self) -> None:
378
395
return
379
396
raise Exception ("failed." )
380
397
398
+ @pytest .mark .skipif (
399
+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
400
+ )
381
401
def test_multi_add (self ) -> None :
382
402
"""Test adding multiple requests simultaneously."""
383
403
self ._run_test (num_gap_steps = 0 )
384
404
405
+ @pytest .mark .skipif (
406
+ not is_fa_min_version ("2.7.3" ), reason = "need latest flash attn for dynamic batching"
407
+ )
385
408
def test_fixed_output_lengths (self ) -> None :
386
409
"""Test generating a fixed number of output tokens."""
387
410
self ._run_test (use_fixed_output_lengths = True )
0 commit comments