1
+ import os
2
+ os .environ ['XLA_FLAGS' ] = '--xla_force_host_platform_device_count=8'
3
+
4
+ import jax
5
+ import jax .numpy as jnp
6
+ import numpy as np
7
+ import optax
8
+ from flax import nnx
9
+ from jax .experimental import mesh_utils
10
+ import matplotlib .pyplot as plt
11
+
12
+ # create a mesh + shardings
13
+ num_devices = jax .local_device_count ()
14
+ mesh = jax .sharding .Mesh (
15
+ mesh_utils .create_device_mesh ((num_devices ,)), ('data' ,)
16
+ )
17
+ model_sharding = jax .NamedSharding (mesh , jax .sharding .PartitionSpec ())
18
+ data_sharding = jax .NamedSharding (mesh , jax .sharding .PartitionSpec ('data' ))
19
+
20
+
21
+ # create model
22
+ class MLP (nnx .Module ):
23
+ def __init__ (self , din , dmid , dout , * , rngs : nnx .Rngs ):
24
+ self .linear1 = nnx .Linear (din , dmid , rngs = rngs )
25
+ self .linear2 = nnx .Linear (dmid , dout , rngs = rngs )
26
+
27
+ def __call__ (self , x ):
28
+ return self .linear2 (nnx .relu (self .linear1 (x )))
29
+
30
+
31
+ model = MLP (1 , 64 , 1 , rngs = nnx .Rngs (0 ))
32
+ optimizer = nnx .Optimizer (model , optax .adamw (1e-2 ))
33
+
34
+ # replicate state
35
+ state = nnx .state ((model , optimizer ))
36
+ state = jax .device_put (state , model_sharding )
37
+ nnx .update ((model , optimizer ), state )
38
+
39
+ # visualize model sharding
40
+ print ('model sharding' )
41
+ jax .debug .visualize_array_sharding (model .linear1 .kernel .value )
42
+
43
+
44
+ @nnx .jit
45
+ def train_step (model : MLP , optimizer : nnx .Optimizer , x , y ):
46
+ def loss_fn (model : MLP ):
47
+ y_pred = model (x )
48
+ return jnp .mean ((y - y_pred ) ** 2 )
49
+
50
+ loss , grads = nnx .value_and_grad (loss_fn )(model )
51
+ optimizer .update (grads )
52
+ return loss
53
+
54
+
55
+ def dataset (steps , batch_size ):
56
+ for _ in range (steps ):
57
+ x = np .random .uniform (- 2 , 2 , size = (batch_size , 1 ))
58
+ y = 0.8 * x ** 2 + 0.1 + np .random .normal (0 , 0.1 , size = x .shape )
59
+ yield x , y
60
+
61
+
62
+ for step , (x , y ) in enumerate (dataset (1000 , 16 )):
63
+ # shard data
64
+ x , y = jax .device_put ((x , y ), data_sharding )
65
+ # train
66
+ loss = train_step (model , optimizer , x , y )
67
+
68
+ if step == 0 :
69
+ print ('data sharding' )
70
+ jax .debug .visualize_array_sharding (x )
71
+
72
+ if step % 100 == 0 :
73
+ print (f'step={ step } , loss={ loss } ' )
74
+
75
+ # dereplicate state
76
+ state = nnx .state ((model , optimizer ))
77
+ state = jax .device_get (state )
78
+ nnx .update ((model , optimizer ), state )
79
+
80
+ X , Y = next (dataset (1 , 1000 ))
81
+ x_range = np .linspace (X .min (), X .max (), 100 )[:, None ]
82
+ y_pred = model (x_range )
83
+
84
+ # plot
85
+ plt .scatter (X , Y , label = 'data' )
86
+ plt .plot (x_range , y_pred , color = 'black' , label = 'model' )
87
+ plt .legend ()
88
+ plt .show ()
0 commit comments