Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 191b641

Browse files
authored
make existing readers work with multi-process loading (#4597)
* make existing readers work with multi-process loading * add 'overrides' decorator * call apply_token_indexers in predictor * clean up * fix tests
1 parent d7124d4 commit 191b641

File tree

10 files changed

+53
-14
lines changed

10 files changed

+53
-14
lines changed

allennlp/data/dataset_readers/babi.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,29 @@ def text_to_instance(
8585

8686
if self._keep_sentences:
8787
context_field_ks = ListField(
88-
[
89-
TextField([Token(word) for word in line], self._token_indexers)
90-
for line in context
91-
]
88+
[TextField([Token(word) for word in line]) for line in context]
9289
)
9390

9491
fields["supports"] = ListField(
9592
[IndexField(support, context_field_ks) for support in supports]
9693
)
9794
else:
98-
context_field = TextField(
99-
[Token(word) for line in context for word in line], self._token_indexers
100-
)
95+
context_field = TextField([Token(word) for line in context for word in line])
10196

10297
fields["context"] = context_field_ks if self._keep_sentences else context_field
103-
fields["question"] = TextField([Token(word) for word in question], self._token_indexers)
104-
fields["answer"] = TextField([Token(answer)], self._token_indexers)
98+
fields["question"] = TextField(
99+
[Token(word) for word in question],
100+
)
101+
fields["answer"] = TextField([Token(answer)])
105102

106103
return Instance(fields)
104+
105+
@overrides
106+
def apply_token_indexers(self, instance: Instance) -> None:
107+
if self._keep_sentences:
108+
for text_field in instance.fields["context"]: # type: ignore
109+
text_field._token_indexers = self._token_indexers # type: ignore
110+
else:
111+
instance.fields["context"]._token_indexers = self._token_indexers # type: ignore
112+
instance.fields["question"]._token_indexers = self._token_indexers # type: ignore
113+
instance.fields["answer"]._token_indexers = self._token_indexers # type: ignore

allennlp/data/dataset_readers/conll2003.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def text_to_instance( # type: ignore
143143
We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
144144
"""
145145

146-
sequence = TextField(tokens, self._token_indexers)
146+
sequence = TextField(tokens)
147147
instance_fields: Dict[str, Field] = {"tokens": sequence}
148148
instance_fields["metadata"] = MetadataField({"words": [x.text for x in tokens]})
149149

@@ -198,3 +198,7 @@ def text_to_instance( # type: ignore
198198
)
199199

200200
return Instance(instance_fields)
201+
202+
@overrides
203+
def apply_token_indexers(self, instance: Instance) -> None:
204+
instance.fields["tokens"]._token_indexers = self._token_indexers # type: ignore

allennlp/data/dataset_readers/nlvr2_reader.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def text_to_instance(
184184
only_predictions: bool = False,
185185
) -> Instance:
186186
tokenized_sentence = self._tokenizer.tokenize(question)
187-
sentence_field = TextField(tokenized_sentence, self._token_indexers)
187+
sentence_field = TextField(tokenized_sentence)
188188

189189
original_identifier = identifier
190190
all_boxes = []
@@ -220,3 +220,7 @@ def text_to_instance(
220220
if denotation is not None:
221221
fields["denotation"] = LabelField(int(denotation), skip_indexing=True)
222222
return Instance(fields)
223+
224+
@overrides
225+
def apply_token_indexers(self, instance: Instance) -> None:
226+
instance.fields["sentence_field"]._token_indexers = self._token_indexers # type: ignore

allennlp/data/dataset_readers/sequence_tagging.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,13 @@ def text_to_instance( # type: ignore
8686
"""
8787

8888
fields: Dict[str, Field] = {}
89-
sequence = TextField(tokens, self._token_indexers)
89+
sequence = TextField(tokens)
9090
fields["tokens"] = sequence
9191
fields["metadata"] = MetadataField({"words": [x.text for x in tokens]})
9292
if tags is not None:
9393
fields["tags"] = SequenceLabelField(tags, sequence)
9494
return Instance(fields)
95+
96+
@overrides
97+
def apply_token_indexers(self, instance: Instance) -> None:
98+
instance.fields["tokens"]._token_indexers = self._token_indexers # type: ignore

allennlp/data/dataset_readers/text_classification_json.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,21 @@ def text_to_instance(
124124
word_tokens = self._tokenizer.tokenize(sentence)
125125
if self._max_sequence_length is not None:
126126
word_tokens = self._truncate(word_tokens)
127-
sentences.append(TextField(word_tokens, self._token_indexers))
127+
sentences.append(TextField(word_tokens))
128128
fields["tokens"] = ListField(sentences)
129129
else:
130130
tokens = self._tokenizer.tokenize(text)
131131
if self._max_sequence_length is not None:
132132
tokens = self._truncate(tokens)
133-
fields["tokens"] = TextField(tokens, self._token_indexers)
133+
fields["tokens"] = TextField(tokens)
134134
if label is not None:
135135
fields["label"] = LabelField(label, skip_indexing=self._skip_label_indexing)
136136
return Instance(fields)
137+
138+
@overrides
139+
def apply_token_indexers(self, instance: Instance) -> None:
140+
if self._segment_sentences:
141+
for text_field in instance.fields["tokens"]: # type: ignore
142+
text_field._token_indexers = self._token_indexers
143+
else:
144+
instance.fields["tokens"]._token_indexers = self._token_indexers # type: ignore

allennlp/interpret/attackers/hotflip.py

+1
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def attack_from_json(
194194
whatever it was to `"she"`.
195195
"""
196196
instance = self.predictor._json_to_instance(inputs)
197+
self.predictor._dataset_reader.apply_token_indexers(instance)
197198
if target is None:
198199
output_dict = self.predictor._model.forward_on_instance(instance)
199200
else:

allennlp/predictors/predictor.py

+7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]:
6161
"""
6262

6363
instance = self._json_to_instance(inputs)
64+
self._dataset_reader.apply_token_indexers(instance)
6465
outputs = self._model.forward_on_instance(instance)
6566
new_instances = self.predictions_to_labeled_instances(instance, outputs)
6667
return new_instances
@@ -98,6 +99,9 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
9899
embedding_gradients: List[Tensor] = []
99100
hooks: List[RemovableHandle] = self._register_embedding_gradient_hooks(embedding_gradients)
100101

102+
for instance in instances:
103+
self._dataset_reader.apply_token_indexers(instance)
104+
101105
dataset = Batch(instances)
102106
dataset.index_instances(self._model.vocab)
103107
dataset_tensor_dict = util.move_to_device(dataset.as_tensor_dict(), self.cuda_device)
@@ -181,6 +185,7 @@ def _add_output(mod, _, outputs):
181185
hook.remove()
182186

183187
def predict_instance(self, instance: Instance) -> JsonDict:
188+
self._dataset_reader.apply_token_indexers(instance)
184189
outputs = self._model.forward_on_instance(instance)
185190
return sanitize(outputs)
186191

@@ -212,6 +217,8 @@ def predict_batch_json(self, inputs: List[JsonDict]) -> List[JsonDict]:
212217
return self.predict_batch_instance(instances)
213218

214219
def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]:
220+
for instance in instances:
221+
self._dataset_reader.apply_token_indexers(instance)
215222
outputs = self._model.forward_on_instances(instances)
216223
return sanitize(outputs)
217224

tests/predictors/predictor_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_get_gradients(self):
4545
predictor = Predictor.from_archive(archive)
4646

4747
instance = predictor._json_to_instance(inputs)
48+
predictor._dataset_reader.apply_token_indexers(instance)
4849
outputs = predictor._model.forward_on_instance(instance)
4950
labeled_instances = predictor.predictions_to_labeled_instances(instance, outputs)
5051
for instance in labeled_instances:
@@ -70,6 +71,7 @@ def test_get_gradients_when_requires_grad_is_false(self):
7071
embedding_layer = util.find_embedding_layer(predictor._model)
7172
assert not embedding_layer.weight.requires_grad
7273
instance = predictor._json_to_instance(inputs)
74+
predictor._dataset_reader.apply_token_indexers(instance)
7375
outputs = predictor._model.forward_on_instance(instance)
7476
labeled_instances = predictor.predictions_to_labeled_instances(instance, outputs)
7577
# ensure that gradients are always present, despite requires_grad being false on the embedding layer

tests/predictors/sentence_tagger_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def test_predictions_to_labeled_instances(self):
1313
predictor = Predictor.from_archive(archive, "sentence_tagger")
1414

1515
instance = predictor._json_to_instance(inputs)
16+
predictor._dataset_reader.apply_token_indexers(instance)
1617
outputs = predictor._model.forward_on_instance(instance)
1718
new_instances = predictor.predictions_to_labeled_instances(instance, outputs)
1819
assert len(new_instances) > 1

tests/predictors/text_classifier_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_predictions_to_labeled_instances(self):
9090
predictor = Predictor.from_archive(archive, "text_classifier")
9191

9292
instance = predictor._json_to_instance(inputs)
93+
predictor._dataset_reader.apply_token_indexers(instance)
9394
outputs = predictor._model.forward_on_instance(instance)
9495
new_instances = predictor.predictions_to_labeled_instances(instance, outputs)
9596
assert "label" in new_instances[0].fields

0 commit comments

Comments
 (0)