Skip to content

[MRG] Adds Batch Count in History #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 1, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Adds FAQ entry regarding the initialization behavior of `NeuralNet` when passed instantiated models. (#409)
- Added CUDA pickle test including an artifact that supports testing on CUDA-less CI machines
- Adds `train_batch_count` and `valid_batch_count` to history in training loop. (#445)

### Changed

Expand Down
1 change: 1 addition & 0 deletions skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def _sorted_keys(self, keys):
(key in ('epoch', 'dur')) or
(key in self.keys_ignored_) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
sorted_keys.append(key)
Expand Down
13 changes: 9 additions & 4 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def kwargs(self):

def on_train_begin(self, net, **kwargs):
if net.history:
self.batch_idx_ = len(net.history[:, 'batches']) - 1
try:
self.batch_idx_ = sum(net.history[:, 'train_batch_count'])
except KeyError:
self.batch_idx_ = sum(len(b) for b in net.history[:, 'batches'])
self.lr_scheduler_ = self._get_scheduler(
net, self.policy_, **self.kwargs
)
Expand All @@ -138,15 +141,17 @@ def on_epoch_begin(self, net, **kwargs):
else:
self.lr_scheduler_.step(epoch)

def on_batch_begin(self, net, **kwargs):
def on_batch_begin(self, net, training, **kwargs):
if (
training and
hasattr(self.lr_scheduler_, 'batch_step') and
callable(self.lr_scheduler_.batch_step)
):
self.lr_scheduler_.batch_step(self.batch_idx_)

def on_batch_end(self, net, **kwargs):
self.batch_idx_ += 1
def on_batch_end(self, net, training, **kwargs):
if training:
self.batch_idx_ += 1

def _get_scheduler(self, net, policy, **scheduler_kwargs):
"""Return scheduler, based on indicated policy, with appropriate
Expand Down
6 changes: 6 additions & 0 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
for _ in range(epochs):
self.notify('on_epoch_begin', **on_epoch_kwargs)

train_batch_count = 0
for data in self.get_iterator(dataset_train, training=True):
Xi, yi = unpack_data(data)
yi_res = yi if not y_train_is_ph else None
Expand All @@ -742,11 +743,14 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
self.history.record_batch('train_loss', step['loss'].item())
self.history.record_batch('train_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
train_batch_count += 1
self.history.record("train_batch_count", train_batch_count)

if dataset_valid is None:
self.notify('on_epoch_end', **on_epoch_kwargs)
continue

valid_batch_count = 0
for data in self.get_iterator(dataset_valid, training=False):
Xi, yi = unpack_data(data)
yi_res = yi if not y_valid_is_ph else None
Expand All @@ -755,6 +759,8 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
self.history.record_batch('valid_loss', step['loss'].item())
self.history.record_batch('valid_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
valid_batch_count += 1
self.history.record("valid_batch_count", valid_batch_count)

self.notify('on_epoch_end', **on_epoch_kwargs)
return self
Expand Down
58 changes: 55 additions & 3 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,70 @@ def test_lr_callback_batch_steps_correctly(
policy,
kwargs,
):
num_examples = 1000
batch_size = 100
max_epochs = 2

X, y = classifier_data
num_examples = len(X)

lr_policy = LRScheduler(policy, **kwargs)
net = NeuralNetClassifier(classifier_module(), max_epochs=max_epochs,
batch_size=batch_size, callbacks=[lr_policy])
net.fit(X, y)
expected = (num_examples // batch_size) * max_epochs - 1

total_iterations_per_epoch = num_examples / batch_size
# 80% of sample used for training by default
total_training_iterations_per_epoch = 0.8 * total_iterations_per_epoch

expected = int(total_training_iterations_per_epoch * max_epochs)
# pylint: disable=protected-access
assert lr_policy.batch_idx_ == expected

@pytest.mark.parametrize('policy, kwargs', [
('CyclicLR', {}),
])
def test_lr_callback_batch_steps_correctly_fallback(
self,
classifier_module,
classifier_data,
policy,
kwargs,
):
batch_size = 100
max_epochs = 2

X, y = classifier_data
num_examples = len(X)

lr_policy = LRScheduler(policy, **kwargs)
net = NeuralNetClassifier(classifier_module(), max_epochs=max_epochs,
batch_size=batch_size, callbacks=[lr_policy])
net.fit(X, y)

# Removes batch count information in the last two epochs
for i in range(2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't 2 be max_epochs here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! PR updated to reflect this.

del net.history[i]["train_batch_count"]
del net.history[i]["valid_batch_count"]
net.partial_fit(X, y)

total_iterations_per_epoch = num_examples / batch_size

# batch_counts were removed thus the total iterations of the last
# epoch is used
total_iterations_fit_run = total_iterations_per_epoch * max_epochs

# 80% of sample used for training by default
total_iterations_partial_fit_run = (
0.8 * total_iterations_per_epoch * max_epochs)

# called fit AND partial_fit
total_iterations = (total_iterations_fit_run +
total_iterations_partial_fit_run)
# Failback to using both valid and training batches counts on
# second run
expected = int(total_iterations)
# pylint: disable=protected-access
assert lr_policy.lr_scheduler_.last_batch_idx == expected
assert lr_policy.batch_idx_ == expected

def test_lr_scheduler_cloneable(self):
# reproduces bug #271
Expand Down
13 changes: 13 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,19 @@ def test_batch_size_neg_1_uses_whole_dataset(
assert train_kwargs['batch_size'] == expected_train_batch_size
assert valid_kwargs['batch_size'] == expected_valid_batch_size

@pytest.mark.parametrize('batch_size', [40, 100])
def test_batch_count(self, net_cls, module_cls, data, batch_size):

net = net_cls(module_cls, max_epochs=1, batch_size=batch_size)
X, y = data
net.fit(X, y)

train_batch_count = int(0.8 * len(X)) / batch_size
valid_batch_count = int(0.2 * len(X)) / batch_size

assert net.history[:, "train_batch_count"] == [train_batch_count]
assert net.history[:, "valid_batch_count"] == [valid_batch_count]

def test_fit_lbfgs_optimizer(self, net_cls, module_cls, data):
X, y = data
net = net_cls(
Expand Down