Skip to content

Difference in Memory Usage Between jacrev(vmap(...)) and vmap(jacrev(...)) in JAX #26127

Answered by jakevdp
unik-w asked this question in Q&A
Discussion options

You must be logged in to vote

One way to see what's happening is to look at the jaxpr for your computation. This may get pretty complicated in practice, but let's look at a very simple example:

import jax.numpy as jnp
from jax import jit, jacrev, vmap

def f(params, input):
  return params * input

params = 2.0
inputs = jnp.arange(1000.0)
print('jacrev(vmap(f)):')
print(jit(jacrev(vmap(f, in_axes=(None, 0)), argnums=0)).trace(params, inputs).jaxpr)
print()
print('vmap(jacrev(f)):')
print(jit(vmap(jacrev(f, argnums=0), in_axes=(None, 0))).trace(params, inputs).jaxpr)
jacrev(vmap(f)):
{ lambda ; a:f32[] b:f32[1000]. let
    c:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
    _:f32[1000] = mul c b
   …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@unik-w
Comment options

Answer selected by unik-w
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants