20
20
from flax .core import meta
21
21
from flax .nnx import spmd
22
22
from flax .nnx import traversals
23
- from flax .nnx import variablelib as variableslib
23
+ from flax .nnx import variablelib
24
24
from flax .nnx .module import GraphDef
25
25
import typing as tp
26
26
29
29
B = TypeVar ('B' )
30
30
31
31
32
- #######################################################
33
- ### Variable type <-> Linen collection name mapping ###
34
- #######################################################
35
- # Assumption: the mapping is 1-1 and unique.
36
-
37
- VariableTypeCache : dict [str , tp .Type [variableslib .Variable [tp .Any ]]] = {}
38
-
39
-
40
- def variable_type (name : str ) -> tp .Type [variableslib .Variable [tp .Any ]]:
41
- """Given a Linen-style collection name, get or create its corresponding NNX Variable type."""
42
- if name not in VariableTypeCache :
43
- VariableTypeCache [name ] = type (name , (variableslib .Variable ,), {})
44
- return VariableTypeCache [name ]
45
-
46
-
47
- def variable_type_name (typ : tp .Type [variableslib .Variable [tp .Any ]]) -> str :
48
- """Given an NNX Variable type, get or create its Linen-style collection name.
49
-
50
- Should output the exact inversed result of `variable_type()`."""
51
- for name , t in VariableTypeCache .items ():
52
- if typ == t :
53
- return name
54
- name = typ .__name__
55
- if name in VariableTypeCache :
56
- raise ValueError (
57
- 'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
58
- 'It cannot be linked with this type {typ}.'
59
- )
60
- register_variable_name_type_pair (name , typ )
61
- return name
62
-
63
-
64
- def register_variable_name_type_pair (name , typ , overwrite = False ):
65
- """Register a pair of Linen collection name and its NNX type."""
66
- if not overwrite and name in VariableTypeCache :
67
- raise ValueError (f'Name { name } already mapped to type { VariableTypeCache [name ]} . '
68
- 'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.' )
69
- VariableTypeCache [name ] = typ
70
-
71
-
72
- # add known variable type names
73
- register_variable_name_type_pair ('params' , variableslib .Param )
74
- register_variable_name_type_pair ('batch_stats' , variableslib .BatchStat )
75
- register_variable_name_type_pair ('cache' , variableslib .Cache )
76
- register_variable_name_type_pair ('intermediates' , variableslib .Intermediate )
77
-
78
-
79
32
def sort_variable_types (types : tp .Iterable [type ]):
80
33
def _variable_parents_count (t : type ):
81
- return sum (1 for p in t .mro () if issubclass (p , variableslib .Variable ))
34
+ return sum (1 for p in t .mro () if issubclass (p , variablelib .Variable ))
82
35
parent_count = {t : _variable_parents_count (t ) for t in types }
83
36
return sorted (types , key = lambda t : - parent_count [t ])
84
37
@@ -91,7 +44,7 @@ def _variable_parents_count(t: type):
91
44
class NNXMeta (struct .PyTreeNode , meta .AxisMetadata [A ]):
92
45
"""Default Flax metadata class for `nnx.VariableState`."""
93
46
94
- var_type : type [variableslib .Variable [tp .Any ]] = struct .field (pytree_node = False )
47
+ var_type : type [variablelib .Variable [tp .Any ]] = struct .field (pytree_node = False )
95
48
value : Any = struct .field (pytree_node = True )
96
49
metadata : dict [str , tp .Any ] = struct .field (pytree_node = False )
97
50
@@ -114,11 +67,11 @@ def get_partition_spec(self) -> jax.sharding.PartitionSpec:
114
67
nnx_var = self .to_nnx_variable ().to_state ()
115
68
return spmd .get_partition_spec (nnx_var ).value
116
69
117
- def to_nnx_variable (self ) -> variableslib .Variable :
70
+ def to_nnx_variable (self ) -> variablelib .Variable :
118
71
return self .var_type (self .value , ** self .metadata )
119
72
120
73
121
- def is_vanilla_variable (vs : variableslib .VariableState ) -> bool :
74
+ def is_vanilla_variable (vs : variablelib .VariableState ) -> bool :
122
75
"""A variables state is vanilla if its metadata is essentially blank.
123
76
124
77
Returns False only if it has non-empty hooks or any non-built-in attribute.
@@ -132,7 +85,7 @@ def is_vanilla_variable(vs: variableslib.VariableState) -> bool:
132
85
return True
133
86
134
87
135
- def to_linen_var (vs : variableslib .VariableState ) -> meta .AxisMetadata :
88
+ def to_linen_var (vs : variablelib .VariableState ) -> meta .AxisMetadata :
136
89
metadata = vs .get_metadata ()
137
90
if 'linen_meta_type' in metadata :
138
91
linen_type = metadata ['linen_meta_type' ]
@@ -151,9 +104,9 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:
151
104
return str (keypath [0 ].key )
152
105
153
106
154
- def to_nnx_var (col : str , x : meta .AxisMetadata | Any ) -> variableslib .Variable :
107
+ def to_nnx_var (col : str , x : meta .AxisMetadata | Any ) -> variablelib .Variable :
155
108
"""Convert a Linen variable to an NNX variable."""
156
- vtype = variable_type (col )
109
+ vtype = variablelib . variable_type_from_name (col )
157
110
if isinstance (x , NNXMeta ):
158
111
assert vtype == x .var_type , f'Type stored in NNXMeta { x .var_type } != type inferred from collection name { vtype } '
159
112
return x .to_nnx_variable ()
@@ -196,14 +149,14 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
196
149
for kp , v in traversals .flatten_mapping (
197
150
nnx_attrs ,
198
151
is_leaf = lambda _ , x : isinstance (
199
- x , variableslib .Variable | variableslib .VariableState | GraphDef
152
+ x , variablelib .Variable | variablelib .VariableState | GraphDef
200
153
),
201
154
).items ():
202
- if isinstance (v , variableslib .Variable ):
203
- col_name = variable_type_name (type (v ))
155
+ if isinstance (v , variablelib .Variable ):
156
+ col_name = variablelib . variable_name_from_type (type (v ))
204
157
v = to_linen_var (v .to_state ())
205
- elif isinstance (v , variableslib .VariableState ):
206
- col_name = variable_type_name (v .type )
158
+ elif isinstance (v , variablelib .VariableState ):
159
+ col_name = variablelib . variable_name_from_type (v .type )
207
160
v = to_linen_var (v )
208
161
else :
209
162
col_name = 'nnx' # it must be an nnx.GraphDef, for some ToLinen submodule
0 commit comments