Skip to content

Commit 1c516bf

Browse files
committed
Switch from logit to linear-sum
1 parent 394147c commit 1c516bf

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

04-spiral_classification.ipynb

+14-14
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
")\n",
133133
"model.to(device) # possibly send to CUDA\n",
134134
"\n",
135-
"# Cross entropy given the output logits\n",
135+
"# Cross entropy given the linear output\n",
136136
"C = torch.nn.CrossEntropyLoss(reduction='none')\n",
137137
"\n",
138138
"# Using Adam optimiser\n",
@@ -141,11 +141,11 @@
141141
"# Full-batch training loop\n",
142142
"for t in range(2_000):\n",
143143
" \n",
144-
" # Feed forward to get the logits\n",
145-
" l = model(X)\n",
144+
" # Feed forward to get the linear sum s\n",
145+
" s = model(X)\n",
146146
" \n",
147-
" # Compute the free energy F\n",
148-
" F = C(l, y)\n",
147+
" # Compute the free energy F and loss L\n",
148+
" F = C(s, y)\n",
149149
" L = F.mean()\n",
150150
" \n",
151151
" # Zero the gradients\n",
@@ -159,7 +159,7 @@
159159
" optimiser.step()\n",
160160
" \n",
161161
" # Display epoch, L, and accuracy\n",
162-
" overwrite(f'[EPOCH]: {t}, [LOSS]: {L.item():.6f}, [ACCURACY]: {acc(l, y):.3f}')"
162+
" overwrite(f'[EPOCH]: {t}, [LOSS]: {L.item():.6f}, [ACCURACY]: {acc(s, y):.3f}')"
163163
]
164164
},
165165
{
@@ -189,13 +189,13 @@
189189
"metadata": {},
190190
"outputs": [],
191191
"source": [
192-
"# Compute logits for a fine grid over the input space\n",
192+
"# Compute linear output s for a fine grid over the input space\n",
193193
"\n",
194194
"mesh = torch.arange(-1.5, 1.5, 0.01)\n",
195195
"xx, yy = torch.meshgrid(mesh, mesh)\n",
196196
"grid = torch.stack((xx.reshape(-1), yy.reshape(-1)), dim=1)\n",
197197
"with torch.no_grad():\n",
198-
" logits = model(grid)"
198+
" s = model(grid)"
199199
]
200200
},
201201
{
@@ -207,7 +207,7 @@
207207
"# Choice of free energy\n",
208208
"\n",
209209
"fe = 'cross-entropy'\n",
210-
"fe = 'negative logit'"
210+
"fe = 'negative linear output'"
211211
]
212212
},
213213
{
@@ -242,12 +242,12 @@
242242
"\n",
243243
"for k in range(K):\n",
244244
" if fe == 'cross-entropy':\n",
245-
" F = C(logits, torch.LongTensor(1).fill_(k).expand(logits.size(0)))\n",
245+
" F = C(s, torch.LongTensor(1).fill_(k).expand(s.size(0)))\n",
246246
" F = F.reshape(xx.shape)\n",
247247
" plot_2d_energy_levels(X, y, (xx, yy, F, k, K), (0, 35), (1, 35, 4))\n",
248248
"\n",
249-
" elif fe == 'negative logit':\n",
250-
" F = -logits[:, k]\n",
249+
" elif fe == 'negative linear output':\n",
250+
" F = -s[:, k]\n",
251251
" F = F.reshape(xx.shape)\n",
252252
" plot_2d_energy_levels(X, y, (xx, yy, F, k, K), (-20, 20), (-20, 21, 2.5))\n",
253253
" \n",
@@ -282,7 +282,7 @@
282282
"# Cross-entropy\n",
283283
"if fe == 'cross-entropy':\n",
284284
" fig, ax = plot_3d_energy_levels(X, y, (xx, yy, F, k, K), (0, 18), (0, 19, 1), (0, 19, 2))\n",
285-
"elif fe == 'negative logit':\n",
285+
"elif fe == 'negative linear output':\n",
286286
" fig, ax = plot_3d_energy_levels(X, y, (xx, yy, F, k, K), (-30, 20), (-30, 20, 1), (-30, 21, 5))"
287287
]
288288
},
@@ -336,7 +336,7 @@
336336
"name": "python",
337337
"nbconvert_exporter": "python",
338338
"pygments_lexer": "ipython3",
339-
"version": "3.10.12"
339+
"version": "3.10.13"
340340
}
341341
},
342342
"nbformat": 4,

0 commit comments

Comments
 (0)