Skip to content

Commit b13d6bc

Browse files
levskayaFlax Authors
authored andcommitted
add logical axis name indirection to new flax partitioning API
PiperOrigin-RevId: 495211165
1 parent 8a9fbc9 commit b13d6bc

File tree

5 files changed

+352
-231
lines changed

5 files changed

+352
-231
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ vNext
66
(Add your change to a random empty line to avoid merge conflicts)
77
-
88
-
9-
-
9+
- Added logical partitioning helpers for using pjit with Flax.
1010
-
1111
-
1212
-

flax/linen/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""The Flax Module system."""
1616

1717

18-
# pylint: disable=g-multiple-import
18+
# pylint: disable=g-multiple-import,useless-import-alias
1919
# re-export commonly used modules and functions
2020
from .activation import (
2121
PReLU as PReLU,
@@ -67,6 +67,16 @@
6767
unbox as unbox,
6868
PARTITION_NAME as PARTITION_NAME,
6969
)
70+
from .spmd import (
71+
logical_axis_rules as logical_axis_rules,
72+
set_logical_axis_rules as set_logical_axis_rules,
73+
get_logical_axis_rules as get_logical_axis_rules,
74+
logical_to_mesh_axes,
75+
logical_to_mesh,
76+
with_logical_constraint,
77+
LogicallyPartitioned as LogicallyPartitioned,
78+
with_logical_partitioning as with_logical_partitioning,
79+
)
7080
from .initializers import (
7181
ones as ones,
7282
zeros as zeros

flax/linen/partitioning.py

Lines changed: 37 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Utilities for working with pjit and partitioned models.
15+
"""Legacy utilities for working with pjit and partitioned models.
1616
1717
**Experimental: please give feedback, and expect changes.**
1818
@@ -28,247 +28,57 @@
2828
logical axis metadata to the underlying Lifted transformations.
2929
"""
3030

31-
import collections
32-
import contextlib
33-
import enum
3431
import functools
3532
import re
36-
import threading
37-
from typing import (Any, Callable, List, Mapping, Optional, Sequence, Tuple,
38-
Union)
33+
from typing import (Any, Callable, Mapping, Optional, Tuple)
34+
3935
import flax
4036
from flax import linen as nn
37+
from flax import struct
4138
from flax.core.frozen_dict import freeze
4239
from flax.core.frozen_dict import unfreeze
4340
from flax.core.lift import In as ScanIn # pylint: disable=unused-import
4441
from flax.core.lift import Out as ScanOut # pylint: disable=unused-import
45-
import flax.struct
42+
from flax.linen.spmd import _axis_rules # pylint: disable=unused-import
43+
from flax.linen.spmd import _AxisRules # pylint: disable=unused-import
44+
from flax.linen.spmd import _is_logical_spec
45+
from flax.linen.spmd import _with_sharding_constraint # pylint: disable=unused-import
46+
from flax.linen.spmd import Array # pylint: disable=unused-import
47+
from flax.linen.spmd import ArrayPytree # pylint: disable=unused-import
48+
from flax.linen.spmd import get_logical_axis_rules as get_axis_rules # pylint: disable=unused-import
49+
from flax.linen.spmd import logical_axis_rules as axis_rules # pylint: disable=unused-import
50+
from flax.linen.spmd import logical_to_mesh # pylint: disable=unused-import
51+
from flax.linen.spmd import logical_to_mesh_axes # pylint: disable=unused-import
52+
from flax.linen.spmd import LogicalPartitionSpec # pylint: disable=unused-import
53+
from flax.linen.spmd import LogicalPartitionSpecPytree
54+
from flax.linen.spmd import LogicalRules # pylint: disable=unused-import
55+
from flax.linen.spmd import PartitionSpecPytree # pylint: disable=unused-import
56+
from flax.linen.spmd import RulesFallback
57+
from flax.linen.spmd import set_logical_axis_rules as set_axis_rules # pylint: disable=unused-import
58+
from flax.linen.spmd import with_logical_constraint as with_sharding_constraint
4659
from flax.traverse_util import flatten_dict
4760
from flax.traverse_util import unflatten_dict
4861
import jax
49-
from jax.experimental import maps
5062
from jax.experimental import pjit
5163

52-
# Real types and dummy aliases for documentation
53-
LogicalRules = Sequence[Tuple[str, Union[str, Tuple[str], None]]]
54-
Array = Any # pylint: disable=invalid-name
55-
ArrayPytree = Any # pylint: disable=invalid-name
56-
LogicalPartitionSpec = Any # pylint: disable=invalid-name
57-
LogicalPartitionSpecPytree = Any # pylint: disable=invalid-name
58-
PartitionSpecPytree = Any # pylint: disable=invalid-name
5964

60-
# Dynamic Axis Mapping Context
6165
# ------------------------------------------------------------------------------
62-
63-
64-
class _AxisRules:
65-
"""Dynamic logical axis to mesh axis binding context."""
66-
67-
def __init__(self):
68-
self._thread_data = threading.local()
69-
70-
@property
71-
def rules(self) -> LogicalRules:
72-
if not hasattr(self._thread_data, 'rules'):
73-
self._thread_data.rules = ()
74-
return self._thread_data.rules
75-
76-
@rules.setter
77-
def rules(self, value: LogicalRules):
78-
self._thread_data.rules = value
79-
80-
81-
# Global axis binding context.
82-
_axis_rules = _AxisRules()
83-
84-
85-
def set_axis_rules(rules: LogicalRules):
86-
"""Sets the global logical axis to mesh axis binding."""
87-
_axis_rules.rules = rules
88-
89-
90-
def get_axis_rules() -> LogicalRules:
91-
"""Returns the global logical axis to mesh axis binding."""
92-
return _axis_rules.rules
93-
94-
95-
@contextlib.contextmanager
96-
def axis_rules(rules: LogicalRules):
97-
"""Context manager for setting the logical to mesh axis bindings."""
98-
old_rules = _axis_rules.rules
99-
try:
100-
_axis_rules.rules = rules
101-
yield
102-
finally:
103-
_axis_rules.rules = old_rules
104-
105-
106-
class _UnassignedAxis:
107-
"""Sentinel class for unassigned logical axis name."""
108-
109-
def __repr__(self):
110-
return 'UnassignedAxis'
111-
112-
def __bool__(self):
113-
return False
114-
115-
116-
_unassigned_axis = _UnassignedAxis()
117-
118-
119-
def _mesh_assignment_free(new_assignment, existing_assignments):
120-
"""Determines if a given mesh axis has already been assigned."""
121-
new = set(jax.tree_util.tree_leaves(new_assignment))
122-
existing = set(jax.tree_util.tree_leaves(existing_assignments))
123-
if existing.intersection(new):
124-
return False
125-
return True
126-
127-
128-
def _logical_to_mesh_axes(
129-
array_dim_names: Optional[Sequence[Optional[str]]],
130-
rules: Optional[LogicalRules] = None,
131-
) -> Optional[List[Union[_UnassignedAxis, None, str, Tuple[str]]]]:
132-
"""Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis."""
133-
if array_dim_names is None:
134-
return None
135-
if rules is None:
136-
rules = _axis_rules.rules
137-
axis_name_counts = collections.Counter(array_dim_names)
138-
dups = tuple(
139-
k for k, v in axis_name_counts.items() if v > 1 and k is not None)
140-
if dups:
141-
raise ValueError(
142-
f'Unsupported: Dimensions {dups} occur more than once in array names.')
143-
if not isinstance(rules, (tuple, list)):
144-
raise ValueError('Unknown axis rule specification type.')
145-
# We assign mesh axes using a priority based ruleset over logical axis names.
146-
result: List[Union[_UnassignedAxis, None, str, Tuple[str]]]
147-
result = [_unassigned_axis] * len(array_dim_names)
148-
for rule_model_name, rule_mesh_names in rules:
149-
if rule_model_name in array_dim_names:
150-
pos = array_dim_names.index(rule_model_name)
151-
if (_mesh_assignment_free(rule_mesh_names, result) and
152-
result[pos] == _unassigned_axis):
153-
result[pos] = rule_mesh_names
154-
return result
155-
156-
157-
def logical_to_mesh_axes(
158-
array_dim_names: Optional[Sequence[Optional[str]]],
159-
rules: Optional[LogicalRules] = None,
160-
) -> Optional[pjit.PartitionSpec]:
161-
"""Compute layout for an array.
162-
163-
The rules are in order of precedence, and consist of pairs:
164-
(ArrayDimensionName, MeshDimensionName), meaning that the given array
165-
dimension (if present and unused) should be sharded across the given
166-
mesh dimension (if present and unused).
167-
168-
A Layout of an Array is expressed as a tuple with one element for each
169-
dimension in the Array. The element is either None, or is the name of a
170-
mesh-dimension, meaning that this dimension of the array is sharded across
171-
this dimension of the mesh.
172-
173-
For example, given an array with
174-
array_dim_names = ('batch', 'length', 'heads', 'features')
175-
and the layout rules are:
176-
rules = (('batch', 'X'),
177-
('features', 'X'),
178-
('heads', 'Y'),
179-
('batch', 'Z'))
180-
181-
then this function will return
182-
183-
PartitionSpec('X', None, 'Y', None)
184-
185-
Args:
186-
array_dim_names: Tuple of array dimension names or None.
187-
rules: Optional logical to mesh rules override. Defaults to using the
188-
rules defined in the dynamic context set from the `axis_rules` function.
189-
190-
Returns:
191-
PartitionSpec for the parameter.
192-
"""
193-
result = _logical_to_mesh_axes(array_dim_names, rules)
194-
if result is None:
195-
return None
196-
# We default to None - ie unsharded along the dimension.
197-
result = [None if x is _unassigned_axis else x for x in result]
198-
return pjit.PartitionSpec(*result)
199-
200-
201-
def _global_mesh_defined() -> bool:
202-
"""Checks if global xmap/pjit mesh resource environment is defined."""
203-
maps_env = maps.thread_resources.env
204-
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
205-
206-
207-
class RulesFallback(enum.Enum):
208-
"""How a sharding constraint should behave when no matching rule is found."""
209-
AXIS_IS_UNSHARDED = 'axis_is_unsharded'
210-
RAISE_ERROR = 'raise_error'
211-
NO_CONSTRAINT = 'no_constraint'
212-
213-
214-
def _with_sharding_constraint(x: Array, axis_resources: Optional[pjit.PartitionSpec]):
215-
"""Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
216-
if jax.devices()[0].platform == 'cpu' or not _global_mesh_defined():
217-
return x
218-
else:
219-
return pjit.with_sharding_constraint(x, axis_resources)
220-
221-
222-
def _with_sharding_constraint_one_fallback(
223-
axis_resources: LogicalPartitionSpec,
224-
x: Array,
225-
fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED):
226-
"""Either imposes a sharding constraint or applies fallback."""
227-
mesh_axes = _logical_to_mesh_axes(axis_resources)
228-
if mesh_axes is None:
229-
return _with_sharding_constraint(x, None)
230-
231-
if fallback == RulesFallback.AXIS_IS_UNSHARDED:
232-
mesh_axes = [None if x is _unassigned_axis else x for x in mesh_axes]
233-
else:
234-
if any(x is _unassigned_axis for x in mesh_axes):
235-
if fallback == RulesFallback.RAISE_ERROR:
236-
raise ValueError(f'Axis names {axis_resources} did not match a rule')
237-
else:
238-
return x
239-
return _with_sharding_constraint(x, pjit.PartitionSpec(*mesh_axes))
240-
241-
242-
def _is_logical_spec(x):
243-
return x is None or (
244-
isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x))
245-
246-
247-
def with_sharding_constraint(
248-
x: ArrayPytree,
249-
logical_axis_resources: LogicalPartitionSpecPytree,
250-
fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED):
251-
"""Version of pjit's with_sharding_constraint that uses logical axis names."""
252-
# If no axis binding is set, this is a no-op.
253-
if not _axis_rules.rules or logical_axis_resources is None:
254-
return x
255-
# Translate logical names to mesh assignments.
256-
return jax.tree_util.tree_map(
257-
functools.partial(
258-
_with_sharding_constraint_one_fallback, fallback=fallback),
259-
logical_axis_resources,
260-
x,
261-
is_leaf=_is_logical_spec)
66+
# NOTICE: This experimental partitioning utility API is deprecated
67+
#
68+
# We intend to continue supporting it indefinitely for those using it, but
69+
# we encourage new users to adopt the simpler metadata handling system found
70+
# in "spmd.py".
71+
# ------------------------------------------------------------------------------
26272

26373

26474
# Annotated parameters and Module axis metadata handling.
26575
# ------------------------------------------------------------------------------
26676

26777

268-
@flax.struct.dataclass
78+
@struct.dataclass
26979
class AxisMetadata:
27080
"""Contains a tuple of axis names, which is passed through FLAX."""
271-
names: LogicalPartitionSpecPytree = flax.struct.field(pytree_node=False)
81+
names: LogicalPartitionSpecPytree = struct.field(pytree_node=False)
27282

27383

27484
def _param_with_axes_sow_reduce_fn(x, y):
@@ -306,7 +116,7 @@ def param_with_axes(
306116
init_fn,
307117
*init_args,
308118
axes: Optional[Tuple[str, ...]] = None,
309-
module: Optional[nn.Module] = None):
119+
module: Optional['nn.Module'] = None):
310120
"""Declares and returns a parameter with logical axes in the current Module.
311121
312122
See :mod:`flax.linen.module.param` for original docstring.
@@ -340,7 +150,7 @@ def param_with_axes(
340150
pjit.PartitionSpec(*axes))
341151
# record logical axis constraint for global axis metadata
342152
module.sow(
343-
'params_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore
153+
'params_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore
344154
reduce_fn=_param_with_axes_sow_reduce_fn)
345155
return module_param
346156

@@ -418,7 +228,7 @@ def variable_with_axes(
418228
init_fn,
419229
*init_args,
420230
axes: Optional[Tuple[str, ...]] = None,
421-
module: Optional[nn.Module] = None,
231+
module: Optional['nn.Module'] = None,
422232
fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED):
423233
"""Declares and returns a variable with logical axes in the current Module.
424234
@@ -459,7 +269,7 @@ def variable_with_axes(
459269
if axes is not None:
460270
# record logical axis constraint for global axis metadata
461271
module.sow(
462-
f'{collection}_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore
272+
f'{collection}_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore
463273
reduce_fn=_param_with_axes_sow_reduce_fn)
464274
return module_var
465275

@@ -567,7 +377,7 @@ def remove_fn(x):
567377

568378
# pylint: disable=dangerous-default-value
569379
def scan_with_axes(
570-
target: flax.linen.transforms.Target,
380+
target: 'flax.linen.transforms.Target',
571381
variable_axes: Mapping[flax.core.lift.CollectionFilter,
572382
flax.core.lift.InOutScanAxis] = {},
573383
variable_broadcast: flax.core.lift.CollectionFilter = False,
@@ -581,7 +391,7 @@ def scan_with_axes(
581391
axis_name: str = 'layers',
582392
axes_collections: Tuple[str, ...] = ('params',),
583393
data_transform: Optional[Callable[..., Any]] = None,
584-
methods=None) -> flax.linen.transforms.Target:
394+
methods=None) -> 'flax.linen.transforms.Target':
585395
"""Wrapped version of nn.scan that handles logical axis metadata."""
586396

587397
# we broadcast the static metadata collections.
@@ -616,7 +426,7 @@ def scan_with_axes(
616426

617427

618428
# pylint: disable=dangerous-default-value
619-
def vmap_with_axes(target: flax.linen.transforms.Target,
429+
def vmap_with_axes(target: 'flax.linen.transforms.Target',
620430
variable_axes: Mapping[flax.core.lift.CollectionFilter,
621431
flax.core.lift.InOutAxis],
622432
split_rngs: Mapping[flax.core.lift.PRNGSequenceFilter,
@@ -627,7 +437,7 @@ def vmap_with_axes(target: flax.linen.transforms.Target,
627437
axis_name: Optional[str] = None,
628438
partitioning_axis_names: Mapping[Any, str] = {},
629439
spmd_axis_name: Optional[str] = None,
630-
methods=None) -> flax.linen.transforms.Target:
440+
methods=None) -> 'flax.linen.transforms.Target':
631441
"""Wrapped version of nn.vmap that handles logical axis metadata."""
632442

633443
# tell normal vmap to broadcast axis metadata.

0 commit comments

Comments
 (0)