Skip to content

Commit e21ca8a

Browse files
Merge pull request #152 from geometric-intelligence/fix_serialization
Fix serialization
2 parents f2cba7c + 8e9a05a commit e21ca8a

21 files changed

+211
-479
lines changed

neurometry/curvature/datasets/utils.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66
import torch
77
from scipy.signal import savgol_filter
88

9-
import neurometry.curvature.datasets as datasets
9+
from neurometry.curvature.datasets.experimental import load_neural_activity
10+
from neurometry.curvature.datasets.gridcells import load_grid_cells_synthetic
11+
from neurometry.curvature.datasets.synthetic import (
12+
load_images,
13+
load_place_cells,
14+
load_points,
15+
load_projected_images,
16+
load_s1_synthetic,
17+
load_s2_synthetic,
18+
load_t2_synthetic,
19+
load_three_place_cells,
20+
)
1021

1122

1223
def load(config):
@@ -31,7 +42,7 @@ def load(config):
3142
test dataset.
3243
"""
3344
if config.dataset_name == "experimental":
34-
dataset, labels = datasets.experimental.load_neural_activity(
45+
dataset, labels = load_neural_activity(
3546
expt_id=config.expt_id, timestep_microsec=config.timestep_microsec
3647
)
3748
dataset = dataset[labels["velocities"] > 5]
@@ -72,24 +83,22 @@ def load(config):
7283
labels = labels[labels["gains"] == gain]
7384

7485
elif config.dataset_name == "synthetic":
75-
dataset, labels = datasets.synthetic.load_place_cells()
86+
dataset, labels = load_place_cells()
7687
dataset = np.log(dataset.astype(np.float32) + 1)
7788
dataset = (dataset - np.min(dataset)) / (np.max(dataset) - np.min(dataset))
7889
elif config.dataset_name == "images":
79-
dataset, labels = datasets.synthetic.load_images(img_size=config.img_size)
90+
dataset, labels = load_images(img_size=config.img_size)
8091
dataset = (dataset - np.min(dataset)) / (np.max(dataset) - np.min(dataset))
8192
height, width = dataset.shape[1:3]
8293
dataset = dataset.reshape((-1, height * width))
8394
elif config.dataset_name == "projected_images":
84-
dataset, labels = datasets.synthetic.load_projected_images(
85-
img_size=config.img_size
86-
)
95+
dataset, labels = load_projected_images(img_size=config.img_size)
8796
dataset = (dataset - np.min(dataset)) / (np.max(dataset) - np.min(dataset))
8897
elif config.dataset_name == "points":
89-
dataset, labels = datasets.synthetic.load_points()
98+
dataset, labels = load_points()
9099
dataset = dataset.astype(np.float32)
91100
elif config.dataset_name == "s1_synthetic":
92-
dataset, labels = datasets.synthetic.load_s1_synthetic(
101+
dataset, labels = load_s1_synthetic(
93102
synthetic_rotation=config.synthetic_rotation,
94103
n_times=config.n_times,
95104
radius=config.radius,
@@ -100,7 +109,7 @@ def load(config):
100109
geodesic_distortion_func=config.geodesic_distortion_func,
101110
)
102111
elif config.dataset_name == "s2_synthetic":
103-
dataset, labels = datasets.synthetic.load_s2_synthetic(
112+
dataset, labels = load_s2_synthetic(
104113
synthetic_rotation=config.synthetic_rotation,
105114
n_times=config.n_times,
106115
radius=config.radius,
@@ -109,7 +118,7 @@ def load(config):
109118
noise_var=config.noise_var,
110119
)
111120
elif config.dataset_name == "t2_synthetic":
112-
dataset, labels = datasets.synthetic.load_t2_synthetic(
121+
dataset, labels = load_t2_synthetic(
113122
synthetic_rotation=config.synthetic_rotation,
114123
n_times=config.n_times,
115124
major_radius=config.major_radius,
@@ -119,7 +128,7 @@ def load(config):
119128
noise_var=config.noise_var,
120129
)
121130
elif config.dataset_name == "grid_cells":
122-
dataset, labels = datasets.gridcells.load_grid_cells_synthetic(
131+
dataset, labels = load_grid_cells_synthetic(
123132
grid_scale=config.grid_scale,
124133
arena_dims=config.arena_dims,
125134
n_cells=config.n_cells,
@@ -129,7 +138,7 @@ def load(config):
129138
resolution=config.resolution,
130139
)
131140
elif config.dataset_name == "three_place_cells_synthetic":
132-
dataset, labels = datasets.synthetic.load_three_place_cells()
141+
dataset, labels = load_three_place_cells()
133142
print(f"Dataset shape: {dataset.shape}.")
134143
if type(dataset) == np.ndarray:
135144
dataset_torch = torch.from_numpy(dataset)

neurometry/curvature/default_config.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WANDB API KEY
1212
# Find it here: https://wandb.ai/authorize
1313
# Story it in file: api_key.txt (without extra line break)
14-
api_key_path = os.path.join(os.getcwd(),"api_key.txt")
14+
api_key_path = os.path.join(os.getcwd(), "api_key.txt")
1515
with open(api_key_path) as f:
1616
api_key = f.read()
1717

@@ -28,18 +28,14 @@
2828
if not os.path.exists(curvature_profiles_dir):
2929
os.makedirs(curvature_profiles_dir)
3030

31-
print(configs_dir)
32-
print(trained_models_dir)
33-
34-
3531
# Hardware
3632
device = "cuda" if torch.cuda.is_available() else "cpu"
3733

3834
# Can be replaced by logging.DEBUG or logging.WARNING
3935
logging.basicConfig(level=logging.INFO)
4036

4137
# Results
42-
project = "neurometry"
38+
project = "topo-vae"
4339
trained_model_path = None
4440

4541
### Fixed experiment parameters ###
@@ -164,10 +160,10 @@
164160

165161
# Only used of dataset_name in ["s1_synthetic", "s2_synthetic", "t2_synthetic"]
166162
n_times = [2500] # , 2000] # actual number of times is sqrt_ntimes ** 2
167-
embedding_dim = [5] # for s1 stopped at 5 (not done, but 3 was done)
163+
embedding_dim = [3] # for s1 stopped at 5 (not done, but 3 was done)
168164
geodesic_distortion_amp = [0.4]
169165
# TODO: Add 0.03, possibly 0,000[1
170-
noise_var = [0.1] # , 1e-2, 1e-1] 0.075, 0.1] #[
166+
noise_var = [1e-5] # , 1e-2, 1e-1] 0.075, 0.1] #[
171167

172168
# Only used if dataset_name == "grid_cells"
173169
grid_scale = [1.0]
@@ -188,7 +184,7 @@
188184
scheduler = False
189185
log_interval = 20
190186
checkpt_interval = 20
191-
n_epochs = 60 # 00 # 00 # 50 # 200 # 150 # 240
187+
n_epochs = 400 # 00 # 00 # 50 # 200 # 150 # 240
192188
sftbeta = 4.5 # beta parameter for softplus
193189
alpha = 1.0 # weight for the reconstruction loss
194190
beta = 0.03 # 0.03 # weight for KL loss
@@ -202,14 +198,14 @@
202198
### Ray sweep hyperparameters ###
203199
# --> Lists of values to sweep for each hyperparameter
204200
# Except for lr_min and lr_max which are floats
205-
lr_min = 0.0001
201+
lr_min = [0.001] # 0.0001
206202
lr_max = 0.1
207-
batch_size = [16, 64, 128] # [16,32,64]
208-
encoder_width = [200, 400] # [100,400] # , 100, 200, 300]
209-
encoder_depth = [4, 10, 12] # [4,6,8] # , 10, 20, 50, 100]
210-
decoder_width = [200, 400] # [100,400] # , 100, 200, 300]
211-
decoder_depth = [4, 6, 8] # [4,6,8] # , 10, 20, 50, 100]
212-
drop_out_p = [0, 0.1] # [0,0.1,0.2] # put probability p at 0. for no drop out
203+
batch_size = [64] # [16,32,64]
204+
encoder_width = [400] # [100,400] # , 100, 200, 300]
205+
encoder_depth = [10] # [4,6,8] # , 10, 20, 50, 100]
206+
decoder_width = [200] # [100,400] # , 100, 200, 300]
207+
decoder_depth = [6] # [4,6,8] # , 10, 20, 50, 100]
208+
drop_out_p = [0] # [0,0.1,0.2] # put probability p at 0. for no drop out
213209
for p in drop_out_p:
214210
assert p >= 0.0 and p <= 1, "Probability needs to be in [0, 1]"
215211

neurometry/curvature/evaluate.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66

77
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
88
import geomstats.backend as gs # noqa: E402
9+
from geomstats.geometry.pullback_metric import PullbackMetric # noqa: E402
10+
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # noqa: E402
911

1012
# import gph
1113
from neurometry.curvature.datasets.synthetic import ( # noqa: E402
1214
get_s1_synthetic_immersion,
1315
get_s2_synthetic_immersion,
1416
get_t2_synthetic_immersion,
1517
)
16-
from geomstats.geometry.pullback_metric import PullbackMetric # noqa: E402
17-
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # noqa: E402
1818

1919

2020
def get_learned_immersion(model, config):
@@ -106,7 +106,8 @@ def get_z_grid(config, n_grid_points=100):
106106
z_grid = torch.cartesian_prod(thetas, phis)
107107
return z_grid
108108

109-
#TODO: change instantiation of PullbackMetric to match latest geomstats version
109+
110+
# TODO: change instantiation of PullbackMetric to match latest geomstats version
110111
def _compute_curvature(z_grid, immersion, dim, embedding_dim):
111112
"""Compute mean curvature vector and its norm at each point."""
112113
neural_metric = PullbackMetric(

neurometry/curvature/grid-cells-curvature/models/xu_rnn/LSTM.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import torch
32
from labml_helpers.module import Module
43
from torch import nn

neurometry/curvature/grid-cells-curvature/models/xu_rnn/input_pipeline.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
self.dx_list = self._generate_dx_list(config.max_dr_trans)
1919
# self.dx_list = self._generate_dx_list_continous(config.max_dr_trans)
2020
self.scale_vector = np.zeros(self.num_blocks) + config.max_dr_isometry
21+
self.rng = self.rng.default_rng()
2122

2223
def __iter__(self):
2324
while True:
@@ -42,9 +43,9 @@ def _gen_data_kernel(self):
4243
batch_size = self.config.batch_size
4344
config = self.config
4445

45-
theta = np.random.random(size=int(batch_size * 1.5)) * 2 * np.pi
46+
theta = self.rng.random(size=int(batch_size * 1.5)) * 2 * np.pi
4647
dr = (
47-
np.abs(np.random.normal(size=int(batch_size * 1.5)) * config.sigma_data)
48+
np.abs(self.rng.normal(size=int(batch_size * 1.5)) * config.sigma_data)
4849
* self.num_grid
4950
)
5051
dx = _dr_theta_to_dx(dr, theta)
@@ -57,7 +58,7 @@ def _gen_data_kernel(self):
5758
x_max, x_min, dx = x_max[select_idx], x_min[select_idx], dx[select_idx]
5859
assert len(dx) == batch_size
5960

60-
x = np.random.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
61+
x = self.rng.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
6162
x_prime = x + dx
6263

6364
return {"x": x, "x_prime": x_prime}
@@ -69,7 +70,7 @@ def _gen_data_trans_rnn(self):
6970
n_steps = self.rnn_step
7071
dx_list = self.dx_list
7172

72-
dx_idx = np.random.choice(len(dx_list), size=[n_traj * 10, n_steps])
73+
dx_idx = self.rng.choice(len(dx_list), size=[n_traj * 10, n_steps])
7374
dx = dx_list[dx_idx] # [N, T, 2]
7475
dx_cumsum = np.cumsum(dx, axis=1) # [N, T, 2]
7576

@@ -86,7 +87,7 @@ def _gen_data_trans_rnn(self):
8687
x_start_max, x_start_min = x_start_max[select_idx], x_start_min[select_idx]
8788
dx_cumsum = dx_cumsum[select_idx]
8889
x_start = (
89-
np.random.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
90+
self.rng.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
9091
)
9192
x_start = x_start[:, None] # [N, 1, 2]
9293
x_start = np.round(x_start - 0.5)
@@ -99,13 +100,13 @@ def _gen_data_iso_numerical(self):
99100
batch_size = self.config.batch_size
100101
config = self.config
101102

102-
theta = np.random.random(size=(batch_size, 2)) * 2 * np.pi
103-
dr = np.sqrt(np.random.random(size=(batch_size, 1))) * config.max_dr_isometry
103+
theta = self.rng.random(size=(batch_size, 2)) * 2 * np.pi
104+
dr = np.sqrt(self.rng.random(size=(batch_size, 1))) * config.max_dr_isometry
104105
dx = _dr_theta_to_dx(dr, theta) # [N, 2, 2]
105106

106107
x_max = np.fmin(self.num_grid - 0.5, np.min(self.num_grid - 0.5 - dx, axis=1))
107108
x_min = np.fmax(-0.5, np.max(-0.5 - dx, axis=1))
108-
x = np.random.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
109+
x = self.rng.random(size=(batch_size, 2)) * (x_max - x_min) + x_min
109110
x_plus_dx1 = x + dx[:, 0]
110111
x_plus_dx2 = x + dx[:, 1]
111112

@@ -117,18 +118,18 @@ def _gen_data_iso_numerical_adaptive(self):
117118
config = self.config
118119

119120
theta = (
120-
np.random.random(size=(batch_size, num_blocks, 2)) * 2 * np.pi
121+
self.rng.random(size=(batch_size, num_blocks, 2)) * 2 * np.pi
121122
) # (batch_size, num_blocks, 2)
122123
dr = (
123-
np.sqrt(np.random.random(size=(batch_size, num_blocks, 1)))
124+
np.sqrt(self.rng.random(size=(batch_size, num_blocks, 1)))
124125
* np.tile(self.scale_vector, (batch_size, 1))[:, :, None]
125126
) # (batch_size, num_blocks, 1)
126127
dx = _dr_theta_to_dx(dr, theta) # [N, num_blocks, 2, 2]
127128

128129
x_max = np.fmin(self.num_grid - 0.5, np.min(self.num_grid - 0.5 - dx, axis=2))
129130
x_min = np.fmax(-0.5, np.max(-0.5 - dx, axis=2))
130131
x = (
131-
np.random.random(size=(batch_size, num_blocks, 2)) * (x_max - x_min) + x_min
132+
self.rng.random(size=(batch_size, num_blocks, 2)) * (x_max - x_min) + x_min
132133
) # (batch_size, num_blocks, 2)
133134
x_plus_dx1 = x + dx[:, :, 0]
134135
x_plus_dx2 = x + dx[:, :, 1]
@@ -157,9 +158,9 @@ def _generate_dx_list_continous(self, max_dr):
157158
dx_list = []
158159
batch_size = self.config.batch_size
159160

160-
dr = np.sqrt(np.random.random(size=(batch_size,))) * max_dr
161-
np.random.shuffle(dr)
162-
theta = np.random.random(size=(batch_size,)) * 2 * np.pi
161+
dr = np.sqrt(self.rng.random(size=(batch_size,))) * max_dr
162+
self.rng.shuffle(dr)
163+
theta = self.rng.random(size=(batch_size,)) * 2 * np.pi
163164

164165
dx = _dr_theta_to_dx(dr, theta)
165166

@@ -202,7 +203,7 @@ def _gen_trajectory_vis(self, n_traj, n_steps):
202203
x_start = np.reshape([5, 5], newshape=(1, 1, 2)) # [1, 1, 2]
203204
dx_idx_pool = np.where((dx_list[:, 0] >= -1) & (dx_list[:, 1] >= -1))[0]
204205
# dx_idx_pool = np.where((dx_list[:, 0] >= 0) & (dx_list[:, 1] >= -1))[0]
205-
dx_idx = np.random.choice(dx_idx_pool, size=[n_traj * 50, n_steps])
206+
dx_idx = self.rng.choice(dx_idx_pool, size=[n_traj * 50, n_steps])
206207
dx = dx_list[dx_idx]
207208
dx_cumsum = np.cumsum(dx, axis=1) # [N, T, 2]
208209

@@ -224,7 +225,7 @@ def _gen_trajectory(self, n_traj, n_steps):
224225
# uniformly wihtin the whole region.
225226
dx_list = self.dx_list
226227

227-
dx_idx = np.random.choice(len(dx_list), size=[n_traj * 10, n_steps])
228+
dx_idx = self.rng.choice(len(dx_list), size=[n_traj * 10, n_steps])
228229
dx = dx_list[dx_idx] # [N, T, 2]
229230
dx_cumsum = np.cumsum(dx, axis=1) # [N, T, 2]
230231

@@ -241,7 +242,7 @@ def _gen_trajectory(self, n_traj, n_steps):
241242
x_start_max, x_start_min = x_start_max[select_idx], x_start_min[select_idx]
242243
dx_cumsum = dx_cumsum[select_idx]
243244
x_start = (
244-
np.random.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
245+
self.rng.random((n_traj, 2)) * (x_start_max - x_start_min) + x_start_min
245246
)
246247
x_start = x_start[:, None] # [N, 1, 2]
247248
x_start = np.round(x_start - 0.5)

neurometry/curvature/grid-cells-curvature/models/xu_rnn/model.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,9 @@ def get_grid_code(codebook, x, num_grid):
507507
align_corners=False,
508508
) # [1, C, 1, N]
509509

510-
v_x = torch.squeeze(torch.squeeze(v_x, 0), 1).transpose(0, 1) # [N, C]
511510
# v_x = v_x.squeeze().transpose(0, 1)
512511

513-
return v_x
512+
return torch.squeeze(torch.squeeze(v_x, 0), 1).transpose(0, 1) # [N, C]
514513

515514

516515
def get_grid_code_block(codebook, x, num_grid, block_size):
@@ -537,7 +536,10 @@ def get_grid_code_int(codebook, x, num_grid):
537536

538537
# query the 2D codebook, no interpolation
539538
v_x = torch.vstack(
540-
[codebook[:, i, j] for i, j in zip(x_normalized[:, 0], x_normalized[:, 1], strict=False)]
539+
[
540+
codebook[:, i, j]
541+
for i, j in zip(x_normalized[:, 0], x_normalized[:, 1], strict=False)
542+
]
541543
)
542544
# v_x = v_x.squeeze().transpose(0, 1) # [N, C]
543545

neurometry/curvature/grid-cells-curvature/models/xu_rnn/scores.py

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
"""Grid score calculations."""
1717

18-
1918
import math
2019

2120
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)