@@ -95,9 +95,9 @@ def plot_2d_energy_levels(X, y, energy, v=None, l=None):
95
95
if not l : levels = None
96
96
else : levels = torch .arange (l [0 ], l [1 ], l [2 ])
97
97
plt .figure (figsize = (12 , 10 ))
98
- plt .pcolormesh (xx , yy , F , vmin = vmin , vmax = vmax )
98
+ plt .pcolormesh (xx . numpy () , yy . numpy () , F , vmin = vmin , vmax = vmax )
99
99
plt .colorbar ()
100
- cnt = plt .contour (xx , yy , F , colors = 'w' , linewidths = 1 , levels = levels )
100
+ cnt = plt .contour (xx . numpy () , yy . numpy () , F , colors = 'w' , linewidths = 1 , levels = levels )
101
101
plt .clabel (cnt , inline = True , fontsize = 10 , colors = 'w' )
102
102
s = plot_data (X , y )
103
103
plt .legend (* s .legend_elements (), title = 'Classes' , loc = 'lower right' )
@@ -116,7 +116,7 @@ def plot_3d_energy_levels(X, y, energy, v=None, l=None, cbl=None):
116
116
else : levels = torch .arange (l [0 ], l [1 ], l [2 ])
117
117
fig = plt .figure (figsize = (9.5 , 6 ), facecolor = 'k' )
118
118
ax = fig .add_subplot (projection = '3d' )
119
- cnt = ax .contour (xx , yy , F , levels = levels , vmin = vmin , vmax = vmax )
119
+ cnt = ax .contour (xx . numpy () , yy . numpy () , F , levels = levels , vmin = vmin , vmax = vmax )
120
120
ax .scatter (X [:,0 ], X [:,1 ], zs = 0 , c = y , cmap = plt .cm .Spectral )
121
121
ax .xaxis .set_pane_color (color = (0 ,0 ,0 ))
122
122
ax .yaxis .set_pane_color (color = (0 ,0 ,0 ))
@@ -129,7 +129,7 @@ def plot_3d_energy_levels(X, y, energy, v=None, l=None, cbl=None):
129
129
else : cbl = torch .arange (cbl [0 ], cbl [1 ], cbl [2 ])
130
130
sm = plt .cm .ScalarMappable (norm = norm , cmap = cnt .cmap )
131
131
sm .set_array ([])
132
- fig .colorbar (sm , ticks = cbl )
132
+ fig .colorbar (sm , ticks = cbl , ax = ax )
133
133
ȳ = torch .zeros (K ).int (); ȳ [k ] = 1
134
134
plt .title (f'Free energy F(x, y = { ȳ .tolist ()} )' )
135
135
plt .tight_layout ()
0 commit comments