|
12 | 12 | "from torch.utils.data import DataLoader\n",
|
13 | 13 | "from torchvision import transforms\n",
|
14 | 14 | "from torchvision.datasets import MNIST\n",
|
15 |
| - "from matplotlib import pyplot as plt" |
| 15 | + "from matplotlib import pyplot as plt\n", |
| 16 | + "from res.plot_lib import set_default" |
16 | 17 | ]
|
17 | 18 | },
|
18 | 19 | {
|
|
23 | 24 | "source": [
|
24 | 25 | "# Displaying routine\n",
|
25 | 26 | "\n",
|
26 |
| - "def display_images(in_, out, n=1, label=None, count=False):\n", |
| 27 | + "def display_images(in_, out, n=1, label='', count=False, energy=None):\n", |
27 | 28 | " for N in range(n):\n",
|
28 | 29 | " if in_ is not None:\n",
|
29 | 30 | " in_pic = in_.data.cpu().view(-1, 28, 28)\n",
|
|
39 | 40 | " plt.subplot(1,4,i+1)\n",
|
40 | 41 | " plt.imshow(out_pic[i+4*N])\n",
|
41 | 42 | " plt.axis('off')\n",
|
42 |
| - " if count: plt.title(str(4 * N + i), color='w')" |
| 43 | + " c = 4 * N + i\n", |
| 44 | + " if count: plt.title(str(c), color='w')\n", |
| 45 | + " if count and energy is not None: plt.title(f'{c}, e={energy[c].item():.2f}', color='w')\n" |
43 | 46 | ]
|
44 | 47 | },
|
45 | 48 | {
|
|
166 | 169 | {
|
167 | 170 | "cell_type": "code",
|
168 | 171 | "execution_count": null,
|
169 |
| - "metadata": { |
170 |
| - "scrolled": false |
171 |
| - }, |
| 172 | + "metadata": {}, |
172 | 173 | "outputs": [],
|
173 | 174 | "source": [
|
174 | 175 | "# Training and testing the VAE\n",
|
|
238 | 239 | "outputs": [],
|
239 | 240 | "source": [
|
240 | 241 | "# Display last test batch\n",
|
241 |
| - "\n", |
242 |
| - "display_images(None, y, 4, count=True)" |
| 242 | + "with torch.no_grad():\n", |
| 243 | + " ỹ = model(y)[0].view(-1, 28, 28)\n", |
| 244 | + "energy = y.squeeze().sub(ỹ).pow(2).sum(dim=(1,2))\n", |
| 245 | + "display_images(None, y, 4, count=True, energy=energy)" |
243 | 246 | ]
|
244 | 247 | },
|
245 | 248 | {
|
|
250 | 253 | "source": [
|
251 | 254 | "# Choose starting and ending point for the interpolation -> shows original and reconstructed\n",
|
252 | 255 | "\n",
|
253 |
| - "A, B = 1, 14\n", |
| 256 | + "A, B = 0, 6\n", |
254 | 257 | "sample = model.decoder(torch.stack((mu[A].data, mu[B].data), 0))\n",
|
255 | 258 | "display_images(None, torch.stack(((\n",
|
256 | 259 | " y[A].data.view(-1),\n",
|
257 | 260 | " y[B].data.view(-1),\n",
|
258 | 261 | " sample.data[0],\n",
|
259 |
| - " sample.data[1]\n", |
260 |
| - ")), 0))" |
| 262 | + " sample.data[1],\n", |
| 263 | + " sample.data[0],\n", |
| 264 | + " sample.data[1],\n", |
| 265 | + " y[A].data.view(-1) - sample.data[0],\n", |
| 266 | + " y[B].data.view(-1) - sample.data[1]\n", |
| 267 | + ")), 0), 2)" |
261 | 268 | ]
|
262 | 269 | },
|
263 | 270 | {
|
|
269 | 276 | "# Perform an interpolation between input A and B, in N steps\n",
|
270 | 277 | "\n",
|
271 | 278 | "N = 16\n",
|
272 |
| - "code = torch.Tensor(N, 20).to(device)\n", |
273 |
| - "sample = torch.Tensor(N, 28, 28).to(device)\n", |
| 279 | + "# code = torch.Tensor(N, 20).to(device)\n", |
| 280 | + "samples = torch.Tensor(N, 28, 28).to(device)\n", |
274 | 281 | "for i in range(N):\n",
|
275 |
| - " #code[i] = i / (N - 1) * mu[B].data + (1 - i / (N - 1) ) * mu[A].data\n", |
276 |
| - " sample[i] = i / (N - 1) * y[B].data + (1 - i / (N - 1) ) * y[A].data\n", |
277 |
| - "# sample = model.decoder(code)\n", |
278 |
| - "display_images(None, sample, N // 4, count=True)" |
| 282 | + " # code[i] = i / (N - 1) * mu[B].data + (1 - i / (N - 1) ) * mu[A].data\n", |
| 283 | + " samples[i] = i / (N - 1) * y[B].data + (1 - i / (N - 1) ) * y[A].data\n", |
| 284 | + "# samples = model.decoder(code)\n", |
| 285 | + "display_images(None, samples, N // 4, count=True)" |
279 | 286 | ]
|
280 | 287 | },
|
281 | 288 | {
|
|
288 | 295 | "with torch.no_grad():\n",
|
289 | 296 | " ỹ = model(ẏ)[0]\n",
|
290 | 297 | "plt.figure(figsize=(10,5))\n",
|
291 |
| - "plt.subplot(121)\n", |
292 |
| - "plt.imshow((ẏ).view(28, 28))\n", |
293 |
| - "plt.subplot(122)\n", |
294 |
| - "plt.imshow((ỹ).view(28, 28))" |
| 298 | + "plt.subplot(121), plt.imshow((ẏ).view(28, 28))\n", |
| 299 | + "plt.subplot(122), plt.imshow((ỹ).view(28, 28))" |
| 300 | + ] |
| 301 | + }, |
| 302 | + { |
| 303 | + "cell_type": "code", |
| 304 | + "execution_count": null, |
| 305 | + "metadata": { |
| 306 | + "tags": [] |
| 307 | + }, |
| 308 | + "outputs": [], |
| 309 | + "source": [ |
| 310 | + "N = 16\n", |
| 311 | + "samples = torch.Tensor(N, 28, 28).to(device)\n", |
| 312 | + "for i in range(N):\n", |
| 313 | + " samples[i] = i / (N - 1) * y[B].data + (1 - i / (N - 1) ) * y[A].data\n", |
| 314 | + "with torch.no_grad():\n", |
| 315 | + " reconstructions = model(samples)[0].view(-1, 28, 28)\n", |
| 316 | + "\n", |
| 317 | + "plt.title(f'{A = }, {B = }')\n", |
| 318 | + "plt.plot(samples.sub(reconstructions).pow(2).sum(dim=(1,2)), '-o')" |
295 | 319 | ]
|
296 | 320 | },
|
297 | 321 | {
|
|
301 | 325 | "outputs": [],
|
302 | 326 | "source": [
|
303 | 327 | "import numpy as np\n",
|
304 |
| - "from sklearn.manifold import TSNE\n", |
305 |
| - "from res.plot_lib import set_default" |
| 328 | + "from sklearn.manifold import TSNE" |
306 | 329 | ]
|
307 | 330 | },
|
308 | 331 | {
|
|
343 | 366 | " a[i].axis('equal')\n",
|
344 | 367 | "f.colorbar(s, ax=a[:], ticks=np.arange(10), boundaries=np.arange(11) - .5)"
|
345 | 368 | ]
|
346 |
| - }, |
347 |
| - { |
348 |
| - "cell_type": "code", |
349 |
| - "execution_count": null, |
350 |
| - "metadata": {}, |
351 |
| - "outputs": [], |
352 |
| - "source": [] |
353 | 369 | }
|
354 | 370 | ],
|
355 | 371 | "metadata": {
|
356 | 372 | "kernelspec": {
|
357 |
| - "display_name": "Python 3", |
| 373 | + "display_name": "Python 3 (ipykernel)", |
358 | 374 | "language": "python",
|
359 | 375 | "name": "python3"
|
360 | 376 | },
|
|
368 | 384 | "name": "python",
|
369 | 385 | "nbconvert_exporter": "python",
|
370 | 386 | "pygments_lexer": "ipython3",
|
371 |
| - "version": "3.8.3" |
| 387 | + "version": "3.8.13" |
372 | 388 | }
|
373 | 389 | },
|
374 | 390 | "nbformat": 4,
|
|
0 commit comments