Skip to content

Feature/wsd scheduler #12611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/core/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
SquareRootAnnealing,
T5InverseSquareRootAnnealing,
WarmupAnnealing,
WarmupHoldAnnealLinear,
WarmupHoldAnnealOneMinusSquareRoot,
WarmupHoldPolicy,
WarmupPolicy,
prepare_lr_scheduler,
Expand Down
45 changes: 45 additions & 0 deletions nemo/core/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,49 @@ def get_lr(self):
return self._get_lr(step)


class WarmupHoldAnnealOneMinusSquareRoot(WarmupHoldPolicy):
"""Learning rate scheduler with warmup, hold, and one-minus-square-root annealing phases.

This scheduler follows a three-phase pattern:
1. Warmup phase: LR increases linearly from 0 to base_lr
2. Hold phase: LR remains constant at base_lr
3. Annealing phase: LR decreases following a one-minus-square-root curve from base_lr to min_lr

The annealing follows the formula: LR = base_lr * (1 - sqrt((step - hold_steps)/(max_steps - hold_steps)))
The min_lr is enforced after the annealing phase. i.e. the learning rate will not decay below min_lr.

Reference: https://arxiv.org/html/2408.11029
"""

def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)

def _get_lr(self, step):
mult = 1 - ((step - self.hold_steps) / (self.max_steps - self.hold_steps)) ** 0.5 # from 1 to 0
out_lr = [max(self.min_lr, initial_lr * mult) for initial_lr in self.base_lrs]
return out_lr


class WarmupHoldAnnealLinear(WarmupHoldPolicy):
"""Learning rate scheduler with warmup, hold, and linear annealing phases.

This scheduler follows a three-phase pattern:
1. Warmup phase: LR increases linearly from 0 to base_lr
2. Hold phase: LR remains constant at base_lr
3. Annealing phase: LR decreases linearly from base_lr to min_lr

Reference: https://arxiv.org/pdf/2404.06395
"""

def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)

def _get_lr(self, step):
ratio = (step - self.hold_steps) / (self.max_steps - self.hold_steps) # from 0 to 1
out_lr = [initial_lr - (initial_lr - self.min_lr) * ratio for initial_lr in self.base_lrs]
return out_lr


class WarmupAnnealHoldPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Expand Down Expand Up @@ -1002,6 +1045,8 @@ def compute_max_steps(
'CosineAnnealing': CosineAnnealing,
'NoamAnnealing': NoamAnnealing,
'NoamHoldAnnealing': NoamHoldAnnealing,
'WarmupHoldAnnealOneMinusSquareRoot': WarmupHoldAnnealOneMinusSquareRoot,
'WarmupHoldAnnealLinear': WarmupHoldAnnealLinear,
'WarmupAnnealing': WarmupAnnealing,
'InverseSquareRootAnnealing': InverseSquareRootAnnealing,
'T5InverseSquareRootAnnealing': T5InverseSquareRootAnnealing,
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/pytorch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SquareRootAnnealingScheduler,
T5InverseSquareRootAnnealingScheduler,
WarmupAnnealingScheduler,
WarmupHoldAnnealScheduler,
WarmupHoldPolicyScheduler,
WarmupPolicyScheduler,
)
Expand All @@ -47,4 +48,5 @@
"PolynomialHoldDecayAnnealingScheduler",
"CosineAnnealingScheduler",
"PytorchOptimizerModule",
"WarmupHoldAnnealScheduler",
]
56 changes: 55 additions & 1 deletion nemo/lightning/pytorch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Literal, Optional

from nemo.core.optim.lr_scheduler import (
InverseSquareRootAnnealing,
Expand All @@ -24,6 +24,8 @@
SquareRootAnnealing,
T5InverseSquareRootAnnealing,
WarmupAnnealing,
WarmupHoldAnnealLinear,
WarmupHoldAnnealOneMinusSquareRoot,
WarmupHoldPolicy,
WarmupPolicy,
)
Expand Down Expand Up @@ -484,3 +486,55 @@ def scheduler(self, model, optimizer):
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": self.monitor,
}


class WarmupHoldAnnealScheduler(LRSchedulerModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string missing here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hawkoli1987 are you able to add a docstring here to address this comment?

def __init__(
self,
warmup_ratio: Optional[float] = None,
hold_ratio: Optional[float] = None,
max_steps: int = 10,
decay_schedule: Literal["one_minus_square_root", "linear"] = "linear",
min_lr: float = 0.0,
interval: str = "step",
frequency: int = 1,
monitor: str = "val_loss",
):
super().__init__()
self.warmup_ratio = warmup_ratio
self.hold_ratio = hold_ratio
self.max_steps = max_steps
self.decay_schedule = decay_schedule
self.min_lr = min_lr
self.interval = interval
self.frequency = frequency
self.monitor = monitor

def scheduler(self, model, optimizer):
if self.decay_schedule == "one_minus_square_root":
lr_scheduler = WarmupHoldAnnealOneMinusSquareRoot(
optimizer,
warmup_ratio=self.warmup_ratio,
hold_ratio=self.hold_ratio,
max_steps=self.max_steps,
min_lr=self.min_lr,
)
elif self.decay_schedule == "linear":
lr_scheduler = WarmupHoldAnnealLinear(
optimizer,
warmup_ratio=self.warmup_ratio,
hold_ratio=self.hold_ratio,
max_steps=self.max_steps,
min_lr=self.min_lr,
)
else:
raise ValueError(f"Unknown decay schedule: {self.decay_schedule}")
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": self.interval,
"frequency": self.frequency,
},
"monitor": self.monitor,
}
144 changes: 144 additions & 0 deletions tests/core/test_optimizers_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,150 @@ def test_InverseSquareRootAnnealing(self):

