Skip to content

Refactor train loop for easier customization #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[run]
omit =
skorch/tests/*
skorch/tests/*

4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
__pycache__/
dist/
docs/_build/
bin/
.ipynb_checkpoints/
Untitled*.ipynb
data/
*.egg-info
.coverage
Expand All @@ -18,3 +21,4 @@ data/
*.w2v
*prof
*.py~
*.pt
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- We no longer pass the `epoch` parameter to LR schedulers, since that parameter has been deprecated. We now rely on the scheduler to keep track of the epoch itself.
- Changed implementation of `net.history` access to make it faster; this should result in a nice speedup when dealing with very small model/data but otherwise not have any noticeable effects; if you encounter bugs, though, please create an issue
- Changed the signature of `validation_step`, `train_step_single`, `train_step`, `evaluation_step`, `on_batch_begin`, and `on_batch_end` such that instead of receiving `X` and `y`, they receive the whole batch; this makes it easier to deal with datasets that don't strictly return an `(X, y)` tuple, which is true for quite a few PyTorch datasets; please refer to the [migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-9-to-0-10) if you encounter problems

### Fixed

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ User's Guide
user/helper
user/REST
user/parallelism
user/customization
user/performance
user/FAQ

Expand Down
44 changes: 42 additions & 2 deletions docs/user/FAQ.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,15 @@ gradient accumulation yourself:
loss = super().get_loss(*args, **kwargs)
return loss / self.acc_steps # normalize loss

def train_step(self, Xi, yi, **fit_params):
def train_step(self, batch, **fit_params):
"""Perform gradient accumulation

Only optimize every nth batch.

"""
# note that n_train_batches starts at 1 for each epoch
n_train_batches = len(self.history[-1, 'batches'])
step = self.train_step_single(Xi, yi, **fit_params)
step = self.train_step_single(batch, **fit_params)

if n_train_batches % self.acc_steps == 0:
self.optimizer_.step()
Expand Down Expand Up @@ -404,3 +404,43 @@ the **greatest** score.
grid_searcher.fit(X, y)
best_net = grid_searcher.best_estimator_
print(best_net.score(X, y))

Migration guide
---------------

Migration from 0.9 to 0.10
^^^^^^^^^^^^^^^^^^^^^^^^^^

With skorch 0.10, we pushed the tuple unpacking of values returned by
the iterator to methods lower down the call chain. This way, it is
much easier to work with iterators that don't return exactly two
values, as per the convention.

A consequence of this is a **change in signature** of these methods:

- :py:meth:`skorch.net.NeuralNet.train_step_single`
- :py:meth:`skorch.net.NeuralNet.validation_step`
- :py:meth:`skorch.callbacks.Callback.on_batch_begin`
- :py:meth:`skorch.callbacks.Callback.on_batch_end`

Instead of receiving the unpacked tuple of ``X`` and ``y``, they just
receive a ``batch``, which is whatever is returned by the
iterator. The tuple unpacking needs to be performed inside these
methods.

If you have customized any of these methods, it is easy to retrieve
the previous behavior. E.g. if you wrote your own ``on_batch_begin``,
this is how to make the transition:

.. code:: python

# before
def on_batch_begin(self, net, X, y, ...):
...

# after
def on_batch_begin(self, net, batch, ...):
X, y = batch
...

The same goes for the other three methods.
8 changes: 4 additions & 4 deletions docs/user/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ on_epoch_end(net, dataset_train, dataset_valid)
Called once at the end of the epoch, i.e. possibly several times per
fit call. Gets training and validation data as additional input.

on_batch_begin(net, Xi, yi, training)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
on_batch_begin(net, batch, training)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Called once before each batch of data is processed, i.e. possibly
several times per epoch. Gets batch data as additional input.
Also includes a bool indicating if this is a training batch or not.


on_batch_end(net, Xi, yi, training, loss, y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
on_batch_end(net, batch, training, loss, y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Called once after each batch of data is processed, i.e. possibly
several times per epoch. Gets batch data as additional input.
Expand Down
98 changes: 98 additions & 0 deletions docs/user/customization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
=============
Customization
=============

Customizing NeuralNet
---------------------

:class:`.NeuralNet` and its subclasses like
:class:`.NeuralNetClassifier` are already very flexible as they are
and should cover many use cases by adjusting the provided
parameters. However, this may not always be sufficient for your use
cases. If you thus find yourself wanting to customize
:class:`.NeuralNet`, please follow these guidelines.

Initialization
^^^^^^^^^^^^^^

The method :func:`~skorch.net.NeuralNet.initialize` is responsible for
initializing all the components needed by the net, e.g. the module and
the optimizer. For this, it calls specific initialization methods,
such as :func:`~skorch.net.NeuralNet.initialize_module` and
:func:`~skorch.net.NeuralNet.initialize_optimizer`. If you'd like to
customize the initialization behavior, you should override the
corresponding methods. Following sklearn conventions, the created
components should be set as an attribute with a trailing underscore as
the name, e.g. ``module_`` for the initialized module. Finally, the
method should return ``self``.

Methods starting with get_*
^^^^^^^^^^^^^^^^^^^^^^^^^^^

The net provides a few ``get_*`` methods, most notably
:func:`~skorch.net.NeuralNet.get_loss`,
:func:`~skorch.net.NeuralNet.get_dataset`, and
:func:`~skorch.net.NeuralNet.get_iterator`. The intent of these
methods should be pretty self-explanatory, and if you are still not
quite sure, consult their documentations. In general, these methods
are fairly safe to override as long as you make sure to conform to the
same signature as the original.

Training and validation
^^^^^^^^^^^^^^^^^^^^^^^

If you would like to customize training and validation, there are
several possibilities. Below are the methods that you most likely want
to customize:

The method :func:`~skorch.net.NeuralNet.train_step_single` performs a
single training step. It accepts the current batch of data as input
(as well as the ``fit_params``) and should return a dictionary
containing the ``loss`` and the prediction ``y_pred``. E.g. you should
override this if your dataset returns some non-standard data that
needs custom handling, and/or if your module has to be called in a
very specific way. If you want to, you can still make use of
:func:`~skorch.net.NeuralNet.infer` and
:func:`~skorch.net.NeuralNet.get_loss` but it's not strictly
necessary. Don't call the optimizer in this method, this is handled by
the next method.

The method :func:`~skorch.net.NeuralNet.train_step` defines the
complete training procedure performed for each batch. It accepts the
same arguments as :func:`~skorch.net.NeuralNet.train_step_single` but
it differs in that it defines the training closure passed to the
optimizer, which for instance could be called more than once (e.g. in
case of L-BFGS). You might override this if you deal with non-standard
training procedures, as e.g. gradient accumulation.

The method :func:`~skorch.net.NeuralNet.validation_step` is
responsible for calculating the prediction and loss on the validation
data (remember that skorch uses an internal validation set for
reporting, early stopping, etc.). Similar to
:func:`~skorch.net.NeuralNet.train_step_single`, it receives the batch
and ``fit_params`` as input and should return a dictionary containing
``loss`` and ``y_pred``. Most likely, when you need to customize
:func:`~skorch.net.NeuralNet.train_step_single`, you'll need to
customize :func:`~skorch.net.NeuralNet.validation_step` accordingly.

Finally, the method :func:`~skorch.net.NeuralNet.evaluation_step` is
called to you use the net for inference, e.g. when calling
:func:`~skorch.net.NeuralNet.forward` or
:func:`~skorch.net.NeuralNet.predict`. You may want to modify this if,
e.g., you want your model to behave differently during training and
during prediction.

You should also be aware that some methods are better left
untouched. E.g., in most cases, the following methods should *not* be
overridden:

* :func:`~skorch.net.NeuralNet.fit`
* :func:`~skorch.net.NeuralNet.partial_fit`
* :func:`~skorch.net.NeuralNet.fit_loop`
* :func:`~skorch.net.NeuralNet.run_single_epoch`

The reason why these methods should stay untouched is because they
perform some book keeping, like making sure that callbacks are handled
or writing logs to the ``history``. If you do need to override these,
make sure that you perform the same book keeping as the original
methods.
106 changes: 92 additions & 14 deletions skorch/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
""" Basic callback definition. """

import warnings

from sklearn.base import BaseEstimator
from skorch.exceptions import SkorchWarning


__all__ = ['Callback']
Expand Down Expand Up @@ -28,32 +31,26 @@ def initialize(self):
"""
return self

def on_train_begin(self, net,
X=None, y=None, **kwargs):
def on_train_begin(self, net, X=None, y=None, **kwargs):
"""Called at the beginning of training."""

def on_train_end(self, net,
X=None, y=None, **kwargs):
def on_train_end(self, net, X=None, y=None, **kwargs):
"""Called at the end of training."""

def on_epoch_begin(self, net,
dataset_train=None, dataset_valid=None, **kwargs):
def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
"""Called at the beginning of each epoch."""

def on_epoch_end(self, net,
dataset_train=None, dataset_valid=None, **kwargs):
def on_epoch_end(self, net, dataset_train=None, dataset_valid=None, **kwargs):
"""Called at the end of each epoch."""

def on_batch_begin(self, net,
X=None, y=None, training=None, **kwargs):
def on_batch_begin(self, net, batch=None, training=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that changing the signature of callbacks would break's users code the most.

Copy link
Collaborator

@githubnemo githubnemo Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that changing the signature of callbacks would break's users code the most.

Agreed. Not every user might follow the changelog closely, we could wrap the notify call with an exception handler and in case of on_batch_begin we could raise the original exception in addition to a note informing the user about the change in signature and how to fix it. This could stay there for one release.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a stab at this but it turns out not to be so simple. The reason is that we always have **kwargs in the signature, so even if someone writes a wrong method like def on_batch_begin(net, X=None, y=None, **kwargs), it would not immediately lead to an error. The error occurs only when they do something with X or y, because these will be None instead of the expected values. But this something can be anything.

In my attempt, I just caught any kind of Exception and raised a TypeError from it that pointed at the possible error source. However, this broke other tests, which checked for specific errors (for instance, using scoring with a non existing function name). We don't want that to happen.

The problem is thus that we don't know what error will occur but we also don't want to catch all kinds of errors.

We could theoretically inspect the arguments inside the notify-wrapper. But there I would be afraid that the performance penalty could be too big. Furthermore, we want to allow the users to be very liberal in how they call their on_* methods (e.g. only passing the arguments that they need, which might not even include batch), so it's not possible to do a strict check.

This leaves me with no idea how to implement the suggestion. Do you have any ideas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK this is fair, maybe we can issue a warning on the first call of notify('on_batch_begin') for callbacks that aren't ours that indicates that things might have changed? We can deprecate this warning in the next or a following release and you could always filter the warning if you are getting annoyed by it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically on the very first call to notify, you'd like to go through all callbacks and check whether they are an instance of any built-in callback, and if at least one of them isn't, issue a warning? I'm not sure if this is not overkill, given that the warning will mostly be a false alarm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, no, more like if not cb.__module__.startswith('skorch'): issueWarning().

"""Called at the beginning of each batch."""

def on_batch_end(self, net,
X=None, y=None, training=None, **kwargs):
def on_batch_end(self, net, batch=None, training=None, **kwargs):
"""Called at the end of each batch."""

def on_grad_computed(self, net, named_parameters,
X=None, y=None, training=None, **kwargs):
def on_grad_computed(
self, net, named_parameters, X=None, y=None, training=None, **kwargs):
"""Called once per batch after gradients have been computed but before
an update step was performed.
"""
Expand All @@ -66,3 +63,84 @@ def get_params(self, deep=True):

def set_params(self, **params):
BaseEstimator.set_params(self, **params)


# TODO: remove after some deprecation period, e.g. skorch 0.12
def _on_batch_overridden(callback):
"""Check if on_batch_begin or on_batch_end were overridden

If the method does not exist at all, it's not considered overridden. This is
mostly for callbacks that are mocked.

"""
try:
base_skorch_cls = next(cls for cls in callback.__class__.__mro__
if cls.__module__.startswith('skorch'))
except StopIteration:
# does not derive from skorch callback, possibly a mock
return False

obb = base_skorch_cls.on_batch_begin
obe = base_skorch_cls.on_batch_end
return (
getattr(callback.__class__, 'on_batch_begin', obb) is not obb
or getattr(callback.__class__, 'on_batch_end', obe) is not obe
)


# TODO: remove after some deprecation period, e.g. skorch 0.12
def _issue_warning_if_on_batch_override(callback_list):
"""Check callbacks for overridden on_batch method and issue warning

We introduced a breaking change by changing the signature of on_batch_begin
and on_batch_end. To help users, we try to detect if they use any custom
callback that overrides on of these methods and issue a warning if they do.
The warning states how to adjust the method signature and how it can be
filtered.

After some transition period, the checking and the warning should be
removed again.

Parameters
----------
callback_list : list of (str, callback) tuples
List of initialized callbacks.

Warns
-----
Issues a ``SkorchWarning`` if any of the callbacks fits the conditions.

"""
if not callback_list:
return

callbacks = [callback for _, callback in callback_list]

# first detect if there are any user defined callbacks
user_defined_callbacks = [
callback for callback in callbacks
if not callback.__module__.startswith('skorch')
]
if not user_defined_callbacks:
return

# check if any of these callbacks overrides on_batch_begin or on_batch_end
overriding_callbacks = [
callback for callback in user_defined_callbacks
if _on_batch_overridden(callback)
]

if not overriding_callbacks:
return

warning_msg = (
"You are using an callback that overrides on_batch_begin "
"or on_batch_end. As of skorch 0.10, the signature was changed "
"from 'on_batch_{begin,end}(self, X, y, ...)' to "
"'on_batch_{begin,end}(self, batch, ...)'. To recover, change "
"the signature accordingly and add 'X, y = batch' on the first "
"line of the method body. To suppress this warning, add:\n"
"'import warnings; from skorch.exceptions import SkorchWarning\n"
"warnings.filterwarnings('ignore', category=SkorchWarning)'.")

warnings.warn(warning_msg, SkorchWarning)
6 changes: 3 additions & 3 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def on_train_begin(self, net, **kwargs):
)

def on_epoch_end(self, net, **kwargs):
if not self.step_every == 'epoch':
if self.step_every != 'epoch':
return
if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
Expand All @@ -166,7 +166,7 @@ def on_epoch_end(self, net, **kwargs):
self.lr_scheduler_.step()

def on_batch_end(self, net, training, **kwargs):
if not training or not self.step_every == 'batch':
if not training or self.step_every != 'batch':
return
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(

def _get_current_lr(self, min_lr, max_lr, period, epoch):
return min_lr + 0.5 * (max_lr - min_lr) * (
1 + np.cos(epoch * np.pi / period))
1 + np.cos(epoch * np.pi / period))

def get_lr(self):
epoch_idx = float(self.last_epoch)
Expand Down
Loading