Skip to content

Commit 37313f5

Browse files
authored
FIX Allows TrainEndCheckpoint to be unpickled (#778)
1 parent 54796f1 commit 37313f5

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
### Fixed
2020

2121
- Fixed a few bugs in the `net.history` implementation (#776)
22+
- Fixed a bug in `TrainEndCheckpoint` that prevented it from being unpickled (#773)
2223

2324
## [0.10.0] - 2021-03-23
2425

skorch/callbacks/training.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,11 @@ def on_train_begin(self, net,
639639
X=None, y=None, **kwargs):
640640
if not self.did_load_:
641641
self.did_load_ = True
642-
with suppress(Exception):
643-
net.load_params(checkpoint=self.checkpoint)
642+
with suppress(FileNotFoundError):
643+
if isinstance(self.checkpoint, TrainEndCheckpoint):
644+
net.load_params(checkpoint=self.checkpoint.checkpoint_)
645+
else:
646+
net.load_params(checkpoint=self.checkpoint)
644647

645648

646649
class TrainEndCheckpoint(Callback):
@@ -752,10 +755,9 @@ def initialize(self):
752755
**self._f_kwargs()
753756
)
754757
self.checkpoint_.initialize()
758+
return self
755759

756760
def on_train_end(self, net, **kwargs):
757761
self.checkpoint_.save_model(net)
758762
self.checkpoint_._sink("Final checkpoint triggered", net.verbose)
759-
760-
def __getattr__(self, attr):
761-
return getattr(self.checkpoint_, attr)
763+
return self

skorch/tests/callbacks/test_training.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for callbacks in training.py"""
22

33
from functools import partial
4+
import pickle
45
from unittest.mock import Mock
56
from unittest.mock import patch
67
from unittest.mock import call
@@ -1125,3 +1126,17 @@ def initialize_module(self, *args, **kwargs):
11251126

11261127
assert save_params_mock.call_count == 1
11271128
save_params_mock.assert_has_calls([call(f_mymodule='train_end_mymodule.pt')])
1129+
1130+
def test_pickle_uninitialized_callback(self, trainendcheckpoint_cls):
1131+
# isuue 773
1132+
cp = trainendcheckpoint_cls()
1133+
# does not raise
1134+
s = pickle.dumps(cp)
1135+
pickle.loads(s)
1136+
1137+
def test_pickle_initialized_callback(self, trainendcheckpoint_cls):
1138+
# issue 773
1139+
cp = trainendcheckpoint_cls().initialize()
1140+
# does not raise
1141+
s = pickle.dumps(cp)
1142+
pickle.loads(s)

0 commit comments

Comments
 (0)