Skip to content

Commit 89709ba

Browse files
committed
Created using Colaboratory
1 parent 73be7e6 commit 89709ba

File tree

1 file changed

+272
-21
lines changed

1 file changed

+272
-21
lines changed

deep_learning/lstm.ipynb

+272-21
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,95 @@
8282
"execution_count": 0,
8383
"outputs": []
8484
},
85+
{
86+
"cell_type": "code",
87+
"metadata": {
88+
"id": "R-pmd345UT1M",
89+
"colab_type": "code",
90+
"colab": {
91+
"base_uri": "https://localhost:8080/",
92+
"height": 1000
93+
},
94+
"outputId": "8318fa7e-3589-4eaf-862f-c0771ff285d8"
95+
},
96+
"source": [
97+
"next(iter(train_loader))"
98+
],
99+
"execution_count": 3,
100+
"outputs": [
101+
{
102+
"output_type": "execute_result",
103+
"data": {
104+
"text/plain": [
105+
"[tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n",
106+
" [0., 0., 0., ..., 0., 0., 0.],\n",
107+
" [0., 0., 0., ..., 0., 0., 0.],\n",
108+
" ...,\n",
109+
" [0., 0., 0., ..., 0., 0., 0.],\n",
110+
" [0., 0., 0., ..., 0., 0., 0.],\n",
111+
" [0., 0., 0., ..., 0., 0., 0.]]],\n",
112+
" \n",
113+
" \n",
114+
" [[[0., 0., 0., ..., 0., 0., 0.],\n",
115+
" [0., 0., 0., ..., 0., 0., 0.],\n",
116+
" [0., 0., 0., ..., 0., 0., 0.],\n",
117+
" ...,\n",
118+
" [0., 0., 0., ..., 0., 0., 0.],\n",
119+
" [0., 0., 0., ..., 0., 0., 0.],\n",
120+
" [0., 0., 0., ..., 0., 0., 0.]]],\n",
121+
" \n",
122+
" \n",
123+
" [[[0., 0., 0., ..., 0., 0., 0.],\n",
124+
" [0., 0., 0., ..., 0., 0., 0.],\n",
125+
" [0., 0., 0., ..., 0., 0., 0.],\n",
126+
" ...,\n",
127+
" [0., 0., 0., ..., 0., 0., 0.],\n",
128+
" [0., 0., 0., ..., 0., 0., 0.],\n",
129+
" [0., 0., 0., ..., 0., 0., 0.]]],\n",
130+
" \n",
131+
" \n",
132+
" ...,\n",
133+
" \n",
134+
" \n",
135+
" [[[0., 0., 0., ..., 0., 0., 0.],\n",
136+
" [0., 0., 0., ..., 0., 0., 0.],\n",
137+
" [0., 0., 0., ..., 0., 0., 0.],\n",
138+
" ...,\n",
139+
" [0., 0., 0., ..., 0., 0., 0.],\n",
140+
" [0., 0., 0., ..., 0., 0., 0.],\n",
141+
" [0., 0., 0., ..., 0., 0., 0.]]],\n",
142+
" \n",
143+
" \n",
144+
" [[[0., 0., 0., ..., 0., 0., 0.],\n",
145+
" [0., 0., 0., ..., 0., 0., 0.],\n",
146+
" [0., 0., 0., ..., 0., 0., 0.],\n",
147+
" ...,\n",
148+
" [0., 0., 0., ..., 0., 0., 0.],\n",
149+
" [0., 0., 0., ..., 0., 0., 0.],\n",
150+
" [0., 0., 0., ..., 0., 0., 0.]]],\n",
151+
" \n",
152+
" \n",
153+
" [[[0., 0., 0., ..., 0., 0., 0.],\n",
154+
" [0., 0., 0., ..., 0., 0., 0.],\n",
155+
" [0., 0., 0., ..., 0., 0., 0.],\n",
156+
" ...,\n",
157+
" [0., 0., 0., ..., 0., 0., 0.],\n",
158+
" [0., 0., 0., ..., 0., 0., 0.],\n",
159+
" [0., 0., 0., ..., 0., 0., 0.]]]]),\n",
160+
" tensor([2, 9, 4, 1, 5, 9, 3, 3, 1, 6, 5, 3, 7, 7, 1, 3, 4, 2, 1, 7, 4, 5, 6, 0,\n",
161+
" 2, 1, 4, 1, 4, 3, 7, 7, 7, 7, 4, 5, 6, 5, 5, 3, 6, 7, 3, 1, 9, 8, 1, 3,\n",
162+
" 8, 7, 7, 6, 0, 7, 9, 9, 7, 3, 5, 3, 3, 2, 3, 2, 6, 8, 9, 6, 2, 0, 3, 7,\n",
163+
" 4, 5, 7, 4, 6, 8, 1, 3, 7, 8, 0, 6, 0, 6, 1, 7, 0, 3, 5, 3, 3, 6, 7, 1,\n",
164+
" 5, 7, 0, 0])]"
165+
]
166+
},
167+
"metadata": {
168+
"tags": []
169+
},
170+
"execution_count": 3
171+
}
172+
]
173+
},
85174
{
86175
"cell_type": "code",
87176
"metadata": {
@@ -116,6 +205,115 @@
116205
"execution_count": 0,
117206
"outputs": []
118207
},
208+
{
209+
"cell_type": "code",
210+
"metadata": {
211+
"id": "0i7c0JK1Vpgo",
212+
"colab_type": "code",
213+
"colab": {}
214+
},
215+
"source": [
216+
"from IPython.core.debugger import set_trace"
217+
],
218+
"execution_count": 0,
219+
"outputs": []
220+
},
221+
{
222+
"cell_type": "code",
223+
"metadata": {
224+
"id": "zdmCH-uSer4Z",
225+
"colab_type": "code",
226+
"colab": {}
227+
},
228+
"source": [
229+
"class Model1(nn.Module):\n",
230+
" def __init__(self):\n",
231+
" super().__init__()\n",
232+
" # self.i_h = nn.Embedding(nv,nh) # green arrow\n",
233+
" self.h_h = nn.Linear(nh,nh) # brown arrow\n",
234+
" self.h_o = nn.Linear(nh,num_classes) # blue arrow\n",
235+
" self.bn = nn.BatchNorm1d(nh)\n",
236+
" \n",
237+
" def forward(self, x):\n",
238+
" h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
239+
" for i in range(x.shape[1]):\n",
240+
" # h = h + self.i_h(x[:,i])\n",
241+
" h = h + x[:,i]\n",
242+
" h = self.bn(F.relu(self.h_h(h)))\n",
243+
" return self.h_o(h)"
244+
],
245+
"execution_count": 0,
246+
"outputs": []
247+
},
248+
{
249+
"cell_type": "code",
250+
"metadata": {
251+
"id": "Z6k7kCc9kPIP",
252+
"colab_type": "code",
253+
"colab": {
254+
"base_uri": "https://localhost:8080/",
255+
"height": 34
256+
},
257+
"outputId": "9d114180-e69a-4f41-a59b-70602c53b389"
258+
},
259+
"source": [
260+
"nh"
261+
],
262+
"execution_count": 18,
263+
"outputs": [
264+
{
265+
"output_type": "execute_result",
266+
"data": {
267+
"text/plain": [
268+
"128"
269+
]
270+
},
271+
"metadata": {
272+
"tags": []
273+
},
274+
"execution_count": 18
275+
}
276+
]
277+
},
278+
{
279+
"cell_type": "code",
280+
"metadata": {
281+
"id": "nCSiQLeYh9so",
282+
"colab_type": "code",
283+
"colab": {}
284+
},
285+
"source": [
286+
"nh = 28"
287+
],
288+
"execution_count": 0,
289+
"outputs": []
290+
},
291+
{
292+
"cell_type": "code",
293+
"metadata": {
294+
"id": "CMYxcSDzkaFP",
295+
"colab_type": "code",
296+
"colab": {}
297+
},
298+
"source": [
299+
"import torch.nn.functional as F"
300+
],
301+
"execution_count": 0,
302+
"outputs": []
303+
},
304+
{
305+
"cell_type": "code",
306+
"metadata": {
307+
"id": "R6cCAVPvewKm",
308+
"colab_type": "code",
309+
"colab": {}
310+
},
311+
"source": [
312+
"model = Model1().to(device)"
313+
],
314+
"execution_count": 0,
315+
"outputs": []
316+
},
119317
{
120318
"cell_type": "code",
121319
"metadata": {
@@ -125,7 +323,7 @@
125323
"base_uri": "https://localhost:8080/",
126324
"height": 218
127325
},
128-
"outputId": "5e6bf8ff-34ff-4d39-8e93-c2741f1cf6f8"
326+
"outputId": "5aa78f4d-4cf2-4845-cec7-e148a356916f"
129327
},
130328
"source": [
131329
"# Loss and optimizer\n",
@@ -136,6 +334,7 @@
136334
"total_step = len(train_loader)\n",
137335
"for epoch in range(num_epochs):\n",
138336
" for i, (images, labels) in enumerate(train_loader):\n",
337+
" # set_trace()\n",
139338
" images = images.reshape(-1, sequence_length, input_size).to(device)\n",
140339
" labels = labels.to(device)\n",
141340
" \n",
@@ -152,23 +351,23 @@
152351
" print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n",
153352
" .format(epoch+1, num_epochs, i+1, total_step, loss.item()))"
154353
],
155-
"execution_count": 22,
354+
"execution_count": 24,
156355
"outputs": [
157356
{
158357
"output_type": "stream",
159358
"text": [
160-
"Epoch [1/2], Step [100/600], Loss: 0.4692\n",
161-
"Epoch [1/2], Step [200/600], Loss: 0.2797\n",
162-
"Epoch [1/2], Step [300/600], Loss: 0.1271\n",
163-
"Epoch [1/2], Step [400/600], Loss: 0.2750\n",
164-
"Epoch [1/2], Step [500/600], Loss: 0.1792\n",
165-
"Epoch [1/2], Step [600/600], Loss: 0.0991\n",
166-
"Epoch [2/2], Step [100/600], Loss: 0.0826\n",
167-
"Epoch [2/2], Step [200/600], Loss: 0.1674\n",
168-
"Epoch [2/2], Step [300/600], Loss: 0.1562\n",
169-
"Epoch [2/2], Step [400/600], Loss: 0.1447\n",
170-
"Epoch [2/2], Step [500/600], Loss: 0.0842\n",
171-
"Epoch [2/2], Step [600/600], Loss: 0.0283\n"
359+
"Epoch [1/2], Step [100/600], Loss: 1.7748\n",
360+
"Epoch [1/2], Step [200/600], Loss: 1.5532\n",
361+
"Epoch [1/2], Step [300/600], Loss: 1.2281\n",
362+
"Epoch [1/2], Step [400/600], Loss: 1.3798\n",
363+
"Epoch [1/2], Step [500/600], Loss: 0.7575\n",
364+
"Epoch [1/2], Step [600/600], Loss: 1.1237\n",
365+
"Epoch [2/2], Step [100/600], Loss: 0.8953\n",
366+
"Epoch [2/2], Step [200/600], Loss: 0.8352\n",
367+
"Epoch [2/2], Step [300/600], Loss: 0.9067\n",
368+
"Epoch [2/2], Step [400/600], Loss: 0.8490\n",
369+
"Epoch [2/2], Step [500/600], Loss: 0.8371\n",
370+
"Epoch [2/2], Step [600/600], Loss: 1.0092\n"
172371
],
173372
"name": "stdout"
174373
}
@@ -183,7 +382,7 @@
183382
"base_uri": "https://localhost:8080/",
184383
"height": 34
185384
},
186-
"outputId": "84f5134c-0b9e-42e3-fcaf-42f528a3f363"
385+
"outputId": "3940a25a-3267-4854-cdc5-3584f07d8143"
187386
},
188387
"source": [
189388
"# Test the model\n",
@@ -203,12 +402,12 @@
203402
"# Save the model checkpoint\n",
204403
"# torch.save(model.state_dict(), 'model.ckpt')"
205404
],
206-
"execution_count": 23,
405+
"execution_count": 25,
207406
"outputs": [
208407
{
209408
"output_type": "stream",
210409
"text": [
211-
"Test Accuracy of the model on the 10000 test images: 97.55 %\n"
410+
"Test Accuracy of the model on the 10000 test images: 71.49 %\n"
212411
],
213412
"name": "stdout"
214413
}
@@ -219,13 +418,65 @@
219418
"metadata": {
220419
"id": "cs2NL18qEaz2",
221420
"colab_type": "code",
222-
"colab": {}
421+
"colab": {
422+
"base_uri": "https://localhost:8080/",
423+
"height": 706
424+
},
425+
"outputId": "b683fa26-cbf8-43ec-838b-45a0bb7ef854"
223426
},
224427
"source": [
225-
""
428+
"%debug"
226429
],
227-
"execution_count": 0,
228-
"outputs": []
430+
"execution_count": 17,
431+
"outputs": [
432+
{
433+
"output_type": "stream",
434+
"text": [
435+
"> \u001b[0;32m<ipython-input-12-8318d738d92a>\u001b[0m(13)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
436+
"\u001b[0;32m 11 \u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
437+
"\u001b[0m\u001b[0;32m 12 \u001b[0;31m \u001b[0;31m# h = h + self.i_h(x[:,i])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
438+
"\u001b[0m\u001b[0;32m---> 13 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
439+
"\u001b[0m\u001b[0;32m 14 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_h\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
440+
"\u001b[0m\u001b[0;32m 15 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_o\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
441+
"\u001b[0m\n",
442+
"ipdb> h.shape\n",
443+
"*** No help for '.shape'\n",
444+
"ipdb> h\n",
445+
"\n",
446+
"Documented commands (type help <topic>):\n",
447+
"========================================\n",
448+
"EOF cl disable interact next psource rv unt \n",
449+
"a clear display j p q s until \n",
450+
"alias commands down jump pdef quit source up \n",
451+
"args condition enable l pdoc r step w \n",
452+
"b cont exit list pfile restart tbreak whatis\n",
453+
"break continue h ll pinfo return u where \n",
454+
"bt d help longlist pinfo2 retval unalias \n",
455+
"c debug ignore n pp run undisplay\n",
456+
"\n",
457+
"Miscellaneous help topics:\n",
458+
"==========================\n",
459+
"exec pdb\n",
460+
"\n",
461+
"ipdb> l\n",
462+
"\u001b[1;32m 8 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
463+
"\u001b[1;32m 9 \u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
464+
"\u001b[1;32m 10 \u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
465+
"\u001b[1;32m 11 \u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
466+
"\u001b[1;32m 12 \u001b[0m \u001b[0;31m# h = h + self.i_h(x[:,i])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
467+
"\u001b[0;32m---> 13 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
468+
"\u001b[0m\u001b[1;32m 14 \u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_h\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
469+
"\u001b[1;32m 15 \u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh_o\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
470+
"\n",
471+
"ipdb> !h.shape\n",
472+
"torch.Size([100, 128])\n",
473+
"ipdb> x.shape\n",
474+
"torch.Size([100, 28, 28])\n",
475+
"ipdb> q\n"
476+
],
477+
"name": "stdout"
478+
}
479+
]
229480
},
230481
{
231482
"cell_type": "code",

0 commit comments

Comments
 (0)