Skip to content

Commit 61634e4

Browse files
committed
Rename variable string mapping utils and move them to variableslib
1 parent d28f03f commit 61634e4

File tree

5 files changed

+67
-68
lines changed

5 files changed

+67
-68
lines changed

flax/nnx/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
from flax.typing import Initializer as Initializer
2020

2121
from .bridge import wrappers as wrappers
22-
from .bridge.variables import (
23-
register_variable_name_type_pair as register_variable_name_type_pair,
24-
)
2522
from .filterlib import WithTag as WithTag
2623
from .filterlib import PathContains as PathContains
2724
from .filterlib import OfType as OfType
@@ -163,6 +160,9 @@
163160
from .variablelib import VariableState as VariableState
164161
from .variablelib import VariableMetadata as VariableMetadata
165162
from .variablelib import with_metadata as with_metadata
163+
from .variablelib import variable_type_from_name as variable_type_from_name
164+
from .variablelib import variable_name_from_type as variable_name_from_type
165+
from .variablelib import register_variable_name_type_pair as register_variable_name_type_pair
166166
from .visualization import display as display
167167
from .extract import to_tree as to_tree
168168
from .extract import from_tree as from_tree

flax/nnx/bridge/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,4 @@
1919
from .wrappers import lazy_init as lazy_init
2020
from .wrappers import ToLinen as ToLinen
2121
from .wrappers import to_linen as to_linen
22-
from .variables import NNXMeta as NNXMeta
23-
from .variables import register_variable_name_type_pair as register_variable_name_type_pair
22+
from .variables import NNXMeta as NNXMeta

flax/nnx/bridge/variables.py

Lines changed: 13 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flax.core import meta
2121
from flax.nnx import spmd
2222
from flax.nnx import traversals
23-
from flax.nnx import variablelib as variableslib
23+
from flax.nnx import variablelib
2424
from flax.nnx.module import GraphDef
2525
import typing as tp
2626

@@ -29,56 +29,9 @@
2929
B = TypeVar('B')
3030

3131

32-
#######################################################
33-
### Variable type <-> Linen collection name mapping ###
34-
#######################################################
35-
# Assumption: the mapping is 1-1 and unique.
36-
37-
VariableTypeCache: dict[str, tp.Type[variableslib.Variable[tp.Any]]] = {}
38-
39-
40-
def variable_type(name: str) -> tp.Type[variableslib.Variable[tp.Any]]:
41-
"""Given a Linen-style collection name, get or create its corresponding NNX Variable type."""
42-
if name not in VariableTypeCache:
43-
VariableTypeCache[name] = type(name, (variableslib.Variable,), {})
44-
return VariableTypeCache[name]
45-
46-
47-
def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:
48-
"""Given an NNX Variable type, get or create its Linen-style collection name.
49-
50-
Should output the exact inversed result of `variable_type()`."""
51-
for name, t in VariableTypeCache.items():
52-
if typ == t:
53-
return name
54-
name = typ.__name__
55-
if name in VariableTypeCache:
56-
raise ValueError(
57-
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
58-
'It cannot be linked with this type {typ}.'
59-
)
60-
register_variable_name_type_pair(name, typ)
61-
return name
62-
63-
64-
def register_variable_name_type_pair(name, typ, overwrite = False):
65-
"""Register a pair of Linen collection name and its NNX type."""
66-
if not overwrite and name in VariableTypeCache:
67-
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
68-
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
69-
VariableTypeCache[name] = typ
70-
71-
72-
# add known variable type names
73-
register_variable_name_type_pair('params', variableslib.Param)
74-
register_variable_name_type_pair('batch_stats', variableslib.BatchStat)
75-
register_variable_name_type_pair('cache', variableslib.Cache)
76-
register_variable_name_type_pair('intermediates', variableslib.Intermediate)
77-
78-
7932
def sort_variable_types(types: tp.Iterable[type]):
8033
def _variable_parents_count(t: type):
81-
return sum(1 for p in t.mro() if issubclass(p, variableslib.Variable))
34+
return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable))
8235
parent_count = {t: _variable_parents_count(t) for t in types}
8336
return sorted(types, key=lambda t: -parent_count[t])
8437

@@ -91,7 +44,7 @@ def _variable_parents_count(t: type):
9144
class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
9245
"""Default Flax metadata class for `nnx.VariableState`."""
9346

94-
var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
47+
var_type: type[variablelib.Variable[tp.Any]] = struct.field(pytree_node=False)
9548
value: Any = struct.field(pytree_node=True)
9649
metadata: dict[str, tp.Any] = struct.field(pytree_node=False)
9750

@@ -114,11 +67,11 @@ def get_partition_spec(self) -> jax.sharding.PartitionSpec:
11467
nnx_var = self.to_nnx_variable().to_state()
11568
return spmd.get_partition_spec(nnx_var).value
11669

117-
def to_nnx_variable(self) -> variableslib.Variable:
70+
def to_nnx_variable(self) -> variablelib.Variable:
11871
return self.var_type(self.value, **self.metadata)
11972

