Skip to content

Commit 7dcec50

Browse files
authored
Language Modeling Datasets and Sampler (apache#9514)
* refactor dataset * add interval sampler * wikitext-2/-103 * update word language model * address comments * move interval sampler to contrib * update * add frequencies property
1 parent 0ff26df commit 7dcec50

File tree

10 files changed

+369
-155
lines changed

10 files changed

+369
-155
lines changed

example/gluon/word_language_model/data.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

example/gluon/word_language_model/train.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
import math
2121
import mxnet as mx
2222
from mxnet import gluon, autograd
23+
from mxnet.gluon import contrib
2324
import model
2425
import data
2526

26-
parser = argparse.ArgumentParser(description='MXNet Autograd PennTreeBank RNN/LSTM Language Model')
27-
parser.add_argument('--data', type=str, default='./data/wikitext-2/wiki.',
28-
help='location of the data corpus')
27+
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
2928
parser.add_argument('--model', type=str, default='lstm',
3029
help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)')
3130
parser.add_argument('--emsize', type=int, default=200,
@@ -72,26 +71,41 @@
7271
else:
7372
context = mx.cpu(0)
7473

75-
corpus = data.Corpus(args.data)
76-
77-
def batchify(data, batch_size):
78-
"""Reshape data into (num_example, batch_size)"""
79-
nbatch = data.shape[0] // batch_size
80-
data = data[:nbatch * batch_size]
81-
data = data.reshape((batch_size, nbatch)).T
82-
return data
83-
84-
train_data = batchify(corpus.train, args.batch_size).as_in_context(context)
85-
val_data = batchify(corpus.valid, args.batch_size).as_in_context(context)
86-
test_data = batchify(corpus.test, args.batch_size).as_in_context(context)
74+
train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
75+
vocab = train_dataset.vocabulary
76+
val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment,
77+
vocab=vocab,
78+
seq_len=args.bptt)
79+
for segment in ['validation', 'test']]
80+
81+
nbatch_train = len(train_dataset) / args.batch_size
82+
train_data = gluon.data.DataLoader(train_dataset,
83+
batch_size=args.batch_size,
84+
sampler=contrib.data.IntervalSampler(len(train_dataset),
85+
nbatch_train),
86+
last_batch='discard')
87+
88+
nbatch_val = len(val_dataset) / args.batch_size
89+
val_data = gluon.data.DataLoader(val_dataset,
90+
batch_size=args.batch_size,
91+
sampler=contrib.data.IntervalSampler(len(val_dataset),
92+
nbatch_val),
93+
last_batch='discard')
94+
95+
nbatch_test = len(test_dataset) / args.batch_size
96+
test_data = gluon.data.DataLoader(test_dataset,
97+
batch_size=args.batch_size,
98+
sampler=contrib.data.IntervalSampler(len(test_dataset),
99+
nbatch_test),
100+
last_batch='discard')
87101

88102

89103
###############################################################################
90104
# Build the model
91105
###############################################################################
92106

93107

94-
ntokens = len(corpus.dictionary)
108+
ntokens = len(vocab)
95109
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
96110
args.nlayers, args.dropout, args.tied)
97111
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
@@ -108,12 +122,6 @@ def batchify(data, batch_size):
108122
# Training code
109123
###############################################################################
110124

111-
def get_batch(source, i):
112-
seq_len = min(args.bptt, source.shape[0] - 1 - i)
113-
data = source[i:i+seq_len]
114-
target = source[i+1:i+1+seq_len]
115-
return data, target.reshape((-1,))
116-
117125
def detach(hidden):
118126
if isinstance(hidden, (tuple, list)):
119127
hidden = [i.detach() for i in hidden]
@@ -125,8 +133,9 @@ def eval(data_source):
125133
total_L = 0.0
126134
ntotal = 0
127135
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
128-
for i in range(0, data_source.shape[0] - 1, args.bptt):
129-
data, target = get_batch(data_source, i)
136+
for i, (data, target) in enumerate(data_source):
137+
data = data.as_in_context(context).T
138+
target = target.as_in_context(context).T.reshape((-1, 1))
130139
output, hidden = model(data, hidden)
131140
L = loss(output, target)
132141
total_L += mx.nd.sum(L).asscalar()
@@ -139,26 +148,27 @@ def train():
139148
total_L = 0.0
140149
start_time = time.time()
141150
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
142-
for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)):
143-
data, target = get_batch(train_data, i)
151+
for i, (data, target) in enumerate(train_data):
152+
data = data.as_in_context(context).T
153+
target = target.as_in_context(context).T.reshape((-1, 1))
144154
hidden = detach(hidden)
145155
with autograd.record():
146156
output, hidden = model(data, hidden)
147157
L = loss(output, target)
148158
L.backward()
149159

150-
grads = [i.grad(context) for i in model.collect_params().values()]
160+
grads = [p.grad(context) for p in model.collect_params().values()]
151161
# Here gradient is for the whole batch.
152162
# So we multiply max_norm by batch_size and bptt size to balance it.
153163
gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)
154164

155165
trainer.step(args.batch_size)
156166
total_L += mx.nd.sum(L).asscalar()
157167

158-
if ibatch % args.log_interval == 0 and ibatch > 0:
168+
if i % args.log_interval == 0 and i > 0:
159169
cur_L = total_L / args.bptt / args.batch_size / args.log_interval
160170
print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%(
161-
epoch, ibatch, cur_L, math.exp(cur_L)))
171+
epoch, i, cur_L, math.exp(cur_L)))
162172
total_L = 0.0
163173

164174
val_L = eval(val_data)

python/mxnet/gluon/contrib/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@
2121
from . import nn
2222

2323
from . import rnn
24+
25+
from . import data
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#!/usr/bin/env bash
2-
31
# Licensed to the Apache Software Foundation (ASF) under one
42
# or more contributor license agreements. See the NOTICE file
53
# distributed with this work for additional information
@@ -17,20 +15,10 @@
1715
# specific language governing permissions and limitations
1816
# under the License.
1917

18+
# coding: utf-8
19+
# pylint: disable=wildcard-import
20+
"""Contrib datasets."""
2021

21-
RNN_DIR=$(cd `dirname $0`; pwd)
22-
DATA_DIR="${RNN_DIR}/data/"
23-
24-
if [[ ! -d "${DATA_DIR}" ]]; then
25-
echo "${DATA_DIR} doesn't exist, will create one";
26-
mkdir -p ${DATA_DIR}
27-
fi
28-
29-
wget -P ${DATA_DIR} https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
30-
cd ${DATA_DIR}
31-
unzip wikitext-2-v1.zip
22+
from . import text
3223

33-
# rename
34-
mv ${DATA_DIR}/wikitext-2/wiki.test.tokens ${DATA_DIR}/wikitext-2/wiki.test.txt
35-
mv ${DATA_DIR}/wikitext-2/wiki.valid.tokens ${DATA_DIR}/wikitext-2/wiki.valid.txt
36-
mv ${DATA_DIR}/wikitext-2/wiki.train.tokens ${DATA_DIR}/wikitext-2/wiki.train.txt
24+
from .sampler import *
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
20+
"""Read text files and load embeddings."""
21+
22+
EOS_TOKEN = '<eos>'
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
# pylint: disable=
20+
"""Dataset sampler."""
21+
__all__ = ['IntervalSampler']
22+
23+
from ...data import sampler
24+
25+
class IntervalSampler(sampler.Sampler):
26+
"""Samples elements from [0, length) at fixed intervals.
27+
28+
Parameters
29+
----------
30+
length : int
31+
Length of the sequence.
32+
interval : int
33+
The number of items to skip between two samples.
34+
rollover : bool, default True
35+
Whether to start again from the first skipped item after reaching the end.
36+
If true, this sampler would start again from the first skipped item until all items
37+
are visited.
38+
Otherwise, iteration stops when end is reached and skipped items are ignored.
39+
40+
Examples
41+
--------
42+
>>> sampler = contrib.data.IntervalSampler(13, interval=3)
43+
>>> list(sampler)
44+
[0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11]
45+
>>> sampler = contrib.data.IntervalSampler(13, interval=3, rollover=False)
46+
>>> list(sampler)
47+
[0, 3, 6, 9, 12]
48+
"""
49+
def __init__(self, length, interval, rollover=True):
50+
assert interval < length, \
51+
"Interval {} must be smaller than length {}".format(interval, length)
52+
self._length = length
53+
self._interval = interval
54+
self._rollover = rollover
55+
56+
def __iter__(self):
57+
for i in range(self._interval if self._rollover else 1):
58+
for j in range(i, self._length, self._interval):
59+
yield j
60+
61+
def __len__(self):
62+
return self._length

0 commit comments

Comments
 (0)