13
13
# limitations under the License.
14
14
"""Utilities for setting up different optimizers."""
15
15
import functools
16
- from typing import Any , Callable , Iterator , Mapping , Optional , Tuple , Type , Union
16
+ from typing import Any , Callable , Iterator , Mapping , Optional , Sequence , Tuple , Type , Union
17
17
18
18
from absl import logging
19
19
import chex
@@ -266,18 +266,63 @@ def cosine_schedule(
266
266
"""A cosine schedule described in the TAT paper."""
267
267
if (steps is None ) == (epochs is None ):
268
268
raise ValueError ("Only one of `steps` and `epochs` can be set." )
269
+
269
270
warmup_steps = warmup_epochs * dataset_size / train_total_batch_size
271
+
270
272
if epochs is not None :
271
273
total_steps = epochs * dataset_size / train_total_batch_size
272
274
else :
273
275
total_steps = steps
276
+
274
277
scaled_step = (jnp .maximum (global_step - warmup_steps , 0 ) /
275
278
(total_steps - warmup_steps ))
279
+
276
280
warmup_factor = jnp .minimum (1. , global_step / warmup_steps )
277
281
factor = (1.0 + jnp .cos (jnp .pi * scaled_step )) / 2
282
+
278
283
return initial_learning_rate * warmup_factor * factor
279
284
280
285
286
+ def stepwise_schedule (
287
+ global_step : chex .Numeric ,
288
+ dataset_size : int ,
289
+ train_total_batch_size : int ,
290
+ lr_decay_factors : Sequence [float ],
291
+ initial_learning_rate : float ,
292
+ epoch_boundaries : Optional [Sequence [float ]],
293
+ warmup_epochs : Optional [int ],
294
+ step_boundaries : Optional [Sequence [float ]],
295
+ warmup_steps : Optional [int ],
296
+ ** _ : Any ,
297
+ ) -> chex .Array :
298
+ """A basic stepwise schedule."""
299
+
300
+ if (epoch_boundaries is None ) == (step_boundaries is None ):
301
+ raise ValueError ("Only one of `epoch_boundaries` and `step_boundaries` can "
302
+ "be set." )
303
+
304
+ if (warmup_epochs is None ) == (warmup_steps is None ):
305
+ raise ValueError ("Only one of `warmup_epochs` and `warmup_steps` can be "
306
+ "set." )
307
+
308
+ steps_per_epoch = dataset_size / train_total_batch_size
309
+ current_epoch = global_step / steps_per_epoch
310
+
311
+ if step_boundaries is None :
312
+ step_boundaries = jnp .array (epoch_boundaries ) * steps_per_epoch
313
+ else :
314
+ step_boundaries = jnp .array (step_boundaries )
315
+
316
+ values = lr_decay_factors * initial_learning_rate
317
+ index = jnp .sum (step_boundaries < global_step )
318
+ lr = jnp .take (values , index )
319
+
320
+ if warmup_steps is None :
321
+ return lr * jnp .minimum (1. , current_epoch / warmup_epochs )
322
+ else :
323
+ return lr * jnp .minimum (1. , global_step / warmup_steps )
324
+
325
+
281
326
def construct_schedule (
282
327
name : str ,
283
328
** kwargs ,
@@ -291,6 +336,8 @@ def construct_schedule(
291
336
return functools .partial (kfac_resnet50_schedule , ** kwargs )
292
337
elif name == "cosine" :
293
338
return functools .partial (cosine_schedule , ** kwargs )
339
+ elif name == "stepwise" :
340
+ return functools .partial (stepwise_schedule , ** kwargs )
294
341
else :
295
342
raise NotImplementedError (name )
296
343
0 commit comments