@@ -21,6 +21,7 @@ def __init__(
21
21
hidden_size ,
22
22
rnn_type = 'lstm' ,
23
23
num_layers = 1 ,
24
+ num_hidden_layers = 2 ,
24
25
bias = True ,
25
26
batch_first = True ,
26
27
dropout = 0 ,
@@ -41,6 +42,7 @@ def __init__(
41
42
self .hidden_size = hidden_size
42
43
self .rnn_type = rnn_type
43
44
self .num_layers = num_layers
45
+ self .num_hidden_layers = num_hidden_layers
44
46
self .bias = bias
45
47
self .batch_first = batch_first
46
48
self .dropout = dropout
@@ -57,25 +59,34 @@ def __init__(
57
59
self .w = self .cell_size
58
60
self .r = self .read_heads
59
61
60
- # input size of layer 0
61
- self .layer0_input_size = self .r * self .w + self .input_size
62
- # input size of subsequent layers
63
- self .layern_input_size = self .r * self .w + self .hidden_size
62
+ # input size
63
+ self .nn_input_size = self .r * self .w + self .input_size
64
+ self .nn_output_size = self .r * self .w + self .hidden_size
64
65
65
66
self .interface_size = (self .w * self .r ) + (3 * self .w ) + (5 * self .r ) + 3
66
67
self .output_size = self .hidden_size
67
68
68
- self .rnns = []
69
+ self .rnns = [[ None ] * self . num_hidden_layers ] * self . num_layers
69
70
self .memories = []
70
71
71
72
for layer in range (self .num_layers ):
72
73
# controllers for each layer
73
- if self .rnn_type .lower () == 'rnn' :
74
- self .rnns .append (nn .RNNCell (self .layer0_input_size , self .output_size , bias = self .bias , nonlinearity = self .nonlinearity ))
75
- elif self .rnn_type .lower () == 'gru' :
76
- self .rnns .append (nn .GRUCell (self .layer0_input_size , self .output_size , bias = self .bias ))
77
- elif self .rnn_type .lower () == 'lstm' :
78
- self .rnns .append (nn .LSTMCell (self .layer0_input_size , self .output_size , bias = self .bias ))
74
+ for hlayer in range (self .num_hidden_layers ):
75
+ if self .rnn_type .lower () == 'rnn' :
76
+ if hlayer == 0 :
77
+ self .rnns [layer ][hlayer ] = nn .RNNCell (self .nn_input_size , self .output_size ,bias = self .bias , nonlinearity = self .nonlinearity )
78
+ else :
79
+ self .rnns [layer ][hlayer ] = nn .RNNCell (self .output_size , self .output_size ,bias = self .bias , nonlinearity = self .nonlinearity )
80
+ elif self .rnn_type .lower () == 'gru' :
81
+ if hlayer == 0 :
82
+ self .rnns [layer ][hlayer ] = nn .GRUCell (self .nn_input_size , self .output_size , bias = self .bias )
83
+ else :
84
+ self .rnns [layer ][hlayer ] = nn .GRUCell (self .output_size , self .output_size , bias = self .bias )
85
+ elif self .rnn_type .lower () == 'lstm' :
86
+ if hlayer == 0 :
87
+ self .rnns [layer ][hlayer ] = nn .LSTMCell (self .nn_input_size , self .output_size , bias = self .bias )
88
+ else :
89
+ self .rnns [layer ][hlayer ] = nn .LSTMCell (self .output_size , self .output_size , bias = self .bias )
79
90
80
91
# memories for each layer
81
92
if not self .share_memory :
@@ -104,19 +115,20 @@ def __init__(
104
115
)
105
116
106
117
for layer in range (self .num_layers ):
107
- setattr (self , 'rnn_layer_' + str (layer ), self .rnns [layer ])
118
+ for hlayer in range (self .num_hidden_layers ):
119
+ setattr (self , 'rnn_layer_' + str (layer ) + '_' + str (hlayer ), self .rnns [layer ][hlayer ])
108
120
if not self .share_memory :
109
121
setattr (self , 'rnn_layer_memory_' + str (layer ), self .memories [layer ])
110
122
if self .share_memory :
111
123
setattr (self , 'rnn_layer_memory_shared' , self .memories [0 ])
112
124
113
125
# final output layer
114
126
self .output_weights = nn .Linear (self .output_size , self .output_size )
115
- self .mem_out = nn .Linear (self .layern_input_size , self .input_size )
127
+ self .mem_out = nn .Linear (self .nn_output_size , self .input_size )
116
128
self .dropout_layer = nn .Dropout (self .dropout )
117
129
118
130
if self .gpu_id != - 1 :
119
- [x .cuda (self .gpu_id ) for x in self .rnns ]
131
+ [x .cuda (self .gpu_id ) for y in self .rnns for x in y ]
120
132
[x .cuda (self .gpu_id ) for x in self .memories ]
121
133
self .mem_out .cuda (self .gpu_id )
122
134
@@ -128,9 +140,11 @@ def _init_hidden(self, hx, batch_size, reset_experience):
128
140
129
141
# initialize hidden state of the controller RNN
130
142
if chx is None :
131
- chx = cuda (T .zeros (self . num_layers , batch_size , self .output_size ), gpu_id = self .gpu_id )
143
+ chx = cuda (T .zeros (batch_size , self .output_size ), gpu_id = self .gpu_id )
132
144
if self .rnn_type .lower () == 'lstm' :
133
- chx = (chx , chx )
145
+ chx = [ [ (chx .clone (), chx .clone ()) for h in range (self .num_hidden_layers ) ] for l in range (self .num_layers ) ]
146
+ else :
147
+ chx = [ [ chx .clone () for h in range (self .num_hidden_layers ) ] for l in range (self .num_layers ) ]
134
148
135
149
# Last read vectors
136
150
if last_read is None :
@@ -158,12 +172,19 @@ def _layer_forward(self, input, layer, hx=(None, None)):
158
172
159
173
for time in range (max_length ):
160
174
# pass through controller
161
- # print('input[time]', input[time].size(), self.layer0_input_size, self.layern_input_size)
162
- chx = self .rnns [layer ](input [time ], chx )
175
+ layer_input = input [time ]
176
+ hchx = []
177
+
178
+ for hlayer in range (self .num_hidden_layers ):
179
+ h = self .rnns [layer ][hlayer ](layer_input , chx [hlayer ])
180
+ layer_input = h [0 ] if self .rnn_type .lower () == 'lstm' else h
181
+ hchx .append (h )
182
+ chx = hchx
183
+
163
184
# the interface vector
164
- ξ = chx [ 0 ] if self . rnn_type . lower () == 'lstm' else chx
185
+ ξ = layer_input
165
186
# the output
166
- out = self .output_weights (chx [ 0 ]) if self . rnn_type . lower () == 'lstm' else self . output_weights ( chx )
187
+ out = self .output_weights (layer_input )
167
188
168
189
# pass through memory
169
190
if self .share_memory :
@@ -205,10 +226,9 @@ def forward(self, input, hx=(None, None, None), reset_experience=False):
205
226
# outs = [input[:, x, :] for x in range(max_length)]
206
227
outs = [T .cat ([input [:, x , :], last_read ], 1 ) for x in range (max_length )]
207
228
208
- # chx = [x[0] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[0]
209
229
for layer in range (self .num_layers ):
210
230
# this layer's hidden states
211
- chx = [ x [ layer ] for x in controller_hidden ] if self . rnn_type . lower () == 'lstm' else controller_hidden [layer ]
231
+ chx = controller_hidden [layer ]
212
232
213
233
m = mem_hidden if self .share_memory else mem_hidden [layer ]
214
234
# pass through controller
@@ -240,21 +260,13 @@ def forward(self, input, hx=(None, None, None), reset_experience=False):
240
260
if self .debug :
241
261
viz = T .cat (viz , 0 ).transpose (0 , 1 )
242
262
243
- # final hidden values
244
- if self .rnn_type .lower () == 'lstm' :
245
- h = T .stack ([x [0 ] for x in chxs ], 0 )
246
- c = T .stack ([x [1 ] for x in chxs ], 0 )
247
- controller_hidden = (h , c )
248
- else :
249
- controller_hidden = T .stack (chxs , 0 )
263
+ controller_hidden = chxs
250
264
251
265
if not self .batch_first :
252
266
outputs = outputs .transpose (0 , 1 )
253
267
if is_packed :
254
268
outputs = pack (output , lengths )
255
269
256
- # apply_dict(locals())
257
-
258
270
if self .debug :
259
271
return outputs , (controller_hidden , mem_hidden , read_vectors [- 1 ]), viz
260
272
else :
0 commit comments