@@ -568,9 +568,18 @@ def should_update_inverse_cache(
568
568
self , state : "Optimizer.State"
569
569
) -> Union [Array , bool ]:
570
570
"""Whether at the current step the optimizer should update the inverse curvature approximation."""
571
- if self ._inverse_update_period == 1 :
572
- return True
573
- return state .step_counter % self ._inverse_update_period == 0
571
+ return self ._use_cached_inverses and (
572
+ state .step_counter % self ._inverse_update_period == 0 )
573
+
574
+ def should_sync_estimator (
575
+ self , state : "Optimizer.State"
576
+ ) -> Union [Array , bool ]:
577
+ """Whether at the current step the optimizer should update the inverse curvature approximation."""
578
+
579
+ if self ._use_cached_inverses :
580
+ return self .should_update_inverse_cache (state )
581
+
582
+ return True
574
583
575
584
@functools .partial (utils .staged , static_argnums = 1 )
576
585
def _rng_split (
@@ -970,6 +979,9 @@ def _init(
970
979
) -> "Optimizer.State" :
971
980
"""A staged function to initialize the optimizer state ."""
972
981
982
+ # Note that we can reuse the ng in the func_args construction below, as
983
+ # these are just dummy values used to perform the tracing.
984
+
973
985
return Optimizer .State (
974
986
velocities = jax .tree_util .tree_map (jnp .zeros_like , params ),
975
987
estimator_state = self .estimator .init (
@@ -1014,7 +1026,8 @@ def _burnin(
1014
1026
rng : Array ,
1015
1027
batch : Batch ,
1016
1028
func_state : Optional [FuncState ],
1017
- accumulator : utils .MultiChunkAccumulator
1029
+ accumulator : utils .MultiChunkAccumulator ,
1030
+ sync : Union [Array , bool ],
1018
1031
) -> Tuple ["Optimizer.State" , utils .MultiChunkAccumulator ]:
1019
1032
"""A single burnin step, updating only the curvature estimate."""
1020
1033
@@ -1026,7 +1039,7 @@ def _burnin(
1026
1039
1027
1040
# Update curvature estimate
1028
1041
state .estimator_state = self ._update_estimator_curvature (
1029
- state .estimator_state , func_args , rng , 1.0 , 1.0 )
1042
+ state .estimator_state , func_args , rng , 1.0 , 1.0 , sync = sync )
1030
1043
1031
1044
# Optionally update func_state
1032
1045
if func_state is not None :
@@ -1050,16 +1063,18 @@ def burnin(
1050
1063
"""Runs all burnin steps required."""
1051
1064
1052
1065
if num_steps > 0 :
1066
+
1053
1067
rng = self ._rng_split (rng , num_steps )
1054
1068
1055
1069
accumulator = utils .MultiChunkAccumulator .zeros_like (
1056
1070
func_state , self .multi_device )
1057
1071
1058
- for rng_i in rng :
1072
+ for i , rng_i in enumerate ( rng ) :
1059
1073
batch = next (data_iterator )
1060
1074
1061
1075
state , accumulator = self ._burnin (
1062
- params , state , rng_i , batch , func_state , accumulator )
1076
+ params , state , rng_i , batch , func_state , accumulator ,
1077
+ i == num_steps - 1 )
1063
1078
1064
1079
func_state = accumulator .value_and_clear ()
1065
1080
@@ -1099,9 +1114,7 @@ def _step(
1099
1114
rng ,
1100
1115
self ._curvature_ema ,
1101
1116
1.0 ,
1102
- sync = self .should_update_inverse_cache (
1103
- state
1104
- ), # sync curvature estimates only before inverses are updated.
1117
+ sync = self .should_sync_estimator (state ),
1105
1118
)
1106
1119
1107
1120
del rng # should not be used after this point!
0 commit comments