17
17
import model as model
18
18
import utils
19
19
import pickle
20
+ import wandb
20
21
21
22
class Experiment :
22
23
def __init__ (self , config : ml_collections .ConfigDict , device ):
24
+
23
25
self .config = config
24
26
self .device = device
25
27
28
+ wandb .init (project = 'grid-cell-rnns' , entity = 'bioshape-lab' , config = config .to_dict ())
29
+
26
30
# initialize models
27
31
logging .info ("==== initialize model ====" )
28
32
self .model_config = model .GridCellConfig (** config .model )
@@ -118,6 +122,7 @@ def train_and_evaluate(self, workdir):
118
122
if step % config .steps_per_logging == 0 or step == 1 :
119
123
train_metrics = utils .average_appended_metrics (train_metrics )
120
124
writer .write_scalars (step , train_metrics )
125
+ wandb .log ({key : value for key , value in train_metrics .items ()}, step = step )
121
126
train_metrics = []
122
127
123
128
if step % config .steps_per_large_logging == 0 :
@@ -131,7 +136,9 @@ def visualize(activations, name):
131
136
activations = activations .data .cpu ().detach ().numpy ()
132
137
activations = activations .reshape (
133
138
(- 1 , block_size , num_grid , num_grid ))[:10 , :10 ]
134
- writer .write_images (step , {name : utils .draw_heatmap (activations )})
139
+ images = utils .draw_heatmap (activations )
140
+ writer .write_images (step , {name : images })
141
+ wandb .log ({name : wandb .Image (images )}, step = step )
135
142
136
143
visualize (self .model .encoder .v , 'v' )
137
144
visualize (self .model .decoder .u , 'u' )
@@ -172,11 +179,13 @@ def visualize(activations, name):
172
179
heatmaps = heatmaps .cpu ().detach ().numpy ()[None , ...]
173
180
writer .write_images (
174
181
step , {'vu_heatmap' : utils .draw_heatmap (heatmaps )})
182
+ wandb .log ({'vu_heatmap' : wandb .Image (utils .draw_heatmap (heatmaps ))}, step = step )
175
183
176
184
err = torch .mean (torch .sum ((x_eval - x_pred ) ** 2 , dim = - 1 ))
177
185
writer .write_scalars (step , {'pred_x' : err .item ()})
178
186
writer .write_scalars (step , {'error_fixed' : error_fixed .item ()})
179
187
writer .write_scalars (step , {'error_fixed_zero' : error_fixed_zero .item ()})
188
+ wandb .log ({'pred_x' : err .item (), 'error_fixed' : error_fixed .item (), 'error_fixed_zero' : error_fixed_zero .item ()}, step = step )
180
189
181
190
if step % config .steps_per_integration == 0 or step == 1 :
182
191
# perform path integration
@@ -193,6 +202,7 @@ def visualize(activations, name):
193
202
writer .write_scalars (step , {'score' : score .item ()})
194
203
writer .write_scalars (step , {'scale' : scale_tensor [0 ].item () * num_grid })
195
204
writer .write_scalars (step , {'scale_mean' : torch .mean (scale_tensor ).item () * num_grid })
205
+ wandb .log ({'score' : score .item (), 'scale' : scale_tensor [0 ].item () * num_grid , 'scale_mean' : torch .mean (scale_tensor ).item () * num_grid }, step = step )
196
206
197
207
# for visualization
198
208
if self .config .model .trans_type == 'nonlinear_simple' :
@@ -209,6 +219,7 @@ def visualize(activations, name):
209
219
'heatmaps' : utils .draw_heatmap (outputs ['heatmaps' ][:, ::5 ]),
210
220
}
211
221
writer .write_images (step , images )
222
+ wandb .log ({key : wandb .Image (value ) for key , value in images .items ()}, step = step )
212
223
213
224
# for quantitative evaluation
214
225
if self .config .model .trans_type == 'nonlinear_simple' :
@@ -218,6 +229,7 @@ def visualize(activations, name):
218
229
219
230
err = utils .dict_to_numpy (outputs ['err' ])
220
231
writer .write_scalars (step , err )
232
+ wandb .log ({key : value for key , value in err .items ()}, step = step )
221
233
222
234
if step == config .num_steps_train :
223
235
ckpt_dir = os .path .join (workdir , 'ckpt' )
@@ -300,8 +312,9 @@ def _save_checkpoint(self, step, ckpt_dir):
300
312
if not tf .io .gfile .exists (model_dir ):
301
313
tf .io .gfile .makedirs (model_dir )
302
314
model_filename = os .path .join (model_dir , 'checkpoint-step{}.pth' .format (step ))
303
- torch .save (state , model_filename )
304
315
logging .info ("Saving model checkpoint: {} ..." .format (model_filename ))
316
+ torch .save (state , model_filename )
317
+ wandb .save (model_filename )
305
318
306
319
activations_dir = os .path .join (ckpt_dir , 'activations' )
307
320
if not tf .io .gfile .exists (activations_dir ):
0 commit comments