Skip to content

Commit 394147c

Browse files
committed
Switch to JupyterLab interactive matplotlib
1 parent 98e0368 commit 394147c

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

04-spiral_classification.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@
207207
"# Choice of free energy\n",
208208
"\n",
209209
"fe = 'cross-entropy'\n",
210-
"# fe = 'negative logit'"
210+
"fe = 'negative logit'"
211211
]
212212
},
213213
{
@@ -270,8 +270,7 @@
270270
"outputs": [],
271271
"source": [
272272
"# Switch to interactive matplotlib\n",
273-
"%matplotlib notebook\n",
274-
"set_default()"
273+
"%matplotlib widget"
275274
]
276275
},
277276
{

res/plot_lib.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def plot_2d_energy_levels(X, y, energy, v=None, l=None):
9595
if not l: levels = None
9696
else: levels = torch.arange(l[0], l[1], l[2])
9797
plt.figure(figsize=(12, 10))
98-
plt.pcolormesh(xx, yy, F, vmin=vmin, vmax=vmax)
98+
plt.pcolormesh(xx.numpy(), yy.numpy(), F, vmin=vmin, vmax=vmax)
9999
plt.colorbar()
100-
cnt = plt.contour(xx, yy, F, colors='w', linewidths=1, levels=levels)
100+
cnt = plt.contour(xx.numpy(), yy.numpy(), F, colors='w', linewidths=1, levels=levels)
101101
plt.clabel(cnt, inline=True, fontsize=10, colors='w')
102102
s = plot_data(X, y)
103103
plt.legend(*s.legend_elements(), title='Classes', loc='lower right')
@@ -116,7 +116,7 @@ def plot_3d_energy_levels(X, y, energy, v=None, l=None, cbl=None):
116116
else: levels = torch.arange(l[0], l[1], l[2])
117117
fig = plt.figure(figsize=(9.5, 6), facecolor='k')
118118
ax = fig.add_subplot(projection='3d')
119-
cnt = ax.contour(xx, yy, F, levels=levels, vmin=vmin, vmax=vmax)
119+
cnt = ax.contour(xx.numpy(), yy.numpy(), F, levels=levels, vmin=vmin, vmax=vmax)
120120
ax.scatter(X[:,0], X[:,1], zs=0, c=y, cmap=plt.cm.Spectral)
121121
ax.xaxis.set_pane_color(color=(0,0,0))
122122
ax.yaxis.set_pane_color(color=(0,0,0))
@@ -129,7 +129,7 @@ def plot_3d_energy_levels(X, y, energy, v=None, l=None, cbl=None):
129129
else: cbl = torch.arange(cbl[0], cbl[1], cbl[2])
130130
sm = plt.cm.ScalarMappable(norm=norm, cmap=cnt.cmap)
131131
sm.set_array([])
132-
fig.colorbar(sm, ticks=cbl)
132+
fig.colorbar(sm, ticks=cbl, ax=ax)
133133
= torch.zeros(K).int(); [k] = 1
134134
plt.title(f'Free energy F(x, y = {.tolist()})')
135135
plt.tight_layout()

0 commit comments

Comments
 (0)