23
23
import model
24
24
import data
25
25
26
- parser = argparse .ArgumentParser (description = 'MXNet Autograd PennTreeBank RNN/LSTM Language Model' )
27
- parser .add_argument ('--data' , type = str , default = './data/wikitext-2/wiki.' ,
28
- help = 'location of the data corpus' )
26
+ parser = argparse .ArgumentParser (description = 'MXNet Autograd RNN/LSTM Language Model on Wikitext-2.' )
29
27
parser .add_argument ('--model' , type = str , default = 'lstm' ,
30
28
help = 'type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)' )
31
29
parser .add_argument ('--emsize' , type = int , default = 200 ,
72
70
else :
73
71
context = mx .cpu (0 )
74
72
75
- corpus = data .Corpus (args .data )
76
-
77
- def batchify (data , batch_size ):
78
- """Reshape data into (num_example, batch_size)"""
79
- nbatch = data .shape [0 ] // batch_size
80
- data = data [:nbatch * batch_size ]
81
- data = data .reshape ((batch_size , nbatch )).T
82
- return data
83
-
84
- train_data = batchify (corpus .train , args .batch_size ).as_in_context (context )
85
- val_data = batchify (corpus .valid , args .batch_size ).as_in_context (context )
86
- test_data = batchify (corpus .test , args .batch_size ).as_in_context (context )
73
+ train_dataset = gluon .data .text .WikiText2 ('./data' , 'train' , seq_len = args .bptt )
74
+ indexer = train_dataset .indexer
75
+ val_dataset , test_dataset = [gluon .data .text .WikiText2 ('./data' , segment ,
76
+ indexer = indexer ,
77
+ seq_len = args .bptt )
78
+ for segment in ['validation' , 'test' ]]
79
+
80
+ nbatch_train = len (train_dataset ) / args .batch_size
81
+ train_data = gluon .data .DataLoader (train_dataset ,
82
+ batch_size = args .batch_size ,
83
+ sampler = gluon .data .IntervalSampler (len (train_dataset ),
84
+ nbatch_train ),
85
+ last_batch = 'discard' )
86
+
87
+ nbatch_val = len (val_dataset ) / args .batch_size
88
+ val_data = gluon .data .DataLoader (val_dataset ,
89
+ batch_size = args .batch_size ,
90
+ sampler = gluon .data .IntervalSampler (len (val_dataset ),
91
+ nbatch_val ),
92
+ last_batch = 'discard' )
93
+
94
+ nbatch_test = len (test_dataset ) / args .batch_size
95
+ test_data = gluon .data .DataLoader (test_dataset ,
96
+ batch_size = args .batch_size ,
97
+ sampler = gluon .data .IntervalSampler (len (test_dataset ),
98
+ nbatch_test ),
99
+ last_batch = 'discard' )
87
100
88
101
89
102
###############################################################################
90
103
# Build the model
91
104
###############################################################################
92
105
93
106
94
- ntokens = len (corpus . dictionary )
107
+ ntokens = len (indexer )
95
108
model = model .RNNModel (args .model , ntokens , args .emsize , args .nhid ,
96
109
args .nlayers , args .dropout , args .tied )
97
110
model .collect_params ().initialize (mx .init .Xavier (), ctx = context )
@@ -108,12 +121,6 @@ def batchify(data, batch_size):
108
121
# Training code
109
122
###############################################################################
110
123
111
- def get_batch (source , i ):
112
- seq_len = min (args .bptt , source .shape [0 ] - 1 - i )
113
- data = source [i :i + seq_len ]
114
- target = source [i + 1 :i + 1 + seq_len ]
115
- return data , target .reshape ((- 1 ,))
116
-
117
124
def detach (hidden ):
118
125
if isinstance (hidden , (tuple , list )):
119
126
hidden = [i .detach () for i in hidden ]
@@ -125,8 +132,9 @@ def eval(data_source):
125
132
total_L = 0.0
126
133
ntotal = 0
127
134
hidden = model .begin_state (func = mx .nd .zeros , batch_size = args .batch_size , ctx = context )
128
- for i in range (0 , data_source .shape [0 ] - 1 , args .bptt ):
129
- data , target = get_batch (data_source , i )
135
+ for i , (data , target ) in enumerate (data_source ):
136
+ data = data .as_in_context (context ).T
137
+ target = target .as_in_context (context ).T .reshape ((- 1 , 1 ))
130
138
output , hidden = model (data , hidden )
131
139
L = loss (output , target )
132
140
total_L += mx .nd .sum (L ).asscalar ()
@@ -139,26 +147,27 @@ def train():
139
147
total_L = 0.0
140
148
start_time = time .time ()
141
149
hidden = model .begin_state (func = mx .nd .zeros , batch_size = args .batch_size , ctx = context )
142
- for ibatch , i in enumerate (range (0 , train_data .shape [0 ] - 1 , args .bptt )):
143
- data , target = get_batch (train_data , i )
150
+ for i , (data , target ) in enumerate (train_data ):
151
+ data = data .as_in_context (context ).T
152
+ target = target .as_in_context (context ).T .reshape ((- 1 , 1 ))
144
153
hidden = detach (hidden )
145
154
with autograd .record ():
146
155
output , hidden = model (data , hidden )
147
156
L = loss (output , target )
148
157
L .backward ()
149
158
150
- grads = [i .grad (context ) for i in model .collect_params ().values ()]
159
+ grads = [p .grad (context ) for p in model .collect_params ().values ()]
151
160
# Here gradient is for the whole batch.
152
161
# So we multiply max_norm by batch_size and bptt size to balance it.
153
162
gluon .utils .clip_global_norm (grads , args .clip * args .bptt * args .batch_size )
154
163
155
164
trainer .step (args .batch_size )
156
165
total_L += mx .nd .sum (L ).asscalar ()
157
166
158
- if ibatch % args .log_interval == 0 and ibatch > 0 :
167
+ if i % args .log_interval == 0 and i > 0 :
159
168
cur_L = total_L / args .bptt / args .batch_size / args .log_interval
160
169
print ('[Epoch %d Batch %d] loss %.2f, ppl %.2f' % (
161
- epoch , ibatch , cur_L , math .exp (cur_L )))
170
+ epoch , i , cur_L , math .exp (cur_L )))
162
171
total_L = 0.0
163
172
164
173
val_L = eval (val_data )
0 commit comments