Skip to content

Commit b2560a4

Browse files
authored
Resolve #656 by fixing progressbar pickling (#663)
The fix is to ignore the tqdm instance in the returned state of the progress bar callback.
1 parent 44bef1b commit b2560a4

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626

2727
- Fixed a bug where `CyclicLR` scheduler would update during both training and validation rather than just during training.
2828
- Fixed a bug introduced by moving the `optimizer.zero_grad()` call outside of the train step function, making it incompatible with LBFGS and other optimizers that call the train step several times per batch (#636)
29+
- Fixed pickling of the `ProgressBar` callback (#656)
2930

3031
## [0.8.0] - 2019-04-11
3132

skorch/callbacks/logging.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,8 @@ def _get_postfix_dict(self, net):
556556

557557
# pylint: disable=attribute-defined-outside-init
558558
def on_batch_end(self, net, **kwargs):
559-
self.pbar.set_postfix(self._get_postfix_dict(net), refresh=False)
560-
self.pbar.update()
559+
self.pbar_.set_postfix(self._get_postfix_dict(net), refresh=False)
560+
self.pbar_.update()
561561

562562
# pylint: disable=attribute-defined-outside-init, arguments-differ
563563
def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
@@ -576,12 +576,19 @@ def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
576576
batches_per_epoch = len(net.history[-2, 'batches'])
577577

578578
if self._use_notebook():
579-
self.pbar = tqdm.tqdm_notebook(total=batches_per_epoch, leave=False)
579+
self.pbar_ = tqdm.tqdm_notebook(total=batches_per_epoch, leave=False)
580580
else:
581-
self.pbar = tqdm.tqdm(total=batches_per_epoch, leave=False)
581+
self.pbar_ = tqdm.tqdm(total=batches_per_epoch, leave=False)
582582

583583
def on_epoch_end(self, net, **kwargs):
584-
self.pbar.close()
584+
self.pbar_.close()
585+
586+
def __getstate__(self):
587+
# don't save away the temporary pbar_ object which gets created on
588+
# epoch begin anew anyway. This avoids pickling errors with tqdm.
589+
state = self.__dict__.copy()
590+
del state['pbar_']
591+
return state
585592

586593

587594
def rename_tensorboard_key(key):

skorch/tests/callbacks/test_logging.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,20 @@ def test_different_count_schemes(
541541
for i, total in enumerate(expected_total):
542542
assert tqdm_mock.call_args_list[i][1]['total'] == total
543543

544+
def test_pickle(self, net_cls, progressbar_cls, data):
545+
# pickling was an issue since TQDM progress bar instances cannot
546+
# be pickled. Test pickling and restoration.
547+
import pickle
548+
549+
net = net_cls(callbacks=[
550+
progressbar_cls(),
551+
])
552+
net.fit(*data)
553+
dump = pickle.dumps(net)
554+
555+
net = pickle.loads(dump)
556+
net.fit(*data)
557+
544558

545559
@pytest.mark.skipif(
546560
not tensorboard_installed, reason='tensorboard is not installed')
@@ -776,4 +790,4 @@ def forward(self, b, e, c, d, a, **kwargs):
776790
net.fit(X_dict, y)
777791

778792
# is not empty
779-
assert os.listdir(path)
793+
assert os.listdir(path)

0 commit comments

Comments
 (0)