Skip to content

Error when using nn.scan with negative output_axes #3460

Closed
@lucaslingle

Description

@lucaslingle

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): N/A
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax==0.6.11, jax==0.4.9, jaxlib==0.4.9
  • Python version: 3.8
  • GPU/TPU model and memory: N/A
  • CUDA version (if applicable): N/A

Problem you have encountered:

When using flax.linen.scan with a negative output_axes, there is an unexpected AssertionError. If I have understood the source code correctly, it is due to a typo here (namely, a minus sign instead of a plus sign).

What you expected to happen:

Apply scan as usual, stacking the outputs along the specified axis.

Logs, error messages, etc:

(projectabcde) lucaslingle@Lucass-MacBook-Pro projectabcde % python3 scripts/scan_issue.py
Traceback (most recent call last):
  File "scripts/scan_issue.py", line 39, in <module>
    main()
  File "scripts/scan_issue.py", line 32, in main
    params = cls().init(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1689, in init
    _, v_out = self.init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1594, in init_with_output
    return init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 968, in wrapper
    return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 936, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 2170, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 432, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 868, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "scripts/scan_issue.py", line 18, in __call__
    _, outputs = nn.scan(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/transforms.py", line 323, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 219, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 806, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
    ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
    return jax.tree_util.tree_map(trans, xs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
    assert pax < x.ndim
jax._src.traceback_util.UnfilteredStackTrace: AssertionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "scripts/scan_issue.py", line 39, in <module>
    main()
  File "scripts/scan_issue.py", line 32, in main
    params = cls().init(
  File "scripts/scan_issue.py", line 18, in __call__
    _, outputs = nn.scan(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
    ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
    return jax.tree_util.tree_map(trans, xs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
    assert pax < x.ndim
AssertionError

Steps to reproduce:

# issue appears to be at https://github.com/google/flax/blob/main/flax/core/axes_scan.py#L101

import flax.linen as nn
import jax.random


class Foo(nn.Module):
    unused_config: int

    @nn.compact
    def __call__(self, state, input_dict):
        return state, nn.Dense(100)(input_dict["x"])


class Bar(nn.Module):
    @nn.compact
    def __call__(self, x):
        _, outputs = nn.scan(
            Foo,
            variable_broadcast="params",
            split_rngs=dict(
                params=False,
            ),
            in_axes=0,
            out_axes=-1,
        )(unused_config=123)(dict(unused_state_item=None), dict(x=x))
        return outputs


def main():
    cls = Bar
    params = cls().init(
        {"params": jax.random.PRNGKey(0)},
        jax.random.normal(jax.random.PRNGKey(1), shape=[8, 128, 16])
    )["params"]


if __name__ == "__main__":
    main()

Thank you for your attention to this matter!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions