@@ -121,6 +121,7 @@ def __init__(
121
121
default_batch_size_extractor ,
122
122
pmap_axis_name : str = "kfac_axis" ,
123
123
forbid_setting_attributes_after_finalize : bool = True ,
124
+ modifiable_attribute_exceptions : Sequence [str ] = (),
124
125
include_norms_in_stats : bool = False ,
125
126
):
126
127
"""Initializes the K-FAC optimizer with the provided settings.
@@ -266,6 +267,10 @@ def __init__(
266
267
they have been compiled. However, if you are extending this class, and
267
268
clearly understand the risks of modifying attributes, setting this to
268
269
``False`` will remove the restriction. (Default: ``True``)
270
+ modifiable_attribute_exceptions: Sequence of strings. Gives a list
271
+ of names for attributes that can be modified after finalization even
272
+ when ``forbid_setting_attributes_after_finalize`` is ``True``.
273
+ (Default: ``()``)
269
274
include_norms_in_stats: Boolean. It True, the vector norms of the
270
275
gradient, preconditioned gradient, and parameter update are included in
271
276
the statistics returned by the step function. (Default: ``False``)
@@ -276,6 +281,7 @@ def __init__(
276
281
debug = debug ,
277
282
forbid_setting_attributes_after_finalize =
278
283
forbid_setting_attributes_after_finalize ,
284
+ excluded_attribute_names = modifiable_attribute_exceptions ,
279
285
)
280
286
if use_adaptive_damping and initial_damping is None :
281
287
raise ValueError ("When use_adaptive_damping is True you must provide a "
0 commit comments