3
3
from allennlp .data import DatasetReader , Token , Field
4
4
from allennlp .data .fields import TextField , LabelField , ListField
5
5
from allennlp .data .instance import Instance
6
- from datasets import load_dataset , Dataset , DatasetDict
6
+ from datasets import load_dataset , Dataset , DatasetDict , Split
7
7
from datasets .features import ClassLabel , Sequence , Translation , TranslationVariableLanguages
8
8
from datasets .features import Value
9
9
10
- # TODO pab complete the documentation comments
11
- class HuggingfaceDatasetSplitReader (DatasetReader ):
10
+
11
+ # TODO pab-vmware complete the documentation comments
12
+ class HuggingfaceDatasetReader (DatasetReader ):
12
13
"""
13
14
This reader implementation wraps the huggingface datasets package
14
15
to utilize it's dataset management functionality and load the information in AllenNLP friendly formats
@@ -44,6 +45,8 @@ class HuggingfaceDatasetSplitReader(DatasetReader):
44
45
pre_load : `bool`, optional (default='False`)
45
46
"""
46
47
48
+ SUPPORTED_SPLITS = [Split .TRAIN , Split .TEST , Split .VALIDATION ]
49
+
47
50
def __init__ (
48
51
self ,
49
52
max_instances : Optional [int ] = None ,
@@ -52,7 +55,7 @@ def __init__(
52
55
serialization_dir : Optional [str ] = None ,
53
56
dataset_name : str = None ,
54
57
config_name : Optional [str ] = None ,
55
- pre_load : Optional [bool ] = False
58
+ pre_load : Optional [bool ] = False ,
56
59
) -> None :
57
60
super ().__init__ (
58
61
max_instances ,
@@ -61,7 +64,7 @@ def __init__(
61
64
serialization_dir ,
62
65
)
63
66
64
- # It would be cleaner to create a separate reader object for different dataset
67
+ # It would be cleaner to create a separate reader object for diferent dataset
65
68
self .dataset : Dataset = None
66
69
self .datasets : DatasetDict = DatasetDict ()
67
70
self .dataset_name = dataset_name
@@ -77,22 +80,33 @@ def load_dataset(self):
77
80
else :
78
81
self .datasets = load_dataset (self .dataset_name )
79
82
80
- def load_dataset_split (self , split ):
81
- if self .config_name is not None :
82
- self .datasets [split ] = load_dataset (self .dataset_name , self .config_name , split = split )
83
+ def load_dataset_split (self , split : str ):
84
+ # TODO add support for datasets.split.NamedSplit
85
+ if split in self .SUPPORTED_SPLITS :
86
+ if self .config_name is not None :
87
+ self .datasets [split ] = load_dataset (
88
+ self .dataset_name , self .config_name , split = split
89
+ )
90
+ else :
91
+ self .datasets [split ] = load_dataset (self .dataset_name , split = split )
83
92
else :
84
- self .datasets [split ] = load_dataset (self .dataset_name , split = split )
93
+ raise ValueError (
94
+ f"Only default splits:{ self .SUPPORTED_SPLITS } are currently supported."
95
+ )
85
96
86
- def _read (self , file_path ) -> Iterable [Instance ]:
97
+ def _read (self , file_path : str ) -> Iterable [Instance ]:
87
98
"""
88
99
Reads the dataset and converts the entry to AllenNLP friendly instance
89
100
"""
101
+ if file_path is None :
102
+ raise ValueError ("parameter split cannot be None" )
103
+
104
+ # If split is not loaded, load the specific split
90
105
if file_path not in self .datasets :
91
106
self .load_dataset_split (file_path )
92
107
93
- if self .datasets is not None and self .datasets [file_path ] is not None :
94
- for entry in self .datasets [file_path ]:
95
- yield self .text_to_instance (entry )
108
+ for entry in self .datasets [file_path ]:
109
+ yield self .text_to_instance (entry )
96
110
97
111
def raise_feature_not_supported_value_error (self , value ):
98
112
raise ValueError (f"Datasets feature type { type (value )} is not supported yet." )
@@ -136,7 +150,9 @@ def text_to_instance(self, *inputs) -> Instance:
136
150
137
151
# datasets ClassLabel maps to LabelField
138
152
if isinstance (value , ClassLabel ):
139
- field = LabelField (inputs [0 ][feature ], label_namespace = feature , skip_indexing = True )
153
+ field = LabelField (
154
+ inputs [0 ][feature ], label_namespace = feature , skip_indexing = True
155
+ )
140
156
141
157
# datasets Value can be of different types
142
158
elif isinstance (value , Value ):
@@ -179,30 +195,35 @@ def text_to_instance(self, *inputs) -> Instance:
179
195
else :
180
196
self .raise_feature_not_supported_value_error (value )
181
197
182
-
183
198
# datasets.Translation cannot be mapped directly
184
199
# but it's dict structure can be mapped to a ListField of 2 ListField
185
200
elif isinstance (value , Translation ):
186
201
if value .dtype == "dict" :
187
202
input_dict = inputs [0 ][feature ]
188
203
langs = list (input_dict .keys ())
189
- field_langs = [LabelField (lang , label_namespace = "languages" ) for lang in langs ]
204
+ field_langs = [
205
+ LabelField (lang , label_namespace = "languages" ) for lang in langs
206
+ ]
190
207
langs_field = ListField (field_langs )
191
208
texts = list ()
192
209
for lang in langs :
193
210
texts .append (TextField ([Token (input_dict [lang ])]))
194
211
field = ListField ([langs_field , ListField (texts )])
195
212
196
213
else :
197
- raise ValueError (f"Datasets feature type { type (value )} is not supported yet." )
214
+ raise ValueError (
215
+ f"Datasets feature type { type (value )} is not supported yet."
216
+ )
198
217
199
218
# datasets.TranslationVariableLanguages
200
219
# is functionally a pair of Lists and hence mapped to a ListField of 2 ListField
201
220
elif isinstance (value , TranslationVariableLanguages ):
202
221
if value .dtype == "dict" :
203
222
input_dict = inputs [0 ][feature ]
204
223
langs = input_dict ["language" ]
205
- field_langs = [LabelField (lang , label_namespace = "languages" ) for lang in langs ]
224
+ field_langs = [
225
+ LabelField (lang , label_namespace = "languages" ) for lang in langs
226
+ ]
206
227
langs_field = ListField (field_langs )
207
228
texts = list ()
208
229
for lang in langs :
@@ -211,12 +232,14 @@ def text_to_instance(self, *inputs) -> Instance:
211
232
field = ListField ([langs_field , ListField (texts )])
212
233
213
234
else :
214
- raise ValueError (f"Datasets feature type { type (value )} is not supported yet." )
235
+ raise ValueError (
236
+ f"Datasets feature type { type (value )} is not supported yet."
237
+ )
215
238
216
239
else :
217
240
raise ValueError (f"Datasets feature type { type (value )} is not supported yet." )
218
241
219
- if field :
242
+ if field is not None :
220
243
fields [feature ] = field
221
244
222
245
return Instance (fields )
0 commit comments