File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -448,8 +448,9 @@ def __init__(
448
448
** kwargs ,
449
449
)
450
450
451
- self ._build_torch_device_mesh (self .parallelism_config )
452
- self .parallelism_config .validate_accelerator (self )
451
+ if self .parallelism_config :
452
+ self ._build_torch_device_mesh (self .parallelism_config )
453
+ self .parallelism_config .validate_accelerator (self )
453
454
454
455
self .fp8_enabled = self .state .mixed_precision == "fp8" or mixed_precision == "fp8"
455
456
Original file line number Diff line number Diff line change @@ -2998,7 +2998,10 @@ def active_mesh_dims(self):
2998
2998
return self .dp_dim_names + self .non_dp_dim_names
2999
2999
3000
3000
def build_device_mesh (self , device_type : str ):
3001
- mesh_dim_names , mesh_shape = self .get_mesh ()
3001
+ mesh = self .get_mesh ()
3002
+ if not mesh :
3003
+ return
3004
+ mesh_dim_names , mesh_shape = mesh
3002
3005
device_mesh = torch .distributed .init_device_mesh (
3003
3006
device_type ,
3004
3007
mesh_shape ,
You can’t perform that action at this time.
0 commit comments