-
I'm computing the Jacobian of a neural network function with respect to its parameters for multiple inputs using Flax/JAX: f(params, input) -> output I tried two approaches to calculate the jacobian: Jacobian after batching:jit(jacrev(vmap(f, in_axes=(None, 0)), argnums=0))(params, inputs) Batching after Jacobian:jit(vmap(jacrev(f, argnums=0), in_axes=(None, 0)))(params, inputs) Questions:
I expected XLA to optimize both approaches similarly, but the first one seems much more memory-intensive. (I am assuming its a bug based on my naive understanding of XLA and JIT). Any insights into why this happens would be greatly appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
One way to see what's happening is to look at the import jax.numpy as jnp
from jax import jit, jacrev, vmap
def f(params, input):
return params * input
params = 2.0
inputs = jnp.arange(1000.0)
print('jacrev(vmap(f)):')
print(jit(jacrev(vmap(f, in_axes=(None, 0)), argnums=0)).trace(params, inputs).jaxpr)
print()
print('vmap(jacrev(f)):')
print(jit(vmap(jacrev(f, argnums=0), in_axes=(None, 0))).trace(params, inputs).jaxpr)
Notice the two are almost identical, except for the fact that The reason for this is that the jacobian of a Note that this kind of optimization (recognizing that a matrix is diagonal and rewriting the computation to account for that) is not something XLA's compiler will do automatically. |
Beta Was this translation helpful? Give feedback.
One way to see what's happening is to look at the
jaxpr
for your computation. This may get pretty complicated in practice, but let's look at a very simple example: