23
23
import model
24
24
import data
25
25
26
- parser = argparse .ArgumentParser (description = 'MXNet Autograd PennTreeBank RNN/LSTM Language Model' )
27
- parser .add_argument ('--data' , type = str , default = './data/wikitext-2/wiki.' ,
28
- help = 'location of the data corpus' )
26
+ parser = argparse .ArgumentParser (description = 'MXNet Autograd RNN/LSTM Language Model on Wikitext-2.' )
29
27
parser .add_argument ('--model' , type = str , default = 'lstm' ,
30
28
help = 'type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)' )
31
29
parser .add_argument ('--emsize' , type = int , default = 200 ,
72
70
else :
73
71
context = mx .cpu (0 )
74
72
75
- corpus = data .Corpus (args .data )
76
-
77
- def batchify (data , batch_size ):
78
- """Reshape data into (num_example, batch_size)"""
79
- nbatch = data .shape [0 ] // batch_size
80
- data = data [:nbatch * batch_size ]
81
- data = data .reshape ((batch_size , nbatch )).T
82
- return data
83
-
84
- train_data = batchify (corpus .train , args .batch_size ).as_in_context (context )
85
- val_data = batchify (corpus .valid , args .batch_size ).as_in_context (context )
86
- test_data = batchify (corpus .test , args .batch_size ).as_in_context (context )
73
+ train_dataset = gluon .data .text .WikiText2 ('./data' , 'train' , seq_len = args .bptt )
74
+ indexer = train_dataset .indexer
75
+ val_dataset , test_dataset = [gluon .data .text .WikiText2 ('./data' , segment ,
76
+ indexer = indexer ,
77
+ seq_len = args .bptt )
78
+ for segment in ['validation' , 'test' ]]
79
+ train_data = gluon .data .DataLoader (train_dataset ,
80
+ batch_size = args .batch_size ,
81
+ sampler = gluon .data .IntervalSampler (len (train_dataset ),
82
+ args .batch_size ),
83
+ last_batch = 'discard' )
84
+
85
+ val_data = gluon .data .DataLoader (val_dataset ,
86
+ batch_size = args .batch_size ,
87
+ sampler = gluon .data .IntervalSampler (len (val_dataset ),
88
+ args .batch_size ),
89
+ last_batch = 'discard' )
90
+
91
+ test_data = gluon .data .DataLoader (test_dataset ,
92
+ batch_size = args .batch_size ,
93
+ sampler = gluon .data .IntervalSampler (len (test_dataset ),
94
+ args .batch_size ),
95
+ last_batch = 'discard' )
87
96
88
97
89
98
###############################################################################
90
99
# Build the model
91
100
###############################################################################
92
101
93
102
94
- ntokens = len (corpus . dictionary )
103
+ ntokens = len (indexer )
95
104
model = model .RNNModel (args .model , ntokens , args .emsize , args .nhid ,
96
105
args .nlayers , args .dropout , args .tied )
97
106
model .collect_params ().initialize (mx .init .Xavier (), ctx = context )
@@ -108,12 +117,6 @@ def batchify(data, batch_size):
108
117
# Training code
109
118
###############################################################################
110
119
111
- def get_batch (source , i ):
112
- seq_len = min (args .bptt , source .shape [0 ] - 1 - i )
113
- data = source [i :i + seq_len ]
114
- target = source [i + 1 :i + 1 + seq_len ]
115
- return data , target .reshape ((- 1 ,))
116
-
117
120
def detach (hidden ):
118
121
if isinstance (hidden , (tuple , list )):
119
122
hidden = [i .detach () for i in hidden ]
@@ -125,8 +128,9 @@ def eval(data_source):
125
128
total_L = 0.0
126
129
ntotal = 0
127
130
hidden = model .begin_state (func = mx .nd .zeros , batch_size = args .batch_size , ctx = context )
128
- for i in range (0 , data_source .shape [0 ] - 1 , args .bptt ):
129
- data , target = get_batch (data_source , i )
131
+ for i , (data , target ) in enumerate (data_source ):
132
+ data = data .as_in_context (context ).T
133
+ target = target .as_in_context (context ).T .reshape ((- 1 , 1 ))
130
134
output , hidden = model (data , hidden )
131
135
L = loss (output , target )
132
136
total_L += mx .nd .sum (L ).asscalar ()
@@ -139,26 +143,27 @@ def train():
139
143
total_L = 0.0
140
144
start_time = time .time ()
141
145
hidden = model .begin_state (func = mx .nd .zeros , batch_size = args .batch_size , ctx = context )
142
- for ibatch , i in enumerate (range (0 , train_data .shape [0 ] - 1 , args .bptt )):
143
- data , target = get_batch (train_data , i )
146
+ for i , (data , target ) in enumerate (train_data ):
147
+ data = data .as_in_context (context ).T
148
+ target = target .as_in_context (context ).T .reshape ((- 1 , 1 ))
144
149
hidden = detach (hidden )
145
150
with autograd .record ():
146
151
output , hidden = model (data , hidden )
147
152
L = loss (output , target )
148
153
L .backward ()
149
154
150
- grads = [i .grad (context ) for i in model .collect_params ().values ()]
155
+ grads = [p .grad (context ) for p in model .collect_params ().values ()]
151
156
# Here gradient is for the whole batch.
152
157
# So we multiply max_norm by batch_size and bptt size to balance it.
153
158
gluon .utils .clip_global_norm (grads , args .clip * args .bptt * args .batch_size )
154
159
155
160
trainer .step (args .batch_size )
156
161
total_L += mx .nd .sum (L ).asscalar ()
157
162
158
- if ibatch % args .log_interval == 0 and ibatch > 0 :
163
+ if i % args .log_interval == 0 and i > 0 :
159
164
cur_L = total_L / args .bptt / args .batch_size / args .log_interval
160
165
print ('[Epoch %d Batch %d] loss %.2f, ppl %.2f' % (
161
- epoch , ibatch , cur_L , math .exp (cur_L )))
166
+ epoch , i , cur_L , math .exp (cur_L )))
162
167
total_L = 0.0
163
168
164
169
val_L = eval (val_data )
0 commit comments