Skip to content

Commit 5f2ec52

Browse files
botevKfacJaxDev
authored and
KfacJaxDev
committed
* Adding an argument to set the reduction ratio thresholds for automatic damping adjustment.
* Bug fix: get_default_tag now correctly returns None if the tag is not present. * Adding an option to skip a raising error check if we are running a different graph. * Fixing a bug in ExplicitExactCurvature.update_cache(). PiperOrigin-RevId: 443657951
1 parent c30fa53 commit 5f2ec52

File tree

4 files changed

+37
-18
lines changed

4 files changed

+37
-18
lines changed

kfac_jax/_src/curvature_estimator.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@
8080
)
8181

8282

83-
def get_default_tag_to_block_ctor(tag_name: str) -> CurvatureBlockCtor:
83+
def get_default_tag_to_block_ctor(
84+
tag_name: str
85+
) -> Optional[CurvatureBlockCtor]:
8486
"""Returns the default curvature block constructor for the give tag name."""
8587
global _DEFAULT_TAG_TO_BLOCK_CTOR
86-
return _DEFAULT_TAG_TO_BLOCK_CTOR[tag_name]
88+
return _DEFAULT_TAG_TO_BLOCK_CTOR.get(tag_name)
8789

8890

8991
def set_default_tag_to_block_ctor(
@@ -1262,7 +1264,7 @@ def blocks_vectors_to_params_vector(
12621264

12631265
def update_curvature_matrix_estimate(
12641266
self,
1265-
state: curvature_blocks.Full.State,
1267+
state: BlockDiagonalCurvature.State,
12661268
ema_old: chex.Numeric,
12671269
ema_new: chex.Numeric,
12681270
batch_size: int,
@@ -1297,18 +1299,19 @@ def single_state_update(
12971299

12981300
def update_cache(
12991301
self,
1300-
state: curvature_blocks.Full.State,
1302+
state: BlockDiagonalCurvature.State,
13011303
identity_weight: chex.Numeric,
13021304
exact_powers: Optional[curvature_blocks.ScalarOrSequence],
13031305
approx_powers: Optional[curvature_blocks.ScalarOrSequence],
13041306
eigenvalues: bool,
13051307
pmap_axis_name: Optional[str],
13061308
) -> curvature_blocks.Full.State:
1307-
return self.blocks[0].update_cache(
1308-
state=state,
1309+
block_state = self.blocks[0].update_cache(
1310+
state=state.blocks_states[0],
13091311
identity_weight=identity_weight,
13101312
exact_powers=exact_powers,
13111313
approx_powers=approx_powers,
13121314
eigenvalues=eigenvalues,
13131315
pmap_axis_name=pmap_axis_name,
13141316
)
1317+
return BlockDiagonalCurvature.State(blocks_states=(block_state,))

kfac_jax/_src/optimizer.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def __init__(
101101
include_damping_in_quad_change: bool = False,
102102
damping_adaptation_interval: int = 5,
103103
damping_adaptation_decay: chex.Numeric = 0.9,
104+
damping_lower_threshold: chex.Numeric = 0.25,
105+
damping_upper_threshold: chex.Numeric = 0.75,
104106
always_use_exact_qmodel_for_damping_adjustment: bool = False,
105107
norm_constraint: Optional[chex.Numeric] = None,
106108
num_burnin_steps: int = 10,
@@ -200,6 +202,10 @@ def __init__(
200202
damping_adaptation_decay: Scalar. The ``damping`` parameter is multiplied
201203
by the ``damping_adaptation_decay`` every
202204
``damping_adaptation_interval`` number of iterations. (Default: ``0.9``)
205+
damping_lower_threshold: Scalar. The ``damping`` parameter is increased if
206+
the reduction ratio is below this threshold. (Default: ``0.25``)
207+
damping_upper_threshold: Scalar. The ``damping`` parameter is decreased if
208+
the reduction ratio is below this threshold. (Default: ``0.75``)
203209
always_use_exact_qmodel_for_damping_adjustment: Boolean. When using
204210
learning rate and/or momentum adaptation, the quadratic model change
205211
used for damping adaption is always computed using the exact curvature
@@ -314,6 +320,8 @@ def schedule_with_first_step_zero(global_step: chex.Array) -> chex.Array:
314320
self._include_damping_in_quad_change = include_damping_in_quad_change
315321
self._damping_adaptation_decay = damping_adaptation_decay
316322
self._damping_adaptation_interval = damping_adaptation_interval
323+
self._damping_lower_threshold = damping_lower_threshold
324+
self._damping_upper_threshold = damping_upper_threshold
317325
self._always_use_exact_qmodel_for_damping_adjustment = (
318326
always_use_exact_qmodel_for_damping_adjustment)
319327
self._norm_constraint = norm_constraint
@@ -1111,10 +1119,10 @@ def _compute_new_damping_and_rho(
11111119
rho = (new_loss - old_loss) / quad_change
11121120

11131121
# Update damping
1114-
should_decrease = rho > 0.75
1115-
decreased_damping = current_damping * self.damping_decay_factor
1116-
should_increase = rho < 0.25
1122+
should_increase = rho < self._damping_lower_threshold
11171123
increased_damping = current_damping / self.damping_decay_factor
1124+
should_decrease = rho > self._damping_upper_threshold
1125+
decreased_damping = current_damping * self.damping_decay_factor
11181126

11191127
# This is basically an if-else statement
11201128
damping = (should_decrease * decreased_damping +

kfac_jax/_src/tracer.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def cached_transformation(
282282
auto_register_tags: bool = True,
283283
allow_left_out_params: bool = False,
284284
allow_no_losses: bool = False,
285+
raise_error_on_diff_jaxpr: bool = True,
285286
**auto_registration_kwargs: Any,
286287
) -> TransformedFunction[T, T]:
287288
"""Caches ``transformation(preprocessed_jaxpr, func_args, *args)``.
@@ -302,6 +303,9 @@ def cached_transformation(
302303
tag.
303304
allow_no_losses: If this is set to ``False`` an error would be raised if no
304305
registered losses have been found when tracing the function.
306+
raise_error_on_diff_jaxpr: Whether to raise an exception if the function has
307+
been traced before, with different arguments, and the new Jaxpr graph
308+
differs in more than just the shapes and dtypes of the Jaxpr equations.
305309
**auto_registration_kwargs: Any additional keyword arguments, to be passed
306310
to the automatic registration pass.
307311
@@ -341,8 +345,8 @@ def wrapped_transformation(
341345
if not allow_no_losses and not processed_jaxpr.loss_tags:
342346
raise ValueError("No registered losses have been found during tracing.")
343347

344-
# If any previous `ProcessedJaxpr` exits verify that they are equivalent
345-
if cache:
348+
if cache and raise_error_on_diff_jaxpr:
349+
# If any previous `ProcessedJaxpr` exists verify that they are equivalent
346350
ref_jaxpr, _ = cache[next(iter(cache))]
347351
if ref_jaxpr != processed_jaxpr:
348352
raise ValueError("The consecutive tracing of the provided function "
@@ -889,6 +893,7 @@ def layer_tags_vjp(
889893
func: utils.Func,
890894
params_index: int = 0,
891895
auto_register_tags: bool = True,
896+
raise_error_on_diff_jaxpr: bool = True,
892897
**auto_registration_kwargs,
893898
) -> ...:
894899
"""Creates a function for primal values and tangents w.r.t. all layer tags.
@@ -910,6 +915,8 @@ def layer_tags_vjp(
910915
parameters.
911916
auto_register_tags: Whether to run an automatic layer registration on the
912917
function (e.g. :func:`~auto_register_tags`).
918+
raise_error_on_diff_jaxpr: When tracing with different arguments, if the
919+
returned jaxpr has a different graph will raise an exception.
913920
**auto_registration_kwargs: Any additional keyword arguments, to be passed
914921
to the automatic registration pass.
915922
@@ -924,5 +931,6 @@ def layer_tags_vjp(
924931
params_index=params_index,
925932
auto_register_tags=auto_register_tags,
926933
allow_left_out_params=False,
934+
raise_error_on_diff_jaxpr=raise_error_on_diff_jaxpr,
927935
**auto_registration_kwargs
928936
)

tests/test_estimator.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_explicit_exact_full(
106106
data = {}
107107
for name, shape in data_point_shapes.items():
108108
data_key, key = jax.random.split(data_key)
109-
data[name] = jax.random.uniform(key, (data_size,) + shape)
109+
data[name] = jax.random.uniform(key, (data_size, *shape))
110110
if name == "labels":
111111
data[name] = jnp.argmax(data[name], axis=-1)
112112

@@ -167,7 +167,7 @@ def test_block_diagonal_full(
167167
data = {}
168168
for name, shape in data_point_shapes.items():
169169
data_key, key = jax.random.split(data_key)
170-
data[name] = jax.random.uniform(key, (data_size,) + shape)
170+
data[name] = jax.random.uniform(key, (data_size, *shape))
171171
if name == "labels":
172172
data[name] = jnp.argmax(data[name], axis=-1)
173173

@@ -231,7 +231,7 @@ def test_block_diagonal_full_to_hessian(
231231
data = {}
232232
for name, shape in data_point_shapes.items():
233233
data_key, key = jax.random.split(data_key)
234-
data[name] = jax.random.uniform(key, (data_size,) + shape)
234+
data[name] = jax.random.uniform(key, (data_size, *shape))
235235
if name == "labels":
236236
data[name] = jnp.argmax(data[name], axis=-1)
237237

@@ -300,7 +300,7 @@ def test_diagonal(
300300
data = {}
301301
for name, shape in data_point_shapes.items():
302302
data_key, key = jax.random.split(data_key)
303-
data[name] = jax.random.uniform(key, (data_size,) + shape)
303+
data[name] = jax.random.uniform(key, (data_size, *shape))
304304
if name == "labels":
305305
data[name] = jnp.argmax(data[name], axis=-1)
306306

@@ -366,7 +366,7 @@ def test_kronecker_factored(
366366
data = {}
367367
for name, shape in data_point_shapes.items():
368368
data_key, key = jax.random.split(data_key)
369-
data[name] = jax.random.uniform(key, (data_size,) + shape)
369+
data[name] = jax.random.uniform(key, (data_size, *shape))
370370
if name == "labels":
371371
data[name] = jnp.argmax(data[name], axis=-1)
372372

@@ -446,7 +446,7 @@ def test_eigenvalues(
446446
data = {}
447447
for name, shape in data_point_shapes.items():
448448
data_key, key = jax.random.split(data_key)
449-
data[name] = jax.random.uniform(key, (data_size,) + shape)
449+
data[name] = jax.random.uniform(key, (data_size, *shape))
450450
if name == "labels":
451451
data[name] = jnp.argmax(data[name], axis=-1)
452452

@@ -534,7 +534,7 @@ def test_matmul(
534534
data = {}
535535
for name, shape in data_point_shapes.items():
536536
data_key, key = jax.random.split(data_key)
537-
data[name] = jax.random.uniform(key, (data_size,) + shape)
537+
data[name] = jax.random.uniform(key, (data_size, *shape))
538538
if name == "labels":
539539
data[name] = jnp.argmax(data[name], axis=-1)
540540

0 commit comments

Comments
 (0)