Skip to content

Commit 59d708c

Browse files
authored
Merge pull request #132 from geometric-intelligence/ninamiolane-docs3
Ruff format
2 parents b9a235f + 768bbbb commit 59d708c

37 files changed

+132
-102
lines changed

docs/conf.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Sphinx configuration file."""
22

3-
43
project = "neurometry"
54
copyright = "2023, Geometric Intelligence Lab."
65
author = "GI Authors"

neurometry/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
"""Initialize folder as submodule."""
2+
23
__version__ = "0.0.1"

neurometry/curvature/datasets/experimental.py

-1
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ def _average_variable(variable_to_average, recorded_times, sampling_times):
353353
return np.array(variable_averaged)
354354

355355

356-
357356
def get_place_field_centers(neural_activity, task_variable):
358357
"""Get the center of mass of the place fields of a list of neurons.
359358

neurometry/curvature/datasets/gridcells.py

-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def create_reference_lattice(lx, ly, arena_dims, lattice_type="hexagonal"):
6767
return np.hstack((np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))))
6868

6969

70-
7170
def generate_all_grids(
7271
grid_scale,
7372
arena_dims,

neurometry/curvature/datasets/structures.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,3 @@ def get_lattice(scale, lattice_type, dimensions):
2424
Y = ly * N_y
2525

2626
return np.hstack((np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))))
27-

neurometry/curvature/datasets/utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,10 @@ def load(config):
156156
test_labels = labels.iloc[test_indices]
157157

158158
# The angles are positional angles in the lab frame
159-
if config.dataset_name in ("experimental", "s1_synthetic") or config.dataset_name in ("three_place_cells_synthetic"):
159+
if config.dataset_name in (
160+
"experimental",
161+
"s1_synthetic",
162+
) or config.dataset_name in ("three_place_cells_synthetic"):
160163
train = []
161164
for d, l in zip(train_dataset, train_labels["angles"], strict=False):
162165
train.append([d, float(l)])
@@ -165,10 +168,14 @@ def load(config):
165168
test.append([d, float(l)])
166169
elif config.dataset_name in ("s2_synthetic", "t2_synthetic"):
167170
train = []
168-
for d, t, p in zip(train_dataset, train_labels["thetas"], train_labels["phis"], strict=False):
171+
for d, t, p in zip(
172+
train_dataset, train_labels["thetas"], train_labels["phis"], strict=False
173+
):
169174
train.append([d, torch.tensor([float(t), float(p)])])
170175
test = []
171-
for d, t, p in zip(test_dataset, test_labels["thetas"], test_labels["phis"], strict=False):
176+
for d, t, p in zip(
177+
test_dataset, test_labels["thetas"], test_labels["phis"], strict=False
178+
):
172179
test.append([d, torch.tensor([float(t), float(p)])])
173180
elif config.dataset_name == "grid_cells":
174181
train = []

neurometry/curvature/default_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138
### ---> Lists of values to try for each parameter
139139

