From 967df28be13cd35821512ea373b2cfa692ea1cfc Mon Sep 17 00:00:00 2001 From: BenjaminBossan Date: Wed, 12 Feb 2020 21:51:38 +0100 Subject: [PATCH 1/2] Fix a bug with custom attribute name clashes When a user adds a new "settable" attribute (i.e. that works with set_params) whose names starts the same as an existing attribute (say "optimizer_2"), and adds a corresponding argument (say "optimizer_2__lr"), skorch will erroneously complain about this argument because it thinks it belongs to the other attribute ("optimizer" in this example). The added unit test should illustrate this behavior. --- CHANGES.md | 1 + skorch/net.py | 2 +- skorch/tests/test_net.py | 23 ++++++++++++++++++++++- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 033e453b5..4a5f8e554 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Make skorch compatible with sklearn 0.22 +- Fixed a bug that could occur when a new "settable" (via `set_params`) attribute was added to `NeuralNet` whose name starts the same as an existing attribute's name ## [0.7.0] - 2019-11-29 diff --git a/skorch/net.py b/skorch/net.py index f954aaca6..a71bb845c 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -1334,7 +1334,7 @@ def _check_kwargs(self, kwargs): for key in kwargs: if key.endswith('_'): continue - for prefix in self.prefixes_: + for prefix in sorted(self.prefixes_, key=lambda s: (-len(s), s)): if key.startswith(prefix): if not key.startswith(prefix + '__'): missing_dunder_kwargs.append((prefix, key)) diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 31679ae96..e1e5a6f98 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -169,7 +169,7 @@ def test_net_init_one_unknown_argument(self, net_cls, module_cls): "should deal with the new arguments explicitely.") assert e.value.args[0] == expected - def test_net_init_two_unknown_argument(self, net_cls, module_cls): + def test_net_init_two_unknown_arguments(self, net_cls, module_cls): with pytest.raises(TypeError) as e: net_cls(module_cls, lr=0.1, mxa_epochs=5, warm_start=False, bathc_size=20) @@ -231,6 +231,27 @@ def test_net_init_missing_dunder_and_unknown( "did you mean iterator_train__shuffle?") assert e.value.args[0] == expected + def test_net_with_new_attribute_with_name_clash( + self, net_cls, module_cls): + # This covers a bug that existed when a new "settable" + # argument was added whose name starts the same as the name + # for an existing argument + class MyNet(net_cls): + # add "optimizer_2" as a valid prefix so that it works + # with set_params + prefixes_ = net_cls.prefixes_[:] + ['optimizer_2'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer_2 = torch.optim.SGD + + # the following line used to raise this error: "TypeError: Got + # an unexpected argument optimizer_2__lr, did you mean + # optimizer__2__lr?" because it was erronously assumed that + # "optimizer_2__lr" should be dispatched to "optimizer", not + # "optimizer_2". + MyNet(module_cls, optimizer_2__lr=0.123) # should not raise + def test_fit(self, net_fit): # fitting does not raise anything pass From fb8c9a1652139f0eb2333059b4927503f9ee6855 Mon Sep 17 00:00:00 2001 From: BenjaminBossan Date: Wed, 12 Feb 2020 23:06:02 +0100 Subject: [PATCH 2/2] Add a comment about why we need sorting --- skorch/net.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/skorch/net.py b/skorch/net.py index a71bb845c..23ed01c10 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -1334,6 +1334,9 @@ def _check_kwargs(self, kwargs): for key in kwargs: if key.endswith('_'): continue + + # see https://github.com/skorch-dev/skorch/pull/590 for + # why this must be sorted for prefix in sorted(self.prefixes_, key=lambda s: (-len(s), s)): if key.startswith(prefix): if not key.startswith(prefix + '__'):