Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4e92a88

Browse files
author
Sheng Zha
committed
update word language model
1 parent 1eb7e17 commit 4e92a88

File tree

3 files changed

+34
-131
lines changed

3 files changed

+34
-131
lines changed

example/gluon/word_language_model/data.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

example/gluon/word_language_model/get_wikitext2_data.sh

Lines changed: 0 additions & 36 deletions
This file was deleted.

example/gluon/word_language_model/train.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
import model
2424
import data
2525

26-
parser = argparse.ArgumentParser(description='MXNet Autograd PennTreeBank RNN/LSTM Language Model')
27-
parser.add_argument('--data', type=str, default='./data/wikitext-2/wiki.',
28-
help='location of the data corpus')
26+
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
2927
parser.add_argument('--model', type=str, default='lstm',
3028
help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)')
3129
parser.add_argument('--emsize', type=int, default=200,
@@ -72,26 +70,37 @@
7270
else:
7371
context = mx.cpu(0)
7472

75-
corpus = data.Corpus(args.data)
76-
77-
def batchify(data, batch_size):
78-
"""Reshape data into (num_example, batch_size)"""
79-
nbatch = data.shape[0] // batch_size
80-
data = data[:nbatch * batch_size]
81-
data = data.reshape((batch_size, nbatch)).T
82-
return data
83-
84-
train_data = batchify(corpus.train, args.batch_size).as_in_context(context)
85-
val_data = batchify(corpus.valid, args.batch_size).as_in_context(context)
86-
test_data = batchify(corpus.test, args.batch_size).as_in_context(context)
73+
train_dataset = gluon.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
74+
indexer = train_dataset.indexer
75+
val_dataset, test_dataset = [gluon.data.text.WikiText2('./data', segment,
76+
indexer=indexer,
77+
seq_len=args.bptt)
78+
for segment in ['validation', 'test']]
79+
train_data = gluon.data.DataLoader(train_dataset,
80+
batch_size=args.batch_size,
81+
sampler=gluon.data.IntervalSampler(len(train_dataset),
82+
args.batch_size),
83+
last_batch='discard')
84+
85+
val_data = gluon.data.DataLoader(val_dataset,
86+
batch_size=args.batch_size,
87+
sampler=gluon.data.IntervalSampler(len(val_dataset),
88+
args.batch_size),
89+
last_batch='discard')
90+
91+
test_data = gluon.data.DataLoader(test_dataset,
92+
batch_size=args.batch_size,
93+
sampler=gluon.data.IntervalSampler(len(test_dataset),
94+
args.batch_size),
95+
last_batch='discard')
8796

8897

8998
###############################################################################
9099
# Build the model
91100
###############################################################################
92101

93102

94-
ntokens = len(corpus.dictionary)
103+
ntokens = len(indexer)
95104
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
96105
args.nlayers, args.dropout, args.tied)
97106
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
@@ -108,12 +117,6 @@ def batchify(data, batch_size):
108117
# Training code
109118
###############################################################################
110119

111-
def get_batch(source, i):
112-
seq_len = min(args.bptt, source.shape[0] - 1 - i)
113-
data = source[i:i+seq_len]
114-
target = source[i+1:i+1+seq_len]
115-
return data, target.reshape((-1,))
116-
117120
def detach(hidden):
118121
if isinstance(hidden, (tuple, list)):
119122
hidden = [i.detach() for i in hidden]
@@ -125,8 +128,9 @@ def eval(data_source):
125128
total_L = 0.0
126129
ntotal = 0
127130
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
128-
for i in range(0, data_source.shape[0] - 1, args.bptt):
129-
data, target = get_batch(data_source, i)
131+
for i, (data, target) in enumerate(data_source):
132+
data = data.as_in_context(context).T
133+
target = target.as_in_context(context).T.reshape((-1, 1))
130134
output, hidden = model(data, hidden)
131135
L = loss(output, target)
132136
total_L += mx.nd.sum(L).asscalar()
@@ -139,26 +143,27 @@ def train():
139143
total_L = 0.0
140144
start_time = time.time()
141145
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
142-
for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)):
143-
data, target = get_batch(train_data, i)
146+
for i, (data, target) in enumerate(train_data):
147+
data = data.as_in_context(context).T
148+
target = target.as_in_context(context).T.reshape((-1, 1))
144149
hidden = detach(hidden)
145150
with autograd.record():
146151
output, hidden = model(data, hidden)
147152
L = loss(output, target)
148153
L.backward()
149154

150-
grads = [i.grad(context) for i in model.collect_params().values()]
155+
grads = [p.grad(context) for p in model.collect_params().values()]
151156
# Here gradient is for the whole batch.
152157
# So we multiply max_norm by batch_size and bptt size to balance it.
153158
gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)
154159

155160
trainer.step(args.batch_size)
156161
total_L += mx.nd.sum(L).asscalar()
157162

158-
if ibatch % args.log_interval == 0 and ibatch > 0:
163+
if i % args.log_interval == 0 and i > 0:
159164
cur_L = total_L / args.bptt / args.batch_size / args.log_interval
160165
print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%(
161-
epoch, ibatch, cur_L, math.exp(cur_L)))
166+
epoch, i, cur_L, math.exp(cur_L)))
162167
total_L = 0.0
163168

164169
val_L = eval(val_data)

0 commit comments

Comments
 (0)