Skip to content

Commit ec5a4d2

Browse files
committed
add pytest for wsd scheduler
1 parent 78700d7 commit ec5a4d2

File tree

5 files changed

+166
-1
lines changed

5 files changed

+166
-1
lines changed

nemo/core/optim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
CosineAnnealing,
1919
InverseSquareRootAnnealing,
2020
NoamAnnealing,
21+
WarmupHoldAnnealOneMinusSquareRoot,
22+
WarmupHoldAnnealLinear,
2123
PolynomialDecayAnnealing,
2224
PolynomialHoldDecayAnnealing,
2325
SquareAnnealing,

nemo/core/optim/lr_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,8 @@ def compute_max_steps(
10201020
'CosineAnnealing': CosineAnnealing,
10211021
'NoamAnnealing': NoamAnnealing,
10221022
'NoamHoldAnnealing': NoamHoldAnnealing,
1023+
'WarmupHoldAnnealOneMinusSquareRoot': WarmupHoldAnnealOneMinusSquareRoot,
1024+
'WarmupHoldAnnealLinear': WarmupHoldAnnealLinear,
10231025
'WarmupAnnealing': WarmupAnnealing,
10241026
'InverseSquareRootAnnealing': InverseSquareRootAnnealing,
10251027
'T5InverseSquareRootAnnealing': T5InverseSquareRootAnnealing,

nemo/lightning/pytorch/optim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
WarmupAnnealingScheduler,
2727
WarmupHoldPolicyScheduler,
2828
WarmupPolicyScheduler,
29+
WarmupHoldAnnealScheduler,
2930
)
3031
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
3132
from nemo.lightning.pytorch.optim.pytorch import PytorchOptimizerModule
@@ -47,4 +48,5 @@
4748
"PolynomialHoldDecayAnnealingScheduler",
4849
"CosineAnnealingScheduler",
4950
"PytorchOptimizerModule",
51+
"WarmupHoldAnnealScheduler",
5052
]

nemo/lightning/pytorch/optim/lr_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def scheduler(self, model, optimizer):
486486
"monitor": self.monitor,
487487
}
488488

