@@ -300,84 +300,13 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2):
300
300
pytest .param (2 , 'adam' , False , 'amp_bf16' , False , True , False , False , False , marks = pytest .mark .world_size (2 )),
301
301
pytest .param (2 , 'adam' , False , 'amp_bf16' , False , False , True , False , False , marks = pytest .mark .world_size (2 )),
302
302
pytest .param (4 , 'adam' , False , 'amp_bf16' , False , False , False , True , False , marks = pytest .mark .world_size (4 )),
303
- pytest .param (
304
- 4 ,
305
- 'adam' ,
306
- False ,
307
- 'amp_bf16' ,
308
- False ,
309
- False ,
310
- False ,
311
- False ,
312
- True ,
313
- marks = [pytest .mark .world_size (4 ),
314
- pytest .mark .xfail (reason = 'Known issue, waiting for composer bump' )],
315
- ),
316
- pytest .param (
317
- 4 ,
318
- 'adamw' ,
319
- False ,
320
- 'amp_bf16' ,
321
- False ,
322
- False ,
323
- False ,
324
- False ,
325
- True ,
326
- marks = [pytest .mark .world_size (4 ),
327
- pytest .mark .xfail (reason = 'Known issue, waiting for composer bump' )],
328
- ),
329
- pytest .param (
330
- 4 ,
331
- 'adam' ,
332
- True ,
333
- 'amp_bf16' ,
334
- False ,
335
- False ,
336
- False ,
337
- False ,
338
- True ,
339
- marks = [pytest .mark .world_size (4 ),
340
- pytest .mark .xfail (reason = 'Known issue, waiting for composer bump' )],
341
- ),
342
- pytest .param (
343
- 4 ,
344
- 'adam' ,
345
- False ,
346
- 'amp_fp16' ,
347
- False ,
348
- False ,
349
- False ,
350
- False ,
351
- True ,
352
- marks = [pytest .mark .world_size (4 ),
353
- pytest .mark .xfail (reason = 'Known issue, waiting for composer bump' )],
354
- ),
355
- pytest .param (
356
- 4 ,
357
- 'adam' ,
358
- False ,
359
- 'amp_bf16' ,
360
- True ,
361
- True ,
362
- False ,
363
- False ,
364
- True ,
365
- marks = [pytest .mark .world_size (4 ),
366
- pytest .mark .xfail (reason = 'Known issue, waiting for composer bump' )],
367
- ), # save_weights_only requires load_weights_only
368
- pytest .param (
369
- 4 ,
370
- 'adam' ,
371
- False ,
372
- 'amp_bf16' ,
373
- False ,
374
- True ,
375
- False ,
376
- False ,
377
- True ,
378
- marks = [pytest .mark .world_size (4 ),
379
- pytest .mark .xfail (reason = 'Known issue, waiting for composer bump' )],
380
- ),
303
+ pytest .param (4 , 'adam' , False , 'amp_bf16' , False , False , False , False , True , marks = pytest .mark .world_size (4 )),
304
+ pytest .param (4 , 'adamw' , False , 'amp_bf16' , False , False , False , False , True , marks = pytest .mark .world_size (4 )),
305
+ pytest .param (4 , 'adam' , True , 'amp_bf16' , False , False , False , False , True , marks = pytest .mark .world_size (4 )),
306
+ pytest .param (4 , 'adam' , False , 'amp_fp16' , False , False , False , False , True , marks = pytest .mark .world_size (4 )),
307
+ pytest .param (4 , 'adam' , False , 'amp_bf16' , True , True , False , False , True ,
308
+ marks = pytest .mark .world_size (4 )), # save_weights_only requires load_weights_only
309
+ pytest .param (4 , 'adam' , False , 'amp_bf16' , False , True , False , False , True , marks = pytest .mark .world_size (4 )),
381
310
],
382
311
)
383
312
def test_fsdp_full_state_dict_load (
@@ -392,6 +321,8 @@ def test_fsdp_full_state_dict_load(
392
321
use_tp : bool ,
393
322
use_hsdp : bool ,
394
323
):
324
+ if use_hsdp :
325
+ pytest .xfail ('Known Pytorch issue with HSDP, waiting for pytorch patch' )
395
326
if autoresume :
396
327
run_name = 'my-cool-autoresume-run'
397
328
else :
0 commit comments