-
Notifications
You must be signed in to change notification settings - Fork 398
Refactor train loop for easier customization #699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ottonemo
merged 15 commits into
master
from
refactor-train-loop-for-easier-customization
Mar 26, 2021
Merged
Changes from 4 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
0d70af6
Small refactoring of some training/evaluation methods
3c24974
Clean ups to satisfy pylint
fbfb827
Add more to .gitignore and remove duplication from .coveragerc
b767e30
Clarify doc for return value of train_step
f5dcd10
Reviewer comment: avoid unnecessary lambda in test
BenjaminBossan 11eb398
Add migration guide for affected methods to FAQ
c49f14c
Merge branch 'refactor-train-loop-for-easier-customization' of https:…
1ab3f21
Merge branch 'master' into refactor-train-loop-for-easier-customization
3072e5e
Reviewer comment: tuple unpacking
BenjaminBossan c030dd2
Reviewer comment: tuple unpacking
BenjaminBossan d416749
Merge branch 'master' into refactor-train-loop-for-easier-customization
BenjaminBossan 60911b8
Add a helpful warning if user-defined on_batch
775378e
Merge branch 'refactor-train-loop-for-easier-customization' of https:…
8477edc
Merge branch 'master' into refactor-train-loop-for-easier-customization
BenjaminBossan ef468ed
Add tests for warning when on_batch overridden
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
[run] | ||
omit = | ||
skorch/tests/* | ||
skorch/tests/* | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,7 @@ User's Guide | |
user/helper | ||
user/REST | ||
user/parallelism | ||
user/customization | ||
user/FAQ | ||
|
||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
============= | ||
Customization | ||
============= | ||
|
||
Customizing NeuralNet | ||
--------------------- | ||
|
||
:class:`.NeuralNet` and its subclasses like | ||
:class:`.NeuralNetClassifier` are already very flexible as they are | ||
and should cover many use cases by adjusting the provided | ||
parameters. However, this may not always be sufficient for your use | ||
cases. If you thus find yourself wanting to customize | ||
:class:`.NeuralNet`, please follow these guidelines. | ||
|
||
Initialization | ||
^^^^^^^^^^^^^^ | ||
|
||
The method :func:`~skorch.net.NeuralNet.initialize` is responsible for | ||
initializing all the components needed by the net, e.g. the module and | ||
the optimizer. For this, it calls specific initialization methods, | ||
such as :func:`~skorch.net.NeuralNet.initialize_module` and | ||
:func:`~skorch.net.NeuralNet.initialize_optimizer`. If you'd like to | ||
customize the initialization behavior, you should override the | ||
corresponding methods. Following sklearn conventions, the created | ||
components should be set as an attribute with a trailing underscore as | ||
the name, e.g. ``module_`` for the initialized module. Finally, the | ||
method should return ``self``. | ||
|
||
Methods starting with get_* | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
The net provides a few ``get_*`` methods, most notably | ||
:func:`~skorch.net.NeuralNet.get_loss`, | ||
:func:`~skorch.net.NeuralNet.get_dataset`, and | ||
:func:`~skorch.net.NeuralNet.get_iterator`. The intent of these | ||
methods should be pretty self-explanatory, and if you are still not | ||
quite sure, consult their documentations. In general, these methods | ||
are fairly safe to override as long as you make sure to conform to the | ||
same signature as the original. | ||
|
||
Training and validation | ||
^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
If you would like to customize training and validation, there are | ||
several possibilities. Below are the methods that you most likely want | ||
to customize: | ||
|
||
The method :func:`~skorch.net.NeuralNet.train_step_single` performs a | ||
single training step. It accepts the current batch of data as input | ||
(as well as the ``fit_params``) and should return a dictionary | ||
containing the ``loss`` and the prediction ``y_pred``. E.g. you should | ||
override this if your dataset returns some non-standard data that | ||
needs custom handling, and/or if your module has to be called in a | ||
very specific way. If you want to, you can still make use of | ||
:func:`~skorch.net.NeuralNet.infer` and | ||
:func:`~skorch.net.NeuralNet.get_loss` but it's not strictly | ||
necessary. Don't call the optimizer in this method, this is handled by | ||
the next method. | ||
|
||
The method :func:`~skorch.net.NeuralNet.train_step` defines the | ||
complete training procedure performed for each batch. It accepts the | ||
same arguments as :func:`~skorch.net.NeuralNet.train_step_single` but | ||
it differs in that it defines the training closure passed to the | ||
optimizer, which for instance could be called more than once (e.g. in | ||
case of L-BFGS). You might override this if you deal with non-standard | ||
training procedures, as e.g. gradient accumulation. | ||
|
||
The method :func:`~skorch.net.NeuralNet.validation_step` is | ||
responsible for calculating the prediction and loss on the validation | ||
data (remember that skorch uses an internal validation set for | ||
reporting, early stopping, etc.). Similar to | ||
:func:`~skorch.net.NeuralNet.train_step_single`, it receives the batch | ||
and ``fit_params`` as input and should return a dictionary containing | ||
``loss`` and ``y_pred``. Most likely, when you need to customize | ||
:func:`~skorch.net.NeuralNet.train_step_single`, you'll need to | ||
customize :func:`~skorch.net.NeuralNet.validation_step` accordingly. | ||
|
||
Finally, the method :func:`~skorch.net.NeuralNet.evaluation_step` is | ||
called to you use the net for inference, e.g. when calling | ||
:func:`~skorch.net.NeuralNet.forward` or | ||
:func:`~skorch.net.NeuralNet.predict`. You may want to modify this if, | ||
e.g., you want your model to behave differently during training and | ||
during prediction. | ||
|
||
You should also be aware that some methods are better left | ||
untouched. E.g., in most cases, the following methods should *not* be | ||
overridden: | ||
|
||
* :func:`~skorch.net.NeuralNet.fit` | ||
* :func:`~skorch.net.NeuralNet.partial_fit` | ||
* :func:`~skorch.net.NeuralNet.fit_loop` | ||
* :func:`~skorch.net.NeuralNet.run_single_epoch` | ||
|
||
The reason why these methods should stay untouched is because they | ||
perform some book keeping, like making sure that callbacks are handled | ||
or writing logs to the ``history``. If you do need to override these, | ||
make sure that you perform the same book keeping as the original | ||
methods. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect that changing the signature of callbacks would break's users code the most.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Not every user might follow the changelog closely, we could wrap the
notify
call with an exception handler and in case ofon_batch_begin
we could raise the original exception in addition to a note informing the user about the change in signature and how to fix it. This could stay there for one release.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a stab at this but it turns out not to be so simple. The reason is that we always have
**kwargs
in the signature, so even if someone writes a wrong method likedef on_batch_begin(net, X=None, y=None, **kwargs)
, it would not immediately lead to an error. The error occurs only when they do something with X or y, because these will be None instead of the expected values. But this something can be anything.In my attempt, I just caught any kind of
Exception
and raised aTypeError
from it that pointed at the possible error source. However, this broke other tests, which checked for specific errors (for instance, using scoring with a non existing function name). We don't want that to happen.The problem is thus that we don't know what error will occur but we also don't want to catch all kinds of errors.
We could theoretically inspect the arguments inside the
notify
-wrapper. But there I would be afraid that the performance penalty could be too big. Furthermore, we want to allow the users to be very liberal in how they call theiron_*
methods (e.g. only passing the arguments that they need, which might not even includebatch
), so it's not possible to do a strict check.This leaves me with no idea how to implement the suggestion. Do you have any ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK this is fair, maybe we can issue a warning on the first call of
notify('on_batch_begin')
for callbacks that aren't ours that indicates that things might have changed? We can deprecate this warning in the next or a following release and you could always filter the warning if you are getting annoyed by it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So basically on the very first call to
notify
, you'd like to go through all callbacks and check whether they are an instance of any built-in callback, and if at least one of them isn't, issue a warning? I'm not sure if this is not overkill, given that the warning will mostly be a false alarm.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, no, more like
if not cb.__module__.startswith('skorch'): issueWarning()
.