Skip to content

Refactor train loop for easier customization #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[run]
omit =
skorch/tests/*
skorch/tests/*

4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
__pycache__/
dist/
docs/_build/
bin/
.ipynb_checkpoints/
Untitled*.ipynb
data/
*.egg-info
.coverage
Expand All @@ -18,3 +21,4 @@ data/
*.w2v
*prof
*.py~
*.pt
4 changes: 3 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Changed the signature of `validation_step`, `train_step_single`, `train_step`, `evaluation_step`, `on_batch_begin`, and `on_batch_end` such that instead of receiving `X` and `y`, they receive the whole batch; this makes it easier to deal with datasets that don't strictly return an `(X, y)` tuple, which is true for quite a few PyTorch datasets

### Fixed

## [0.9.0] - 2020-08-30
Expand Down Expand Up @@ -222,4 +224,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[0.6.0]: https://github.com/skorch-dev/skorch/compare/v0.5.0...v0.6.0
[0.7.0]: https://github.com/skorch-dev/skorch/compare/v0.6.0...v0.7.0
[0.8.0]: https://github.com/skorch-dev/skorch/compare/v0.7.0...v0.8.0
[0.9.0]: https://github.com/skorch-dev/skorch/compare/v0.8.0...v0.9.0
[0.9.0]: https://github.com/skorch-dev/skorch/compare/v0.8.0...v0.9.0
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ User's Guide
user/helper
user/REST
user/parallelism
user/customization
user/FAQ


Expand Down
4 changes: 2 additions & 2 deletions docs/user/FAQ.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,15 @@ gradient accumulation yourself:
loss = super().get_loss(*args, **kwargs)
return loss / self.acc_steps # normalize loss

def train_step(self, Xi, yi, **fit_params):
def train_step(self, batch, **fit_params):
"""Perform gradient accumulation

Only optimize every nth batch.

"""
# note that n_train_batches starts at 1 for each epoch
n_train_batches = len(self.history[-1, 'batches'])
step = self.train_step_single(Xi, yi, **fit_params)
step = self.train_step_single(batch, **fit_params)

if n_train_batches % self.acc_steps == 0:
self.optimizer_.step()
Expand Down
8 changes: 4 additions & 4 deletions docs/user/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ on_epoch_end(net, dataset_train, dataset_valid)
Called once at the end of the epoch, i.e. possibly several times per
fit call. Gets training and validation data as additional input.

on_batch_begin(net, Xi, yi, training)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
on_batch_begin(net, batch, training)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Called once before each batch of data is processed, i.e. possibly
several times per epoch. Gets batch data as additional input.
Also includes a bool indicating if this is a training batch or not.


on_batch_end(net, Xi, yi, training, loss, y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
on_batch_end(net, batch, training, loss, y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Called once after each batch of data is processed, i.e. possibly
several times per epoch. Gets batch data as additional input.
Expand Down
98 changes: 98 additions & 0 deletions docs/user/customization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
=============
Customization
=============

Customizing NeuralNet
---------------------

:class:`.NeuralNet` and its subclasses like
:class:`.NeuralNetClassifier` are already very flexible as they are
and should cover many use cases by adjusting the provided
parameters. However, this may not always be sufficient for your use
cases. If you thus find yourself wanting to customize
:class:`.NeuralNet`, please follow these guidelines.

Initialization
^^^^^^^^^^^^^^

The method :func:`~skorch.net.NeuralNet.initialize` is responsible for
initializing all the components needed by the net, e.g. the module and
the optimizer. For this, it calls specific initialization methods,
such as :func:`~skorch.net.NeuralNet.initialize_module` and
:func:`~skorch.net.NeuralNet.initialize_optimizer`. If you'd like to
customize the initialization behavior, you should override the
corresponding methods. Following sklearn conventions, the created
components should be set as an attribute with a trailing underscore as
the name, e.g. ``module_`` for the initialized module. Finally, the
method should return ``self``.

Methods starting with get_*
^^^^^^^^^^^^^^^^^^^^^^^^^^^

The net provides a few ``get_*`` methods, most notably
:func:`~skorch.net.NeuralNet.get_loss`,
:func:`~skorch.net.NeuralNet.get_dataset`, and
:func:`~skorch.net.NeuralNet.get_iterator`. The intent of these
methods should be pretty self-explanatory, and if you are still not
quite sure, consult their documentations. In general, these methods
are fairly safe to override as long as you make sure to conform to the
same signature as the original.

Training and validation
^^^^^^^^^^^^^^^^^^^^^^^

If you would like to customize training and validation, there are
several possibilities. Below are the methods that you most likely want
to customize:

The method :func:`~skorch.net.NeuralNet.train_step_single` performs a
single training step. It accepts the current batch of data as input
(as well as the ``fit_params``) and should return a dictionary
containing the ``loss`` and the prediction ``y_pred``. E.g. you should
override this if your dataset returns some non-standard data that
needs custom handling, and/or if your module has to be called in a
very specific way. If you want to, you can still make use of
:func:`~skorch.net.NeuralNet.infer` and
:func:`~skorch.net.NeuralNet.get_loss` but it's not strictly
necessary. Don't call the optimizer in this method, this is handled by
the next method.

The method :func:`~skorch.net.NeuralNet.train_step` defines the
complete training procedure performed for each batch. It accepts the
same arguments as :func:`~skorch.net.NeuralNet.train_step_single` but
it differs in that it defines the training closure passed to the
optimizer, which for instance could be called more than once (e.g. in
case of L-BFGS). You might override this if you deal with non-standard
training procedures, as e.g. gradient accumulation.

The method :func:`~skorch.net.NeuralNet.validation_step` is
responsible for calculating the prediction and loss on the validation
data (remember that skorch uses an internal validation set for
reporting, early stopping, etc.). Similar to
:func:`~skorch.net.NeuralNet.train_step_single`, it receives the batch
and ``fit_params`` as input and should return a dictionary containing
``loss`` and ``y_pred``. Most likely, when you need to customize
:func:`~skorch.net.NeuralNet.train_step_single`, you'll need to
customize :func:`~skorch.net.NeuralNet.validation_step` accordingly.

Finally, the method :func:`~skorch.net.NeuralNet.evaluation_step` is
called to you use the net for inference, e.g. when calling
:func:`~skorch.net.NeuralNet.forward` or
:func:`~skorch.net.NeuralNet.predict`. You may want to modify this if,
e.g., you want your model to behave differently during training and
during prediction.

You should also be aware that some methods are better left
untouched. E.g., in most cases, the following methods should *not* be
overridden:

* :func:`~skorch.net.NeuralNet.fit`
* :func:`~skorch.net.NeuralNet.partial_fit`
* :func:`~skorch.net.NeuralNet.fit_loop`
* :func:`~skorch.net.NeuralNet.run_single_epoch`

The reason why these methods should stay untouched is because they
perform some book keeping, like making sure that callbacks are handled
or writing logs to the ``history``. If you do need to override these,
make sure that you perform the same book keeping as the original
methods.
22 changes: 8 additions & 14 deletions skorch/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,26 @@ def initialize(self):
"""
return self

def on_train_begin(self, net,
X=None, y=None, **kwargs):
def on_train_begin(self, net, X=None, y=None, **kwargs):
"""Called at the beginning of training."""

def on_train_end(self, net,
X=None, y=None, **kwargs):
def on_train_end(self, net, X=None, y=None, **kwargs):
"""Called at the end of training."""

def on_epoch_begin(self, net,
dataset_train=None, dataset_valid=None, **kwargs):
def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
"""Called at the beginning of each epoch."""

def on_epoch_end(self, net,
dataset_train=None, dataset_valid=None, **kwargs):
def on_epoch_end(self, net, dataset_train=None, dataset_valid=None, **kwargs):
"""Called at the end of each epoch."""

def on_batch_begin(self, net,
X=None, y=None, training=None, **kwargs):
def on_batch_begin(self, net, batch=None, training=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that changing the signature of callbacks would break's users code the most.

Copy link
Collaborator

@githubnemo githubnemo Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that changing the signature of callbacks would break's users code the most.

Agreed. Not every user might follow the changelog closely, we could wrap the notify call with an exception handler and in case of on_batch_begin we could raise the original exception in addition to a note informing the user about the change in signature and how to fix it. This could stay there for one release.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a stab at this but it turns out not to be so simple. The reason is that we always have **kwargs in the signature, so even if someone writes a wrong method like def on_batch_begin(net, X=None, y=None, **kwargs), it would not immediately lead to an error. The error occurs only when they do something with X or y, because these will be None instead of the expected values. But this something can be anything.

In my attempt, I just caught any kind of Exception and raised a TypeError from it that pointed at the possible error source. However, this broke other tests, which checked for specific errors (for instance, using scoring with a non existing function name). We don't want that to happen.

The problem is thus that we don't know what error will occur but we also don't want to catch all kinds of errors.

We could theoretically inspect the arguments inside the notify-wrapper. But there I would be afraid that the performance penalty could be too big. Furthermore, we want to allow the users to be very liberal in how they call their on_* methods (e.g. only passing the arguments that they need, which might not even include batch), so it's not possible to do a strict check.

This leaves me with no idea how to implement the suggestion. Do you have any ideas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK this is fair, maybe we can issue a warning on the first call of notify('on_batch_begin') for callbacks that aren't ours that indicates that things might have changed? We can deprecate this warning in the next or a following release and you could always filter the warning if you are getting annoyed by it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically on the very first call to notify, you'd like to go through all callbacks and check whether they are an instance of any built-in callback, and if at least one of them isn't, issue a warning? I'm not sure if this is not overkill, given that the warning will mostly be a false alarm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no, more like if not cb.__module__.startswith('skorch'): issueWarning().

"""Called at the beginning of each batch."""

def on_batch_end(self, net,
X=None, y=None, training=None, **kwargs):
def on_batch_end(self, net, batch=None, training=None, **kwargs):
"""Called at the end of each batch."""

def on_grad_computed(self, net, named_parameters,
X=None, y=None, training=None, **kwargs):
def on_grad_computed(
self, net, named_parameters, X=None, y=None, training=None, **kwargs):
"""Called once per batch after gradients have been computed but before
an update step was performed.
"""
Expand Down
6 changes: 3 additions & 3 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def on_train_begin(self, net, **kwargs):
)

def on_epoch_end(self, net, **kwargs):
if not self.step_every == 'epoch':
if self.step_every != 'epoch':
return
epoch = len(net.history)
if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
Expand All @@ -177,7 +177,7 @@ def on_epoch_end(self, net, **kwargs):
self.lr_scheduler_.step(epoch)

def on_batch_end(self, net, training, **kwargs):
if not training or not self.step_every == 'batch':
if not training or self.step_every != 'batch':
return
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(

def _get_current_lr(self, min_lr, max_lr, period, epoch):
return min_lr + 0.5 * (max_lr - min_lr) * (
1 + np.cos(epoch * np.pi / period))
1 + np.cos(epoch * np.pi / period))

def get_lr(self):
epoch_idx = float(self.last_epoch)
Expand Down
14 changes: 8 additions & 6 deletions skorch/callbacks/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
else:
from sklearn.metrics.scorer import _BaseScorer

from skorch.callbacks import Callback
from skorch.dataset import unpack_data
from skorch.utils import data_from_dataset
from skorch.utils import is_skorch_dataset
from skorch.utils import to_numpy
from skorch.callbacks import Callback
from skorch.utils import check_indexing
from skorch.utils import to_device

Expand Down Expand Up @@ -146,9 +147,8 @@ def _get_name(self):
if hasattr(self.scoring_._score_func, '__name__'):
# sklearn < 0.22
return self.scoring_._score_func.__name__
else:
# sklearn >= 0.22
return self.scoring_._score_func._score_func.__name__
# sklearn >= 0.22
return self.scoring_._score_func._score_func.__name__
if isinstance(self.scoring_, dict):
raise ValueError("Dict not supported as scorer for multi-metric scoring."
" Register multiple scoring callbacks instead.")
Expand Down Expand Up @@ -239,10 +239,11 @@ class BatchScoring(ScoringBase):
"""
# pylint: disable=unused-argument,arguments-differ

def on_batch_end(self, net, X, y, training, **kwargs):
def on_batch_end(self, net, batch, training, **kwargs):
if training != self.on_train:
return

X, y = unpack_data(batch)
y_preds = [kwargs['y_pred']]
with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net:
# In case of y=None we will not have gathered any samples.
Expand Down Expand Up @@ -367,7 +368,7 @@ def on_epoch_begin(self, net, dataset_train, dataset_valid, **kwargs):

# pylint: disable=arguments-differ
def on_batch_end(
self, net, y, y_pred, training, **kwargs):
self, net, batch, y_pred, training, **kwargs):
if not self.use_caching or training != self.on_train:
return

Expand All @@ -378,6 +379,7 @@ def on_batch_end(
# self.target_extractor(y) here but on epoch end, so that
# there are no copies of parts of y hanging around during
# training.
y = unpack_data(batch)[1]
if y is not None:
self.y_trues_.append(y)
self.y_preds_.append(y_pred)
Expand Down
2 changes: 1 addition & 1 deletion skorch/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from skorch.dataset import CVSplit
from skorch.utils import get_dim
from skorch.utils import is_dataset
from skorch.utils import to_numpy


neural_net_clf_doc_start = """NeuralNet for classification tasks
Expand Down Expand Up @@ -117,6 +116,7 @@ def check_data(self, X, y):
"respectively.")
raise ValueError(msg)
if y is not None:
# pylint: disable=attribute-defined-outside-init
self.classes_inferred_ = np.unique(y)

# pylint: disable=arguments-differ
Expand Down
11 changes: 0 additions & 11 deletions skorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,6 @@ def get_len(data):
return list(len_set)[0]


def uses_placeholder_y(ds):
"""If ``ds`` is a ``skorch.dataset.Dataset`` or a
``skorch.dataset.Dataset`` nested inside a
``torch.utils.data.Subset`` and uses
y as a placeholder, return ``True``."""

if isinstance(ds, torch.utils.data.Subset):
return uses_placeholder_y(ds.dataset)
return isinstance(ds, Dataset) and hasattr(ds, "y") and ds.y is None


def unpack_data(data):
"""Unpack data returned by the net's iterator into a 2-tuple.

Expand Down
3 changes: 1 addition & 2 deletions skorch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

"""
from collections import Sequence
from collections import namedtuple
from functools import partial

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
import torch

from skorch.cli import parse_args
from skorch.cli import parse_args # pylint: disable=unused-import
from skorch.utils import _make_split
from skorch.utils import is_torch_data_type
from skorch.utils import to_tensor
Expand Down
Loading