|
132 | 132 | ")\n",
|
133 | 133 | "model.to(device) # possibly send to CUDA\n",
|
134 | 134 | "\n",
|
135 |
| - "# Cross entropy given the output logits\n", |
| 135 | + "# Cross entropy given the linear output\n", |
136 | 136 | "C = torch.nn.CrossEntropyLoss(reduction='none')\n",
|
137 | 137 | "\n",
|
138 | 138 | "# Using Adam optimiser\n",
|
|
141 | 141 | "# Full-batch training loop\n",
|
142 | 142 | "for t in range(2_000):\n",
|
143 | 143 | " \n",
|
144 |
| - " # Feed forward to get the logits\n", |
145 |
| - " l = model(X)\n", |
| 144 | + " # Feed forward to get the linear sum s\n", |
| 145 | + " s = model(X)\n", |
146 | 146 | " \n",
|
147 |
| - " # Compute the free energy F\n", |
148 |
| - " F = C(l, y)\n", |
| 147 | + " # Compute the free energy F and loss L\n", |
| 148 | + " F = C(s, y)\n", |
149 | 149 | " L = F.mean()\n",
|
150 | 150 | " \n",
|
151 | 151 | " # Zero the gradients\n",
|
|
159 | 159 | " optimiser.step()\n",
|
160 | 160 | " \n",
|
161 | 161 | " # Display epoch, L, and accuracy\n",
|
162 |
| - " overwrite(f'[EPOCH]: {t}, [LOSS]: {L.item():.6f}, [ACCURACY]: {acc(l, y):.3f}')" |
| 162 | + " overwrite(f'[EPOCH]: {t}, [LOSS]: {L.item():.6f}, [ACCURACY]: {acc(s, y):.3f}')" |
163 | 163 | ]
|
164 | 164 | },
|
165 | 165 | {
|
|
189 | 189 | "metadata": {},
|
190 | 190 | "outputs": [],
|
191 | 191 | "source": [
|
192 |
| - "# Compute logits for a fine grid over the input space\n", |
| 192 | + "# Compute linear output s for a fine grid over the input space\n", |
193 | 193 | "\n",
|
194 | 194 | "mesh = torch.arange(-1.5, 1.5, 0.01)\n",
|
195 | 195 | "xx, yy = torch.meshgrid(mesh, mesh)\n",
|
196 | 196 | "grid = torch.stack((xx.reshape(-1), yy.reshape(-1)), dim=1)\n",
|
197 | 197 | "with torch.no_grad():\n",
|
198 |
| - " logits = model(grid)" |
| 198 | + " s = model(grid)" |
199 | 199 | ]
|
200 | 200 | },
|
201 | 201 | {
|
|
207 | 207 | "# Choice of free energy\n",
|
208 | 208 | "\n",
|
209 | 209 | "fe = 'cross-entropy'\n",
|
210 |
| - "fe = 'negative logit'" |
| 210 | + "fe = 'negative linear output'" |
211 | 211 | ]
|
212 | 212 | },
|
213 | 213 | {
|
|
242 | 242 | "\n",
|
243 | 243 | "for k in range(K):\n",
|
244 | 244 | " if fe == 'cross-entropy':\n",
|
245 |
| - " F = C(logits, torch.LongTensor(1).fill_(k).expand(logits.size(0)))\n", |
| 245 | + " F = C(s, torch.LongTensor(1).fill_(k).expand(s.size(0)))\n", |
246 | 246 | " F = F.reshape(xx.shape)\n",
|
247 | 247 | " plot_2d_energy_levels(X, y, (xx, yy, F, k, K), (0, 35), (1, 35, 4))\n",
|
248 | 248 | "\n",
|
249 |
| - " elif fe == 'negative logit':\n", |
250 |
| - " F = -logits[:, k]\n", |
| 249 | + " elif fe == 'negative linear output':\n", |
| 250 | + " F = -s[:, k]\n", |
251 | 251 | " F = F.reshape(xx.shape)\n",
|
252 | 252 | " plot_2d_energy_levels(X, y, (xx, yy, F, k, K), (-20, 20), (-20, 21, 2.5))\n",
|
253 | 253 | " \n",
|
|
282 | 282 | "# Cross-entropy\n",
|
283 | 283 | "if fe == 'cross-entropy':\n",
|
284 | 284 | " fig, ax = plot_3d_energy_levels(X, y, (xx, yy, F, k, K), (0, 18), (0, 19, 1), (0, 19, 2))\n",
|
285 |
| - "elif fe == 'negative logit':\n", |
| 285 | + "elif fe == 'negative linear output':\n", |
286 | 286 | " fig, ax = plot_3d_energy_levels(X, y, (xx, yy, F, k, K), (-30, 20), (-30, 20, 1), (-30, 21, 5))"
|
287 | 287 | ]
|
288 | 288 | },
|
|
336 | 336 | "name": "python",
|
337 | 337 | "nbconvert_exporter": "python",
|
338 | 338 | "pygments_lexer": "ipython3",
|
339 |
| - "version": "3.10.12" |
| 339 | + "version": "3.10.13" |
340 | 340 | }
|
341 | 341 | },
|
342 | 342 | "nbformat": 4,
|
|
0 commit comments