Skip to content

Commit 76a546f

Browse files
committed
add guards
1 parent 168b520 commit 76a546f

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/accelerate/accelerator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,9 @@ def __init__(
448448
**kwargs,
449449
)
450450

451-
self._build_torch_device_mesh(self.parallelism_config)
452-
self.parallelism_config.validate_accelerator(self)
451+
if self.parallelism_config:
452+
self._build_torch_device_mesh(self.parallelism_config)
453+
self.parallelism_config.validate_accelerator(self)
453454

454455
self.fp8_enabled = self.state.mixed_precision == "fp8" or mixed_precision == "fp8"
455456

src/accelerate/utils/dataclasses.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2998,7 +2998,10 @@ def active_mesh_dims(self):
29982998
return self.dp_dim_names + self.non_dp_dim_names
29992999

30003000
def build_device_mesh(self, device_type: str):
3001-
mesh_dim_names, mesh_shape = self.get_mesh()
3001+
mesh = self.get_mesh()
3002+
if not mesh:
3003+
return
3004+
mesh_dim_names, mesh_shape = mesh
30023005
device_mesh = torch.distributed.init_device_mesh(
30033006
device_type,
30043007
mesh_shape,

0 commit comments

Comments
 (0)