Skip to content

Commit b33eccb

Browse files
Merge pull request #146 from geometric-intelligence/wandb
add wandb integration, update pyproject.toml and .gitignore
2 parents ccae50c + c855b48 commit b33eccb

File tree

5 files changed

+28
-6
lines changed

5 files changed

+28
-6
lines changed

.gitignore

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ neurometry/datasets/rnn_grid_cells/Dual agent path integration disjoint PCs/*
2323
neurometry/datasets/rnn_grid_cells/Single agent path integration/*
2424

2525
# Wandb files
26-
wandb/*
26+
*wandb/*
27+
*logs/*
28+
29+
neurometry/curvature/grid-cells-curvature/models/xu_rnn/logs/*
30+
neurometry/curvature/grid-cells-curvature/models/xu_rnn/wandb/*
2731

2832

2933
# Result files

neurometry/curvature/grid-cells-curvature/models/xu_rnn/configs/rnn_isometry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ def get_config():
1313

1414
# training config
1515
config.train = d(
16-
num_steps_train=20, #100000
16+
num_steps_train=25000, #100000
1717
lr=0.006,
1818
lr_decay_from=10000,
1919
steps_per_logging=20,
20-
steps_per_large_logging=5, #500
20+
steps_per_large_logging=500, #500
2121
steps_per_integration=2000,
2222
norm_v=True,
2323
positive_v=True,

neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717
import model as model
1818
import utils
1919
import pickle
20+
import wandb
2021

2122
class Experiment:
2223
def __init__(self, config: ml_collections.ConfigDict, device):
24+
2325
self.config = config
2426
self.device = device
2527

28+
wandb.init(project='grid-cell-rnns', entity='bioshape-lab', config=config.to_dict())
29+
2630
# initialize models
2731
logging.info("==== initialize model ====")
2832
self.model_config = model.GridCellConfig(**config.model)
@@ -118,6 +122,7 @@ def train_and_evaluate(self, workdir):
118122
if step % config.steps_per_logging == 0 or step == 1:
119123
train_metrics = utils.average_appended_metrics(train_metrics)
120124
writer.write_scalars(step, train_metrics)
125+
wandb.log({key: value for key, value in train_metrics.items()}, step=step)
121126
train_metrics = []
122127

123128
if step % config.steps_per_large_logging == 0:
@@ -131,7 +136,9 @@ def visualize(activations, name):
131136
activations = activations.data.cpu().detach().numpy()
132137
activations = activations.reshape(
133138
(-1, block_size, num_grid, num_grid))[:10, :10]
134-
writer.write_images(step, {name: utils.draw_heatmap(activations)})
139+
images = utils.draw_heatmap(activations)
140+
writer.write_images(step, {name: images})
141+
wandb.log({name: wandb.Image(images)}, step=step)
135142

136143
visualize(self.model.encoder.v, 'v')
137144
visualize(self.model.decoder.u, 'u')
@@ -172,11 +179,13 @@ def visualize(activations, name):
172179
heatmaps = heatmaps.cpu().detach().numpy()[None, ...]
173180
writer.write_images(
174181
step, {'vu_heatmap': utils.draw_heatmap(heatmaps)})
182+
wandb.log({'vu_heatmap': wandb.Image(utils.draw_heatmap(heatmaps))}, step=step)
175183

176184
err = torch.mean(torch.sum((x_eval - x_pred) ** 2, dim=-1))
177185
writer.write_scalars(step, {'pred_x': err.item()})
178186
writer.write_scalars(step, {'error_fixed': error_fixed.item()})
179187
writer.write_scalars(step, {'error_fixed_zero': error_fixed_zero.item()})
188+
wandb.log({'pred_x': err.item(), 'error_fixed': error_fixed.item(), 'error_fixed_zero': error_fixed_zero.item()}, step=step)
180189

181190
if step % config.steps_per_integration == 0 or step == 1:
182191
# perform path integration
@@ -193,6 +202,7 @@ def visualize(activations, name):
193202
writer.write_scalars(step, {'score': score.item()})
194203
writer.write_scalars(step, {'scale': scale_tensor[0].item() * num_grid})
195204
writer.write_scalars(step, {'scale_mean': torch.mean(scale_tensor).item() * num_grid})
205+
wandb.log({'score': score.item(), 'scale': scale_tensor[0].item() * num_grid, 'scale_mean': torch.mean(scale_tensor).item() * num_grid}, step=step)
196206

197207
# for visualization
198208
if self.config.model.trans_type == 'nonlinear_simple':
@@ -209,6 +219,7 @@ def visualize(activations, name):
209219
'heatmaps': utils.draw_heatmap(outputs['heatmaps'][:, ::5]),
210220
}
211221
writer.write_images(step, images)
222+
wandb.log({key: wandb.Image(value) for key, value in images.items()}, step=step)
212223

213224
# for quantitative evaluation
214225
if self.config.model.trans_type == 'nonlinear_simple':
@@ -218,6 +229,7 @@ def visualize(activations, name):
218229

219230
err = utils.dict_to_numpy(outputs['err'])
220231
writer.write_scalars(step, err)
232+
wandb.log({key: value for key, value in err.items()}, step=step)
221233

222234
if step == config.num_steps_train:
223235
ckpt_dir = os.path.join(workdir, 'ckpt')
@@ -300,8 +312,9 @@ def _save_checkpoint(self, step, ckpt_dir):
300312
if not tf.io.gfile.exists(model_dir):
301313
tf.io.gfile.makedirs(model_dir)
302314
model_filename = os.path.join(model_dir, 'checkpoint-step{}.pth'.format(step))
303-
torch.save(state, model_filename)
304315
logging.info("Saving model checkpoint: {} ...".format(model_filename))
316+
torch.save(state, model_filename)
317+
wandb.save(model_filename)
305318

306319
activations_dir = os.path.join(ckpt_dir, 'activations')
307320
if not tf.io.gfile.exists(activations_dir):

neurometry/curvature/grid-cells-curvature/models/xu_rnn/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
FLAGS = flags.FLAGS
1414
config_flags.DEFINE_config_file(
1515
"config", None, "Training configuration.", lock_config=True)
16-
flags.DEFINE_string("workdir", "../logs", "Work unit directory.")
16+
flags.DEFINE_string("workdir", "logs", "Work unit directory.")
1717
flags.mark_flags_as_required(["config"])
1818
flags.DEFINE_string("mode", 'train', "train / visualize / integration / correction")
1919

pyproject.toml

+5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ dependencies=[
4646
"scikit-dimension",
4747
"umap-learn",
4848
"ripser",
49+
"absl-py",
50+
"ml-collections",
51+
"tensowflow-cpu",
52+
"clu",
53+
"labml-helpers",
4954
"giotto-ph @ git+https://github.com/alibayeh/giotto-ph.git",
5055
"pyflagser @ git+https://github.com/alibayeh/pyflagser.git",
5156
"giotto-tda @ git+https://github.com/alibayeh/giotto-tda.git",

0 commit comments

Comments
 (0)