Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit f2824fd

Browse files
authored
Predictors for demo LMs, update for coref predictor (#3202)
* predictors for demo LMs, other fixes * added test * More tests * Add missing method * Add docstring * Fix decode methods * pylint, mypy * more pylint... * docs
1 parent d78ac70 commit f2824fd

12 files changed

+234
-19
lines changed

allennlp/models/masked_language_model.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,10 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
145145
for mask_positions in instance_indices])
146146
output_dict["words"] = top_words
147147
tokens = []
148-
for instance_indices in output_dict['token_ids']:
149-
tokens.append([[self.vocab.get_token_from_index(token_id.item(),
150-
namespace=self._target_namespace)
151-
for token_id in token_ids]
152-
for token_ids in instance_indices])
148+
for instance_tokens in output_dict['token_ids']:
149+
tokens.append([self.vocab.get_token_from_index(token_id.item(),
150+
namespace=self._target_namespace)
151+
for token_id in instance_tokens])
153152
output_dict["tokens"] = tokens
154153

155154
return output_dict

allennlp/models/next_token_lm.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,10 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
116116
for index in instance_indices]])
117117
output_dict["words"] = top_words
118118
tokens = []
119-
for instance_indices in output_dict['token_ids']:
120-
tokens.append([[self.vocab.get_token_from_index(token_id.item(),
121-
namespace=self._target_namespace)
122-
for token_id in token_ids]
123-
for token_ids in instance_indices])
119+
for instance_tokens in output_dict['token_ids']:
120+
tokens.append([self.vocab.get_token_from_index(token_id.item(),
121+
namespace=self._target_namespace)
122+
for token_id in instance_tokens])
124123
output_dict["tokens"] = tokens
125124

126125
return output_dict

allennlp/nn/util.py

