Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 6adba7e

Browse files
dirkgrepwalsh
authored andcommitted
Predicting with a dataset reader on a multitask model (#5115)
* Create a way to use allennlp predict with a dataset and a multitask model * Fix type ignoration * Changelog * Fix to the predictor
1 parent 900914a commit 6adba7e

File tree

5 files changed

+94
-50
lines changed

5 files changed

+94
-50
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`.
1313
- Added a T5 implementation to `modules.transformers`.
1414
- Added new `sub_token_mode` parameter to `pretrained_transformer_mismatched_embedder` class to support first sub-token embedding
15+
- Added a way to run a multi task model with a dataset reader as part of `allennlp predict`.
1516
- Added new `eval_mode` in `PretrainedTransformerEmbedder`. If it is set to `True`, the transformer is _always_ run in evaluation mode, which, e.g., disables dropout and does not update batch normalization statistics.
1617
- Added additional parameters to the W&B callback: `entity`, `group`, `name`, `notes`, and `wandb_kwargs`.
1718

allennlp/commands/predict.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
or dataset to JSON predictions using a trained model and its
44
[`Predictor`](../predictors/predictor.md#predictor) wrapper.
55
"""
6-
76
from typing import List, Iterator, Optional
87
import argparse
98
import sys
@@ -16,6 +15,7 @@
1615
from allennlp.common.checks import check_for_gpu, ConfigurationError
1716
from allennlp.common.file_utils import cached_path
1817
from allennlp.common.util import lazy_groups_of
18+
from allennlp.data.dataset_readers import MultiTaskDatasetReader
1919
from allennlp.models.archival import load_archive
2020
from allennlp.predictors.predictor import Predictor, JsonDict
2121
from allennlp.data import Instance
@@ -73,6 +73,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
7373
"flag is set.",
7474
)
7575

76+
subparser.add_argument(
77+
"--multitask-head",
78+
type=str,
79+
default=None,
80+
help="If you are using a dataset reader to make predictions, and the model is a"
81+
"multitask model, you have to specify the name of the model head to use here.",
82+
)
83+
7684
subparser.add_argument(
7785
"-o",
7886
"--overrides",
@@ -144,6 +152,7 @@ def __init__(
144152
batch_size: int,
145153
print_to_console: bool,
146154
has_dataset_reader: bool,
155+
multitask_head: Optional[str] = None,
147156
) -> None:
148157
self._predictor = predictor
149158
self._input_file = input_file
@@ -152,6 +161,24 @@ def __init__(
152161
self._print_to_console = print_to_console
153162
self._dataset_reader = None if not has_dataset_reader else predictor._dataset_reader
154163

164+
self._multitask_head = multitask_head
165+
if self._multitask_head is not None:
166+
if self._dataset_reader is None:
167+
raise ConfigurationError(
168+
"You must use a dataset reader when using --multitask-head."
169+
)
170+
if not isinstance(self._dataset_reader, MultiTaskDatasetReader):
171+
raise ConfigurationError(
172+
"--multitask-head only works with a multitask dataset reader."
173+
)
174+
if (
175+
isinstance(self._dataset_reader, MultiTaskDatasetReader)
176+
and self._multitask_head is None
177+
):
178+
raise ConfigurationError(
179+
"You must specify --multitask-head when using a multitask dataset reader."
180+
)
181+
155182
def _predict_json(self, batch_data: List[JsonDict]) -> Iterator[str]:
156183
if len(batch_data) == 1:
157184
results = [self._predictor.predict_json(batch_data[0])]
@@ -196,7 +223,15 @@ def _get_instance_data(self) -> Iterator[Instance]:
196223
elif self._dataset_reader is None:
197224
raise ConfigurationError("To generate instances directly, pass a DatasetReader.")
198225
else:
199-
yield from self._dataset_reader.read(self._input_file)
226+
if isinstance(self._dataset_reader, MultiTaskDatasetReader):
227+
assert (
228+
self._multitask_head is not None
229+
) # This is properly checked by the constructor.
230+
yield from self._dataset_reader.read(
231+
self._input_file, force_task=self._multitask_head
232+
)
233+
else:
234+
yield from self._dataset_reader.read(self._input_file)
200235

201236
def run(self) -> None:
202237
has_reader = self._dataset_reader is not None
@@ -235,5 +270,6 @@ def _predict(args: argparse.Namespace) -> None:
235270
args.batch_size,
236271
not args.silent,
237272
args.use_dataset_reader,
273+
args.multitask_head,
238274
)
239275
manager.run()

allennlp/data/data_loaders/multitask_data_loader.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66
from overrides import overrides
77

88
from allennlp.common import util
9-
from allennlp.data.dataset_readers.dataset_reader import (
10-
DatasetReader,
11-
DatasetReaderInput,
12-
WorkerInfo,
13-
)
149
from allennlp.data.batch import Batch
1510
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
1611
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader
@@ -251,7 +246,7 @@ def _get_instances_for_epoch(self) -> Dict[str, Iterable[Instance]]:
251246

252247
def _make_data_loader(self, key: str) -> MultiProcessDataLoader:
253248
kwargs: Dict[str, Any] = {
254-
"reader": _MultitaskDatasetReaderShim(self.readers[key], key),
249+
"reader": self.readers[key],
255250
"data_path": self.data_paths[key],
256251
# We don't load batches from this data loader, only instances, but we have to set
257252
# something for the batch size, so we set 1.
@@ -264,39 +259,3 @@ def _make_data_loader(self, key: str) -> MultiProcessDataLoader:
264259
if key in self._start_method:
265260
kwargs["start_method"] = self._start_method[key]
266261
return MultiProcessDataLoader(**kwargs)
267-
268-
269-
@DatasetReader.register("multitask_shim")
270-
class _MultitaskDatasetReaderShim(DatasetReader):
271-
"""This dataset reader wraps another dataset reader and adds the name of the "task" into
272-
each instance as a metadata field. This exists only to support `MultitaskDataLoader`. You
273-
should not have to use this yourself."""
274-
275-
def __init__(self, inner: DatasetReader, head: str, **kwargs):
276-
super().__init__(**kwargs)
277-
self.inner = inner
278-
self.head = head
279-
280-
def _set_worker_info(self, info: Optional[WorkerInfo]) -> None:
281-
"""
282-
Should only be used internally.
283-
"""
284-
super()._set_worker_info(info)
285-
self.inner._set_worker_info(info)
286-
287-
def read(self, file_path: DatasetReaderInput) -> Iterator[Instance]:
288-
from allennlp.data.fields import MetadataField
289-
290-
for instance in self.inner.read(file_path):
291-
instance.add_field("task", MetadataField(self.head))
292-
yield instance
293-
294-
def text_to_instance(self, *inputs) -> Instance:
295-
from allennlp.data.fields import MetadataField
296-
297-
instance = self.inner.text_to_instance(*inputs)
298-
instance.add_field("task", MetadataField(self.head))
299-
return instance
300-
301-
def apply_token_indexers(self, instance: Instance) -> None:
302-
self.inner.apply_token_indexers(instance)

allennlp/data/dataset_readers/multitask.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from os import PathLike
2-
from typing import Dict, Iterator, Union
2+
from typing import Dict, Iterator, Union, Optional
33

44
from allennlp.data.instance import Instance
5-
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
5+
from allennlp.data.dataset_readers.dataset_reader import (
6+
DatasetReader,
7+
WorkerInfo,
8+
DatasetReaderInput,
9+
)
610

711

812
@DatasetReader.register("multitask")
@@ -25,7 +29,51 @@ class MultiTaskDatasetReader(DatasetReader):
2529
"""
2630

2731
def __init__(self, readers: Dict[str, DatasetReader]) -> None:
28-
self.readers = readers
32+
self.readers = {
33+
task: _MultitaskDatasetReaderShim(reader, task) for task, reader in readers.items()
34+
}
2935

30-
def read(self, file_paths: Dict[str, Union[PathLike, str]]) -> Dict[str, Iterator[Instance]]: # type: ignore
31-
raise RuntimeError("This class is not designed to be called like this")
36+
def read( # type: ignore
37+
self,
38+
file_paths: Union[PathLike, str, Dict[str, Union[PathLike, str]]],
39+
*,
40+
force_task: Optional[str] = None
41+
) -> Union[Iterator[Instance], Dict[str, Iterator[Instance]]]:
42+
if force_task is None:
43+
raise RuntimeError("This class is not designed to be called like this.")
44+
return self.readers[force_task].read(file_paths)
45+
46+
47+
@DatasetReader.register("multitask_shim")
48+
class _MultitaskDatasetReaderShim(DatasetReader):
49+
"""This dataset reader wraps another dataset reader and adds the name of the "task" into
50+
each instance as a metadata field. You should not have to use this yourself."""
51+
52+
def __init__(self, inner: DatasetReader, head: str, **kwargs):
53+
super().__init__(**kwargs)
54+
self.inner = inner
55+
self.head = head
56+
57+
def _set_worker_info(self, info: Optional[WorkerInfo]) -> None:
58+
"""
59+
Should only be used internally.
60+
"""
61+
super()._set_worker_info(info)
62+
self.inner._set_worker_info(info)
63+
64+
def read(self, file_path: DatasetReaderInput) -> Iterator[Instance]:
65+
from allennlp.data.fields import MetadataField
66+
67+
for instance in self.inner.read(file_path):
68+
instance.add_field("task", MetadataField(self.head))
69+
yield instance
70+
71+
def text_to_instance(self, *inputs) -> Instance:
72+
from allennlp.data.fields import MetadataField
73+
74+
instance = self.inner.text_to_instance(*inputs)
75+
instance.add_field("task", MetadataField(self.head))
76+
return instance
77+
78+
def apply_token_indexers(self, instance: Instance) -> None:
79+
self.inner.apply_token_indexers(instance)

allennlp/predictors/multitask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, model: MultiTaskModel, dataset_reader: MultiTaskDatasetReader
5252
predictor_class: Type[Predictor] = (
5353
Predictor.by_name(predictor_name) if predictor_name is not None else Predictor # type: ignore
5454
)
55-
self.predictors[name] = predictor_class(model, dataset_reader.readers[name])
55+
self.predictors[name] = predictor_class(model, dataset_reader.readers[name].inner)
5656

5757
@overrides
5858
def predict_instance(self, instance: Instance) -> JsonDict:

0 commit comments

Comments
 (0)