1
- import numpy as np
1
+ import matplotlib . cm as cm
2
2
import matplotlib .pyplot as plt
3
+ import numpy as np
3
4
from matplotlib .animation import FuncAnimation
4
- import matplotlib .cm as cm
5
5
6
6
L = 10 # Length of the domain
7
7
8
8
# Initialize parameters
9
9
num_samples = 1000
10
10
theta = np .linspace (0 , 2 * np .pi , num_samples , endpoint = False )
11
11
N = 4 # Number of harmonics
12
- activation = ' relu'
12
+ activation = " relu"
13
13
14
14
15
15
def gaussian_on_circle (theta , loc , sigma = 0.1 ):
@@ -20,19 +20,19 @@ def relu(x):
20
20
return np .maximum (0 , x )
21
21
22
22
# Function to plot a harmonic given amplitude and phase
23
- def plot_harmonic (ax , amplitude , phase , n , label , activation = ' relu' ):
23
+ def plot_harmonic (ax , amplitude , phase , n , label , activation = " relu" ):
24
24
25
25
harmonic_values = amplitude * np .cos (n * theta + phase )
26
- if activation == ' relu' :
26
+ if activation == " relu" :
27
27
harmonic_values = relu (harmonic_values )
28
- ax .plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = 'z' , linestyle = '--' ,linewidth = 3 , color = ' black' )
28
+ ax .plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = "z" , linestyle = "--" ,linewidth = 3 , color = " black" )
29
29
normalized_phase = (phase + np .pi ) / (2 * np .pi ) # Normalizing from -π to π to 0 to 1
30
30
color = cm .hsv (normalized_phase )
31
31
ax .plot (np .cos (theta ), np .sin (theta ), harmonic_values , label = label ,linewidth = 3 ,color = color ,alpha = 1 - 0.1 * n )
32
- ax .axis (' off' )
32
+ ax .axis (" off" )
33
33
34
34
# Prepare figure for plotting
35
- fig , axs = plt .subplots (2 , N + 1 , figsize = (20 , 10 ), subplot_kw = {' projection' : '3d' })
35
+ fig , axs = plt .subplots (2 , N + 1 , figsize = (20 , 10 ), subplot_kw = {" projection" : "3d" })
36
36
plt .tight_layout ()
37
37
38
38
def update (loc ):
@@ -46,36 +46,36 @@ def update(loc):
46
46
for ax_row in axs :
47
47
for ax in ax_row :
48
48
ax .cla ()
49
- ax .axis (' off' )
49
+ ax .axis (" off" )
50
50
51
51
# Plot original function
52
- axs [0 , 2 ].plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = 'z' , linestyle = '--' ,linewidth = 3 ,color = ' black' )
53
- axs [0 , 2 ].plot (np .cos (theta ), np .sin (theta ), bump_samples , label = ' Original Function' ,linewidth = 3 ,color = ' tomato' )
54
- axs [0 , 2 ].set_title (' Target place field, position = {:.2f}' . format ( loc ) ,fontsize = 20 )
55
- axs [0 , 2 ].scatter (np .cos (loc ), np .sin (loc ), zs = 0 , zdir = 'z' , s = 100 , c = ' red' )
52
+ axs [0 , 2 ].plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = "z" , linestyle = "--" ,linewidth = 3 ,color = " black" )
53
+ axs [0 , 2 ].plot (np .cos (theta ), np .sin (theta ), bump_samples , label = " Original Function" ,linewidth = 3 ,color = " tomato" )
54
+ axs [0 , 2 ].set_title (f" Target place field, position = { loc :.2f} " ,fontsize = 20 )
55
+ axs [0 , 2 ].scatter (np .cos (loc ), np .sin (loc ), zs = 0 , zdir = "z" , s = 100 , c = " red" )
56
56
57
57
# Plot each harmonic and the reconstructed function
58
58
reconstructed = np .zeros (num_samples )
59
59
for n in range (1 , N + 1 ):
60
60
index = n if frequencies [n ] >= 0 else num_samples + n
61
61
amplitude = np .abs (coefficients_fft [index ])
62
62
phase = np .angle (coefficients_fft [index ])
63
-
64
- plot_harmonic (axs [1 , n - 1 ], amplitude , phase , n , f' GC module { n } , period $\lambda=${ L / n :0.1f} ' , activation = activation )
65
- axs [1 , n - 1 ].set_title (f' GC module { n } , period $\lambda_{ n } =${ L / n :0.1f} ' ,fontsize = 18 )
66
- if activation == ' relu' :
63
+
64
+ plot_harmonic (axs [1 , n - 1 ], amplitude , phase , n , rf" GC module { n } , period $\lambda=${ L / n :0.1f} " , activation = activation )
65
+ axs [1 , n - 1 ].set_title (rf" GC module { n } , period $\lambda_{ n } =${ L / n :0.1f} " ,fontsize = 18 )
66
+ if activation == " relu" :
67
67
reconstructed += relu (amplitude * np .cos (n * theta + phase ))
68
68
else :
69
69
reconstructed += amplitude * np .cos (n * theta + phase )
70
70
71
71
# Reconstructed function
72
- axs [1 , N ].plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = 'z' , linestyle = '--' ,linewidth = 3 ,color = ' black' )
73
- axs [1 , N ].plot (np .cos (theta ), np .sin (theta ), reconstructed , label = ' Reconstructed' ,linewidth = 3 ,color = ' limegreen' )
74
- axs [1 , N ].set_title (' Place field readout' ,fontsize = 20 )
72
+ axs [1 , N ].plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = "z" , linestyle = "--" ,linewidth = 3 ,color = " black" )
73
+ axs [1 , N ].plot (np .cos (theta ), np .sin (theta ), reconstructed , label = " Reconstructed" ,linewidth = 3 ,color = " limegreen" )
74
+ axs [1 , N ].set_title (" Place field readout" ,fontsize = 20 )
75
75
76
76
# Create animation
77
- loc_values = np .linspace (0 , 2 * np .pi , 100 )
77
+ loc_values = np .linspace (0 , 2 * np .pi , 100 )
78
78
ani = FuncAnimation (fig , update , frames = loc_values , repeat = True )
79
79
80
80
# Save the animation
81
- ani .save (' position_from_grid_cells.gif' , writer = ' imagemagick' , fps = 10 )
81
+ ani .save (" position_from_grid_cells.gif" , writer = " imagemagick" , fps = 10 )
0 commit comments