You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
revise axes_scan to flatten argument pytrees only once
A user has a custom pytree node with the unusual behavior that it introduces
new arrays when flattening. That is, it's as if we had:
```python
# a custom object with two leaf arrays
custom_tree_object = SomeObject(jax_arrray1, jax_array2)
# convert leaves to ShapedArrays
custom_tree_object2 = jax.tree.map(core.typeof, custom_tree_object)
# flatten, should only see ShapedArrays, right?
leaves, treedef = jax.tree.flatten(custom_tree_object2)
print(leaves)
# [ShapedArray(...), ShapedArray(...), np.array(...)]
```
This change makes the `flax.nn.scan` function robust to such behavior. Without it, we were passing non-AbstractValues into JAX where JAX required AbstractValues.
I don't think we want to support this in general, but this fix seemed like the most
expedient way to roll fowrard jax-ml/jax#29273
PiperOrigin-RevId: 768175118
0 commit comments