Skip to content

Fix a bug that led to double-registration #781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,10 @@ def get_params(self, deep=True, **kwargs):
# special treatment.
params_cb = self._get_params_callbacks(deep=deep)
params.update(params_cb)
return params

# don't include the following attributes
to_exclude = {'_modules', '_criteria', '_optimizers'}
return {key: val for key, val in params.items() if key not in to_exclude}
Comment on lines +1797 to +1798
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On master, these attributes are not included in get_params.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? On current master, after merging #751, get_params includes those.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I was mistaken. (I was on a older version of master locally)

Can we add a test here:

def test_get_params_works(self, net_cls, module_cls):

Something like:

assert `_modules` not in params

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I added a separate test below because that one was already big enough.


def _check_kwargs(self, kwargs):
"""Check argument names passed at initialization.
Expand Down Expand Up @@ -2095,11 +2098,13 @@ def _register_attribute(
self.cuda_dependent_attributes_ = (
self.cuda_dependent_attributes_[:] + [name + '_'])

if self.init_context_ == 'module':
# make sure to not double register -- this should never happen, but
# still better to check
if (self.init_context_ == 'module') and (name not in self._modules):
self._modules = self._modules[:] + [name]
Copy link
Member

@thomasjpfan thomasjpfan Jun 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is what fixed the bug. Would a set be a better data structure for _modules?

elif self.init_context_ == 'criterion':
elif (self.init_context_ == 'criterion') and (name not in self._criteria):
self._criteria = self._criteria[:] + [name]
elif self.init_context_ == 'optimizer':
elif (self.init_context_ == 'optimizer') and (name not in self._optimizers):
self._optimizers = self._optimizers[:] + [name]

def _unregister_attribute(
Expand Down
47 changes: 47 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,44 @@ def test_save_load_state_dict_str(
score_after = accuracy_score(y, net.predict(X))
assert np.isclose(score_after, score_before)

def test_save_load_state_dict_no_duplicate_registration_after_initialize(
self, net_cls, module_cls, net_fit, tmpdir):
# #781
net = net_cls(module_cls).initialize()

p = tmpdir.mkdir('skorch').join('testmodel.pkl')
with open(str(p), 'wb') as f:
net_fit.save_params(f_params=f)
del net_fit

with open(str(p), 'rb') as f:
net.load_params(f_params=f)

# check that there are no duplicates in _modules, _criteria, _optimizers
# pylint: disable=protected-access
assert net._modules == ['module']
assert net._criteria == ['criterion']
assert net._optimizers == ['optimizer']

def test_save_load_state_dict_no_duplicate_registration_after_clone(
self, net_fit, tmpdir):
# #781
net = clone(net_fit).initialize()

p = tmpdir.mkdir('skorch').join('testmodel.pkl')
with open(str(p), 'wb') as f:
net_fit.save_params(f_params=f)
del net_fit

with open(str(p), 'rb') as f:
net.load_params(f_params=f)

# check that there are no duplicates in _modules, _criteria, _optimizers
# pylint: disable=protected-access
assert net._modules == ['module']
assert net._criteria == ['criterion']
assert net._optimizers == ['optimizer']

@pytest.fixture(scope='module')
def net_fit_adam(self, net_cls, module_cls, data):
net = net_cls(
Expand Down Expand Up @@ -1426,6 +1464,15 @@ def test_get_params_works(self, net_cls, module_cls):
# now initialized
assert 'callbacks__myscore__scoring' in params

def test_get_params_no_unwanted_params(self, net, net_fit):
# #781
# make sure certain keys are not returned
keys_unwanted = {'_modules', '_criteria', '_optimizers'}
for net_ in (net, net_fit):
keys_found = set(net_.get_params())
overlap = keys_found & keys_unwanted
assert not overlap

def test_get_params_with_uninit_callbacks(self, net_cls, module_cls):
from skorch.callbacks import EpochTimer

Expand Down