Skip to content

Commit 9ebdb9c

Browse files
author
Russi Chatterjee
authored
Merge pull request #9 from ixaxaar/hidden_layers
Implement Hidden layers, small enhancements, cleanups
2 parents fc863a9 + 522a810 commit 9ebdb9c

File tree

8 files changed

+499
-37
lines changed

8 files changed

+499
-37
lines changed

.travis.yml

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
language: python
2+
python:
3+
- "3.6"
4+
# command to install dependencies
5+
install:
6+
- pip install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
7+
- pip install numpy
8+
- pip install visdom
9+
# command to run tests
10+
script:
11+
- pytest

dnc/dnc.py

+43-31
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
hidden_size,
2222
rnn_type='lstm',
2323
num_layers=1,
24+
num_hidden_layers=2,
2425
bias=True,
2526
batch_first=True,
2627
dropout=0,
@@ -41,6 +42,7 @@ def __init__(
4142
self.hidden_size = hidden_size
4243
self.rnn_type = rnn_type
4344
self.num_layers = num_layers
45+
self.num_hidden_layers = num_hidden_layers
4446
self.bias = bias
4547
self.batch_first = batch_first
4648
self.dropout = dropout
@@ -57,25 +59,34 @@ def __init__(
5759
self.w = self.cell_size
5860
self.r = self.read_heads
5961

60-
# input size of layer 0
61-
self.layer0_input_size = self.r * self.w + self.input_size
62-
# input size of subsequent layers
63-
self.layern_input_size = self.r * self.w + self.hidden_size
62+
# input size
63+
self.nn_input_size = self.r * self.w + self.input_size
64+
self.nn_output_size = self.r * self.w + self.hidden_size
6465

6566
self.interface_size = (self.w * self.r) + (3 * self.w) + (5 * self.r) + 3
6667
self.output_size = self.hidden_size
6768

68-
self.rnns = []
69+
self.rnns = [[None] * self.num_hidden_layers] * self.num_layers
6970
self.memories = []
7071

7172
for layer in range(self.num_layers):
7273
# controllers for each layer
73-
if self.rnn_type.lower() == 'rnn':
74-
self.rnns.append(nn.RNNCell(self.layer0_input_size, self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
75-
elif self.rnn_type.lower() == 'gru':
76-
self.rnns.append(nn.GRUCell(self.layer0_input_size, self.output_size, bias=self.bias))
77-
elif self.rnn_type.lower() == 'lstm':
78-
self.rnns.append(nn.LSTMCell(self.layer0_input_size, self.output_size, bias=self.bias))
74+
for hlayer in range(self.num_hidden_layers):
75+
if self.rnn_type.lower() == 'rnn':
76+
if hlayer == 0:
77+
self.rnns[layer][hlayer] = nn.RNNCell(self.nn_input_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
78+
else:
79+
self.rnns[layer][hlayer] = nn.RNNCell(self.output_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
80+
elif self.rnn_type.lower() == 'gru':
81+
if hlayer == 0:
82+
self.rnns[layer][hlayer] = nn.GRUCell(self.nn_input_size, self.output_size, bias=self.bias)
83+
else:
84+
self.rnns[layer][hlayer] = nn.GRUCell(self.output_size, self.output_size, bias=self.bias)
85+
elif self.rnn_type.lower() == 'lstm':
86+
if hlayer == 0:
87+
self.rnns[layer][hlayer] = nn.LSTMCell(self.nn_input_size, self.output_size, bias=self.bias)
88+
else:
89+
self.rnns[layer][hlayer] = nn.LSTMCell(self.output_size, self.output_size, bias=self.bias)
7990

8091
# memories for each layer
8192
if not self.share_memory:
@@ -104,19 +115,20 @@ def __init__(
104115
)
105116

106117
for layer in range(self.num_layers):
107-
setattr(self, 'rnn_layer_' + str(layer), self.rnns[layer])
118+
for hlayer in range(self.num_hidden_layers):
119+
setattr(self, 'rnn_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
108120
if not self.share_memory:
109121
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
110122
if self.share_memory:
111123
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
112124

113125
# final output layer
114126
self.output_weights = nn.Linear(self.output_size, self.output_size)
115-
self.mem_out = nn.Linear(self.layern_input_size, self.input_size)
127+
self.mem_out = nn.Linear(self.nn_output_size, self.input_size)
116128
self.dropout_layer = nn.Dropout(self.dropout)
117129

118130
if self.gpu_id != -1:
119-
[x.cuda(self.gpu_id) for x in self.rnns]
131+
[x.cuda(self.gpu_id) for y in self.rnns for x in y]
120132
[x.cuda(self.gpu_id) for x in self.memories]
121133
self.mem_out.cuda(self.gpu_id)
122134

@@ -128,9 +140,11 @@ def _init_hidden(self, hx, batch_size, reset_experience):
128140

129141
# initialize hidden state of the controller RNN
130142
if chx is None:
131-
chx = cuda(T.zeros(self.num_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
143+
chx = cuda(T.zeros(batch_size, self.output_size), gpu_id=self.gpu_id)
132144
if self.rnn_type.lower() == 'lstm':
133-
chx = (chx, chx)
145+
chx = [ [ (chx.clone(), chx.clone()) for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]
146+
else:
147+
chx = [ [ chx.clone() for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]
134148

135149
# Last read vectors
136150
if last_read is None:
@@ -158,12 +172,19 @@ def _layer_forward(self, input, layer, hx=(None, None)):
158172

159173
for time in range(max_length):
160174
# pass through controller
161-
# print('input[time]', input[time].size(), self.layer0_input_size, self.layern_input_size)
162-
chx = self.rnns[layer](input[time], chx)
175+
layer_input = input[time]
176+
hchx = []
177+
178+
for hlayer in range(self.num_hidden_layers):
179+
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
180+
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
181+
hchx.append(h)
182+
chx = hchx
183+
163184
# the interface vector
164-
ξ = chx[0] if self.rnn_type.lower() == 'lstm' else chx
185+
ξ = layer_input
165186
# the output
166-
out = self.output_weights(chx[0]) if self.rnn_type.lower() == 'lstm' else self.output_weights(chx)
187+
out = self.output_weights(layer_input)
167188

168189
# pass through memory
169190
if self.share_memory:
@@ -205,10 +226,9 @@ def forward(self, input, hx=(None, None, None), reset_experience=False):
205226
# outs = [input[:, x, :] for x in range(max_length)]
206227
outs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
207228

208-
# chx = [x[0] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[0]
209229
for layer in range(self.num_layers):
210230
# this layer's hidden states
211-
chx = [x[layer] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[layer]
231+
chx = controller_hidden[layer]
212232

213233
m = mem_hidden if self.share_memory else mem_hidden[layer]
214234
# pass through controller
@@ -240,21 +260,13 @@ def forward(self, input, hx=(None, None, None), reset_experience=False):
240260
if self.debug:
241261
viz = T.cat(viz, 0).transpose(0, 1)
242262

243-
# final hidden values
244-
if self.rnn_type.lower() == 'lstm':
245-
h = T.stack([x[0] for x in chxs], 0)
246-
c = T.stack([x[1] for x in chxs], 0)
247-
controller_hidden = (h, c)
248-
else:
249-
controller_hidden = T.stack(chxs, 0)
263+
controller_hidden = chxs
250264

251265
if not self.batch_first:
252266
outputs = outputs.transpose(0, 1)
253267
if is_packed:
254268
outputs = pack(output, lengths)
255269

256-
# apply_dict(locals())
257-
258270
if self.debug:
259271
return outputs, (controller_hidden, mem_hidden, read_vectors[-1]), viz
260272
else:

dnc/memory.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def reset(self, batch_size=1, hidden=None, erase=True):
5050

5151
if hidden is None:
5252
return {
53-
'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id),
53+
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.gpu_id),
5454
'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id),
5555
'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id),
56-
'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id),
57-
'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id),
56+
'read_weights': cuda(T.zeros(b, r, m).fill_(δ), gpu_id=self.gpu_id),
57+
'write_weights': cuda(T.zeros(b, 1, m).fill_(δ), gpu_id=self.gpu_id),
5858
'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id)
5959
}
6060
else:

tasks/copy_task.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
parser.add_argument('-input_size', type=int, default=6, help='dimension of input feature')
2828
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
2929
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
30-
parser.add_argument('-dropout', type=float, default=0.3, help='controller dropout')
30+
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
3131

3232
parser.add_argument('-nlayer', type=int, default=2, help='number of layers')
33+
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
3334
parser.add_argument('-lr', type=float, default=1e-2, help='initial learning rate')
3435
parser.add_argument('-clip', type=float, default=0.5, help='gradient clipping')
3536

@@ -110,14 +111,17 @@ def criterion(predictions, targets):
110111
rnn = DNC(
111112
input_size=args.input_size,
112113
hidden_size=args.nhid,
113-
rnn_type='lstm',
114+
rnn_type=args.rnn_type,
114115
num_layers=args.nlayer,
116+
num_hidden_layers=args.nhlayer,
117+
dropout=args.dropout,
115118
nr_cells=mem_slot,
116119
cell_size=mem_size,
117120
read_heads=read_heads,
118121
gpu_id=args.cuda,
119122
debug=True
120123
)
124+
print(rnn)
121125

122126
if args.cuda != -1:
123127
rnn = rnn.cuda(args.cuda)
@@ -147,6 +151,7 @@ def criterion(predictions, targets):
147151
# apply_dict(locals())
148152
loss.backward()
149153

154+
T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
150155
optimizer.step()
151156
loss_value = loss.data[0]
152157

@@ -166,7 +171,7 @@ def criterion(predictions, targets):
166171
xtickstep=10,
167172
ytickstep=2,
168173
title='Timestep: ' + str(epoch) + ', loss: ' + str(loss),
169-
xlabel='mem_slot * time',
174+
xlabel='mem_slot * layer',
170175
ylabel='mem_size'
171176
)
172177
)

test/test_gru.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python3
2+
3+
import pytest
4+
import numpy as np
5+
6+
import torch.nn as nn
7+
import torch as T
8+
from torch.autograd import Variable as var
9+
import torch.nn.functional as F
10+
from torch.nn.utils import clip_grad_norm
11+
import torch.optim as optim
12+
import numpy as np
13+
14+
import sys
15+
import os
16+
import math
17+
import time
18+
sys.path.append('./src/')
19+
sys.path.insert(0, os.path.join('..', '..'))
20+
21+
from dnc.dnc import DNC
22+
from test_utils import generate_data, criterion
23+
24+
25+
def test_rnn_1():
26+
T.manual_seed(1111)
27+
28+
input_size = 100
29+
hidden_size = 100
30+
rnn_type = 'gru'
31+
num_layers = 1
32+
num_hidden_layers = 1
33+
dropout = 0
34+
nr_cells = 1
35+
cell_size = 1
36+
read_heads = 1
37+
gpu_id = -1
38+
debug = True
39+
lr = 0.001
40+
sequence_max_length = 10
41+
batch_size = 10
42+
cuda = gpu_id
43+
clip = 10
44+
length = 10
45+
46+
rnn = DNC(
47+
input_size=input_size,
48+
hidden_size=hidden_size,
49+
rnn_type=rnn_type,
50+
num_layers=num_layers,
51+
num_hidden_layers=num_hidden_layers,
52+
dropout=dropout,
53+
nr_cells=nr_cells,
54+
cell_size=cell_size,
55+
read_heads=read_heads,
56+
gpu_id=gpu_id,
57+
debug=debug
58+
)
59+
60+
optimizer = optim.Adam(rnn.parameters(), lr=lr)
61+
optimizer.zero_grad()
62+
63+
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
64+
target_output = target_output.transpose(0, 1).contiguous()
65+
66+
output, (chx, mhx, rv), v = rnn(input_data, None)
67+
output = output.transpose(0, 1)
68+
69+
loss = criterion((output), target_output)
70+
loss.backward()
71+
72+
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
73+
optimizer.step()
74+
75+
assert target_output.size() == T.Size([21, 10, 100])
76+
assert chx[0][0].size() == T.Size([10,100])
77+
assert mhx['memory'].size() == T.Size([10,1,1])
78+
assert rv.size() == T.Size([10,1])
79+
80+
81+
def test_rnn_n():
82+
T.manual_seed(1111)
83+
84+
input_size = 100
85+
hidden_size = 100
86+
rnn_type = 'gru'
87+
num_layers = 3
88+
num_hidden_layers = 5
89+
dropout = 0.2
90+
nr_cells = 12
91+
cell_size = 17
92+
read_heads = 3
93+
gpu_id = -1
94+
debug = True
95+
lr = 0.001
96+
sequence_max_length = 10
97+
batch_size = 10
98+
cuda = gpu_id
99+
clip = 20
100+
length = 13
101+
102+
rnn = DNC(
103+
input_size=input_size,
104+
hidden_size=hidden_size,
105+
rnn_type=rnn_type,
106+
num_layers=num_layers,
107+
num_hidden_layers=num_hidden_layers,
108+
dropout=dropout,
109+
nr_cells=nr_cells,
110+
cell_size=cell_size,
111+
read_heads=read_heads,
112+
gpu_id=gpu_id,
113+
debug=debug
114+
)
115+
116+
optimizer = optim.Adam(rnn.parameters(), lr=lr)
117+
optimizer.zero_grad()
118+
119+
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
120+
target_output = target_output.transpose(0, 1).contiguous()
121+
122+
output, (chx, mhx, rv), v = rnn(input_data, None)
123+
output = output.transpose(0, 1)
124+
125+
loss = criterion((output), target_output)
126+
loss.backward()
127+
128+
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
129+
optimizer.step()
130+
131+
assert target_output.size() == T.Size([27, 10, 100])
132+
assert chx[1][2].size() == T.Size([10,100])
133+
assert mhx['memory'].size() == T.Size([10,12,17])
134+
assert rv.size() == T.Size([10,51])

0 commit comments

Comments
 (0)