diff --git a/nemo/core/optim/__init__.py b/nemo/core/optim/__init__.py index 488f4f57ea58..5abeb179dd98 100644 --- a/nemo/core/optim/__init__.py +++ b/nemo/core/optim/__init__.py @@ -24,6 +24,8 @@ SquareRootAnnealing, T5InverseSquareRootAnnealing, WarmupAnnealing, + WarmupHoldAnnealLinear, + WarmupHoldAnnealOneMinusSquareRoot, WarmupHoldPolicy, WarmupPolicy, prepare_lr_scheduler, diff --git a/nemo/core/optim/lr_scheduler.py b/nemo/core/optim/lr_scheduler.py index 7697a27f6d4b..a5936fc6d41e 100644 --- a/nemo/core/optim/lr_scheduler.py +++ b/nemo/core/optim/lr_scheduler.py @@ -225,6 +225,49 @@ def get_lr(self): return self._get_lr(step) +class WarmupHoldAnnealOneMinusSquareRoot(WarmupHoldPolicy): + """Learning rate scheduler with warmup, hold, and one-minus-square-root annealing phases. + + This scheduler follows a three-phase pattern: + 1. Warmup phase: LR increases linearly from 0 to base_lr + 2. Hold phase: LR remains constant at base_lr + 3. Annealing phase: LR decreases following a one-minus-square-root curve from base_lr to min_lr + + The annealing follows the formula: LR = base_lr * (1 - sqrt((step - hold_steps)/(max_steps - hold_steps))) + The min_lr is enforced after the annealing phase. i.e. the learning rate will not decay below min_lr. + + Reference: https://arxiv.org/html/2408.11029 + """ + + def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr) + + def _get_lr(self, step): + mult = 1 - ((step - self.hold_steps) / (self.max_steps - self.hold_steps)) ** 0.5 # from 1 to 0 + out_lr = [max(self.min_lr, initial_lr * mult) for initial_lr in self.base_lrs] + return out_lr + + +class WarmupHoldAnnealLinear(WarmupHoldPolicy): + """Learning rate scheduler with warmup, hold, and linear annealing phases. + + This scheduler follows a three-phase pattern: + 1. Warmup phase: LR increases linearly from 0 to base_lr + 2. Hold phase: LR remains constant at base_lr + 3. Annealing phase: LR decreases linearly from base_lr to min_lr + + Reference: https://arxiv.org/pdf/2404.06395 + """ + + def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr) + + def _get_lr(self, step): + ratio = (step - self.hold_steps) / (self.max_steps - self.hold_steps) # from 0 to 1 + out_lr = [initial_lr - (initial_lr - self.min_lr) * ratio for initial_lr in self.base_lrs] + return out_lr + + class WarmupAnnealHoldPolicy(_LRScheduler): """Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity, @@ -1002,6 +1045,8 @@ def compute_max_steps( 'CosineAnnealing': CosineAnnealing, 'NoamAnnealing': NoamAnnealing, 'NoamHoldAnnealing': NoamHoldAnnealing, + 'WarmupHoldAnnealOneMinusSquareRoot': WarmupHoldAnnealOneMinusSquareRoot, + 'WarmupHoldAnnealLinear': WarmupHoldAnnealLinear, 'WarmupAnnealing': WarmupAnnealing, 'InverseSquareRootAnnealing': InverseSquareRootAnnealing, 'T5InverseSquareRootAnnealing': T5InverseSquareRootAnnealing, diff --git a/nemo/lightning/pytorch/optim/__init__.py b/nemo/lightning/pytorch/optim/__init__.py index f7af606c0916..cb7ab8c1c4ee 100644 --- a/nemo/lightning/pytorch/optim/__init__.py +++ b/nemo/lightning/pytorch/optim/__init__.py @@ -24,6 +24,7 @@ SquareRootAnnealingScheduler, T5InverseSquareRootAnnealingScheduler, WarmupAnnealingScheduler, + WarmupHoldAnnealScheduler, WarmupHoldPolicyScheduler, WarmupPolicyScheduler, ) @@ -47,4 +48,5 @@ "PolynomialHoldDecayAnnealingScheduler", "CosineAnnealingScheduler", "PytorchOptimizerModule", + "WarmupHoldAnnealScheduler", ] diff --git a/nemo/lightning/pytorch/optim/lr_scheduler.py b/nemo/lightning/pytorch/optim/lr_scheduler.py index 5966c7e6fcdf..40fa1c3f4c90 100644 --- a/nemo/lightning/pytorch/optim/lr_scheduler.py +++ b/nemo/lightning/pytorch/optim/lr_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Literal, Optional from nemo.core.optim.lr_scheduler import ( InverseSquareRootAnnealing, @@ -24,6 +24,8 @@ SquareRootAnnealing, T5InverseSquareRootAnnealing, WarmupAnnealing, + WarmupHoldAnnealLinear, + WarmupHoldAnnealOneMinusSquareRoot, WarmupHoldPolicy, WarmupPolicy, ) @@ -484,3 +486,60 @@ def scheduler(self, model, optimizer): # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": self.monitor, } + + +class WarmupHoldAnnealScheduler(LRSchedulerModule): + """Warmup Hold Annealing Learning Rate Scheduler. + + A learning rate scheduler with warmup, hold, and linear annealing phases. + """ + + def __init__( + self, + warmup_ratio: Optional[float] = None, + hold_ratio: Optional[float] = None, + max_steps: int = 10, + decay_schedule: Literal["one_minus_square_root", "linear"] = "linear", + min_lr: float = 0.0, + interval: str = "step", + frequency: int = 1, + monitor: str = "val_loss", + ): + super().__init__() + self.warmup_ratio = warmup_ratio + self.hold_ratio = hold_ratio + self.max_steps = max_steps + self.decay_schedule = decay_schedule + self.min_lr = min_lr + self.interval = interval + self.frequency = frequency + self.monitor = monitor + + def scheduler(self, model, optimizer): + if self.decay_schedule == "one_minus_square_root": + lr_scheduler = WarmupHoldAnnealOneMinusSquareRoot( + optimizer, + warmup_ratio=self.warmup_ratio, + hold_ratio=self.hold_ratio, + max_steps=self.max_steps, + min_lr=self.min_lr, + ) + elif self.decay_schedule == "linear": + lr_scheduler = WarmupHoldAnnealLinear( + optimizer, + warmup_ratio=self.warmup_ratio, + hold_ratio=self.hold_ratio, + max_steps=self.max_steps, + min_lr=self.min_lr, + ) + else: + raise ValueError(f"Unknown decay schedule: {self.decay_schedule}") + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, + "monitor": self.monitor, + } diff --git a/tests/core/test_optimizers_schedulers.py b/tests/core/test_optimizers_schedulers.py index 419db309a918..698993fb3ab1 100644 --- a/tests/core/test_optimizers_schedulers.py +++ b/tests/core/test_optimizers_schedulers.py @@ -883,6 +883,150 @@ def test_InverseSquareRootAnnealing(self): assert final_lr == self.MIN_LR + +class TestWarmupHoldAnnealSchedulers: + INITIAL_LR = 0.1 + MIN_LR = 0.01 + MAX_STEPS = 100 + + @pytest.mark.unit + def test_WarmupHoldAnnealOneMinusSquareRoot(self): + model = TempModel() + opt_cls = optim.get_optimizer('novograd') + opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) + + # Test case 1: No warmup, no hold + policy = optim.lr_scheduler.WarmupHoldAnnealOneMinusSquareRoot( + opt, warmup_ratio=None, hold_ratio=None, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR + ) + initial_lr = policy.get_last_lr()[0] + assert initial_lr == self.INITIAL_LR + + # Simulate training steps + lrs = [] + for i in range(self.MAX_STEPS): + current_lr = policy.get_last_lr()[0] + lrs.append(current_lr) + assert current_lr <= self.INITIAL_LR + opt.step() + policy.step() + + # Check final LR + policy.step() + final_lr = policy.get_last_lr()[0] + assert final_lr == self.MIN_LR + + # Test case 2: With warmup and hold + warmup_ratio = 0.1 # 10% warmup + hold_ratio = 0.2 # 20% hold + warmup_steps = int(warmup_ratio * self.MAX_STEPS) + hold_steps = int(hold_ratio * self.MAX_STEPS) + + policy = optim.lr_scheduler.WarmupHoldAnnealOneMinusSquareRoot( + opt, warmup_ratio=warmup_ratio, hold_ratio=hold_ratio, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR + ) + + initial_lr = policy.get_last_lr()[0] + assert initial_lr < self.INITIAL_LR # Should start at a lower LR + + # Simulate training steps + lrs = [] + for i in range(self.MAX_STEPS): + current_lr = policy.get_last_lr()[0] + lrs.append(current_lr) + + # During warmup, LR should increase + if i < warmup_steps: + if i > 0: + assert current_lr >= lrs[i - 1] + assert current_lr <= self.INITIAL_LR + + # During hold, LR should remain constant + elif i < warmup_steps + hold_steps: + assert abs(current_lr - self.INITIAL_LR) < 1e-6 + + # During annealing, LR should decrease + else: + if i > warmup_steps + hold_steps: + assert current_lr <= lrs[i - 1] + + opt.step() + policy.step() + + # Check final LR + policy.step() + final_lr = policy.get_last_lr()[0] + assert final_lr == self.MIN_LR + + @pytest.mark.unit + def test_WarmupHoldAnnealLinear(self): + model = TempModel() + opt_cls = optim.get_optimizer('novograd') + opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) + + # Test case 1: No warmup, no hold + policy = optim.lr_scheduler.WarmupHoldAnnealLinear( + opt, warmup_ratio=None, hold_ratio=None, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR + ) + initial_lr = policy.get_last_lr()[0] + assert initial_lr == self.INITIAL_LR + + # Simulate training steps + lrs = [] + for i in range(self.MAX_STEPS): + current_lr = policy.get_last_lr()[0] + lrs.append(current_lr) + assert current_lr <= self.INITIAL_LR + opt.step() + policy.step() + + # Check final LR + policy.step() + final_lr = policy.get_last_lr()[0] + assert final_lr == self.MIN_LR + + # Test case 2: With warmup and hold + warmup_ratio = 0.1 # 10% warmup + hold_ratio = 0.2 # 20% hold + warmup_steps = int(warmup_ratio * self.MAX_STEPS) + hold_steps = int(hold_ratio * self.MAX_STEPS) + + policy = optim.lr_scheduler.WarmupHoldAnnealLinear( + opt, warmup_ratio=warmup_ratio, hold_ratio=hold_ratio, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR + ) + + initial_lr = policy.get_last_lr()[0] + assert initial_lr < self.INITIAL_LR # Should start at a lower LR + + # Simulate training steps + lrs = [] + for i in range(self.MAX_STEPS): + current_lr = policy.get_last_lr()[0] + lrs.append(current_lr) + + # During warmup, LR should increase + if i < warmup_steps: + if i > 0: + assert current_lr >= lrs[i - 1] + assert current_lr <= self.INITIAL_LR + + # During hold, LR should remain constant + elif i < warmup_steps + hold_steps: + assert abs(current_lr - self.INITIAL_LR) < 1e-6 + + # During annealing, LR should decrease + else: + if i > warmup_steps + hold_steps: + assert current_lr <= lrs[i - 1] + + opt.step() + policy.step() + + # Check final LR + policy.step() + final_lr = policy.get_last_lr()[0] + assert final_lr == self.MIN_LR + @pytest.mark.unit def test_CosineAnnealing_with_noop_steps(self): model = TempModel()