Skip to content

Commit f2d7f99

Browse files
botevKfacJaxDev
authored and
KfacJaxDev
committed
Adding logging for the number of parameters and optimizer state.
PiperOrigin-RevId: 533111618
1 parent 9bc0c3d commit f2d7f99

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

examples/training.py

+41
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def __init__(
126126

127127
self._params, self._state, self._opt_state = None, None, None
128128
self._python_step = 0
129+
self._num_tensors = 0
130+
self._num_parameters = 0
131+
self._optimizer_state_size = 0
129132

130133
def log_machines_setup(self):
131134
"""Logs the machine setup for the experiment."""
@@ -377,12 +380,14 @@ def maybe_initialize_state(self):
377380
init_rng = kfac_jax.utils.replicate_all_local_devices(self.init_rng)
378381
# Initialize parameters and optional state
379382
params_rng, optimizer_rng = kfac_jax.utils.p_split(init_rng)
383+
logging.info("Initializing parameters.")
380384
if self.has_func_state:
381385
self._params, self._state = self.params_init(params_rng, self.init_batch)
382386
else:
383387
self._params = self.params_init(params_rng, self.init_batch)
384388

385389
# Initialize optimizer state
390+
logging.info("Initializing optimizer state.")
386391
self._opt_state = self.optimizer.init(
387392
self._params, optimizer_rng, self.init_batch, self._state
388393
)
@@ -391,6 +396,42 @@ def maybe_initialize_state(self):
391396
# Needed for checkpointing
392397
self._state = ()
393398

399+
# Log parameters
400+
self._num_tensors = 0
401+
self._num_parameters = 0
402+
logging.info("%s %s %s", "=" * 20, "Parameters", "=" * 20)
403+
for path, var in jax.tree_util.tree_flatten_with_path(self._params)[0]:
404+
# Because of pmap
405+
var = var[0]
406+
logging.info(
407+
"%s - %s, %s",
408+
"-".join(str(p)[2:-2] for p in path),
409+
var.shape,
410+
var.dtype,
411+
)
412+
self._num_parameters = self._num_parameters + var.size
413+
self._num_tensors = self._num_tensors + 1
414+
logging.info("Total parameters: %s", f"{self._num_parameters:,}")
415+
416+
# Log optimizer state
417+
self._optimizer_state_size = 0
418+
logging.info("%s %s %s", "=" * 20, "Optimizer State", "=" * 20)
419+
easy_state = kfac_jax.utils.serialize_state_tree(self._opt_state)
420+
for path, var in jax.tree_util.tree_flatten_with_path(easy_state)[0]:
421+
if isinstance(var, str):
422+
# For __class__ entries
423+
continue
424+
# Because of pmap
425+
var = var[0]
426+
logging.info(
427+
"%s - %s, %s",
428+
"/".join(str(p)[2:-2] for p in path),
429+
var.shape,
430+
var.dtype,
431+
)
432+
self._optimizer_state_size = self._optimizer_state_size + var.size
433+
logging.info("Total optimizer state: %s", f"{self._optimizer_state_size:,}")
434+
394435
# _ _
395436
# | |_ _ __ __ _(_)_ __
396437
# | __| "__/ _` | | "_ \

kfac_jax/_src/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
del types
4949

5050
# misc
51+
deserialize_state_tree = misc.deserialize_state_tree
52+
serialize_state_tree = misc.serialize_state_tree
5153
to_tuple_or_repeat = misc.to_tuple_or_repeat
5254
first_dim_is_size = misc.first_dim_is_size
5355
fake_element_from_iterator = misc.fake_element_from_iterator

kfac_jax/_src/utils/misc.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,16 @@ def copy(self):
117117
(flattened, structure) = jax.tree_util.tree_flatten(self)
118118
return jax.tree_util.tree_unflatten(structure, flattened)
119119

120-
def tree_flatten(self) -> Tuple[Tuple[ArrayTree, ...], None]:
121-
return self.field_values, None
120+
def tree_flatten(self) -> Tuple[Tuple[ArrayTree, ...], Tuple[str, ...]]:
121+
return self.field_values, self.field_names()
122122

123123
@classmethod
124124
def tree_unflatten(
125125
cls,
126-
aux_data: None,
126+
aux_data: Tuple[str, ...],
127127
children: Tuple[ArrayTree, ...],
128128
):
129-
del aux_data # not used
130-
return cls(**dict(zip(cls.field_names(), children)))
129+
return cls(**dict(zip(aux_data, children)))
131130

132131
def __repr__(self) -> str:
133132
return (f"{self.__class__.__name__}(" +

0 commit comments

Comments
 (0)