20
20
import math
21
21
import mxnet as mx
22
22
from mxnet import gluon , autograd
23
+ from mxnet .gluon import contrib
23
24
import model
24
25
import data
25
26
26
- parser = argparse .ArgumentParser (description = 'MXNet Autograd PennTreeBank RNN/LSTM Language Model' )
27
- parser .add_argument ('--data' , type = str , default = './data/wikitext-2/wiki.' ,
28
- help = 'location of the data corpus' )
27
+ parser = argparse .ArgumentParser (description = 'MXNet Autograd RNN/LSTM Language Model on Wikitext-2.' )
29
28
parser .add_argument ('--model' , type = str , default = 'lstm' ,
30
29
help = 'type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)' )
31
30
parser .add_argument ('--emsize' , type = int , default = 200 ,
72
71
else :
73
72
context = mx .cpu (0 )
74
73
75
- corpus = data .Corpus (args .data )
76
-
77
- def batchify (data , batch_size ):
78
- """Reshape data into (num_example, batch_size)"""
79
- nbatch = data .shape [0 ] // batch_size
80
- data = data [:nbatch * batch_size ]
81
- data = data .reshape ((batch_size , nbatch )).T
82
- return data
83
-
84
- train_data = batchify (corpus .train , args .batch_size ).as_in_context (context )
85
- val_data = batchify (corpus .valid , args .batch_size ).as_in_context (context )
86
- test_data = batchify (corpus .test , args .batch_size ).as_in_context (context )
74
+ train_dataset = contrib .data .text .WikiText2 ('./data' , 'train' , seq_len = args .bptt )
75
+ vocab = train_dataset .vocabulary
76
+ val_dataset , test_dataset = [contrib .data .text .WikiText2 ('./data' , segment ,
77
+ vocab = vocab ,
78
+ seq_len = args .bptt )
79
+ for segment in ['validation' , 'test' ]]
80
+
81
+ nbatch_train = len (train_dataset ) / args .batch_size
82
+ train_data = gluon .data .DataLoader (train_dataset ,
83
+ batch_size = args .batch_size ,
84
+ sampler = contrib .data .IntervalSampler (len (train_dataset ),
85
+ nbatch_train ),
86
+ last_batch = 'discard' )
87
+
88
+ nbatch_val = len (val_dataset ) / args .batch_size
89
+ val_data = gluon .data .DataLoader (val_dataset ,
90
+ batch_size = args .batch_size ,
91
+ sampler = contrib .data .IntervalSampler (len (val_dataset ),
92
+ nbatch_val ),
93
+ last_batch = 'discard' )
94
+
95
+ nbatch_test = len (test_dataset ) / args .batch_size
96
+ test_data = gluon .data .DataLoader (test_dataset ,
97
+ batch_size = args .batch_size ,
98
+ sampler = contrib .data .IntervalSampler (len (test_dataset ),
99
+ nbatch_test ),
100
+ last_batch = 'discard' )
87
101
88
102
89
103
###############################################################################
90
104
# Build the model
91
105
###############################################################################
92
106
93
107
94
- ntokens = len (corpus . dictionary )
108
+ ntokens = len (vocab )
95
109
model = model .RNNModel (args .model , ntokens , args .emsize , args .nhid ,
96
110
args .nlayers , args .dropout , args .tied )
97
111
model .collect_params ().initialize (mx .init .Xavier (), ctx = context )
@@ -108,12 +122,6 @@ def batchify(data, batch_size):
108
122
# Training code
109
123
###############################################################################
110
124
111
- def get_batch (source , i ):
112
- seq_len = min (args .bptt , source .shape [0 ] - 1 - i )
113
- data = source [i :i + seq_len ]
114
- target = source [i + 1 :i + 1 + seq_len ]
115
- return data , target .reshape ((- 1 ,))
116
-
117
125
def detach (hidden ):
118
126
if isinstance (hidden , (tuple , list )):
119
127
hidden = [i .detach () for i in hidden ]
@@ -125,8 +133,9 @@ def eval(data_source):
125
133
total_L = 0.0
126
134
ntotal = 0
127
135
hidden = model .begin_state (func = mx .nd .zeros , batch_size = args .batch_size , ctx = context )
128
- for i in range (0 , data_source .shape [0 ] - 1 , args .bptt ):
129
- data , target = get_batch (data_source , i )
136
+ for i , (data , target ) in enumerate (data_source ):
137
+ data = data .as_in_context (context ).T
138
+ target = target .as_in_context (context ).T .reshape ((- 1 , 1 ))
130
139
output , hidden = model (data , hidden )
131
140
L = loss (output , target )
132
141
total_L += mx .nd .sum (L ).asscalar ()
@@ -139,26 +148,27 @@ def train():
139
148
total_L = 0.0
140
149
start_time = time .time ()
141
150
hidden = model .begin_state (func = mx .nd .zeros , batch_size = args .batch_size , ctx = context )
142
- for ibatch , i in enumerate (range (0 , train_data .shape [0 ] - 1 , args .bptt )):
143
- data , target = get_batch (train_data , i )
151
+ for i , (data , target ) in enumerate (train_data ):
152
+ data = data .as_in_context (context ).T
153
+ target = target .as_in_context (context ).T .reshape ((- 1 , 1 ))
144
154
hidden = detach (hidden )
145
155
with autograd .record ():
146
156
output , hidden = model (data , hidden )
147
157
L = loss (output , target )
148
158
L .backward ()
149
159
150
- grads = [i .grad (context ) for i in model .collect_params ().values ()]
160
+ grads = [p .grad (context ) for p in model .collect_params ().values ()]
151
161
# Here gradient is for the whole batch.
152
162
# So we multiply max_norm by batch_size and bptt size to balance it.
153
163
gluon .utils .clip_global_norm (grads , args .clip * args .bptt * args .batch_size )
154
164
155
165
trainer .step (args .batch_size )
156
166
total_L += mx .nd .sum (L ).asscalar ()
157
167
158
- if ibatch % args .log_interval == 0 and ibatch > 0 :
168
+ if i % args .log_interval == 0 and i > 0 :
159
169
cur_L = total_L / args .bptt / args .batch_size / args .log_interval
160
170
print ('[Epoch %d Batch %d] loss %.2f, ppl %.2f' % (
161
- epoch , ibatch , cur_L , math .exp (cur_L )))
171
+ epoch , i , cur_L , math .exp (cur_L )))
162
172
total_L = 0.0
163
173
164
174
val_L = eval (val_data )
0 commit comments