Skip to content

Commit d15b304

Browse files
james-martensKfacJaxDev
authored and
KfacJaxDev
committed
- minor refactoring
PiperOrigin-RevId: 611099938
1 parent 5c64ae7 commit d15b304

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

kfac_jax/_src/optimizer.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,18 @@ def should_update_inverse_cache(
568568
self, state: "Optimizer.State"
569569
) -> Union[Array, bool]:
570570
"""Whether at the current step the optimizer should update the inverse curvature approximation."""
571-
if self._inverse_update_period == 1:
572-
return True
573-
return state.step_counter % self._inverse_update_period == 0
571+
return self._use_cached_inverses and (
572+
state.step_counter % self._inverse_update_period == 0)
573+
574+
def should_sync_estimator(
575+
self, state: "Optimizer.State"
576+
) -> Union[Array, bool]:
577+
"""Whether at the current step the optimizer should update the inverse curvature approximation."""
578+
579+
if self._use_cached_inverses:
580+
return self.should_update_inverse_cache(state)
581+
582+
return True
574583

575584
@functools.partial(utils.staged, static_argnums=1)
576585
def _rng_split(
@@ -970,6 +979,9 @@ def _init(
970979
) -> "Optimizer.State":
971980
"""A staged function to initialize the optimizer state ."""
972981

982+
# Note that we can reuse the ng in the func_args construction below, as
983+
# these are just dummy values used to perform the tracing.
984+
973985
return Optimizer.State(
974986
velocities=jax.tree_util.tree_map(jnp.zeros_like, params),
975987
estimator_state=self.estimator.init(
@@ -1014,7 +1026,8 @@ def _burnin(
10141026
rng: Array,
10151027
batch: Batch,
10161028
func_state: Optional[FuncState],
1017-
accumulator: utils.MultiChunkAccumulator
1029+
accumulator: utils.MultiChunkAccumulator,
1030+
sync: Union[Array, bool],
10181031
) -> Tuple["Optimizer.State", utils.MultiChunkAccumulator]:
10191032
"""A single burnin step, updating only the curvature estimate."""
10201033

@@ -1026,7 +1039,7 @@ def _burnin(
10261039

10271040
# Update curvature estimate
10281041
state.estimator_state = self._update_estimator_curvature(
1029-
state.estimator_state, func_args, rng, 1.0, 1.0)
1042+
state.estimator_state, func_args, rng, 1.0, 1.0, sync=sync)
10301043

10311044
# Optionally update func_state
10321045
if func_state is not None:
@@ -1050,16 +1063,18 @@ def burnin(
10501063
"""Runs all burnin steps required."""
10511064

10521065
if num_steps > 0:
1066+
10531067
rng = self._rng_split(rng, num_steps)
10541068

10551069
accumulator = utils.MultiChunkAccumulator.zeros_like(
10561070
func_state, self.multi_device)
10571071

1058-
for rng_i in rng:
1072+
for i, rng_i in enumerate(rng):
10591073
batch = next(data_iterator)
10601074

10611075
state, accumulator = self._burnin(
1062-
params, state, rng_i, batch, func_state, accumulator)
1076+
params, state, rng_i, batch, func_state, accumulator,
1077+
i == num_steps - 1)
10631078

10641079
func_state = accumulator.value_and_clear()
10651080

@@ -1099,9 +1114,7 @@ def _step(
10991114
rng,
11001115
self._curvature_ema,
11011116
1.0,
1102-
sync=self.should_update_inverse_cache(
1103-
state
1104-
), # sync curvature estimates only before inverses are updated.
1117+
sync=self.should_sync_estimator(state),
11051118
)
11061119

11071120
del rng # should not be used after this point!

0 commit comments

Comments
 (0)