12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- """Utilities for working with pjit and partitioned models.
15
+ """Legacy utilities for working with pjit and partitioned models.
16
16
17
17
**Experimental: please give feedback, and expect changes.**
18
18
28
28
logical axis metadata to the underlying Lifted transformations.
29
29
"""
30
30
31
- import collections
32
- import contextlib
33
- import enum
34
31
import functools
35
32
import re
36
- import threading
37
- from typing import (Any , Callable , List , Mapping , Optional , Sequence , Tuple ,
38
- Union )
33
+ from typing import (Any , Callable , Mapping , Optional , Tuple )
34
+
39
35
import flax
40
36
from flax import linen as nn
37
+ from flax import struct
41
38
from flax .core .frozen_dict import freeze
42
39
from flax .core .frozen_dict import unfreeze
43
40
from flax .core .lift import In as ScanIn # pylint: disable=unused-import
44
41
from flax .core .lift import Out as ScanOut # pylint: disable=unused-import
45
- import flax .struct
42
+ from flax .linen .spmd import _axis_rules # pylint: disable=unused-import
43
+ from flax .linen .spmd import _AxisRules # pylint: disable=unused-import
44
+ from flax .linen .spmd import _is_logical_spec
45
+ from flax .linen .spmd import _with_sharding_constraint # pylint: disable=unused-import
46
+ from flax .linen .spmd import Array # pylint: disable=unused-import
47
+ from flax .linen .spmd import ArrayPytree # pylint: disable=unused-import
48
+ from flax .linen .spmd import get_logical_axis_rules as get_axis_rules # pylint: disable=unused-import
49
+ from flax .linen .spmd import logical_axis_rules as axis_rules # pylint: disable=unused-import
50
+ from flax .linen .spmd import logical_to_mesh # pylint: disable=unused-import
51
+ from flax .linen .spmd import logical_to_mesh_axes # pylint: disable=unused-import
52
+ from flax .linen .spmd import LogicalPartitionSpec # pylint: disable=unused-import
53
+ from flax .linen .spmd import LogicalPartitionSpecPytree
54
+ from flax .linen .spmd import LogicalRules # pylint: disable=unused-import
55
+ from flax .linen .spmd import PartitionSpecPytree # pylint: disable=unused-import
56
+ from flax .linen .spmd import RulesFallback
57
+ from flax .linen .spmd import set_logical_axis_rules as set_axis_rules # pylint: disable=unused-import
58
+ from flax .linen .spmd import with_logical_constraint as with_sharding_constraint
46
59
from flax .traverse_util import flatten_dict
47
60
from flax .traverse_util import unflatten_dict
48
61
import jax
49
- from jax .experimental import maps
50
62
from jax .experimental import pjit
51
63
52
- # Real types and dummy aliases for documentation
53
- LogicalRules = Sequence [Tuple [str , Union [str , Tuple [str ], None ]]]
54
- Array = Any # pylint: disable=invalid-name
55
- ArrayPytree = Any # pylint: disable=invalid-name
56
- LogicalPartitionSpec = Any # pylint: disable=invalid-name
57
- LogicalPartitionSpecPytree = Any # pylint: disable=invalid-name
58
- PartitionSpecPytree = Any # pylint: disable=invalid-name
59
64
60
- # Dynamic Axis Mapping Context
61
65
# ------------------------------------------------------------------------------
62
-
63
-
64
- class _AxisRules :
65
- """Dynamic logical axis to mesh axis binding context."""
66
-
67
- def __init__ (self ):
68
- self ._thread_data = threading .local ()
69
-
70
- @property
71
- def rules (self ) -> LogicalRules :
72
- if not hasattr (self ._thread_data , 'rules' ):
73
- self ._thread_data .rules = ()
74
- return self ._thread_data .rules
75
-
76
- @rules .setter
77
- def rules (self , value : LogicalRules ):
78
- self ._thread_data .rules = value
79
-
80
-
81
- # Global axis binding context.
82
- _axis_rules = _AxisRules ()
83
-
84
-
85
- def set_axis_rules (rules : LogicalRules ):
86
- """Sets the global logical axis to mesh axis binding."""
87
- _axis_rules .rules = rules
88
-
89
-
90
- def get_axis_rules () -> LogicalRules :
91
- """Returns the global logical axis to mesh axis binding."""
92
- return _axis_rules .rules
93
-
94
-
95
- @contextlib .contextmanager
96
- def axis_rules (rules : LogicalRules ):
97
- """Context manager for setting the logical to mesh axis bindings."""
98
- old_rules = _axis_rules .rules
99
- try :
100
- _axis_rules .rules = rules
101
- yield
102
- finally :
103
- _axis_rules .rules = old_rules
104
-
105
-
106
- class _UnassignedAxis :
107
- """Sentinel class for unassigned logical axis name."""
108
-
109
- def __repr__ (self ):
110
- return 'UnassignedAxis'
111
-
112
- def __bool__ (self ):
113
- return False
114
-
115
-
116
- _unassigned_axis = _UnassignedAxis ()
117
-
118
-
119
- def _mesh_assignment_free (new_assignment , existing_assignments ):
120
- """Determines if a given mesh axis has already been assigned."""
121
- new = set (jax .tree_util .tree_leaves (new_assignment ))
122
- existing = set (jax .tree_util .tree_leaves (existing_assignments ))
123
- if existing .intersection (new ):
124
- return False
125
- return True
126
-
127
-
128
- def _logical_to_mesh_axes (
129
- array_dim_names : Optional [Sequence [Optional [str ]]],
130
- rules : Optional [LogicalRules ] = None ,
131
- ) -> Optional [List [Union [_UnassignedAxis , None , str , Tuple [str ]]]]:
132
- """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis."""
133
- if array_dim_names is None :
134
- return None
135
- if rules is None :
136
- rules = _axis_rules .rules
137
- axis_name_counts = collections .Counter (array_dim_names )
138
- dups = tuple (
139
- k for k , v in axis_name_counts .items () if v > 1 and k is not None )
140
- if dups :
141
- raise ValueError (
142
- f'Unsupported: Dimensions { dups } occur more than once in array names.' )
143
- if not isinstance (rules , (tuple , list )):
144
- raise ValueError ('Unknown axis rule specification type.' )
145
- # We assign mesh axes using a priority based ruleset over logical axis names.
146
- result : List [Union [_UnassignedAxis , None , str , Tuple [str ]]]
147
- result = [_unassigned_axis ] * len (array_dim_names )
148
- for rule_model_name , rule_mesh_names in rules :
149
- if rule_model_name in array_dim_names :
150
- pos = array_dim_names .index (rule_model_name )
151
- if (_mesh_assignment_free (rule_mesh_names , result ) and
152
- result [pos ] == _unassigned_axis ):
153
- result [pos ] = rule_mesh_names
154
- return result
155
-
156
-
157
- def logical_to_mesh_axes (
158
- array_dim_names : Optional [Sequence [Optional [str ]]],
159
- rules : Optional [LogicalRules ] = None ,
160
- ) -> Optional [pjit .PartitionSpec ]:
161
- """Compute layout for an array.
162
-
163
- The rules are in order of precedence, and consist of pairs:
164
- (ArrayDimensionName, MeshDimensionName), meaning that the given array
165
- dimension (if present and unused) should be sharded across the given
166
- mesh dimension (if present and unused).
167
-
168
- A Layout of an Array is expressed as a tuple with one element for each
169
- dimension in the Array. The element is either None, or is the name of a
170
- mesh-dimension, meaning that this dimension of the array is sharded across
171
- this dimension of the mesh.
172
-
173
- For example, given an array with
174
- array_dim_names = ('batch', 'length', 'heads', 'features')
175
- and the layout rules are:
176
- rules = (('batch', 'X'),
177
- ('features', 'X'),
178
- ('heads', 'Y'),
179
- ('batch', 'Z'))
180
-
181
- then this function will return
182
-
183
- PartitionSpec('X', None, 'Y', None)
184
-
185
- Args:
186
- array_dim_names: Tuple of array dimension names or None.
187
- rules: Optional logical to mesh rules override. Defaults to using the
188
- rules defined in the dynamic context set from the `axis_rules` function.
189
-
190
- Returns:
191
- PartitionSpec for the parameter.
192
- """
193
- result = _logical_to_mesh_axes (array_dim_names , rules )
194
- if result is None :
195
- return None
196
- # We default to None - ie unsharded along the dimension.
197
- result = [None if x is _unassigned_axis else x for x in result ]
198
- return pjit .PartitionSpec (* result )
199
-
200
-
201
- def _global_mesh_defined () -> bool :
202
- """Checks if global xmap/pjit mesh resource environment is defined."""
203
- maps_env = maps .thread_resources .env
204
- return maps_env .physical_mesh .devices .shape != () # pylint: disable=g-explicit-bool-comparison
205
-
206
-
207
- class RulesFallback (enum .Enum ):
208
- """How a sharding constraint should behave when no matching rule is found."""
209
- AXIS_IS_UNSHARDED = 'axis_is_unsharded'
210
- RAISE_ERROR = 'raise_error'
211
- NO_CONSTRAINT = 'no_constraint'
212
-
213
-
214
- def _with_sharding_constraint (x : Array , axis_resources : Optional [pjit .PartitionSpec ]):
215
- """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
216
- if jax .devices ()[0 ].platform == 'cpu' or not _global_mesh_defined ():
217
- return x
218
- else :
219
- return pjit .with_sharding_constraint (x , axis_resources )
220
-
221
-
222
- def _with_sharding_constraint_one_fallback (
223
- axis_resources : LogicalPartitionSpec ,
224
- x : Array ,
225
- fallback : RulesFallback = RulesFallback .AXIS_IS_UNSHARDED ):
226
- """Either imposes a sharding constraint or applies fallback."""
227
- mesh_axes = _logical_to_mesh_axes (axis_resources )
228
- if mesh_axes is None :
229
- return _with_sharding_constraint (x , None )
230
-
231
- if fallback == RulesFallback .AXIS_IS_UNSHARDED :
232
- mesh_axes = [None if x is _unassigned_axis else x for x in mesh_axes ]
233
- else :
234
- if any (x is _unassigned_axis for x in mesh_axes ):
235
- if fallback == RulesFallback .RAISE_ERROR :
236
- raise ValueError (f'Axis names { axis_resources } did not match a rule' )
237
- else :
238
- return x
239
- return _with_sharding_constraint (x , pjit .PartitionSpec (* mesh_axes ))
240
-
241
-
242
- def _is_logical_spec (x ):
243
- return x is None or (
244
- isinstance (x , tuple ) and all (isinstance (e , str ) or e is None for e in x ))
245
-
246
-
247
- def with_sharding_constraint (
248
- x : ArrayPytree ,
249
- logical_axis_resources : LogicalPartitionSpecPytree ,
250
- fallback : RulesFallback = RulesFallback .AXIS_IS_UNSHARDED ):
251
- """Version of pjit's with_sharding_constraint that uses logical axis names."""
252
- # If no axis binding is set, this is a no-op.
253
- if not _axis_rules .rules or logical_axis_resources is None :
254
- return x
255
- # Translate logical names to mesh assignments.
256
- return jax .tree_util .tree_map (
257
- functools .partial (
258
- _with_sharding_constraint_one_fallback , fallback = fallback ),
259
- logical_axis_resources ,
260
- x ,
261
- is_leaf = _is_logical_spec )
66
+ # NOTICE: This experimental partitioning utility API is deprecated
67
+ #
68
+ # We intend to continue supporting it indefinitely for those using it, but
69
+ # we encourage new users to adopt the simpler metadata handling system found
70
+ # in "spmd.py".
71
+ # ------------------------------------------------------------------------------
262
72
263
73
264
74
# Annotated parameters and Module axis metadata handling.
265
75
# ------------------------------------------------------------------------------
266
76
267
77
268
- @flax . struct .dataclass
78
+ @struct .dataclass
269
79
class AxisMetadata :
270
80
"""Contains a tuple of axis names, which is passed through FLAX."""
271
- names : LogicalPartitionSpecPytree = flax . struct .field (pytree_node = False )
81
+ names : LogicalPartitionSpecPytree = struct .field (pytree_node = False )
272
82
273
83
274
84
def _param_with_axes_sow_reduce_fn (x , y ):
@@ -306,7 +116,7 @@ def param_with_axes(
306
116
init_fn ,
307
117
* init_args ,
308
118
axes : Optional [Tuple [str , ...]] = None ,
309
- module : Optional [nn .Module ] = None ):
119
+ module : Optional [' nn.Module' ] = None ):
310
120
"""Declares and returns a parameter with logical axes in the current Module.
311
121
312
122
See :mod:`flax.linen.module.param` for original docstring.
@@ -340,7 +150,7 @@ def param_with_axes(
340
150
pjit .PartitionSpec (* axes ))
341
151
# record logical axis constraint for global axis metadata
342
152
module .sow (
343
- 'params_axes' , f'{ name } _axes' , AxisMetadata (axes ), # type: ignore
153
+ 'params_axes' , f'{ name } _axes' , AxisMetadata (axes ), # type: ignore
344
154
reduce_fn = _param_with_axes_sow_reduce_fn )
345
155
return module_param
346
156
@@ -418,7 +228,7 @@ def variable_with_axes(
418
228
init_fn ,
419
229
* init_args ,
420
230
axes : Optional [Tuple [str , ...]] = None ,
421
- module : Optional [nn .Module ] = None ,
231
+ module : Optional [' nn.Module' ] = None ,
422
232
fallback : RulesFallback = RulesFallback .AXIS_IS_UNSHARDED ):
423
233
"""Declares and returns a variable with logical axes in the current Module.
424
234
@@ -459,7 +269,7 @@ def variable_with_axes(
459
269
if axes is not None :
460
270
# record logical axis constraint for global axis metadata
461
271
module .sow (
462
- f'{ collection } _axes' , f'{ name } _axes' , AxisMetadata (axes ), # type: ignore
272
+ f'{ collection } _axes' , f'{ name } _axes' , AxisMetadata (axes ), # type: ignore
463
273
reduce_fn = _param_with_axes_sow_reduce_fn )
464
274
return module_var
465
275
@@ -567,7 +377,7 @@ def remove_fn(x):
567
377
568
378
# pylint: disable=dangerous-default-value
569
379
def scan_with_axes (
570
- target : flax .linen .transforms .Target ,
380
+ target : ' flax.linen.transforms.Target' ,
571
381
variable_axes : Mapping [flax .core .lift .CollectionFilter ,
572
382
flax .core .lift .InOutScanAxis ] = {},
573
383
variable_broadcast : flax .core .lift .CollectionFilter = False ,
@@ -581,7 +391,7 @@ def scan_with_axes(
581
391
axis_name : str = 'layers' ,
582
392
axes_collections : Tuple [str , ...] = ('params' ,),
583
393
data_transform : Optional [Callable [..., Any ]] = None ,
584
- methods = None ) -> flax .linen .transforms .Target :
394
+ methods = None ) -> ' flax.linen.transforms.Target' :
585
395
"""Wrapped version of nn.scan that handles logical axis metadata."""
586
396
587
397
# we broadcast the static metadata collections.
@@ -616,7 +426,7 @@ def scan_with_axes(
616
426
617
427
618
428
# pylint: disable=dangerous-default-value
619
- def vmap_with_axes (target : flax .linen .transforms .Target ,
429
+ def vmap_with_axes (target : ' flax.linen.transforms.Target' ,
620
430
variable_axes : Mapping [flax .core .lift .CollectionFilter ,
621
431
flax .core .lift .InOutAxis ],
622
432
split_rngs : Mapping [flax .core .lift .PRNGSequenceFilter ,
@@ -627,7 +437,7 @@ def vmap_with_axes(target: flax.linen.transforms.Target,
627
437
axis_name : Optional [str ] = None ,
628
438
partitioning_axis_names : Mapping [Any , str ] = {},
629
439
spmd_axis_name : Optional [str ] = None ,
630
- methods = None ) -> flax .linen .transforms .Target :
440
+ methods = None ) -> ' flax.linen.transforms.Target' :
631
441
"""Wrapped version of nn.vmap that handles logical axis metadata."""
632
442
633
443
# tell normal vmap to broadcast axis metadata.
0 commit comments