Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 0f6b3b8

Browse files
nelson-liuDeNeutoy
authored andcommitted
Make SrlBert model use SrlEvalMetric (#3168)
* Switch SemanticRoleLabeler metric to SrlEvalScorer. * Switch back to https links, per f9e2029 * Add ignore_classes to SrlEvalScorer and ignore V in SRL model * Enable specifying path to srl-eval.pl * Only run span metric if it is enabled, and during evaluation * Add comment explaining ignore_classes * Add doc for srl_util * Add srl_util.rst to allennlp.models.rst * Fix position of comment * Use SrlEvalMetric in SrlBert
1 parent adad1bc commit 0f6b3b8

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

allennlp/models/srl_bert.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
from allennlp.data import Vocabulary
1010
from allennlp.models.model import Model
11+
from allennlp.models.srl_util import convert_bio_tags_to_conll_format
1112
from allennlp.nn import InitializerApplicator, RegularizerApplicator
1213
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
1314
from allennlp.nn.util import get_lengths_from_binary_sequence_mask, viterbi_decode
14-
from allennlp.training.metrics import SpanBasedF1Measure
15+
from allennlp.training.metrics.srl_eval_scorer import SrlEvalScorer, DEFAULT_SRL_EVAL_PATH
1516

1617
@Model.register("srl_bert")
1718
class SrlBert(Model):
@@ -31,6 +32,9 @@ class SrlBert(Model):
3132
Whether or not to use label smoothing on the labels when computing cross entropy loss.
3233
ignore_span_metric: ``bool``, optional (default = False)
3334
Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
35+
srl_eval_path: ``str``, optional (default=``DEFAULT_SRL_EVAL_PATH``)
36+
The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
37+
which is located at allennlp/tools/srl-eval.pl . If ``None``, srl-eval.pl is not used.
3438
"""
3539
def __init__(self,
3640
vocab: Vocabulary,
@@ -39,7 +43,8 @@ def __init__(self,
3943
initializer: InitializerApplicator = InitializerApplicator(),
4044
regularizer: Optional[RegularizerApplicator] = None,
4145
label_smoothing: float = None,
42-
ignore_span_metric: bool = False) -> None:
46+
ignore_span_metric: bool = False,
47+
srl_eval_path: str = DEFAULT_SRL_EVAL_PATH) -> None:
4348
super(SrlBert, self).__init__(vocab, regularizer)
4449

4550
if isinstance(bert_model, str):
@@ -48,9 +53,12 @@ def __init__(self,
4853
self.bert_model = bert_model
4954

5055
self.num_classes = self.vocab.get_vocab_size("labels")
51-
# For the span based evaluation, we don't want to consider labels
52-
# for verb, because the verb index is provided to the model.
53-
self.span_metric = SpanBasedF1Measure(vocab, tag_namespace="labels", ignore_classes=["V"])
56+
if srl_eval_path is not None:
57+
# For the span based evaluation, we don't want to consider labels
58+
# for verb, because the verb index is provided to the model.
59+
self.span_metric = SrlEvalScorer(srl_eval_path, ignore_classes=["V"])
60+
else:
61+
self.span_metric = None
5462
self.tag_projection_layer = Linear(self.bert_model.config.hidden_size, self.num_classes)
5563

5664
self.embedding_dropout = Dropout(p=embedding_dropout)
@@ -110,25 +118,38 @@ def forward(self, # type: ignore
110118
sequence_length,
111119
self.num_classes])
112120
output_dict = {"logits": logits, "class_probabilities": class_probabilities}
113-
if tags is not None:
114-
loss = sequence_cross_entropy_with_logits(logits,
115-
tags,
116-
mask,
117-
label_smoothing=self._label_smoothing)
118-
if not self.ignore_span_metric:
119-
self.span_metric(class_probabilities, tags, mask)
120-
output_dict["loss"] = loss
121-
122121
# We need to retain the mask in the output dictionary
123122
# so that we can crop the sequences to remove padding
124123
# when we do viterbi inference in self.decode.
125124
output_dict["mask"] = mask
126-
127125
# We add in the offsets here so we can compute the un-wordpieced tags.
128126
words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata])
129127
output_dict["words"] = list(words)
130128
output_dict["verb"] = list(verbs)
131129
output_dict["wordpiece_offsets"] = list(offsets)
130+
131+
if tags is not None:
132+
loss = sequence_cross_entropy_with_logits(logits,
133+
tags,
134+
mask,
135+
label_smoothing=self._label_smoothing)
136+
if not self.ignore_span_metric and self.span_metric is not None and not self.training:
137+
batch_verb_indices = [example_metadata["verb_index"] for example_metadata in metadata]
138+
batch_sentences = [example_metadata["words"] for example_metadata in metadata]
139+
# Get the BIO tags from decode()
140+
# TODO (nfliu): This is kind of a hack, consider splitting out part
141+
# of decode() to a separate function.
142+
batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
143+
batch_conll_predicted_tags = [convert_bio_tags_to_conll_format(tags) for
144+
tags in batch_bio_predicted_tags]
145+
batch_bio_gold_tags = [example_metadata["gold_tags"] for example_metadata in metadata]
146+
batch_conll_gold_tags = [convert_bio_tags_to_conll_format(tags) for
147+
tags in batch_bio_gold_tags]
148+
self.span_metric(batch_verb_indices,
149+
batch_sentences,
150+
batch_conll_predicted_tags,
151+
batch_conll_gold_tags)
152+
output_dict["loss"] = loss
132153
return output_dict
133154

134155
@overrides

0 commit comments

Comments
 (0)