Skip to content

Commit d39ff53

Browse files
author
vvv
committed
fix .get_last_lr() error refer to ildoonet/pytorch-gradual-warmup-lr#14
1 parent 6b5e895 commit d39ff53

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

warmup_scheduler/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
optim = SGD(model, 0.1)
1111

1212
# scheduler_warmup is chained with schduler_steplr
13-
scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
13+
scheduler_steplr = StepLR(optim, step_size=5, gamma=0.1)
1414
scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
1515

1616
# this zero gradient update is needed to avoid a warning message, issue #8.

warmup_scheduler/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def get_lr(self):
2828
if not self.finished:
2929
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
3030
self.finished = True
31-
return self.after_scheduler.get_last_lr()
31+
return [group['lr'] for group in self.optimizer.param_groups]
3232
return [base_lr * self.multiplier for base_lr in self.base_lrs]
3333

3434
if self.multiplier == 1.0:
@@ -57,7 +57,7 @@ def step(self, epoch=None, metrics=None):
5757
self.after_scheduler.step(None)
5858
else:
5959
self.after_scheduler.step(epoch - self.total_epoch)
60-
self._last_lr = self.after_scheduler.get_last_lr()
60+
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
6161
else:
6262
return super(GradualWarmupScheduler, self).step(epoch)
6363
else:

0 commit comments

Comments
 (0)