1
1
import os
2
-
3
- import matplotlib .pyplot as plt
4
- import numpy as np
5
2
import pickle
6
3
7
- # sys.path.append(str(Path(__file__).parent.parent))
8
- from .rnn_grid_cells import config , dual_agent_activity , single_agent_activity , utils
9
4
import matplotlib .cm as cm
5
+ import matplotlib .pyplot as plt
6
+ import numpy as np
10
7
import tensorflow as tf
11
- import pickle
12
- import yaml
13
8
import torch
9
+ import umap
10
+ import yaml
11
+ from sklearn .cluster import DBSCAN
12
+
14
13
from neurometry .datasets .rnn_grid_cells .scores import GridScorer
15
14
16
- from sklearn .cluster import DBSCAN
17
- import umap
15
+ # sys.path.append(str(Path(__file__).parent.parent))
16
+ from .rnn_grid_cells import config , dual_agent_activity , single_agent_activity , utils
17
+
18
18
19
19
def load_rate_maps (run_id , step ):
20
20
#XU_RNN
21
21
model_dir = os .path .join (os .getcwd (), "curvature/grid-cells-curvature/models/xu_rnn" )
22
22
run_dir = os .path .join (model_dir , f"logs/rnn_isometry/{ run_id } " )
23
23
activations_file = os .path .join (run_dir , f"ckpt/activations/activations-step{ step } .pkl" )
24
24
with open (activations_file , "rb" ) as f :
25
- activations = pickle .load (f )
26
-
27
- return activations
25
+ return pickle .load (f )
28
26
29
27
def load_config (run_id ):
30
28
model_dir = os .path .join (os .getcwd (), "curvature/grid-cells-curvature/models/xu_rnn" )
31
29
run_dir = os .path .join (model_dir , f"logs/rnn_isometry/{ run_id } " )
32
30
config_file = os .path .join (run_dir , "config.txt" )
33
31
34
- with open (config_file , 'r' ) as file :
35
- config = yaml .safe_load (file )
32
+ with open (config_file ) as file :
33
+ return yaml .safe_load (file )
34
+
36
35
37
- return config
38
36
39
37
40
38
@@ -44,9 +42,11 @@ def extract_tensor_events(event_file, verbose=True):
44
42
losses = []
45
43
try :
46
44
for e in tf .compat .v1 .train .summary_iterator (event_file ):
47
- if verbose : print (f"Found event at step { e .step } with wall time { e .wall_time } " )
45
+ if verbose :
46
+ print (f"Found event at step { e .step } with wall time { e .wall_time } " )
48
47
for v in e .summary .value :
49
- if verbose : print (f"Found value with tag: { v .tag } " )
48
+ if verbose :
49
+ print (f"Found value with tag: { v .tag } " )
50
50
if v .HasField ("tensor" ):
51
51
tensor = tf .make_ndarray (v .tensor )
52
52
record = {
@@ -60,7 +60,8 @@ def extract_tensor_events(event_file, verbose=True):
60
60
loss = {"step" : e .step , "loss" : tensor }
61
61
losses .append (loss )
62
62
else :
63
- if verbose : print (f"No 'tensor' found for tag { v .tag } " )
63
+ if verbose :
64
+ print (f"No 'tensor' found for tag { v .tag } " )
64
65
except Exception as e :
65
66
print (f"An error occurred: { e } " )
66
67
return records , losses
@@ -74,23 +75,21 @@ def _compute_scores(activations, config):
74
75
75
76
starts = [0.1 ] * 20
76
77
ends = np .linspace (0.2 , 1.4 , num = 20 )
77
- masks_parameters = zip (starts , ends .tolist ())
78
-
79
- ncol , nrow = block_size , num_block
78
+ masks_parameters = zip (starts , ends .tolist (), strict = False )
80
79
81
80
scorer = GridScorer (40 , ((0 , 1 ), (0 , 1 )), masks_parameters )
82
81
83
- score_list = np .zeros (shape = [len (activations ['v' ])], dtype = np .float32 )
84
- scale_list = np .zeros (shape = [len (activations ['v' ])], dtype = np .float32 )
82
+ score_list = np .zeros (shape = [len (activations ["v" ])], dtype = np .float32 )
83
+ scale_list = np .zeros (shape = [len (activations ["v" ])], dtype = np .float32 )
85
84
#orientation_list = np.zeros(shape=[len(weights)], dtype=np.float32)
86
85
sac_list = []
87
86
88
- for i in range (len (activations ['v' ])):
89
- rate_map = activations ['v' ][i ]
87
+ for i in range (len (activations ["v" ])):
88
+ rate_map = activations ["v" ][i ]
90
89
rate_map = (rate_map - rate_map .min ()) / (rate_map .max () - rate_map .min ())
91
90
92
91
score_60 , score_90 , max_60_mask , max_90_mask , sac , _ = scorer .get_scores (
93
- activations ['v' ][i ])
92
+ activations ["v" ][i ])
94
93
sac_list .append (sac )
95
94
96
95
score_list [i ] = score_60
@@ -109,10 +108,9 @@ def _compute_scores(activations, config):
109
108
# score_tensor = score_tensor.reshape((num_block, block_size))
110
109
score_tensor = torch .mean (score_tensor )
111
110
sac_array = np .array (sac_list )
112
-
113
- scores = {"sac" :sac_array , "scale" :scale_tensor , "score" : score_tensor , "max_scale" : max_scale }
114
111
115
- return scores
112
+ return {"sac" :sac_array , "scale" :scale_tensor , "score" : score_tensor , "max_scale" : max_scale }
113
+
116
114
117
115
118
116
@@ -288,7 +286,7 @@ def draw_heatmap(activations, title):
288
286
fig .canvas .get_width_height ()[::- 1 ] + (3 ,)
289
287
)
290
288
291
- fig .suptitle (title , fontsize = 20 , fontweight = ' bold' , verticalalignment = ' top' )
289
+ fig .suptitle (title , fontsize = 20 , fontweight = " bold" , verticalalignment = " top" )
292
290
293
291
plt .tight_layout (rect = [0 , 0 , 1 , 0.95 ])
294
292
plt .show ()
@@ -320,9 +318,9 @@ def _vectorized_spatial_autocorrelation_matrix(spatial_autocorrelation):
320
318
def umap_dbscan (activations , run_dir , config , sac_array = None , plot = True ):
321
319
if sac_array is None :
322
320
sac_array = get_scores (run_dir , activations , config )["sac" ]
323
-
321
+
324
322
spatial_autocorrelation_matrix = _vectorized_spatial_autocorrelation_matrix (sac_array )
325
-
323
+
326
324
umap_reducer_2d = umap .UMAP (n_components = 2 , random_state = 10 )
327
325
umap_embedding = umap_reducer_2d .fit_transform (spatial_autocorrelation_matrix .T )
328
326
@@ -336,7 +334,7 @@ def umap_dbscan(activations, run_dir, config, sac_array=None, plot=True):
336
334
if plot :
337
335
fig , axes = plt .subplots (1 , 2 , figsize = (12 , 4 ))
338
336
339
- for k , col in zip (unique_labels , colors ):
337
+ for k , col in zip (unique_labels , colors , strict = False ):
340
338
if k == - 1 :
341
339
# Black used for noise.
342
340
# col = [0, 0, 0, 1]
@@ -346,20 +344,20 @@ def umap_dbscan(activations, run_dir, config, sac_array=None, plot=True):
346
344
347
345
xy = umap_embedding [class_member_mask ]
348
346
if plot :
349
- axes [0 ].plot (xy [:, 0 ], xy [:, 1 ], 'o' , markerfacecolor = tuple (col ), markeredgecolor = ' none' , markersize = 5 , label = f' Cluster { k } ' )
347
+ axes [0 ].plot (xy [:, 0 ], xy [:, 1 ], "o" , markerfacecolor = tuple (col ), markeredgecolor = " none" , markersize = 5 , label = f" Cluster { k } " )
350
348
351
349
umap_cluster_labels = umap_dbscan .fit_predict (umap_embedding )
352
350
clusters = {}
353
351
for i in np .unique (umap_cluster_labels ):
354
352
#cluster = _get_data_from_cluster(activations,i, umap_cluster_labels)
355
353
cluster = activations [umap_cluster_labels == i ]
356
354
clusters [i ] = cluster
357
-
355
+
358
356
if plot :
359
357
axes [0 ].set_xlabel ("UMAP 1" )
360
358
axes [0 ].set_ylabel ("UMAP 2" )
361
359
axes [0 ].set_title ("UMAP embedding of spatial autocorrelation" )
362
- axes [0 ].legend (title = "Cluster IDs" , loc = ' center left' , bbox_to_anchor = (1 , 0.5 ))
360
+ axes [0 ].legend (title = "Cluster IDs" , loc = " center left" , bbox_to_anchor = (1 , 0.5 ))
363
361
364
362
axes [1 ].hist (umap_cluster_labels , bins = len (np .unique (umap_cluster_labels )))
365
363
axes [1 ].set_xlabel ("Cluster ID" )
@@ -369,4 +367,3 @@ def umap_dbscan(activations, run_dir, config, sac_array=None, plot=True):
369
367
plt .show ()
370
368
return clusters , umap_cluster_labels
371
369
372
-
0 commit comments