Skip to content

Commit e030727

Browse files
awaelchlirasbt
andauthored
Add function to explicitly mark forward methods in Fabric (#19690)
Co-authored-by: Sebastian Raschka <[email protected]>
1 parent 0c8a193 commit e030727

File tree

6 files changed

+250
-17
lines changed

6 files changed

+250
-17
lines changed

docs/source-fabric/api/fabric_methods.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Moves the model and optimizer to the correct device automatically.
4949
5050
5151
The setup method also prepares the model for the selected precision choice so that operations during ``forward()`` get
52-
cast automatically.
52+
cast automatically. Advanced users should read :doc:`the notes on models wrapped by Fabric <../api/wrappers>`.
5353

5454
setup_dataloaders
5555
=================

docs/source-fabric/api/wrappers.rst

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
########################
2+
Models wrapped by Fabric
3+
########################
4+
5+
When you :doc:`set up <../api/fabric_methods>` a model in Fabric, it gets automatically wrapped by a new module, the ``FabricModule``:
6+
7+
.. code-block:: python
8+
9+
import torch
10+
import lightning as L
11+
12+
fabric = L.Fabric()
13+
model = torch.nn.Linear(10, 2)
14+
model = fabric.setup(model)
15+
16+
print(type(model)) # <class 'lightning.fabric.wrappers._FabricModule'>
17+
18+
This wrapper module takes care of a few things for you, notably:
19+
20+
- Strategy: Handles strategy-specific logic for the forward method (DDP, FSDP, etc.).
21+
- Precision: Inputs and outputs passed through ``forward`` get automatically converted to the right precision depending on the ``Fabric(precision=...)`` setting.
22+
- Device: The wrapper remembers which device the model is on. You can access it with `model.device`.
23+
24+
.. note::
25+
The ``FabricModule`` wrapper is completely transparent and most users will never need to interact with it directly.
26+
27+
Below we describe a few functions and properties of the wrapper for advanced use cases.
28+
This might be useful if you are building a custom Trainer using Fabric as the core.
29+
30+
31+
----
32+
33+
34+
********************************
35+
Accessing methods and attributes
36+
********************************
37+
38+
Access to methods and attributes gets redirected to the original model automatically:
39+
40+
.. code-block:: python
41+
42+
import torch
43+
import lightning as L
44+
45+
fabric = L.Fabric()
46+
model = torch.nn.Linear(10, 2)
47+
fabric_model = fabric.setup(model)
48+
49+
# You can access attributes and methods normally
50+
print(fabric_model.weight is model.weight) # True
51+
52+
53+
----
54+
55+
56+
********************
57+
Unwrapping the model
58+
********************
59+
60+
You can check whether a model is wrapped in a ``FabricModule`` with the ``is_wrapped`` utility function:
61+
62+
.. code-block:: python
63+
64+
import torch
65+
import lightning as L
66+
from lightning.fabric import is_wrapped
67+
68+
fabric = L.Fabric()
69+
model = torch.nn.Linear(10, 2)
70+
fabric_model = fabric.setup(model)
71+
72+
print(is_wrapped(model)) # False
73+
print(is_wrapped(fabric_model)) # True
74+
75+
76+
If you ever need to, you can access the original model explicitly via ``.module``:
77+
78+
.. code-block:: python
79+
80+
# Access the original model explicitly
81+
original_model = fabric_model.module
82+
83+
print(original_model is model) # True
84+
85+
86+
----
87+
88+
89+
************************************************
90+
Using methods other than forward for computation
91+
************************************************
92+
93+
PyTorch's ``nn.Modules`` have a special contract you need to follow when using them for training: Your forward computation has to be defined in the **forward** method and you should call this forward method directly.
94+
But sometimes your model may need to define different flavors of `forward`, like in this example below where the regular `forward` is used for training, but the `generate` method does something slightly different for inference:
95+
96+
.. code-block:: python
97+
98+
import torch
99+
import lightning as L
100+
101+
102+
class MyModel(torch.nn.Module):
103+
def __init__(self):
104+
super().__init__()
105+
self.layer = torch.nn.Linear(10, 2)
106+
107+
def forward(self, x):
108+
return self.layer(x)
109+
110+
def generate(self):
111+
sample = torch.randn(10)
112+
return self(sample)
113+
114+
115+
If you were to run this model in Fabric with multiple devices (DDP or FSDP), you would get an error:
116+
117+
.. code-block:: python
118+
119+
fabric = L.Fabric(accelerator="cpu", devices=2)
120+
fabric.launch()
121+
model = MyModel()
122+
model = fabric.setup(model)
123+
124+
# OK: Calling the model directly
125+
output = model(torch.randn(10))
126+
127+
# OK: Calling the model's forward (equivalent to the abvoe)
128+
output = model.forward(torch.randn(10))
129+
130+
# ERROR: Calling another method that calls forward indirectly
131+
output = model.generate()
132+
133+
Fabric produces an error there informing the user about incorrect usage because this is normally not allowed in PyTorch and could potentially lead to silent correctness bugs.
134+
If you want to use such methods, you need to mark them explicitly with ``.mark_forward_method()`` so that Fabric can do some rerouting behind the scenes for you to do the right thing:
135+
136+
.. code-block:: python
137+
138+
# You must mark special forward methods explicitly:
139+
model.mark_forward_method(model.generate)
140+
141+
# Passing just the name is also sufficient
142+
model.mark_forward_method("generate")
143+
144+
# OK: Fabric will do some rerouting behind the scenes now
145+
output = model.generate()
146+
147+
|

docs/source-fabric/glossary/index.rst

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Glossary
88

99
Checkpoint <../guide/checkpoint/index>
1010
Weights and Biases <../guide/loggers/wandb>
11+
Wrappers <../api/wrappers>
1112

1213

1314
.. raw:: html
@@ -80,6 +81,11 @@ Glossary
8081
:button_link: ../fundamentals/accelerators.html
8182
:col_css: col-md-4
8283

84+
.. displayitem::
85+
:header: FabricModule
86+
:button_link: ../api/wrappers.html
87+
:col_css: col-md-4
88+
8389
.. displayitem::
8490
:header: FSDP
8591
:button_link: ../advanced/model_parallel/fsdp.html

src/lightning/fabric/CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
12+
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
13+
14+
- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))
1315

