Skip to content

Commit 2c165a7

Browse files
committed
Use (unsafe) ruff fixes
1 parent 1dfee85 commit 2c165a7

File tree

72 files changed

+12558
-614941
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+12558
-614941
lines changed

.github/workflows/docs.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ jobs:
4141
branch: main
4242
folder: docs/_build
4343
token: ${{ secrets.DOCUMENTATION_KEY }}
44-
repository-name: geometric-intelligence/geometric-intelligence.github.io
45-
target-folder: neurometry
44+
repository-name: geometric-intelligence/neurometry.github.io
4645
clean: true
4746
numpydoc-validation:
4847
runs-on: ${{matrix.os}}

docs/conf.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Sphinx configuration file."""
22

3-
import neurometry
43

54
project = "neurometry"
65
copyright = "2023, Geometric Intelligence Lab."

neurometry/curvature/datasets/experimental.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def load_neural_activity(expt_id=34, vel_threshold=5, timestep_microsec=int(1e5)
145145
qw = _average_variable(qw, tracked_times, sampling_times)
146146
success = _average_variable(success, tracked_times, sampling_times)
147147

148-
radius2 = [xx**2 + yy**2 for xx, yy in zip(x, y)]
148+
radius2 = [xx**2 + yy**2 for xx, yy in zip(x, y, strict=False)]
149149
angles_tracked = np.arctan2(y, x)
150150

151151
quat_head = np.stack([qx, qy, qz, qw], axis=1) # scalar-last format
@@ -350,9 +350,8 @@ def _average_variable(variable_to_average, recorded_times, sampling_times):
350350
variable_averaged.append(averaged)
351351
cum_count += int(count)
352352

353-
variable_averaged = np.array(variable_averaged)
353+
return np.array(variable_averaged)
354354

355-
return variable_averaged
356355

357356

358357
def get_place_field_centers(neural_activity, task_variable):

neurometry/curvature/datasets/gridcells.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import os
44

5-
import matplotlib.pyplot as plt
65
import numpy as np
76
import pandas as pd
87

@@ -65,9 +64,8 @@ def create_reference_lattice(lx, ly, arena_dims, lattice_type="hexagonal"):
6564
X = lx * N_x
6665
Y = ly * N_y
6766

68-
ref_lattice = np.hstack((np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))))
67+
return np.hstack((np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))))
6968

70-
return ref_lattice
7169

7270

7371
def generate_all_grids(
@@ -79,7 +77,7 @@ def generate_all_grids(
7977
warp=None,
8078
lattice_type="hexagonal",
8179
):
82-
"""Create lattices for all grid cells within a module, with varying phase & orientation.
80+
r"""Create lattices for all grid cells within a module, with varying phase & orientation.
8381
8482
Parameters
8583
----------
@@ -105,9 +103,9 @@ def generate_all_grids(
105103
scale=grid_scale, lattice_type=lattice_type, dimensions=arena_dims
106104
)
107105

108-
grids = np.zeros((n_cells,) + np.shape(ref_lattice))
106+
grids = np.zeros((n_cells, *np.shape(ref_lattice)))
109107

110-
grids_warped = np.zeros((n_cells,) + np.shape(ref_lattice))
108+
grids_warped = np.zeros((n_cells, *np.shape(ref_lattice)))
111109

112110
arena_dims = np.array(arena_dims)
113111

@@ -122,7 +120,7 @@ def generate_all_grids(
122120
lattice_i = np.matmul(rot_i, ref_lattice.T).T + phase_i
123121
# lattice_i = np.where(abs(lattice_i) < arena_dims / 2, lattice_i, None)
124122

125-
if warp == None:
123+
if warp is None:
126124
pass
127125
else:
128126
for j, point in enumerate(lattice_i):

neurometry/curvature/datasets/structures.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,5 @@ def get_lattice(scale, lattice_type, dimensions):
2323
X = lx * N_x
2424
Y = ly * N_y
2525

26-
lattice = np.hstack((np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))))
26+
return np.hstack((np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))))
2727

28-
return lattice

neurometry/curvature/datasets/synthetic.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
import pandas as pd
1010
import torch
11-
import torch.nn.functional as F
1211
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # NOQA
1312
from torch.distributions.multivariate_normal import MultivariateNormal
1413

neurometry/curvature/datasets/utils.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
"""Utils to import data from matlab."""
22

3-
import curvature.datasets.experimental
4-
import curvature.datasets.gridcells
5-
import curvature.datasets.synthetic
63
import mat73
74
import numpy as np
85
import scipy.io
96
import torch
107
from scipy.signal import savgol_filter
11-
from sklearn.decomposition import PCA
128

139

1410
def load(config):
@@ -40,7 +36,7 @@ def load(config):
4036
labels = labels[labels["velocities"] > 5]
4137
dataset = np.log(dataset.astype(np.float32) + 1)
4238

43-
if config.smooth == True:
39+
if config.smooth is True:
4440
dataset_smooth = np.zeros_like(dataset)
4541
for _ in range(dataset.shape[1]):
4642
dataset_smooth[:, _] = savgol_filter(
@@ -160,33 +156,26 @@ def load(config):
160156
test_labels = labels.iloc[test_indices]
161157

162158
# The angles are positional angles in the lab frame
163-
if config.dataset_name in ("experimental", "s1_synthetic"):
159+
if config.dataset_name in ("experimental", "s1_synthetic") or config.dataset_name in ("three_place_cells_synthetic"):
164160
train = []
165-
for d, l in zip(train_dataset, train_labels["angles"]):
161+
for d, l in zip(train_dataset, train_labels["angles"], strict=False):
166162
train.append([d, float(l)])
167163
test = []
168-
for d, l in zip(test_dataset, test_labels["angles"]):
169-
test.append([d, float(l)])
170-
elif config.dataset_name in ("three_place_cells_synthetic"):
171-
train = []
172-
for d, l in zip(train_dataset, train_labels["angles"]):
173-
train.append([d, float(l)])
174-
test = []
175-
for d, l in zip(test_dataset, test_labels["angles"]):
164+
for d, l in zip(test_dataset, test_labels["angles"], strict=False):
176165
test.append([d, float(l)])
177166
elif config.dataset_name in ("s2_synthetic", "t2_synthetic"):
178167
train = []
179-
for d, t, p in zip(train_dataset, train_labels["thetas"], train_labels["phis"]):
168+
for d, t, p in zip(train_dataset, train_labels["thetas"], train_labels["phis"], strict=False):
180169
train.append([d, torch.tensor([float(t), float(p)])])
181170
test = []
182-
for d, t, p in zip(test_dataset, test_labels["thetas"], test_labels["phis"]):
171+
for d, t, p in zip(test_dataset, test_labels["thetas"], test_labels["phis"], strict=False):
183172
test.append([d, torch.tensor([float(t), float(p)])])
184173
elif config.dataset_name == "grid_cells":
185174
train = []
186-
for d, l in zip(train_dataset, train_labels["no_labels"]):
175+
for d, l in zip(train_dataset, train_labels["no_labels"], strict=False):
187176
train.append([d, float(l)])
188177
test = []
189-
for d, l in zip(test_dataset, test_labels["no_labels"]):
178+
for d, l in zip(test_dataset, test_labels["no_labels"], strict=False):
190179
test.append([d, float(l)])
191180

192181
train_loader = torch.utils.data.DataLoader(train, batch_size=config.batch_size)

neurometry/curvature/default_config.py

-2
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
import logging
44
import os
5-
from datetime import datetime
65

76
import numpy as np
87
import torch
9-
from ray.tune.search.hyperopt import HyperOptSearch
108

119
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
1210

neurometry/curvature/evaluate.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
44
import geomstats.backend as gs
5+
56
#import gph
67
import numpy as np
78
import torch
@@ -10,7 +11,6 @@
1011
get_s2_synthetic_immersion,
1112
get_t2_synthetic_immersion,
1213
)
13-
from decorators import timer
1414
from geomstats.geometry.pullback_metric import PullbackMetric
1515
from geomstats.geometry.special_orthogonal import SpecialOrthogonal # NOQA
1616

@@ -51,9 +51,8 @@ def immersion(angle):
5151
)
5252

5353
z = z.to(config.device)
54-
x_mu = model.decode(z)
54+
return model.decode(z)
5555

56-
return x_mu
5756

5857
return immersion
5958

@@ -112,7 +111,7 @@ def _compute_curvature(z_grid, immersion, dim, embedding_dim):
112111
neural_metric = PullbackMetric(
113112
dim=dim, embedding_dim=embedding_dim, immersion=immersion
114113
)
115-
z0 = torch.unsqueeze(z_grid[0], dim=0)
114+
torch.unsqueeze(z_grid[0], dim=0)
116115
if dim == 1:
117116
curv = gs.zeros(len(z_grid), embedding_dim)
118117
geodesic_dist = gs.zeros(len(z_grid))
@@ -214,8 +213,7 @@ def _integrate_s2(thetas, phis, h):
214213
sum_phis[t] = torch.trapz(
215214
y=h[len(phis) * t : len(phis) * (t + 1)], x=phis
216215
) * np.sin(theta)
217-
integral = torch.trapz(y=sum_phis, x=thetas)
218-
return integral
216+
return torch.trapz(y=sum_phis, x=thetas)
219217

220218

221219
def _compute_curvature_error_s2(thetas, phis, curv_norms_learned, curv_norms_true):
@@ -240,23 +238,17 @@ def _compute_curvature_error_t2(thetas, phis, curv_norms_learned, curv_norms_tru
240238
def compute_curvature_error(
241239
z_grid, curv_norms_learned, curv_norms_true, config
242240
): # Calculate method error
243-
start_time = time.time()
241+
time.time()
244242

245243
if config.dataset_name == "s1_synthetic":
246244
thetas = z_grid
247245
error = _compute_curvature_error_s1(thetas, curv_norms_learned, curv_norms_true)
248-
elif config.dataset_name == "s2_synthetic":
249-
thetas = z_grid[:, 0]
250-
phis = z_grid[:, 1]
251-
error = _compute_curvature_error_s2(
252-
thetas, phis, curv_norms_learned, curv_norms_true
253-
)
254-
elif config.dataset_name == "t2_synthetic":
246+
elif config.dataset_name == "s2_synthetic" or config.dataset_name == "t2_synthetic":
255247
thetas = z_grid[:, 0]
256248
phis = z_grid[:, 1]
257249
error = _compute_curvature_error_s2(
258250
thetas, phis, curv_norms_learned, curv_norms_true
259251
)
260-
end_time = time.time()
252+
time.time()
261253
# print("Computation time: " + "%.3f" % (end_time - start_time) + " seconds.")
262254
return error
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
import numpy as np
1+
import matplotlib.cm as cm
22
import matplotlib.pyplot as plt
3+
import numpy as np
34
from matplotlib.animation import FuncAnimation
4-
import matplotlib.cm as cm
55

66
L = 10 # Length of the domain
77

88
# Initialize parameters
99
num_samples = 1000
1010
theta = np.linspace(0, 2 * np.pi, num_samples, endpoint=False)
1111
N = 4 # Number of harmonics
12-
activation = 'relu'
12+
activation = "relu"
1313

1414

1515
def gaussian_on_circle(theta, loc, sigma=0.1):
@@ -20,19 +20,19 @@ def relu(x):
2020
return np.maximum(0, x)
2121

2222
# Function to plot a harmonic given amplitude and phase
23-
def plot_harmonic(ax, amplitude, phase, n, label, activation='relu'):
23+
def plot_harmonic(ax, amplitude, phase, n, label, activation="relu"):
2424

2525
harmonic_values = amplitude * np.cos(n * theta + phase)
26-
if activation == 'relu':
26+
if activation == "relu":
2727
harmonic_values = relu(harmonic_values)
28-
ax.plot(np.cos(theta), np.sin(theta), zs=0, zdir='z', linestyle='--',linewidth=3, color='black')
28+
ax.plot(np.cos(theta), np.sin(theta), zs=0, zdir="z", linestyle="--",linewidth=3, color="black")
2929
normalized_phase = (phase + np.pi) / (2 * np.pi) # Normalizing from -π to π to 0 to 1
3030
color = cm.hsv(normalized_phase)
3131
ax.plot(np.cos(theta), np.sin(theta), harmonic_values, label=label,linewidth=3,color=color,alpha=1-0.1*n)
32-
ax.axis('off')
32+
ax.axis("off")
3333

3434
# Prepare figure for plotting
35-
fig, axs = plt.subplots(2, N+1, figsize=(20, 10), subplot_kw={'projection': '3d'})
35+
fig, axs = plt.subplots(2, N+1, figsize=(20, 10), subplot_kw={"projection": "3d"})
3636
plt.tight_layout()
3737

3838
def update(loc):
@@ -46,36 +46,36 @@ def update(loc):
4646
for ax_row in axs:
4747
for ax in ax_row:
4848
ax.cla()
49-
ax.axis('off')
49+
ax.axis("off")
5050

5151
# Plot original function
52-
axs[0, 2].plot(np.cos(theta), np.sin(theta), zs=0, zdir='z', linestyle='--',linewidth=3,color='black')
53-
axs[0, 2].plot(np.cos(theta), np.sin(theta), bump_samples, label='Original Function',linewidth=3,color='tomato')
54-
axs[0, 2].set_title('Target place field, position = {:.2f}'.format(loc),fontsize=20)
55-
axs[0, 2].scatter(np.cos(loc), np.sin(loc), zs=0, zdir='z', s=100, c='red')
52+
axs[0, 2].plot(np.cos(theta), np.sin(theta), zs=0, zdir="z", linestyle="--",linewidth=3,color="black")
53+
axs[0, 2].plot(np.cos(theta), np.sin(theta), bump_samples, label="Original Function",linewidth=3,color="tomato")
54+
axs[0, 2].set_title(f"Target place field, position = {loc:.2f}",fontsize=20)
55+
axs[0, 2].scatter(np.cos(loc), np.sin(loc), zs=0, zdir="z", s=100, c="red")
5656

5757
# Plot each harmonic and the reconstructed function
5858
reconstructed = np.zeros(num_samples)
5959
for n in range(1, N+1):
6060
index = n if frequencies[n] >= 0 else num_samples + n
6161
amplitude = np.abs(coefficients_fft[index])
6262
phase = np.angle(coefficients_fft[index])
63-
64-
plot_harmonic(axs[1, n-1], amplitude, phase, n, f'GC module {n}, period $\lambda=${L/n:0.1f}', activation=activation)
65-
axs[1, n-1].set_title(f'GC module {n}, period $\lambda_{n}=${L/n:0.1f}',fontsize=18)
66-
if activation == 'relu':
63+
64+
plot_harmonic(axs[1, n-1], amplitude, phase, n, rf"GC module {n}, period $\lambda=${L/n:0.1f}", activation=activation)
65+
axs[1, n-1].set_title(rf"GC module {n}, period $\lambda_{n}=${L/n:0.1f}",fontsize=18)
66+
if activation == "relu":
6767
reconstructed += relu(amplitude * np.cos(n * theta + phase))
6868
else:
6969
reconstructed += amplitude * np.cos(n * theta + phase)
7070

7171
# Reconstructed function
72-
axs[1, N].plot(np.cos(theta), np.sin(theta), zs=0, zdir='z', linestyle='--',linewidth=3,color='black')
73-
axs[1, N].plot(np.cos(theta), np.sin(theta), reconstructed, label='Reconstructed',linewidth=3,color='limegreen')
74-
axs[1, N].set_title('Place field readout',fontsize=20)
72+
axs[1, N].plot(np.cos(theta), np.sin(theta), zs=0, zdir="z", linestyle="--",linewidth=3,color="black")
73+
axs[1, N].plot(np.cos(theta), np.sin(theta), reconstructed, label="Reconstructed",linewidth=3,color="limegreen")
74+
axs[1, N].set_title("Place field readout",fontsize=20)
7575

7676
# Create animation
77-
loc_values = np.linspace(0, 2*np.pi, 100)
77+
loc_values = np.linspace(0, 2*np.pi, 100)
7878
ani = FuncAnimation(fig, update, frames=loc_values, repeat=True)
7979

8080
# Save the animation
81-
ani.save('position_from_grid_cells.gif', writer='imagemagick', fps=10)
81+
ani.save("position_from_grid_cells.gif", writer="imagemagick", fps=10)

neurometry/curvature/grid-cells-curvature/notebooks/12_path_int.ipynb

+3-6
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,12 @@
2727
"%autoreload 2\n",
2828
"%load_ext jupyter_black\n",
2929
"\n",
30-
"import numpy as np\n",
30+
"import os\n",
3131
"\n",
3232
"import matplotlib.pyplot as plt\n",
33+
"import numpy as np\n",
3334
"\n",
34-
"\n",
35-
"import os\n",
36-
"\n",
37-
"os.environ[\"GEOMSTATS_BACKEND\"] = \"pytorch\"\n",
38-
"import geomstats.backend as gs"
35+
"os.environ[\"GEOMSTATS_BACKEND\"] = \"pytorch\""
3936
]
4037
},
4138
{

neurometry/curvature/hyperspherical/distributions/hyperspherical_uniform.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def device(self, val):
2222
self._device = val if isinstance(val, torch.device) else torch.device(val)
2323

2424
def __init__(self, dim, validate_args=None, device=None):
25-
super(HypersphericalUniform, self).__init__(
25+
super().__init__(
2626
torch.Size([dim]), validate_args=validate_args
2727
)
2828
self._dim = dim

0 commit comments

Comments
 (0)