Skip to content

Commit affd636

Browse files
james-martensKfacJaxDev
authored and
KfacJaxDev
committed
Adding stepwise schedule option to examples.
PiperOrigin-RevId: 449204440
1 parent 2ffb778 commit affd636

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

examples/optimizers.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Utilities for setting up different optimizers."""
1515
import functools
16-
from typing import Any, Callable, Iterator, Mapping, Optional, Tuple, Type, Union
16+
from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Type, Union
1717

1818
from absl import logging
1919
import chex
@@ -266,18 +266,63 @@ def cosine_schedule(
266266
"""A cosine schedule described in the TAT paper."""
267267
if (steps is None) == (epochs is None):
268268
raise ValueError("Only one of `steps` and `epochs` can be set.")
269+
269270
warmup_steps = warmup_epochs * dataset_size / train_total_batch_size
271+
270272
if epochs is not None:
271273
total_steps = epochs * dataset_size / train_total_batch_size
272274
else:
273275
total_steps = steps
276+
274277
scaled_step = (jnp.maximum(global_step - warmup_steps, 0) /
275278
(total_steps - warmup_steps))
279+
276280
warmup_factor = jnp.minimum(1., global_step / warmup_steps)
277281
factor = (1.0 + jnp.cos(jnp.pi * scaled_step)) / 2
282+
278283
return initial_learning_rate * warmup_factor * factor
279284

280285

286+
def stepwise_schedule(
287+
global_step: chex.Numeric,
288+
dataset_size: int,
289+
train_total_batch_size: int,
290+
lr_decay_factors: Sequence[float],
291+
initial_learning_rate: float,
292+
epoch_boundaries: Optional[Sequence[float]],
293+
warmup_epochs: Optional[int],
294+
step_boundaries: Optional[Sequence[float]],
295+
warmup_steps: Optional[int],
296+
**_: Any,
297+
) -> chex.Array:
298+
"""A basic stepwise schedule."""
299+
300+
if (epoch_boundaries is None) == (step_boundaries is None):
301+
raise ValueError("Only one of `epoch_boundaries` and `step_boundaries` can "
302+
"be set.")
303+
304+
if (warmup_epochs is None) == (warmup_steps is None):
305+
raise ValueError("Only one of `warmup_epochs` and `warmup_steps` can be "
306+
"set.")
307+
308+
steps_per_epoch = dataset_size / train_total_batch_size
309+
current_epoch = global_step / steps_per_epoch
310+
311+
if step_boundaries is None:
312+
step_boundaries = jnp.array(epoch_boundaries) * steps_per_epoch
313+
else:
314+
step_boundaries = jnp.array(step_boundaries)
315+
316+
values = lr_decay_factors * initial_learning_rate
317+
index = jnp.sum(step_boundaries < global_step)
318+
lr = jnp.take(values, index)
319+
320+
if warmup_steps is None:
321+
return lr * jnp.minimum(1., current_epoch / warmup_epochs)
322+
else:
323+
return lr * jnp.minimum(1., global_step / warmup_steps)
324+
325+
281326
def construct_schedule(
282327
name: str,
283328
**kwargs,
@@ -291,6 +336,8 @@ def construct_schedule(
291336
return functools.partial(kfac_resnet50_schedule, **kwargs)
292337
elif name == "cosine":
293338
return functools.partial(cosine_schedule, **kwargs)
339+
elif name == "stepwise":
340+
return functools.partial(stepwise_schedule, **kwargs)
294341
else:
295342
raise NotImplementedError(name)
296343

0 commit comments

Comments
 (0)