12073

121-
def is_vanilla_variable(vs: variableslib.VariableState) -> bool:
74+
def is_vanilla_variable(vs: variablelib.VariableState) -> bool:
12275
"""A variables state is vanilla if its metadata is essentially blank.
12376
12477
Returns False only if it has non-empty hooks or any non-built-in attribute.
@@ -132,7 +85,7 @@ def is_vanilla_variable(vs: variableslib.VariableState) -> bool:
13285
return True
13386

13487

135-
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
88+
def to_linen_var(vs: variablelib.VariableState) -> meta.AxisMetadata:
13689
metadata = vs.get_metadata()
13790
if 'linen_meta_type' in metadata:
13891
linen_type = metadata['linen_meta_type']
@@ -151,9 +104,9 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:
151104
return str(keypath[0].key)
152105

153106

154-
def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
107+
def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variablelib.Variable:
155108
"""Convert a Linen variable to an NNX variable."""
156-
vtype = variable_type(col)
109+
vtype = variablelib.variable_type_from_name(col)
157110
if isinstance(x, NNXMeta):
158111
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
159112
return x.to_nnx_variable()
@@ -196,14 +149,14 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
196149
for kp, v in traversals.flatten_mapping(
197150
nnx_attrs,
198151
is_leaf=lambda _, x: isinstance(
199-
x, variableslib.Variable | variableslib.VariableState | GraphDef
152+
x, variablelib.Variable | variablelib.VariableState | GraphDef
200153
),
201154
).items():
202-
if isinstance(v, variableslib.Variable):
203-
col_name = variable_type_name(type(v))
155+
if isinstance(v, variablelib.Variable):
156+
col_name = variablelib.variable_name_from_type(type(v))
204157
v = to_linen_var(v.to_state())
205-
elif isinstance(v, variableslib.VariableState):
206-
col_name = variable_type_name(v.type)
158+
elif isinstance(v, variablelib.VariableState):
159+
col_name = variablelib.variable_name_from_type(v.type)
207160
v = to_linen_var(v)
208161
else:
209162
col_name = 'nnx' # it must be an nnx.GraphDef, for some ToLinen submodule

flax/nnx/bridge/wrappers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from flax.core import FrozenDict
2222
from flax.core import meta
2323
from flax.nnx import graph
24+
from flax.nnx import variablelib
2425
from flax.nnx.bridge import variables as bv
2526
from flax.nnx.module import GraphDef, Module
2627
from flax.nnx.object import Object
@@ -271,7 +272,7 @@ def _update_variables(self, module):
271272
# Each variable type goes to its own linen collection, and
272273
# each attribute goes to its own linen variable
273274
for typ, state in zip(types, state_by_types):
274-
collection = bv.variable_type_name(typ)
275+
collection = variablelib.variable_name_from_type(typ)
275276
if self.is_mutable_collection(collection):
276277
for k, v in state.raw_mapping.items():
277278
v = jax.tree.map(bv.to_linen_var, v,

flax/nnx/variablelib.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@
4444
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
4545
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
4646

47-
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
48-
4947

5048

5149
@dataclasses.dataclass
@@ -966,3 +964,51 @@ def split_flat_state(
966964
)
967965

968966
return flat_states
967+
968+
969+
970+
###################################################
971+
### Variable type/class <-> string name mapping ###
972+
###################################################
973+
# Assumption: the mapping is 1-1 and unique.
974+
975+
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
976+
977+
978+
def variable_type_from_name(name: str) -> tp.Type[Variable[tp.Any]]:
979+
"""Given a Linen-style collection name, get or create its NNX Variable class."""
980+
if name not in VariableTypeCache:
981+
VariableTypeCache[name] = type(name, (Variable,), {})
982+
return VariableTypeCache[name]
983+
984+
985+
def variable_name_from_type(typ: tp.Type[Variable[tp.Any]]) -> str:
986+
"""Given an NNX Variable type, get its Linen-style collection name.
987+
988+
Should output the exact inversed result of `variable_type_from_name()`."""
989+
for name, t in VariableTypeCache.items():
990+
if typ == t:
991+
return name
992+
name = typ.__name__
993+
if name in VariableTypeCache:
994+
raise ValueError(
995+
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
996+
'It cannot be linked with this type {typ}.'
997+
)
998+
register_variable_name_type_pair(name, typ)
999+
return name
1000+
1001+
1002+
def register_variable_name_type_pair(name, typ, overwrite = False):
1003+
"""Register a pair of Linen collection name and its NNX type."""
1004+
if not overwrite and name in VariableTypeCache:
1005+
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
1006+
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
1007+
VariableTypeCache[name] = typ
1008+
1009+
1010+
# add known variable type names
1011+
register_variable_name_type_pair('params', Param)
1012+
register_variable_name_type_pair('batch_stats', BatchStat)
1013+
register_variable_name_type_pair('cache', Cache)
1014+
register_variable_name_type_pair('intermediates', Intermediate)

0 commit comments

Comments
 (0)