Skip to content

Commit 893a660

Browse files
mattjjFlax Authors
authored andcommitted
revise axes_scan to flatten argument pytrees only once
A user has a custom pytree node with the unusual behavior that it introduces new arrays when flattening. That is, it's as if we had: ```python # a custom object with two leaf arrays custom_tree_object = SomeObject(jax_arrray1, jax_array2) # convert leaves to ShapedArrays custom_tree_object2 = jax.tree.map(core.typeof, custom_tree_object) # flatten, should only see ShapedArrays, right? leaves, treedef = jax.tree.flatten(custom_tree_object2) print(leaves) # [ShapedArray(...), ShapedArray(...), np.array(...)] ``` This change makes the `flax.nn.scan` function robust to such behavior. Without it, we were passing non-AbstractValues into JAX where JAX required AbstractValues. I don't think we want to support this in general, but this fix seemed like the most expedient way to roll fowrard jax-ml/jax#29273 PiperOrigin-RevId: 768175118
1 parent 8233415 commit 893a660

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

flax/core/axes_scan.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,16 @@ def body_fn(c, xs, init_mode=False):
147147

148148
broadcast_body = functools.partial(body_fn, init_mode=True)
149149

150-
carry_avals = jax.tree_util.tree_map(
151-
lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init
152-
)
153-
scan_avals = jax.tree_util.tree_map(
154-
lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs
155-
)
156-
input_avals = (carry_avals, scan_avals)
150+
init_flat, carry_tree = jax.tree.flatten(init)
151+
xs_flat, scan_tree = jax.tree.flatten(xs)
152+
carry_avals = [core.ShapedArray(jnp.shape(x), jnp.result_type(x))
153+
for x in init_flat]
154+
scan_avals = [core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))
155+
for x in xs_flat]
156+
in_avals = [*carry_avals, *scan_avals]
157+
in_tree = jax.tree_util.treedef_tuple((carry_tree, scan_tree))
158+
assert all(isinstance(a, core.AbstractValue) for a in in_avals), in_avals
157159

158-
in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
159160
debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
160161
(in_tree,), {})
161162
f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(

0 commit comments

Comments
 (0)