Skip to content

Commit 7e83e63

Browse files
ruff
1 parent 6e8d629 commit 7e83e63

File tree

7 files changed

+51
-61
lines changed

7 files changed

+51
-61
lines changed

neurometry/curvature/grid-cells-curvature/models/xu_rnn/experiment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import ml_collections
88
import model as model
99
import numpy as np
10+
1011
#import tensorflow as tf
1112
import torch
1213
import torch.nn as nn
@@ -431,6 +432,6 @@ def _save_checkpoint(self, step, ckpt_dir):
431432
}
432433
with open(activations_filename, "wb") as f:
433434
pickle.dump(activations, f)
434-
435+
435436

436437
logging.info(f"Saving activations: {activations_filename} ...")

neurometry/curvature/grid-cells-curvature/models/xu_rnn/viz.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import os
22
import pickle
33

4+
import matplotlib.cm as cm
45
import matplotlib.pyplot as plt
56
import numpy as np
67
from moviepy.editor import ImageSequenceClip
78

8-
from neurometry.datasets.load_rnn_grid_cells import plot_rate_map
9-
import matplotlib.cm as cm
10-
import yaml
11-
129
logs_dir = "logs/rnn_isometry"
1310

1411
activations_dir = "ckpt/activations"
@@ -55,7 +52,7 @@ def draw_heatmap(activations, title):
5552
image_from_plot = image_from_plot.reshape(
5653
fig.canvas.get_width_height()[::-1] + (3,)
5754
)
58-
fig.suptitle(title, fontsize=20, fontweight='bold', verticalalignment='top')
55+
fig.suptitle(title, fontsize=20, fontweight="bold", verticalalignment="top")
5956

6057
plt.tight_layout(rect=[0, 0, 1, 0.95])
6158
plt.show()
@@ -68,6 +65,7 @@ def save_rate_maps_as_image(activations, image_path, title):
6865
fig = draw_heatmap(activations, title)
6966
plt.savefig(image_path)
7067
plt.close()
68+
return fig
7169

7270

7371
def generate_videos(run_id, start_epoch=25000, end_epoch=65000, step=500):
@@ -80,8 +78,7 @@ def generate_videos(run_id, start_epoch=25000, end_epoch=65000, step=500):
8078
os.makedirs(u_images_dir, exist_ok=True)
8179
os.makedirs(v_images_dir, exist_ok=True)
8280

83-
config_file = os.path.join(logs_dir, run_id, "config.txt")
84-
81+
#config_file = os.path.join(logs_dir, run_id, "config.txt")
8582
# with open(config_file, 'r') as file:
8683
# config = yaml.safe_load(file)
8784

neurometry/datasets/load_rnn_grid_cells.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,38 @@
11
import os
2-
3-
import matplotlib.pyplot as plt
4-
import numpy as np
52
import pickle
63

7-
# sys.path.append(str(Path(__file__).parent.parent))
8-
from .rnn_grid_cells import config, dual_agent_activity, single_agent_activity, utils
94
import matplotlib.cm as cm
5+
import matplotlib.pyplot as plt
6+
import numpy as np
107
import tensorflow as tf
11-
import pickle
12-
import yaml
138
import torch
9+
import umap
10+
import yaml
11+
from sklearn.cluster import DBSCAN
12+
1413
from neurometry.datasets.rnn_grid_cells.scores import GridScorer
1514

16-
from sklearn.cluster import DBSCAN
17-
import umap
15+
# sys.path.append(str(Path(__file__).parent.parent))
16+
from .rnn_grid_cells import config, dual_agent_activity, single_agent_activity, utils
17+
1818

1919
def load_rate_maps(run_id, step):
2020
#XU_RNN
2121
model_dir = os.path.join(os.getcwd(), "curvature/grid-cells-curvature/models/xu_rnn")
2222
run_dir = os.path.join(model_dir, f"logs/rnn_isometry/{run_id}")
2323
activations_file = os.path.join(run_dir, f"ckpt/activations/activations-step{step}.pkl")
2424
with open(activations_file, "rb") as f:
25-
activations = pickle.load(f)
26-
27-
return activations
25+
return pickle.load(f)
2826

2927
def load_config(run_id):
3028
model_dir = os.path.join(os.getcwd(), "curvature/grid-cells-curvature/models/xu_rnn")
3129
run_dir = os.path.join(model_dir, f"logs/rnn_isometry/{run_id}")
3230
config_file = os.path.join(run_dir, "config.txt")
3331

34-
with open(config_file, 'r') as file:
35-
config = yaml.safe_load(file)
32+
with open(config_file) as file:
33+
return yaml.safe_load(file)
34+
3635

37-
return config
3836

3937

4038

