-
Hey guys! I'm trying to figure out how to use pjit with a flax struct dataclass as the output object. The issue seems that the input or output axis resources cannot be defined when returning the state object. Any ideas how to resolve this? Reprod of what I'm trying to achieve, but simplified. In my actual code I have the state object as the input to the function as well rather than created within the step function: from typing import Any, Optional
import flax
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import mesh
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.nn import relu
@flax.struct.dataclass
class OptimizerState:
scaler: Any
params: Any
opt_state: Optional[Any]
optimizer: Optional[Any]
def test_mlp_grad():
def loss_func(batch, weights):
x, y = batch
w1, w2 = weights
x = x @ w1
x = relu(x)
x = with_sharding_constraint(x, P('data_parallel', 'model_parallel'))
x = x @ w2
x = relu(x)
loss = jnp.mean((x - y) ** 2)
return loss
def step_serial(batch, weights):
gradients = jax.grad(loss_func, argnums=1)(batch, weights)
state = OptimizerState(
scaler=flax.optim.DynamicScale(), # todo: this crashes pjit, but I cant define the axis resources for this object
params=tuple(w - g * lr for w, g in zip(weights, gradients)),
opt_state=None,
optimizer=None
)
return state
step_parallel = pjit(
step_serial,
in_axis_resources=((P('data_parallel', None), P('data_parallel', None)),
(P(None, 'model_parallel'), P('model_parallel', None))),
out_axis_resources=(P('model_parallel'))
)
lr = 0.1
N = 8
D = 128
np.random.seed(1)
x = np.random.uniform(size=(N, D))
y = np.random.uniform(size=(N, D))
w1 = np.random.uniform(size=(D, D))
w2 = np.random.uniform(size=(D, D))
mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2)
with mesh(mesh_devices, ('data_parallel', 'model_parallel')):
out = step_parallel((x, y), (w1, w2))
if __name__ == "__main__":
test_mlp_grad() A short term solution would be to just not use the optimizer state and pass individual variables into the function, but hoping for a longer term solution! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
It looks to me like this is supposed to work or at least not fail due to the dataclass.
|
Beta Was this translation helpful? Give feedback.
It looks to me like this is supposed to work or at least not fail due to the dataclass.
I'll need some more info to help you out here: