8
8
9
9
from allennlp .data import Vocabulary
10
10
from allennlp .models .model import Model
11
+ from allennlp .models .srl_util import convert_bio_tags_to_conll_format
11
12
from allennlp .nn import InitializerApplicator , RegularizerApplicator
12
13
from allennlp .nn .util import get_text_field_mask , sequence_cross_entropy_with_logits
13
14
from allennlp .nn .util import get_lengths_from_binary_sequence_mask , viterbi_decode
14
- from allennlp .training .metrics import SpanBasedF1Measure
15
+ from allennlp .training .metrics . srl_eval_scorer import SrlEvalScorer , DEFAULT_SRL_EVAL_PATH
15
16
16
17
@Model .register ("srl_bert" )
17
18
class SrlBert (Model ):
@@ -31,6 +32,9 @@ class SrlBert(Model):
31
32
Whether or not to use label smoothing on the labels when computing cross entropy loss.
32
33
ignore_span_metric: ``bool``, optional (default = False)
33
34
Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
35
+ srl_eval_path: ``str``, optional (default=``DEFAULT_SRL_EVAL_PATH``)
36
+ The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
37
+ which is located at allennlp/tools/srl-eval.pl . If ``None``, srl-eval.pl is not used.
34
38
"""
35
39
def __init__ (self ,
36
40
vocab : Vocabulary ,
@@ -39,7 +43,8 @@ def __init__(self,
39
43
initializer : InitializerApplicator = InitializerApplicator (),
40
44
regularizer : Optional [RegularizerApplicator ] = None ,
41
45
label_smoothing : float = None ,
42
- ignore_span_metric : bool = False ) -> None :
46
+ ignore_span_metric : bool = False ,
47
+ srl_eval_path : str = DEFAULT_SRL_EVAL_PATH ) -> None :
43
48
super (SrlBert , self ).__init__ (vocab , regularizer )
44
49
45
50
if isinstance (bert_model , str ):
@@ -48,9 +53,12 @@ def __init__(self,
48
53
self .bert_model = bert_model
49
54
50
55
self .num_classes = self .vocab .get_vocab_size ("labels" )
51
- # For the span based evaluation, we don't want to consider labels
52
- # for verb, because the verb index is provided to the model.
53
- self .span_metric = SpanBasedF1Measure (vocab , tag_namespace = "labels" , ignore_classes = ["V" ])
56
+ if srl_eval_path is not None :
57
+ # For the span based evaluation, we don't want to consider labels
58
+ # for verb, because the verb index is provided to the model.
59
+ self .span_metric = SrlEvalScorer (srl_eval_path , ignore_classes = ["V" ])
60
+ else :
61
+ self .span_metric = None
54
62
self .tag_projection_layer = Linear (self .bert_model .config .hidden_size , self .num_classes )
55
63
56
64
self .embedding_dropout = Dropout (p = embedding_dropout )
@@ -110,25 +118,38 @@ def forward(self, # type: ignore
110
118
sequence_length ,
111
119
self .num_classes ])
112
120
output_dict = {"logits" : logits , "class_probabilities" : class_probabilities }
113
- if tags is not None :
114
- loss = sequence_cross_entropy_with_logits (logits ,
115
- tags ,
116
- mask ,
117
- label_smoothing = self ._label_smoothing )
118
- if not self .ignore_span_metric :
119
- self .span_metric (class_probabilities , tags , mask )
120
- output_dict ["loss" ] = loss
121
-
122
121
# We need to retain the mask in the output dictionary
123
122
# so that we can crop the sequences to remove padding
124
123
# when we do viterbi inference in self.decode.
125
124
output_dict ["mask" ] = mask
126
-
127
125
# We add in the offsets here so we can compute the un-wordpieced tags.
128
126
words , verbs , offsets = zip (* [(x ["words" ], x ["verb" ], x ["offsets" ]) for x in metadata ])
129
127
output_dict ["words" ] = list (words )
130
128
output_dict ["verb" ] = list (verbs )
131
129
output_dict ["wordpiece_offsets" ] = list (offsets )
130
+
131
+ if tags is not None :
132
+ loss = sequence_cross_entropy_with_logits (logits ,
133
+ tags ,
134
+ mask ,
135
+ label_smoothing = self ._label_smoothing )
136
+ if not self .ignore_span_metric and self .span_metric is not None and not self .training :
137
+ batch_verb_indices = [example_metadata ["verb_index" ] for example_metadata in metadata ]
138
+ batch_sentences = [example_metadata ["words" ] for example_metadata in metadata ]
139
+ # Get the BIO tags from decode()
140
+ # TODO (nfliu): This is kind of a hack, consider splitting out part
141
+ # of decode() to a separate function.
142
+ batch_bio_predicted_tags = self .decode (output_dict ).pop ("tags" )
143
+ batch_conll_predicted_tags = [convert_bio_tags_to_conll_format (tags ) for
144
+ tags in batch_bio_predicted_tags ]
145
+ batch_bio_gold_tags = [example_metadata ["gold_tags" ] for example_metadata in metadata ]
146
+ batch_conll_gold_tags = [convert_bio_tags_to_conll_format (tags ) for
147
+ tags in batch_bio_gold_tags ]
148
+ self .span_metric (batch_verb_indices ,
149
+ batch_sentences ,
150
+ batch_conll_predicted_tags ,
151
+ batch_conll_gold_tags )
152
+ output_dict ["loss" ] = loss
132
153
return output_dict
133
154
134
155
@overrides
0 commit comments