Skip to content

Commit ca6c94c

Browse files
authored
Fix monkeypatching of _FabricModule methods (#19705)
1 parent 0fb267b commit ca6c94c

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

src/lightning/fabric/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949

5050
- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))
5151

52-
-
52+
- Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped ([#19705](https://github.com/Lightning-AI/pytorch-lightning/pull/19705))
5353

5454
-
5555

src/lightning/fabric/wrappers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def __setattr__(self, name: str, value: Any) -> None:
263263
original_has_attr = hasattr(original_module, name)
264264
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
265265
# Can't use self.__getattr__ because it would pass through to the original module
266-
fabric_has_attr = name in self.__dict__
266+
fabric_has_attr = name in dir(self)
267267

268268
if not (original_has_attr or fabric_has_attr):
269269
setattr(original_module, name, value)

tests/tests_fabric/test_wrappers.py

+20
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def __init__(self):
155155

156156
# Modify existing attribute on original_module
157157
fabric_module.attribute = 101
158+
# "attribute" is only in the original_module, so it shouldn't get set in the fabric_module
159+
assert "attribute" not in fabric_module.__dict__
160+
assert fabric_module.attribute == 101 # returns it from original_module
158161
assert original_module.attribute == 101
159162

160163
# Check setattr of original_module
@@ -170,6 +173,23 @@ def __init__(self):
170173
assert linear in fabric_module.modules()
171174
assert linear in original_module.modules()
172175

176+
# Check monkeypatching of methods
177+
fabric_module = _FabricModule(Mock(), Mock())
178+
original = id(fabric_module.forward)
179+
fabric_module.forward = lambda *_: None
180+
assert id(fabric_module.forward) != original
181+
# Check special methods
182+
assert "__repr__" in dir(fabric_module)
183+
assert "__repr__" not in fabric_module.__dict__
184+
assert "__repr__" not in _FabricModule.__dict__
185+
fabric_module.__repr__ = lambda *_: "test"
186+
assert fabric_module.__repr__() == "test"
187+
# needs to be monkeypatched on the class for `repr()` to change
188+
assert repr(fabric_module) == "_FabricModule()"
189+
with mock.patch.object(_FabricModule, "__repr__", return_value="test"):
190+
assert fabric_module.__repr__() == "test"
191+
assert repr(fabric_module) == "test"
192+
173193

174194
def test_fabric_module_state_dict_access():
175195
"""Test that state_dict access passes through to the original module."""

0 commit comments

Comments
 (0)