1
+ # Copyright 2024 The Flax Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ os .environ ['XLA_FLAGS' ] = '--xla_force_host_platform_device_count=8'
17
+
18
+ import jax
19
+ import jax .numpy as jnp
20
+ import numpy as np
21
+ import optax
22
+ from flax import nnx
23
+ from jax .experimental import mesh_utils
24
+ import matplotlib .pyplot as plt
25
+
26
+ # create a mesh + shardings
27
+ num_devices = jax .local_device_count ()
28
+ mesh = jax .sharding .Mesh (
29
+ mesh_utils .create_device_mesh ((num_devices ,)), ('data' ,)
30
+ )
31
+ model_sharding = jax .NamedSharding (mesh , jax .sharding .PartitionSpec ())
32
+ data_sharding = jax .NamedSharding (mesh , jax .sharding .PartitionSpec ('data' ))
33
+
34
+
35
+ # create model
36
+ class MLP (nnx .Module ):
37
+ def __init__ (self , din , dmid , dout , * , rngs : nnx .Rngs ):
38
+ self .linear1 = nnx .Linear (din , dmid , rngs = rngs )
39
+ self .linear2 = nnx .Linear (dmid , dout , rngs = rngs )
40
+
41
+ def __call__ (self , x ):
42
+ return self .linear2 (nnx .relu (self .linear1 (x )))
43
+
44
+
45
+ model = MLP (1 , 64 , 1 , rngs = nnx .Rngs (0 ))
46
+ optimizer = nnx .Optimizer (model , optax .adamw (1e-2 ))
47
+
48
+ # replicate state
49
+ state = nnx .state ((model , optimizer ))
50
+ state = jax .device_put (state , model_sharding )
51
+ nnx .update ((model , optimizer ), state )
52
+
53
+ # visualize model sharding
54
+ print ('model sharding' )
55
+ jax .debug .visualize_array_sharding (model .linear1 .kernel .value )
56
+
57
+
58
+ @nnx .jit
59
+ def train_step (model : MLP , optimizer : nnx .Optimizer , x , y ):
60
+ def loss_fn (model : MLP ):
61
+ y_pred = model (x )
62
+ return jnp .mean ((y - y_pred ) ** 2 )
63
+
64
+ loss , grads = nnx .value_and_grad (loss_fn )(model )
65
+ optimizer .update (grads )
66
+ return loss
67
+
68
+
69
+ def dataset (steps , batch_size ):
70
+ for _ in range (steps ):
71
+ x = np .random .uniform (- 2 , 2 , size = (batch_size , 1 ))
72
+ y = 0.8 * x ** 2 + 0.1 + np .random .normal (0 , 0.1 , size = x .shape )
73
+ yield x , y
74
+
75
+
76
+ for step , (x , y ) in enumerate (dataset (1000 , 16 )):
77
+ # shard data
78
+ x , y = jax .device_put ((x , y ), data_sharding )
79
+ # train
80
+ loss = train_step (model , optimizer , x , y )
81
+
82
+ if step == 0 :
83
+ print ('data sharding' )
84
+ jax .debug .visualize_array_sharding (x )
85
+
86
+ if step % 100 == 0 :
87
+ print (f'step={ step } , loss={ loss } ' )
88
+
89
+ # dereplicate state
90
+ state = nnx .state ((model , optimizer ))
91
+ state = jax .device_get (state )
92
+ nnx .update ((model , optimizer ), state )
93
+
94
+ X , Y = next (dataset (1 , 1000 ))
95
+ x_range = np .linspace (X .min (), X .max (), 100 )[:, None ]
96
+ y_pred = model (x_range )
97
+
98
+ # plot
99
+ plt .scatter (X , Y , label = 'data' )
100
+ plt .plot (x_range , y_pred , color = 'black' , label = 'model' )
101
+ plt .legend ()
102
+ plt .show ()
0 commit comments