diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 3949e0afe56..1814d2dde3c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1584,9 +1584,17 @@ def prepare(self, *args, device_placement=None): return result if len(result) > 1 else result[0] def _prepare_tp(self, *args): + # First pass: prepare everything except schedulers (and model, which is prepared separately below) + result = [ + self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args + ] + + # Second pass: prepare schedulers + result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result] + device_mesh = self.torch_device_mesh - for arg in args: + for arg in result: if not isinstance(arg, torch.nn.Module): continue