Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 96a08a1

Browse files
author
Sheng Zha
committed
update word language model
1 parent 2445ade commit 96a08a1

File tree

3 files changed

+38
-131
lines changed

3 files changed

+38
-131
lines changed

example/gluon/word_language_model/data.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

example/gluon/word_language_model/get_wikitext2_data.sh

Lines changed: 0 additions & 36 deletions
This file was deleted.

example/gluon/word_language_model/train.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
import model
2424
import data
2525

26-
parser = argparse.ArgumentParser(description='MXNet Autograd PennTreeBank RNN/LSTM Language Model')
27-
parser.add_argument('--data', type=str, default='./data/wikitext-2/wiki.',
28-
help='location of the data corpus')
26+
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
2927
parser.add_argument('--model', type=str, default='lstm',
3028
help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)')
3129
parser.add_argument('--emsize', type=int, default=200,
@@ -72,26 +70,41 @@
7270
else:
7371
context = mx.cpu(0)
7472

75-
corpus = data.Corpus(args.data)
76-
77-
def batchify(data, batch_size):
78-
"""Reshape data into (num_example, batch_size)"""
79-
nbatch = data.shape[0] // batch_size
80-
data = data[:nbatch * batch_size]
81-
data = data.reshape((batch_size, nbatch)).T
82-
return data
83-
84-
train_data = batchify(corpus.train, args.batch_size).as_in_context(context)
85-
val_data = batchify(corpus.valid, args.batch_size).as_in_context(context)
86-
test_data = batchify(corpus.test, args.batch_size).as_in_context(context)
73+
train_dataset = gluon.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
74+
indexer = train_dataset.indexer
75+
val_dataset, test_dataset = [gluon.data.text.WikiText2('./data', segment,
76+
indexer=indexer,
77+
seq_len=args.bptt)
78+
for segment in ['validation', 'test']]
79+
80+
nbatch_train = len(train_dataset) / args.batch_size
81+
train_data = gluon.data.DataLoader(train_dataset,
82+
batch_size=args.batch_size,
83+
sampler=gluon.data.IntervalSampler(len(train_dataset),
84+
nbatch_train),
85+
last_batch='discard')
86+
87+
nbatch_val = len(val_dataset) / args.batch_size
88+
val_data = gluon.data.DataLoader(val_dataset,
89+
batch_size=args.batch_size,
90+
sampler=gluon.data.IntervalSampler(len(val_dataset),
91+
nbatch_val),
92+
last_batch='discard')
93+
94+
nbatch_test = len(test_dataset) / args.batch_size
95+
test_data = gluon.data.DataLoader(test_dataset,
96+
batch_size=args.batch_size,
97+
sampler=gluon.data.IntervalSampler(len(test_dataset),
98+
nbatch_test),
99+
last_batch='discard')
87100

88101

89102
###############################################################################
90103
# Build the model
91104
###############################################################################
92105

93106

94-
ntokens = len(corpus.dictionary)
107+
ntokens = len(indexer)
95108
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
96109
args.nlayers, args.dropout, args.tied)
97110
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
@@ -108,12 +121,6 @@ def batchify(data, batch_size):
108121
# Training code
109122
###############################################################################
110123

111-
def get_batch(source, i):
112-
seq_len = min(args.bptt, source.shape[0] - 1 - i)
113-
data = source[i:i+seq_len]
114-
target = source[i+1:i+1+seq_len]
115-
return data, target.reshape((-1,))
116-
117124
def detach(hidden):
118125
if isinstance(hidden, (tuple, list)):
119126
hidden = [i.detach() for i in hidden]
@@ -125,8 +132,9 @@ def eval(data_source):
125132
total_L = 0.0
126133
ntotal = 0
127134
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
128-
for i in range(0, data_source.shape[0] - 1, args.bptt):
129-
data, target = get_batch(data_source, i)
135+
for i, (data, target) in enumerate(data_source):
136+
data = data.as_in_context(context).T
137+
target = target.as_in_context(context).T.reshape((-1, 1))
130138
output, hidden = model(data, hidden)
131139
L = loss(output, target)
132140
total_L += mx.nd.sum(L).asscalar()
@@ -139,26 +147,27 @@ def train():
139147
total_L = 0.0
140148
start_time = time.time()
141149
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
142-
for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)):
143-
data, target = get_batch(train_data, i)
150+
for i, (data, target) in enumerate(train_data):
151+
data = data.as_in_context(context).T
152+
target = target.as_in_context(context).T.reshape((-1, 1))
144153
hidden = detach(hidden)
145154
with autograd.record():
146155
output, hidden = model(data, hidden)
147156
L = loss(output, target)
148157
L.backward()
149158

150-
grads = [i.grad(context) for i in model.collect_params().values()]
159+
grads = [p.grad(context) for p in model.collect_params().values()]
151160
# Here gradient is for the whole batch.
152161
# So we multiply max_norm by batch_size and bptt size to balance it.
153162
gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)
154163

155164
trainer.step(args.batch_size)
156165
total_L += mx.nd.sum(L).asscalar()
157166

158-
if ibatch % args.log_interval == 0 and ibatch > 0:
167+
if i % args.log_interval == 0 and i > 0:
159168
cur_L = total_L / args.bptt / args.batch_size / args.log_interval
160169
print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%(
161-
epoch, ibatch, cur_L, math.exp(cur_L)))
170+
epoch, i, cur_L, math.exp(cur_L)))
162171
total_L = 0.0
163172

164173
val_L = eval(val_data)

0 commit comments

Comments
 (0)