Skip to content

Commit d12e3a6

Browse files
Merge pull request #155 from geometric-intelligence/half_plane_s_x
Half plane s x
2 parents f177e9d + 3d71b59 commit d12e3a6

File tree

8 files changed

+90030
-85319
lines changed

8 files changed

+90030
-85319
lines changed

neurometry/curvature/grid-cells-curvature/models/xu_rnn/configs/rnn_isometry.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ def d(**kwargs):
1010
def get_config():
1111
"""Get the hyperparameters for the model"""
1212
config = ml_collections.ConfigDict()
13-
config.gpu = 4
13+
config.gpu = 5
1414

1515
# training config
1616
config.train = d(
17-
num_steps_train=30000, # 100000
17+
load_pretrain=True,
18+
pretrain_dir="logs/rnn_isometry/20240418-180712/ckpt/model/checkpoint-step25000.pth",
19+
num_steps_train=40000, # 100000
1820
lr=0.006,
1921
lr_decay_from=10000,
2022
steps_per_logging=20,
@@ -45,20 +47,22 @@ def get_config():
4547
block_size=12,
4648
sigma=0.07,
4749
w_kernel=1.05,
48-
w_trans=0.1,
50+
w_trans=10,#0.1
4951
w_isometry=0.005,
5052
w_reg_u=0.2,
5153
reg_decay_until=15000,
5254
adaptive_dr=True,
53-
s_0 = 0.5,
54-
x_star = torch.tensor([0.2, 0.5]),
55-
sigma_star = 0.3,
56-
reward_step = 11000,
55+
s_0 = 5,
56+
x_star = torch.tensor([0.5, 0.5]),
57+
sigma_star_x = 0.2,
58+
sigma_star_y = 0.2,
59+
reward_step = 3000,
60+
saliency_type = "left_half",
5761
)
5862

5963
# path integration
6064
config.integration = d(
61-
n_inte_step=30,
65+
n_inte_step=30, # 50
6266
n_traj=100,
6367
n_inte_step_vis=30,
6468
n_traj_vis=5,

neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ def __init__(self, rng, config: ml_collections.ConfigDict, device):
5858
self.model.parameters(), lr=config.train.lr, momentum=0.9
5959
)
6060

61+
if config.train.load_pretrain:
62+
logging.info("==== load pretrain model ====")
63+
ckpt_model_path = config.train.pretrain_dir
64+
logging.info(f"Loading pretrain model from {ckpt_model_path}")
65+
ckpt = torch.load(ckpt_model_path, map_location=device)
66+
self.model.load_state_dict(ckpt["state_dict"])
67+
logging.info("==== load pretrained optimizer ====")
68+
self.optimizer.load_state_dict(ckpt["optimizer"])
69+
self.starting_step = ckpt["step"]
70+
else:
71+
self.starting_step = 1
72+
6173
def train_and_evaluate(self, workdir):
6274
logging.info("==== Experiment.train_and_evaluate() ===")
6375

@@ -81,7 +93,7 @@ def train_and_evaluate(self, workdir):
8193

8294
logging.info("==== Start of training ====")
8395
with metric_writers.ensure_flushes(writer):
84-
for step in range(1, config.num_steps_train + 1):
96+
for step in range(self.starting_step, config.num_steps_train + self.starting_step):
8597
batch_data = utils.dict_to_device(next(self.train_iter), self.device)
8698

8799
if 120000 > step > 10000:
@@ -151,7 +163,7 @@ def train_and_evaluate(self, workdir):
151163
)
152164
train_metrics = []
153165

154-
if step % config.steps_per_large_logging == 0:
166+
if step == self.starting_step or step % config.steps_per_large_logging == 0:
155167
ckpt_dir = os.path.join(workdir, "ckpt")
156168
if not tf.io.gfile.exists(ckpt_dir):
157169
tf.io.gfile.makedirs(ckpt_dir)
@@ -166,8 +178,8 @@ def visualize(activations):
166178
)[:10, :10]
167179
return utils.draw_heatmap(activations)
168180

169-
images_v = visualize(self.model.encoder.v, "v")
170-
images_u = visualize(self.model.decoder.u, "u")
181+
images_v = visualize(self.model.encoder.v)
182+
images_u = visualize(self.model.decoder.u)
171183

172184
writer.write_images(step, {"v": images_v})
173185
wandb.log({"v": wandb.Image(images_v)}, step=step)
@@ -239,7 +251,7 @@ def visualize(activations):
239251
step=step,
240252
)
241253

