Skip to content

Commit 9eb0a61

Browse files
IvyZXFlax Authors
authored andcommitted
Partially revert #4192 which sets back a bunch of previous merged pushes.
PiperOrigin-RevId: 675337465
1 parent 03e034d commit 9eb0a61

File tree

17 files changed

+274
-158
lines changed

17 files changed

+274
-158
lines changed

flax/core/meta.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323

2424
import abc
25+
import dataclasses
2526
import functools
2627
from typing import Any, Generic, TypeVar
2728
from collections.abc import Callable
@@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
287288
"""Returns the ``NamedSharding`` for this partitioned value."""
288289
return jax.sharding.NamedSharding(mesh, self.get_partition_spec())
289290

291+
def to_nnx_metadata(self) -> dict[str, Any]:
292+
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
293+
metadata = vars(self)
294+
metadata['sharding'] = metadata.pop('names')
295+
return metadata
296+
297+
@classmethod
298+
def from_nnx_metadata(cls, metadata: dict[str, Any]):
299+
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
300+
metadata['names'] = metadata.pop('sharding')
301+
fields = {x.name for x in dataclasses.fields(cls)}
302+
return cls(**{k: v for k, v in metadata.items() if k in fields})
303+
290304

291305
def with_partitioning(
292306
fn: Callable[..., Any],

flax/errors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def __reduce__(self):
6464
return (FlaxError, (str(self),))
6565

6666

67+
#################################################
68+
# NNX errors #
69+
#################################################
70+
71+
72+
class TraceContextError(FlaxError):
73+
pass
74+
75+
6776
#################################################
6877
# lazy_init.py errors #
6978
#################################################

flax/linen/spmd.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any:
328328
else:
329329
return self.value
330330

331+
def to_nnx_metadata(self) -> dict[str, Any]:
332+
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
333+
metadata = vars(self)
334+
metadata['sharding'] = metadata.pop('names')
335+
metadata['sharding_rules'] = metadata.pop('rules')
336+
return metadata
337+
338+
@classmethod
339+
def from_nnx_metadata(cls, metadata: dict[str, Any]):
340+
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
341+
metadata['names'] = metadata.pop('sharding')
342+
metadata['rules'] = metadata.pop('sharding_rules')
343+
fields = {x.name for x in dataclasses.fields(cls)}
344+
return cls(**{k: v for k, v in metadata.items() if k in fields})
345+
331346

332347
def with_logical_partitioning(
333348
fn: Callable[..., Any],

flax/nnx/bridge/variables.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:
5858

5959

6060
def register_variable_name_type_pair(name, typ, overwrite = False):
61-
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
61+
"""Register a pair of Linen collection name and its NNX type."""
6262
if not overwrite and name in VariableTypeCache:
6363
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
64-
'To overwrite, call with `overwrite=True`.')
64+
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
6565
VariableTypeCache[name] = typ
6666

6767

@@ -85,8 +85,7 @@ def _variable_parents_count(t: type):
8585

8686

8787
class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
88-
"""Default Flax metadata class for `nnx.VariableState`.
89-
"""
88+
"""Default Flax metadata class for `nnx.VariableState`."""
9089

9190
var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
9291
value: Any = struct.field(pytree_node=True)
@@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
110109
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
111110
metadata = vs.get_metadata()
112111
if 'linen_meta_type' in metadata:
113-
if metadata['linen_meta_type'] is not meta.Partitioned:
114-
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
115-
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
116-
return NNXMeta(vs.type, vs.value, vs.get_metadata())
112+
linen_type = metadata['linen_meta_type']
113+
if hasattr(linen_type, 'from_nnx_metadata'):
114+
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
115+
return linen_type(vs.value, **metadata)
116+
return NNXMeta(vs.type, vs.value, metadata)
117117

118118

119119
def get_col_name(keypath: tp.Sequence[Any]) -> str:
@@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:
124124

125125

126126
def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
127-
"""Convert a Linen variable to an NNX variable.
128-
This process needs the collection name,
129-
"""
127+
"""Convert a Linen variable to an NNX variable."""
130128
vtype = variable_type(col)
131129
if isinstance(x, NNXMeta):
132130
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
133131
return x.var_type(x.value, **x.metadata)
134132
if isinstance(x, meta.AxisMetadata):
135-
if isinstance(x, meta.Partitioned):
136-
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
137-
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
138-
return vtype(x)
133+
x_metadata = vars(x)
134+
if hasattr(x, 'to_nnx_metadata'):
135+
x_metadata = x.to_nnx_metadata()
136+
assert hasattr(x, 'value')
137+
return vtype(**x_metadata, linen_meta_type=type(x))
138+
return vtype(x)

flax/nnx/bridge/wrappers.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
7474
module = fn
7575
assert callable(fn)
7676
else:
77-
if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)):
77+
if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module):
7878
raise ValueError(f'{fn = } needs to be a method of an NNX Module.')
7979
module = fn.__self__
8080
_set_initializing(module, True)
@@ -124,6 +124,7 @@ def __init__(
124124
self.linen_collections: tuple[str, ...] = ()
125125

126126
def lazy_init(self, *args, **kwargs):
127+
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
127128
return lazy_init(self, *args, **kwargs)
128129

129130
def __call__(
@@ -224,28 +225,6 @@ class ToLinen(linen.Module):
224225
skip_rng: bool = False
225226
metadata_type: tp.Type = bv.NNXMeta
226227

227-
def update_variables(self, module):
228-
"""Store the NNX module's graph def and state inside Linen module variables."""
229-
gdef, state = nnx.split(module)
230-
# Save the graph def.
231-
if self.is_mutable_collection('nnx'):
232-
self.put_variable('nnx', 'graphdef', gdef)
233-
# Sort all the variable types.
234-
types = set(jax.tree.leaves(
235-
jax.tree.map(lambda x: x.type, state,
236-
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
237-
types = bv.sort_variable_types(types)
238-
_, *state_by_types = nnx.split(module, *types)
239-
# Each variable type goes to its own linen collection, and
240-
# each attribute goes to its own linen variable
241-
for typ, state in zip(types, state_by_types):
242-
collection = bv.variable_type_name(typ)
243-
if self.is_mutable_collection(collection):
244-
for k, v in state.raw_mapping.items():
245-
v = jax.tree.map(bv.to_linen_var, v,
246-
is_leaf=lambda x: isinstance(x, nnx.VariableState))
247-
self.put_variable(collection, k, v)
248-
249228
@linen.compact
250229
def __call__(self, *args, **kwargs):
251230
# init codepath
@@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs):
255234
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
256235
module = self.nnx_class(*self.args, **module_kwargs)
257236
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
258-
self.update_variables(module)
237+
self._update_variables(module)
259238
return module(*args, **kwargs)
260239

261240
# apply codepath
@@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs):
270249
module = nnx.merge(gdef, nnx_state)
271250
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
272251
out = module(*args, **kwargs)
273-
self.update_variables(module)
252+
self._update_variables(module)
274253
return out
275254

255+
def _update_variables(self, module):
256+
"""Store the NNX module's graph def and state inside Linen module variables."""
257+
gdef, state = nnx.split(module)
258+
# Save the graph def.
259+
if self.is_mutable_collection('nnx'):
260+
self.put_variable('nnx', 'graphdef', gdef)
261+
# Sort all the variable types.
262+
types = set(jax.tree.leaves(
263+
jax.tree.map(lambda x: x.type, state,
264+
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
265+
types = bv.sort_variable_types(types)
266+
_, *state_by_types = nnx.split(module, *types)
267+
# Each variable type goes to its own linen collection, and
268+
# each attribute goes to its own linen variable
269+
for typ, state in zip(types, state_by_types):
270+
collection = bv.variable_type_name(typ)
271+
if self.is_mutable_collection(collection):
272+
for k, v in state.raw_mapping.items():
273+
v = jax.tree.map(bv.to_linen_var, v,
274+
is_leaf=lambda x: isinstance(x, nnx.VariableState))
275+
self.put_variable(collection, k, v)
276+
276277

277278
def to_linen(nnx_class: tp.Callable[..., Module], *args,
278279
name: str | None = None, **kwargs):
279-
"""Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields."""
280+
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
280281
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)

flax/nnx/errors.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

flax/nnx/extract.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from flax import struct
2424
from flax.nnx.object import Object
25-
from flax.typing import MISSING, PathParts
25+
from flax.typing import Missing, PathParts
2626
from flax.nnx import graph
2727

2828

@@ -59,7 +59,7 @@ def extract_graph_nodes(
5959
pytree: A,
6060
/,
6161
*,
62-
prefix: tp.Any = MISSING,
62+
prefix: tp.Any = Missing,
6363
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
6464
) -> (
6565
tuple[A, tuple[tp.Any, ...]]
@@ -101,7 +101,7 @@ def extract_graph_nodes(
101101

102102
pytree_out = jax.tree.unflatten(treedef, leaves)
103103

104-
if prefix is MISSING:
104+
if prefix is Missing:
105105
return pytree_out, tuple(nodes) # type: ignore[bad-return-type]
106106
else:
107107
return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type]
@@ -330,12 +330,13 @@ def to_tree(
330330
tree,
331331
/,
332332
*,
333-
prefix: tp.Any = MISSING,
333+
prefix: tp.Any = Missing,
334334
split_fn: tp.Callable[
335335
[graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any
336336
] = default_split_fn,
337337
map_non_graph_nodes: bool = False,
338338
ctxtag: str | None = None,
339+
check_aliasing: bool = True,
339340
) -> tp.Any:
340341
leaf_prefixes = broadcast_prefix(
341342
prefix,
@@ -351,9 +352,10 @@ def to_tree(
351352
with graph.split_context(ctxtag) as split_ctx:
352353
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
353354
if graph.is_graph_node(leaf):
354-
check_consistent_aliasing(
355-
leaf, leaf_prefix, node_prefixes=node_prefixes
356-
)
355+
if check_aliasing:
356+
check_consistent_aliasing(
357+
leaf, leaf_prefix, node_prefixes=node_prefixes
358+
)
357359
tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf)
358360
leaves_out.append(tree_node)
359361
else:
@@ -381,7 +383,7 @@ def from_tree(
381383
tree: tp.Any,
382384
/,
383385
*,
384-
prefix: tp.Any = MISSING,
386+
prefix: tp.Any = Missing,
385387
merge_fn: tp.Callable[
386388
[graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any
387389
] = merge_tree_node,

flax/nnx/object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import numpy as np
2626

2727
from flax.nnx import (
28-
errors,
2928
reprlib,
3029
tracers,
3130
)
3231
from flax.nnx import graph
3332
from flax.nnx.variables import Variable, VariableState
3433
from flax.typing import Key
34+
from flax import errors
3535

3636
G = tp.TypeVar('G', bound='Object')
3737

flax/nnx/spmd.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _add_axis(x: tp.Any):
4444
sharding.insert(index, axis_name)
4545
x.sharding = tuple(sharding) # type: ignore
4646

47-
x.add_axis(axis_name, index)
47+
x.add_axis(index, axis_name)
4848
return x
4949

5050
return jax.tree.map(
@@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any):
6161
sharding = list(x.sharding)
6262
assert sharding.pop(index) == axis_name
6363
x.sharding = tuple(sharding)
64-
x.remove_axis(axis_name, index)
64+
x.remove_axis(index, axis_name)
6565
return x
6666

6767
return jax.tree.map(
@@ -89,9 +89,15 @@ def _maybe_replicate(x):
8989
else:
9090
return None
9191

92+
def from_rules(sharding, sharding_rules):
93+
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
94+
return (rules[s] if s in rules else s for s in sharding)
95+
9296
def f(x):
9397
if isinstance(x, (variables.VariableState, variables.Variable)):
9498
if hasattr(x, 'sharding') and x.sharding:
99+
if hasattr(x, 'sharding_rules') and x.sharding_rules:
100+
return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules)))
95101
return x.replace(PartitionSpec(*x.sharding))
96102
else:
97103
return x.replace(_maybe_replicate(x.value))

flax/nnx/transforms/compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def jit_wrapper(*args, **kwargs):
324324
(args, kwargs),
325325
prefix=(in_shardings, kwarg_shardings),
326326
split_fn=_jit_split_fn,
327+
check_aliasing=in_shardings is not None,
327328
ctxtag='jit',
328329
)
329330
pure_args_out, pure_kwargs_out, pure_out = jitted_fn(

0 commit comments

Comments
 (0)