Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit d082e55

Browse files
committed
Fix Doc mistake and the dataset availability check
Signed-off-by: Abhishek P (VMware) <[email protected]>
1 parent edf2681 commit d082e55

File tree

3 files changed

+26
-25
lines changed

3 files changed

+26
-25
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## Unreleased
99

1010
### Added
11-
- Add `HuggingfaceDatasetSplitReader` for using huggingface datasets in AllenNLP with limited support
11+
- Add `HuggingfaceDatasetReader` for using huggingface datasets in AllenNLP with limited support
1212
- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`.
1313

1414
### Changed

allennlp/data/dataset_readers/huggingface_datasets_reader.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import typing
12
from typing import Iterable, Optional
23

34
from allennlp.data import DatasetReader, Token, Field, Tokenizer
45
from allennlp.data.fields import TextField, LabelField, ListField
56
from allennlp.data.instance import Instance
6-
from datasets import load_dataset, DatasetDict, Split
7+
from datasets import load_dataset, DatasetDict, Split, list_datasets
78
from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages
89
from datasets.features import Value
910

@@ -43,15 +44,15 @@ class HuggingfaceDatasetReader(DatasetReader):
4344
# Parameters
4445
4546
dataset_name : `str`
46-
Name of the dataset from huggingface datasets the reader will be used for
47-
config_name : `str`, optional (default=`None`)
48-
Configuration(mandatory for some datasets) of the dataset
49-
pre_load : `bool`, optional (default='False`)
47+
Name of the dataset from huggingface datasets the reader will be used for.
48+
config_name : `str`, optional (default=`None`)
49+
Configuration(mandatory for some datasets) of the dataset.
50+
preload : `bool`, optional (default=`False`)
5051
If `True` all splits for the dataset is loaded(includes download etc) as part of the initialization,
51-
otherwise each split is loaded on when `read()` is used for the same for the first time
52-
tokenizer : `Tokenizer`, optional (default=`None`)
53-
If specified is used for tokenization of string and text fields from the dataset
54-
This is useful since Text in allennlp is dealt with as a series of tokens.
52+
otherwise each split is loaded on when `read()` is used for the same for the first time.
53+
tokenizer : `Tokenizer`, optional (default=`None`)
54+
If specified is used for tokenization of string and text fields from the dataset.
55+
This is useful since text in allennlp is dealt with as a series of tokens.
5556
"""
5657

5758
SUPPORTED_SPLITS = [Split.TRAIN, Split.TEST, Split.VALIDATION]
@@ -60,7 +61,7 @@ def __init__(
6061
self,
6162
dataset_name: str = None,
6263
config_name: Optional[str] = None,
63-
pre_load: Optional[bool] = False,
64+
preload: Optional[bool] = False,
6465
tokenizer: Optional[Tokenizer] = None,
6566
**kwargs,
6667
) -> None:
@@ -71,17 +72,17 @@ def __init__(
7172
)
7273

7374
# It would be cleaner to create a separate reader object for diferent dataset
74-
if dataset_name not in load_dataset():
75-
raise NotImplementedError(
75+
if dataset_name not in list_datasets():
76+
raise ValueError(
7677
f"Dataset {dataset_name} does not seem to available in huggingface datasets"
7778
)
7879
self.dataset: DatasetDict = DatasetDict()
7980
self.dataset_name = dataset_name
8081
self.config_name = config_name
8182
self.tokenizer = tokenizer
8283

83-
if pre_load:
84-
load_dataset()
84+
if preload:
85+
self.load_dataset()
8586

8687
def load_dataset(self):
8788
if self.config_name is not None:
@@ -152,7 +153,7 @@ def text_to_instance(self, *inputs) -> Instance:
152153
# TODO we need to support all different datasets features described
153154
# in https://huggingface.co/docs/datasets/features.html
154155
for feature in features:
155-
fields_to_be_added = dict[str, Field]()
156+
fields_to_be_added: typing.Dict[str, Field] = dict()
156157
item_field: Field
157158
field_list: list
158159
value = features[feature]
@@ -188,21 +189,21 @@ def text_to_instance(self, *inputs) -> Instance:
188189
# We do not know if the string is token or text, we will assume text and make each a TextField
189190
# datasets.features.Sequence of strings maps to ListField of TextField
190191
if value.feature.dtype == "string":
191-
field_list = list[TextField]()
192+
field_list2: typing.List[TextField] = list()
192193
for item in inputs[1][feature]:
193194
# If tokenizer is provided we will use it to split it to tokens
194195
# Else put whole text as a single token
195-
tokens: list[Token]
196+
tokens: typing.List[Token]
196197
if self.tokenizer is not None:
197198
tokens = self.tokenizer.tokenize(item)
198199

199200
else:
200201
tokens = [Token(item)]
201202

202203
item_field = TextField(tokens)
203-
field_list.append(item_field)
204+
field_list2.append(item_field)
204205

205-
fields_to_be_added[feature] = ListField(field_list)
206+
fields_to_be_added[feature] = ListField(field_list2)
206207

207208
# datasets Sequence of strings to ListField of LabelField
208209
elif isinstance(value.feature, ClassLabel):

tests/data/dataset_readers/huggingface_datasets_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@ class HuggingfaceDatasetReaderTest:
1515

1616
@pytest.mark.parametrize(
1717
"dataset, config, split",
18-
(
19-
("glue", "cola", "train"),
20-
("glue", "cola", "test"),
21-
("universal_dependencies", "en_lines", "validation"),
22-
),
18+
(("glue", "cola", "train"), ("glue", "cola", "test")),
2319
)
2420
def test_read(self, dataset, config, split):
2521
huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config)
@@ -75,3 +71,7 @@ def test_xnli_all_languages(self):
7571
# datasets.features.TranslationVariableLanguages into two fields each
7672
# For XNLI that means 3 fields become 5
7773
assert len(instance.fields) == 5
74+
75+
def test_non_available_dataset(self):
76+
with pytest.raises(ValueError):
77+
HuggingfaceDatasetReader(dataset_name="surely-such-a-dataset-cannot-exist")

0 commit comments

Comments
 (0)