Skip to content

Commit bfebf66

Browse files
committed
Compute energy for interpolation in ambient space
1 parent 43a92d7 commit bfebf66

File tree

2 files changed

+84
-45
lines changed

2 files changed

+84
-45
lines changed

10-autoencoder.ipynb

+36-13
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
"from torch.utils.data import DataLoader\n",
1515
"from torchvision import transforms\n",
1616
"from torchvision.datasets import MNIST\n",
17-
"from matplotlib import pyplot as plt"
17+
"from matplotlib import pyplot as plt\n",
18+
"from res.plot_lib import set_default\n",
19+
"set_default(figsize=(15, 4))"
1820
]
1921
},
2022
{
@@ -145,9 +147,7 @@
145147
{
146148
"cell_type": "code",
147149
"execution_count": null,
148-
"metadata": {
149-
"scrolled": false
150-
},
150+
"metadata": {},
151151
"outputs": [],
152152
"source": [
153153
"# Train standard or denoising autoencoder (AE)\n",
@@ -176,9 +176,7 @@
176176
{
177177
"cell_type": "code",
178178
"execution_count": null,
179-
"metadata": {
180-
"scrolled": false
181-
},
179+
"metadata": {},
182180
"outputs": [],
183181
"source": [
184182
"# Visualise a few kernels of the encoder\n",
@@ -247,9 +245,7 @@
247245
{
248246
"cell_type": "code",
249247
"execution_count": null,
250-
"metadata": {
251-
"scrolled": false
252-
},
248+
"metadata": {},
253249
"outputs": [],
254250
"source": [
255251
"# Train standard or denoising autoencoder (AE)\n",
@@ -336,6 +332,15 @@
336332
" display_images(TELEA, NS)"
337333
]
338334
},
335+
{
336+
"cell_type": "markdown",
337+
"metadata": {},
338+
"source": [
339+
"# Experimenting\n",
340+
"\n",
341+
"The section below needs some more love and refactorin"
342+
]
343+
},
339344
{
340345
"cell_type": "code",
341346
"execution_count": null,
@@ -351,7 +356,8 @@
351356
"metadata": {},
352357
"outputs": [],
353358
"source": [
354-
"y = (img[1:2] + img[15:16])/2\n",
359+
"A, B = 0, 9\n",
360+
"y = (img[[A]] + img[[B]])/2\n",
355361
"with torch.no_grad():\n",
356362
" ỹ = model(y)\n",
357363
"plt.figure(figsize=(10,5))\n",
@@ -360,11 +366,28 @@
360366
"plt.subplot(122)\n",
361367
"plt.imshow(to_img(ỹ).squeeze())"
362368
]
369+
},
370+
{
371+
"cell_type": "code",
372+
"execution_count": null,
373+
"metadata": {},
374+
"outputs": [],
375+
"source": [
376+
"N = 16\n",
377+
"samples = torch.Tensor(N, 28 * 28).to(device)\n",
378+
"for i in range(N):\n",
379+
" samples[i] = i / (N - 1) * img[B].data + (1 - i / (N - 1) ) * img[A].data\n",
380+
"with torch.no_grad():\n",
381+
" reconstructions = model(samples)[0]\n",
382+
"\n",
383+
"plt.title(f'{A = }, {B = }')\n",
384+
"plt.plot(samples.sub(reconstructions).pow(2).sum(dim=(1)), '-o')"
385+
]
363386
}
364387
],
365388
"metadata": {
366389
"kernelspec": {
367-
"display_name": "Python 3",
390+
"display_name": "Python 3 (ipykernel)",
368391
"language": "python",
369392
"name": "python3"
370393
},
@@ -378,7 +401,7 @@
378401
"name": "python",
379402
"nbconvert_exporter": "python",
380403
"pygments_lexer": "ipython3",
381-
"version": "3.8.3"
404+
"version": "3.8.13"
382405
}
383406
},
384407
"nbformat": 4,

11-VAE.ipynb

+48-32
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
"from torch.utils.data import DataLoader\n",
1313
"from torchvision import transforms\n",
1414
"from torchvision.datasets import MNIST\n",
15-
"from matplotlib import pyplot as plt"
15+
"from matplotlib import pyplot as plt\n",
16+
"from res.plot_lib import set_default"
1617
]
1718
},
1819
{
@@ -23,7 +24,7 @@
2324
"source": [
2425
"# Displaying routine\n",
2526
"\n",
26-
"def display_images(in_, out, n=1, label=None, count=False):\n",
27+
"def display_images(in_, out, n=1, label='', count=False, energy=None):\n",
2728
" for N in range(n):\n",
2829
" if in_ is not None:\n",
2930
" in_pic = in_.data.cpu().view(-1, 28, 28)\n",
@@ -39,7 +40,9 @@
3940
" plt.subplot(1,4,i+1)\n",
4041
" plt.imshow(out_pic[i+4*N])\n",
4142
" plt.axis('off')\n",
42-
" if count: plt.title(str(4 * N + i), color='w')"
43+
" c = 4 * N + i\n",
44+
" if count: plt.title(str(c), color='w')\n",
45+
" if count and energy is not None: plt.title(f'{c}, e={energy[c].item():.2f}', color='w')\n"
4346
]
4447
},
4548
{
@@ -166,9 +169,7 @@
166169
{
167170
"cell_type": "code",
168171
"execution_count": null,
169-
"metadata": {
170-
"scrolled": false
171-
},
172+
"metadata": {},
172173
"outputs": [],
173174
"source": [
174175
"# Training and testing the VAE\n",
@@ -238,8 +239,10 @@
238239
"outputs": [],
239240
"source": [
240241
"# Display last test batch\n",
241-
"\n",
242-
"display_images(None, y, 4, count=True)"
242+
"with torch.no_grad():\n",
243+
" ỹ = model(y)[0].view(-1, 28, 28)\n",
244+
"energy = y.squeeze().sub(ỹ).pow(2).sum(dim=(1,2))\n",
245+
"display_images(None, y, 4, count=True, energy=energy)"
243246
]
244247
},
245248
{
@@ -250,14 +253,18 @@
250253
"source": [
251254
"# Choose starting and ending point for the interpolation -> shows original and reconstructed\n",
252255
"\n",
253-
"A, B = 1, 14\n",
256+
"A, B = 0, 6\n",
254257
"sample = model.decoder(torch.stack((mu[A].data, mu[B].data), 0))\n",
255258
"display_images(None, torch.stack(((\n",
256259
" y[A].data.view(-1),\n",
257260
" y[B].data.view(-1),\n",
258261
" sample.data[0],\n",
259-
" sample.data[1]\n",
260-
")), 0))"
262+
" sample.data[1],\n",
263+
" sample.data[0],\n",
264+
" sample.data[1],\n",
265+
" y[A].data.view(-1) - sample.data[0],\n",
266+
" y[B].data.view(-1) - sample.data[1]\n",
267+
")), 0), 2)"
261268
]
262269
},
263270
{
@@ -269,13 +276,13 @@
269276
"# Perform an interpolation between input A and B, in N steps\n",
270277
"\n",
271278
"N = 16\n",
272-
"code = torch.Tensor(N, 20).to(device)\n",
273-
"sample = torch.Tensor(N, 28, 28).to(device)\n",
279+
"# code = torch.Tensor(N, 20).to(device)\n",
280+
"samples = torch.Tensor(N, 28, 28).to(device)\n",
274281
"for i in range(N):\n",
275-
" #code[i] = i / (N - 1) * mu[B].data + (1 - i / (N - 1) ) * mu[A].data\n",
276-
" sample[i] = i / (N - 1) * y[B].data + (1 - i / (N - 1) ) * y[A].data\n",
277-
"# sample = model.decoder(code)\n",
278-
"display_images(None, sample, N // 4, count=True)"
282+
" # code[i] = i / (N - 1) * mu[B].data + (1 - i / (N - 1) ) * mu[A].data\n",
283+
" samples[i] = i / (N - 1) * y[B].data + (1 - i / (N - 1) ) * y[A].data\n",
284+
"# samples = model.decoder(code)\n",
285+
"display_images(None, samples, N // 4, count=True)"
279286
]
280287
},
281288
{
@@ -288,10 +295,27 @@
288295
"with torch.no_grad():\n",
289296
" ỹ = model(ẏ)[0]\n",
290297
"plt.figure(figsize=(10,5))\n",
291-
"plt.subplot(121)\n",
292-
"plt.imshow((ẏ).view(28, 28))\n",
293-
"plt.subplot(122)\n",
294-
"plt.imshow((ỹ).view(28, 28))"
298+
"plt.subplot(121), plt.imshow((ẏ).view(28, 28))\n",
299+
"plt.subplot(122), plt.imshow((ỹ).view(28, 28))"
300+
]
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": null,
305+
"metadata": {
306+
"tags": []
307+
},
308+
"outputs": [],
309+
"source": [
310+
"N = 16\n",
311+
"samples = torch.Tensor(N, 28, 28).to(device)\n",
312+
"for i in range(N):\n",
313+
" samples[i] = i / (N - 1) * y[B].data + (1 - i / (N - 1) ) * y[A].data\n",
314+
"with torch.no_grad():\n",
315+
" reconstructions = model(samples)[0].view(-1, 28, 28)\n",
316+
"\n",
317+
"plt.title(f'{A = }, {B = }')\n",
318+
"plt.plot(samples.sub(reconstructions).pow(2).sum(dim=(1,2)), '-o')"
295319
]
296320
},
297321
{
@@ -301,8 +325,7 @@
301325
"outputs": [],
302326
"source": [
303327
"import numpy as np\n",
304-
"from sklearn.manifold import TSNE\n",
305-
"from res.plot_lib import set_default"
328+
"from sklearn.manifold import TSNE"
306329
]
307330
},
308331
{
@@ -343,18 +366,11 @@
343366
" a[i].axis('equal')\n",
344367
"f.colorbar(s, ax=a[:], ticks=np.arange(10), boundaries=np.arange(11) - .5)"
345368
]
346-
},
347-
{
348-
"cell_type": "code",
349-
"execution_count": null,
350-
"metadata": {},
351-
"outputs": [],
352-
"source": []
353369
}
354370
],
355371
"metadata": {
356372
"kernelspec": {
357-
"display_name": "Python 3",
373+
"display_name": "Python 3 (ipykernel)",
358374
"language": "python",
359375
"name": "python3"
360376
},
@@ -368,7 +384,7 @@
368384
"name": "python",
369385
"nbconvert_exporter": "python",
370386
"pygments_lexer": "ipython3",
371-
"version": "3.8.3"
387+
"version": "3.8.13"
372388
}
373389
},
374390
"nbformat": 4,

0 commit comments

Comments
 (0)