Skip to content

Commit a2237ca

Browse files
committed
[nnx] add data parallel toy example
1 parent 591cd40 commit a2237ca

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
17+
18+
import jax
19+
import jax.numpy as jnp
20+
import numpy as np
21+
import optax
22+
from flax import nnx
23+
from jax.experimental import mesh_utils
24+
import matplotlib.pyplot as plt
25+
26+
# create a mesh + shardings
27+
num_devices = jax.local_device_count()
28+
mesh = jax.sharding.Mesh(
29+
mesh_utils.create_device_mesh((num_devices,)), ('data',)
30+
)
31+
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
32+
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('data'))
33+
34+
35+
# create model
36+
class MLP(nnx.Module):
37+
def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
38+
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
39+
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
40+
41+
def __call__(self, x):
42+
return self.linear2(nnx.relu(self.linear1(x)))
43+
44+
45+
model = MLP(1, 64, 1, rngs=nnx.Rngs(0))
46+
optimizer = nnx.Optimizer(model, optax.adamw(1e-2))
47+
48+
# replicate state
49+
state = nnx.state((model, optimizer))
50+
state = jax.device_put(state, model_sharding)
51+
nnx.update((model, optimizer), state)
52+
53+
# visualize model sharding
54+
print('model sharding')
55+
jax.debug.visualize_array_sharding(model.linear1.kernel.value)
56+
57+
58+
@nnx.jit
59+
def train_step(model: MLP, optimizer: nnx.Optimizer, x, y):
60+
def loss_fn(model: MLP):
61+
y_pred = model(x)
62+
return jnp.mean((y - y_pred) ** 2)
63+
64+
loss, grads = nnx.value_and_grad(loss_fn)(model)
65+
optimizer.update(grads)
66+
return loss
67+
68+
69+
def dataset(steps, batch_size):
70+
for _ in range(steps):
71+
x = np.random.uniform(-2, 2, size=(batch_size, 1))
72+
y = 0.8 * x**2 + 0.1 + np.random.normal(0, 0.1, size=x.shape)
73+
yield x, y
74+
75+
76+
for step, (x, y) in enumerate(dataset(1000, 16)):
77+
# shard data
78+
x, y = jax.device_put((x, y), data_sharding)
79+
# train
80+
loss = train_step(model, optimizer, x, y)
81+
82+
if step == 0:
83+
print('data sharding')
84+
jax.debug.visualize_array_sharding(x)
85+
86+
if step % 100 == 0:
87+
print(f'step={step}, loss={loss}')
88+
89+
# dereplicate state
90+
state = nnx.state((model, optimizer))
91+
state = jax.device_get(state)
92+
nnx.update((model, optimizer), state)
93+
94+
X, Y = next(dataset(1, 1000))
95+
x_range = np.linspace(X.min(), X.max(), 100)[:, None]
96+
y_pred = model(x_range)
97+
98+
# plot
99+
plt.scatter(X, Y, label='data')
100+
plt.plot(x_range, y_pred, color='black', label='model')
101+
plt.legend()
102+
plt.show()

0 commit comments

Comments
 (0)