Skip to content

[MRG] LRScheduler batch option #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added the `event_name` argument for `LRScheduler` for optional recording of LR changes inside `net.history`. NOTE: Supported only in Pytorch>=1.4
- Make it easier to add custom modules or optimizers to a neural net class by automatically registering them where necessary and by making them available to set_params
- Added the `step_every` argument for `LRScheduler` to set whether the scheduler step should be taken on every epoch or on every batch.

### Changed

- Removed support for schedulers with a `batch_step()` method in `LRScheduler`.
- Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620)
- The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527)
- Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch.

### Fixed

Expand Down
57 changes: 39 additions & 18 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sys

# pylint: disable=unused-import
import warnings

import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -57,6 +59,10 @@ class LRScheduler(Callback):
Pass ``None`` to disable placing events in history.
**Note:** This feature works only for pytorch version >=1.4

step_every: str, (default='epoch')
Value for when to apply the learning scheduler step. Can be either 'batch'
or 'epoch'.

kwargs
Additional arguments passed to the lr scheduler.

Expand All @@ -66,10 +72,12 @@ def __init__(self,
policy='WarmRestartLR',
monitor='train_loss',
event_name="event_lr",
step_every='epoch',
**kwargs):
self.policy = policy
self.monitor = monitor
self.event_name = event_name
self.step_every = step_every
vars(self).update(kwargs)

def simulate(self, steps, initial_lr):
Expand Down Expand Up @@ -107,6 +115,16 @@ def initialize(self):
self.policy_ = self._get_policy_cls()
self.lr_scheduler_ = None
self.batch_idx_ = 0
# TODO: Remove this warning on 0.10 release
if (self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR"
and self.step_every == 'epoch'):
warnings.warn(
"The LRScheduler now makes a step every epoch by default. "
"To have the cyclic lr scheduler update "
"every batch set step_every='batch'",
FutureWarning
)

return self

def _get_policy_cls(self):
Expand All @@ -119,7 +137,7 @@ def kwargs(self):
# These are the parameters that are passed to the
# scheduler. Parameters that don't belong there must be
# excluded.
excluded = ('policy', 'monitor', 'event_name')
excluded = ('policy', 'monitor', 'event_name', 'step_every')
kwargs = {key: val for key, val in vars(self).items()
if not (key in excluded or key.endswith('_'))}
return kwargs
Expand All @@ -135,35 +153,37 @@ def on_train_begin(self, net, **kwargs):
)

