|
1 |
| -from typing import List |
| 1 | +from copy import deepcopy |
| 2 | +from typing import List, Dict |
2 | 3 |
|
3 | 4 | from overrides import overrides
|
4 | 5 | from spacy.tokens import Doc
|
| 6 | +import numpy |
5 | 7 |
|
6 | 8 | from allennlp.common.util import JsonDict
|
7 | 9 | from allennlp.common.util import get_spacy_model
|
8 | 10 | from allennlp.data import DatasetReader, Instance
|
| 11 | +from allennlp.data.fields import ListField, SequenceLabelField |
9 | 12 | from allennlp.models import Model
|
10 | 13 | from allennlp.predictors.predictor import Predictor
|
11 | 14 |
|
@@ -72,6 +75,41 @@ def predict_tokenized(self, tokenized_document: List[str]) -> JsonDict:
|
72 | 75 | instance = self._words_list_to_instance(tokenized_document)
|
73 | 76 | return self.predict_instance(instance)
|
74 | 77 |
|
| 78 | + @overrides |
| 79 | + def predictions_to_labeled_instances(self, |
| 80 | + instance: Instance, |
| 81 | + outputs: Dict[str, numpy.ndarray]) -> List[Instance]: |
| 82 | + """ |
| 83 | + Takes each predicted cluster and makes it into a labeled ``Instance`` with only that |
| 84 | + cluster labeled, so we can compute gradients of the loss `on the model's prediction of that |
| 85 | + cluster`. This lets us run interpretation methods using those gradients. See superclass |
| 86 | + docstring for more info. |
| 87 | + """ |
| 88 | + # Digging into an Instance makes mypy go crazy, because we have all kinds of things where |
| 89 | + # the type has been lost. So there are lots of `type: ignore`s here... |
| 90 | + predicted_clusters = outputs['clusters'] |
| 91 | + span_field: ListField = instance['spans'] # type: ignore |
| 92 | + instances = [] |
| 93 | + for cluster in predicted_clusters: |
| 94 | + new_instance = deepcopy(instance) |
| 95 | + span_labels = [0 if (span.span_start, span.span_end) in cluster else -1 # type: ignore |
| 96 | + for span in span_field] # type: ignore |
| 97 | + new_instance.add_field('span_labels', |
| 98 | + SequenceLabelField(span_labels, span_field), |
| 99 | + self._model.vocab) |
| 100 | + new_instance['metadata'].metadata['clusters'] = [cluster] # type: ignore |
| 101 | + instances.append(new_instance) |
| 102 | + if not instances: |
| 103 | + # No predicted clusters; we just give an empty coref prediction. |
| 104 | + new_instance = deepcopy(instance) |
| 105 | + span_labels = [-1] * len(span_field) # type: ignore |
| 106 | + new_instance.add_field('span_labels', |
| 107 | + SequenceLabelField(span_labels, span_field), |
| 108 | + self._model.vocab) |
| 109 | + new_instance['metadata'].metadata['clusters'] = [] # type: ignore |
| 110 | + instances.append(new_instance) |
| 111 | + return instances |
| 112 | + |
75 | 113 | @staticmethod
|
76 | 114 | def replace_corefs(document: Doc, clusters: List[List[List[int]]]) -> str:
|
77 | 115 | """
|
|
0 commit comments