@@ -44,9 +42,11 @@ def extract_tensor_events(event_file, verbose=True):
4442
losses = []
4543
try:
4644
for e in tf.compat.v1.train.summary_iterator(event_file):
47-
if verbose: print(f"Found event at step {e.step} with wall time {e.wall_time}")
45+
if verbose:
46+
print(f"Found event at step {e.step} with wall time {e.wall_time}")
4847
for v in e.summary.value:
49-
if verbose: print(f"Found value with tag: {v.tag}")
48+
if verbose:
49+
print(f"Found value with tag: {v.tag}")
5050
if v.HasField("tensor"):
5151
tensor = tf.make_ndarray(v.tensor)
5252
record = {
@@ -60,7 +60,8 @@ def extract_tensor_events(event_file, verbose=True):
6060
loss = {"step": e.step, "loss": tensor}
6161
losses.append(loss)
6262
else:
63-
if verbose: print(f"No 'tensor' found for tag {v.tag}")
63+
if verbose:
64+
print(f"No 'tensor' found for tag {v.tag}")
6465
except Exception as e:
6566
print(f"An error occurred: {e}")
6667
return records, losses
@@ -74,23 +75,21 @@ def _compute_scores(activations, config):
7475

7576
starts = [0.1] * 20
7677
ends = np.linspace(0.2, 1.4, num=20)
77-
masks_parameters = zip(starts, ends.tolist())
78-
79-
ncol, nrow = block_size, num_block
78+
masks_parameters = zip(starts, ends.tolist(), strict=False)
8079

8180
scorer = GridScorer(40, ((0, 1), (0, 1)), masks_parameters)
8281

83-
score_list = np.zeros(shape=[len(activations['v'])], dtype=np.float32)
84-
scale_list = np.zeros(shape=[len(activations['v'])], dtype=np.float32)
82+
score_list = np.zeros(shape=[len(activations["v"])], dtype=np.float32)
83+
scale_list = np.zeros(shape=[len(activations["v"])], dtype=np.float32)
8584
#orientation_list = np.zeros(shape=[len(weights)], dtype=np.float32)
8685
sac_list = []
8786

88-
for i in range(len(activations['v'])):
89-
rate_map = activations['v'][i]
87+
for i in range(len(activations["v"])):
88+
rate_map = activations["v"][i]
9089
rate_map = (rate_map - rate_map.min()) / (rate_map.max() - rate_map.min())
9190

9291
score_60, score_90, max_60_mask, max_90_mask, sac, _ = scorer.get_scores(
93-
activations['v'][i])
92+
activations["v"][i])
9493
sac_list.append(sac)
9594

9695
score_list[i] = score_60
@@ -109,10 +108,9 @@ def _compute_scores(activations, config):
109108
# score_tensor = score_tensor.reshape((num_block, block_size))
110109
score_tensor = torch.mean(score_tensor)
111110
sac_array = np.array(sac_list)
112-
113-
scores = {"sac":sac_array, "scale":scale_tensor, "score": score_tensor, "max_scale": max_scale}
114111

115-
return scores
112+
return {"sac":sac_array, "scale":scale_tensor, "score": score_tensor, "max_scale": max_scale}
113+
116114

117115

118116

@@ -288,7 +286,7 @@ def draw_heatmap(activations, title):
288286
fig.canvas.get_width_height()[::-1] + (3,)
289287
)
290288

291-
fig.suptitle(title, fontsize=20, fontweight='bold', verticalalignment='top')
289+
fig.suptitle(title, fontsize=20, fontweight="bold", verticalalignment="top")
292290

293291
plt.tight_layout(rect=[0, 0, 1, 0.95])
294292
plt.show()
@@ -320,9 +318,9 @@ def _vectorized_spatial_autocorrelation_matrix(spatial_autocorrelation):
320318
def umap_dbscan(activations, run_dir, config, sac_array=None, plot=True):
321319
if sac_array is None:
322320
sac_array = get_scores(run_dir, activations, config)["sac"]
323-
321+
324322
spatial_autocorrelation_matrix = _vectorized_spatial_autocorrelation_matrix(sac_array)
325-
323+
326324
umap_reducer_2d = umap.UMAP(n_components=2, random_state=10)
327325
umap_embedding = umap_reducer_2d.fit_transform(spatial_autocorrelation_matrix.T)
328326

@@ -336,7 +334,7 @@ def umap_dbscan(activations, run_dir, config, sac_array=None, plot=True):
336334
if plot:
337335
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
338336

339-
for k, col in zip(unique_labels, colors):
337+
for k, col in zip(unique_labels, colors, strict=False):
340338
if k == -1:
341339
# Black used for noise.
342340
# col = [0, 0, 0, 1]
@@ -346,20 +344,20 @@ def umap_dbscan(activations, run_dir, config, sac_array=None, plot=True):
346344

347345
xy = umap_embedding[class_member_mask]
348346
if plot:
349-
axes[0].plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col), markeredgecolor='none', markersize=5, label=f'Cluster {k}')
347+
axes[0].plot(xy[:, 0], xy[:, 1], "o", markerfacecolor=tuple(col), markeredgecolor="none", markersize=5, label=f"Cluster {k}")
350348