1416
- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
1517

src/lightning/fabric/wrappers.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import inspect
1515
from copy import deepcopy
1616
from functools import partial, wraps
17+
from types import MethodType
1718
from typing import (
1819
Any,
1920
Callable,
@@ -123,6 +124,7 @@ def __init__(
123124
self._forward_module = forward_module
124125
self._original_module = original_module or forward_module
125126
self._strategy = strategy
127+
self._forward_methods = set(_LIGHTNING_MODULE_STEP_METHODS)
126128
self._fabric_module_initialized = True
127129

128130
@property
@@ -165,6 +167,20 @@ def load_state_dict( # type: ignore[override]
165167
) -> _IncompatibleKeys:
166168
return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)
167169

170+
def mark_forward_method(self, method: Union[MethodType, str]) -> None:
171+
"""Mark a method as a 'forward' method to prevent it bypassing the strategy wrapper (e.g., DDP)."""
172+
if not isinstance(method, (MethodType, str)):
173+
raise TypeError(f"Expected a method or a string, but got: {type(method).__name__}")
174+
name = method if isinstance(method, str) else method.__name__
175+
if name == "forward":
176+
raise ValueError("You cannot mark the forward method itself as a forward method.")
177+
if not isinstance(getattr(self._original_module, name, None), MethodType):
178+
raise AttributeError(
179+
f"You marked '{name}' as a forward method, but `{type(self._original_module).__name__}.{name}` does not"
180+
f" exist or is not a method."
181+
)
182+
self._forward_methods.add(name)
183+
168184
def _redirection_through_forward(self, method_name: str) -> Callable:
169185
assert method_name != "forward"
170186
original_forward = self._original_module.forward
@@ -207,8 +223,8 @@ def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
207223
if module_called:
208224
raise RuntimeError(
209225
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
210-
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
211-
" `.backward()`. You should pass your inputs through `forward()`.",
226+
" model. To avoid issues with the currently selected strategy, explicitly mark it as a"
227+
f" forward method with `fabric_model.mark_forward_method({name!r})` after `fabric.setup()`."
212228
)
213229
for handle in handles:
214230
handle.remove()
@@ -231,8 +247,12 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:
231247

232248
@override
233249
def __getattr__(self, item: Any) -> Any:
234-
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
235-
# Special support for `LightningModule`, to prevent bypassing DDP's forward
250+
if (
251+
item != "_forward_methods"
252+
and item in self._forward_methods
253+
and self._forward_module != self._original_module
254+
):
255+
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
236256
return self._redirection_through_forward(item)
237257

238258
try:

tests/tests_fabric/test_wrappers.py

+69-11
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,20 @@ def __init__(self, module):
102102
super().__init__()
103103
self.wrapped = module
104104

105+
def forward(self, *args, **kwargs):
106+
return self.wrapped(*args, **kwargs)
107+
105108
# Regular case: forward_module == original_module -> no warnings
106109
original_module = OriginalModule()
107110
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
108111
assert fabric_module.method_without_module_invocation() == 100
109112

110-
# Special case: original module wrapped by forward module: -> warn if method accepts args
113+
# Special case: original module wrapped by forward module: -> error if method requires rerouting
111114
original_module = OriginalModule()
112115
wrapped_module = ModuleWrapper(original_module)
113-
fabric_module = _FabricModule(forward_module=wrapped_module, strategy=Mock(), original_module=original_module)
116+
fabric_module = _FabricModule(
117+
forward_module=wrapped_module, strategy=Mock(precision=Precision()), original_module=original_module
118+
)
114119
assert fabric_module.method_without_module_invocation() == 100
115120
with pytest.raises(
116121
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
@@ -121,6 +126,51 @@ def __init__(self, module):
121126
):
122127
assert fabric_module.method_with_self_invocation() == 102
123128

129+
# No error if explicitly marked as forward method
130+
fabric_module.mark_forward_method("method_with_self_invocation")
131+
assert fabric_module.method_with_self_invocation() == 102
132+
133+
134+
def test_fabric_module_mark_forward_method():
135+
class OriginalModule(torch.nn.Module):
136+
attribute = 1
137+
138+
def forward(self, x):
139+
return x
140+
141+
def special(self):
142+
pass
143+
144+
original_module = OriginalModule()
145+
fabric_module = _FabricModule(original_module, Mock(), original_module=original_module)
146+
147+
with pytest.raises(ValueError, match="You cannot mark the forward method itself"):
148+
fabric_module.mark_forward_method("forward")
149+
150+
with pytest.raises(AttributeError, match="`OriginalModule.not_exist` does not exist or is not a method."):
151+
fabric_module.mark_forward_method("not_exist")
152+
153+
with pytest.raises(AttributeError, match="`OriginalModule.attribute` does not exist or is not a method."):
154+
fabric_module.mark_forward_method("attribute")
155+
156+
def special(x):
157+
return x
158+
159+
with pytest.raises(TypeError, match="Expected a method or a string"):
160+
fabric_module.mark_forward_method(special)
161+
162+
lightning_module_methods = {"training_step", "validation_step", "test_step", "predict_step"}
163+
assert fabric_module._forward_methods == lightning_module_methods
164+
165+
# Mark via name
166+
fabric_module.mark_forward_method("special")
167+
assert fabric_module._forward_methods == {"special"} | lightning_module_methods
168+
169+
# Mark by passing in the method itself
170+
fabric_module = _FabricModule(original_module, Mock(), original_module=original_module)
171+
fabric_module.mark_forward_method(original_module.special)
172+
assert fabric_module._forward_methods == {"special"} | lightning_module_methods
173+
124174

125175
def test_fabric_module_setattr():
126176
"""Test that setattr sets attributes on the original module."""
@@ -549,8 +599,8 @@ def test_unwrap_objects(compile):
549599

550600

551601
def test_step_method_redirection():
552-
"""Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
553-
module."""
602+
"""Test that the FabricModule redirects methods marked as 'forward methods' through forward to avoid bypassing the
603+
DDP/FSDP wrappers."""
554604

555605
class DDP(torch.nn.Module):
556606
def __init__(self, module):
@@ -570,11 +620,11 @@ def training_step(self, arg, kwarg=None):
570620
assert kwarg == "train_kwarg"
571621
return "training_step_return"
572622

573-
def validation_step(self, arg, kwarg=None):
623+
def marked_method(self, arg, kwarg=None):
574624
assert self() == "forward_return"
575-
assert arg == "val_arg"
576-
assert kwarg == "val_kwarg"
577-
return "validation_step_return"
625+
assert arg == "marked_arg"
626+
assert kwarg == "marked_kwarg"
627+
return "marked_method_return"
578628

579629
def normal_method(self):
580630
pass
@@ -602,18 +652,26 @@ def normal_method(self):
602652
assert original_module.forward.__name__ == "forward"
603653

604654
# The special methods get redirected correctly to produce the expected output
655+
strategy.precision.forward_context.reset_mock()
605656
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
606657
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
607-
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
608-
strategy.precision.forward_context.assert_called()
658+
assert strategy.precision.forward_context.call_count == 2
659+
660+
# Other methods must be marked explicitly to be redirected
661+
strategy.precision.forward_context.reset_mock()
662+
with pytest.raises(RuntimeError, match="You are calling the method .* from outside the model"):
663+
fabric_module.marked_method("marked_arg", kwarg="marked_kwarg")
664+
fabric_module.mark_forward_method("marked_method")
665+
assert fabric_module.marked_method("marked_arg", kwarg="marked_kwarg") == "marked_method_return"
666+
strategy.precision.forward_context.assert_called_once()
609667

610668
# The forward method remains untouched/unpatched after the special methods have been called
611669
assert original_module.forward.__name__ == "forward"
612670

613671
# Special case: forward_module == original_module -> no special treatment applied
614672
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
615673
assert fabric_module.training_step == original_module.training_step
616-
assert fabric_module.validation_step == original_module.validation_step
674+
assert fabric_module.marked_method == original_module.marked_method
617675

618676

619677
@RunIf(dynamo=True)

0 commit comments

Comments
 (0)