Skip to content

Commit f367123

Browse files
authored
Merge pull request #141 from geometric-intelligence/ninamiolane-lint
Lint majority of repo
2 parents 3b0489b + d45738d commit f367123

25 files changed

+90705
-41832
lines changed

neurometry/curvature/datasets/gridcells.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import pandas as pd
77

88
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
9-
import geomstats.backend as gs
9+
import geomstats.backend as gs # noqa: E402
1010

11-
import neurometry.curvature.datasets.structures as structures
11+
import neurometry.curvature.datasets.structures as structures # noqa: E402
1212

1313

1414
# TODO
@@ -96,7 +96,7 @@ def generate_all_grids(
9696
grids : numpy.ndarray, shape=(num_cells, num_fields_per_cell = (ceil(dims[0]/lx)+1)*(ceil(dims[1]/ly)+1),2)
9797
All the grid cell lattices.
9898
"""
99-
99+
lx = ly = 10 # TODO: FIX, these values are only placeholders.
100100
# ref_lattice = create_reference_lattice(lx, ly, arena_dims, lattice_type = lattice_type)
101101
ref_lattice = structures.get_lattice(
102102
scale=grid_scale, lattice_type=lattice_type, dimensions=arena_dims

neurometry/curvature/datasets/synthetic.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import os
55

66
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
7-
import geomstats.backend as gs
8-
import numpy as np
9-
import pandas as pd
10-
import torch
11-
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # NOQA
12-
from torch.distributions.multivariate_normal import MultivariateNormal
7+
import geomstats.backend as gs # noqa: E402
8+
import numpy as np # noqa: E402
9+
import pandas as pd # noqa: E402
10+
import skimage # noqa: E402
11+
import torch # noqa: E402
12+
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # noqa: E402
13+
from torch.distributions.multivariate_normal import MultivariateNormal # noqa: E402
1314

1415

1516
def load_projected_images(n_scalars=5, n_angles=1000, img_size=128):

neurometry/curvature/datasets/utils.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77
from scipy.signal import savgol_filter
88

9+
import neurometry.curvature.datasets as datasets
10+
911

1012
def load(config):
1113
"""Load dataset according to configuration in config.
@@ -161,29 +163,29 @@ def load(config):
161163
"s1_synthetic",
162164
) or config.dataset_name in ("three_place_cells_synthetic"):
163165
train = []
164-
for d, l in zip(train_dataset, train_labels["angles"], strict=False):
165-
train.append([d, float(l)])
166+
for data, label in zip(train_dataset, train_labels["angles"], strict=False):
167+
train.append([data, float(label)])
166168
test = []
167-
for d, l in zip(test_dataset, test_labels["angles"], strict=False):
168-
test.append([d, float(l)])
169+
for data, label in zip(test_dataset, test_labels["angles"], strict=False):
170+
test.append([data, float(label)])
169171
elif config.dataset_name in ("s2_synthetic", "t2_synthetic"):
170172
train = []
171-
for d, t, p in zip(
173+
for data, theta, phi in zip(
172174
train_dataset, train_labels["thetas"], train_labels["phis"], strict=False
173175
):
174-
train.append([d, torch.tensor([float(t), float(p)])])
176+
train.append([data, torch.tensor([float(theta), float(phi)])])
175177
test = []
176-
for d, t, p in zip(
178+
for data, theta, phi in zip(
177179
test_dataset, test_labels["thetas"], test_labels["phis"], strict=False
178180
):
179-
test.append([d, torch.tensor([float(t), float(p)])])
181+
test.append([data, torch.tensor([float(theta), float(phi)])])
180182
elif config.dataset_name == "grid_cells":
181183
train = []
182-
for d, l in zip(train_dataset, train_labels["no_labels"], strict=False):
183-
train.append([d, float(l)])
184+
for data, label in zip(train_dataset, train_labels["no_labels"], strict=False):
185+
train.append([data, float(label)])
184186
test = []
185-
for d, l in zip(test_dataset, test_labels["no_labels"], strict=False):
186-
test.append([d, float(l)])
187+
for data, label in zip(test_dataset, test_labels["no_labels"], strict=False):
188+
test.append([data, float(label)])
187189

188190
train_loader = torch.utils.data.DataLoader(train, batch_size=config.batch_size)
189191
test_loader = torch.utils.data.DataLoader(test, batch_size=config.batch_size)

neurometry/curvature/evaluate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import os
2+
import time
3+
4+
import numpy as np
5+
import torch
26

37
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
4-
import geomstats.backend as gs
8+
import geomstats.backend as gs # noqa: E402
59

610
# import gph
7-
import numpy as np
8-
import torch
9-
from datasets.synthetic import (
11+
from datasets.synthetic import ( # noqa: E402
1012
get_s1_synthetic_immersion,
1113
get_s2_synthetic_immersion,
1214
get_t2_synthetic_immersion,
1315
)
14-
from geomstats.geometry.pullback_metric import PullbackMetric
15-
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # NOQA
16+
from geomstats.geometry.pullback_metric import PullbackMetric # noqa: E402
17+
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # noqa: E402
1618

1719

1820
def get_learned_immersion(model, config):

neurometry/curvature/grid-cells-curvature/notebooks/12_path_int.ipynb

Lines changed: 0 additions & 240 deletions
This file was deleted.

neurometry/curvature/hyperspherical/distributions/hyperspherical_uniform.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def __init__(self, dim, validate_args=None, device=None):
2626
self._dim = dim
2727
self.device = device
2828

29-
def sample(self, shape=torch.Size()):
29+
def sample(self, shape=None):
30+
if shape is None:
31+
shape = torch.Size()
3032
output = (
3133
torch.distributions.Normal(0, 1)
3234
.sample(

neurometry/curvature/hyperspherical/distributions/von_mises_fisher.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,15 @@ def __init__(self, loc, scale, validate_args=None, k=1):
4444

4545
super().__init__(self.loc.size(), validate_args=validate_args)
4646

47-
def sample(self, shape=torch.Size()):
47+
def sample(self, shape=None):
48+
if shape is None:
49+
shape = torch.Size()
4850
with torch.no_grad():
4951
return self.rsample(shape)
5052

51-
def rsample(self, shape=torch.Size()):
53+
def rsample(self, shape=None):
54+
if shape is None:
55+
shape = torch.Size()
5256
shape = shape if isinstance(shape, torch.Size) else torch.Size([shape])
5357

5458
w = (

neurometry/curvature/losses.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import torch
44

5-
from neurometry.curvature.hyperspherical.distributions import hyperspherical_uniform
6-
from neurometry.curvature.hyperspherical.distributions import von_mises_fisher
5+
from neurometry.curvature.hyperspherical.distributions import (
6+
hyperspherical_uniform,
7+
von_mises_fisher,
8+
)
79

810

911
def elbo(x, x_mu, posterior_params, z, labels, config):
@@ -25,9 +27,11 @@ def elbo(x, x_mu, posterior_params, z, labels, config):
2527
x : array-like, shape=[batch_size, data_dim]
2628
Input data.
2729
gen_likelihood_params : tuple
28-
Learned distributional parameters of generative model. (e.g., (x_mu,x_logvar) for Gaussian).
30+
Learned distributional parameters of generative model.
31+
(e.g., (x_mu,x_logvar) for Gaussian).
2932
posterior_params : tuple
30-
Learned distributional parameters of approximate posterior. (e.g., (z_mu,z_logvar) for Gaussian).
33+
Learned distributional parameters of approximate posterior.
34+
(e.g., (z_mu,z_logvar) for Gaussian).
3135
config : module
3236
Module specifying various model hyperparameters
3337
@@ -52,9 +56,10 @@ def elbo(x, x_mu, posterior_params, z, labels, config):
5256

5357
if config.posterior_type == "toroidal":
5458
z_theta_mu, z_theta_kappa, z_phi_mu, z_phi_kappa = posterior_params
55-
q_z_theta = VonMisesFisher(z_theta_mu, z_theta_kappa)
56-
q_z_phi = VonMisesFisher(z_phi_mu, z_phi_kappa)
57-
p_z = HypersphericalUniform(config.latent_dim - 1, device=config.device)
59+
q_z_theta = von_mises_fisher.VonMisesFisher(z_theta_mu, z_theta_kappa)
60+
q_z_phi = von_mises_fisher.VonMisesFisher(z_phi_mu, z_phi_kappa)
61+
p_z = hyperspherical_uniform.HypersphericalUniform(
62+
config.latent_dim - 1, device=config.device)
5863
kld_theta = torch.distributions.kl.kl_divergence(q_z_theta, p_z).mean()
5964
kld_phi = torch.distributions.kl.kl_divergence(q_z_phi, p_z).mean()
6065
kld = kld_theta + kld_phi
@@ -149,7 +154,8 @@ def moving_forward_loss(z, config):
149154
"""
150155
if config.dataset_name != "experimental":
151156
# print(
152-
# "WARNING: Moving forward loss only implemented for experimental data --> Skipped."
157+
# "WARNING: Moving forward loss only implemented for experimental data
158+
# --> Skipped."
153159
# )
154160
return torch.zeros(1).to(config.device)
155161
if len(z) == 1:
@@ -176,7 +182,8 @@ def dynamic_loss(labels, z, config):
176182
"""
177183
if config.dataset_name != "experimental":
178184
# print(
179-
# "WARNING: Dynamic loss only implemented for experimental data --> Skipped."
185+
# "WARNING: Dynamic loss only implemented for experimental data
186+
# --> Skipped."
180187
# )
181188
return torch.zeros(1).to(config.device)
182189
latent_angles = (torch.atan2(z[:, 1], z[:, 0]) + 2 * torch.pi) % (2 * torch.pi)

neurometry/curvature/main.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,11 @@
44
import json
55
import logging
66
import os
7-
import time
8-
9-
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
107
import random
8+
import time
119

12-
import datasets.utils
13-
import default_config
14-
import evaluate
15-
import geomstats.backend as gs
1610
import matplotlib
1711
import matplotlib.pyplot as plt
18-
import models.klein_bottle_vae
19-
import models.neural_vae
20-
import models.toroidal_vae
2112
import numpy as np
2213
import pandas as pd
2314

@@ -30,6 +21,15 @@
3021
from ray.tune.schedulers import AsyncHyperBandScheduler
3122
from ray.tune.search.hyperopt import HyperOptSearch
3223

24+
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
25+
import datasets.utils # noqa: E402
26+
import default_config # noqa: E402
27+
import evaluate # noqa: E402
28+
import geomstats.backend as gs # noqa: E402
29+
import models.klein_bottle_vae # noqa: E402
30+
import models.neural_vae # noqa: E402
31+
import models.toroidal_vae # noqa: E402
32+
3333
# Required to make matplotlib figures in threads:
3434
matplotlib.use("Agg")
3535

@@ -92,7 +92,8 @@ def main():
9292
f"Manifold cannot be embedded in {embedding_dim} dimensions"
9393
)
9494
continue
95-
sweep_name = f"{dataset_name}_noise_var_{noise_var}_embedding_dim_{embedding_dim}"
95+
sweep_name = f"{dataset_name}_noise_var_{noise_var}"
96+
sweep_name += f"_embedding_dim_{embedding_dim}"
9697
logging.info(f"\n---> START training for ray sweep: {sweep_name}.")
9798
main_sweep(
9899
sweep_name=sweep_name,
@@ -120,7 +121,8 @@ def main():
120121
default_config.field_width,
121122
default_config.resolution,
122123
):
123-
sweep_name = f"{dataset_name}_orientation_std_{grid_orientation_std}_ncells_{n_cells}"
124+
sweep_name = f"{dataset_name}_orientation_std_{grid_orientation_std}"
125+
sweep_name += f"_ncells_{n_cells}"
124126
logging.info(f"\n---> START training for ray sweep: {sweep_name}.")
125127
main_sweep(
126128
sweep_name=sweep_name,
@@ -219,7 +221,8 @@ def main_sweep(
219221
"grid_orientation_std": grid_orientation_std,
220222
"field_width": field_width,
221223
"resolution": resolution,
222-
# Parameters fixed across runs and sweeps (unique value depending on dataset_name):
224+
# Parameters fixed across runs and sweeps
225+
# (unique value depending on dataset_name):
223226
"manifold_dim": default_config.manifold_dim[dataset_name],
224227
"latent_dim": default_config.latent_dim[dataset_name],
225228
"posterior_type": default_config.posterior_type[dataset_name],

neurometry/curvature/viz.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,10 +629,8 @@ def plot_persistence_diagrams(diagrams_df, density=False):
629629

630630
handles, _ = ax.get_legend_handles_labels()
631631

632-
legend_labels = []
633-
legend_labels.append("Infinity")
634-
for dimension in plot_df["Dimension"].unique():
635-
legend_labels.append(f"Dimension {dimension}")
632+
legend_labels = ["Infinity"] + [
633+
f"Dimension {dimension}" for dimension in plot_df["Dimension"].unique()]
636634

637635
ax.legend(handles, legend_labels, loc="lower right")
638636

0 commit comments

Comments
 (0)