Skip to content

Commit 1165a78

Browse files
Fix a bug that led to double-registration (#781)
After cloning a net, _module, _criteria, and _optimizers are already populated. Then, when loading params, there is yet another registration, i.e. a double registration. As a consequence, there would be two 'modules' etc. This is a fix for that.
1 parent 852383e commit 1165a78

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

skorch/net.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,10 @@ def get_params(self, deep=True, **kwargs):
17921792
# special treatment.
17931793
params_cb = self._get_params_callbacks(deep=deep)
17941794
params.update(params_cb)
1795-
return params
1795+
1796+
# don't include the following attributes
1797+
to_exclude = {'_modules', '_criteria', '_optimizers'}
1798+
return {key: val for key, val in params.items() if key not in to_exclude}
17961799

17971800
def _check_kwargs(self, kwargs):
17981801
"""Check argument names passed at initialization.
@@ -2095,11 +2098,13 @@ def _register_attribute(
20952098
self.cuda_dependent_attributes_ = (
20962099
self.cuda_dependent_attributes_[:] + [name + '_'])
20972100

2098-
if self.init_context_ == 'module':
2101+
# make sure to not double register -- this should never happen, but
2102+
# still better to check
2103+
if (self.init_context_ == 'module') and (name not in self._modules):
20992104
self._modules = self._modules[:] + [name]
2100-
elif self.init_context_ == 'criterion':
2105+
elif (self.init_context_ == 'criterion') and (name not in self._criteria):
21012106
self._criteria = self._criteria[:] + [name]
2102-
elif self.init_context_ == 'optimizer':
2107+
elif (self.init_context_ == 'optimizer') and (name not in self._optimizers):
21032108
self._optimizers = self._optimizers[:] + [name]
21042109

21052110
def _unregister_attribute(

skorch/tests/test_net.py

+47
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,44 @@ def test_save_load_state_dict_str(
572572
score_after = accuracy_score(y, net.predict(X))
573573
assert np.isclose(score_after, score_before)
574574

575+
def test_save_load_state_dict_no_duplicate_registration_after_initialize(
576+
self, net_cls, module_cls, net_fit, tmpdir):
577+
# #781
578+
net = net_cls(module_cls).initialize()
579+
580+
p = tmpdir.mkdir('skorch').join('testmodel.pkl')
581+
with open(str(p), 'wb') as f:
582+
net_fit.save_params(f_params=f)
583+
del net_fit
584+
585+
with open(str(p), 'rb') as f:
586+
net.load_params(f_params=f)
587+
588+
# check that there are no duplicates in _modules, _criteria, _optimizers
589+
# pylint: disable=protected-access
590+
assert net._modules == ['module']
591+
assert net._criteria == ['criterion']
592+
assert net._optimizers == ['optimizer']
593+
594+
def test_save_load_state_dict_no_duplicate_registration_after_clone(
595+
self, net_fit, tmpdir):
596+
# #781
597+
net = clone(net_fit).initialize()
598+
599+
p = tmpdir.mkdir('skorch').join('testmodel.pkl')
600+
with open(str(p), 'wb') as f:
601+
net_fit.save_params(f_params=f)
602+
del net_fit
603+
604+
with open(str(p), 'rb') as f:
605+
net.load_params(f_params=f)
606+
607+
# check that there are no duplicates in _modules, _criteria, _optimizers
608+
# pylint: disable=protected-access
609+
assert net._modules == ['module']
610+
assert net._criteria == ['criterion']
611+
assert net._optimizers == ['optimizer']
612+
575613
@pytest.fixture(scope='module')
576614
def net_fit_adam(self, net_cls, module_cls, data):
577615
net = net_cls(
@@ -1426,6 +1464,15 @@ def test_get_params_works(self, net_cls, module_cls):
14261464
# now initialized
14271465
assert 'callbacks__myscore__scoring' in params
14281466

1467+
def test_get_params_no_unwanted_params(self, net, net_fit):
1468+
# #781
1469+
# make sure certain keys are not returned
1470+
keys_unwanted = {'_modules', '_criteria', '_optimizers'}
1471+
for net_ in (net, net_fit):
1472+
keys_found = set(net_.get_params())
1473+
overlap = keys_found & keys_unwanted
1474+
assert not overlap
1475+
14291476
def test_get_params_with_uninit_callbacks(self, net_cls, module_cls):
14301477
from skorch.callbacks import EpochTimer
14311478

0 commit comments

Comments
 (0)