Skip to content

Commit 7668a6b

Browse files
Liyang90awaelchli
andauthored
Flexible and easy to use HSDP setting (#19504)
Co-authored-by: awaelchli <[email protected]>
1 parent 1a6786d commit 7668a6b

File tree

6 files changed

+73
-6
lines changed

6 files changed

+73
-6
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))
2121

22+
- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504))
23+
2224

2325
### Changed
2426

src/lightning/fabric/strategies/fsdp.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,14 @@
7474
from lightning.fabric.utilities.types import _PATH, _Stateful
7575

7676
if TYPE_CHECKING:
77+
from torch.distributed.device_mesh import DeviceMesh
7778
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
7879
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
7980

8081
_POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
8182
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
8283

84+
8385
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
8486

8587

@@ -117,10 +119,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
117119
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
118120
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
119121
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
120-
replicates across machines.
122+
replicates across machines. See also the `device_mesh` parameter below.
121123
122124
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
123125
126+
device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
127+
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
128+
with the `HYBRID_SHARD` sharding strategy.
129+
124130
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
125131
126132
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
@@ -146,6 +152,7 @@ def __init__(
146152
activation_checkpointing_policy: Optional["_POLICY"] = None,
147153
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
148154
state_dict_type: Literal["full", "sharded"] = "sharded",
155+
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
149156
**kwargs: Any,
150157
) -> None:
151158
super().__init__(
@@ -163,6 +170,11 @@ def __init__(
163170
# Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
164171
self._fsdp_kwargs.setdefault("use_orig_params", True)
165172

173+
if device_mesh is not None:
174+
if not _TORCH_GREATER_EQUAL_2_2:
175+
raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.")
176+
self._fsdp_kwargs["device_mesh"] = device_mesh
177+
166178
self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
167179
activation_checkpointing, activation_checkpointing_policy
168180
)
@@ -244,6 +256,12 @@ def setup_environment(self) -> None:
244256
super().setup_environment()
245257
self._setup_distributed()
246258

259+
# if 'device_mesh' in the `_fsdp_kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
260+
if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple):
261+
from torch.distributed.device_mesh import init_device_mesh
262+
263+
self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"])
264+
247265
@override
248266
def setup_module_and_optimizers(
249267
self, module: Module, optimizers: List[Optimizer]

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))
2424

25+
- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504))
26+
2527

2628
### Changed
2729

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,21 @@
1616
from contextlib import contextmanager, nullcontext
1717
from datetime import timedelta
1818
from pathlib import Path
19-
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Callable,
23+
Dict,
24+
Generator,
25+
List,
26+
Literal,
27+
Mapping,
28+
Optional,
29+
Set,
30+
Tuple,
31+
Type,
32+
Union,
33+
)
2034

2135
import torch
2236
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
@@ -53,7 +67,10 @@
5367
_sync_ddp_if_available,
5468
)
5569
from lightning.fabric.utilities.distributed import group as _group
56-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
70+
from lightning.fabric.utilities.imports import (
71+
_TORCH_GREATER_EQUAL_2_1,
72+
_TORCH_GREATER_EQUAL_2_2,
73+
)
5774
from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers
5875
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
5976
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@@ -70,6 +87,7 @@
7087
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
7188

7289
if TYPE_CHECKING:
90+
from torch.distributed.device_mesh import DeviceMesh
7391
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
7492
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
7593

@@ -114,10 +132,14 @@ class FSDPStrategy(ParallelStrategy):
114132
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
115133
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
116134
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
117-
replicates across machines.
135+
replicates across machines. See also the `device_mesh` parameter below.
118136
119137
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
120138
139+
device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
140+
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
141+
with the `HYBRID_SHARD` sharding strategy.
142+
121143
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
122144
123145
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
@@ -147,6 +169,7 @@ def __init__(
147169
activation_checkpointing_policy: Optional["_POLICY"] = None,
148170
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
149171
state_dict_type: Literal["full", "sharded"] = "full",
172+
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
150173
**kwargs: Any,
151174
) -> None:
152175
super().__init__(
@@ -162,6 +185,12 @@ def __init__(
162185
self.cpu_offload = _init_cpu_offload(cpu_offload)
163186
self.mixed_precision = mixed_precision
164187
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)
188+
189+
if device_mesh is not None:
190+
if not _TORCH_GREATER_EQUAL_2_2:
191+
raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.")
192+
self.kwargs["device_mesh"] = device_mesh
193+
165194
self.sharding_strategy = _init_sharding_strategy(sharding_strategy, self.kwargs)
166195

167196
# Avoids the need for user to reference params in `configure_optimizers` via
@@ -242,6 +271,12 @@ def setup_environment(self) -> None:
242271
assert self.cluster_environment is not None
243272
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
244273

274+
# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
275+
if isinstance(self.kwargs.get("device_mesh"), tuple):
276+
from torch.distributed.device_mesh import init_device_mesh
277+
278+
self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"])
279+
245280
def _get_process_group_backend(self) -> str:
246281
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
247282

tests/tests_fabric/strategies/test_fsdp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_sharding_strategy():
7272

7373

7474
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
75-
def test_hybrid_shard_configuration(sharding_strategy):
75+
def test_hybrid_shard_configuration(sharding_strategy, monkeypatch):
7676
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
7777
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
7878
FSDPStrategy(sharding_strategy=sharding_strategy)
@@ -85,6 +85,11 @@ def test_hybrid_shard_configuration(sharding_strategy):
8585
assert strategy.sharding_strategy.name == sharding_strategy
8686
assert strategy._fsdp_kwargs["process_group"] is process_group
8787

88+
monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False)
89+
with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."):
90+
FSDPStrategy(device_mesh=Mock())
91+
92+
monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True)
8893
device_mesh = Mock()
8994
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
9095
assert strategy.sharding_strategy.name == sharding_strategy

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def test_sharding_strategy():
501501

502502

503503
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
504-
def test_hybrid_sharding_strategy(sharding_strategy):
504+
def test_hybrid_shard_configuration(sharding_strategy, monkeypatch):
505505
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
506506
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
507507
FSDPStrategy(sharding_strategy=sharding_strategy)
@@ -514,6 +514,11 @@ def test_hybrid_sharding_strategy(sharding_strategy):
514514
assert strategy.sharding_strategy.name == sharding_strategy
515515
assert strategy.kwargs["process_group"] is process_group
516516

517+
monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False)
518+
with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."):
519+
FSDPStrategy(device_mesh=Mock())
520+
521+
monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True)
517522
device_mesh = Mock()
518523
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
519524
assert strategy.sharding_strategy.name == sharding_strategy

0 commit comments

Comments
 (0)