Skip to content

Add deprecation warnings and docstring pointers to flax.training.lr_schedule #2702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions flax/training/lr_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
"""

from absl import logging
import jax.numpy as jnp
import numpy as np

Expand All @@ -36,6 +37,14 @@ def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch,
warmup_length=0.0):
"""Create a constant learning rate schedule with optional warmup.

Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
**effectively deprecated** in favor of Optax_ schedules. Please refer to
`Optimizer Schedules`_ for more information.

.. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
.. _Optax: https://github.com/deepmind/optax
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules

Holds the learning rate constant. This function also offers a learing rate
warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training
with large mini-batches.
Expand All @@ -50,6 +59,11 @@ def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch,
Returns:
Function `f(step) -> lr` that computes the learning rate for a given step.
"""
logging.warning(
'Learning rate schedules in ``flax.training`` are effectively deprecated '
'in favor of Optax schedules. Please refer to '
'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
' for alternatives.')
def learning_rate_fn(step):
lr = base_learning_rate
if warmup_length > 0.0:
Expand All @@ -62,6 +76,14 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch,
lr_sched_steps, warmup_length=0.0):
"""Create a stepped learning rate schedule with optional warmup.

Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
**effectively deprecated** in favor of Optax_ schedules. Please refer to
`Optimizer Schedules`_ for more information.

.. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
.. _Optax: https://github.com/deepmind/optax
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules

A stepped learning rate schedule decreases the learning rate
by specified amounts at specified epochs. The steps are given as
the `lr_sched_steps` parameter. A common ImageNet schedule decays the
Expand Down Expand Up @@ -91,6 +113,11 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch,
Returns:
Function `f(step) -> lr` that computes the learning rate for a given step.
"""
logging.warning(
'Learning rate schedules in ``flax.training`` are effectively deprecated '
'in favor of Optax schedules. Please refer to '
'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
' for alternatives.')
boundaries = [step[0] for step in lr_sched_steps]
decays = [step[1] for step in lr_sched_steps]
boundaries = np.array(boundaries) * steps_per_epoch
Expand All @@ -109,6 +136,14 @@ def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch,
halfcos_epochs, warmup_length=0.0):
"""Create a cosine learning rate schedule with optional warmup.

Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
**effectively deprecated** in favor of Optax_ schedules. Please refer to
`Optimizer Schedules`_ for more information.

.. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
.. _Optax: https://github.com/deepmind/optax
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules

A cosine learning rate schedule modules the learning rate with
half a cosine wave, gradually scaling it to 0 at the end of training.

Expand All @@ -128,6 +163,11 @@ def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch,
Returns:
Function `f(step) -> lr` that computes the learning rate for a given step.
"""
logging.warning(
'Learning rate schedules in ``flax.training`` are effectively deprecated '
'in favor of Optax schedules. Please refer to '
'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
' for alternatives.')
halfwavelength_steps = halfcos_epochs * steps_per_epoch

def learning_rate_fn(step):
Expand Down