Skip to content

Use jax.named_scope for name stack rather than named_call. #2349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ vNext
(Add your change to a random empty line to avoid merge conflicts)
-
-
-
- Switched to using `jax.named_scope` for all profile naming, cut some pointless
stack traces out.
-
-
-
Expand Down
12 changes: 6 additions & 6 deletions flax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

.. data:: flax_profile

Whether to automatically wrap Module methods with named_call for profiles.
Set by the FLAX_PROFILE environment variable. Defaults to False.
Whether to automatically wrap Module methods with jax.named_scope for
profiles. Set by the FLAX_PROFILE environment variable. Defaults to True.
"""

import os
Expand All @@ -38,6 +38,8 @@ def bool_env(varname: str, default: bool) -> bool:
Args:
varname: the name of the variable
default: the default boolean value
Returns:
boolean return value derived from defaults and environment.
Raises: ValueError if the environment variable is anything else.
"""
val = os.getenv(varname, str(default))
Expand All @@ -56,10 +58,8 @@ def bool_env(varname: str, default: bool) -> bool:
# Whether to hide flax-internal stack frames from tracebacks.
flax_filter_frames = bool_env('FLAX_FILTER_FRAMES', True)

# Whether to automatically wrap Module methods with named_call for profiles.
# We profile by default if JAX's dynamic name-stack based named_call is used.
flax_profile = (bool_env('FLAX_PROFILE', False) or
bool_env('JAX_EXPERIMENTAL_NAME_STACK', False))
# Whether to run Module methods under jax.named_scope for profiles.
flax_profile = bool_env('FLAX_PROFILE', True)

# Whether to use the lazy rng implementation
flax_lazy_rng = bool_env('FLAX_LAZY_RNG', True)
16 changes: 0 additions & 16 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,22 +1345,6 @@ def inner_loop(scope, carry):
return fn


def named_call(fn: Callable[..., Any], name: str) -> Callable[..., Any]:
"""Adds a name scope to `fn` during profiling."""
def inner(scope_fn, repack_fn, variable_groups, rng_groups, args, kwargs):
@functools.wraps(fn)
def named(variable_groups, rng_groups):
scope = scope_fn(variable_groups, rng_groups)
y = fn(scope, *args, **kwargs)
return y, repack_fn(scope)
named = jax.named_call(named, name=name)
return named(variable_groups, rng_groups)
lifted = pack(inner, (True,), (True,), (True,))
def wrapper(scope, *args, **kwargs):
return lifted(scope, args, kwargs)
return wrapper


