Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit f77cfa3

Browse files
committed
Signed-off-by: Abhishek P (VMware) <[email protected]>
Converted HFDatasetSplitReader to HFDatasetReader Now all splits can be used in the same reader Support for both pre-load of all splits or on demand load of the split Reduced tests to glue-cola dataset:config which is ~ 0.36MB download Updated dataset dep to be the range of >=1.5.0 and <1.6.0
1 parent 6e613b9 commit f77cfa3

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

allennlp/data/dataset_readers/huggingface_datasets_reader.py

+43-20
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from allennlp.data import DatasetReader, Token, Field
44
from allennlp.data.fields import TextField, LabelField, ListField
55
from allennlp.data.instance import Instance
6-
from datasets import load_dataset, Dataset, DatasetDict
6+
from datasets import load_dataset, Dataset, DatasetDict, Split
77
from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages
88
from datasets.features import Value
99

10-
# TODO pab complete the documentation comments
11-
class HuggingfaceDatasetSplitReader(DatasetReader):
10+
11+
# TODO pab-vmware complete the documentation comments
12+
class HuggingfaceDatasetReader(DatasetReader):
1213
"""
1314
This reader implementation wraps the huggingface datasets package
1415
to utilize it's dataset management functionality and load the information in AllenNLP friendly formats
@@ -44,6 +45,8 @@ class HuggingfaceDatasetSplitReader(DatasetReader):
4445
pre_load : `bool`, optional (default='False`)
4546
"""
4647

