Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit ebdf5fa

Browse files
committed
Add HuggingfaceDatasetSplitReader for using Huggingface datasets
Added a new reader to allow for reading huggingface datasets as instance Mapped limited `datasets.features` to `allenlp.data.fields` Verified for selective dataset and/or dataset configurations New Dependency - "datasets==1.5.0" Signed-off-by: Abhishek P (VMware) <[email protected]>
1 parent f82d3f1 commit ebdf5fa

File tree

4 files changed

+233
-2
lines changed

4 files changed

+233
-2
lines changed

CHANGELOG.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## Unreleased
99

1010
### Added
11-
11+
- Add `HuggingfaceDatasetSplitReader` for using huggingface datasets in AllenNLP with limited support
1212
- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`.
1313

1414
### Changed
@@ -264,7 +264,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
264264
- Added sampler class and parameter in beam search for non-deterministic search, with several
265265
implementations, including `MultinomialSampler`, `TopKSampler`, `TopPSampler`, and
266266
`GumbelSampler`. Utilizing `GumbelSampler` will give [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
267-
267+
268268
### Changed
269269

270270
- Pass batch metrics to `BatchCallback`.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from typing import Iterable, Optional
2+
3+
from allennlp.data import DatasetReader, Token
4+
from allennlp.data.fields import TextField, LabelField, ListField
5+
from allennlp.data.instance import Instance
6+
from datasets import load_dataset
7+
from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages
8+
from datasets.features import Value
9+
10+
11+
class HuggingfaceDatasetSplitReader(DatasetReader):
12+
"""
13+
This reader implementation wraps the huggingface datasets package
14+
to utilize it's dataset management functionality and load the information in AllenNLP friendly formats
15+
Note: Reader works w.r.t to only one split of the dataset,
16+
i.e. you would need to create separate reader for separate splits
17+
18+
Following dataset and configurations have been verified and work with this reader
19+
20+
Dataset Dataset Configuration
21+
`xnli` `ar`
22+
`xnli` `en`
23+
`xnli` `de`
24+
`xnli` `all_languages`
25+
`glue` `cola`
26+
`glue` `mrpc`
27+
`glue` `sst2`
28+
`glue` `qqp`
29+
`glue` `mnli`
30+
`glue` `mnli_matched`
31+
`universal_dependencies` `en_lines`
32+
`universal_dependencies` `ko_kaist`
33+
`universal_dependencies` `af_afribooms`
34+
`afrikaans_ner_corpus` `NA`
35+
`swahili` `NA`
36+
`conll2003` `NA`
37+
`dbpedia_14` `NA`
38+
`trec` `NA`
39+
`emotion` `NA`
40+
"""
41+
42+
def __init__(
43+
self,
44+
max_instances: Optional[int] = None,
45+
manual_distributed_sharding: bool = False,
46+
manual_multiprocess_sharding: bool = False,
47+
serialization_dir: Optional[str] = None,
48+
dataset_name: [str] = None,
49+
split: str = "train",
50+
config_name: Optional[str] = None,
51+
) -> None:
52+
super().__init__(
53+
max_instances,
54+
manual_distributed_sharding,
55+
manual_multiprocess_sharding,
56+
serialization_dir,
57+
)
58+
59+
# It would be cleaner to create a separate reader object for different dataset
60+
self.dataset = None
61+
self.dataset_name = dataset_name
62+
self.config_name = config_name
63+
self.index = -1
64+
65+
if config_name:
66+
self.dataset = load_dataset(self.dataset_name, self.config_name, split=split)
67+
else:
68+
self.dataset = load_dataset(self.dataset_name, split=split)
69+
70+
def _read(self, file_path) -> Iterable[Instance]:
71+
"""
72+
Reads the dataset and converts the entry to AllenNLP friendly instance
73+
"""
74+
for entry in self.dataset:
75+
yield self.text_to_instance(entry)
76+
77+
def text_to_instance(self, *inputs) -> Instance:
78+
"""
79+
Takes care of converting dataset entry into AllenNLP friendly instance
80+
Currently it is implemented in an unseemly catch-up model
81+
where it converts datasets.features that are required for the supported dataset,
82+
ideally it would require design where we cleanly deliberate, decide
83+
map dataset.feature to an allenlp.data.field and then go ahead with converting it
84+
Doing that would provide the best chance of providing largest possible coverage with datasets
85+
86+
Currently this is how datasets.features types are mapped to AllenNLP Fields
87+
88+
dataset.feature type allennlp.data.fields
89+
`ClassLabel` `LabelField` in feature name namespace
90+
`Value.string` `TextField` with value as Token
91+
`Value.*` `LabelField` with value being label in feature name namespace
92+
`Sequence.string` `ListField` of `TextField` with individual string as token
93+
`Sequence.ClassLabel` `ListField` of `ClassLabel` in feature name namespace
94+
`Translation` `ListField` of 2 ListField (ClassLabel and TextField)
95+
`TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField)
96+
"""
97+
98+
# features indicate the different information available in each entry from dataset
99+
# feature types decide what type of information they are
100+
# e.g. In a Sentiment dataset an entry could have one feature (of type text/string) indicating the text
101+
# and another indicate the sentiment (of typeint32/ClassLabel)
102+
features = self.dataset.features
103+
fields = dict()
104+
105+
# TODO we need to support all different datasets features described
106+
# in https://huggingface.co/docs/datasets/features.html
107+
for feature in features:
108+
value = features[feature]
109+
110+
# datasets ClassLabel maps to LabelField
111+
if isinstance(value, ClassLabel):
112+
field = LabelField(inputs[0][feature], label_namespace=feature, skip_indexing=True)
113+
114+
# datasets Value can be of different types
115+
elif isinstance(value, Value):
116+
117+
# String value maps to TextField
118+
if value.dtype == "string":
119+
# Since TextField has to be made of Tokens add whole text as a token
120+
# TODO Should we use simple heuristics to identify what is token and what is not?
121+
field = TextField([Token(inputs[0][feature])])
122+
123+
else:
124+
field = LabelField(
125+
inputs[0][feature], label_namespace=feature, skip_indexing=True
126+
)
127+
128+
elif isinstance(value, Sequence):
129+
# datasets Sequence of strings to ListField of TextField
130+
if value.feature.dtype == "string":
131+
field_list = list()
132+
for item in inputs[0][feature]:
133+
item_field = TextField([Token(item)])
134+
field_list.append(item_field)
135+
if len(field_list) == 0:
136+
continue
137+
field = ListField(field_list)
138+
139+
# datasets Sequence of strings to ListField of LabelField
140+
elif isinstance(value.feature, ClassLabel):
141+
field_list = list()
142+
for item in inputs[0][feature]:
143+
item_field = LabelField(
144+
label=item, label_namespace=feature, skip_indexing=True
145+
)
146+
field_list.append(item_field)
147+
if len(field_list) == 0:
148+
continue
149+
field = ListField(field_list)
150+
151+
# datasets.Translation cannot be mapped directly
152+
# but it's dict structure can be mapped to a ListField of 2 ListField
153+
elif isinstance(value, Translation):
154+
if value.dtype == "dict":
155+
input_dict = inputs[0][feature]
156+
langs = list(input_dict.keys())
157+
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
158+
langs_field = ListField(field_langs)
159+
texts = list()
160+
for lang in langs:
161+
texts.append(TextField([Token(input_dict[lang])]))
162+
field = ListField([langs_field, ListField(texts)])
163+
164+
# datasets.TranslationVariableLanguages
165+
# is functionally a pair of Lists and hence mapped to a ListField of 2 ListField
166+
elif isinstance(value, TranslationVariableLanguages):
167+
if value.dtype == "dict":
168+
input_dict = inputs[0][feature]
169+
langs = input_dict["language"]
170+
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
171+
langs_field = ListField(field_langs)
172+
texts = list()
173+
for lang in langs:
174+
index = langs.index(lang)
175+
texts.append(TextField([Token(input_dict["translation"][index])]))
176+
field = ListField([langs_field, ListField(texts)])
177+
178+
else:
179+
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
180+
181+
fields[feature] = field
182+
183+
return Instance(fields)

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"lmdb",
7474
"more-itertools",
7575
"wandb>=0.10.0,<0.11.0",
76+
"datasets==1.5.0",
7677
],
7778
entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]},
7879
include_package_data=True,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
3+
from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetSplitReader
4+
import logging
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
# TODO these UTs are actually downloading the datasets and will be very very slow
10+
# TODO add UT were we compare huggingface wrapped reader with an explicitly coded builder
11+
class HuggingfaceDatasetSplitReaderTest:
12+
13+
SUPPORTED_DATASETS_WITHOUT_CONFIG = [
14+
"afrikaans_ner_corpus",
15+
"dbpedia_14",
16+
"trec",
17+
"swahili",
18+
"conll2003",
19+
"emotion",
20+
]
21+
22+
"""
23+
Running the tests for supported datasets which do not require config name to be specified
24+
"""
25+
26+
@pytest.mark.parametrize("dataset", SUPPORTED_DATASETS_WITHOUT_CONFIG)
27+
def test_read_for_datasets_without_config(self, dataset):
28+
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset)
29+
instances = list(huggingface_reader.read(None))
30+
assert len(instances) == len(huggingface_reader.dataset)
31+
32+
# Not testing for all configurations only some
33+
SUPPORTED_DATASET_CONFIGURATION = (
34+
("glue", "cola"),
35+
("universal_dependencies", "af_afribooms"),
36+
("xnli", "all_languages"),
37+
)
38+
39+
"""
40+
Running the tests for supported datasets which require config name to be specified
41+
"""
42+
43+
@pytest.mark.parametrize("dataset, config", SUPPORTED_DATASET_CONFIGURATION)
44+
def test_read_for_datasets_requiring_config(self, dataset, config):
45+
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset, config_name=config)
46+
instances = list(huggingface_reader.read(None))
47+
assert len(instances) == len(huggingface_reader.dataset)

0 commit comments

Comments
 (0)