Skip to content

[nnx] optimize NodeDef.attributes #4399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 106 additions & 10 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

118 changes: 61 additions & 57 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,24 @@ def __treescope_repr__(self, path, subtree_renderer):
jax.tree_util.register_static(VariableDef)


@dataclasses.dataclass(frozen=True, slots=True)
class SubGraphAttribute:
key: Key
value: NodeDef[tp.Any] | NodeRef[tp.Any]


@dataclasses.dataclass(frozen=True, slots=True)
class StaticAttribute:
key: Key
value: tp.Any


@dataclasses.dataclass(frozen=True, slots=True)
class LeafAttribute:
key: Key
value: VariableDef | NodeRef[tp.Any]


@dataclasses.dataclass(frozen=True, repr=False, slots=True)
class NodeDef(GraphDef[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
Expand All @@ -298,10 +316,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable):

type: tp.Type[Node]
index: int
attributes: tuple[Key, ...]
subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: HashableMapping[Key, tp.Any]
leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...]
metadata: tp.Any
index_mapping: HashableMapping[Index, Index] | None

Expand All @@ -310,20 +325,14 @@ def create(
cls,
type: tp.Type[Node],
index: int,
attributes: tuple[Key, ...],
subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]],
static_fields: tp.Iterable[tuple[Key, tp.Any]],
leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]],
attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...],
metadata: tp.Any,
index_mapping: tp.Mapping[Index, Index] | None,
):
return cls(
type=type,
index=index,
attributes=attributes,
subgraphs=HashableMapping(subgraphs),
static_fields=HashableMapping(static_fields),
leaves=HashableMapping(leaves),
metadata=metadata,
index_mapping=HashableMapping(index_mapping)
if index_mapping is not None
Expand All @@ -335,12 +344,7 @@ def __nnx_repr__(self):

yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('index', self.index)
yield reprlib.Attr('attributes', self.attributes)
yield reprlib.Attr('subgraphs', reprlib.PrettyMapping(self.subgraphs))
yield reprlib.Attr(
'static_fields', reprlib.PrettyMapping(self.static_fields)
)
yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes))
yield reprlib.Attr('metadata', self.metadata)
yield reprlib.Attr(
'index_mapping',
Expand All @@ -352,18 +356,15 @@ def __nnx_repr__(self):
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'attributes': self.attributes,
'subgraphs': dict(self.subgraphs),
'static_fields': dict(self.static_fields),
'leaves': dict(self.leaves),
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'attributes': self.attributes,
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
)

def apply(
Expand Down Expand Up @@ -426,40 +427,39 @@ def _graph_flatten(
else:
index = -1

subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
static_fields: list[tuple[Key, tp.Any]] = []
leaves: list[tuple[Key, VariableDef | NodeRef]] = []
attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = []

values, metadata = node_impl.flatten(node)
for key, value in values:
if is_node(value):
nodedef = _graph_flatten((*path, key), ref_index, flat_state, value)
subgraphs.append((key, nodedef))
# subgraphs.append((key, nodedef))
attributes.append(SubGraphAttribute(key, nodedef))
elif isinstance(value, Variable):
if value in ref_index:
leaves.append((key, NodeRef(type(value), ref_index[value])))
attributes.append(
LeafAttribute(key, NodeRef(type(value), ref_index[value]))
)
else:
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, HashableMapping(value.get_metadata())
)
leaves.append((key, variabledef))
attributes.append(LeafAttribute(key, variabledef))
else:
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
)
static_fields.append((key, value))
# static_fields.append((key, value))
attributes.append(StaticAttribute(key, value))