48+
SUPPORTED_SPLITS = [Split.TRAIN, Split.TEST, Split.VALIDATION]
49+
4750
def __init__(
4851
self,
4952
max_instances: Optional[int] = None,
@@ -52,7 +55,7 @@ def __init__(
5255
serialization_dir: Optional[str] = None,
5356
dataset_name: str = None,
5457
config_name: Optional[str] = None,
55-
pre_load: Optional[bool] = False
58+
pre_load: Optional[bool] = False,
5659
) -> None:
5760
super().__init__(
5861
max_instances,
@@ -61,7 +64,7 @@ def __init__(
6164
serialization_dir,
6265
)
6366

64-
# It would be cleaner to create a separate reader object for different dataset
67+
# It would be cleaner to create a separate reader object for diferent dataset
6568
self.dataset: Dataset = None
6669
self.datasets: DatasetDict = DatasetDict()
6770
self.dataset_name = dataset_name
@@ -77,22 +80,33 @@ def load_dataset(self):
7780
else:
7881
self.datasets = load_dataset(self.dataset_name)
7982

80-
def load_dataset_split(self, split):
81-
if self.config_name is not None:
82-
self.datasets[split] = load_dataset(self.dataset_name, self.config_name, split=split)
83+
def load_dataset_split(self, split: str):
84+
# TODO add support for datasets.split.NamedSplit
85+
if split in self.SUPPORTED_SPLITS:
86+
if self.config_name is not None:
87+
self.datasets[split] = load_dataset(
88+
self.dataset_name, self.config_name, split=split
89+
)
90+
else:
91+
self.datasets[split] = load_dataset(self.dataset_name, split=split)
8392
else:
84-
self.datasets[split] = load_dataset(self.dataset_name, split=split)
93+
raise ValueError(
94+
f"Only default splits:{self.SUPPORTED_SPLITS} are currently supported."
95+
)
8596

86-
def _read(self, file_path) -> Iterable[Instance]:
97+
def _read(self, file_path: str) -> Iterable[Instance]:
8798
"""
8899
Reads the dataset and converts the entry to AllenNLP friendly instance
89100
"""
101+
if file_path is None:
102+
raise ValueError("parameter split cannot be None")
103+
104+
# If split is not loaded, load the specific split
90105
if file_path not in self.datasets:
91106
self.load_dataset_split(file_path)
92107

93-
if self.datasets is not None and self.datasets[file_path] is not None:
94-
for entry in self.datasets[file_path]:
95-
yield self.text_to_instance(entry)
108+
for entry in self.datasets[file_path]:
109+
yield self.text_to_instance(entry)
96110

97111
def raise_feature_not_supported_value_error(self, value):
98112
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
@@ -136,7 +150,9 @@ def text_to_instance(self, *inputs) -> Instance:
136150

137151
# datasets ClassLabel maps to LabelField
138152
if isinstance(value, ClassLabel):
139-
field = LabelField(inputs[0][feature], label_namespace=feature, skip_indexing=True)
153+
field = LabelField(
154+
inputs[0][feature], label_namespace=feature, skip_indexing=True
155+
)
140156

141157
# datasets Value can be of different types
142158
elif isinstance(value, Value):
@@ -179,30 +195,35 @@ def text_to_instance(self, *inputs) -> Instance:
179195
else:
180196
self.raise_feature_not_supported_value_error(value)
181197

182-
183198
# datasets.Translation cannot be mapped directly
184199
# but it's dict structure can be mapped to a ListField of 2 ListField
185200
elif isinstance(value, Translation):
186201
if value.dtype == "dict":
187202
input_dict = inputs[0][feature]
188203
langs = list(input_dict.keys())
189-
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
204+
field_langs = [
205+
LabelField(lang, label_namespace="languages") for lang in langs
206+
]
190207
langs_field = ListField(field_langs)
191208
texts = list()
192209
for lang in langs:
193210
texts.append(TextField([Token(input_dict[lang])]))
194211
field = ListField([langs_field, ListField(texts)])
195212

196213
else:
197-
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
214+
raise ValueError(
215+
f"Datasets feature type {type(value)} is not supported yet."
216+
)
198217

199218
# datasets.TranslationVariableLanguages
200219
# is functionally a pair of Lists and hence mapped to a ListField of 2 ListField
201220
elif isinstance(value, TranslationVariableLanguages):
202221
if value.dtype == "dict":
203222
input_dict = inputs[0][feature]
204223
langs = input_dict["language"]
205-
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
224+
field_langs = [
225+
LabelField(lang, label_namespace="languages") for lang in langs
226+
]
206227
langs_field = ListField(field_langs)
207228
texts = list()
208229
for lang in langs:
@@ -211,12 +232,14 @@ def text_to_instance(self, *inputs) -> Instance:
211232
field = ListField([langs_field, ListField(texts)])
212233

213234
else:
214-
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
235+
raise ValueError(
236+
f"Datasets feature type {type(value)} is not supported yet."
237+
)
215238

216239
else:
217240
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
218241

219-
if field:
242+
if field is not None:
220243
fields[feature] = field
221244

222245
return Instance(fields)
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
import pytest
22

3-
from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetSplitReader
3+
from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader
44
import logging
55

66
logger = logging.getLogger(__name__)
77

88

9-
# TODO these UTs are actually downloading the datasets and will be very very slow
10-
# TODO add UT were we compare huggingface wrapped reader with an explicitly coded builder
9+
# TODO add UT were we compare huggingface wrapped reader with an explicitly coded dataset
1110
class HuggingfaceDatasetSplitReaderTest:
1211

1312
"""
14-
Running the tests for supported datasets which require config name to be specified
13+
Running the tests for supported datasets which require config name to be specified
1514
"""
16-
@pytest.mark.parametrize("dataset, config, split", (("glue", "cola", "train"), ("glue", "cola", "test")))
15+
16+
@pytest.mark.parametrize(
17+
"dataset, config, split", (("glue", "cola", "train"), ("glue", "cola", "test"))
18+
)
1719
def test_read_for_datasets_requiring_config(self, dataset, config, split):
18-
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset, config_name=config)
20+
huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config)
1921
instances = list(huggingface_reader.read(split))
2022
assert len(instances) == len(huggingface_reader.datasets[split])
2123
print(instances[0], print(huggingface_reader.datasets[split][0]))

0 commit comments

Comments
 (0)