Skip to content

Commit dddf979

Browse files
james-martensKfacJaxDev
authored and
KfacJaxDev
committed
- Adding "modifiable_attribute_exceptions" argument to optimizer
- Renaming "preprocess_rng" to "seed_rng" in examples (since it's used for more than just preprocessing) PiperOrigin-RevId: 444946712
1 parent f8b6405 commit dddf979

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

examples/training.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -271,17 +271,18 @@ def create_optimizer(self) -> Union[
271271

272272
def initialize_state(self):
273273
"""Initializes all of the experiment's state variables."""
274-
init_rng, preprocess_rng = jax.random.split(self.init_rng)
274+
init_rng, seed_rng = jax.random.split(self.init_rng)
275275
init_rng = kfac_jax.utils.replicate_all_local_devices(init_rng)
276-
preprocess_rng = jax.random.fold_in(preprocess_rng, jax.process_index())
276+
seed_rng = jax.random.fold_in(seed_rng, jax.process_index())
277+
seed = int(seed_rng[0])
277278

278279
# Initialize and load dataset
279280
if self.mode == "train":
280281
self._train_input = pipe_utils.py_prefetch(
281282
datasets.dataset_as_generator(
282283
self._build_train_input,
283284
split="train",
284-
seed=int(preprocess_rng[0]),
285+
seed=seed,
285286
device_batch_size=self.train_per_device_batch_size,
286287
)
287288
)
@@ -293,12 +294,12 @@ def initialize_state(self):
293294
self._eval_input = dict(
294295
train=self._build_eval_input(
295296
split="train",
296-
seed=int(preprocess_rng[0]),
297+
seed=seed,
297298
device_batch_size=self.eval_per_device_batch_size
298299
),
299300
test=self._build_eval_input(
300301
split="test",
301-
seed=int(preprocess_rng[0]),
302+
seed=seed,
302303
device_batch_size=self.eval_per_device_batch_size
303304
),
304305
)

kfac_jax/_src/optimizer.py

+6
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
default_batch_size_extractor,
122122
pmap_axis_name: str = "kfac_axis",
123123
forbid_setting_attributes_after_finalize: bool = True,
124+
modifiable_attribute_exceptions: Sequence[str] = (),
124125
include_norms_in_stats: bool = False,
125126
):
126127
"""Initializes the K-FAC optimizer with the provided settings.
@@ -266,6 +267,10 @@ def __init__(
266267
they have been compiled. However, if you are extending this class, and
267268
clearly understand the risks of modifying attributes, setting this to
268269
``False`` will remove the restriction. (Default: ``True``)
270+
modifiable_attribute_exceptions: Sequence of strings. Gives a list
271+
of names for attributes that can be modified after finalization even
272+
when ``forbid_setting_attributes_after_finalize`` is ``True``.
273+
(Default: ``()``)
269274
include_norms_in_stats: Boolean. It True, the vector norms of the
270275
gradient, preconditioned gradient, and parameter update are included in
271276
the statistics returned by the step function. (Default: ``False``)
@@ -276,6 +281,7 @@ def __init__(
276281
debug=debug,
277282
forbid_setting_attributes_after_finalize=
278283
forbid_setting_attributes_after_finalize,
284+
excluded_attribute_names=modifiable_attribute_exceptions,
279285
)
280286
if use_adaptive_damping and initial_damping is None:
281287
raise ValueError("When use_adaptive_damping is True you must provide a "

0 commit comments

Comments
 (0)