Skip to content

Commit 45583a5

Browse files
Andrew Gusvekars
authored andcommitted
[FSDP2] Move to public torch.distributed.fsdp (pytorch#141868)
**Overview** This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.: ``` from torch.distributed.fsdp import fully_shard ``` This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing. **Follow-Ups** - [x] Add some explanation in the docs about FSDP1 vs. FSDP2 - [ ] Move unit tests from `test/distributed/_composable/fsdp` to `test/distributed/fsdp/fully_shard/` Pull Request resolved: pytorch#141868 Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent f9af86d commit 45583a5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+363
-174
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
torch.distributed.fsdp.fully_shard
2+
==================================
3+
4+
PyTorch FSDP2 (``fully_shard``)
5+
-------------------------------
6+
7+
PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation
8+
targeting performant eager-mode while using per-parameter sharding for improved
9+
usability.
10+
11+
- If you are new to FSDP, we recommend that you start with FSDP2 due to improved
12+
usability.
13+
- If you are currently using FSDP1, consider evaluating the following
14+
differences to see if you should switch to FSDP2:
15+
16+
Compared to PyTorch FSDP1 (``FullyShardedDataParallel``):
17+
18+
- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler
19+
sharding representation compared to FSDP1's flat-parameter sharding, while
20+
preserving similar throughput performance. More specifically, FSDP2 chunks
21+
each parameter on dim-0 across the data parallel workers (using
22+
``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a
23+
group of tensors together, making reasoning about what data is present on
24+
each worker and resharding to different parallelisms complex. Per-parameter
25+
sharding provides a more intuitive user experience, relaxes constraints
26+
around frozen parameters, and allows for communication-free (sharded) state
27+
dicts, which otherwise require all-gathers in FSDP1.
28+
- FSDP2 implements a different memory management approach to handle the
29+
multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures
30+
deterministic and expected memory usage and does not require blocking the CPU
31+
like in FSDP1's ``limit_all_gathers=True``.
32+
- FSDP2 exposes APIs for manual control over prefetching and collective
33+
scheduling, allowing power users more customization. See the methods on
34+
``FSDPModule`` below for details.
35+
- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
36+
support full state dicts. Instead, users can reshard the sharded state dicts
37+
containing ``DTensor`` s to full state dicts themselves using ``DTensor``
38+
APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like
39+
`PyTorch Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_ 's
40+
distributed state dict APIs. Also, some other args have been removed; see
41+
`here <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for
42+
details.
43+
44+
If you are onboarding FSDP for the first time or if any of the above appeals to
45+
your use case, we recommend that you consider using FSDP2.
46+
47+
See `this RFC <https://github.com/pytorch/pytorch/issues/114299>`_ for details
48+
on system design and implementation.
49+
50+
.. note::
51+
``torch.distributed.fsdp.fully_shard`` is currently in prototype state and
52+
under development. The core API will likely not change, but we may make some
53+
API changes if necessary.
54+
55+
.. currentmodule:: torch.distributed.fsdp
56+
57+
The frontend API is ``fully_shard`` that can be called on a ``module``:
58+
59+
.. autofunction:: fully_shard
60+
61+
Calling ``fully_shard(module)`` dynamically constructs a new class that
62+
subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if
63+
we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP
64+
constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this.
65+
Otherwise, ``fully_shard`` does not change the module structure and parameter
66+
fully-qualified names. The class ``FSDPModule`` allows providing some
67+
FSDP-specific methods on the module.
68+
69+
.. autoclass:: FSDPModule
70+
:members:
71+
:member-order: bysource
72+
73+
.. autoclass:: UnshardHandle
74+
:members:
75+
76+
.. autofunction:: register_fsdp_forward_method
77+
78+
.. autoclass:: MixedPrecisionPolicy
79+
:members:
80+
81+
.. autoclass:: OffloadPolicy
82+
:members:
83+
84+
.. autoclass:: CPUOffloadPolicy
85+
:members:

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ Features described in this documentation are classified by release status:
7979
torch.distributed.algorithms.join <distributed.algorithms.join>
8080
torch.distributed.elastic <distributed.elastic>
8181
torch.distributed.fsdp <fsdp>
82+
torch.distributed.fsdp.fully_shard <distributed.fsdp.fully_shard>
8283
torch.distributed.tensor.parallel <distributed.tensor.parallel>
8384
torch.distributed.optim <distributed.optim>
8485
torch.distributed.pipelining <distributed.pipelining>

test/distributed/_composable/fsdp/test_fully_shard_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import torch.distributed as dist
1212
import torch.nn as nn
13-
from torch.distributed._composable.fsdp import fully_shard
13+
from torch.distributed.fsdp import fully_shard
1414
from torch.nn.parallel.scatter_gather import _is_namedtuple
1515
from torch.testing._internal.common_cuda import TEST_CUDA
1616
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88
import torch.nn as nn
99
from torch.distributed._composable import replicate
10-
from torch.distributed._composable.fsdp import fully_shard
1110
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
11+
from torch.distributed.fsdp import fully_shard
1212
from torch.distributed.tensor.debug import CommDebugMode
1313
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1414
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,30 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313
from torch.distributed._composable import checkpoint, replicate
14-
from torch.distributed._composable.fsdp import (
14+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
15+
from torch.distributed.fsdp import (
1516
FSDPModule,
1617
fully_shard,
1718
MixedPrecisionPolicy,
1819
OffloadPolicy,
1920
)
20-
from torch.distributed._composable.fsdp._fsdp_collectives import (
21+
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
2122
_div_if_needed,
2223
_get_gradient_divide_factors,
2324
foreach_all_gather,
2425
foreach_all_gather_copy_out,
2526
foreach_reduce,
2627
)
27-
from torch.distributed._composable.fsdp._fsdp_common import FSDPMeshInfo, TrainingState
28-
from torch.distributed._composable.fsdp._fsdp_init import (
28+
from torch.distributed.fsdp._fully_shard._fsdp_common import FSDPMeshInfo, TrainingState
29+
from torch.distributed.fsdp._fully_shard._fsdp_init import (
2930
_get_post_forward_mesh_info,
3031
_init_default_fully_shard_mesh,
3132
)
32-
from torch.distributed._composable.fsdp._fsdp_param import ShardedState
33-
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
34-
from torch.distributed._tensor import DTensor
35-
from torch.distributed._tensor.experimental import implicit_replication
36-
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
33+
from torch.distributed.fsdp._fully_shard._fsdp_param import ShardedState
34+
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
35+
from torch.distributed.tensor import DTensor
3736
from torch.distributed.tensor.debug import CommDebugMode
37+
from torch.distributed.tensor.experimental import implicit_replication
3838
from torch.testing._internal.common_cuda import TEST_CUDA
3939
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
4040
from torch.testing._internal.common_fsdp import (

test/distributed/_composable/fsdp/test_fully_shard_compile.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@
1212

1313
import torch
1414
import torch._dynamo.testing
15-
import torch.distributed._composable.fsdp._fsdp_param
1615
import torch.nn.functional as F
1716
from torch import nn
1817
from torch._dynamo.utils import counters
1918
from torch._inductor import comms
2019
from torch._inductor.utils import is_fallback_op, run_and_get_code
21-
from torch.distributed._composable.fsdp import fully_shard
22-
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
23-
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
2420
from torch.distributed._tensor import init_device_mesh
25-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
21+
from torch.distributed.fsdp import (
22+
fully_shard,
23+
FullyShardedDataParallel as FSDP,
24+
ShardingStrategy,
25+
)
26+
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
27+
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
2628
from torch.testing import FileCheck
2729
from torch.testing._internal.common_distributed import (
2830
at_least_x_gpu,
@@ -83,7 +85,7 @@ def _test_disable_compiling_hooks(
8385
):
8486
torch._dynamo.reset()
8587
trace_rules_check_count = 0
86-
HOOKS_FILE_NAME = "torch/distributed/_composable/fsdp/_fsdp_state.py"
88+
HOOKS_FILE_NAME = "torch/distributed/fsdp/_fully_shard/_fsdp_state.py"
8789
HOOK_WRAPPER_NAME = "fsdp_hook_wrapper"
8890

8991
def patched_trace_rules_check(*args, **kwargs):

test/distributed/_composable/fsdp/test_fully_shard_extensions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import torch.nn as nn
1414
import torch.utils._pytree as pytree
1515
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
16-
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1716
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
17+
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
1818
from torch.testing._internal.common_cuda import TEST_CUDA
1919
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
2020
from torch.testing._internal.common_fsdp import (

test/distributed/_composable/fsdp/test_fully_shard_frozen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212
from torch.distributed._composable import checkpoint, replicate
13-
from torch.distributed._composable.fsdp import fully_shard
14-
from torch.distributed._composable.fsdp._fsdp_param_group import (
13+
from torch.distributed.fsdp import fully_shard
14+
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
1515
RegisterPostBackwardFunction,
1616
)
1717
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
import torch.nn as nn
66
from torch.amp.grad_scaler import GradScaler, OptState
7-
from torch.distributed._composable.fsdp import fully_shard
87
from torch.distributed._tensor import init_device_mesh
8+
from torch.distributed.fsdp import fully_shard
99
from torch.distributed.tensor.parallel import (
1010
ColwiseParallel,
1111
parallelize_module,

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99
import torch.distributed as dist
1010
import torch.nn as nn
1111
from torch.distributed._composable import replicate
12-
from torch.distributed._composable.fsdp import fully_shard
13-
from torch.distributed._composable.fsdp._fsdp_init import (
14-
_get_managed_modules,
15-
_get_managed_states,
16-
)
17-
from torch.distributed._composable.fsdp._fsdp_param import ParamModuleInfo
18-
from torch.distributed._composable.fsdp._fsdp_param_group import _get_param_module_infos
1912
from torch.distributed._tensor import (
2013
DeviceMesh,
2114
distribute_tensor,
@@ -24,6 +17,15 @@
2417
Shard,
2518
)
2619
from torch.distributed.device_mesh import init_device_mesh
20+
from torch.distributed.fsdp import fully_shard
21+
from torch.distributed.fsdp._fully_shard._fsdp_init import (
22+
_get_managed_modules,
23+
_get_managed_states,
24+
)
25+
from torch.distributed.fsdp._fully_shard._fsdp_param import ParamModuleInfo
26+
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
27+
_get_param_module_infos,
28+
)
2729
from torch.distributed.fsdp._init_utils import (
2830
_init_inter_node_process_group,
2931
_init_intra_node_process_group,
@@ -1156,5 +1158,26 @@ def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
11561158
fully_shard(model, shard_placement_fn=shard_placement_fn)
11571159

11581160

1161+
# TODO: Remove this test class once we remove the old import path:
1162+
# torch/distributed/_composable/fsdp
1163+
class TestFullyShardOldImport(FSDPTestMultiThread):
1164+
@property
1165+
def world_size(self) -> int:
1166+
return 2
1167+
1168+
@unittest.skipIf(not TEST_CUDA, "no cuda")
1169+
def test_old_import_training(self):
1170+
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1171+
1172+
model = nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16))
1173+
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
1174+
fully_shard(model[0], mp_policy=mp_policy)
1175+
fully_shard(model[1], mp_policy=mp_policy)
1176+
fully_shard(model, mp_policy=mp_policy)
1177+
1178+
inp = torch.randn((8, 16), device="cuda")
1179+
model(inp).sum().backward()
1180+
1181+
11591182
if __name__ == "__main__":
11601183
run_tests()

0 commit comments

Comments
 (0)