Skip to content

Commit c79271d

Browse files
Merge pull request #143 from geometric-intelligence/backup-before-detaching-submodules
Backup before detaching submodules
2 parents 1a29509 + b60d7ed commit c79271d

File tree

9 files changed

+13155
-167
lines changed

9 files changed

+13155
-167
lines changed

neurometry/datasets/load_rnn_grid_cells.py

+71-16
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def load_activations(epochs, version="single", verbose=True):
3333
activations = []
3434
rate_maps = []
3535
state_points = []
36+
positions = []
3637

3738
if version == "single":
3839
activations_dir = (
@@ -51,53 +52,95 @@ def load_activations(epochs, version="single", verbose=True):
5152
rate_map_epoch_path = (
5253
activations_dir + f"rate_map_{version}_agent_epoch_{epoch}.npy"
5354
)
55+
positions_epoch_path = (
56+
activations_dir + f"positions_{version}_agent_epoch_{epoch}.npy"
57+
)
58+
5459
if os.path.exists(activations_epoch_path) and os.path.exists(
5560
rate_map_epoch_path
56-
):
61+
) and os.path.exists(positions_epoch_path):
5762
activations.append(np.load(activations_epoch_path))
5863
rate_maps.append(np.load(rate_map_epoch_path))
64+
positions.append(np.load(positions_epoch_path))
5965
if verbose:
60-
print(f"Epoch {epoch} found!!! :D")
66+
print(f"Epoch {epoch} found!")
6167
else:
6268
print(f"Epoch {epoch} not found. Loading ...")
6369
parser = config.parser
6470
options, _ = parser.parse_known_args()
6571
options.run_ID = utils.generate_run_ID(options)
66-
if type == "single":
72+
if version == "single":
6773
(
6874
activations_single_agent,
6975
rate_map_single_agent,
76+
positions_single_agent,
7077
) = single_agent_activity.main(options, epoch=epoch)
7178
activations.append(activations_single_agent)
7279
rate_maps.append(rate_map_single_agent)
73-
elif type == "dual":
74-
activations_dual_agent, rate_map_dual_agent = dual_agent_activity.main(
80+
positions.append(positions_single_agent)
81+
elif version == "dual":
82+
activations_dual_agent, rate_map_dual_agent, positions_dual_agent = dual_agent_activity.main(
7583
options, epoch=epoch
7684
)
7785
activations.append(activations_dual_agent)
7886
rate_maps.append(rate_map_dual_agent)
87+
positions.append(positions_dual_agent)
88+
print(len(activations))
7989
state_points_epoch = activations[-1].reshape(activations[-1].shape[0], -1)
8090
state_points.append(state_points_epoch)
8191

8292
if verbose:
8393
print(f"Loaded epochs {epochs} of {version} agent model.")
8494
print(
85-
f"There are {activations[0].shape[0]} grid cells with {activations[0].shape[1]} x {activations[0].shape[2]} environment resolution, averaged over {activations[0].shape[3]} trajectories."
95+
f"activations has shape {activations[0].shape}. There are {activations[0].shape[0]} grid cells with {activations[0].shape[1]} x {activations[0].shape[2]} environment resolution, averaged over {activations[0].shape[3]} trajectories."
8696
)
8797
print(
88-
f"There are {state_points[0].shape[1]} data points in the {state_points[0].shape[0]}-dimensional state space."
98+
f"state_points has shape {state_points[0].shape}. There are {state_points[0].shape[1]} data points in the {state_points[0].shape[0]}-dimensional state space."
8999
)
90100
print(
91-
f"There are {rate_maps[0].shape[1]} data points averaged over {activations[0].shape[3]} trajectories in the {rate_maps[0].shape[0]}-dimensional state space."
101+
f"rate_maps has shape {rate_maps[0].shape}. There are {rate_maps[0].shape[1]} data points averaged over {activations[0].shape[3]} trajectories in the {rate_maps[0].shape[0]}-dimensional state space."
92102
)
103+
print(f"positions has shape {positions[0].shape}.")
104+
105+
return activations, rate_maps, state_points, positions
106+
107+
108+
# def plot_rate_map(indices, num_plots, activations, title):
109+
# rng = np.random.default_rng(seed=0)
110+
# if indices is None:
111+
# idxs = rng.integers(0, 4095, num_plots)
112+
# else:
113+
# idxs = indices
114+
# num_plots = len(indices)
115+
116+
# rows = 4
117+
# cols = num_plots // rows + (num_plots % rows > 0)
93118

94-
return activations, rate_maps, state_points
119+
# plt.rcParams["text.usetex"] = False
95120

121+
# fig, axes = plt.subplots(rows, cols, figsize=(20, 8))
96122

97-
def plot_rate_map(indices, num_plots, activations):
123+
# for i in range(rows):
124+
# for j in range(cols):
125+
# if i * cols + j < num_plots:
126+
# gc = np.mean(activations[idxs[i * cols + j]], axis=2)
127+
# axes[i, j].imshow(gc)
128+
# axes[i, j].set_title(f"grid cell id: {idxs[i * cols + j]}")
129+
# axes[i, j].axis("off")
130+
# else:
131+
# axes[i, j].axis("off")
132+
133+
# plt.suptitle(title)
134+
# plt.tight_layout()
135+
# plt.show()
136+
137+
import numpy as np
138+
import matplotlib.pyplot as plt
139+
140+
def plot_rate_map(indices, num_plots, activations, title):
98141
rng = np.random.default_rng(seed=0)
99142
if indices is None:
100-
idxs = rng.integers(0, 4095, num_plots)
143+
idxs = rng.integers(0, activations.shape[0]-1, num_plots)
101144
else:
102145
idxs = indices
103146
num_plots = len(indices)
@@ -112,12 +155,24 @@ def plot_rate_map(indices, num_plots, activations):
112155
for i in range(rows):
113156
for j in range(cols):
114157
if i * cols + j < num_plots:
115-
gc = np.mean(activations[idxs[i * cols + j]], axis=2)
116-
axes[i, j].imshow(gc)
117-
axes[i, j].set_title(f"grid cell id: {idxs[i * cols + j]}")
118-
axes[i, j].axis("off")
158+
if len(activations.shape) == 4:
159+
gc = np.mean(activations[idxs[i * cols + j]], axis=2)
160+
else:
161+
gc = activations[idxs[i * cols + j]]
162+
if axes.ndim > 1: # Check if axes is a 2D array
163+
ax = axes[i, j]
164+
else: # If axes is flattened (e.g., only one row of subplots)
165+
ax = axes[i * cols + j]
166+
ax.imshow(gc)
167+
ax.set_title(f"grid cell id: {idxs[i * cols + j]}", fontsize=10)
168+
ax.axis("off")
119169
else:
120-
axes[i, j].axis("off")
170+
if axes.ndim > 1:
171+
axes[i, j].axis("off")
172+
else:
173+
axes[i * cols + j].axis("off")
121174

175+
fig.suptitle(title, fontsize=30)
122176
plt.tight_layout()
123177
plt.show()
178+

neurometry/datasets/rnn_grid_cells/config.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class Config:
2121
periodic = False # trajectories with periodic boundary conditions
2222
box_width = 2.2 # width of training environment
2323
box_height = 2.2 # height of training environment
24-
device = (
25-
"cuda" if torch.cuda.is_available() else "cpu"
26-
) # device to use for training
24+
# device = (
25+
# "cuda" if torch.cuda.is_available() else "cpu"
26+
# ) # device to use for training
27+
device = torch.device('cuda:8')
2728
n_avg = 50 # number of trajectories to average over for rate maps
2829

2930

neurometry/datasets/rnn_grid_cells/dual_agent_activity.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
import numpy as np
55
import torch
6-
from config import parser
7-
from model_dual_path_integration import RNN
8-
from place_cells_dual_path_integration import PlaceCells
9-
from scores import GridScorer
6+
from .config import parser
7+
from .model_dual_path_integration import RNN
8+
from .place_cells_dual_path_integration import PlaceCells
9+
from .scores import GridScorer
1010
from tqdm import tqdm
11-
from trajectory_generator_dual_path_integration import TrajectoryGenerator
12-
from utils import generate_run_ID
13-
from visualize import compute_ratemaps
11+
from .trajectory_generator_dual_path_integration import TrajectoryGenerator
12+
from .utils import generate_run_ID
13+
from .visualize import compute_ratemaps
1414

1515
parent_dir = os.getcwd() + "/"
1616
model_folder = "Dual agent path integration disjoint PCs/Seed 1 weight decay 1e-06/"
@@ -44,7 +44,7 @@ def main(options, epoch="final", res=20):
4444

4545
Ng = options.Ng
4646
n_avg = options.n_avg
47-
activations_dual_agent, rate_map_dual_agent, _, _ = compute_ratemaps(
47+
activations_dual_agent, rate_map_dual_agent, _, positions_dual_agent = compute_ratemaps(
4848
model_dual_agent,
4949
trajectory_generator,
5050
options,
@@ -63,11 +63,15 @@ def main(options, epoch="final", res=20):
6363
np.save(
6464
activations_dir + f"rate_map_dual_agent_epoch_{epoch}.npy", rate_map_dual_agent
6565
)
66+
np.save(
67+
activations_dir + f"positions_dual_agent_epoch_{epoch}.npy",
68+
positions_dual_agent,
69+
)
6670

6771
# # activations is in the shape [number of grid cells (Ng) x res x res x n_avg]
6872
# # ratemap is in the shape [Ng x res^2]
6973

70-
return activations_dual_agent, rate_map_dual_agent
74+
return activations_dual_agent, rate_map_dual_agent, positions_dual_agent
7175

7276

7377
def compute_grid_scores(res, rate_map_dual_agent, scorer):

neurometry/datasets/rnn_grid_cells/scores.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_grid_scores_for_mask(self, sac, rotated_sacs, mask):
161161
masked_sac_centered = (masked_sac - masked_sac_mean) * mask
162162
variance = np.sum(masked_sac_centered**2) / ring_area + 1e-5
163163
corrs = dict()
164-
for angle, rotated_sac in zip(self._corr_angles, rotated_sacs, strict=False):
164+
for angle, rotated_sac in zip(self._corr_angles, rotated_sacs,strict=False):
165165
masked_rotated_sac = (rotated_sac - masked_sac_mean) * mask
166166
cross_prod = np.sum(masked_sac_centered * masked_rotated_sac) / ring_area
167167
corrs[angle] = cross_prod / variance

neurometry/datasets/rnn_grid_cells/single_agent_activity.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@
55

66
import numpy as np
77
import torch
8-
from config import parser
9-
from model import RNN
10-
from place_cells import PlaceCells
11-
from scores import GridScorer
8+
from .config import parser
9+
from .model import RNN
10+
from .place_cells import PlaceCells
11+
from .scores import GridScorer
1212
from tqdm import tqdm
13-
from trajectory_generator import TrajectoryGenerator
14-
from utils import generate_run_ID
15-
from visualize import compute_ratemaps
13+
from .trajectory_generator import TrajectoryGenerator
14+
from .utils import generate_run_ID
15+
from .visualize import compute_ratemaps
1616

17-
parent_dir = os.getcwd() + "/"
17+
# parent_dir = os.getcwd() + "/"
18+
parent_dir = "/scratch/facosta/rnn_grid_cells/"
1819

1920

2021
model_folder = "Single agent path integration/Seed 1 weight decay 1e-06/"
@@ -38,6 +39,7 @@ def main(options, epoch="final", res=20):
3839

3940
model_single_agent = model.to(options.device)
4041

42+
4143
model_name = "final_model.pth" if epoch == "final" else f"epoch_{epoch}.pth"
4244
saved_model_single_agent = torch.load(
4345
parent_dir + model_folder + model_parameters + model_name
@@ -49,7 +51,7 @@ def main(options, epoch="final", res=20):
4951
Ng = options.Ng
5052
n_avg = options.n_avg
5153

52-
activations_single_agent, rate_map_single_agent, _, _ = compute_ratemaps(
54+
activations_single_agent, rate_map_single_agent, _, positions_single_agent = compute_ratemaps(
5355
model_single_agent,
5456
trajectory_generator,
5557
options,
@@ -69,10 +71,12 @@ def main(options, epoch="final", res=20):
6971
activations_dir + f"rate_map_single_agent_epoch_{epoch}.npy",
7072
rate_map_single_agent,
7173
)
74+
75+
np.save(activations_dir + f"positions_single_agent_epoch_{epoch}.npy", positions_single_agent)
7276
# # activations is in the shape [number of grid cells (Ng) x res x res x n_avg]
7377
# # ratemap is in the shape [Ng x res^2]
7478

75-
return activations_single_agent, rate_map_single_agent
79+
return activations_single_agent, rate_map_single_agent, positions_single_agent
7680

7781

7882
def compute_grid_scores(res, rate_map_single_agent, scorer):

neurometry/datasets/synthetic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch # noqa: E402
66
from geomstats.geometry.euclidean import Euclidean # noqa: E402
77
from geomstats.geometry.hypersphere import Hypersphere # noqa: E402
8-
from geomstats.geometry.klein_bottle import KleinBottle # noqa: E402
8+
#from geomstats.geometry.klein_bottle import KleinBottle # noqa: E402
99
from geomstats.geometry.product_manifold import ProductManifold # noqa: E402
1010

1111

0 commit comments

Comments
 (0)