351349
umap_cluster_labels = umap_dbscan.fit_predict(umap_embedding)
352350
clusters = {}
353351
for i in np.unique(umap_cluster_labels):
354352
#cluster = _get_data_from_cluster(activations,i, umap_cluster_labels)
355353
cluster = activations[umap_cluster_labels == i]
356354
clusters[i] = cluster
357-
355+
358356
if plot:
359357
axes[0].set_xlabel("UMAP 1")
360358
axes[0].set_ylabel("UMAP 2")
361359
axes[0].set_title("UMAP embedding of spatial autocorrelation")
362-
axes[0].legend(title="Cluster IDs", loc='center left', bbox_to_anchor=(1, 0.5))
360+
axes[0].legend(title="Cluster IDs", loc="center left", bbox_to_anchor=(1, 0.5))
363361

364362
axes[1].hist(umap_cluster_labels, bins=len(np.unique(umap_cluster_labels)))
365363
axes[1].set_xlabel("Cluster ID")
@@ -369,4 +367,3 @@ def umap_dbscan(activations, run_dir, config, sac_array=None, plot=True):
369367
plt.show()
370368
return clusters, umap_cluster_labels
371369

372-

neurometry/dimension/dim_reduction.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import matplotlib.pyplot as plt
2-
from sklearn.decomposition import PCA
32
import numpy as np
4-
5-
import umap
6-
from sklearn.manifold import Isomap, MDS, LocallyLinearEmbedding, SpectralEmbedding, TSNE
3+
from sklearn.decomposition import PCA
74

85

96
def plot_pca_projections(X, K, title):
@@ -29,8 +26,8 @@ def plot_pca_projections(X, K, title):
2926
)
3027

3128
plt.tight_layout()
32-
33-
fig.suptitle(title, fontsize=30, fontweight='bold', verticalalignment='top')
29+
30+
fig.suptitle(title, fontsize=30, fontweight="bold", verticalalignment="top")
3431
plt.show()
3532

3633
print(f"The {K} top PCs explain {100*np.cumsum(ev)[-1]:.2f}% of the variance")

neurometry/topology/pd_distances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ def compute_pairwise_distances(diagrams, metric="bottleneck"):
99
def compare_representation_to_references(
1010
representation, reference_topologies, metric="bottleneck"
1111
):
12-
raise NotImplementedError
12+
raise NotImplementedError

neurometry/topology/persistent_homology.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dreimac import CircularCoords, ToroidalCoords
33
from gtda.homology import VietorisRipsPersistence, WeightedRipsPersistence
44

5+
56
def compute_persistence_diagrams(
67
representations,
78
homology_dimensions=(0, 1, 2),
@@ -23,17 +24,15 @@ def compute_persistence_diagrams(
2324

2425

2526
def _shuffle_entries(data, rng):
26-
shuffled_data = np.array([rng.permutation(row) for row in data])
27-
return shuffled_data
27+
return np.array([rng.permutation(row) for row in data])
2828

2929

3030
def compute_diagrams_shuffle(X, num_shuffles, seed=0, homology_dimensions=(0, 1)):
3131
rng = np.random.default_rng(seed)
3232
shuffled_Xs = [_shuffle_entries(X, rng) for _ in range(num_shuffles)]
33-
diagrams = compute_persistence_diagrams(
34-
[X] + shuffled_Xs, homology_dimensions=homology_dimensions
33+
return compute_persistence_diagrams(
34+
[X, *shuffled_Xs], homology_dimensions=homology_dimensions
3535
)
36-
return diagrams
3736

3837
def cohomological_toroidal_coordinates(data):
3938
n_landmarks = data.shape[0]

neurometry/topology/plotting.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import plotly.graph_objects as go
44

55

6-
76
def _plot_bars_from_diagrams(ax, diagrams, dim, **kwargs):
87
birth = diagrams[:, 0]
98
death = diagrams[:, 1]
@@ -20,7 +19,7 @@ def _plot_bars_from_diagrams(ax, diagrams, dim, **kwargs):
2019
linewidth = kwargs.get("linewidth", 5)
2120

2221
# Plotting each bar
23-
for i, (b, d) in enumerate(zip(birth, death)):
22+
for i, (b, d) in enumerate(zip(birth, death, strict=False)):
2423
ax.plot(
2524
[0, d - b],
2625
[i * offset, i * offset],
@@ -47,7 +46,7 @@ def plot_all_barcodes_with_null(diagrams, **kwargs):
4746
if num_dims == 1:
4847
axs = [axs] # Make it iterable if there's only one subplot
4948

50-
for ax, dim, color in zip(axs, dims, colors):
49+
for ax, dim, color in zip(axs, dims, colors, strict=False):
5150
diag_dim = original_diagram[original_diagram[:, 2] == dim]
5251
null_diag_dim = shuffled_diagrams[:, :, 2] == dim
5352
null_diag = shuffled_diagrams[null_diag_dim]
@@ -126,4 +125,4 @@ def plot_activity_on_torus(neural_activations, toroidal_coords, neuron_id, neuro
126125

127126
fig.update_layout(title=title, autosize=False, width=800, height=500)
128127
fig.show()
129-
return fig
128+
return fig

0 commit comments

Comments
 (0)