Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2d42d83

Browse files
author
Sheng Zha
committed
update
1 parent 12891d4 commit 2d42d83

File tree

3 files changed

+101
-63
lines changed

3 files changed

+101
-63
lines changed

example/gluon/word_language_model/train.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import math
2121
import mxnet as mx
2222
from mxnet import gluon, autograd
23+
from mxnet.gluon import contrib
2324
import model
2425
import data
2526

@@ -70,32 +71,32 @@
7071
else:
7172
context = mx.cpu(0)
7273

73-
train_dataset = gluon.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
74-
indexer = train_dataset.indexer
75-
val_dataset, test_dataset = [gluon.data.text.WikiText2('./data', segment,
76-
indexer=indexer,
77-
seq_len=args.bptt)
74+
train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
75+
vocab = train_dataset.vocabulary
76+
val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment,
77+
vocab=vocab,
78+
seq_len=args.bptt)
7879
for segment in ['validation', 'test']]
7980

8081
nbatch_train = len(train_dataset) / args.batch_size
8182
train_data = gluon.data.DataLoader(train_dataset,
8283
batch_size=args.batch_size,
83-
sampler=gluon.data.IntervalSampler(len(train_dataset),
84-
nbatch_train),
84+
sampler=contrib.data.IntervalSampler(len(train_dataset),
85+
nbatch_train),
8586
last_batch='discard')
8687

8788
nbatch_val = len(val_dataset) / args.batch_size
8889
val_data = gluon.data.DataLoader(val_dataset,
8990
batch_size=args.batch_size,
90-
sampler=gluon.data.IntervalSampler(len(val_dataset),
91-
nbatch_val),
91+
sampler=contrib.data.IntervalSampler(len(val_dataset),
92+
nbatch_val),
9293
last_batch='discard')
9394

9495
nbatch_test = len(test_dataset) / args.batch_size
9596
test_data = gluon.data.DataLoader(test_dataset,
9697
batch_size=args.batch_size,
97-
sampler=gluon.data.IntervalSampler(len(test_dataset),
98-
nbatch_test),
98+
sampler=contrib.data.IntervalSampler(len(test_dataset),
99+
nbatch_test),
99100
last_batch='discard')
100101

101102

@@ -104,7 +105,7 @@
104105
###############################################################################
105106

106107

107-
ntokens = len(indexer)
108+
ntokens = len(vocab)
108109
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
109110
args.nlayers, args.dropout, args.tied)
110111
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
20+
"""Read text files and load embeddings."""
21+
22+
EOS_TOKEN = '<eos>'

python/mxnet/gluon/contrib/data/text.py

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,73 +26,45 @@
2626
import shutil
2727
import numpy as np
2828

29+
from . import _constants as C
2930
from ...data import dataset
3031
from ...utils import download, check_sha1
3132
from ....contrib import text
3233
from .... import nd
3334

3435

35-
class WikiText2(dataset._DownloadedDataset):
36-
"""WikiText-2 word-level dataset for language modeling, from Salesforce research.
36+
class _TextDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method
37+
def __init__(self, repo_dir, root, vocabulary, transform):
38+
self._vocab = vocabulary
39+
super(_TextDataset, self).__init__(repo_dir, root, transform)
3740

38-
From
39-
https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
40-
41-
License: Creative Commons Attribution-ShareAlike
41+
@property
42+
def vocabulary(self):
43+
return self._vocab
4244

43-
Each sample is a vector of length equal to the specified sequence length.
44-
At the end of each sentence, an end-of-sentence token '<eos>' is added.
45-
46-
Parameters
47-
----------
48-
root : str, default '~/.mxnet/datasets/cifar10'
49-
Path to temp folder for storing data.
50-
segment : str, default 'train'
51-
Dataset segment. Options are 'train', 'validation', 'test'.
52-
indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None
53-
The indexer to use for indexing the text dataset. If None, a default indexer is created.
54-
seq_len : int, default 35
55-
The sequence length of each sample, regardless of the sentence boundary.
56-
transform : function, default None
57-
A user defined callback that transforms each sample. For example:
58-
::
5945

60-
transform=lambda data, label: (data.astype(np.float32)/255, label)
46+
class _WikiText(_TextDataset):
6147

62-
"""
63-
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
64-
segment='train', indexer=None, seq_len=35, transform=None):
65-
self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
66-
self._data_file = {'train': ('wiki.train.tokens',
67-
'863f29c46ef9d167fff4940ec821195882fe29d1'),
68-
'validation': ('wiki.valid.tokens',
69-
'0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'),
70-
'test': ('wiki.test.tokens',
71-
'c7b8ce0aa086fb34dab808c5c49224211eb2b172')}
72-
self._segment = segment
73-
self._seq_len = seq_len
74-
self.indexer = indexer
75-
super(WikiText2, self).__init__('wikitext-2', root, transform)
48+
def _build_vocab(self, content):
49+
if not self._vocab:
50+
counter = text.utils.count_tokens_from_str(content)
51+
self._vocab = text.vocab.Vocabulary(counter=counter,
52+
reserved_tokens=[C.EOS_TOKEN])
7653

