|
82 | 82 | "execution_count": 0,
|
83 | 83 | "outputs": []
|
84 | 84 | },
|
| 85 | + { |
| 86 | + "cell_type": "code", |
| 87 | + "metadata": { |
| 88 | + "id": "R-pmd345UT1M", |
| 89 | + "colab_type": "code", |
| 90 | + "colab": { |
| 91 | + "base_uri": "https://localhost:8080/", |
| 92 | + "height": 1000 |
| 93 | + }, |
| 94 | + "outputId": "8318fa7e-3589-4eaf-862f-c0771ff285d8" |
| 95 | + }, |
| 96 | + "source": [ |
| 97 | + "next(iter(train_loader))" |
| 98 | + ], |
| 99 | + "execution_count": 3, |
| 100 | + "outputs": [ |
| 101 | + { |
| 102 | + "output_type": "execute_result", |
| 103 | + "data": { |
| 104 | + "text/plain": [ |
| 105 | + "[tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n", |
| 106 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 107 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 108 | + " ...,\n", |
| 109 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 110 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 111 | + " [0., 0., 0., ..., 0., 0., 0.]]],\n", |
| 112 | + " \n", |
| 113 | + " \n", |
| 114 | + " [[[0., 0., 0., ..., 0., 0., 0.],\n", |
| 115 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 116 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 117 | + " ...,\n", |
| 118 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 119 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 120 | + " [0., 0., 0., ..., 0., 0., 0.]]],\n", |
| 121 | + " \n", |
| 122 | + " \n", |
| 123 | + " [[[0., 0., 0., ..., 0., 0., 0.],\n", |
| 124 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 125 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 126 | + " ...,\n", |
| 127 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 128 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 129 | + " [0., 0., 0., ..., 0., 0., 0.]]],\n", |
| 130 | + " \n", |
| 131 | + " \n", |
| 132 | + " ...,\n", |
| 133 | + " \n", |
| 134 | + " \n", |
| 135 | + " [[[0., 0., 0., ..., 0., 0., 0.],\n", |
| 136 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 137 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 138 | + " ...,\n", |
| 139 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 140 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 141 | + " [0., 0., 0., ..., 0., 0., 0.]]],\n", |
| 142 | + " \n", |
| 143 | + " \n", |
| 144 | + " [[[0., 0., 0., ..., 0., 0., 0.],\n", |
| 145 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 146 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 147 | + " ...,\n", |
| 148 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 149 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 150 | + " [0., 0., 0., ..., 0., 0., 0.]]],\n", |
| 151 | + " \n", |
| 152 | + " \n", |
| 153 | + " [[[0., 0., 0., ..., 0., 0., 0.],\n", |
| 154 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 155 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 156 | + " ...,\n", |
| 157 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 158 | + " [0., 0., 0., ..., 0., 0., 0.],\n", |
| 159 | + " [0., 0., 0., ..., 0., 0., 0.]]]]),\n", |
| 160 | + " tensor([2, 9, 4, 1, 5, 9, 3, 3, 1, 6, 5, 3, 7, 7, 1, 3, 4, 2, 1, 7, 4, 5, 6, 0,\n", |
| 161 | + " 2, 1, 4, 1, 4, 3, 7, 7, 7, 7, 4, 5, 6, 5, 5, 3, 6, 7, 3, 1, 9, 8, 1, 3,\n", |
| 162 | + " 8, 7, 7, 6, 0, 7, 9, 9, 7, 3, 5, 3, 3, 2, 3, 2, 6, 8, 9, 6, 2, 0, 3, 7,\n", |
| 163 | + " 4, 5, 7, 4, 6, 8, 1, 3, 7, 8, 0, 6, 0, 6, 1, 7, 0, 3, 5, 3, 3, 6, 7, 1,\n", |
| 164 | + " 5, 7, 0, 0])]" |
| 165 | + ] |
| 166 | + }, |
| 167 | + "metadata": { |
| 168 | + "tags": [] |
| 169 | + }, |
| 170 | + "execution_count": 3 |
| 171 | + } |
| 172 | + ] |
| 173 | + }, |
85 | 174 | {
|
86 | 175 | "cell_type": "code",
|
87 | 176 | "metadata": {
|
|
116 | 205 | "execution_count": 0,
|
117 | 206 | "outputs": []
|
118 | 207 | },
|
| 208 | + { |
| 209 | + "cell_type": "code", |
| 210 | + "metadata": { |
| 211 | + "id": "0i7c0JK1Vpgo", |
| 212 | + "colab_type": "code", |
| 213 | + "colab": {} |
| 214 | + }, |
| 215 | + "source": [ |
| 216 | + "from IPython.core.debugger import set_trace" |
| 217 | + ], |
| 218 | + "execution_count": 0, |
| 219 | + "outputs": [] |
| 220 | + }, |
| 221 | + { |
| 222 | + "cell_type": "code", |
| 223 | + "metadata": { |
| 224 | + "id": "zdmCH-uSer4Z", |
| 225 | + "colab_type": "code", |
| 226 | + "colab": {} |
| 227 | + }, |
| 228 | + "source": [ |
| 229 | + "class Model1(nn.Module):\n", |
| 230 | + " def __init__(self):\n", |
| 231 | + " super().__init__()\n", |
| 232 | + " # self.i_h = nn.Embedding(nv,nh) # green arrow\n", |
| 233 | + " self.h_h = nn.Linear(nh,nh) # brown arrow\n", |
| 234 | + " self.h_o = nn.Linear(nh,num_classes) # blue arrow\n", |
| 235 | + " self.bn = nn.BatchNorm1d(nh)\n", |
| 236 | + " \n", |
| 237 | + " def forward(self, x):\n", |
| 238 | + " h = torch.zeros(x.shape[0], nh).to(device=x.device)\n", |
| 239 | + " for i in range(x.shape[1]):\n", |
| 240 | + " # h = h + self.i_h(x[:,i])\n", |
| 241 | + " h = h + x[:,i]\n", |
| 242 | + " h = self.bn(F.relu(self.h_h(h)))\n", |
| 243 | + " return self.h_o(h)" |
| 244 | + ], |
| 245 | + "execution_count": 0, |
| 246 | + "outputs": [] |
| 247 | + }, |
| 248 | + { |
| 249 | + "cell_type": "code", |
| 250 | + "metadata": { |
| 251 | + "id": "Z6k7kCc9kPIP", |
| 252 | + "colab_type": "code", |
| 253 | + "colab": { |
| 254 | + "base_uri": "https://localhost:8080/", |
| 255 | + "height": 34 |
| 256 | + }, |
| 257 | + "outputId": "9d114180-e69a-4f41-a59b-70602c53b389" |
| 258 | + }, |
| 259 | + "source": [ |
| 260 | + "nh" |
| 261 | + ], |
| 262 | + "execution_count": 18, |
| 263 | + "outputs": [ |
| 264 | + { |
| 265 | + "output_type": "execute_result", |
| 266 | + "data": { |
| 267 | + "text/plain": [ |
| 268 | + "128" |
| 269 | + ] |
| 270 | + }, |
| 271 | + "metadata": { |
| 272 | + "tags": [] |
| 273 | + }, |
| 274 | + "execution_count": 18 |
| 275 | + } |
| 276 | + ] |
| 277 | + }, |
| 278 | + { |
| 279 | + "cell_type": "code", |
| 280 | + "metadata": { |
| 281 | + "id": "nCSiQLeYh9so", |
| 282 | + "colab_type": "code", |
| 283 | + "colab": {} |
| 284 | + }, |
| 285 | + "source": [ |
| 286 | + "nh = 28" |
| 287 | + ], |
| 288 | + "execution_count": 0, |
| 289 | + "outputs": [] |
| 290 | + }, |
| 291 | + { |
| 292 | + "cell_type": "code", |
| 293 | + "metadata": { |
| 294 | + "id": "CMYxcSDzkaFP", |
| 295 | + "colab_type": "code", |
| 296 | + "colab": {} |
| 297 | + }, |
| 298 | + "source": [ |
| 299 | + "import torch.nn.functional as F" |
| 300 | + ], |
| 301 | + "execution_count": 0, |
| 302 | + "outputs": [] |
| 303 | + }, |
| 304 | + { |
| 305 | + "cell_type": "code", |
| 306 | + "metadata": { |
| 307 | + "id": "R6cCAVPvewKm", |
| 308 | + "colab_type": "code", |
| 309 | + "colab": {} |
| 310 | + }, |
| 311 | + "source": [ |
| 312 | + "model = Model1().to(device)" |
| 313 | + ], |
| 314 | + "execution_count": 0, |
| 315 | + "outputs": [] |
| 316 | + }, |
119 | 317 | {
|
120 | 318 | "cell_type": "code",
|
121 | 319 | "metadata": {
|
|
125 | 323 | "base_uri": "https://localhost:8080/",
|
126 | 324 | "height": 218
|
127 | 325 | },
|
128 |
| - "outputId": "5e6bf8ff-34ff-4d39-8e93-c2741f1cf6f8" |
| 326 | + "outputId": "5aa78f4d-4cf2-4845-cec7-e148a356916f" |
129 | 327 | },
|
130 | 328 | "source": [
|
131 | 329 | "# Loss and optimizer\n",
|
|
136 | 334 | "total_step = len(train_loader)\n",
|
137 | 335 | "for epoch in range(num_epochs):\n",
|
138 | 336 | " for i, (images, labels) in enumerate(train_loader):\n",
|
| 337 | + " # set_trace()\n", |
139 | 338 | " images = images.reshape(-1, sequence_length, input_size).to(device)\n",
|
140 | 339 | " labels = labels.to(device)\n",
|
141 | 340 | " \n",
|
|
152 | 351 | " print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n",
|
153 | 352 | " .format(epoch+1, num_epochs, i+1, total_step, loss.item()))"
|
154 | 353 | ],
|
155 |
| - "execution_count": 22, |
| 354 | + "execution_count": 24, |
156 | 355 | "outputs": [
|
157 | 356 | {
|
158 | 357 | "output_type": "stream",
|
159 | 358 | "text": [
|
160 |
| - "Epoch [1/2], Step [100/600], Loss: 0.4692\n", |
161 |
| - "Epoch [1/2], Step [200/600], Loss: 0.2797\n", |
162 |
| - "Epoch [1/2], Step [300/600], Loss: 0.1271\n", |
163 |
| - "Epoch [1/2], Step [400/600], Loss: 0.2750\n", |
164 |
| - "Epoch [1/2], Step [500/600], Loss: 0.1792\n", |
165 |
| - "Epoch [1/2], Step [600/600], Loss: 0.0991\n", |
166 |
| - "Epoch [2/2], Step [100/600], Loss: 0.0826\n", |
167 |
| - "Epoch [2/2], Step [200/600], Loss: 0.1674\n", |
168 |
| - "Epoch [2/2], Step [300/600], Loss: 0.1562\n", |
169 |
| - "Epoch [2/2], Step [400/600], Loss: 0.1447\n", |
170 |
| - "Epoch [2/2], Step [500/600], Loss: 0.0842\n", |
171 |
| - "Epoch [2/2], Step [600/600], Loss: 0.0283\n" |
| 359 | + "Epoch [1/2], Step [100/600], Loss: 1.7748\n", |
| 360 | + "Epoch [1/2], Step [200/600], Loss: 1.5532\n", |
| 361 | + "Epoch [1/2], Step [300/600], Loss: 1.2281\n", |
| 362 | + "Epoch [1/2], Step [400/600], Loss: 1.3798\n", |
| 363 | + "Epoch [1/2], Step [500/600], Loss: 0.7575\n", |
| 364 | + "Epoch [1/2], Step [600/600], Loss: 1.1237\n", |
| 365 | + "Epoch [2/2], Step [100/600], Loss: 0.8953\n", |
| 366 | + "Epoch [2/2], Step [200/600], Loss: 0.8352\n", |
| 367 | + "Epoch [2/2], Step [300/600], Loss: 0.9067\n", |
| 368 | + "Epoch [2/2], Step [400/600], Loss: 0.8490\n", |
| 369 | + "Epoch [2/2], Step [500/600], Loss: 0.8371\n", |
| 370 | + "Epoch [2/2], Step [600/600], Loss: 1.0092\n" |
172 | 371 | ],
|
173 | 372 | "name": "stdout"
|
174 | 373 | }
|
|
183 | 382 | "base_uri": "https://localhost:8080/",
|
184 | 383 | "height": 34
|
185 | 384 | },
|
186 |
| - "outputId": "84f5134c-0b9e-42e3-fcaf-42f528a3f363" |
| 385 | + "outputId": "3940a25a-3267-4854-cdc5-3584f07d8143" |
187 | 386 | },
|
188 | 387 | "source": [
|
189 | 388 | "# Test the model\n",
|
|
203 | 402 | "# Save the model checkpoint\n",
|
204 | 403 | "# torch.save(model.state_dict(), 'model.ckpt')"
|
205 | 404 | ],
|
206 |
| - "execution_count": 23, |
| 405 | + "execution_count": 25, |
207 | 406 | "outputs": [
|
208 | 407 | {
|
209 | 408 | "output_type": "stream",
|
210 | 409 | "text": [
|
211 |
| - "Test Accuracy of the model on the 10000 test images: 97.55 %\n" |
| 410 | + "Test Accuracy of the model on the 10000 test images: 71.49 %\n" |
212 | 411 | ],
|
213 | 412 | "name": "stdout"
|
214 | 413 | }
|
|
219 | 418 | "metadata": {
|
220 | 419 | "id": "cs2NL18qEaz2",
|
221 | 420 | "colab_type": "code",
|
222 |
| - "colab": {} |
| 421 | + "colab": { |
| 422 | + "base_uri": "https://localhost:8080/", |
| 423 | + "height": 706 |
| 424 | + }, |
| 425 | + "outputId": "b683fa26-cbf8-43ec-838b-45a0bb7ef854" |
223 | 426 | },
|
224 | 427 | "source": [
|
225 |
| - "" |
| 428 | + "%debug" |
226 | 429 | ],
|
227 |
| - "execution_count": 0, |
228 |
| - "outputs": [] |
| 430 | + "execution_count": 17, |
| 431 | + "outputs": [ |
| 432 | + { |
| 433 | + "output_type": "stream", |
| 434 | + "text": [ |
| 435 | + "> \u001b[0;32m<ipython-input-12-8318d738d92a>\u001b[0m(13)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n", |
| 436 | + "\u001b[0;32m 11 \u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 437 | + "\u001b[0m\u001b[0;32m 12 \u001b[0;31m \u001b[0;31m# h = h + self.i_h(x[:,i])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 438 | + "\u001b[0m\u001b[0;32m---> 13 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 439 | + "\u001b[0m\u001b[0;32m 14 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_h\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 440 | + "\u001b[0m\u001b[0;32m 15 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_o\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 441 | + "\u001b[0m\n", |
| 442 | + "ipdb> h.shape\n", |
| 443 | + "*** No help for '.shape'\n", |
| 444 | + "ipdb> h\n", |
| 445 | + "\n", |
| 446 | + "Documented commands (type help <topic>):\n", |
| 447 | + "========================================\n", |
| 448 | + "EOF cl disable interact next psource rv unt \n", |
| 449 | + "a clear display j p q s until \n", |
| 450 | + "alias commands down jump pdef quit source up \n", |
| 451 | + "args condition enable l pdoc r step w \n", |
| 452 | + "b cont exit list pfile restart tbreak whatis\n", |
| 453 | + "break continue h ll pinfo return u where \n", |
| 454 | + "bt d help longlist pinfo2 retval unalias \n", |
| 455 | + "c debug ignore n pp run undisplay\n", |
| 456 | + "\n", |
| 457 | + "Miscellaneous help topics:\n", |
| 458 | + "==========================\n", |
| 459 | + "exec pdb\n", |
| 460 | + "\n", |
| 461 | + "ipdb> l\n", |
| 462 | + "\u001b[1;32m 8 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 463 | + "\u001b[1;32m 9 \u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 464 | + "\u001b[1;32m 10 \u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 465 | + "\u001b[1;32m 11 \u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 466 | + "\u001b[1;32m 12 \u001b[0m \u001b[0;31m# h = h + self.i_h(x[:,i])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 467 | + "\u001b[0;32m---> 13 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 468 | + "\u001b[0m\u001b[1;32m 14 \u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_h\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 469 | + "\u001b[1;32m 15 \u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_o\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 470 | + "\n", |
| 471 | + "ipdb> !h.shape\n", |
| 472 | + "torch.Size([100, 128])\n", |
| 473 | + "ipdb> x.shape\n", |
| 474 | + "torch.Size([100, 28, 28])\n", |
| 475 | + "ipdb> q\n" |
| 476 | + ], |
| 477 | + "name": "stdout" |
| 478 | + } |
| 479 | + ] |
229 | 480 | },
|
230 | 481 | {
|
231 | 482 | "cell_type": "code",
|
|
0 commit comments