+24
Original file line numberDiff line numberDiff line change
@@ -1472,3 +1472,27 @@ def inspect_parameters(module: torch.nn.Module, quiet: bool = False) -> Dict[str
14721472
if not quiet:
14731473
print(json.dumps(results, indent=4))
14741474
return results
1475+
1476+
1477+
def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module:
1478+
"""
1479+
Takes a model (typically an AllenNLP ``Model``, but this works for any ``torch.nn.Module``) and
1480+
makes a best guess about which module is the embedding layer. For typical AllenNLP models,
1481+
this often is the ``TextFieldEmbedder``, but if you're using a pre-trained contextualizer, we
1482+
really want layer 0 of that contextualizer, not the output. So there are a bunch of hacks in
1483+
here for specific pre-trained contextualizers.
1484+
"""
1485+
# We'll look for a few special cases in a first pass, then fall back to just finding a
1486+
# TextFieldEmbedder in a second pass if we didn't find a special case.
1487+
from pytorch_pretrained_bert.modeling import BertEmbeddings
1488+
from pytorch_transformers.modeling_gpt2 import GPT2Model
1489+
from allennlp.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
1490+
for module in model.modules():
1491+
if isinstance(module, BertEmbeddings):
1492+
return module.word_embeddings
1493+
if isinstance(module, GPT2Model):
1494+
return module.wte
1495+
for module in model.modules():
1496+
if isinstance(module, TextFieldEmbedder):
1497+
return module
1498+
raise RuntimeError("No embedding module found!")

allennlp/predictors/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from allennlp.predictors.decomposable_attention import DecomposableAttentionPredictor
1616
from allennlp.predictors.dialog_qa import DialogQAPredictor
1717
from allennlp.predictors.event2mind import Event2MindPredictor
18+
from allennlp.predictors.masked_language_model import MaskedLanguageModelPredictor
19+
from allennlp.predictors.next_token_lm import NextTokenLMPredictor
1820
from allennlp.predictors.nlvr_parser import NlvrParserPredictor
1921
from allennlp.predictors.open_information_extraction import OpenIePredictor
2022
from allennlp.predictors.quarel_parser import QuarelParserPredictor

allennlp/predictors/coref.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from typing import List
1+
from copy import deepcopy
2+
from typing import List, Dict
23

34
from overrides import overrides
45
from spacy.tokens import Doc
6+
import numpy
57

68
from allennlp.common.util import JsonDict
79
from allennlp.common.util import get_spacy_model
810
from allennlp.data import DatasetReader, Instance
11+
from allennlp.data.fields import ListField, SequenceLabelField
912
from allennlp.models import Model
1013
from allennlp.predictors.predictor import Predictor
1114

@@ -72,6 +75,41 @@ def predict_tokenized(self, tokenized_document: List[str]) -> JsonDict:
7275
instance = self._words_list_to_instance(tokenized_document)
7376
return self.predict_instance(instance)
7477

78+
@overrides
79+
def predictions_to_labeled_instances(self,
80+
instance: Instance,
81+
outputs: Dict[str, numpy.ndarray]) -> List[Instance]:
82+
"""
83+
Takes each predicted cluster and makes it into a labeled ``Instance`` with only that
84+
cluster labeled, so we can compute gradients of the loss `on the model's prediction of that
85+
cluster`. This lets us run interpretation methods using those gradients. See superclass
86+
docstring for more info.
87+
"""
88+
# Digging into an Instance makes mypy go crazy, because we have all kinds of things where
89+
# the type has been lost. So there are lots of `type: ignore`s here...
90+
predicted_clusters = outputs['clusters']
91+
span_field: ListField = instance['spans'] # type: ignore
92+
instances = []
93+
for cluster in predicted_clusters:
94+
new_instance = deepcopy(instance)
95+
span_labels = [0 if (span.span_start, span.span_end) in cluster else -1 # type: ignore
96+
for span in span_field] # type: ignore
97+
new_instance.add_field('span_labels',
98+
SequenceLabelField(span_labels, span_field),
99+
self._model.vocab)
100+
new_instance['metadata'].metadata['clusters'] = [cluster] # type: ignore
101+
instances.append(new_instance)
102+
if not instances:
103+
# No predicted clusters; we just give an empty coref prediction.
104+
new_instance = deepcopy(instance)
105+
span_labels = [-1] * len(span_field) # type: ignore
106+
new_instance.add_field('span_labels',
107+
SequenceLabelField(span_labels, span_field),
108+
self._model.vocab)
109+
new_instance['metadata'].metadata['clusters'] = [] # type: ignore
110+
instances.append(new_instance)
111+
return instances
112+
75113
@staticmethod
76114
def replace_corefs(document: Doc, clusters: List[List[List[int]]]) -> str:
77115
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from copy import deepcopy
2+
from typing import Dict
3+
4+
from overrides import overrides
5+
import numpy
6+
7+
from allennlp.common.util import JsonDict
8+
from allennlp.data import Instance, Token
9+
from allennlp.data.fields import TextField
10+
from allennlp.predictors.predictor import Predictor
11+
12+
13+
@Predictor.register('masked_language_model')
14+
class MaskedLanguageModelPredictor(Predictor):
15+
def predict(self, sentence_with_masks: str) -> JsonDict:
16+
return self.predict_json({"sentence" : sentence_with_masks})
17+
18+
@overrides
19+
def predictions_to_labeled_instances(self,
20+
instance: Instance,
21+
outputs: Dict[str, numpy.ndarray]):
22+
new_instance = deepcopy(instance)
23+
token_field: TextField = instance['tokens'] # type: ignore
24+
mask_targets = [Token(target_top_k[0]) for target_top_k in outputs['words']]
25+
# pylint: disable=protected-access
26+
new_instance.add_field('target_ids',
27+
TextField(mask_targets, token_field._token_indexers),
28+
vocab=self._model.vocab)
29+
return [new_instance]
30+
31+
@overrides
32+
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
33+
"""
34+
Expects JSON that looks like ``{"sentence": "..."}``.
35+
"""
36+
sentence = json_dict["sentence"]
37+
return self._dataset_reader.text_to_instance(sentence=sentence) # type: ignore

allennlp/predictors/next_token_lm.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from copy import deepcopy
2+
from typing import Dict
3+
4+
from overrides import overrides
5+
import numpy
6+
7+
from allennlp.common.util import JsonDict
8+
from allennlp.data import Instance, Token
9+
from allennlp.data.fields import TextField
10+
from allennlp.predictors.predictor import Predictor
11+
12+
13+
@Predictor.register('next_token_lm')
14+
class NextTokenLMPredictor(Predictor):
15+
def predict(self, sentence: str) -> JsonDict:
16+
return self.predict_json({"sentence" : sentence})
17+
18+
@overrides
19+
def predictions_to_labeled_instances(self,
20+
instance: Instance,
21+
outputs: Dict[str, numpy.ndarray]):
22+
new_instance = deepcopy(instance)
23+
token_field: TextField = instance['tokens'] # type: ignore
24+
mask_targets = [Token(target_top_k[0]) for target_top_k in outputs['words']]
25+
# pylint: disable=protected-access
26+
new_instance.add_field('target_ids',
27+
TextField(mask_targets, token_field._token_indexers),
28+
vocab=self._model.vocab)
29+
return [new_instance]
30+
31+
@overrides
32+
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
33+
"""
34+
Expects JSON that looks like ``{"sentence": "..."}``.
35+
"""
36+
sentence = json_dict["sentence"]
37+
return self._dataset_reader.text_to_instance(sentence=sentence) # type: ignore

allennlp/predictors/predictor.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Iterator, Dict, Tuple, Any
22
import json
33
from contextlib import contextmanager
4+
45
import numpy
56
from torch.utils.hooks import RemovableHandle
67
from torch import Tensor
@@ -9,10 +10,10 @@
910
from allennlp.common.checks import ConfigurationError
1011
from allennlp.common.util import JsonDict, sanitize
1112
from allennlp.data import DatasetReader, Instance
13+
from allennlp.data.dataset import Batch
1214
from allennlp.models import Model
1315
from allennlp.models.archival import Archive, load_archive
14-
from allennlp.modules.text_field_embedders import TextFieldEmbedder
15-
from allennlp.data.dataset import Batch
16+
from allennlp.nn import util
1617

1718
# a mapping from model `type` to the default Predictor for that type
1819
DEFAULT_PREDICTORS = {
@@ -137,10 +138,8 @@ def hook_layers(module, grad_in, grad_out): # pylint: disable=unused-argument
137138
embedding_gradients.append(grad_out[0])
138139

139140
backward_hooks = []
140-
for module in self._model.modules():
141-
if isinstance(module, TextFieldEmbedder):
142-
backward_hooks.append(module.register_backward_hook(hook_layers))
143-
141+
embedding_layer = util.find_embedding_layer(self._model)
142+
backward_hooks.append(embedding_layer.register_backward_hook(hook_layers))
144143
return backward_hooks
145144

146145
@contextmanager
@@ -192,7 +191,6 @@ def predictions_to_labeled_instances(self,
192191
multiple predictions in the output (e.g., in NER a model predicts multiple spans). In this
193192
case, each instance in the returned list of Instances contains an individual
194193
entity prediction as the label.
195-
196194
"""
197195
# pylint: disable=unused-argument,no-self-use
198196
raise RuntimeError("implement this method for model interpretations or attacks")

allennlp/tests/predictors/coref_test.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=no-self-use,invalid-name
1+
# pylint: disable=no-self-use,invalid-name,protected-access
22
import spacy
33

44
from allennlp.common.testing import AllenNlpTestCase
@@ -93,3 +93,24 @@ def test_replace_corefs(self):
9393
doc = nlp(text)
9494
output = CorefPredictor.replace_corefs(doc, clusters)
9595
assert output == expected_outputs[i]
96+
97+
def test_predictions_to_labeled_instances(self):
98+
inputs = {"document": "This is a single string document about a test. Sometimes it "
99+
"contains coreferent parts."}
100+
archive = load_archive(self.FIXTURES_ROOT / 'coref' / 'serialization' / 'model.tar.gz')
101+
predictor = Predictor.from_archive(archive, 'coreference-resolution')
102+
103+
instance = predictor._json_to_instance(inputs)
104+
outputs = predictor._model.forward_on_instance(instance)
105+
new_instances = predictor.predictions_to_labeled_instances(instance, outputs)
106+
assert new_instances is not None
107+
108+
for new_instance in new_instances:
109+
assert 'span_labels' in new_instance
110+
assert len(new_instance['span_labels']) == 60 # 7 words in input
111+
true_top_spans = set(tuple(span) for span in outputs['top_spans'])
112+
pred_clust_spans = set()
113+
for i, span in enumerate(outputs['top_spans']):
114+
if new_instance['span_labels'][i]:
115+
pred_clust_spans.add(tuple(span))
116+
assert true_top_spans == pred_clust_spans
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# pylint: disable=no-self-use, protected-access
2+
from allennlp.common.testing import AllenNlpTestCase
3+
from allennlp.models.archival import load_archive
4+
from allennlp.predictors import Predictor
5+
6+
from ..modules.language_model_heads.linear import LinearLanguageModelHead # pylint: disable=unused-import
7+
8+
9+
class TestMaskedLanguageModelPredictor(AllenNlpTestCase):
10+
def test_predictions_to_labeled_instances(self):
11+
inputs = {
12+
"sentence": "Eric [MASK] was an intern at [MASK]",
13+
}
14+
15+
archive = load_archive(self.FIXTURES_ROOT / 'masked_language_model' / 'serialization' / 'model.tar.gz')
16+
predictor = Predictor.from_archive(archive, 'masked_language_model')
17+
18+
instance = predictor._json_to_instance(inputs)
19+
outputs = predictor._model.forward_on_instance(instance)
20+
new_instances = predictor.predictions_to_labeled_instances(instance, outputs)
21+
assert len(new_instances) == 1
22+
assert 'target_ids' in new_instances[0]
23+
assert len(new_instances[0]['target_ids'].tokens) == 2 # should have added two words
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# pylint: disable=no-self-use, protected-access
2+
from allennlp.common.testing import AllenNlpTestCase
3+
from allennlp.models.archival import load_archive
4+
from allennlp.predictors import Predictor
5+
6+
from ..modules.language_model_heads.linear import LinearLanguageModelHead # pylint: disable=unused-import
7+
8+
9+
class TestNextTokenLMPredictor(AllenNlpTestCase):
10+
def test_predictions_to_labeled_instances(self):
11+
inputs = {
12+
"sentence": "Eric Wallace was an intern at",
13+
}
14+
15+
archive = load_archive(self.FIXTURES_ROOT / 'next_token_lm' / 'serialization' / 'model.tar.gz')
16+
predictor = Predictor.from_archive(archive, 'next_token_lm')
17+
18+
instance = predictor._json_to_instance(inputs)
19+
outputs = predictor._model.forward_on_instance(instance)
20+
new_instances = predictor.predictions_to_labeled_instances(instance, outputs)
21+
assert len(new_instances) == 1
22+
assert 'target_ids' in new_instances[0]
23+
assert len(new_instances[0]['target_ids'].tokens) == 1 # should have added one word

doc/api/allennlp.predictors.rst

+14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ allennlp.predictors
2323
* :ref:`Event2MindPredictor<event2mind>`
2424
* :ref:`AtisParserPredictor<atis-parser>`
2525
* :ref:`TextClassifierPredictor<text_classifier>`
26+
* :ref:`MaskedLanguageModelPredictor<masked-language-model>`
27+
* :ref:`NextTokenLMPredictor<next-token-lm>`
2628

2729
.. _predictor:
2830
.. automodule:: allennlp.predictors.predictor
@@ -131,3 +133,15 @@ allennlp.predictors
131133
:members:
132134
:undoc-members:
133135
:show-inheritance:
136+
137+
.. _masked-language-model:
138+
.. automodule:: allennlp.predictors.masked_language_model
139+
:members:
140+
:undoc-members:
141+
:show-inheritance:
142+
143+
.. _next-token-lm:
144+
.. automodule:: allennlp.predictors.next_token_lm
145+
:members:
146+
:undoc-members:
147+
:show-inheritance:

0 commit comments

Comments
 (0)