Skip to content

Commit 6a0823b

Browse files
committed
the long explanation
1 parent 358d725 commit 6a0823b

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

docs/source-fabric/api/fabric_methods.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Moves the model and optimizer to the correct device automatically.
4949
5050
5151
The setup method also prepares the model for the selected precision choice so that operations during ``forward()`` get
52-
cast automatically.
52+
cast automatically. Advanced users should read :doc:`the notes on models wrapped by Fabric <../api/wrappers>`.
5353

5454
setup_dataloaders
5555
=================

docs/source-fabric/api/wrappers.rst

+64-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#########################
2-
Modules wrapped by Fabric
3-
#########################
1+
########################
2+
Models wrapped by Fabric
3+
########################
44

55
When you :doc:`set up <../api/fabric_methods>` a model in Fabric, it gets automatically wrapped by a new module, the ``FabricModule``:
66

@@ -21,7 +21,9 @@ This wrapper module takes care of a few things for you, notably:
2121
- Precision: Inputs and outputs passed through ``forward`` get automatically converted to the right precision depending on the ``Fabric(precision=...)`` setting.
2222
- Device: The wrapper remembers which device the model is on, you can access `model.device`.
2323

24-
The ``FabricModule`` wrapper is completely transparent and most users will never need to interact with it directly.
24+
.. note::
25+
The FabricModule wrapper is completely transparent and most users will never need to interact with it directly.
26+
2527
Below we describe a few functions and properties of the wrapper for advanced use cases.
2628
This might be useful if you are building a custom Trainer using Fabric as the core.
2729

@@ -80,3 +82,61 @@ If you ever need to, you can access the original model explicitly via ``.module`
8082
8183
print(original_model is model) # True
8284
85+
86+
----
87+
88+
89+
************************************************
90+
Using methods other than forward for computation
91+
************************************************
92+
93+
PyTorch's ``nn.Modules`` have a special contract you need to follow when using them for training: Your forward computation has to be defined in the **forward** method and you should call this forward method directly.
94+
But sometimes your model may need to define different flavors of `forward`, like in this example below where the regular forward is used for training, but the `generate` method does something slightly different for inference:
95+
96+
.. code-block:: python
97+
98+
import torch
99+
import lightning as L
100+
101+
102+
class MyModel(torch.nn.Module):
103+
def __init__(self):
104+
super().__init__()
105+
self.layer = torch.nn.Linear(10, 2)
106+
107+
def forward(self, x):
108+
return self.layer(x)
109+
110+
def generate(self):
111+
sample = torch.randn(10)
112+
return self(sample)
113+
114+
115+
If you were to run this model in Fabric with multiple devices (DDP or FSDP), you would get an error:
116+
117+
.. code-block:: python
118+
119+
fabric = L.Fabric(accelerator="cpu", devices=2)
120+
fabric.launch()
121+
model = MyModel()
122+
model = fabric.setup(model)
123+
124+
# OK: Calling the model directly
125+
output = model(torch.randn(10))
126+
127+
# OK: Calling the model's forward (equivalent to the abvoe)
128+
output = model.forward(torch.randn(10))
129+
130+
# ERROR: Calling another method that calls forward indirectly
131+
output = model.generate()
132+
133+
Fabric produces an error here informing the user about incorrect usage because this is normally not allowed in PyTorch and could potentially lead to silent correctness bugs.
134+
If you want to use such methods, you need to mark them explicitly with `.mark_forward_method()` so that Fabric can do some rerouting behind the scenes for you to do the right thing:
135+
136+
.. code-block:: python
137+
138+
# You must mark special forward methods explicitly:
139+
model.mark_forward_method(model.generate)
140+
141+
# OK: Fabric will do some rerouting behind the scenes now
142+
output = model.generate()

0 commit comments

Comments
 (0)