@@ -126,6 +126,9 @@ def __init__(
126
126
127
127
self ._params , self ._state , self ._opt_state = None , None , None
128
128
self ._python_step = 0
129
+ self ._num_tensors = 0
130
+ self ._num_parameters = 0
131
+ self ._optimizer_state_size = 0
129
132
130
133
def log_machines_setup (self ):
131
134
"""Logs the machine setup for the experiment."""
@@ -377,12 +380,14 @@ def maybe_initialize_state(self):
377
380
init_rng = kfac_jax .utils .replicate_all_local_devices (self .init_rng )
378
381
# Initialize parameters and optional state
379
382
params_rng , optimizer_rng = kfac_jax .utils .p_split (init_rng )
383
+ logging .info ("Initializing parameters." )
380
384
if self .has_func_state :
381
385
self ._params , self ._state = self .params_init (params_rng , self .init_batch )
382
386
else :
383
387
self ._params = self .params_init (params_rng , self .init_batch )
384
388
385
389
# Initialize optimizer state
390
+ logging .info ("Initializing optimizer state." )
386
391
self ._opt_state = self .optimizer .init (
387
392
self ._params , optimizer_rng , self .init_batch , self ._state
388
393
)
@@ -391,6 +396,42 @@ def maybe_initialize_state(self):
391
396
# Needed for checkpointing
392
397
self ._state = ()
393
398
399
+ # Log parameters
400
+ self ._num_tensors = 0
401
+ self ._num_parameters = 0
402
+ logging .info ("%s %s %s" , "=" * 20 , "Parameters" , "=" * 20 )
403
+ for path , var in jax .tree_util .tree_flatten_with_path (self ._params )[0 ]:
404
+ # Because of pmap
405
+ var = var [0 ]
406
+ logging .info (
407
+ "%s - %s, %s" ,
408
+ "-" .join (str (p )[2 :- 2 ] for p in path ),
409
+ var .shape ,
410
+ var .dtype ,
411
+ )
412
+ self ._num_parameters = self ._num_parameters + var .size
413
+ self ._num_tensors = self ._num_tensors + 1
414
+ logging .info ("Total parameters: %s" , f"{ self ._num_parameters :,} " )
415
+
416
+ # Log optimizer state
417
+ self ._optimizer_state_size = 0
418
+ logging .info ("%s %s %s" , "=" * 20 , "Optimizer State" , "=" * 20 )
419
+ easy_state = kfac_jax .utils .serialize_state_tree (self ._opt_state )
420
+ for path , var in jax .tree_util .tree_flatten_with_path (easy_state )[0 ]:
421
+ if isinstance (var , str ):
422
+ # For __class__ entries
423
+ continue
424
+ # Because of pmap
425
+ var = var [0 ]
426
+ logging .info (
427
+ "%s - %s, %s" ,
428
+ "/" .join (str (p )[2 :- 2 ] for p in path ),
429
+ var .shape ,
430
+ var .dtype ,
431
+ )
432
+ self ._optimizer_state_size = self ._optimizer_state_size + var .size
433
+ logging .info ("Total optimizer state: %s" , f"{ self ._optimizer_state_size :,} " )
434
+
394
435
# _ _
395
436
# | |_ _ __ __ _(_)_ __
396
437
# | __| "__/ _` | | "_ \
0 commit comments