7754
def _read_batch(self, filename):
7855
with io.open(filename, 'r', encoding='utf8') as fin:
7956
content = fin.read()
80-
eos_token = '<eos>'
81-
if not self.indexer:
82-
counter = text.utils.count_tokens_from_str(content)
83-
self.indexer = text.indexer.TokenIndexer(counter=counter,
84-
reserved_tokens=[eos_token])
57+
self._build_vocab(content)
8558
raw_data = [line for line in [x.strip().split() for x in content.splitlines()]
8659
if line]
8760
for line in raw_data:
88-
line.append(eos_token)
61+
line.append(C.EOS_TOKEN)
8962
raw_data = [x for x in line for line in raw_data if x]
90-
raw_data = self.indexer.to_indices(raw_data)
63+
raw_data = self.vocabulary.to_indices(raw_data)
9164
data = raw_data[0:-1]
9265
label = raw_data[1:]
9366
return np.array(data, dtype=np.int32), np.array(label, dtype=np.int32)
9467

95-
9668
def _get_data(self):
9769
archive_file_name, archive_hash = self._archive_file
9870
data_file_name, data_hash = self._data_file[self._segment]
@@ -117,7 +89,50 @@ def _get_data(self):
11789
self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len))
11890

11991

120-
class WikiText103(WikiText2):
92+
class WikiText2(_WikiText):
93+
"""WikiText-2 word-level dataset for language modeling, from Salesforce research.
94+
95+
From
96+
https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
97+
98+
License: Creative Commons Attribution-ShareAlike
99+
100+
Each sample is a vector of length equal to the specified sequence length.
101+
At the end of each sentence, an end-of-sentence token '<eos>' is added.
102+
103+
Parameters
104+
----------
105+
root : str, default '~/.mxnet/datasets/cifar10'
106+
Path to temp folder for storing data.
107+
segment : str, default 'train'
108+
Dataset segment. Options are 'train', 'validation', 'test'.
109+
vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
110+
The vocabulary to use for indexing the text dataset.
111+
If None, a default vocabulary is created.
112+
seq_len : int, default 35
113+
The sequence length of each sample, regardless of the sentence boundary.
114+
transform : function, default None
115+
A user defined callback that transforms each sample. For example:
116+
::
117+
118+
transform=lambda data, label: (data.astype(np.float32)/255, label)
119+
120+
"""
121+
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
122+
segment='train', vocab=None, seq_len=35, transform=None):
123+
self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
124+
self._data_file = {'train': ('wiki.train.tokens',
125+
'863f29c46ef9d167fff4940ec821195882fe29d1'),
126+
'validation': ('wiki.valid.tokens',
127+
'0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'),
128+
'test': ('wiki.test.tokens',
129+
'c7b8ce0aa086fb34dab808c5c49224211eb2b172')}
130+
self._segment = segment
131+
self._seq_len = seq_len
132+
super(WikiText2, self).__init__('wikitext-2', root, vocab, transform)
133+
134+
135+
class WikiText103(_WikiText):
121136
"""WikiText-103 word-level dataset for language modeling, from Salesforce research.
122137
123138
From
@@ -134,8 +149,9 @@ class WikiText103(WikiText2):
134149
Path to temp folder for storing data.
135150
segment : str, default 'train'
136151
Dataset segment. Options are 'train', 'validation', 'test'.
137-
indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None
138-
The indexer to use for indexing the text dataset. If None, a default indexer is created.
152+
vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
153+
The vocabulary to use for indexing the text dataset.
154+
If None, a default vocabulary is created.
139155
seq_len : int, default 35
140156
The sequence length of each sample, regardless of the sentence boundary.
141157
transform : function, default None
@@ -146,7 +162,7 @@ class WikiText103(WikiText2):
146162
147163
"""
148164
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
149-
segment='train', indexer=None, seq_len=35, transform=None):
165+
segment='train', vocab=None, seq_len=35, transform=None):
150166
self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a')
151167
self._data_file = {'train': ('wiki.train.tokens',
152168
'b7497e2dfe77e72cfef5e3dbc61b7b53712ac211'),
@@ -156,5 +172,4 @@ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
156172
'8a5befc548865cec54ed4273cf87dbbad60d1e47')}
157173
self._segment = segment
158174
self._seq_len = seq_len
159-
self.indexer = indexer
160-
super(WikiText2, self).__init__('wikitext-103', root, transform) # pylint: disable=bad-super-call
175+
super(WikiText103, self).__init__('wikitext-103', root, vocab, transform)

0 commit comments

Comments
 (0)