def on_epoch_end(self, net, **kwargs):
epoch = len(net.history) - 1
if not self.step_every == 'epoch':
return
epoch = len(net.history)
if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
score = self.monitor(net)
else:
if epoch:
score = net.history[-2, self.monitor]
if self.lr_scheduler_.mode == 'max':
score = -np.inf
elif self.lr_scheduler_.mode == 'min':
score = np.inf
else:
if self.lr_scheduler_.mode == 'max':
score = -np.inf
else:
score = np.inf
score = net.history[-1, self.monitor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm no expert on lr schedulers, so can't comment on the exact logic that should be applied here. However, what I'm seeing is that the logic might have a "hole", where score wouldn't be defined at all if none of the if or elif matches. In the logic before, score would always be defined.

Could it be that we don't actually need elif epoch? Previously, we accessed net.history[-2, self.monitor], i.e. the second to last row in the history. Therefore, we had to make sure that we're not in the first epoch (at least that's my interpretation of things). Now that we're accessing net.history[-1, self.monitor], we can probably never get an index error here, hence don't need the guard. Would that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Pushed a change.


self.lr_scheduler_.step(score, epoch)
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
else:
self.lr_scheduler_.step(epoch)
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
net.history.record(self.event_name,
self.lr_scheduler_.get_last_lr()[0])
self.lr_scheduler_.step(epoch)

def on_batch_end(self, net, training, **kwargs):
if not training:
if not training or not self.step_every == 'batch':
return
if TorchCyclicLR and isinstance(self.lr_scheduler_, TorchCyclicLR):
self.lr_scheduler_.step(self.batch_idx_)
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
net.history.record_batch(self.event_name,
self.lr_scheduler_.get_last_lr()[0])
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
net.history.record_batch(self.event_name,
self.lr_scheduler_.get_last_lr()[0])
self.lr_scheduler_.step()
self.batch_idx_ += 1

def _get_scheduler(self, net, policy, **scheduler_kwargs):
Expand Down Expand Up @@ -232,7 +252,8 @@ def __init__(
super(WarmRestartLR, self).__init__(optimizer, last_epoch)

def _get_current_lr(self, min_lr, max_lr, period, epoch):
return min_lr + 0.5 * (max_lr - min_lr) * (1 + np.cos(epoch * np.pi / period))
return min_lr + 0.5 * (max_lr - min_lr) * (
1 + np.cos(epoch * np.pi / period))

def get_lr(self):
epoch_idx = float(self.last_epoch)
Expand Down
38 changes: 31 additions & 7 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_simulate_lrs_epoch_step(self, policy):
@pytest.mark.parametrize('policy', [TorchCyclicLR])
def test_simulate_lrs_batch_step(self, policy):
lr_sch = LRScheduler(
policy, base_lr=1, max_lr=5, step_size_up=4)
policy, base_lr=1, max_lr=5, step_size_up=4, step_every='batch')
lrs = lr_sch.simulate(11, 1)
expected = np.array([1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3])
assert np.allclose(expected, lrs)
Expand Down Expand Up @@ -94,10 +94,10 @@ def test_lr_callback_steps_correctly(
)
net.fit(X, y)
# pylint: disable=protected-access
assert lr_policy.lr_scheduler_.last_epoch == max_epochs - 1
assert lr_policy.lr_scheduler_.last_epoch == max_epochs

@pytest.mark.parametrize('policy, kwargs', [
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3}),
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not be backwards compatible with TorchCyclicLR with default arguments. As in TorchCyclicLR would step every epoch if a user does not pass in step_every='batch'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the implementation is fine and we just need to restore the if TorchCyclicLR and isinstance(self.lr_scheduler_, TorchCyclicLR):... segment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that step_every is a parameter to LRScheduler, if we keep the isinstance(self.lr_scheduler_, TorchCyclicLR) piece, we will not be strictly following the step_every parameter.

We can have a step_every='auto' option that special cases the CosineAnnealingLR and TorchCyclicLR. Specifically, if the class is CosineAnnealingLR or TorchCyclicLR we use batch steps, otherwise we use epoch step. We will still allow the batch and epoch options for custom lr schdulers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would mean introducing the entire concept of an 'auto' option ,and having to be backwards compatible for it in the future. However all it'll do is special case cyclic (for the record, cosine annealing isn't hard coded to work with batches, only cyclic).

This basically means that if we don't want 'auto' to just be a cover up for the cyclic case and instead be meaningful, we will have to add reasonable recommendations for every scheduler we introduce (which will be the default auto options). That will most likely involve having to read the paper from which it was introduced, or it's documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a conundrum. Do you know why some schedulers prefer batch and other epoch? Is it some derived property or more arbitrary?

we will have to add reasonable recommendations for every scheduler we introduce

This is certainly not the way to go. Could we not just use whatever the default is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a conundrum. Do you know why some schedulers prefer batch and other epoch? Is it some derived property or more arbitrary?

Not for certain. Naturally if you wanted your lr to change smoothly you'd do it per batch rather than epoch. I'm just plugging in a scheduler as is written in its documentation.

This is certainly not the way to go. Could we not just use whatever the default is?

A scheduler doesn't have a default. It has a step function which the user can use whenever and wherever they'd like. Setting mode epoch or batch as defaults will break backwards compatibility for anything that isn't intended to use the default mode we choose.

Since the Cyclic scheduler is a special case for which skorch hacked a solution, I suggested leaving that solution as is (so that it won't conflict with this feature or have compatibility issues) until skorch makes an update that breaks compatibility, and just get rid of that hack then.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. Maybe it would be worth it to actually break backwards compatibility here for the sake of getting a clean solution. In fact, we could even detect when such a breakage happens and issue a warning to the user, with instructions how to revert to the old behavior. Would that be a compromise?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that be a compromise?

I would be happy with a deprecation warning with a suggestion of using step_every='batch' in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

])
def test_lr_callback_batch_steps_correctly(
self,
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_lr_callback_batch_steps_correctly(
assert lr_policy.batch_idx_ == expected

@pytest.mark.parametrize('policy, kwargs', [
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3}),
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
])
def test_lr_callback_batch_steps_correctly_fallback(
self,
Expand Down Expand Up @@ -177,7 +177,8 @@ def test_lr_scheduler_cloneable(self):
clone(scheduler) # does not raise

def test_lr_scheduler_set_params(self, classifier_module, classifier_data):
scheduler = LRScheduler(TorchCyclicLR, base_lr=123, max_lr=999)
scheduler = LRScheduler(
TorchCyclicLR, base_lr=123, max_lr=999, step_every='batch')
net = NeuralNetClassifier(
classifier_module,
max_epochs=0,
Expand Down Expand Up @@ -205,7 +206,7 @@ def test_lr_scheduler_record_epoch_step(self,
net = NeuralNetClassifier(
classifier_module,
max_epochs=epochs,
lr=123,
lr=123.,
callbacks=[('scheduler', scheduler)]
)
net.fit(*classifier_data)
Expand All @@ -219,7 +220,13 @@ def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data
X, y = classifier_data
batch_size = 128

scheduler = LRScheduler(TorchCyclicLR, base_lr=1, max_lr=5, step_size_up=4)
scheduler = LRScheduler(
TorchCyclicLR,
base_lr=1,
max_lr=5,
step_size_up=4,
step_every='batch'
)
net = NeuralNetClassifier(
classifier_module,
max_epochs=1,
Expand All @@ -234,6 +241,23 @@ def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data
)
assert np.all(net.history[-1, 'batches', :, 'event_lr'] == new_lrs)

def test_cyclic_lr_with_epoch_step_warning(self,
classifier_module,
classifier_data):
msg = ("The LRScheduler now makes a step every epoch by default. "
"To have the cyclic lr scheduler update "
"every batch set step_every='batch'")
with pytest.warns(FutureWarning, match=msg) as record:
scheduler = LRScheduler(
TorchCyclicLR, base_lr=123, max_lr=999)
net = NeuralNetClassifier(
classifier_module,
max_epochs=0,
callbacks=[('scheduler', scheduler)],
)
net.initialize()
assert len(record) == 1


class TestReduceLROnPlateau:

Expand Down