Skip to content

Commit 46ebf27

Browse files
committed
minor lints
Signed-off-by: Peter St. John <[email protected]>
1 parent f079afd commit 46ebf27

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/accelerate/state.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -944,12 +944,13 @@ def __init__(
944944
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
945945
"before using any functionality from the `accelerate` library."
946946
)
947-
# deepspeed handles mixed_precision using deepspeed_config
948-
self._mixed_precision = (
949-
"no"
950-
if (self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8")
951-
else mixed_precision
952-
)
947+
# deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
948+
# if we're using fp8.
949+
if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
950+
self._mixed_precision = "no"
951+
else:
952+
self._mixed_precision = mixed_precision
953+
953954
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
954955
if mixed_precision == "bf16":
955956
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):

src/accelerate/test_utils/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,11 @@ def require_transformer_engine(test_case):
543543

544544
def require_transformer_engine_mxfp8(test_case):
545545
"""
546-
Decorator marking a test that requires transformers engine FP8 block scaling available. These tests are skipped
547-
when transformers engine FP8 block scaling isn't available
546+
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
547+
when transformers engine MXFP8 block scaling isn't available
548548
"""
549549
return unittest.skipUnless(
550-
is_transformer_engine_mxfp8_available(), "test requires transformers engine FP8 block scaling"
550+
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
551551
)(test_case)
552552

553553

0 commit comments

Comments
 (0)