@@ -50,6 +50,8 @@ def can_convert_te_model(from_config=False):
50
50
accelerator_kwargs = {}
51
51
52
52
accelerator = Accelerator (** accelerator_kwargs )
53
+ assert accelerator .fp8_enabled , "FP8 is not enabled"
54
+
53
55
dataloader = torch .utils .data .DataLoader (torch .randn (10 , 32 ), batch_size = 2 )
54
56
model = torch .nn .Sequential (torch .nn .Linear (32 , 32 ), torch .nn .Linear (32 , 16 ))
55
57
optimizer = torch .optim .Adam (model .parameters (), lr = 1e-3 )
@@ -168,6 +170,35 @@ def test_can_prepare_model_multigpu_deepspeed(self):
168
170
command += ["-m" , "tests.test_fp8" , "--test_te" ]
169
171
run_command (command )
170
172
173
+ @require_deepspeed
174
+ @require_multi_device
175
+ def test_can_prepare_model_multigpu_deepspeed_from_config (self ):
176
+ os .environ ["ZERO_STAGE" ] = str (1 )
177
+ with tempfile .TemporaryDirectory () as dir_name :
178
+ config_file = Path (dir_name ) / "config.yaml"
179
+ config_file .write_text (
180
+ textwrap .dedent (
181
+ """
182
+ distributed_type: "DEEPSPEED"
183
+ deepspeed_config:
184
+ gradient_clipping: 1.0
185
+ gradient_accumulation_steps: 1
186
+ offload_optimizer_device: none
187
+ offload_param_device: none
188
+ zero3_init_flag: false
189
+ zero_stage: 1
190
+ deepspeed_multinode_launcher: standard
191
+ num_processes: 2
192
+ mixed_precision: fp8
193
+ fp8_config:
194
+ backend: TE
195
+ """
196
+ )
197
+ )
198
+ command = get_launch_command (config_file = str (config_file ), monitor_interval = 0.1 )
199
+ command += ["-m" , "tests.test_fp8" , "--test_te" , "--from_config" ]
200
+ run_command (command )
201
+
171
202
172
203
@require_torchao
173
204
@require_huggingface_suite
0 commit comments