12
12
utils ,
13
13
)
14
14
15
- # Loading single agent model
16
15
17
- # parent_dir = os.getcwd() + "/datasets/rnn_grid_cells/"
18
-
19
- parent_dir = "/scratch/facosta/rnn_grid_cells/"
20
-
21
-
22
- single_model_folder = "Single agent path integration/Seed 1 weight decay 1e-06/"
23
- single_model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"
24
-
25
-
26
- dual_model_folder = (
27
- "Dual agent path integration disjoint PCs/Seed 1 weight decay 1e-06/"
28
- )
29
- dual_model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"
30
-
31
-
32
- def load_activations (epochs , version = "single" , verbose = True ):
16
+ def load_activations (epochs , file_path , version = "single" , verbose = True , save = True ):
33
17
activations = []
34
18
rate_maps = []
35
19
state_points = []
36
20
positions = []
21
+ g_s = []
37
22
38
- if version == "single" :
39
- activations_dir = (
40
- parent_dir + single_model_folder + single_model_parameters + "activations/"
41
- )
42
- elif version == "dual" :
43
- activations_dir = (
44
- parent_dir + dual_model_folder + dual_model_parameters + "activations/"
45
- )
23
+ activations_dir = os .path .join (file_path , "activations" )
46
24
47
- random .seed (0 )
48
25
for epoch in epochs :
49
- activations_epoch_path = (
50
- activations_dir + f"activations_{ version } _agent_epoch_{ epoch } .npy"
51
- )
52
- rate_map_epoch_path = (
53
- activations_dir + f"rate_map_{ version } _agent_epoch_{ epoch } .npy"
54
- )
55
- positions_epoch_path = (
56
- activations_dir + f"positions_{ version } _agent_epoch_{ epoch } .npy"
57
- )
58
-
59
- if (
60
- os .path .exists (activations_epoch_path )
61
- and os .path .exists (rate_map_epoch_path )
62
- and os .path .exists (positions_epoch_path )
63
- ):
26
+ activations_epoch_path = os .path .join (activations_dir , f"activations_{ version } _agent_epoch_{ epoch } .npy" )
27
+ rate_map_epoch_path = os .path .join (activations_dir , f"rate_map_{ version } _agent_epoch_{ epoch } .npy" )
28
+ positions_epoch_path = os .path .join (activations_dir , f"positions_{ version } _agent_epoch_{ epoch } .npy" )
29
+ gs_epoch_path = os .path .join (activations_dir , f"g_{ version } _agent_epoch_{ epoch } .npy" )
30
+
31
+ if os .path .exists (activations_epoch_path ) and os .path .exists (
32
+ rate_map_epoch_path
33
+ ) and os .path .exists (positions_epoch_path ) and os .path .exists (gs_epoch_path ):
64
34
activations .append (np .load (activations_epoch_path ))
65
35
rate_maps .append (np .load (rate_map_epoch_path ))
66
36
positions .append (np .load (positions_epoch_path ))
37
+ g_s .append (np .load (gs_epoch_path ))
67
38
if verbose :
68
- print (f"Epoch { epoch } found! " )
39
+ print (f"Epoch { epoch } found. " )
69
40
else :
70
41
print (f"Epoch { epoch } not found. Loading ..." )
71
42
parser = config .parser
@@ -75,22 +46,32 @@ def load_activations(epochs, version="single", verbose=True):
75
46
(
76
47
activations_single_agent ,
77
48
rate_map_single_agent ,
49
+ g_single_agent ,
78
50
positions_single_agent ,
79
- ) = single_agent_activity .main (options , epoch = epoch )
51
+ ) = single_agent_activity .main (options , file_path , epoch = epoch )
80
52
activations .append (activations_single_agent )
81
53
rate_maps .append (rate_map_single_agent )
82
54
positions .append (positions_single_agent )
55
+ g_s .append (g_single_agent )
83
56
elif version == "dual" :
84
- activations_dual_agent , rate_map_dual_agent , positions_dual_agent = (
85
- dual_agent_activity .main (options , epoch = epoch )
86
- )
57
+ activations_dual_agent , rate_map_dual_agent , g_dual_agent , positions_dual_agent = dual_agent_activity .main (
58
+ options , file_path , epoch = epoch )
87
59
activations .append (activations_dual_agent )
88
60
rate_maps .append (rate_map_dual_agent )
89
61
positions .append (positions_dual_agent )
90
- print (len (activations ))
62
+ g_s .append (g_dual_agent )
63
+
64
+ if save :
65
+ np .save (activations_epoch_path , activations [- 1 ])
66
+ np .save (rate_map_epoch_path , rate_maps [- 1 ])
67
+ np .save (positions_epoch_path , positions [- 1 ])
68
+ np .save (gs_epoch_path , g_s [- 1 ])
69
+
91
70
state_points_epoch = activations [- 1 ].reshape (activations [- 1 ].shape [0 ], - 1 )
92
71
state_points .append (state_points_epoch )
93
72
73
+
74
+
94
75
if verbose :
95
76
print (f"Loaded epochs { epochs } of { version } agent model." )
96
77
print (
@@ -104,7 +85,7 @@ def load_activations(epochs, version="single", verbose=True):
104
85
)
105
86
print (f"positions has shape { positions [0 ].shape } ." )
106
87
107
- return activations , rate_maps , state_points , positions
88
+ return activations , rate_maps , state_points , positions , g_s
108
89
109
90
110
91
# def plot_rate_map(indices, num_plots, activations, title):
@@ -137,9 +118,8 @@ def load_activations(epochs, version="single", verbose=True):
137
118
# plt.show()
138
119
139
120
140
-
141
- def plot_rate_map (indices , num_plots , activations , title ):
142
- rng = np .random .default_rng (seed = 0 )
121
+ def plot_rate_map (indices , num_plots , activations , title , seed = None ):
122
+ rng = np .random .default_rng (seed = seed )
143
123
if indices is None :
144
124
idxs = rng .integers (0 , activations .shape [0 ] - 1 , num_plots )
145
125
else :
0 commit comments