489-
class WarmupHoldAnneal(LRSchedulerModule):
489+
class WarmupHoldAnnealScheduler(LRSchedulerModule):
490490
def __init__(
491491
self,
492492
warmup_ratio: Optional[float] = None,

tests/core/test_optimizers_schedulers.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,165 @@ def test_InverseSquareRootAnnealing(self):
883883

884884
assert final_lr == self.MIN_LR
885885

886+
class TestWarmupHoldAnnealSchedulers:
887+
INITIAL_LR = 0.1
888+
MIN_LR = 0.01
889+
MAX_STEPS = 100
890+
891+
@pytest.mark.unit
892+
def test_WarmupHoldAnnealOneMinusSquareRoot(self):
893+
model = TempModel()
894+
opt_cls = optim.get_optimizer('novograd')
895+
opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)
896+
897+
# Test case 1: No warmup, no hold
898+
policy = optim.lr_scheduler.WarmupHoldAnnealOneMinusSquareRoot(
899+
opt,
900+
warmup_ratio=None,
901+
hold_ratio=None,
902+
max_steps=self.MAX_STEPS,
903+
min_lr=self.MIN_LR
904+
)
905+
initial_lr = policy.get_last_lr()[0]
906+
assert initial_lr == self.INITIAL_LR
907+
908+
# Simulate training steps
909+
lrs = []
910+
for i in range(self.MAX_STEPS):
911+
current_lr = policy.get_last_lr()[0]
912+
lrs.append(current_lr)
913+
assert current_lr <= self.INITIAL_LR
914+
opt.step()
915+
policy.step()
916+
917+
# Check final LR
918+
policy.step()
919+
final_lr = policy.get_last_lr()[0]
920+
assert final_lr == self.MIN_LR
921+
922+
# Test case 2: With warmup and hold
923+
warmup_ratio = 0.1 # 10% warmup
924+
hold_ratio = 0.2 # 20% hold
925+
warmup_steps = int(warmup_ratio * self.MAX_STEPS)
926+
hold_steps = int(hold_ratio * self.MAX_STEPS)
927+
928+
policy = optim.lr_scheduler.WarmupHoldAnnealOneMinusSquareRoot(
929+
opt,
930+
warmup_ratio=warmup_ratio,
931+
hold_ratio=hold_ratio,
932+
max_steps=self.MAX_STEPS,
933+
min_lr=self.MIN_LR
934+
)
935+
936+
initial_lr = policy.get_last_lr()[0]
937+
assert initial_lr < self.INITIAL_LR # Should start at a lower LR
938+
939+
# Simulate training steps
940+
lrs = []
941+
for i in range(self.MAX_STEPS):
942+
current_lr = policy.get_last_lr()[0]
943+
lrs.append(current_lr)
944+
945+
# During warmup, LR should increase
946+
if i < warmup_steps:
947+
if i > 0:
948+
assert current_lr >= lrs[i-1]
949+
assert current_lr <= self.INITIAL_LR
950+
951+
# During hold, LR should remain constant
952+
elif i < warmup_steps + hold_steps:
953+
assert abs(current_lr - self.INITIAL_LR) < 1e-6
954+
955+
# During annealing, LR should decrease
956+
else:
957+
if i > warmup_steps + hold_steps:
958+
assert current_lr <= lrs[i-1]
959+
960+
opt.step()
961+
policy.step()
962+
963+
# Check final LR
964+
policy.step()
965+
final_lr = policy.get_last_lr()[0]
966+
assert final_lr == self.MIN_LR
967+
968+
@pytest.mark.unit
969+
def test_WarmupHoldAnnealLinear(self):
970+
model = TempModel()
971+
opt_cls = optim.get_optimizer('novograd')
972+
opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)
973+
974+
# Test case 1: No warmup, no hold
975+
policy = optim.lr_scheduler.WarmupHoldAnnealLinear(
976+
opt,
977+
warmup_ratio=None,
978+
hold_ratio=None,
979+
max_steps=self.MAX_STEPS,
980+
min_lr=self.MIN_LR
981+
)
982+
initial_lr = policy.get_last_lr()[0]
983+
assert initial_lr == self.INITIAL_LR
984+
985+
# Simulate training steps
986+
lrs = []
987+
for i in range(self.MAX_STEPS):
988+
current_lr = policy.get_last_lr()[0]
989+
lrs.append(current_lr)
990+
assert current_lr <= self.INITIAL_LR
991+
opt.step()
992+
policy.step()
993+
994+
# Check final LR
995+
policy.step()
996+
final_lr = policy.get_last_lr()[0]
997+
assert final_lr == self.MIN_LR
998+
999+
# Test case 2: With warmup and hold
1000+
warmup_ratio = 0.1 # 10% warmup
1001+
hold_ratio = 0.2 # 20% hold
1002+
warmup_steps = int(warmup_ratio * self.MAX_STEPS)
1003+
hold_steps = int(hold_ratio * self.MAX_STEPS)
1004+
1005+
policy = optim.lr_scheduler.WarmupHoldAnnealLinear(
1006+
opt,
1007+
warmup_ratio=warmup_ratio,
1008+
hold_ratio=hold_ratio,
1009+
max_steps=self.MAX_STEPS,
1010+
min_lr=self.MIN_LR
1011+
)
1012+
1013+
initial_lr = policy.get_last_lr()[0]
1014+
assert initial_lr < self.INITIAL_LR # Should start at a lower LR
1015+
1016+
# Simulate training steps
1017+
lrs = []
1018+
for i in range(self.MAX_STEPS):
1019+
current_lr = policy.get_last_lr()[0]
1020+
lrs.append(current_lr)
1021+
1022+
# During warmup, LR should increase
1023+
if i < warmup_steps:
1024+
if i > 0:
1025+
assert current_lr >= lrs[i-1]
1026+
assert current_lr <= self.INITIAL_LR
1027+
1028+
# During hold, LR should remain constant
1029+
elif i < warmup_steps + hold_steps:
1030+
assert abs(current_lr - self.INITIAL_LR) < 1e-6
1031+
1032+
# During annealing, LR should decrease
1033+
else:
1034+
if i > warmup_steps + hold_steps:
1035+
assert current_lr <= lrs[i-1]
1036+
1037+
opt.step()
1038+
policy.step()
1039+
1040+
# Check final LR
1041+
policy.step()
1042+
final_lr = policy.get_last_lr()[0]
1043+
assert final_lr == self.MIN_LR
1044+
8861045
@pytest.mark.unit
8871046
def test_CosineAnnealing_with_noop_steps(self):
8881047
model = TempModel()

0 commit comments

Comments
 (0)