Closed
Description
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): N/A
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
:flax==0.6.11, jax==0.4.9, jaxlib==0.4.9
- Python version:
3.8
- GPU/TPU model and memory: N/A
- CUDA version (if applicable): N/A
Problem you have encountered:
When using flax.linen.scan
with a negative output_axes
, there is an unexpected AssertionError
. If I have understood the source code correctly, it is due to a typo here (namely, a minus sign instead of a plus sign).
What you expected to happen:
Apply scan as usual, stacking the outputs along the specified axis.
Logs, error messages, etc:
(projectabcde) lucaslingle@Lucass-MacBook-Pro projectabcde % python3 scripts/scan_issue.py
Traceback (most recent call last):
File "scripts/scan_issue.py", line 39, in <module>
main()
File "scripts/scan_issue.py", line 32, in main
params = cls().init(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1689, in init
_, v_out = self.init_with_output(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1594, in init_with_output
return init_with_output(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 968, in wrapper
return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 936, in wrapper
y = fn(root, *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 2170, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 432, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 868, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "scripts/scan_issue.py", line 18, in __call__
_, outputs = nn.scan(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/transforms.py", line 323, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 219, in wrapper
y, out_variable_groups_xs_t = fn(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 806, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
return jax.tree_util.tree_map(trans, xs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
assert pax < x.ndim
jax._src.traceback_util.UnfilteredStackTrace: AssertionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "scripts/scan_issue.py", line 39, in <module>
main()
File "scripts/scan_issue.py", line 32, in main
params = cls().init(
File "scripts/scan_issue.py", line 18, in __call__
_, outputs = nn.scan(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
return jax.tree_util.tree_map(trans, xs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
assert pax < x.ndim
AssertionError
Steps to reproduce:
# issue appears to be at https://github.com/google/flax/blob/main/flax/core/axes_scan.py#L101
import flax.linen as nn
import jax.random
class Foo(nn.Module):
unused_config: int
@nn.compact
def __call__(self, state, input_dict):
return state, nn.Dense(100)(input_dict["x"])
class Bar(nn.Module):
@nn.compact
def __call__(self, x):
_, outputs = nn.scan(
Foo,
variable_broadcast="params",
split_rngs=dict(
params=False,
),
in_axes=0,
out_axes=-1,
)(unused_config=123)(dict(unused_state_item=None), dict(x=x))
return outputs
def main():
cls = Bar
params = cls().init(
{"params": jax.random.PRNGKey(0)},
jax.random.normal(jax.random.PRNGKey(1), shape=[8, 128, 16])
)["params"]
if __name__ == "__main__":
main()
Thank you for your attention to this matter!
Metadata
Metadata
Assignees
Labels
No labels