Skip to content

Commit edbe9d5

Browse files
committed
add and fix test for deepspeed / fp8 from config
1 parent 19077b3 commit edbe9d5

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/accelerate/state.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,11 @@ def __init__(
945945
"before using any functionality from the `accelerate` library."
946946
)
947947
# deepspeed handles mixed_precision using deepspeed_config
948-
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
948+
self._mixed_precision = (
949+
"no"
950+
if (self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8")
951+
else mixed_precision
952+
)
949953
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
950954
if mixed_precision == "bf16":
951955
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
@@ -1035,7 +1039,7 @@ def _check_initialized(self, mixed_precision=None, cpu=None):
10351039

10361040
@property
10371041
def mixed_precision(self):
1038-
if self.distributed_type == DistributedType.DEEPSPEED:
1042+
if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
10391043
config = self.deepspeed_plugin.deepspeed_config
10401044
if config.get("fp16", {}).get("enabled", False):
10411045
mixed_precision = "fp16"

tests/test_fp8.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def can_convert_te_model(from_config=False):
5050
accelerator_kwargs = {}
5151

5252
accelerator = Accelerator(**accelerator_kwargs)
53+
assert accelerator.fp8_enabled, "FP8 is not enabled"
54+
5355
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
5456
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
5557
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
@@ -168,6 +170,35 @@ def test_can_prepare_model_multigpu_deepspeed(self):
168170
command += ["-m", "tests.test_fp8", "--test_te"]
169171
run_command(command)
170172

173+
@require_deepspeed
174+
@require_multi_device
175+
def test_can_prepare_model_multigpu_deepspeed_from_config(self):
176+
os.environ["ZERO_STAGE"] = str(1)
177+
with tempfile.TemporaryDirectory() as dir_name:
178+
config_file = Path(dir_name) / "config.yaml"
179+
config_file.write_text(
180+
textwrap.dedent(
181+
"""
182+
distributed_type: "DEEPSPEED"
183+
deepspeed_config:
184+
gradient_clipping: 1.0
185+
gradient_accumulation_steps: 1
186+
offload_optimizer_device: none
187+
offload_param_device: none
188+
zero3_init_flag: false
189+
zero_stage: 1
190+
deepspeed_multinode_launcher: standard
191+
num_processes: 2
192+
mixed_precision: fp8
193+
fp8_config:
194+
backend: TE
195+
"""
196+
)
197+
)
198+
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
199+
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
200+
run_command(command)
201+
171202

172203
@require_torchao
173204
@require_huggingface_suite

0 commit comments

Comments
 (0)