nodedef = NodeDef.create(
type=node_impl.type,
index=index,
attributes=tuple(key for key, _ in values),
subgraphs=subgraphs,
static_fields=static_fields,
leaves=leaves,
attributes=tuple(attributes),
metadata=metadata,
index_mapping=None,
)
Expand Down Expand Up @@ -529,22 +529,20 @@ def _graph_unflatten(

def _get_children():
children: dict[Key, NodeLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
raise ValueError(f'Unknown keys: {unkown_keys}')
state_keys: set = set(state.keys())

# for every key in attributes there are 6 possible cases:
# - (2) the key can either be present in the state or not
# - (3) the key can be a subgraph, a leaf, or a static attribute
for key in nodedef.attributes:
for attribute in nodedef.attributes:
key = attribute.key
if key not in state:
# if key is not present create an empty types
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key in nodedef.subgraphs:
if type(attribute) is StaticAttribute:
children[key] = attribute.value
elif type(attribute) is SubGraphAttribute:
# if the key is a subgraph we create an empty node
subgraphdef = nodedef.subgraphs[key]
subgraphdef = attribute.value
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
Expand All @@ -558,8 +556,8 @@ def _get_children():
children[key] = _graph_unflatten(
subgraphdef, substate, index_ref, index_ref_cache
)
elif key in nodedef.leaves:
variabledef = nodedef.leaves[key]
elif type(attribute) is LeafAttribute:
variabledef = attribute.value
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[variabledef.index]
Expand All @@ -572,19 +570,21 @@ def _get_children():
else:
raise RuntimeError(f'Unknown static field: {key!r}')
else:
state_keys.remove(key)
value = state[key]
if key in nodedef.static_fields:
# if key in nodedef.static_fields:
if type(attribute) is StaticAttribute:
raise ValueError(
f'Got state for static field {key!r}, this is not supported.'
)
if key in nodedef.subgraphs:
elif type(attribute) is SubGraphAttribute:
if is_state_leaf(value):
raise ValueError(
f'Expected value of type {nodedef.subgraphs[key]} for '
f'Expected value of type {attribute.value} for '
f'{key!r}, but got {value!r}'
)
assert isinstance(value, dict)
subgraphdef = nodedef.subgraphs[key]
subgraphdef = attribute.value

if isinstance(subgraphdef, NodeRef):
children[key] = index_ref[subgraphdef.index]
Expand All @@ -593,8 +593,8 @@ def _get_children():
subgraphdef, value, index_ref, index_ref_cache
)

elif key in nodedef.leaves:
variabledef = nodedef.leaves[key]
elif type(attribute) is LeafAttribute:
variabledef = attribute.value

if variabledef.index in index_ref:
# add an existing variable
Expand Down Expand Up @@ -631,6 +631,10 @@ def _get_children():
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')

# NOTE: we could allw adding new StateLeafs here
if state_keys:
raise ValueError(f'Unknown keys: {state_keys}')

return children

if isinstance(node_impl, GraphNodeImpl):
Expand Down
12 changes: 11 additions & 1 deletion flax/nnx/reprlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,14 @@ def __nnx_repr__(self):
yield Object(type='', value_sep=': ', start='{', end='}')

for key, value in self.mapping.items():
yield Attr(repr(key), value)
yield Attr(repr(key), value)

@dataclasses.dataclass(repr=False)
class PrettySequence(Representable):
list: tp.Sequence

def __nnx_repr__(self):
yield Object(type='', value_sep='', start='[', end=']')

for value in self.list:
yield Attr('', value)
3 changes: 1 addition & 2 deletions flax/nnx/scripts/run-all-examples.bash
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
set -e

source .venv/bin/activate
cd flax/nnx

for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do
for f in $(find examples/nnx_toy_examples -name "*.py" -maxdepth 1); do
echo -e "\n---------------------------------"
echo "$f"
echo "---------------------------------"
Expand Down
15 changes: 10 additions & 5 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,11 +1341,16 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef):
global_index_mapping[nd.index] = nd.index
if isinstance(nd, graph.NodeRef):
return
for sub_nd in nd.subgraphs.values():
per_node_def(sub_nd)
for l in nd.leaves.values():
if isinstance(l, (graph.VariableDef, graph.NodeRef)) and l.index >= 0:
global_index_mapping[l.index] = l.index

for attribute in nd.attributes:
if type(attribute) is graph.SubGraphAttribute:
per_node_def(attribute.value)
elif (
type(attribute) is graph.LeafAttribute
and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef))
and attribute.value.index >= 0
):
global_index_mapping[attribute.value.index] = attribute.value.index
return

per_node_def(ns._graphdef)
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self):

assert 'tree' in state
assert 'a' in state.tree
assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree
assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree

m2 = nnx.merge(graphdef, state)

Expand Down
Loading