Skip to content

Commit 6852b45

Browse files
Merge pull request #149 from geometric-intelligence/analyze_dual_agents
Run topo vae, Analyze dual agents
2 parents 99a8043 + 5dad96d commit 6852b45

18 files changed

+11101
-6515
lines changed

.gitignore

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# API KEY
2-
neurometry/api_key.txt
2+
*api_key.txt
3+
4+
neurometry/results/*
5+
neurometry/wandb/*
6+
7+
neurometry/datasets/rnn_grid_cells/Dual agent path integration high res/*
8+
neurometry/datasets/rnn_grid_cells/Single agent path integration high res/*
39

410

511
*viewer*

neurometry/curvature/default_config.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WANDB API KEY
1212
# Find it here: https://wandb.ai/authorize
1313
# Story it in file: api_key.txt (without extra line break)
14-
with open("api_key.txt") as f:
14+
api_key_path = os.path.join(os.getcwd(), "curvature","api_key.txt")
15+
with open(api_key_path) as f:
1516
api_key = f.read()
1617

1718
# Directories
@@ -139,7 +140,7 @@
139140

140141
# Datasets
141142
# dataset_name = ["s1_synthetic", "s2_synthetic"]
142-
dataset_name = ["kb_synthetic"]
143+
dataset_name = ["s1_synthetic"]
143144
for one_dataset_name in dataset_name:
144145
if one_dataset_name not in [
145146
"s1_synthetic",
@@ -163,10 +164,10 @@
163164

164165
# Only used of dataset_name in ["s1_synthetic", "s2_synthetic", "t2_synthetic"]
165166
n_times = [2500] # , 2000] # actual number of times is sqrt_ntimes ** 2
166-
embedding_dim = [3, 10, 20, 30] # for s1 stopped at 5 (not done, but 3 was done)
167+
embedding_dim = [5] # for s1 stopped at 5 (not done, but 3 was done)
167168
geodesic_distortion_amp = [0.4]
168169
# TODO: Add 0.03, possibly 0,000[1
169-
noise_var = [0.1, 0.075, 0.05, 0.03, 0.01, 0.005, 0.001] # , 1e-2, 1e-1] 0.075, 0.1] #[
170+
noise_var = [0.1] # , 1e-2, 1e-1] 0.075, 0.1] #[
170171

171172
# Only used if dataset_name == "grid_cells"
172173
grid_scale = [1.0]

neurometry/curvature/evaluate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import geomstats.backend as gs # noqa: E402
99

1010
# import gph
11-
from datasets.synthetic import ( # noqa: E402
11+
from neurometry.curvature.datasets.synthetic import ( # noqa: E402
1212
get_s1_synthetic_immersion,
1313
get_s2_synthetic_immersion,
1414
get_t2_synthetic_immersion,
@@ -106,7 +106,7 @@ def get_z_grid(config, n_grid_points=100):
106106
z_grid = torch.cartesian_prod(thetas, phis)
107107
return z_grid
108108

109-
109+
#TODO: change instantiation of PullbackMetric to match latest geomstats version
110110
def _compute_curvature(z_grid, immersion, dim, embedding_dim):
111111
"""Compute mean curvature vector and its norm at each point."""
112112
neural_metric = PullbackMetric(

neurometry/curvature/losses.py

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def elbo(x, x_mu, posterior_params, z, labels, config):
6767

6868
if config.gen_likelihood_type == "gaussian":
6969
recon_loss = torch.mean((x - x_mu).pow(2))
70+
else:
71+
raise NotImplementedError
7072

7173
if config.dataset_name == "s1_synthetic":
7274
recon_loss = recon_loss / (config.radius**2)

neurometry/curvature/main.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@
1414

1515
# from ray.tune.integration.wandb import wandb_mixin
1616
import torch
17-
import train
18-
import viz
17+
import neurometry.curvature.train as train
18+
import neurometry.curvature.viz as viz
1919
import wandb
2020
from ray import air, tune
2121
from ray.tune.schedulers import AsyncHyperBandScheduler
2222
from ray.tune.search.hyperopt import HyperOptSearch
2323

2424
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
25-
import datasets.utils # noqa: E402
26-
import default_config # noqa: E402
27-
import evaluate # noqa: E402
25+
import neurometry.curvature.datasets.utils as utils # noqa: E402
26+
import neurometry.curvature.default_config as default_config # noqa: E402
27+
import neurometry.curvature.evaluate as evaluate # noqa: E402
2828
import geomstats.backend as gs # noqa: E402
29-
import models.klein_bottle_vae # noqa: E402
30-
import models.neural_vae # noqa: E402
31-
import models.toroidal_vae # noqa: E402
29+
import neurometry.curvature.models.klein_bottle_vae as klein_bottle_vae # noqa: E402
30+
import neurometry.curvature.models.neural_vae as neural_vae # noqa: E402
31+
import neurometry.curvature.models.toroidal_vae as toroidal_vae # noqa: E402
3232

3333
# Required to make matplotlib figures in threads:
3434
matplotlib.use("Agg")
@@ -262,7 +262,7 @@ def main_run(sweep_config):
262262
wandb.run.name = run_name
263263

264264
# Load data, labels
265-
dataset, labels, train_loader, test_loader = datasets.utils.load(wandb_config)
265+
dataset, labels, train_loader, test_loader = utils.load(wandb_config)
266266
data_n_times, data_dim = dataset.shape
267267
wandb_config.update(
268268
{
@@ -344,7 +344,7 @@ def create_model_and_train_test(config, train_loader, test_loader):
344344
torch.manual_seed(0)
345345
if torch.cuda.is_available():
346346
torch.cuda.manual_seed_all(0)
347-
model = models.neural_vae.NeuralVAE(
347+
model = neural_vae.NeuralVAE(
348348
data_dim=data_dim,
349349
latent_dim=config.latent_dim,
350350
sftbeta=config.sftbeta,
@@ -360,7 +360,7 @@ def create_model_and_train_test(config, train_loader, test_loader):
360360
torch.manual_seed(0)
361361
if torch.cuda.is_available():
362362
torch.cuda.manual_seed_all(0)
363-
model = models.toroidal_vae.ToroidalVAE(
363+
model = toroidal_vae.ToroidalVAE(
364364
data_dim=data_dim,
365365
latent_dim=config.latent_dim,
366366
sftbeta=config.sftbeta,
@@ -375,7 +375,7 @@ def create_model_and_train_test(config, train_loader, test_loader):
375375
torch.manual_seed(0)
376376
if torch.cuda.is_available():
377377
torch.cuda.manual_seed_all(0)
378-
model = models.klein_bottle_vae.KleinBottleVAE(
378+
model = klein_bottle_vae.KleinBottleVAE(
379379
data_dim=data_dim,
380380
latent_dim=config.latent_dim,
381381
sftbeta=config.sftbeta,
@@ -587,4 +587,4 @@ def curvature_compute_plot_log(config, dataset, labels, model):
587587
plt.close("all")
588588

589589

590-
main()
590+
#main()

neurometry/curvature/models/klein_bottle_vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import geomstats.backend as gs
77
import torch
8-
from hyperspherical.distributions import VonMisesFisher
8+
from neurometry.curvature.hyperspherical.distributions.von_mises_fisher import VonMisesFisher
99
from torch.nn import functional as F
1010

1111

neurometry/curvature/models/neural_vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import torch
7-
from hyperspherical.distributions import VonMisesFisher
7+
from neurometry.curvature.hyperspherical.distributions.von_mises_fisher import VonMisesFisher
88
from torch.distributions.normal import Normal
99
from torch.nn import functional as F
1010

neurometry/curvature/models/toroidal_vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import geomstats.backend as gs
77
import torch
8-
from hyperspherical.distributions import VonMisesFisher
8+
from neurometry.curvature.hyperspherical.distributions.von_mises_fisher import VonMisesFisher
99
from torch.nn import functional as F
1010

1111

neurometry/curvature/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import copy
44

5-
import losses
5+
import neurometry.curvature.losses as losses
66
import torch
77
import wandb
88

neurometry/datasets/load_rnn_grid_cells.py

+31-51
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,31 @@
1212
utils,
1313
)
1414

15-
# Loading single agent model
1615

17-
# parent_dir = os.getcwd() + "/datasets/rnn_grid_cells/"
18-
19-
parent_dir = "/scratch/facosta/rnn_grid_cells/"
20-
21-
22-
single_model_folder = "Single agent path integration/Seed 1 weight decay 1e-06/"
23-
single_model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"
24-
25-
26-
dual_model_folder = (
27-
"Dual agent path integration disjoint PCs/Seed 1 weight decay 1e-06/"
28-
)
29-
dual_model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"
30-
31-
32-
def load_activations(epochs, version="single", verbose=True):
16+
def load_activations(epochs, file_path, version="single", verbose=True, save = True):
3317
activations = []
3418
rate_maps = []
3519
state_points = []
3620
positions = []
21+
g_s = []
3722

38-
if version == "single":
39-
activations_dir = (
40-
parent_dir + single_model_folder + single_model_parameters + "activations/"
41-
)
42-
elif version == "dual":
43-
activations_dir = (
44-
parent_dir + dual_model_folder + dual_model_parameters + "activations/"
45-
)
23+
activations_dir = os.path.join(file_path, "activations")
4624

47-
random.seed(0)
4825
for epoch in epochs:
49-
activations_epoch_path = (
50-
activations_dir + f"activations_{version}_agent_epoch_{epoch}.npy"
51-
)
52-
rate_map_epoch_path = (
53-
activations_dir + f"rate_map_{version}_agent_epoch_{epoch}.npy"
54-
)
55-
positions_epoch_path = (
56-
activations_dir + f"positions_{version}_agent_epoch_{epoch}.npy"
57-
)
58-
59-
if (
60-
os.path.exists(activations_epoch_path)
61-
and os.path.exists(rate_map_epoch_path)
62-
and os.path.exists(positions_epoch_path)
63-
):
26+
activations_epoch_path = os.path.join(activations_dir, f"activations_{version}_agent_epoch_{epoch}.npy")
27+
rate_map_epoch_path = os.path.join(activations_dir, f"rate_map_{version}_agent_epoch_{epoch}.npy")
28+
positions_epoch_path = os.path.join(activations_dir, f"positions_{version}_agent_epoch_{epoch}.npy")
29+
gs_epoch_path = os.path.join(activations_dir, f"g_{version}_agent_epoch_{epoch}.npy")
30+
31+
if os.path.exists(activations_epoch_path) and os.path.exists(
32+
rate_map_epoch_path
33+
) and os.path.exists(positions_epoch_path) and os.path.exists(gs_epoch_path):
6434
activations.append(np.load(activations_epoch_path))
6535
rate_maps.append(np.load(rate_map_epoch_path))
6636
positions.append(np.load(positions_epoch_path))
37+
g_s.append(np.load(gs_epoch_path))
6738
if verbose:
68-
print(f"Epoch {epoch} found!")
39+
print(f"Epoch {epoch} found.")
6940
else:
7041
print(f"Epoch {epoch} not found. Loading ...")
7142
parser = config.parser
@@ -75,22 +46,32 @@ def load_activations(epochs, version="single", verbose=True):
7546
(
7647
activations_single_agent,
7748
rate_map_single_agent,
49+
g_single_agent,
7850
positions_single_agent,
79-
) = single_agent_activity.main(options, epoch=epoch)
51+
) = single_agent_activity.main(options, file_path, epoch=epoch)
8052
activations.append(activations_single_agent)
8153
rate_maps.append(rate_map_single_agent)
8254
positions.append(positions_single_agent)
55+
g_s.append(g_single_agent)
8356
elif version == "dual":
84-
activations_dual_agent, rate_map_dual_agent, positions_dual_agent = (
85-
dual_agent_activity.main(options, epoch=epoch)
86-
)
57+
activations_dual_agent, rate_map_dual_agent, g_dual_agent, positions_dual_agent = dual_agent_activity.main(
58+
options, file_path, epoch=epoch)
8759
activations.append(activations_dual_agent)
8860
rate_maps.append(rate_map_dual_agent)
8961
positions.append(positions_dual_agent)
90-
print(len(activations))
62+
g_s.append(g_dual_agent)
63+
64+
if save:
65+
np.save(activations_epoch_path, activations[-1])
66+
np.save(rate_map_epoch_path, rate_maps[-1])
67+
np.save(positions_epoch_path, positions[-1])
68+
np.save(gs_epoch_path, g_s[-1])
69+
9170
state_points_epoch = activations[-1].reshape(activations[-1].shape[0], -1)
9271
state_points.append(state_points_epoch)
9372

73+
74+
9475
if verbose:
9576
print(f"Loaded epochs {epochs} of {version} agent model.")
9677
print(
@@ -104,7 +85,7 @@ def load_activations(epochs, version="single", verbose=True):
10485
)
10586
print(f"positions has shape {positions[0].shape}.")
10687

107-
return activations, rate_maps, state_points, positions
88+
return activations, rate_maps, state_points, positions, g_s
10889

10990

11091
# def plot_rate_map(indices, num_plots, activations, title):
@@ -137,9 +118,8 @@ def load_activations(epochs, version="single", verbose=True):
137118
# plt.show()
138119

139120

140-
141-
def plot_rate_map(indices, num_plots, activations, title):
142-
rng = np.random.default_rng(seed=0)
121+
def plot_rate_map(indices, num_plots, activations, title, seed=None):
122+
rng = np.random.default_rng(seed=seed)
143123
if indices is None:
144124
idxs = rng.integers(0, activations.shape[0] - 1, num_plots)
145125
else:

neurometry/datasets/rnn_grid_cells/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class Config:
2828
n_avg = 50 # number of trajectories to average over for rate maps
2929

3030

31+
3132
# If you need to access the configuration as a dictionary
3233
config = Config.__dict__
3334

0 commit comments

Comments
 (0)