Skip to content

Commit f046b05

Browse files
committed
ruff reformat
1 parent 77cb1d4 commit f046b05

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

src/accelerate/test_utils/testing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,13 @@ def require_transformer_engine(test_case):
543543

544544
def require_transformer_engine_mxfp8(test_case):
545545
"""
546-
Decorator marking a test that requires transformers engine FP8 block scaling available. These tests are skipped when transformers
547-
engine FP8 block scaling isn't available
546+
Decorator marking a test that requires transformers engine FP8 block scaling available. These tests are skipped
547+
when transformers engine FP8 block scaling isn't available
548548
"""
549-
return unittest.skipUnless(is_transformer_engine_mxfp8_available(), "test requires transformers engine FP8 block scaling")(test_case)
549+
return unittest.skipUnless(
550+
is_transformer_engine_mxfp8_available(), "test requires transformers engine FP8 block scaling"
551+
)(test_case)
552+
550553

551554
def require_torchao(test_case):
552555
"""

src/accelerate/utils/imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ def is_transformer_engine_available():
117117
def is_transformer_engine_mxfp8_available():
118118
if _is_package_available("transformer_engine", "transformer-engine"):
119119
import transformer_engine.pytorch as te
120+
120121
return te.fp8.check_mxfp8_support()[0]
121122
return False
122123

124+
123125
def is_lomo_available():
124126
return _is_package_available("lomo_optim")
125127

src/accelerate/utils/transformer_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,14 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
146146

147147
if is_hpu_available():
148148
import intel_transformer_engine.recipe as te_recipe
149+
149150
is_fp8_block_scaling_available = False
150151
message = "MXFP8 block scaling is not available on HPU."
151152

152153
else:
153154
import transformer_engine.common.recipe as te_recipe
154155
import transformer_engine.pytorch as te
156+
155157
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
156158

157159
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}

tests/test_fp8.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def test_can_prepare_model_single_gpu_from_config(self):
110110
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
111111
run_command(command)
112112

113-
114113
@require_transformer_engine_mxfp8
115114
def test_can_prepare_model_with_mxfp8_block_scaling(self):
116115
with tempfile.TemporaryDirectory() as dir_name:

0 commit comments

Comments
 (0)