diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 0c7af66bceb98..d88f2ec12827a 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627)) -- +- Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped ([#19705](https://github.com/Lightning-AI/pytorch-lightning/pull/19705)) - diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 11f1c67211e40..093b355e2c376 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -263,7 +263,7 @@ def __setattr__(self, name: str, value: Any) -> None: original_has_attr = hasattr(original_module, name) # Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules # Can't use self.__getattr__ because it would pass through to the original module - fabric_has_attr = name in self.__dict__ + fabric_has_attr = name in dir(self) if not (original_has_attr or fabric_has_attr): setattr(original_module, name, value) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 3d6e47bffa8c5..0923c601d51c3 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -155,6 +155,9 @@ def __init__(self): # Modify existing attribute on original_module fabric_module.attribute = 101 + # "attribute" is only in the original_module, so it shouldn't get set in the fabric_module + assert "attribute" not in fabric_module.__dict__ + assert fabric_module.attribute == 101 # returns it from original_module assert original_module.attribute == 101 # Check setattr of original_module @@ -170,6 +173,23 @@ def __init__(self): assert linear in fabric_module.modules() assert linear in original_module.modules() + # Check monkeypatching of methods + fabric_module = _FabricModule(Mock(), Mock()) + original = id(fabric_module.forward) + fabric_module.forward = lambda *_: None + assert id(fabric_module.forward) != original + # Check special methods + assert "__repr__" in dir(fabric_module) + assert "__repr__" not in fabric_module.__dict__ + assert "__repr__" not in _FabricModule.__dict__ + fabric_module.__repr__ = lambda *_: "test" + assert fabric_module.__repr__() == "test" + # needs to be monkeypatched on the class for `repr()` to change + assert repr(fabric_module) == "_FabricModule()" + with mock.patch.object(_FabricModule, "__repr__", return_value="test"): + assert fabric_module.__repr__() == "test" + assert repr(fabric_module) == "test" + def test_fabric_module_state_dict_access(): """Test that state_dict access passes through to the original module."""