-
Notifications
You must be signed in to change notification settings - Fork 398
[MRG] LRScheduler batch option #626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
8ed7824
bf46f20
00c9527
6c23cd3
f62b4ba
22b1e24
321cfae
ce235dc
e761d07
f47423b
a85be11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
import sys | ||
|
||
# pylint: disable=unused-import | ||
import warnings | ||
|
||
import numpy as np | ||
import torch | ||
from torch.optim.lr_scheduler import _LRScheduler | ||
|
@@ -57,6 +59,10 @@ class LRScheduler(Callback): | |
Pass ``None`` to disable placing events in history. | ||
**Note:** This feature works only for pytorch version >=1.4 | ||
|
||
step_every: str, (default='epoch') | ||
Value for when to apply the learning scheduler step. Can be either 'batch' | ||
or 'epoch'. | ||
|
||
kwargs | ||
Additional arguments passed to the lr scheduler. | ||
|
||
|
@@ -66,10 +72,12 @@ def __init__(self, | |
policy='WarmRestartLR', | ||
monitor='train_loss', | ||
event_name="event_lr", | ||
step_every='epoch', | ||
**kwargs): | ||
self.policy = policy | ||
self.monitor = monitor | ||
self.event_name = event_name | ||
self.step_every = step_every | ||
vars(self).update(kwargs) | ||
|
||
def simulate(self, steps, initial_lr): | ||
|
@@ -107,6 +115,15 @@ def initialize(self): | |
self.policy_ = self._get_policy_cls() | ||
self.lr_scheduler_ = None | ||
self.batch_idx_ = 0 | ||
# TODO: Remove this warning on 0.10 release | ||
if self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR": | ||
warnings.warn( | ||
"The LRScheduler now makes a step every epoch by default. " | ||
"To have the cyclic lr scheduler update " | ||
"every batch set step_every='batch'", | ||
FutureWarning | ||
) | ||
|
||
return self | ||
|
||
def _get_policy_cls(self): | ||
|
@@ -119,7 +136,7 @@ def kwargs(self): | |
# These are the parameters that are passed to the | ||
# scheduler. Parameters that don't belong there must be | ||
# excluded. | ||
excluded = ('policy', 'monitor', 'event_name') | ||
excluded = ('policy', 'monitor', 'event_name', 'step_every') | ||
kwargs = {key: val for key, val in vars(self).items() | ||
if not (key in excluded or key.endswith('_'))} | ||
return kwargs | ||
|
@@ -135,35 +152,37 @@ def on_train_begin(self, net, **kwargs): | |
) | ||
|
||
def on_epoch_end(self, net, **kwargs): | ||
epoch = len(net.history) - 1 | ||
if not self.step_every == 'epoch': | ||
return | ||
epoch = len(net.history) | ||
if isinstance(self.lr_scheduler_, ReduceLROnPlateau): | ||
if callable(self.monitor): | ||
score = self.monitor(net) | ||
else: | ||
if epoch: | ||
score = net.history[-2, self.monitor] | ||
if self.lr_scheduler_.mode == 'max': | ||
score = -np.inf | ||
elif self.lr_scheduler_.mode == 'min': | ||
score = np.inf | ||
else: | ||
if self.lr_scheduler_.mode == 'max': | ||
score = -np.inf | ||
else: | ||
score = np.inf | ||
score = net.history[-1, self.monitor] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm no expert on lr schedulers, so can't comment on the exact logic that should be applied here. However, what I'm seeing is that the logic might have a "hole", where score wouldn't be defined at all if none of the Could it be that we don't actually need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. Pushed a change. |
||
|
||
self.lr_scheduler_.step(score, epoch) | ||
# ReduceLROnPlateau does not expose the current lr so it can't be recorded | ||
else: | ||
self.lr_scheduler_.step(epoch) | ||
if self.event_name is not None and hasattr( | ||
self.lr_scheduler_, "get_last_lr"): | ||
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0]) | ||
net.history.record(self.event_name, | ||
self.lr_scheduler_.get_last_lr()[0]) | ||
self.lr_scheduler_.step(epoch) | ||
|
||
def on_batch_end(self, net, training, **kwargs): | ||
if not training: | ||
if not training or not self.step_every == 'batch': | ||
return | ||
if TorchCyclicLR and isinstance(self.lr_scheduler_, TorchCyclicLR): | ||
self.lr_scheduler_.step(self.batch_idx_) | ||
if self.event_name is not None and hasattr( | ||
self.lr_scheduler_, "get_last_lr"): | ||
net.history.record_batch(self.event_name, | ||
self.lr_scheduler_.get_last_lr()[0]) | ||
if self.event_name is not None and hasattr( | ||
self.lr_scheduler_, "get_last_lr"): | ||
net.history.record_batch(self.event_name, | ||
self.lr_scheduler_.get_last_lr()[0]) | ||
self.lr_scheduler_.step() | ||
self.batch_idx_ += 1 | ||
|
||
def _get_scheduler(self, net, policy, **scheduler_kwargs): | ||
|
@@ -232,7 +251,8 @@ def __init__( | |
super(WarmRestartLR, self).__init__(optimizer, last_epoch) | ||
|
||
def _get_current_lr(self, min_lr, max_lr, period, epoch): | ||
return min_lr + 0.5 * (max_lr - min_lr) * (1 + np.cos(epoch * np.pi / period)) | ||
return min_lr + 0.5 * (max_lr - min_lr) * ( | ||
1 + np.cos(epoch * np.pi / period)) | ||
|
||
def get_lr(self): | ||
epoch_idx = float(self.last_epoch) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ def test_simulate_lrs_epoch_step(self, policy): | |
@pytest.mark.parametrize('policy', [TorchCyclicLR]) | ||
def test_simulate_lrs_batch_step(self, policy): | ||
lr_sch = LRScheduler( | ||
policy, base_lr=1, max_lr=5, step_size_up=4) | ||
policy, base_lr=1, max_lr=5, step_size_up=4, step_every='batch') | ||
lrs = lr_sch.simulate(11, 1) | ||
expected = np.array([1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]) | ||
assert np.allclose(expected, lrs) | ||
|
@@ -94,10 +94,10 @@ def test_lr_callback_steps_correctly( | |
) | ||
net.fit(X, y) | ||
# pylint: disable=protected-access | ||
assert lr_policy.lr_scheduler_.last_epoch == max_epochs - 1 | ||
assert lr_policy.lr_scheduler_.last_epoch == max_epochs | ||
|
||
@pytest.mark.parametrize('policy, kwargs', [ | ||
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3}), | ||
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would not be backwards compatible with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the implementation is fine and we just need to restore the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that We can have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would mean introducing the entire concept of an 'auto' option ,and having to be backwards compatible for it in the future. However all it'll do is special case cyclic (for the record, cosine annealing isn't hard coded to work with batches, only cyclic). This basically means that if we don't want 'auto' to just be a cover up for the cyclic case and instead be meaningful, we will have to add reasonable recommendations for every scheduler we introduce (which will be the default auto options). That will most likely involve having to read the paper from which it was introduced, or it's documentation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit of a conundrum. Do you know why some schedulers prefer batch and other epoch? Is it some derived property or more arbitrary?
This is certainly not the way to go. Could we not just use whatever the default is? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not for certain. Naturally if you wanted your lr to change smoothly you'd do it per batch rather than epoch. I'm just plugging in a scheduler as is written in its documentation.
A scheduler doesn't have a default. It has a step function which the user can use whenever and wherever they'd like. Setting mode Since the Cyclic scheduler is a special case for which skorch hacked a solution, I suggested leaving that solution as is (so that it won't conflict with this feature or have compatibility issues) until skorch makes an update that breaks compatibility, and just get rid of that hack then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good points. Maybe it would be worth it to actually break backwards compatibility here for the sake of getting a clean solution. In fact, we could even detect when such a breakage happens and issue a warning to the user, with instructions how to revert to the old behavior. Would that be a compromise? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I would be happy with a deprecation warning with a suggestion of using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
]) | ||
def test_lr_callback_batch_steps_correctly( | ||
self, | ||
|
@@ -126,7 +126,7 @@ def test_lr_callback_batch_steps_correctly( | |
assert lr_policy.batch_idx_ == expected | ||
|
||
@pytest.mark.parametrize('policy, kwargs', [ | ||
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3}), | ||
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}), | ||
]) | ||
def test_lr_callback_batch_steps_correctly_fallback( | ||
self, | ||
|
@@ -177,7 +177,8 @@ def test_lr_scheduler_cloneable(self): | |
clone(scheduler) # does not raise | ||
|
||
def test_lr_scheduler_set_params(self, classifier_module, classifier_data): | ||
scheduler = LRScheduler(TorchCyclicLR, base_lr=123, max_lr=999) | ||
scheduler = LRScheduler( | ||
TorchCyclicLR, base_lr=123, max_lr=999, step_every='batch') | ||
net = NeuralNetClassifier( | ||
classifier_module, | ||
max_epochs=0, | ||
|
@@ -205,7 +206,7 @@ def test_lr_scheduler_record_epoch_step(self, | |
net = NeuralNetClassifier( | ||
classifier_module, | ||
max_epochs=epochs, | ||
lr=123, | ||
lr=123., | ||
callbacks=[('scheduler', scheduler)] | ||
) | ||
net.fit(*classifier_data) | ||
|
@@ -219,7 +220,13 @@ def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data | |
X, y = classifier_data | ||
batch_size = 128 | ||
|
||
scheduler = LRScheduler(TorchCyclicLR, base_lr=1, max_lr=5, step_size_up=4) | ||
scheduler = LRScheduler( | ||
TorchCyclicLR, | ||
base_lr=1, | ||
max_lr=5, | ||
step_size_up=4, | ||
step_every='batch' | ||
) | ||
net = NeuralNetClassifier( | ||
classifier_module, | ||
max_epochs=1, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only need to warn if
step_every
is'epoch'
.We should also have a test to make sure the warning message is raised.