140140
# Datasets
141-
#dataset_name = ["s1_synthetic", "s2_synthetic"]
141+
# dataset_name = ["s1_synthetic", "s2_synthetic"]
142142
dataset_name = ["kb_synthetic"]
143143
for one_dataset_name in dataset_name:
144144
if one_dataset_name not in [

neurometry/curvature/evaluate.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
44
import geomstats.backend as gs
55

6-
#import gph
6+
# import gph
77
import numpy as np
88
import torch
99
from datasets.synthetic import (
@@ -53,7 +53,6 @@ def immersion(angle):
5353
z = z.to(config.device)
5454
return model.decode(z)
5555

56-
5756
return immersion
5857

5958

neurometry/curvature/grid-cells-curvature/make_animation.py

+78-17
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,54 @@
1414

1515
def gaussian_on_circle(theta, loc, sigma=0.1):
1616
"""A Gaussian-like function defined on the circle."""
17-
return np.exp(-(theta-loc)**2 / (2 * sigma**2))
17+
return np.exp(-((theta - loc) ** 2) / (2 * sigma**2))
18+
1819

1920
def relu(x):
2021
return np.maximum(0, x)
2122

23+
2224
# Function to plot a harmonic given amplitude and phase
2325
def plot_harmonic(ax, amplitude, phase, n, label, activation="relu"):
24-
2526
harmonic_values = amplitude * np.cos(n * theta + phase)
2627
if activation == "relu":
2728
harmonic_values = relu(harmonic_values)
28-
ax.plot(np.cos(theta), np.sin(theta), zs=0, zdir="z", linestyle="--",linewidth=3, color="black")
29-
normalized_phase = (phase + np.pi) / (2 * np.pi) # Normalizing from -π to π to 0 to 1
29+
ax.plot(
30+
np.cos(theta),
31+
np.sin(theta),
32+
zs=0,
33+
zdir="z",
34+
linestyle="--",
35+
linewidth=3,
36+
color="black",
37+
)
38+
normalized_phase = (phase + np.pi) / (
39+
2 * np.pi
40+
) # Normalizing from -π to π to 0 to 1
3041
color = cm.hsv(normalized_phase)
31-
ax.plot(np.cos(theta), np.sin(theta), harmonic_values, label=label,linewidth=3,color=color,alpha=1-0.1*n)
42+
ax.plot(
43+
np.cos(theta),
44+
np.sin(theta),
45+
harmonic_values,
46+
label=label,
47+
linewidth=3,
48+
color=color,
49+
alpha=1 - 0.1 * n,
50+
)
3251
ax.axis("off")
3352

53+
3454
# Prepare figure for plotting
35-
fig, axs = plt.subplots(2, N+1, figsize=(20, 10), subplot_kw={"projection": "3d"})
55+
fig, axs = plt.subplots(2, N + 1, figsize=(20, 10), subplot_kw={"projection": "3d"})
3656
plt.tight_layout()
3757

58+
3859
def update(loc):
3960
bump_samples = gaussian_on_circle(theta, loc=loc)
4061

4162
# Compute FFT
4263
coefficients_fft = np.fft.fft(bump_samples)
43-
frequencies = np.fft.fftfreq(num_samples, d=(2*np.pi/num_samples))
64+
frequencies = np.fft.fftfreq(num_samples, d=(2 * np.pi / num_samples))
4465

4566
# Clear previous plots
4667
for ax_row in axs:
@@ -49,32 +70,72 @@ def update(loc):
4970
ax.axis("off")
5071

5172
# Plot original function
52-
axs[0, 2].plot(np.cos(theta), np.sin(theta), zs=0, zdir="z", linestyle="--",linewidth=3,color="black")
53-
axs[0, 2].plot(np.cos(theta), np.sin(theta), bump_samples, label="Original Function",linewidth=3,color="tomato")
54-
axs[0, 2].set_title(f"Target place field, position = {loc:.2f}",fontsize=20)
73+
axs[0, 2].plot(
74+
np.cos(theta),
75+
np.sin(theta),
76+
zs=0,
77+
zdir="z",
78+
linestyle="--",
79+
linewidth=3,
80+
color="black",
81+
)
82+
axs[0, 2].plot(
83+
np.cos(theta),
84+
np.sin(theta),
85+
bump_samples,
86+
label="Original Function",
87+
linewidth=3,
88+
color="tomato",
89+
)
90+
axs[0, 2].set_title(f"Target place field, position = {loc:.2f}", fontsize=20)
5591
axs[0, 2].scatter(np.cos(loc), np.sin(loc), zs=0, zdir="z", s=100, c="red")
5692

5793
# Plot each harmonic and the reconstructed function
5894
reconstructed = np.zeros(num_samples)
59-
for n in range(1, N+1):
95+
for n in range(1, N + 1):
6096
index = n if frequencies[n] >= 0 else num_samples + n
6197
amplitude = np.abs(coefficients_fft[index])
6298
phase = np.angle(coefficients_fft[index])
6399

64-
plot_harmonic(axs[1, n-1], amplitude, phase, n, rf"GC module {n}, period $\lambda=${L/n:0.1f}", activation=activation)
65-
axs[1, n-1].set_title(rf"GC module {n}, period $\lambda_{n}=${L/n:0.1f}",fontsize=18)
100+
plot_harmonic(
101+
axs[1, n - 1],
102+
amplitude,
103+
phase,
104+
n,
105+
rf"GC module {n}, period $\lambda=${L/n:0.1f}",
106+
activation=activation,
107+
)
108+
axs[1, n - 1].set_title(
109+
rf"GC module {n}, period $\lambda_{n}=${L/n:0.1f}", fontsize=18
110+
)
66111
if activation == "relu":
67112
reconstructed += relu(amplitude * np.cos(n * theta + phase))
68113
else:
69114
reconstructed += amplitude * np.cos(n * theta + phase)
70115

71116
# Reconstructed function
72-
axs[1, N].plot(np.cos(theta), np.sin(theta), zs=0, zdir="z", linestyle="--",linewidth=3,color="black")
73-
axs[1, N].plot(np.cos(theta), np.sin(theta), reconstructed, label="Reconstructed",linewidth=3,color="limegreen")
74-
axs[1, N].set_title("Place field readout",fontsize=20)
117+
axs[1, N].plot(
118+
np.cos(theta),
119+
np.sin(theta),
120+
zs=0,
121+
zdir="z",
122+
linestyle="--",
123+
linewidth=3,
124+
color="black",
125+
)
126+
axs[1, N].plot(
127+
np.cos(theta),
128+
np.sin(theta),
129+
reconstructed,
130+
label="Reconstructed",
131+
linewidth=3,
132+
color="limegreen",
133+
)
134+
axs[1, N].set_title("Place field readout", fontsize=20)
135+
75136

76137
# Create animation
77-
loc_values = np.linspace(0, 2*np.pi, 100)
138+
loc_values = np.linspace(0, 2 * np.pi, 100)
78139
ani = FuncAnimation(fig, update, frames=loc_values, repeat=True)
79140

80141
# Save the animation

neurometry/curvature/hyperspherical/distributions/hyperspherical_uniform.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ def device(self, val):
2222
self._device = val if isinstance(val, torch.device) else torch.device(val)
2323

2424
def __init__(self, dim, validate_args=None, device=None):
25-
super().__init__(
26-
torch.Size([dim]), validate_args=validate_args
27-
)
25+
super().__init__(torch.Size([dim]), validate_args=validate_args)
2826
self._dim = dim
2927
self.device = device
3028

neurometry/curvature/hyperspherical/ops/ive.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,7 @@ def delta_a(a):
7070

7171
delta_0 = delta_a(0.0)
7272
delta_2 = delta_a(2.0)
73-
B_0 = z / (
74-
delta_0 + torch.sqrt(torch.pow(delta_0, 2) + torch.pow(z, 2)).clamp(eps)
75-
)
76-
B_2 = z / (
77-
delta_2 + torch.sqrt(torch.pow(delta_2, 2) + torch.pow(z, 2)).clamp(eps)
78-
)
73+
B_0 = z / (delta_0 + torch.sqrt(torch.pow(delta_0, 2) + torch.pow(z, 2)).clamp(eps))
74+
B_2 = z / (delta_2 + torch.sqrt(torch.pow(delta_2, 2) + torch.pow(z, 2)).clamp(eps))
7975

8076
return (B_0 + B_2) / 2.0

neurometry/curvature/losses.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from neurometry.curvature.hyperspherical.distributions import hyperspherical_uniform
66
from neurometry.curvature.hyperspherical.distributions import von_mises_fisher
77

8+
89
def elbo(x, x_mu, posterior_params, z, labels, config):
910
"""Compute the ELBO for the VAE loss.
1011
@@ -44,7 +45,9 @@ def elbo(x, x_mu, posterior_params, z, labels, config):
4445
if config.posterior_type == "hyperspherical":
4546
z_mu, z_kappa = posterior_params
4647
q_z = von_mises_fisher.VonMisesFisher(z_mu, z_kappa)
47-
p_z = hyperspherical_uniform.HypersphericalUniform(config.latent_dim - 1, device=config.device)
48+
p_z = hyperspherical_uniform.HypersphericalUniform(
49+
config.latent_dim - 1, device=config.device
50+
)
4851
kld = torch.distributions.kl.kl_divergence(q_z, p_z).mean()
4952

5053
if config.posterior_type == "toroidal":

neurometry/curvature/models/klein_bottle_vae.py

-3
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def encode(self, x):
100100

101101
return z_theta_mu, z_theta_kappa, z_u_mu, z_u_kappa
102102

103-
104103
def _build_klein_bottle(self, z_theta, z_u):
105104
# theta = torch.atan2(z_theta[:, 1] / z_theta[:, 0])
106105
# phi = torch.atan2(z_u[:, 1] / z_u[:, 0])
@@ -151,7 +150,6 @@ def reparameterize(self, posterior_params):
151150

152151
return self._build_torus(z_theta, z_u)
153152

154-
155153
def decode(self, z):
156154
"""Decode latent variable z into data.
157155
@@ -173,7 +171,6 @@ def decode(self, z):
173171

174172
return self.fc_x_mu(h)
175173

176-
177174
def forward(self, x):
178175
"""Run VAE: Encode, sample and decode.
179176

neurometry/curvature/models/neural_vae.py

-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def reparameterize(self, posterior_params):
139139

140140
return q_z.rsample()
141141

142-
143142
def decode(self, z):
144143
"""Decode latent variable z into data.
145144
@@ -163,7 +162,6 @@ def decode(self, z):
163162

164163
return self.fc_x_mu(h)
165164

166-
167165
def forward(self, x):
168166
"""Run VAE: Encode, sample and decode.
169167

neurometry/curvature/models/toroidal_vae.py

-3
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def encode(self, x):
100100

101101
return z_theta_mu, z_theta_kappa, z_phi_mu, z_phi_kappa
102102

103-
104103
def _build_torus(self, z_theta, z_phi):
105104
# theta = torch.atan2(z_theta[:, 1] / z_theta[:, 0])
106105
# phi = torch.atan2(z_phi[:, 1] / z_phi[:, 0])
@@ -150,7 +149,6 @@ def reparameterize(self, posterior_params):
150149

151150
return self._build_torus(z_theta, z_phi)
152151

153-
154152
def decode(self, z):
155153
"""Decode latent variable z into data.
156154
@@ -172,7 +170,6 @@ def decode(self, z):
172170

173171
return self.fc_x_mu(h)
174172

175-
176173
def forward(self, x):
177174
"""Run VAE: Encode, sample and decode.
178175

neurometry/curvature/persistent_homology.py

-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,3 @@ def compute_persistence_diagrams(point_cloud, maxdim=2, n_threads=-1):
3232
dfs.append(df)
3333

3434
return pd.concat(dfs, ignore_index=True)
35-

neurometry/curvature/viz.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def plot_recon_per_positional_angle(model, dataset_torch, labels, config):
101101
y_rec = rec[:, 1]
102102
y_rec = [y.item() for y in y_rec]
103103
ax_data.set_title("Synthetic data", fontsize=40)
104-
ax_data.scatter(
105-
x_data, y_data, s=400, c=labels["angles"], cmap=colormap
106-
)
104+
ax_data.scatter(x_data, y_data, s=400, c=labels["angles"], cmap=colormap)
107105
plt.xticks(fontsize=24)
108106
plt.yticks(fontsize=24)
109107
ax_rec = fig.add_subplot(1, 2, 2)

neurometry/datasets/load_rnn_grid_cells.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
# Loading single agent model
1414

15-
#parent_dir = os.getcwd() + "/datasets/rnn_grid_cells/"
15+
# parent_dir = os.getcwd() + "/datasets/rnn_grid_cells/"
1616

1717
parent_dir = "/scratch/facosta/rnn_grid_cells/"
1818

@@ -90,8 +90,8 @@ def load_activations(epochs, version="single", verbose=True):
9090

9191
return activations, rate_maps, state_points
9292

93-
def plot_rate_map(indices, num_plots, activations):
9493

94+
def plot_rate_map(indices, num_plots, activations):
9595
if indices is None:
9696
idxs = np.random.randint(0, 4095, num_plots)
9797
else:

neurometry/datasets/rnn_grid_cells/dual_agent_activity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def main(options, epoch="final", res=20):
7373
def compute_grid_scores(res, rate_map_dual_agent, scorer):
7474
print("Computing grid scores...")
7575
score_60_dual_agent, _, _, _, _, _ = zip(
76-
*[scorer.get_scores(rm.reshape(res, res)) for rm in tqdm(rate_map_dual_agent)], strict=False
76+
*[scorer.get_scores(rm.reshape(res, res)) for rm in tqdm(rate_map_dual_agent)],
77+
strict=False,
7778
)
7879
return np.array(score_60_dual_agent)
7980

8081

81-
8282
def compute_border_scores(box_width, res, rate_map_dual_agent, scorer):
8383
print("Computing border scores...")
8484
border_scores_dual_agent = []

neurometry/datasets/rnn_grid_cells/model.py

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def predict(self, inputs):
4949
"""
5050
return self.decoder(self.g(inputs))
5151

52-
5352
def compute_loss(self, inputs, pc_outputs, pos):
5453
"""
5554
Compute avg. loss and decoding error.

0 commit comments

Comments
 (0)