Skip to content

Commit 6c398ec

Browse files
committed
Fix: prepare works even if nothing except tp specified (rare)
1 parent cb343c6 commit 6c398ec

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/accelerate/accelerator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,9 +1584,17 @@ def prepare(self, *args, device_placement=None):
15841584
return result if len(result) > 1 else result[0]
15851585

15861586
def _prepare_tp(self, *args):
1587+
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
1588+
result = [
1589+
self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args
1590+
]
1591+
1592+
# Second pass: prepare schedulers
1593+
result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]
1594+
15871595
device_mesh = self.torch_device_mesh
15881596

1589-
for arg in args:
1597+
for arg in result:
15901598
if not isinstance(arg, torch.nn.Module):
15911599
continue
15921600

0 commit comments

Comments
 (0)