@@ -207,7 +207,8 @@ def variables(self):
207
207
208
208
Returns:
209
209
A sequence of variables for the current module (sorted by attribute
210
- name) followed by variables from all submodules recursively (depth first).
210
+ name) followed by variables from all submodules recursively (breadth
211
+ first).
211
212
"""
212
213
return tuple (self ._flatten (predicate = _IS_VARIABLE ))
213
214
@@ -221,7 +222,8 @@ def trainable_variables(self):
221
222
222
223
Returns:
223
224
A sequence of variables for the current module (sorted by attribute
224
- name) followed by variables from all submodules recursively (depth first).
225
+ name) followed by variables from all submodules recursively (breadth
226
+ first).
225
227
"""
226
228
return tuple (self ._flatten (predicate = _IS_TRAINABLE_VARIABLE ))
227
229
@@ -249,7 +251,8 @@ def submodules(self):
249
251
def _flatten (self ,
250
252
recursive = True ,
251
253
predicate = None ,
252
- attribute_traversal_key = None ):
254
+ attribute_traversal_key = None ,
255
+ with_path = False ):
253
256
"""Flattened attribute values in sorted order by attribute name.
254
257
255
258
Modules are flattened by first walking their attributes in name order.
@@ -267,11 +270,15 @@ def _flatten(self,
267
270
...
268
271
... @property
269
272
... def tensors(self):
270
- ... return tuple(self._flatten(predicate=is_tensor))
273
+ ... return tuple(self._flatten(predicate=is_tensor, with_path=True ))
271
274
272
275
>>> foo = Foo()
273
276
>>> foo.tensors
274
- (<tf.Tensor...'a'>, <tf.Tensor...'b'>, ...'c'>, ...'d'>, ...'e'>)
277
+ ((('x', 0), <tf.Tensor: ...'a'>),
278
+ (('x', 1), <tf.Tensor: ...'b'>),
279
+ (('y', 'i'), <tf.Tensor: ...'c'>),
280
+ (('y', 'j'), <tf.Tensor: ...'d'>),
281
+ (('z',), <tf.Tensor: ...'e'>))
275
282
276
283
`attribute_traversal_key` controls the order object properties are visited.
277
284
If not set objects are visited in ascending order by name.
@@ -284,6 +291,10 @@ def _flatten(self,
284
291
attribute_traversal_key: (Optional) Method to rekey object attributes
285
292
before they are sorted. Contract is the same as `key` argument to
286
293
builtin `sorted` and only applies to object properties.
294
+ with_path: (Optional) Whether to include the path to the object as well
295
+ as the object itself. If `with_path` is `True` then leaves will not be
296
+ de-duplicated (e.g. if the same leaf instance is reachable via multiple
297
+ modules then it will be yielded multiple times with different paths).
287
298
288
299
Returns:
289
300
Flat generator for leaves of the current module and optionally all
@@ -297,7 +308,7 @@ def _flatten(self,
297
308
recursive = recursive ,
298
309
predicate = predicate ,
299
310
attribute_traversal_key = attribute_traversal_key ,
300
- seen = set () )
311
+ with_path = with_path )
301
312
302
313
@classmethod
303
314
def no_name_scope (cls , method ):
@@ -337,8 +348,20 @@ def camel_to_snake(value):
337
348
return _CAMEL_TO_SNAKE_R .sub (r"_\1" , value ).lower ()
338
349
339
350
340
- def _flatten_module (module , recursive , predicate , attribute_traversal_key ,
341
- seen ):
351
+ # AutoCheckpointable adds object attributes that users will not expect us to
352
+ # include when flattening (these reference dependencies reachable via other
353
+ # object attributes).
354
+ AUTO_CHECKPOINTABLE_ATTRS = ("_unconditional_checkpoint_dependencies" ,
355
+ "_unconditional_dependency_names" )
356
+
357
+
358
+ def _flatten_module (module ,
359
+ recursive ,
360
+ predicate ,
361
+ attribute_traversal_key ,
362
+ with_path ,
363
+ module_path = (),
364
+ seen = None ):
342
365
"""Implementation of `flatten`."""
343
366
if seen is None :
344
367
seen = set ([id (module )])
@@ -347,25 +370,37 @@ def _flatten_module(module, recursive, predicate, attribute_traversal_key,
347
370
submodules = []
348
371
349
372
for key in sorted (module_dict , key = attribute_traversal_key ):
350
- for leaf in nest .flatten (module_dict [key ]):
351
- leaf_id = id (leaf )
352
- if leaf_id in seen :
353
- continue
373
+ if key in AUTO_CHECKPOINTABLE_ATTRS :
374
+ continue
375
+
376
+ for leaf_path , leaf in nest .flatten_with_tuple_paths (module_dict [key ]):
377
+ leaf_path = (key ,) + leaf_path
378
+
379
+ # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
380
+ if not with_path :
381
+ leaf_id = id (leaf )
382
+ if leaf_id in seen :
383
+ continue
384
+ seen .add (leaf_id )
354
385
355
- seen .add (leaf_id )
356
386
if predicate (leaf ):
357
- yield leaf
387
+ if with_path :
388
+ yield module_path + leaf_path , leaf
389
+ else :
390
+ yield leaf
358
391
359
392
if recursive and isinstance (leaf , Module ):
360
393
# Walk direct properties first then recurse.
361
- submodules .append (leaf )
394
+ submodules .append (( module_path + leaf_path , leaf ) )
362
395
363
- for submodule in submodules :
396
+ for submodule_path , submodule in submodules :
364
397
subvalues = _flatten_module (
365
398
submodule ,
366
399
recursive = recursive ,
367
400
predicate = predicate ,
368
401
attribute_traversal_key = attribute_traversal_key ,
402
+ with_path = with_path ,
403
+ module_path = submodule_path ,
369
404
seen = seen )
370
405
371
406
for subvalue in subvalues :
0 commit comments