Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 7292e02

Browse files
committed
fix bug with interleaving dataset reader (#5122)
* fix bug with interleaving dataset reader * more tests * Update allennlp/data/dataset_readers/interleaving_dataset_reader.py * Update allennlp/data/dataset_readers/interleaving_dataset_reader.py
1 parent 6adba7e commit 7292e02

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3535
- Fixed a stall when using distributed training and gradient accumulation at the same time
3636
- Fixed an issue where using the `from_pretrained_transformer` `Vocabulary` constructor in distributed training via the `allennlp train` command
3737
would result in the data being iterated through unnecessarily.
38+
- Fixed a bug regarding token indexers with the `InterleavingDatasetReader` when used with multi-process data loading.
3839
- Fixed a warning from `transformers` when using `max_length` in the `PretrainedTransformerTokenizer`.
3940

4041
### Removed

allennlp/data/dataset_readers/interleaving_dataset_reader.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
from typing import Dict, Mapping, Iterable, Union
1+
from typing import Dict, Mapping, Iterable, Union, Optional
22
import json
33

4+
from overrides import overrides
5+
46
from allennlp.common.checks import ConfigurationError
5-
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, PathOrStr
7+
from allennlp.data.dataset_readers.dataset_reader import (
8+
DatasetReader,
9+
PathOrStr,
10+
WorkerInfo,
11+
DistributedInfo,
12+
)
613
from allennlp.data.fields import MetadataField
714
from allennlp.data.instance import Instance
815

@@ -52,6 +59,18 @@ def __init__(
5259
raise ConfigurationError(f"invalid scheme: {scheme}")
5360
self._scheme = scheme
5461

62+
@overrides
63+
def _set_worker_info(self, info: Optional[WorkerInfo]) -> None:
64+
super()._set_worker_info(info)
65+
for reader in self._readers.values():
66+
reader._set_worker_info(info)
67+
68+
@overrides
69+
def _set_distributed_info(self, info: Optional[DistributedInfo]) -> None:
70+
super()._set_distributed_info(info)
71+
for reader in self._readers.values():
72+
reader._set_distributed_info(info)
73+
5574
def _read_round_robin(self, datasets: Mapping[str, Iterable[Instance]]) -> Iterable[Instance]:
5675
remaining = set(datasets)
5776
dataset_iterators = {key: iter(dataset) for key, dataset in datasets.items()}
@@ -72,6 +91,7 @@ def _read_all_at_once(self, datasets: Mapping[str, Iterable[Instance]]) -> Itera
7291
instance.fields[self._dataset_field_name] = MetadataField(key)
7392
yield instance
7493

94+
@overrides
7595
def _read(self, file_path: Union[str, Dict[str, PathOrStr]]) -> Iterable[Instance]:
7696
if isinstance(file_path, str):
7797
try:
@@ -97,6 +117,11 @@ def _read(self, file_path: Union[str, Dict[str, PathOrStr]]) -> Iterable[Instanc
97117
else:
98118
raise RuntimeError("impossible to get here")
99119

100-
def text_to_instance(self) -> Instance: # type: ignore
120+
@overrides
121+
def text_to_instance(self, dataset_key: str, *args, **kwargs) -> Instance: # type: ignore
122+
return self._readers[dataset_key].text_to_instance(*args, **kwargs) # type: ignore[call-arg]
101123

102-
raise RuntimeError("text_to_instance doesn't make sense here")
124+
@overrides
125+
def apply_token_indexers(self, instance: Instance) -> None:
126+
dataset = instance.fields[self._dataset_field_name].metadata # type: ignore[attr-defined]
127+
self._readers[dataset].apply_token_indexers(instance)

tests/data/dataset_readers/interleaving_dataset_reader_test.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from typing import Iterable
22

3+
import pytest
4+
35
from allennlp.common.testing import AllenNlpTestCase
46
from allennlp.data.dataset_readers import DatasetReader, InterleavingDatasetReader
7+
from allennlp.data.data_loaders import MultiProcessDataLoader
58
from allennlp.data.fields import TextField
69
from allennlp.data.instance import Instance
710
from allennlp.data.token_indexers import SingleIdTokenIndexer
811
from allennlp.data.tokenizers import SpacyTokenizer
12+
from allennlp.data.vocabulary import Vocabulary
913

1014

1115
class PlainTextReader(DatasetReader):
@@ -20,9 +24,11 @@ def _read(self, file_path: str) -> Iterable[Instance]:
2024
yield self.text_to_instance(line)
2125

2226
def text_to_instance(self, line: str) -> Instance: # type: ignore
23-
2427
tokens = self._tokenizer.tokenize(line)
25-
return Instance({"line": TextField(tokens, self._token_indexers)})
28+
return Instance({"line": TextField(tokens)})
29+
30+
def apply_token_indexers(self, instance):
31+
instance.fields["line"].token_indexers = self._token_indexers
2632

2733

2834
class TestInterleavingDatasetReader(AllenNlpTestCase):
@@ -72,3 +78,26 @@ def test_all_at_once(self):
7278

7379
# should be in 3 buckets
7480
assert len(buckets) == 3
81+
82+
@pytest.mark.parametrize("lazy", (True, False))
83+
def test_with_multi_process_loading(self, lazy):
84+
readers = {"a": PlainTextReader(), "b": PlainTextReader(), "c": PlainTextReader()}
85+
reader = InterleavingDatasetReader(readers)
86+
data_dir = self.FIXTURES_ROOT / "data"
87+
file_path = {
88+
"a": data_dir / "babi.txt",
89+
"b": data_dir / "conll2003.txt",
90+
"c": data_dir / "conll2003.txt",
91+
}
92+
vocab = Vocabulary.from_instances(reader.read(file_path))
93+
loader = MultiProcessDataLoader(
94+
reader,
95+
file_path,
96+
num_workers=1,
97+
batch_size=1,
98+
max_instances_in_memory=2 if lazy else None,
99+
)
100+
loader.index_with(vocab)
101+
102+
list(loader.iter_instances())
103+
list(loader)

0 commit comments

Comments
 (0)