Skip to content

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__) #4218

Open
@JunhongXu

Description

@JunhongXu

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax: 0.9.0, jax: 0.4.30, jaxlib: 0.4.30
  • Python version: 3.11
  • GPU/TPU model and memory: GPU: Nvidia RTX 4090
  • CUDA version (if applicable): 12.2

Problem you have encountered:

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__), where aux_fn is defined by

def aux_fn(model, x):
    return model(x)

From my understanding, I found that using an auxiliary function with nnx.jit seems a common practice and is required if we want to modify the internal state of the model (#3998). However, it seems slower than directly wrapping the model.__call__ function using nnx.jit.

See the colab link below to reproduce.

Steps to reproduce:

Colab link: https://colab.research.google.com/drive/1cGpcaBaJABUxhZuywgLZELZRwFsT5zve?usp=sharing

For completeness, I also copy the code here

import time
import jax
from flax import nnx as nnx


class MLP(nnx.Module):
	def __init__(self, din: int, dout: int, rngs: nnx.Rngs) -> None:
		# super().__init__()
		self.fc1 = nnx.Linear(din, 128, rngs=rngs)
		self.fc2 = nnx.Linear(128, 128, rngs=rngs)
		self.fc3 = nnx.Linear(128, 128, rngs=rngs)
		self.out = nnx.Linear(128, dout, rngs=rngs)

	def __call__(self, x):
		x = self.fc1(x)
		x = nnx.relu(x)
		x = self.fc2(x)
		x = nnx.relu(x)
		x = self.fc3(x)
		x = nnx.relu(x)
		x = self.out(x)
		return x


def nn_forward(model, x):
    return model, x


def benchmark_jax():
    rngs = nnx.Rngs(0)
    din, dout = 29, 7  # Example dimensions
    mlp = MLP(din, dout, rngs)
    nn_forward_call_no_aux = nnx.jit(mlp.__call__)

    # Prepare data
    x = jax.random.normal(rngs(), shape=(1, din))
    num_iterations = 1000
    warmup_iters = 100

    for _ in range(warmup_iters):
        _ = nn_forward_call_no_aux(x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_call_no_aux(x)
    end_time = time.time()

    print(f"JAX forward pass time for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average time: {(end_time - start_time) / num_iterations:.5f} seconds")

    print("-------------------")
    nn_forward_jit = nnx.jit(nn_forward)
    for _ in range(warmup_iters):
        _ = nn_forward_jit(mlp, x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_jit(mlp, x)
    end_time = time.time()
    print(f"JAX forward pass time while using auxiliary functions for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average while using auxiliary functions time: {(end_time - start_time) / num_iterations:.5f} seconds")

The outputs using a RTX 4090 are:

JAX forward pass time for 1000 iterations: 0.10531 seconds
JAX forward pass average time: 0.00011 seconds
-------------------
JAX forward pass time while using auxiliary functions for 1000 iterations: 0.59596 seconds
JAX forward pass average while using auxiliary functions time: 0.00060 seconds

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions