Skip to content

Commit 07f7c1b

Browse files
committed
set as explicit args in FSDPStrategy
1 parent 1f2c3ff commit 07f7c1b

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/lightning/fabric/strategies/fsdp.py

+11
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@
8686

8787
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
8888

89+
if _TORCH_GREATER_EQUAL_2_2:
90+
from torch.distributed._tensor import DeviceMesh
91+
else:
92+
DeviceMesh = None
93+
8994
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
9095

9196

@@ -158,6 +163,7 @@ def __init__(
158163
activation_checkpointing_policy: Optional["_POLICY"] = None,
159164
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
160165
state_dict_type: Literal["full", "sharded"] = "sharded",
166+
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
161167
**kwargs: Any,
162168
) -> None:
163169
super().__init__(
@@ -176,6 +182,11 @@ def __init__(
176182
# Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
177183
self._fsdp_kwargs.setdefault("use_orig_params", True)
178184

185+
if device_mesh is not None:
186+
if not _TORCH_GREATER_EQUAL_2_2:
187+
raise ValueError("The device_mesh argument is only supported in torch >= 2.2.")
188+
self._fsdp_kwargs["device_mesh"] = device_mesh
189+
179190
self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
180191
activation_checkpointing, activation_checkpointing_policy
181192
)

src/lightning/pytorch/strategies/fsdp.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from contextlib import contextmanager, nullcontext
1717
from datetime import timedelta
1818
from pathlib import Path
19-
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union
19+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union, Tuple
2020

2121
import torch
2222
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
@@ -57,6 +57,7 @@
5757
from lightning.fabric.utilities.imports import (
5858
_TORCH_GREATER_EQUAL_2_0,
5959
_TORCH_GREATER_EQUAL_2_1,
60+
_TORCH_GREATER_EQUAL_2_2,
6061
)
6162
from lightning.fabric.utilities.init import _EmptyInit
6263
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
@@ -85,6 +86,11 @@
8586

8687
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
8788

89+
if _TORCH_GREATER_EQUAL_2_2:
90+
from torch.distributed._tensor import DeviceMesh
91+
else:
92+
DeviceMesh = None
93+
8894

8995
log = logging.getLogger(__name__)
9096

@@ -162,6 +168,7 @@ def __init__(
162168
activation_checkpointing_policy: Optional["_POLICY"] = None,
163169
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
164170
state_dict_type: Literal["full", "sharded"] = "full",
171+
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
165172
**kwargs: Any,
166173
) -> None:
167174
super().__init__(
@@ -177,6 +184,12 @@ def __init__(
177184
self.cpu_offload = _init_cpu_offload(cpu_offload)
178185
self.mixed_precision = mixed_precision
179186
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)
187+
188+
if device_mesh is not None:
189+
if not _TORCH_GREATER_EQUAL_2_2:
190+
raise ValueError("The device_mesh argument is only supported in torch >= 2.2.")
191+
self.kwargs["device_mesh"] = device_mesh
192+
180193
self.sharding_strategy = _init_sharding_strategy(sharding_strategy, self.kwargs)
181194

182195
if _TORCH_GREATER_EQUAL_2_0:

0 commit comments

Comments
 (0)