26
26
import shutil
27
27
import numpy as np
28
28
29
+ from . import _constants as C
29
30
from ...data import dataset
30
31
from ...utils import download , check_sha1
31
32
from ....contrib import text
32
33
from .... import nd
33
34
34
35
35
- class WikiText2 (dataset ._DownloadedDataset ):
36
- """WikiText-2 word-level dataset for language modeling, from Salesforce research.
36
+ class _TextDataset (dataset ._DownloadedDataset ): # pylint: disable=abstract-method
37
+ def __init__ (self , repo_dir , root , vocabulary , transform ):
38
+ self ._vocab = vocabulary
39
+ super (_TextDataset , self ).__init__ (repo_dir , root , transform )
37
40
38
- From
39
- https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
40
-
41
- License: Creative Commons Attribution-ShareAlike
41
+ @property
42
+ def vocabulary (self ):
43
+ return self ._vocab
42
44
43
- Each sample is a vector of length equal to the specified sequence length.
44
- At the end of each sentence, an end-of-sentence token '<eos>' is added.
45
-
46
- Parameters
47
- ----------
48
- root : str, default '~/.mxnet/datasets/cifar10'
49
- Path to temp folder for storing data.
50
- segment : str, default 'train'
51
- Dataset segment. Options are 'train', 'validation', 'test'.
52
- indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None
53
- The indexer to use for indexing the text dataset. If None, a default indexer is created.
54
- seq_len : int, default 35
55
- The sequence length of each sample, regardless of the sentence boundary.
56
- transform : function, default None
57
- A user defined callback that transforms each sample. For example:
58
- ::
59
45
60
- transform=lambda data, label: (data.astype(np.float32)/255, label)
46
+ class _WikiText ( _TextDataset ):
61
47
62
- """
63
- def __init__ (self , root = os .path .join ('~' , '.mxnet' , 'datasets' , 'wikitext-2' ),
64
- segment = 'train' , indexer = None , seq_len = 35 , transform = None ):
65
- self ._archive_file = ('wikitext-2-v1.zip' , '3c914d17d80b1459be871a5039ac23e752a53cbe' )
66
- self ._data_file = {'train' : ('wiki.train.tokens' ,
67
- '863f29c46ef9d167fff4940ec821195882fe29d1' ),
68
- 'validation' : ('wiki.valid.tokens' ,
69
- '0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422' ),
70
- 'test' : ('wiki.test.tokens' ,
71
- 'c7b8ce0aa086fb34dab808c5c49224211eb2b172' )}
72
- self ._segment = segment
73
- self ._seq_len = seq_len
74
- self .indexer = indexer
75
- super (WikiText2 , self ).__init__ ('wikitext-2' , root , transform )
48
+ def _build_vocab (self , content ):
49
+ if not self ._vocab :
50
+ counter = text .utils .count_tokens_from_str (content )
51
+ self ._vocab = text .vocab .Vocabulary (counter = counter ,
52
+ reserved_tokens = [C .EOS_TOKEN ])
76
53
77
54
def _read_batch (self , filename ):
78
55
with io .open (filename , 'r' , encoding = 'utf8' ) as fin :
79
56
content = fin .read ()
80
- eos_token = '<eos>'
81
- if not self .indexer :
82
- counter = text .utils .count_tokens_from_str (content )
83
- self .indexer = text .indexer .TokenIndexer (counter = counter ,
84
- reserved_tokens = [eos_token ])
57
+ self ._build_vocab (content )
85
58
raw_data = [line for line in [x .strip ().split () for x in content .splitlines ()]
86
59
if line ]
87
60
for line in raw_data :
88
- line .append (eos_token )
61
+ line .append (C . EOS_TOKEN )
89
62
raw_data = [x for x in line for line in raw_data if x ]
90
- raw_data = self .indexer .to_indices (raw_data )
63
+ raw_data = self .vocabulary .to_indices (raw_data )
91
64
data = raw_data [0 :- 1 ]
92
65
label = raw_data [1 :]
93
66
return np .array (data , dtype = np .int32 ), np .array (label , dtype = np .int32 )
94
67
95
-
96
68
def _get_data (self ):
97
69
archive_file_name , archive_hash = self ._archive_file
98
70
data_file_name , data_hash = self ._data_file [self ._segment ]
@@ -117,7 +89,50 @@ def _get_data(self):
117
89
self ._label = nd .array (label , dtype = label .dtype ).reshape ((- 1 , self ._seq_len ))
118
90
119
91
120
- class WikiText103 (WikiText2 ):
92
+ class WikiText2 (_WikiText ):
93
+ """WikiText-2 word-level dataset for language modeling, from Salesforce research.
94
+
95
+ From
96
+ https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
97
+
98
+ License: Creative Commons Attribution-ShareAlike
99
+
100
+ Each sample is a vector of length equal to the specified sequence length.
101
+ At the end of each sentence, an end-of-sentence token '<eos>' is added.
102
+
103
+ Parameters
104
+ ----------
105
+ root : str, default '~/.mxnet/datasets/cifar10'
106
+ Path to temp folder for storing data.
107
+ segment : str, default 'train'
108
+ Dataset segment. Options are 'train', 'validation', 'test'.
109
+ vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
110
+ The vocabulary to use for indexing the text dataset.
111
+ If None, a default vocabulary is created.
112
+ seq_len : int, default 35
113
+ The sequence length of each sample, regardless of the sentence boundary.
114
+ transform : function, default None
115
+ A user defined callback that transforms each sample. For example:
116
+ ::
117
+
118
+ transform=lambda data, label: (data.astype(np.float32)/255, label)
119
+
120
+ """
121
+ def __init__ (self , root = os .path .join ('~' , '.mxnet' , 'datasets' , 'wikitext-2' ),
122
+ segment = 'train' , vocab = None , seq_len = 35 , transform = None ):
123
+ self ._archive_file = ('wikitext-2-v1.zip' , '3c914d17d80b1459be871a5039ac23e752a53cbe' )
124
+ self ._data_file = {'train' : ('wiki.train.tokens' ,
125
+ '863f29c46ef9d167fff4940ec821195882fe29d1' ),
126
+ 'validation' : ('wiki.valid.tokens' ,
127
+ '0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422' ),
128
+ 'test' : ('wiki.test.tokens' ,
129
+ 'c7b8ce0aa086fb34dab808c5c49224211eb2b172' )}
130
+ self ._segment = segment
131
+ self ._seq_len = seq_len
132
+ super (WikiText2 , self ).__init__ ('wikitext-2' , root , vocab , transform )
133
+
134
+
135
+ class WikiText103 (_WikiText ):
121
136
"""WikiText-103 word-level dataset for language modeling, from Salesforce research.
122
137
123
138
From
@@ -134,8 +149,9 @@ class WikiText103(WikiText2):
134
149
Path to temp folder for storing data.
135
150
segment : str, default 'train'
136
151
Dataset segment. Options are 'train', 'validation', 'test'.
137
- indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None
138
- The indexer to use for indexing the text dataset. If None, a default indexer is created.
152
+ vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
153
+ The vocabulary to use for indexing the text dataset.
154
+ If None, a default vocabulary is created.
139
155
seq_len : int, default 35
140
156
The sequence length of each sample, regardless of the sentence boundary.
141
157
transform : function, default None
@@ -146,7 +162,7 @@ class WikiText103(WikiText2):
146
162
147
163
"""
148
164
def __init__ (self , root = os .path .join ('~' , '.mxnet' , 'datasets' , 'wikitext-103' ),
149
- segment = 'train' , indexer = None , seq_len = 35 , transform = None ):
165
+ segment = 'train' , vocab = None , seq_len = 35 , transform = None ):
150
166
self ._archive_file = ('wikitext-103-v1.zip' , '0aec09a7537b58d4bb65362fee27650eeaba625a' )
151
167
self ._data_file = {'train' : ('wiki.train.tokens' ,
152
168
'b7497e2dfe77e72cfef5e3dbc61b7b53712ac211' ),
@@ -156,5 +172,4 @@ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
156
172
'8a5befc548865cec54ed4273cf87dbbad60d1e47' )}
157
173
self ._segment = segment
158
174
self ._seq_len = seq_len
159
- self .indexer = indexer
160
- super (WikiText2 , self ).__init__ ('wikitext-103' , root , transform ) # pylint: disable=bad-super-call
175
+ super (WikiText103 , self ).__init__ ('wikitext-103' , root , vocab , transform )
0 commit comments