Skip to content

Commit 1218658

Browse files
awaelchlicarmocca
authored andcommitted
Fix monkeypatching of _FabricModule methods (#19705)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 41bbd23 commit 1218658

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/lightning/fabric/wrappers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __setattr__(self, name: str, value: Any) -> None:
244244
original_has_attr = hasattr(original_module, name)
245245
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
246246
# Can't use self.__getattr__ because it would pass through to the original module
247-
fabric_has_attr = name in self.__dict__
247+
fabric_has_attr = name in dir(self)
248248

249249
if not (original_has_attr or fabric_has_attr):
250250
setattr(original_module, name, value)

tests/tests_fabric/test_wrappers.py

+20
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def __init__(self):
155155

156156
# Modify existing attribute on original_module
157157
fabric_module.attribute = 101
158+
# "attribute" is only in the original_module, so it shouldn't get set in the fabric_module
159+
assert "attribute" not in fabric_module.__dict__
160+
assert fabric_module.attribute == 101 # returns it from original_module
158161
assert original_module.attribute == 101
159162

160163
# Check setattr of original_module
@@ -170,6 +173,23 @@ def __init__(self):
170173
assert linear in fabric_module.modules()
171174
assert linear in original_module.modules()
172175

176+
# Check monkeypatching of methods
177+
fabric_module = _FabricModule(Mock(), Mock())
178+
original = id(fabric_module.forward)
179+
fabric_module.forward = lambda *_: None
180+
assert id(fabric_module.forward) != original
181+
# Check special methods
182+
assert "__repr__" in dir(fabric_module)
183+
assert "__repr__" not in fabric_module.__dict__
184+
assert "__repr__" not in _FabricModule.__dict__
185+
fabric_module.__repr__ = lambda *_: "test"
186+
assert fabric_module.__repr__() == "test"
187+
# needs to be monkeypatched on the class for `repr()` to change
188+
assert repr(fabric_module) == "_FabricModule()"
189+
with mock.patch.object(_FabricModule, "__repr__", return_value="test"):
190+
assert fabric_module.__repr__() == "test"
191+
assert repr(fabric_module) == "test"
192+
173193

174194
def test_fabric_module_state_dict_access():
175195
"""Test that state_dict access passes through to the original module."""

0 commit comments

Comments
 (0)