242-
if step % config.steps_per_integration == 0 or step == 1:
254+
if step % config.steps_per_integration == 0 or step == self.starting_step:
243255
# perform path integration
244256
with torch.no_grad():
245257
eval_data = utils.dict_to_device(

neurometry/curvature/grid-cells-curvature/models/xu_rnn/model.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ class GridCellConfig:
2323
adaptive_dr: bool
2424
s_0: float
2525
x_star: torch.Tensor
26-
sigma_star: float
26+
sigma_star_x: float
27+
sigma_star_y: float
2728
reward_step: int
29+
saliency_type: str
2830

2931

3032
class GridCell(nn.Module):
@@ -222,7 +224,7 @@ def _loss_trans_rnn(self, traj, step):
222224
heatmap_reshape = heatmap.reshape((heatmap.shape[0], -1)) # (N, 1600)
223225
y_hat = heatmap_reshape # actual "place cell" activity over the grid (linear readout of grid cells)
224226

225-
saliency_kernel = self._saliency_kernel(x_grid).unsqueeze(0).to(traj.device) # (1, 1600)
227+
saliency_kernel = self._saliency_kernel(x_grid, config.saliency_type).unsqueeze(0).to(traj.device) # (1, 1600)
226228
if step < config.reward_step:
227229
L_error = (y - y_hat) ** 2
228230
else:
@@ -233,12 +235,43 @@ def _loss_trans_rnn(self, traj, step):
233235

234236
return loss_trans * config.w_trans
235237

236-
def _saliency_kernel(self, x_grid):
238+
def _saliency_kernel(self, x_grid, saliency_type):
239+
if saliency_type == "gaussian":
240+
return self._saliency_kernel_gaussian(x_grid)
241+
if saliency_type == "left_half":
242+
return self._saliency_kernel_left_half(x_grid)
243+
raise NotImplementedError
244+
245+
# def _saliency_kernel_gaussian(self, x_grid, sigma_star_x, sigma_star_y):
246+
# config = self.config
247+
# s_0 = config.s_0
248+
# x_star = config.x_star
249+
# sigma_star = config.sigma_star
250+
# s_x = s_0*torch.exp(-torch.sum((x_grid - x_star)**2, dim=1)/(2*sigma_star**2))/np.sqrt(2*np.pi*sigma_star**2)
251+
# return 1 + s_x
252+
253+
def _saliency_kernel_gaussian(self, x_grid):
237254
config = self.config
238255
s_0 = config.s_0
239256
x_star = config.x_star
240-
sigma_star = config.sigma_star
241-
s_x = s_0*torch.exp(-torch.sum((x_grid - x_star)**2, dim=1)/(2*sigma_star**2))/np.sqrt(2*np.pi*sigma_star**2)
257+
sigma_star_x = config.sigma_star_x
258+
sigma_star_y = config.sigma_star_y
259+
260+
# Calculate the squared differences, scaled by respective sigma values
261+
diff = x_grid - x_star
262+
scaled_diff_sq = (diff[:, 0]**2 / sigma_star_x**2) + (diff[:, 1]**2 / sigma_star_y**2)
263+
264+
# Compute the Gaussian function
265+
normalization_factor = 2 * np.pi * sigma_star_x * sigma_star_y
266+
s_x = s_0 * torch.exp(-0.5 * scaled_diff_sq) / normalization_factor
267+
268+
return 1 + s_x
269+
270+
271+
def _saliency_kernel_left_half(self, x_grid):
272+
config = self.config
273+
s_0 = config.s_0
274+
s_x = s_0 * (x_grid[:, 0] < 0.5).float()
242275
return 1 + s_x
243276

244277
def _loss_trans_lstm(self, traj):

neurometry/curvature/grid-cells-curvature/models/xu_rnn/utils.py

+58-58
Original file line numberDiff line numberDiff line change
@@ -118,32 +118,32 @@ def block_diagonal(matrices):
118118
return blocked
119119

120120

121-
def draw_heatmap(data, save_path, xlabels=None, ylabels=None):
122-
# data = np.clip(data, -0.05, 0.05)
123-
cmap = cm.get_cmap("rainbow", 1000)
124-
figure = plt.figure(facecolor="w")
125-
ax = figure.add_subplot(1, 1, 1, position=[0.1, 0.15, 0.8, 0.8])
126-
if xlabels is not None:
127-
ax.set_xticks(range(len(xlabels)))
128-
ax.set_xticklabels(xlabels)
129-
if ylabels is not None:
130-
ax.set_yticks(range(len(ylabels)))
131-
ax.set_yticklabels(ylabels)
132-
133-
vmax = data[0][0]
134-
vmin = data[0][0]
135-
for i in data:
136-
for j in i:
137-
if j > vmax:
138-
vmax = j
139-
if j < vmin:
140-
vmin = j
141-
map = ax.imshow(
142-
data, interpolation="nearest", cmap=cmap, aspect="auto", vmin=vmin, vmax=vmax
143-
)
144-
plt.colorbar(mappable=map, cax=None, ax=None, shrink=0.5)
145-
plt.savefig(save_path)
146-
plt.close()
121+
# def draw_heatmap(data, save_path, xlabels=None, ylabels=None):
122+
# # data = np.clip(data, -0.05, 0.05)
123+
# cmap = cm.get_cmap("rainbow", 1000)
124+
# figure = plt.figure(facecolor="w")
125+
# ax = figure.add_subplot(1, 1, 1, position=[0.1, 0.15, 0.8, 0.8])
126+
# if xlabels is not None:
127+
# ax.set_xticks(range(len(xlabels)))
128+
# ax.set_xticklabels(xlabels)
129+
# if ylabels is not None:
130+
# ax.set_yticks(range(len(ylabels)))
131+
# ax.set_yticklabels(ylabels)
132+
133+
# vmax = data[0][0]
134+
# vmin = data[0][0]
135+
# for i in data:
136+
# for j in i:
137+
# if j > vmax:
138+
# vmax = j
139+
# if j < vmin:
140+
# vmin = j
141+
# map = ax.imshow(
142+
# data, interpolation="nearest", cmap=cmap, aspect="auto", vmin=vmin, vmax=vmax
143+
# )
144+
# plt.colorbar(mappable=map, cax=None, ax=None, shrink=0.5)
145+
# plt.savefig(save_path)
146+
# plt.close()
147147

148148

149149
def shape_mask(size, shape):
@@ -452,39 +452,39 @@ def _draw_real_pred_pairs(real, pred, area_size: int):
452452
ax.set_aspect(1)
453453

454454

455-
# def draw_heatmap(weights):
456-
# # weights should a 4-D tensor: [M, N, H, W]
457-
# nrow, ncol = weights.shape[0], weights.shape[1]
458-
# fig = plt.figure(figsize=(ncol, nrow))
459-
460-
# for i in range(nrow):
461-
# for j in range(ncol):
462-
# plt.subplot(nrow, ncol, i * ncol + j + 1)
463-
# weight = weights[i, j]
464-
# vmin, vmax = weight.min() - 0.01, weight.max()
465-
466-
# cmap = cm.get_cmap("rainbow", 1000)
467-
# cmap.set_under("w")
468-
469-
# plt.imshow(
470-
# weight,
471-
# interpolation="nearest",
472-
# cmap=cmap,
473-
# aspect="auto",
474-
# vmin=vmin,
475-
# vmax=vmax,
476-
# )
477-
# plt.axis("off")
478-
479-
# fig.canvas.draw()
480-
# image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
481-
# image_from_plot = image_from_plot.reshape(
482-
# fig.canvas.get_width_height()[::-1] + (3,)
483-
# )
484-
# # plt.show()
485-
# plt.close(fig)
455+
def draw_heatmap(weights):
456+
# weights should a 4-D tensor: [M, N, H, W]
457+
nrow, ncol = weights.shape[0], weights.shape[1]
458+
fig = plt.figure(figsize=(ncol, nrow))
486459

487-
# return np.expand_dims(image_from_plot, axis=0)
460+
for i in range(nrow):
461+
for j in range(ncol):
462+
plt.subplot(nrow, ncol, i * ncol + j + 1)
463+
weight = weights[i, j]
464+
vmin, vmax = weight.min() - 0.01, weight.max()
465+
466+
cmap = cm.get_cmap("rainbow", 1000)
467+
cmap.set_under("w")
468+
469+
plt.imshow(
470+
weight,
471+
interpolation="nearest",
472+
cmap=cmap,
473+
aspect="auto",
474+
vmin=vmin,
475+
vmax=vmax,
476+
)
477+
plt.axis("off")
478+
479+
fig.canvas.draw()
480+
image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
481+
image_from_plot = image_from_plot.reshape(
482+
fig.canvas.get_width_height()[::-1] + (3,)
483+
)
484+
# plt.show()
485+
plt.close(fig)
486+
487+
return np.expand_dims(image_from_plot, axis=0)
488488

489489

490490
def average_appended_metrics(metrics):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import pickle
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
from moviepy.editor import ImageSequenceClip
7+
8+
from neurometry.datasets.load_rnn_grid_cells import plot_rate_map
9+
10+
logs_dir = "logs/rnn_isometry"
11+
12+
activations_dir = "ckpt/activations"
13+
14+
def load_matrices(file_path):
15+
with open(file_path, "rb") as file:
16+
data = pickle.load(file)
17+
return data["u"], data["v"]
18+
19+
20+
# def save_matrix_as_image(matrix, image_path):
21+
# plt.imshow(matrix, cmap='viridis')
22+
# plt.colorbar()
23+
# plt.savefig(image_path)
24+
# plt.close()
25+
26+
27+
def create_video(image_folder, output_file, fps=10):
28+
image_files = [os.path.join(image_folder, img) for img in sorted(os.listdir(image_folder))]
29+
clip = ImageSequenceClip(image_files, fps=fps)
30+
clip.write_videofile(output_file, codec="libx264")
31+
32+
33+
# def generate_videos(run_id, start_epoch=25000, end_epoch=65000, step=500):
34+
35+
# data_dir = os.path.join(logs_dir, run_id, activations_dir)
36+
37+
# print(data_dir)
38+
39+
# output_dir = data_dir
40+
# os.makedirs(output_dir, exist_ok=True)
41+
# u_images_dir = os.path.join(output_dir, 'u_images')
42+
# v_images_dir = os.path.join(output_dir, 'v_images')
43+
# os.makedirs(u_images_dir, exist_ok=True)
44+
# os.makedirs(v_images_dir, exist_ok=True)
45+
46+
# # Generate images
47+
# for epoch in range(start_epoch, end_epoch, step):
48+
# file_name = f'activations-step{epoch}.pkl'
49+
# file_path = os.path.join(data_dir, file_name)
50+
# u, v = load_matrices(file_path)
51+
# save_matrix_as_image(u, os.path.join(u_images_dir, f'u_{epoch}.png'))
52+
# save_matrix_as_image(v, os.path.join(v_images_dir, f'v_{epoch}.png'))
53+
54+
# # Create videos
55+
# create_video(u_images_dir, os.path.join(output_dir, 'u_video.mp4'))
56+
# create_video(v_images_dir, os.path.join(output_dir, 'v_video.mp4'))
57+
58+
59+
def save_rate_maps_as_image(indices, num_plots, activations, image_path, title, seed=None):
60+
plot_rate_map(indices, num_plots, activations, title, seed=seed)
61+
plt.savefig(image_path)
62+
plt.close()
63+
64+
65+
def generate_videos(run_id, start_epoch=25000, end_epoch=65000, step=500, num_cells_per_image=20, seed=None):
66+
67+
data_dir = os.path.join(logs_dir, run_id, activations_dir)
68+
#os.makedirs(output_dir, exist_ok=True)
69+
output_dir = data_dir
70+
u_images_dir = os.path.join(output_dir, "u_images")
71+
v_images_dir = os.path.join(output_dir, "v_images")
72+
os.makedirs(u_images_dir, exist_ok=True)
73+
os.makedirs(v_images_dir, exist_ok=True)
74+
75+
rng = np.random.default_rng(seed=seed)
76+
idxs = rng.integers(0, 1799, num_cells_per_image)
77+
78+
# Iterate through epochs
79+
for epoch in range(start_epoch, end_epoch, step):
80+
file_name = f"activations-step{epoch}.pkl"
81+
file_path = os.path.join(data_dir, file_name)
82+
u, v = load_matrices(file_path)
83+
84+
# Save images using your custom plotting function
85+
u_image_path = os.path.join(u_images_dir, f"u_{epoch}.png")
86+
v_image_path = os.path.join(v_images_dir, f"v_{epoch}.png")
87+
save_rate_maps_as_image(idxs, num_cells_per_image, u, u_image_path, f"U matrices at epoch {epoch}", seed=seed)
88+
save_rate_maps_as_image(idxs, num_cells_per_image, v, v_image_path, f"V matrices at epoch {epoch}", seed=seed)
89+
90+
# Create videos
91+
create_video(u_images_dir, os.path.join(output_dir, "u_video.mp4"))
92+
create_video(v_images_dir, os.path.join(output_dir, "v_video.mp4"))
93+
94+
95+
96+
97+

0 commit comments

Comments
 (0)