Skip to content

Commit 90c08ca

Browse files
levskayaFlax Authors
authored andcommitted
Use jax.named_scope for name stack rather than named_call.
Removes old function-transform-based implementation of named call in favor of new (post 0.3.14 JAX) jax.named_scope context manager mechanism. Refactors some tests. PiperOrigin-RevId: 464144791
1 parent 8d855f9 commit 90c08ca

File tree

11 files changed

+157
-220
lines changed

11 files changed

+157
-220
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ vNext
66
(Add your change to a random empty line to avoid merge conflicts)
77
-
88
-
9-
-
9+
- Switched to using `jax.named_scope` for all profile naming, cut some pointless
10+
stack traces out.
1011
-
1112
-
1213
-

flax/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
2222
.. data:: flax_profile
2323
24-
Whether to automatically wrap Module methods with named_call for profiles.
25-
Set by the FLAX_PROFILE environment variable. Defaults to False.
24+
Whether to automatically wrap Module methods with jax.named_scope for
25+
profiles. Set by the FLAX_PROFILE environment variable. Defaults to True.
2626
"""
2727

2828
import os
@@ -38,6 +38,8 @@ def bool_env(varname: str, default: bool) -> bool:
3838
Args:
3939
varname: the name of the variable
4040
default: the default boolean value
41+
Returns:
42+
boolean return value derived from defaults and environment.
4143
Raises: ValueError if the environment variable is anything else.
4244
"""
4345
val = os.getenv(varname, str(default))
@@ -56,10 +58,8 @@ def bool_env(varname: str, default: bool) -> bool:
5658
# Whether to hide flax-internal stack frames from tracebacks.
5759
flax_filter_frames = bool_env('FLAX_FILTER_FRAMES', True)
5860

59-
# Whether to automatically wrap Module methods with named_call for profiles.
60-
# We profile by default if JAX's dynamic name-stack based named_call is used.
61-
flax_profile = (bool_env('FLAX_PROFILE', False) or
62-
bool_env('JAX_EXPERIMENTAL_NAME_STACK', False))
61+
# Whether to run Module methods under jax.named_scope for profiles.
62+
flax_profile = bool_env('FLAX_PROFILE', True)
6363

6464
# Whether to use the lazy rng implementation
6565
flax_lazy_rng = bool_env('FLAX_LAZY_RNG', True)

flax/core/lift.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,22 +1345,6 @@ def inner_loop(scope, carry):
13451345
return fn
13461346

13471347

1348-
def named_call(fn: Callable[..., Any], name: str) -> Callable[..., Any]:
1349-
"""Adds a name scope to `fn` during profiling."""
1350-
def inner(scope_fn, repack_fn, variable_groups, rng_groups, args, kwargs):
1351-
@functools.wraps(fn)
1352-
def named(variable_groups, rng_groups):
1353-
scope = scope_fn(variable_groups, rng_groups)
1354-
y = fn(scope, *args, **kwargs)
1355-
return y, repack_fn(scope)
1356-
named = jax.named_call(named, name=name)
1357-
return named(variable_groups, rng_groups)
1358-
lifted = pack(inner, (True,), (True,), (True,))
1359-
def wrapper(scope, *args, **kwargs):
1360-
return lifted(scope, args, kwargs)
1361-
return wrapper
1362-
1363-
13641348
def _unzip2(xs):
13651349
ys = tuple(zip(*xs))
13661350
return ys if ys else ((), ())