def _unzip2(xs):
ys = tuple(zip(*xs))
return ys if ys else ((), ())
13 changes: 7 additions & 6 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def child(self,
fn: the function to partially apply the child Scope to.
name: optional name of the child.
prefix: prefix used for generating name if it is `None`.
named_call: if true, `fn` will be wrapped with `lift.named_call`. The XLA
named_call: if true, `fn` will be run under `jax.named_scope`. The XLA
profiler will use this to name tag the computation.
**partial_kwargs: additional kwargs partially applied to `fn`.

Expand All @@ -593,15 +593,16 @@ def child(self,
prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else ''
name = self.default_name(prefix)
scope = self.push(name)
if named_call:
# We import named_call at runtime to avoid a circular import issue.
from . import lift # pylint: disable=g-import-not-at-top
fn = lift.named_call(fn, name)

@functools.wraps(fn)
def wrapper(*args, **kwargs):
kwargs = dict(partial_kwargs, **kwargs)
return fn(scope.rewound(), *args, **kwargs)
if named_call:
with jax.named_scope(name):
res = fn(scope.rewound(), *args, **kwargs)
else:
res = fn(scope.rewound(), *args, **kwargs)
return res

return wrapper

Expand Down
12 changes: 2 additions & 10 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,19 +588,11 @@ def __init__(self, path, step):
#################################################

class TransformedMethodReturnValueError(FlaxError):
"""
Transformed Module methods cannot return other Modules or Variables.
"""Transformed Module methods cannot return other Modules or Variables."""

This commonly occurs when ``@nn.named_call`` is automatically applied to
helper constructor methods when profiling is enabled (``FLAX_PROFILE=true``
environment variable or via ``nn.enable_named_call()``), and can be mitigated
by using the ``@nn.nowrap`` decorator to prevent automatic wrapping.
"""
def __init__(self, name):
super().__init__(
f'Transformed module method {name} cannot return Modules or Variables. '
f'For helper constructor methods use the @nn.nowrap decorator to prevent '
f'decoration by the automatic named_call transform.')
f'Transformed module method {name} cannot return Modules or Variables.')


class TransformTargetError(FlaxError):
Expand Down
33 changes: 20 additions & 13 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,26 @@ class _Sentinel:
_use_named_call = config.flax_profile


def _derive_profiling_name(module, fn):
def _get_fn_name(fn):
if isinstance(fn, functools.partial):
return _get_fn_name(fn.func)
return fn.__name__
fn_name = _get_fn_name(fn)
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
module_name = module.name or module.__class__.__name__
return f'{module_name}{method_suffix}'


def enable_named_call():
"""Enables named call wrapping for labelling profile traces.

When named call wrapping is enabled all JAX ops executed in a Module
will be wrapped with ``jax.named_call``. The ``Module`` class name will
will be run under ``jax.named_scope``. The ``Module`` class name will
show up around the operations belonging to that Module in the
Tensorboard profiling UI, simplifying the profiling process.

Note that ``jax.named_call`` only works for
Note that ``jax.named_scope`` only works for
compiled functions (e.g.: using jax.jit or jax.pmap).
"""
global _use_named_call
Expand Down Expand Up @@ -281,9 +292,6 @@ def nowrap(fun: _CallableT) -> _CallableT:
with the state handler or a separate named_call transform.

This is needed in several concrete instances:
- if you have a helper method that returns Modules or Variables to prevent
it from being functionalized by named_call. (Functionalized methods
can't return Modules/Variables.)
- if you're subclassing a method like Module.param and don't want this
overriden core function decorated with the state management wrapper.
- If you want a method to be callable from an unbound Module (e.g.: a
Expand Down Expand Up @@ -608,12 +616,7 @@ def _wrap_module_methods(cls):
method = getattr(cls, key)
if hasattr(method, 'nowrap'):
continue
wrapped_method = wrap_method_once(method)
if key != 'setup':
# We import named_call at runtime to avoid a circular import issue.
from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top
wrapped_method = named_call(wrapped_method, force=False)
setattr(cls, key, wrapped_method)
setattr(cls, key, wrap_method_once(method))
return cls

def _call_wrapped_method(self, fun, args, kwargs):
Expand Down Expand Up @@ -649,7 +652,11 @@ def _call_wrapped_method(self, fun, args, kwargs):
self._state.in_compact_method = True
_context.module_stack.append(self)
try:
y = fun(self, *args, **kwargs)
if _use_named_call:
with jax.named_scope(_derive_profiling_name(self, fun)):
y = fun(self, *args, **kwargs)
else:
y = fun(self, *args, **kwargs)
if _context.capture_stack:
filter_fn = _context.capture_stack[-1]
if filter_fn and filter_fn(self, fun_name):
Expand Down Expand Up @@ -880,7 +887,7 @@ def run_setup_only(x):
def _name_taken(self,
name: str,
module: 'Module' = None,
reuse_scopes : bool = False) -> bool:
reuse_scopes: bool = False) -> bool:
if name in _all_names_on_object(self):
val = getattr(self, name, None)
if module is not None and val is module:
Expand Down
53 changes: 12 additions & 41 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from flax.linen.module import Variable
from flax.linen.module import wrap_method_once
from flax.linen.module import _get_unbound_fn
from flax.linen.module import _derive_profiling_name
import jax

traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -1296,55 +1297,25 @@ def shared_forward_fn(*args, needs_residual, **kwargs):
multi_scope=False)


# Special case of decorator_lift_transform to handle named calls for profiling.
def named_call(class_fn, force=True):
"""Labels a method for labelled traces in profiles.

Note that it is better to use the `jax.named_scope` context manager directly
to add names to JAX's metadata name stack.

Args:
class_fn: The class method to label.
force: If True, the named_call transform is applied even if it is globally
disabled. (e.g.: by calling `flax.linen.disable_named_call()`)
Returns:
A wrapped version of ``class_fn`` that is labeled.
"""
if (hasattr(jax.config, 'jax_experimental_name_stack') and
jax.config.jax_experimental_name_stack):
# Use JAX's improved dynamic name-stack named_call.
# No transform boundary needed!
@functools.wraps(class_fn)
def wrapped_fn(self, *args, **kwargs):
fn_name = class_fn.__name__
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
module_name = self.name or self.__class__.__name__
full_name = f'{module_name}{method_suffix}'
return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
else:
# Use JAX's old purely-functional call-based named_call.
# Due to the ordering of method decorators, we must wrap the class_fn
# with the module state management wrapper first to maintain Module state
# correctly.
prewrapped_fn = wrap_method_once(class_fn)
@functools.wraps(prewrapped_fn)
def wrapped_fn(self, *args, **kwargs):
if ((not force and not linen_module._use_named_call)
or self._state.in_setup):
return prewrapped_fn(self, *args, **kwargs)
fn_name = class_fn.__name__
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
module_name = self.name or self.__class__.__name__
full_name = f'{module_name}{method_suffix}'
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes)
object.__setattr__(cloned, '_state', self._state.export())
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state)
_test_transformed_return_values(res, fn_name)
return res
# here we apply the given lifting transform to the scope-ingesting fn
trafo_fn = lift.named_call(core_fn, full_name)
module_scopes, args, kwargs = get_module_scopes(self, args, kwargs)
return trafo_fn(module_scopes, *args, **kwargs)

