@@ -155,6 +155,9 @@ def __init__(self):
155
155
156
156
# Modify existing attribute on original_module
157
157
fabric_module .attribute = 101
158
+ # "attribute" is only in the original_module, so it shouldn't get set in the fabric_module
159
+ assert "attribute" not in fabric_module .__dict__
160
+ assert fabric_module .attribute == 101 # returns it from original_module
158
161
assert original_module .attribute == 101
159
162
160
163
# Check setattr of original_module
@@ -170,6 +173,23 @@ def __init__(self):
170
173
assert linear in fabric_module .modules ()
171
174
assert linear in original_module .modules ()
172
175
176
+ # Check monkeypatching of methods
177
+ fabric_module = _FabricModule (Mock (), Mock ())
178
+ original = id (fabric_module .forward )
179
+ fabric_module .forward = lambda * _ : None
180
+ assert id (fabric_module .forward ) != original
181
+ # Check special methods
182
+ assert "__repr__" in dir (fabric_module )
183
+ assert "__repr__" not in fabric_module .__dict__
184
+ assert "__repr__" not in _FabricModule .__dict__
185
+ fabric_module .__repr__ = lambda * _ : "test"
186
+ assert fabric_module .__repr__ () == "test"
187
+ # needs to be monkeypatched on the class for `repr()` to change
188
+ assert repr (fabric_module ) == "_FabricModule()"
189
+ with mock .patch .object (_FabricModule , "__repr__" , return_value = "test" ):
190
+ assert fabric_module .__repr__ () == "test"
191
+ assert repr (fabric_module ) == "test"
192
+
173
193
174
194
def test_fabric_module_state_dict_access ():
175
195
"""Test that state_dict access passes through to the original module."""
0 commit comments