Skip to content

Commit 1fb8f76

Browse files
committed
minor lints
Signed-off-by: Peter St. John <[email protected]>
1 parent 131b116 commit 1fb8f76

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/accelerate/state.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -948,12 +948,13 @@ def __init__(
948948
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
949949
"before using any functionality from the `accelerate` library."
950950
)
951-
# deepspeed handles mixed_precision using deepspeed_config
952-
self._mixed_precision = (
953-
"no"
954-
if (self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8")
955-
else mixed_precision
956-
)
951+
# deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
952+
# if we're using fp8.
953+
if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
954+
self._mixed_precision = "no"
955+
else:
956+
self._mixed_precision = mixed_precision
957+
957958
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
958959
if mixed_precision == "bf16":
959960
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):

src/accelerate/test_utils/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,11 @@ def require_transformer_engine(test_case):
543543

544544
def require_transformer_engine_mxfp8(test_case):
545545
"""
546-
Decorator marking a test that requires transformers engine FP8 block scaling available. These tests are skipped
547-
when transformers engine FP8 block scaling isn't available
546+
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
547+
when transformers engine MXFP8 block scaling isn't available
548548
"""
549549
return unittest.skipUnless(
550-
is_transformer_engine_mxfp8_available(), "test requires transformers engine FP8 block scaling"
550+
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
551551
)(test_case)
552552

553553

0 commit comments

Comments
 (0)