Skip to content

Scale interface vectors, dynamic memory pass #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ The copy task, as descibed in the original paper, is included in the repo.

From the project root:
```bash
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (original implementation)
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (like original implementation)

python ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -mem_slot 32 -batch_size 32 -optim adam # (faster convergence)
python3 ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 32 -batch_size 1000 -optim adam -sequence_max_length 8 # (faster convergence)
```

For the full set of options, see:
Expand Down Expand Up @@ -148,7 +148,9 @@ The visdom dashboard shows memory as a heatmap for batch 0 every `-summarize_fre

## General noteworthy stuff

1. DNCs converge with Adam and RMSProp learning rules, SGD generally causes them to diverge.
1. DNCs converge faster with Adam and RMSProp learning rules, SGD generally converges extremely slowly.
The copy task, for example, takes 25k iterations on SGD with lr 1 compared to 3.5k for adam with lr 0.01.
2. `nan`s in the gradients are common, try with different batch sizes

Repos referred to for creation of this repo:

Expand Down
12 changes: 7 additions & 5 deletions dnc/dnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ def __init__(
for layer in range(self.num_layers):
if self.rnn_type.lower() == 'rnn':
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout))
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
elif self.rnn_type.lower() == 'gru':
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
if self.rnn_type.lower() == 'lstm':
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])

# memories for each layer
Expand Down Expand Up @@ -191,7 +191,7 @@ def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True
else:
read_vectors = None

return output, read_vectors, (chx, mhx)
return output, (chx, mhx, read_vectors)

def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
# handle packed data
Expand Down Expand Up @@ -229,7 +229,7 @@ def forward(self, input, hx=(None, None, None), reset_experience=False, pass_thr
chx = controller_hidden[layer]
m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
outs[time], read_vectors, (chx, m) = \
outs[time], (chx, m, read_vectors) = \
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)

# debug memory
Expand All @@ -246,6 +246,8 @@ def forward(self, input, hx=(None, None, None), reset_experience=False, pass_thr
if read_vectors is not None:
# the controller output + read vectors go into next layer
outs[time] = T.cat([outs[time], read_vectors], 1)
else:
outs[time] = T.cat([outs[time], last_read], 1)
inputs[time] = outs[time]

if self.debug:
Expand Down
38 changes: 18 additions & 20 deletions dnc/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def reset(self, batch_size=1, hidden=None, erase=True):

if hidden is None:
return {
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.gpu_id),
'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id),
'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, r, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, 1, m).fill_(δ), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id),
'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id)
}
else:
Expand All @@ -66,11 +66,11 @@ def reset(self, batch_size=1, hidden=None, erase=True):
hidden['usage_vector'] = hidden['usage_vector'].clone()

if erase:
hidden['memory'].data.fill_(δ)
hidden['memory'].data.fill_(0)
hidden['link_matrix'].data.zero_()
hidden['precedence'].data.zero_()
hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ)
hidden['read_weights'].data.fill_(0)
hidden['write_weights'].data.fill_(0)
hidden['usage_vector'].data.zero_()
return hidden

Expand Down Expand Up @@ -116,7 +116,7 @@ def get_link_matrix(self, link_matrix, write_weights, precedence):
new_link_matrix = write_weights_i * precedence

link_matrix = prev_scale * link_matrix + new_link_matrix
# elaborate trick to delete diag elems
# trick to delete diag elems
return self.I.expand_as(link_matrix) * link_matrix

def update_precedence(self, precedence, write_weights):
Expand All @@ -139,7 +139,6 @@ def write(self, write_key, write_vector, erase_vector, free_gates, read_strength
hidden['usage_vector'],
allocation_gate * write_gate
)
# print((alloc).data.cpu().numpy())

# get write weightings
hidden['write_weights'] = self.write_weighting(
Expand Down Expand Up @@ -170,8 +169,7 @@ def write(self, write_key, write_vector, erase_vector, free_gates, read_strength

def content_weightings(self, memory, keys, strengths):
d = θ(memory, keys)
strengths = F.softplus(strengths).unsqueeze(2)
return σ(d * strengths, 2)
return σ(d * strengths.unsqueeze(2), 2)

def directional_weightings(self, link_matrix, read_weights):
rw = read_weights.unsqueeze(1)
Expand Down Expand Up @@ -215,17 +213,17 @@ def forward(self, ξ, hidden):

if self.independent_linears:
# r read keys (b * r * w)
read_keys = self.read_keys_transform(ξ).view(b, r, w)
read_keys = F.tanh(self.read_keys_transform(ξ).view(b, r, w))
# r read strengths (b * r)
read_strengths = self.read_strengths_transform(ξ).view(b, r)
read_strengths = F.softplus(self.read_strengths_transform(ξ).view(b, r))
# write key (b * 1 * w)
write_key = self.write_key_transform(ξ).view(b, 1, w)
write_key = F.tanh(self.write_key_transform(ξ).view(b, 1, w))
# write strength (b * 1)
write_strength = self.write_strength_transform(ξ).view(b, 1)
write_strength = F.softplus(self.write_strength_transform(ξ).view(b, 1))
# erase vector (b * 1 * w)
erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
# write vector (b * 1 * w)
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
write_vector = F.tanh(self.write_vector_transform(ξ).view(b, 1, w))
# r free gates (b * r)
free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r))
# allocation gate (b * 1)
Expand All @@ -237,17 +235,17 @@ def forward(self, ξ, hidden):
else:
ξ = self.interface_weights(ξ)
# r read keys (b * w * r)
read_keys = ξ[:, :r * w].contiguous().view(b, r, w)
read_keys = F.tanh(ξ[:, :r * w].contiguous().view(b, r, w))
# r read strengths (b * r)
read_strengths = ξ[:, r * w:r * w + r].contiguous().view(b, r)
read_strengths = F.softplus(ξ[:, r * w:r * w + r].contiguous().view(b, r))
# write key (b * w * 1)
write_key = ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w)
write_key = F.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w))
# write strength (b * 1)
write_strength = ξ[:, r * w + r + w].contiguous().view(b, 1)
write_strength = F.softplus(ξ[:, r * w + r + w].contiguous().view(b, 1))
# erase vector (b * w)
erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
# write vector (b * w)
write_vector = ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w)
write_vector = F.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w))
# r free gates (b * r)
free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
# allocation gate (b * 1)
Expand Down
28 changes: 20 additions & 8 deletions tasks/copy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@

parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
parser.add_argument('-mem_size', type=int, default=16, help='memory dimension')
parser.add_argument('-mem_slot', type=int, default=10, help='number of memory slots')
parser.add_argument('-read_heads', type=int, default=1, help='number of read heads')
parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')

parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
Expand Down Expand Up @@ -121,7 +121,8 @@ def criterion(predictions, targets):
read_heads=read_heads,
gpu_id=args.cuda,
debug=True,
batch_first=True
batch_first=True,
independent_linears=True
)
print(rnn)

Expand All @@ -131,9 +132,20 @@ def criterion(predictions, targets):
last_save_losses = []

if args.optim == 'adam':
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
if args.optim == 'sparseadam':
optimizer = optim.SparseAdam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
if args.optim == 'adamax':
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
elif args.optim == 'rmsprop':
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, eps=1e-10)
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, eps=1e-10) # 0.0001
elif args.optim == 'sgd':
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
elif args.optim == 'adagrad':
optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
elif args.optim == 'adadelta':
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)


for epoch in range(iterations + 1):
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
Expand Down Expand Up @@ -183,13 +195,13 @@ def criterion(predictions, targets):
)

viz.heatmap(
v['link_matrix'],
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
opts=dict(
xtickstep=10,
ytickstep=2,
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot * mem_slot'
ylabel='mem_slot',
xlabel='mem_slot'
)
)

Expand Down
67 changes: 66 additions & 1 deletion test/test_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import time
sys.path.insert(0, '.')

import functools

from dnc import DNC
from test_utils import generate_data, criterion

Expand Down Expand Up @@ -128,6 +130,69 @@ def test_rnn_n():
optimizer.step()

assert target_output.size() == T.Size([27, 10, 100])
assert chx[1].size() == T.Size([1,10,100])
assert chx[1].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51])


def test_rnn_no_memory_pass():
T.manual_seed(1111)

input_size = 100
hidden_size = 100
rnn_type = 'gru'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 12
cell_size = 17
read_heads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13

rnn = DNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
gpu_id=gpu_id,
debug=debug
)

optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()

input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()

(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)

output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()

assert target_output.size() == T.Size([27, 10, 100])
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None


65 changes: 64 additions & 1 deletion test/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import math
import time
import functools
sys.path.insert(0, '.')

from dnc import DNC
Expand Down Expand Up @@ -128,6 +129,68 @@ def test_rnn_n():
optimizer.step()

assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0].size() == T.Size([1,10,100])
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51])


def test_rnn_no_memory_pass():
T.manual_seed(1111)

input_size = 100
hidden_size = 100
rnn_type = 'lstm'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 12
cell_size = 17
read_heads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13

rnn = DNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
gpu_id=gpu_id,
debug=debug
)

optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()

input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()

(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)

output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()

assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None

Loading