14
14
15
15
def gaussian_on_circle (theta , loc , sigma = 0.1 ):
16
16
"""A Gaussian-like function defined on the circle."""
17
- return np .exp (- (theta - loc )** 2 / (2 * sigma ** 2 ))
17
+ return np .exp (- ((theta - loc ) ** 2 ) / (2 * sigma ** 2 ))
18
+
18
19
19
20
def relu (x ):
20
21
return np .maximum (0 , x )
21
22
23
+
22
24
# Function to plot a harmonic given amplitude and phase
23
25
def plot_harmonic (ax , amplitude , phase , n , label , activation = "relu" ):
24
-
25
26
harmonic_values = amplitude * np .cos (n * theta + phase )
26
27
if activation == "relu" :
27
28
harmonic_values = relu (harmonic_values )
28
- ax .plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = "z" , linestyle = "--" ,linewidth = 3 , color = "black" )
29
- normalized_phase = (phase + np .pi ) / (2 * np .pi ) # Normalizing from -π to π to 0 to 1
29
+ ax .plot (
30
+ np .cos (theta ),
31
+ np .sin (theta ),
32
+ zs = 0 ,
33
+ zdir = "z" ,
34
+ linestyle = "--" ,
35
+ linewidth = 3 ,
36
+ color = "black" ,
37
+ )
38
+ normalized_phase = (phase + np .pi ) / (
39
+ 2 * np .pi
40
+ ) # Normalizing from -π to π to 0 to 1
30
41
color = cm .hsv (normalized_phase )
31
- ax .plot (np .cos (theta ), np .sin (theta ), harmonic_values , label = label ,linewidth = 3 ,color = color ,alpha = 1 - 0.1 * n )
42
+ ax .plot (
43
+ np .cos (theta ),
44
+ np .sin (theta ),
45
+ harmonic_values ,
46
+ label = label ,
47
+ linewidth = 3 ,
48
+ color = color ,
49
+ alpha = 1 - 0.1 * n ,
50
+ )
32
51
ax .axis ("off" )
33
52
53
+
34
54
# Prepare figure for plotting
35
- fig , axs = plt .subplots (2 , N + 1 , figsize = (20 , 10 ), subplot_kw = {"projection" : "3d" })
55
+ fig , axs = plt .subplots (2 , N + 1 , figsize = (20 , 10 ), subplot_kw = {"projection" : "3d" })
36
56
plt .tight_layout ()
37
57
58
+
38
59
def update (loc ):
39
60
bump_samples = gaussian_on_circle (theta , loc = loc )
40
61
41
62
# Compute FFT
42
63
coefficients_fft = np .fft .fft (bump_samples )
43
- frequencies = np .fft .fftfreq (num_samples , d = (2 * np .pi / num_samples ))
64
+ frequencies = np .fft .fftfreq (num_samples , d = (2 * np .pi / num_samples ))
44
65
45
66
# Clear previous plots
46
67
for ax_row in axs :
@@ -49,32 +70,72 @@ def update(loc):
49
70
ax .axis ("off" )
50
71
51
72
# Plot original function
52
- axs [0 , 2 ].plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = "z" , linestyle = "--" ,linewidth = 3 ,color = "black" )
53
- axs [0 , 2 ].plot (np .cos (theta ), np .sin (theta ), bump_samples , label = "Original Function" ,linewidth = 3 ,color = "tomato" )
54
- axs [0 , 2 ].set_title (f"Target place field, position = { loc :.2f} " ,fontsize = 20 )
73
+ axs [0 , 2 ].plot (
74
+ np .cos (theta ),
75
+ np .sin (theta ),
76
+ zs = 0 ,
77
+ zdir = "z" ,
78
+ linestyle = "--" ,
79
+ linewidth = 3 ,
80
+ color = "black" ,
81
+ )
82
+ axs [0 , 2 ].plot (
83
+ np .cos (theta ),
84
+ np .sin (theta ),
85
+ bump_samples ,
86
+ label = "Original Function" ,
87
+ linewidth = 3 ,
88
+ color = "tomato" ,
89
+ )
90
+ axs [0 , 2 ].set_title (f"Target place field, position = { loc :.2f} " , fontsize = 20 )
55
91
axs [0 , 2 ].scatter (np .cos (loc ), np .sin (loc ), zs = 0 , zdir = "z" , s = 100 , c = "red" )
56
92
57
93
# Plot each harmonic and the reconstructed function
58
94
reconstructed = np .zeros (num_samples )
59
- for n in range (1 , N + 1 ):
95
+ for n in range (1 , N + 1 ):
60
96
index = n if frequencies [n ] >= 0 else num_samples + n
61
97
amplitude = np .abs (coefficients_fft [index ])
62
98
phase = np .angle (coefficients_fft [index ])
63
99
64
- plot_harmonic (axs [1 , n - 1 ], amplitude , phase , n , rf"GC module { n } , period $\lambda=${ L / n :0.1f} " , activation = activation )
65
- axs [1 , n - 1 ].set_title (rf"GC module { n } , period $\lambda_{ n } =${ L / n :0.1f} " ,fontsize = 18 )
100
+ plot_harmonic (
101
+ axs [1 , n - 1 ],
102
+ amplitude ,
103
+ phase ,
104
+ n ,
105
+ rf"GC module { n } , period $\lambda=${ L / n :0.1f} " ,
106
+ activation = activation ,
107
+ )
108
+ axs [1 , n - 1 ].set_title (
109
+ rf"GC module { n } , period $\lambda_{ n } =${ L / n :0.1f} " , fontsize = 18
110
+ )
66
111
if activation == "relu" :
67
112
reconstructed += relu (amplitude * np .cos (n * theta + phase ))
68
113
else :
69
114
reconstructed += amplitude * np .cos (n * theta + phase )
70
115
71
116
# Reconstructed function
72
- axs [1 , N ].plot (np .cos (theta ), np .sin (theta ), zs = 0 , zdir = "z" , linestyle = "--" ,linewidth = 3 ,color = "black" )
73
- axs [1 , N ].plot (np .cos (theta ), np .sin (theta ), reconstructed , label = "Reconstructed" ,linewidth = 3 ,color = "limegreen" )
74
- axs [1 , N ].set_title ("Place field readout" ,fontsize = 20 )
117
+ axs [1 , N ].plot (
118
+ np .cos (theta ),
119
+ np .sin (theta ),
120
+ zs = 0 ,
121
+ zdir = "z" ,
122
+ linestyle = "--" ,
123
+ linewidth = 3 ,
124
+ color = "black" ,
125
+ )
126
+ axs [1 , N ].plot (
127
+ np .cos (theta ),
128
+ np .sin (theta ),
129
+ reconstructed ,
130
+ label = "Reconstructed" ,
131
+ linewidth = 3 ,
132
+ color = "limegreen" ,
133
+ )
134
+ axs [1 , N ].set_title ("Place field readout" , fontsize = 20 )
135
+
75
136
76
137
# Create animation
77
- loc_values = np .linspace (0 , 2 * np .pi , 100 )
138
+ loc_values = np .linspace (0 , 2 * np .pi , 100 )
78
139
ani = FuncAnimation (fig , update , frames = loc_values , repeat = True )
79
140
80
141
# Save the animation
0 commit comments