assert final_lr == self.MIN_LR


class TestWarmupHoldAnnealSchedulers:
INITIAL_LR = 0.1
MIN_LR = 0.01
MAX_STEPS = 100

@pytest.mark.unit
def test_WarmupHoldAnnealOneMinusSquareRoot(self):
model = TempModel()
opt_cls = optim.get_optimizer('novograd')
opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

# Test case 1: No warmup, no hold
policy = optim.lr_scheduler.WarmupHoldAnnealOneMinusSquareRoot(
opt, warmup_ratio=None, hold_ratio=None, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR
)
initial_lr = policy.get_last_lr()[0]
assert initial_lr == self.INITIAL_LR

# Simulate training steps
lrs = []
for i in range(self.MAX_STEPS):
current_lr = policy.get_last_lr()[0]
lrs.append(current_lr)
assert current_lr <= self.INITIAL_LR
opt.step()
policy.step()

# Check final LR
policy.step()
final_lr = policy.get_last_lr()[0]
assert final_lr == self.MIN_LR

# Test case 2: With warmup and hold
warmup_ratio = 0.1 # 10% warmup
hold_ratio = 0.2 # 20% hold
warmup_steps = int(warmup_ratio * self.MAX_STEPS)
hold_steps = int(hold_ratio * self.MAX_STEPS)

policy = optim.lr_scheduler.WarmupHoldAnnealOneMinusSquareRoot(
opt, warmup_ratio=warmup_ratio, hold_ratio=hold_ratio, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR
)

initial_lr = policy.get_last_lr()[0]
assert initial_lr < self.INITIAL_LR # Should start at a lower LR

# Simulate training steps
lrs = []
for i in range(self.MAX_STEPS):
current_lr = policy.get_last_lr()[0]
lrs.append(current_lr)

# During warmup, LR should increase
if i < warmup_steps:
if i > 0:
assert current_lr >= lrs[i - 1]
assert current_lr <= self.INITIAL_LR

# During hold, LR should remain constant
elif i < warmup_steps + hold_steps:
assert abs(current_lr - self.INITIAL_LR) < 1e-6

# During annealing, LR should decrease
else:
if i > warmup_steps + hold_steps:
assert current_lr <= lrs[i - 1]

opt.step()
policy.step()

# Check final LR
policy.step()
final_lr = policy.get_last_lr()[0]
assert final_lr == self.MIN_LR

@pytest.mark.unit
def test_WarmupHoldAnnealLinear(self):
model = TempModel()
opt_cls = optim.get_optimizer('novograd')
opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

# Test case 1: No warmup, no hold
policy = optim.lr_scheduler.WarmupHoldAnnealLinear(
opt, warmup_ratio=None, hold_ratio=None, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR
)
initial_lr = policy.get_last_lr()[0]
assert initial_lr == self.INITIAL_LR

# Simulate training steps
lrs = []
for i in range(self.MAX_STEPS):
current_lr = policy.get_last_lr()[0]
lrs.append(current_lr)
assert current_lr <= self.INITIAL_LR
opt.step()
policy.step()

# Check final LR
policy.step()
final_lr = policy.get_last_lr()[0]
assert final_lr == self.MIN_LR

# Test case 2: With warmup and hold
warmup_ratio = 0.1 # 10% warmup
hold_ratio = 0.2 # 20% hold
warmup_steps = int(warmup_ratio * self.MAX_STEPS)
hold_steps = int(hold_ratio * self.MAX_STEPS)

policy = optim.lr_scheduler.WarmupHoldAnnealLinear(
opt, warmup_ratio=warmup_ratio, hold_ratio=hold_ratio, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR
)

initial_lr = policy.get_last_lr()[0]
assert initial_lr < self.INITIAL_LR # Should start at a lower LR

# Simulate training steps
lrs = []
for i in range(self.MAX_STEPS):
current_lr = policy.get_last_lr()[0]
lrs.append(current_lr)

# During warmup, LR should increase
if i < warmup_steps:
if i > 0:
assert current_lr >= lrs[i - 1]
assert current_lr <= self.INITIAL_LR

# During hold, LR should remain constant
elif i < warmup_steps + hold_steps:
assert abs(current_lr - self.INITIAL_LR) < 1e-6

# During annealing, LR should decrease
else:
if i > warmup_steps + hold_steps:
assert current_lr <= lrs[i - 1]

opt.step()
policy.step()

# Check final LR
policy.step()
final_lr = policy.get_last_lr()[0]
assert final_lr == self.MIN_LR

@pytest.mark.unit
def test_CosineAnnealing_with_noop_steps(self):
model = TempModel()
Expand Down
Loading