Skip to content

Commit 408eb71

Browse files
committed
Merge branch 'fix_12' into 'core_r0.12.0'
Adding skipif to legacy tests See merge request ADLR/megatron-lm!3107
2 parents 8bda844 + 66b6283 commit 408eb71

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

megatron/core/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
# This is a WAR for building docs, where torch is not actually imported
4848
_torch_version = PkgVersion("0.0.0")
4949
_te_version = None
50+
_fa_version = None
5051

5152

5253
class ExperimentalNotEnabledError(Exception):
@@ -279,6 +280,30 @@ def is_torch_min_version(version, check_equality=True):
279280
return get_torch_version() > PkgVersion(version)
280281

281282

283+
def get_fa_version():
284+
"""Get Flash attention version from __version__; if not available use pip's. Use caching."""
285+
286+
def get_fa_version_str():
287+
import flash_attn as fa
288+
289+
if hasattr(fa, '__version__'):
290+
return str(fa.__version__)
291+
else:
292+
return version("flash-attn")
293+
294+
global _fa_version
295+
if _fa_version is None:
296+
_fa_version = PkgVersion(get_fa_version_str())
297+
return _fa_version
298+
299+
300+
def is_fa_min_version(version, check_equality=True):
301+
"""Check if minimum version of `flash-attn` is installed."""
302+
if check_equality:
303+
return get_fa_version() >= PkgVersion(version)
304+
return get_fa_version() > PkgVersion(version)
305+
306+
282307
def ensure_divisibility(numerator, denominator):
283308
"""Ensure that numerator is divisible by the denominator."""
284309
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)

tests/unit_tests/inference/engines/test_dynamic_engine.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import dataclass
66
from typing import List, Optional
77

8+
import pytest
89
import torch
910
from tqdm import tqdm
1011

@@ -30,6 +31,7 @@
3031
from megatron.core.models.gpt.gpt_model import GPTModel
3132
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
3233
from megatron.core.transformer.transformer_config import TransformerConfig
34+
from megatron.core.utils import is_fa_min_version
3335
from tests.unit_tests.test_utilities import Utils
3436

3537
DynamicInferenceContext.ROUNDER = 4 # decreased from 64 for unit tests.
@@ -310,6 +312,9 @@ def setup_method(self, method):
310312
def teardown_method(self, method):
311313
Utils.destroy_model_parallel()
312314

315+
@pytest.mark.skipif(
316+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
317+
)
313318
def test_simple(self) -> None:
314319
"""Simple test that runs without errors, and validates output."""
315320

@@ -336,6 +341,9 @@ def test_simple(self) -> None:
336341
for request, expected_output in zip(env.requests, expected_outputs):
337342
assert request.output == expected_output
338343

344+
@pytest.mark.skipif(
345+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
346+
)
339347
def test_overflow_factor(self) -> None:
340348
"""Test overflow factor arg."""
341349

@@ -350,6 +358,9 @@ def test_overflow_factor(self) -> None:
350358
assert env.engine.context.max_requests == 1120
351359
assert env.engine.context.max_tokens == 1120
352360

361+
@pytest.mark.skipif(
362+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
363+
)
353364
def test_request_overflow(self) -> None:
354365
"""Test request overflow."""
355366
try:
@@ -358,6 +369,9 @@ def test_request_overflow(self) -> None:
358369
return
359370
raise Exception("failed.")
360371

372+
@pytest.mark.skipif(
373+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
374+
)
361375
def test_token_overflow(self) -> None:
362376
"""Test token overflow."""
363377
try:
@@ -366,6 +380,9 @@ def test_token_overflow(self) -> None:
366380
return
367381
raise Exception("failed.")
368382

383+
@pytest.mark.skipif(
384+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
385+
)
369386
def test_chunk_overflow(self) -> None:
370387
"""Test chunk overflow."""
371388
env = self._build_test_env(TestConfig())
@@ -378,10 +395,16 @@ def test_chunk_overflow(self) -> None:
378395
return
379396
raise Exception("failed.")
380397

398+
@pytest.mark.skipif(
399+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
400+
)
381401
def test_multi_add(self) -> None:
382402
"""Test adding multiple requests simultaneously."""
383403
self._run_test(num_gap_steps=0)
384404

405+
@pytest.mark.skipif(
406+
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
407+
)
385408
def test_fixed_output_lengths(self) -> None:
386409
"""Test generating a fixed number of output tokens."""
387410
self._run_test(use_fixed_output_lengths=True)

0 commit comments

Comments
 (0)