Skip to content

Commit 5bd84bd

Browse files
ENH: ReduceLROnPlateau records the learning rate and works on batches (#1075)
Previously, when using ReduceLROnPlateau, we would not record the learning rates in history. The comment says that's because this class does not expose the get_last_lr method. I checked it again and it's now present, so let's use it. Furthermore, I made a change to enable ReduceLROnPlateau to step on each batch instead of each epoch. This is consistent with other learning rate schedulers.
1 parent 4f755b9 commit 5bd84bd

File tree

3 files changed

+97
-13
lines changed

3 files changed

+97
-13
lines changed

CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
### Changed
1212

1313
- All neural net classes now inherit from sklearn's [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). This is to support compatibility with sklearn 1.6.0 and above. Classification models additionally inherit from [`ClassifierMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.ClassifierMixin.html) and regressors from [`RegressorMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.RegressorMixin.html).
14+
- When using the `ReduceLROnPlateau` learning rate scheduler, we now record the learning rate in the net history (`net.history[:, 'event_lr']` by default). It is now also possible to to step per batch, not only by epoch
1415

1516
### Fixed
1617

skorch/callbacks/lr_scheduler.py

+47-13
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,36 @@ def _step(self, net, lr_scheduler, score=None):
162162
else:
163163
lr_scheduler.step(score)
164164

165+
def _record_last_lr(self, net, kind):
166+
# helper function to record the last learning rate if possible;
167+
# only record the first lr returned if more than 1 param group
168+
if kind not in ('epoch', 'batch'):
169+
raise ValueError(f"Argument 'kind' should be 'batch' or 'epoch', get {kind}.")
170+
171+
if (
172+
(self.event_name is None)
173+
or not hasattr(self.lr_scheduler_, 'get_last_lr')
174+
):
175+
return
176+
177+
try:
178+
last_lrs = self.lr_scheduler_.get_last_lr()
179+
except AttributeError:
180+
# get_last_lr fails for ReduceLROnPlateau with PyTorch <= 2.2 on 1st epoch.
181+
# Take the initial lr instead.
182+
last_lrs = [group['lr'] for group in net.optimizer_.param_groups]
183+
184+
if kind == 'epoch':
185+
net.history.record(self.event_name, last_lrs[0])
186+
else:
187+
net.history.record_batch(self.event_name, last_lrs[0])
188+
165189
def on_epoch_end(self, net, **kwargs):
166190
if self.step_every != 'epoch':
167191
return
192+
193+
self._record_last_lr(net, kind='epoch')
194+
168195
if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
169196
if callable(self.monitor):
170197
score = self.monitor(net)
@@ -179,25 +206,32 @@ def on_epoch_end(self, net, **kwargs):
179206
) from e
180207

181208
self._step(net, self.lr_scheduler_, score=score)
182-
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
183209
else:
184-
if (
185-
(self.event_name is not None)
186-
and hasattr(self.lr_scheduler_, "get_last_lr")
187-
):
188-
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
189210
self._step(net, self.lr_scheduler_)
190211

191212
def on_batch_end(self, net, training, **kwargs):
192213
if not training or self.step_every != 'batch':
193214
return
194-
if (
195-
(self.event_name is not None)
196-
and hasattr(self.lr_scheduler_, "get_last_lr")
197-
):
198-
net.history.record_batch(
199-
self.event_name, self.lr_scheduler_.get_last_lr()[0])
200-
self._step(net, self.lr_scheduler_)
215+
216+
self._record_last_lr(net, kind='batch')
217+
218+
if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
219+
if callable(self.monitor):
220+
score = self.monitor(net)
221+
else:
222+
try:
223+
score = net.history[-1, 'batches', -1, self.monitor]
224+
except KeyError as e:
225+
raise ValueError(
226+
f"'{self.monitor}' was not found in history. A "
227+
f"Scoring callback with name='{self.monitor}' "
228+
"should be placed before the LRScheduler callback"
229+
) from e
230+
231+
self._step(net, self.lr_scheduler_, score=score)
232+
else:
233+
self._step(net, self.lr_scheduler_)
234+
201235
self.batch_idx_ += 1
202236

203237
def _get_scheduler(self, net, policy, **scheduler_kwargs):

skorch/tests/callbacks/test_lr_scheduler.py

+49
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,55 @@ def test_reduce_lr_raise_error_when_key_does_not_exist(
315315
with pytest.raises(ValueError, match=msg):
316316
net.fit(X, y)
317317

318+
def test_reduce_lr_record_epoch_step(self, classifier_module, classifier_data):
319+
epochs = 10 * 3 # patience = 10, get 3 full cycles of lr reduction
320+
lr = 123.
321+
net = NeuralNetClassifier(
322+
classifier_module,
323+
max_epochs=epochs,
324+
lr=lr,
325+
callbacks=[
326+
('scheduler', LRScheduler(ReduceLROnPlateau, monitor='train_loss')),
327+
],
328+
)
329+
net.fit(*classifier_data)
330+
331+
# We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be
332+
# simulated. Instead we expect the lr to be reduced by a factor of 10 every
333+
# 10+ epochs (as patience = 10), with the exact number depending on the training
334+
# progress. Therefore, we can have at most 3 distinct lrs, but it could be less,
335+
# so we need to slice the expected lrs.
336+
lrs = net.history[:, 'event_lr']
337+
lrs_unique = np.unique(lrs)
338+
expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):]
339+
assert np.allclose(lrs_unique, expected)
340+
341+
def test_reduce_lr_record_batch_step(self, classifier_module, classifier_data):
342+
epochs = 3
343+
lr = 123.
344+
net = NeuralNetClassifier(
345+
classifier_module,
346+
max_epochs=epochs,
347+
lr=lr,
348+
callbacks=[
349+
('scheduler', LRScheduler(
350+
ReduceLROnPlateau, monitor='train_loss', step_every='batch'
351+
)),
352+
],
353+
)
354+
net.fit(*classifier_data)
355+
356+
# We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be
357+
# simulated. Instead we expect the lr to be reduced by a factor of 10 every
358+
# 10+ batches (as patience = 10), with the exact number depending on the
359+
# training progress. Therefore, we can have at most 3 distinct lrs, but it
360+
# could be less, so we need to slice the expected, lrs.
361+
lrs_nested = net.history[:, 'batches', :, 'event_lr']
362+
lrs_flat = sum(lrs_nested, [])
363+
lrs_unique = np.unique(lrs_flat)
364+
expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):]
365+
assert np.allclose(lrs_unique, expected)
366+
318367

319368
class TestWarmRestartLR():
320369
def assert_lr_correct(

0 commit comments

Comments
 (0)