You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This wrapper module takes care of a few things for you, notably:
19
+
20
+
- Strategy: Handles strategy-specific logic for the forward method (DDP, FSDP, etc.).
21
+
- Precision: Inputs and outputs passed through ``forward`` get automatically converted to the right precision depending on the ``Fabric(precision=...)`` setting.
22
+
- Device: The wrapper remembers which device the model is on. You can access it with `model.device`.
23
+
24
+
.. note::
25
+
The ``FabricModule`` wrapper is completely transparent and most users will never need to interact with it directly.
26
+
27
+
Below we describe a few functions and properties of the wrapper for advanced use cases.
28
+
This might be useful if you are building a custom Trainer using Fabric as the core.
29
+
30
+
31
+
----
32
+
33
+
34
+
********************************
35
+
Accessing methods and attributes
36
+
********************************
37
+
38
+
Access to methods and attributes gets redirected to the original model automatically:
39
+
40
+
.. code-block:: python
41
+
42
+
import torch
43
+
import lightning as L
44
+
45
+
fabric = L.Fabric()
46
+
model = torch.nn.Linear(10, 2)
47
+
fabric_model = fabric.setup(model)
48
+
49
+
# You can access attributes and methods normally
50
+
print(fabric_model.weight is model.weight) # True
51
+
52
+
53
+
----
54
+
55
+
56
+
********************
57
+
Unwrapping the model
58
+
********************
59
+
60
+
You can check whether a model is wrapped in a ``FabricModule`` with the ``is_wrapped`` utility function:
61
+
62
+
.. code-block:: python
63
+
64
+
import torch
65
+
import lightning as L
66
+
from lightning.fabric import is_wrapped
67
+
68
+
fabric = L.Fabric()
69
+
model = torch.nn.Linear(10, 2)
70
+
fabric_model = fabric.setup(model)
71
+
72
+
print(is_wrapped(model)) # False
73
+
print(is_wrapped(fabric_model)) # True
74
+
75
+
76
+
If you ever need to, you can access the original model explicitly via ``.module``:
77
+
78
+
.. code-block:: python
79
+
80
+
# Access the original model explicitly
81
+
original_model = fabric_model.module
82
+
83
+
print(original_model is model) # True
84
+
85
+
86
+
----
87
+
88
+
89
+
************************************************
90
+
Using methods other than forward for computation
91
+
************************************************
92
+
93
+
PyTorch's ``nn.Modules`` have a special contract you need to follow when using them for training: Your forward computation has to be defined in the **forward** method and you should call this forward method directly.
94
+
But sometimes your model may need to define different flavors of `forward`, like in this example below where the regular `forward` is used for training, but the `generate` method does something slightly different for inference:
95
+
96
+
.. code-block:: python
97
+
98
+
import torch
99
+
import lightning as L
100
+
101
+
102
+
classMyModel(torch.nn.Module):
103
+
def__init__(self):
104
+
super().__init__()
105
+
self.layer = torch.nn.Linear(10, 2)
106
+
107
+
defforward(self, x):
108
+
returnself.layer(x)
109
+
110
+
defgenerate(self):
111
+
sample = torch.randn(10)
112
+
returnself(sample)
113
+
114
+
115
+
If you were to run this model in Fabric with multiple devices (DDP or FSDP), you would get an error:
116
+
117
+
.. code-block:: python
118
+
119
+
fabric = L.Fabric(accelerator="cpu", devices=2)
120
+
fabric.launch()
121
+
model = MyModel()
122
+
model = fabric.setup(model)
123
+
124
+
# OK: Calling the model directly
125
+
output = model(torch.randn(10))
126
+
127
+
# OK: Calling the model's forward (equivalent to the abvoe)
128
+
output = model.forward(torch.randn(10))
129
+
130
+
# ERROR: Calling another method that calls forward indirectly
131
+
output = model.generate()
132
+
133
+
Fabric produces an error there informing the user about incorrect usage because this is normally not allowed in PyTorch and could potentially lead to silent correctness bugs.
134
+
If you want to use such methods, you need to mark them explicitly with ``.mark_forward_method()`` so that Fabric can do some rerouting behind the scenes for you to do the right thing:
135
+
136
+
.. code-block:: python
137
+
138
+
# You must mark special forward methods explicitly:
139
+
model.mark_forward_method(model.generate)
140
+
141
+
# Passing just the name is also sufficient
142
+
model.mark_forward_method("generate")
143
+
144
+
# OK: Fabric will do some rerouting behind the scenes now
Copy file name to clipboardExpand all lines: src/lightning/fabric/CHANGELOG.md
+3-1
Original file line number
Diff line number
Diff line change
@@ -9,7 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9
9
10
10
### Added
11
11
12
-
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
12
+
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
13
+
14
+
- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))
13
15
14
16
- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
0 commit comments