# We use JAX's dynamic name-stack named_call. No transform boundary needed!
@functools.wraps(class_fn)
def wrapped_fn(self, *args, **kwargs):
if ((not force and not linen_module._use_named_call) # pylint: disable=protected-access
or self._state.in_setup): # pylint: disable=protected-access
return class_fn(self, *args, **kwargs)
full_name = _derive_profiling_name(self, class_fn)
return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs)
return wrapped_fn

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

install_requires = [
"numpy>=1.12",
"jax>=0.3.2",
"jax>=0.3.14",
"matplotlib", # only needed for tensorboard export
"msgpack",
"optax",
Expand Down
12 changes: 5 additions & 7 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flax import struct
from flax.core import Scope, freeze, tracers
from flax.linen import compact
from flax.linen.module import override_named_call
import jax
from jax import random
from jax.nn import initializers
Expand Down Expand Up @@ -1584,13 +1583,12 @@ class MyComponent2(Generic[T], nn.Module):
class MyModule2(nn.Module):
submodule: MyComponent2[jnp.ndarray]

def test_named_call_rng_equivalance(self):
def test_jit_rng_equivalance(self):
model = nn.Dense(1, use_bias=False)
with override_named_call(False):
param = model.init(random.PRNGKey(0), np.ones((1, 1)))['params']['kernel']
with override_named_call(True):
param_2 = model.init(random.PRNGKey(0), np.ones(
(1, 1)))['params']['kernel']
jit_model = nn.jit(nn.Dense)(1, use_bias=False)
param = model.init(random.PRNGKey(0), np.ones((1, 1)))['params']['kernel']
param_2 = jit_model.init(random.PRNGKey(0), np.ones(
(1, 1)))['params']['kernel']
self.assertEqual(param, param_2)

def test_rng_reuse_after_rewind(self):
Expand Down
Loading