Skip to content

Commit 1d5fd92

Browse files
tomhennigantensorflower-gardener
authored andcommitted
Add with_path to allow modules to be flattened with key paths.
``` >>> class MyModule(tf.Module): ... @Property ... def state_dict(self): ... return dict(self._flatten( ... predicate=lambda v: isinstance(v, tf.Variable), with_path=True)) >>> mod = MyModule() >>> mod.encoder = Encoder() >>> mod.decoder = mod.encoder >>> mod.state_dict {('encoder', 'w'), <tf.Variable: ...>, ('decoder', 'w'), <tf.Variable: ...>} ``` h/t tensorflow/community#56 PiperOrigin-RevId: 232908045
1 parent 567e02c commit 1d5fd92

File tree

2 files changed

+70
-18
lines changed

2 files changed

+70
-18
lines changed

tensorflow/python/module/module.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def variables(self):
207207
208208
Returns:
209209
A sequence of variables for the current module (sorted by attribute
210-
name) followed by variables from all submodules recursively (depth first).
210+
name) followed by variables from all submodules recursively (breadth
211+
first).
211212
"""
212213
return tuple(self._flatten(predicate=_IS_VARIABLE))
213214

@@ -221,7 +222,8 @@ def trainable_variables(self):
221222
222223
Returns:
223224
A sequence of variables for the current module (sorted by attribute
224-
name) followed by variables from all submodules recursively (depth first).
225+
name) followed by variables from all submodules recursively (breadth
226+
first).
225227
"""
226228
return tuple(self._flatten(predicate=_IS_TRAINABLE_VARIABLE))
227229

@@ -249,7 +251,8 @@ def submodules(self):
249251
def _flatten(self,
250252
recursive=True,
251253
predicate=None,
252-
attribute_traversal_key=None):
254+
attribute_traversal_key=None,
255+
with_path=False):
253256
"""Flattened attribute values in sorted order by attribute name.
254257
255258
Modules are flattened by first walking their attributes in name order.
@@ -267,11 +270,15 @@ def _flatten(self,
267270
...
268271
... @property
269272
... def tensors(self):
270-
... return tuple(self._flatten(predicate=is_tensor))
273+
... return tuple(self._flatten(predicate=is_tensor, with_path=True))
271274
272275
>>> foo = Foo()
273276
>>> foo.tensors
274-
(<tf.Tensor...'a'>, <tf.Tensor...'b'>, ...'c'>, ...'d'>, ...'e'>)
277+
((('x', 0), <tf.Tensor: ...'a'>),
278+
(('x', 1), <tf.Tensor: ...'b'>),
279+
(('y', 'i'), <tf.Tensor: ...'c'>),
280+
(('y', 'j'), <tf.Tensor: ...'d'>),
281+
(('z',), <tf.Tensor: ...'e'>))
275282
276283
`attribute_traversal_key` controls the order object properties are visited.
277284
If not set objects are visited in ascending order by name.
@@ -284,6 +291,10 @@ def _flatten(self,
284291
attribute_traversal_key: (Optional) Method to rekey object attributes
285292
before they are sorted. Contract is the same as `key` argument to
286293
builtin `sorted` and only applies to object properties.
294+
with_path: (Optional) Whether to include the path to the object as well
295+
as the object itself. If `with_path` is `True` then leaves will not be
296+
de-duplicated (e.g. if the same leaf instance is reachable via multiple
297+
modules then it will be yielded multiple times with different paths).
287298
288299
Returns:
289300
Flat generator for leaves of the current module and optionally all
@@ -297,7 +308,7 @@ def _flatten(self,
297308
recursive=recursive,
298309
predicate=predicate,
299310
attribute_traversal_key=attribute_traversal_key,
300-
seen=set())
311+
with_path=with_path)
301312

302313
@classmethod
303314
def no_name_scope(cls, method):
@@ -337,8 +348,20 @@ def camel_to_snake(value):
337348
return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
338349

339350

340-
def _flatten_module(module, recursive, predicate, attribute_traversal_key,
341-
seen):
351+
# AutoCheckpointable adds object attributes that users will not expect us to
352+
# include when flattening (these reference dependencies reachable via other
353+
# object attributes).
354+
AUTO_CHECKPOINTABLE_ATTRS = ("_unconditional_checkpoint_dependencies",
355+
"_unconditional_dependency_names")
356+
357+
358+
def _flatten_module(module,
359+
recursive,
360+
predicate,
361+
attribute_traversal_key,
362+
with_path,
363+
module_path=(),
364+
seen=None):
342365
"""Implementation of `flatten`."""
343366
if seen is None:
344367
seen = set([id(module)])
@@ -347,25 +370,37 @@ def _flatten_module(module, recursive, predicate, attribute_traversal_key,
347370
submodules = []
348371

349372
for key in sorted(module_dict, key=attribute_traversal_key):
350-
for leaf in nest.flatten(module_dict[key]):
351-
leaf_id = id(leaf)
352-
if leaf_id in seen:
353-
continue
373+
if key in AUTO_CHECKPOINTABLE_ATTRS:
374+
continue
375+
376+
for leaf_path, leaf in nest.flatten_with_tuple_paths(module_dict[key]):
377+
leaf_path = (key,) + leaf_path
378+
379+
# TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
380+
if not with_path:
381+
leaf_id = id(leaf)
382+
if leaf_id in seen:
383+
continue
384+
seen.add(leaf_id)
354385

355-
seen.add(leaf_id)
356386
if predicate(leaf):
357-
yield leaf
387+
if with_path:
388+
yield module_path + leaf_path, leaf
389+
else:
390+
yield leaf
358391

359392
if recursive and isinstance(leaf, Module):
360393
# Walk direct properties first then recurse.
361-
submodules.append(leaf)
394+
submodules.append((module_path + leaf_path, leaf))
362395

363-
for submodule in submodules:
396+
for submodule_path, submodule in submodules:
364397
subvalues = _flatten_module(
365398
submodule,
366399
recursive=recursive,
367400
predicate=predicate,
368401
attribute_traversal_key=attribute_traversal_key,
402+
with_path=with_path,
403+
module_path=submodule_path,
369404
seen=seen)
370405

371406
for subvalue in subvalues:

tensorflow/python/module/module_test.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,10 @@ def with_name_scope(self):
299299
mk_index_dict = lambda v: dict(enumerate(v))
300300

301301

302-
class WalkTest(parameterized.TestCase, test.TestCase):
302+
class FlattenTest(parameterized.TestCase, test.TestCase):
303303

304304
@parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict)
305-
def test_walk(self, container_type):
305+
def test_flatten(self, container_type):
306306
parent = SimpleModule(container_type=container_type)
307307
child = parent.c
308308

@@ -320,6 +320,23 @@ def test_attribute_traversal_key(self):
320320
mod.variables,
321321
mod._trainable_variables + mod._non_trainable_variables + [mod._bonus])
322322

323+
def test_with_path(self):
324+
mod = module.Module()
325+
mod.w = variables.Variable(1.)
326+
mod.encoder = module.Module()
327+
mod.encoder.w = [({"k": mod.w}, {"k": mod.w})]
328+
mod.decoder = mod.encoder
329+
330+
state_dict = dict(
331+
mod._flatten(with_path=True, predicate=module._IS_VARIABLE))
332+
333+
self.assertEqual(state_dict,
334+
{("w",): mod.w,
335+
("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"],
336+
("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"],
337+
("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
338+
("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},)
339+
323340

324341
class LayerModule(module.Module):
325342

0 commit comments

Comments
 (0)