Skip to content

Commit bec5915

Browse files
committed
[nnx] add data parallel toy example
1 parent 591cd40 commit bec5915

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import numpy as np
7+
import optax
8+
from flax import nnx
9+
from jax.experimental import mesh_utils
10+
import matplotlib.pyplot as plt
11+
12+
# create a mesh + shardings
13+
num_devices = jax.local_device_count()
14+
mesh = jax.sharding.Mesh(
15+
mesh_utils.create_device_mesh((num_devices,)), ('data',)
16+
)
17+
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
18+
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('data'))
19+
20+
21+
# create model
22+
class MLP(nnx.Module):
23+
def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
24+
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
25+
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
26+
27+
def __call__(self, x):
28+
return self.linear2(nnx.relu(self.linear1(x)))
29+
30+
31+
model = MLP(1, 64, 1, rngs=nnx.Rngs(0))
32+
optimizer = nnx.Optimizer(model, optax.adamw(1e-2))
33+
34+
# replicate state
35+
state = nnx.state((model, optimizer))
36+
state = jax.device_put(state, model_sharding)
37+
nnx.update((model, optimizer), state)
38+
39+
# visualize model sharding
40+
print('model sharding')
41+
jax.debug.visualize_array_sharding(model.linear1.kernel.value)
42+
43+
44+
@nnx.jit
45+
def train_step(model: MLP, optimizer: nnx.Optimizer, x, y):
46+
def loss_fn(model: MLP):
47+
y_pred = model(x)
48+
return jnp.mean((y - y_pred) ** 2)
49+
50+
loss, grads = nnx.value_and_grad(loss_fn)(model)
51+
optimizer.update(grads)
52+
return loss
53+
54+
55+
def dataset(steps, batch_size):
56+
for _ in range(steps):
57+
x = np.random.uniform(-2, 2, size=(batch_size, 1))
58+
y = 0.8 * x**2 + 0.1 + np.random.normal(0, 0.1, size=x.shape)
59+
yield x, y
60+
61+
62+
for step, (x, y) in enumerate(dataset(1000, 16)):
63+
# shard data
64+
x, y = jax.device_put((x, y), data_sharding)
65+
# train
66+
loss = train_step(model, optimizer, x, y)
67+
68+
if step == 0:
69+
print('data sharding')
70+
jax.debug.visualize_array_sharding(x)
71+
72+
if step % 100 == 0:
73+
print(f'step={step}, loss={loss}')
74+
75+
# dereplicate state
76+
state = nnx.state((model, optimizer))
77+
state = jax.device_get(state)
78+
nnx.update((model, optimizer), state)
79+
80+
X, Y = next(dataset(1, 1000))
81+
x_range = np.linspace(X.min(), X.max(), 100)[:, None]
82+
y_pred = model(x_range)
83+
84+
# plot
85+
plt.scatter(X, Y, label='data')
86+
plt.plot(x_range, y_pred, color='black', label='model')
87+
plt.legend()
88+
plt.show()

0 commit comments

Comments
 (0)