flax/core/scope.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def child(self,
581581
fn: the function to partially apply the child Scope to.
582582
name: optional name of the child.
583583
prefix: prefix used for generating name if it is `None`.
584-
named_call: if true, `fn` will be wrapped with `lift.named_call`. The XLA
584+
named_call: if true, `fn` will be run under `jax.named_scope`. The XLA
585585
profiler will use this to name tag the computation.
586586
**partial_kwargs: additional kwargs partially applied to `fn`.
587587
@@ -593,15 +593,16 @@ def child(self,
593593
prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else ''
594594
name = self.default_name(prefix)
595595
scope = self.push(name)
596-
if named_call:
597-
# We import named_call at runtime to avoid a circular import issue.
598-
from . import lift # pylint: disable=g-import-not-at-top
599-
fn = lift.named_call(fn, name)
600596

601597
@functools.wraps(fn)
602598
def wrapper(*args, **kwargs):
603599
kwargs = dict(partial_kwargs, **kwargs)
604-
return fn(scope.rewound(), *args, **kwargs)
600+
if named_call:
601+
with jax.named_scope(name):
602+
res = fn(scope.rewound(), *args, **kwargs)
603+
else:
604+
res = fn(scope.rewound(), *args, **kwargs)
605+
return res
605606

606607
return wrapper
607608

flax/errors.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -588,19 +588,11 @@ def __init__(self, path, step):
588588
#################################################
589589

590590
class TransformedMethodReturnValueError(FlaxError):
591-
"""
592-
Transformed Module methods cannot return other Modules or Variables.
591+
"""Transformed Module methods cannot return other Modules or Variables."""
593592

594-
This commonly occurs when ``@nn.named_call`` is automatically applied to
595-
helper constructor methods when profiling is enabled (``FLAX_PROFILE=true``
596-
environment variable or via ``nn.enable_named_call()``), and can be mitigated
597-
by using the ``@nn.nowrap`` decorator to prevent automatic wrapping.
598-
"""
599593
def __init__(self, name):
600594
super().__init__(
601-
f'Transformed module method {name} cannot return Modules or Variables. '
602-
f'For helper constructor methods use the @nn.nowrap decorator to prevent '
603-
f'decoration by the automatic named_call transform.')
595+
f'Transformed module method {name} cannot return Modules or Variables.')
604596

605597

606598
class TransformTargetError(FlaxError):

flax/linen/module.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,26 @@ class _Sentinel:
142142
_use_named_call = config.flax_profile
143143

144144

145+
def _derive_profiling_name(module, fn):
146+
def _get_fn_name(fn):
147+
if isinstance(fn, functools.partial):
148+
return _get_fn_name(fn.func)
149+
return fn.__name__
150+
fn_name = _get_fn_name(fn)
151+
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
152+
module_name = module.name or module.__class__.__name__
153+
return f'{module_name}{method_suffix}'
154+
155+
145156
def enable_named_call():
146157
"""Enables named call wrapping for labelling profile traces.
147158
148159
When named call wrapping is enabled all JAX ops executed in a Module
149-
will be wrapped with ``jax.named_call``. The ``Module`` class name will
160+
will be run under ``jax.named_scope``. The ``Module`` class name will
150161
show up around the operations belonging to that Module in the
151162
Tensorboard profiling UI, simplifying the profiling process.
152163
153-
Note that ``jax.named_call`` only works for
164+
Note that ``jax.named_scope`` only works for
154165
compiled functions (e.g.: using jax.jit or jax.pmap).
155166
"""
156167
global _use_named_call
@@ -281,9 +292,6 @@ def nowrap(fun: _CallableT) -> _CallableT:
281292
with the state handler or a separate named_call transform.
282293
283294
This is needed in several concrete instances:
284-
- if you have a helper method that returns Modules or Variables to prevent
285-
it from being functionalized by named_call. (Functionalized methods
286-
can't return Modules/Variables.)
287295
- if you're subclassing a method like Module.param and don't want this
288296
overriden core function decorated with the state management wrapper.
289297
- If you want a method to be callable from an unbound Module (e.g.: a
@@ -608,12 +616,7 @@ def _wrap_module_methods(cls):
608616
method = getattr(cls, key)
609617
if hasattr(method, 'nowrap'):
610618
continue
611-
wrapped_method = wrap_method_once(method)
612-
if key != 'setup':
613-
# We import named_call at runtime to avoid a circular import issue.
614-
from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top
615-
wrapped_method = named_call(wrapped_method, force=False)
616-
setattr(cls, key, wrapped_method)
619+
setattr(cls, key, wrap_method_once(method))
617620
return cls
618621

619622
def _call_wrapped_method(self, fun, args, kwargs):
@@ -649,7 +652,11 @@ def _call_wrapped_method(self, fun, args, kwargs):
649652
self._state.in_compact_method = True
650653
_context.module_stack.append(self)
651654
try:
652-
y = fun(self, *args, **kwargs)
655+
if _use_named_call:
656+
with jax.named_scope(_derive_profiling_name(self, fun)):
657+
y = fun(self, *args, **kwargs)
658+
else:
659+
y = fun(self, *args, **kwargs)
653660
if _context.capture_stack:
654661
filter_fn = _context.capture_stack[-1]
655662
if filter_fn and filter_fn(self, fun_name):
@@ -880,7 +887,7 @@ def run_setup_only(x):
880887
def _name_taken(self,
881888
name: str,
882889
module: 'Module' = None,
883-
reuse_scopes : bool = False) -> bool:
890+
reuse_scopes: bool = False) -> bool:
884891
if name in _all_names_on_object(self):
885892
val = getattr(self, name, None)
886893
if module is not None and val is module:

flax/linen/transforms.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from flax.linen.module import Variable
4040
from flax.linen.module import wrap_method_once
4141
from flax.linen.module import _get_unbound_fn
42+
from flax.linen.module import _derive_profiling_name
4243
import jax
4344

4445
traceback_util.register_exclusion(__file__)
@@ -1296,55 +1297,25 @@ def shared_forward_fn(*args, needs_residual, **kwargs):
12961297
multi_scope=False)
12971298

12981299

1299-
# Special case of decorator_lift_transform to handle named calls for profiling.
13001300
def named_call(class_fn, force=True):
13011301
"""Labels a method for labelled traces in profiles.
13021302
1303+
Note that it is better to use the `jax.named_scope` context manager directly
1304+
to add names to JAX's metadata name stack.
1305+
13031306
Args:
13041307
class_fn: The class method to label.
13051308
force: If True, the named_call transform is applied even if it is globally
13061309
disabled. (e.g.: by calling `flax.linen.disable_named_call()`)
13071310
Returns:
13081311
A wrapped version of ``class_fn`` that is labeled.
13091312
"""
1310-
if (hasattr(jax.config, 'jax_experimental_name_stack') and
1311-
jax.config.jax_experimental_name_stack):
1312-
# Use JAX's improved dynamic name-stack named_call.
1313-
# No transform boundary needed!
1314-
@functools.wraps(class_fn)
1315-
def wrapped_fn(self, *args, **kwargs):
1316-
fn_name = class_fn.__name__
1317-
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
1318-
module_name = self.name or self.__class__.__name__
1319-
full_name = f'{module_name}{method_suffix}'
1320-
return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
1321-
else:
1322-
# Use JAX's old purely-functional call-based named_call.
1323-
# Due to the ordering of method decorators, we must wrap the class_fn
1324-
# with the module state management wrapper first to maintain Module state
1325-
# correctly.
1326-
prewrapped_fn = wrap_method_once(class_fn)
1327-
@functools.wraps(prewrapped_fn)
1328-
def wrapped_fn(self, *args, **kwargs):
1329-
if ((not force and not linen_module._use_named_call)
1330-
or self._state.in_setup):
1331-
return prewrapped_fn(self, *args, **kwargs)
1332-
fn_name = class_fn.__name__
1333-
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
1334-
module_name = self.name or self.__class__.__name__
1335-
full_name = f'{module_name}{method_suffix}'
1336-
# make a scope-function to transform
1337-
def core_fn(scopes, *args, **kwargs):
1338-
cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes)
1339-
object.__setattr__(cloned, '_state', self._state.export())
1340-
res = prewrapped_fn(cloned, *args, **kwargs)
1341-
self._state.reimport(cloned._state)
1342-
_test_transformed_return_values(res, fn_name)
1343-
return res
1344-
# here we apply the given lifting transform to the scope-ingesting fn
1345-
trafo_fn = lift.named_call(core_fn, full_name)
1346-
module_scopes, args, kwargs = get_module_scopes(self, args, kwargs)
1347-
return trafo_fn(module_scopes, *args, **kwargs)
1348-
1313+
# We use JAX's dynamic name-stack named_call. No transform boundary needed!
1314+
@functools.wraps(class_fn)
1315+
def wrapped_fn(self, *args, **kwargs):
1316+
if ((not force and not linen_module._use_named_call) # pylint: disable=protected-access
1317+
or self._state.in_setup): # pylint: disable=protected-access
1318+
return class_fn(self, *args, **kwargs)
1319+
full_name = _derive_profiling_name(self, class_fn)
1320+
return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
13491321
return wrapped_fn
1350-

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
install_requires = [
2828
"numpy>=1.12",
29-
"jax>=0.3.2",
29+
"jax>=0.3.14",
3030
"matplotlib", # only needed for tensorboard export
3131
"msgpack",
3232
"optax",

tests/linen/linen_module_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from flax import struct
2727
from flax.core import Scope, freeze, tracers
2828
from flax.linen import compact
29-
from flax.linen.module import override_named_call
3029
import jax
3130
from jax import random
3231
from jax.nn import initializers
@@ -1584,13 +1583,12 @@ class MyComponent2(Generic[T], nn.Module):
15841583
class MyModule2(nn.Module):
15851584
submodule: MyComponent2[jnp.ndarray]
15861585

1587-
def test_named_call_rng_equivalance(self):
1586+
def test_jit_rng_equivalance(self):
15881587
model = nn.Dense(1, use_bias=False)
1589-
with override_named_call(False):
1590-
param = model.init(random.PRNGKey(0), np.ones((1, 1)))['params']['kernel']
1591-
with override_named_call(True):
1592-
param_2 = model.init(random.PRNGKey(0), np.ones(
1593-
(1, 1)))['params']['kernel']
1588+
jit_model = nn.jit(nn.Dense)(1, use_bias=False)
1589+
param = model.init(random.PRNGKey(0), np.ones((1, 1)))['params']['kernel']
1590+
param_2 = jit_model.init(random.PRNGKey(0), np.ones(
1591+
(1, 1)))['params']['kernel']
15941592
self.assertEqual(param, param_2)
15951593

15961594
def test_rng_reuse_after_rewind(self):

0 commit comments

Comments
 (0)