23
23
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
24
24
"""
25
25
26
+ from absl import logging
26
27
import jax .numpy as jnp
27
28
import numpy as np
28
29
@@ -36,6 +37,14 @@ def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch,
36
37
warmup_length = 0.0 ):
37
38
"""Create a constant learning rate schedule with optional warmup.
38
39
40
+ Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
41
+ **effectively deprecated** in favor of Optax_ schedules. Please refer to
42
+ `Optimizer Schedules`_ for more information.
43
+
44
+ .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
45
+ .. _Optax: https://github.com/deepmind/optax
46
+ .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
47
+
39
48
Holds the learning rate constant. This function also offers a learing rate
40
49
warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training
41
50
with large mini-batches.
@@ -50,6 +59,11 @@ def create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch,
50
59
Returns:
51
60
Function `f(step) -> lr` that computes the learning rate for a given step.
52
61
"""
62
+ logging .warning (
63
+ 'Learning rate schedules in ``flax.training`` are effectively deprecated '
64
+ 'in favor of Optax schedules. Please refer to '
65
+ 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
66
+ ' for alternatives.' )
53
67
def learning_rate_fn (step ):
54
68
lr = base_learning_rate
55
69
if warmup_length > 0.0 :
@@ -62,6 +76,14 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch,
62
76
lr_sched_steps , warmup_length = 0.0 ):
63
77
"""Create a stepped learning rate schedule with optional warmup.
64
78
79
+ Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
80
+ **effectively deprecated** in favor of Optax_ schedules. Please refer to
81
+ `Optimizer Schedules`_ for more information.
82
+
83
+ .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
84
+ .. _Optax: https://github.com/deepmind/optax
85
+ .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
86
+
65
87
A stepped learning rate schedule decreases the learning rate
66
88
by specified amounts at specified epochs. The steps are given as
67
89
the `lr_sched_steps` parameter. A common ImageNet schedule decays the
@@ -91,6 +113,11 @@ def create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch,
91
113
Returns:
92
114
Function `f(step) -> lr` that computes the learning rate for a given step.
93
115
"""
116
+ logging .warning (
117
+ 'Learning rate schedules in ``flax.training`` are effectively deprecated '
118
+ 'in favor of Optax schedules. Please refer to '
119
+ 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
120
+ ' for alternatives.' )
94
121
boundaries = [step [0 ] for step in lr_sched_steps ]
95
122
decays = [step [1 ] for step in lr_sched_steps ]
96
123
boundaries = np .array (boundaries ) * steps_per_epoch
@@ -109,6 +136,14 @@ def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch,
109
136
halfcos_epochs , warmup_length = 0.0 ):
110
137
"""Create a cosine learning rate schedule with optional warmup.
111
138
139
+ Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
140
+ **effectively deprecated** in favor of Optax_ schedules. Please refer to
141
+ `Optimizer Schedules`_ for more information.
142
+
143
+ .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
144
+ .. _Optax: https://github.com/deepmind/optax
145
+ .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
146
+
112
147
A cosine learning rate schedule modules the learning rate with
113
148
half a cosine wave, gradually scaling it to 0 at the end of training.
114
149
@@ -128,6 +163,11 @@ def create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch,
128
163
Returns:
129
164
Function `f(step) -> lr` that computes the learning rate for a given step.
130
165
"""
166
+ logging .warning (
167
+ 'Learning rate schedules in ``flax.training`` are effectively deprecated '
168
+ 'in favor of Optax schedules. Please refer to '
169
+ 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules'
170
+ ' for alternatives.' )
131
171
halfwavelength_steps = halfcos_epochs * steps_per_epoch
132
172
133
173
def learning_rate_fn (step ):
0 commit comments