Skip to content

Flexible and easy to use HSDP setting #19504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
312caee
Add fsdp_size for FSDPStrategy
Liyang90 Jan 17, 2024
45c1123
fix import
Liyang90 Jan 17, 2024
0ddc51d
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Feb 20, 2024
c952536
Add flexible HSDP in fabric
Liyang90 Feb 20, 2024
8fc2404
minor update
Liyang90 Feb 20, 2024
da3900f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
8311be1
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Mar 1, 2024
d1d719a
Use device_mesh arg to set flexible HSDP with a Tuple
Liyang90 Mar 4, 2024
3315893
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
4652b74
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Mar 5, 2024
4049f60
minor fix
Liyang90 Mar 5, 2024
9c14afe
add simple docs
awaelchli Mar 8, 2024
1f2c3ff
correct doc string
Liyang90 Apr 1, 2024
07f7c1b
set as explicit args in FSDPStrategy
Liyang90 Apr 4, 2024
2ab0423
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Apr 4, 2024
899e032
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
4259df2
update fsdp tests
Liyang90 Apr 18, 2024
dbe22f3
Type check error
Liyang90 Apr 18, 2024
2320a4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
b0d4783
merge
Liyang90 Apr 18, 2024
9d7dfbe
type check
Liyang90 Apr 18, 2024
483f745
Merge branch 'master' into hybrid_fsdp_stage
Liyang90 May 16, 2024
ba0b10b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
d2d9fe8
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Jun 5, 2024
bd03b05
simplify imports
awaelchli Jun 5, 2024
11bc4ee
extend test
awaelchli Jun 5, 2024
c6a052c
add changelog
awaelchli Jun 5, 2024
00efbcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2024
949d36f
Merge branch 'master' into hybrid_fsdp_stage
awaelchli Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.

Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.

device_mesh: A tuple `(sharding size, replication size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.

state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand Down Expand Up @@ -253,6 +257,12 @@ def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()

# if 'device_mesh' in the `_fsdp_kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"])

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
Expand Down
12 changes: 11 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,14 @@ class FSDPStrategy(ParallelStrategy):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.

Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.

device_mesh: A tuple `(sharding size, replication size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.

state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand Down Expand Up @@ -260,6 +264,12 @@ def setup_environment(self) -> None:
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"])

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

Expand Down