Skip to content

Commit fd69b9e

Browse files
author
Flax Authors
committed
Merge pull request #2702 from IvyZX:lrdep
PiperOrigin-RevId: 494689571
2 parents 12f2b27 + 5f8d876 commit fd69b9e

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

flax/training/lr_schedule.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
2424
"""
2525

26+
from absl import logging
2627
import jax.numpy as jnp
2728
import numpy as np
2829

@@ -36,6 +37,14 @@ def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch,
3637
warmup_length=0.0):
3738
"""Create a constant learning rate schedule with optional warmup.
3839
40+
Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
41+
**effectively deprecated** in favor of Optax_ schedules. Please refer to
42+
`Optimizer Schedules`_ for more information.
43+
44+
.. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
45+
.. _Optax: https://github.com/deepmind/optax
46+
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
47+
3948
Holds the learning rate constant. This function also offers a learing rate
4049
warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training
4150
with large mini-batches.
@@ -50,6 +59,11 @@ def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch,
5059
Returns:
5160
Function `f(step) -> lr` that computes the learning rate for a given step.
5261
"""
62+
logging.warning(
63+
'Learning rate schedules in ``flax.training`` are effectively deprecated '
64+
'in favor of Optax schedules. Please refer to '
65+
'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
66+
' for alternatives.')
5367
def learning_rate_fn(step):
5468
lr = base_learning_rate
5569
if warmup_length > 0.0:
@@ -62,6 +76,14 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch,
6276
lr_sched_steps, warmup_length=0.0):
6377
"""Create a stepped learning rate schedule with optional warmup.
6478
79+
Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
80+
**effectively deprecated** in favor of Optax_ schedules. Please refer to
81+
`Optimizer Schedules`_ for more information.
82+
83+
.. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
84+
.. _Optax: https://github.com/deepmind/optax
85+
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
86+
6587
A stepped learning rate schedule decreases the learning rate
6688
by specified amounts at specified epochs. The steps are given as
6789
the `lr_sched_steps` parameter. A common ImageNet schedule decays the
@@ -91,6 +113,11 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch,
91113
Returns:
92114
Function `f(step) -> lr` that computes the learning rate for a given step.
93115
"""
116+
logging.warning(
117+
'Learning rate schedules in ``flax.training`` are effectively deprecated '
118+
'in favor of Optax schedules. Please refer to '
119+
'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
120+
' for alternatives.')
94121
boundaries = [step[0] for step in lr_sched_steps]
95122
decays = [step[1] for step in lr_sched_steps]
96123
boundaries = np.array(boundaries) * steps_per_epoch
@@ -109,6 +136,14 @@ def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch,
109136
halfcos_epochs, warmup_length=0.0):
110137
"""Create a cosine learning rate schedule with optional warmup.
111138
139+
Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
140+
**effectively deprecated** in favor of Optax_ schedules. Please refer to
141+
`Optimizer Schedules`_ for more information.
142+
143+
.. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
144+
.. _Optax: https://github.com/deepmind/optax
145+
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
146+
112147
A cosine learning rate schedule modules the learning rate with
113148
half a cosine wave, gradually scaling it to 0 at the end of training.
114149
@@ -128,6 +163,11 @@ def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch,
128163
Returns:
129164
Function `f(step) -> lr` that computes the learning rate for a given step.
130165
"""
166+
logging.warning(
167+
'Learning rate schedules in ``flax.training`` are effectively deprecated '
168+
'in favor of Optax schedules. Please refer to '
169+
'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
170+
' for alternatives.')
131171
halfwavelength_steps = halfcos_epochs * steps_per_epoch
132172

133173
def learning_rate_fn(step):

0 commit comments

Comments
 (0)