Skip to content

Commit 07ce748

Browse files
authored
Fix: properly error when DDP + Dtensor model (#3629)
* Feat: add check * Refactor: nits
1 parent 175fe91 commit 07ce748

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

src/accelerate/accelerator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
is_xpu_available,
109109
load_fsdp_model,
110110
load_fsdp_optimizer,
111+
model_has_dtensor,
111112
pad_across_processes,
112113
parse_choice_from_env,
113114
recursively_apply,
@@ -1631,6 +1632,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
16311632
DistributedType.MULTI_XPU,
16321633
DistributedType.MULTI_HPU,
16331634
):
1635+
if model_has_dtensor(model):
1636+
raise ValueError(
1637+
"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
1638+
)
16341639
if any(p.requires_grad for p in model.parameters()):
16351640
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
16361641
# TODO: Look at enabling native TP training directly with a proper config

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@
282282
is_port_in_use,
283283
load,
284284
merge_dicts,
285+
model_has_dtensor,
285286
recursive_getattr,
286287
save,
287288
wait_for_everyone,

src/accelerate/utils/other.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,26 @@ def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
194194
module.compile(**compile_kwargs)
195195

196196

197+
def model_has_dtensor(model: torch.nn.Module) -> bool:
198+
"""
199+
Check if the model has DTensor parameters.
200+
201+
Args:
202+
model (`torch.nn.Module`):
203+
The model to check.
204+
205+
Returns:
206+
`bool`: Whether the model has DTensor parameters.
207+
"""
208+
if is_torch_version(">=", "2.5.0"):
209+
from torch.distributed.tensor import DTensor
210+
else:
211+
# from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
212+
from torch.distributed._tensor import DTensor
213+
214+
return any(isinstance(p, DTensor) for p in model.parameters())
215+
216+
197217
def extract_model_from_parallel(
198218
model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
199219
):

0 commit comments

Comments
 (0)