1
- from typing import List
1
+ from typing import List , Dict
2
2
3
3
from overrides import overrides
4
4
@@ -60,6 +60,116 @@ def make_oie_string(tokens: List[Token], tags: List[str]) -> str:
60
60
61
61
return " " .join (frame )
62
62
63
+ def get_predicate_indices (tags : List [str ]) -> List [int ]:
64
+ """
65
+ Return the word indices of a predicate in BIO tags.
66
+ """
67
+ return [ind for ind , tag in enumerate (tags ) if 'V' in tag ]
68
+
69
+ def get_predicate_text (sent_tokens : List [Token ], tags : List [str ]) -> str :
70
+ """
71
+ Get the predicate in this prediction.
72
+ """
73
+ return " " .join ([sent_tokens [pred_id ].text
74
+ for pred_id in get_predicate_indices (tags )])
75
+
76
+ def predicates_overlap (tags1 : List [str ], tags2 : List [str ]) -> bool :
77
+ """
78
+ Tests whether the predicate in BIO tags1 overlap
79
+ with those of tags2.
80
+ """
81
+ # Get predicate word indices from both predictions
82
+ pred_ind1 = get_predicate_indices (tags1 )
83
+ pred_ind2 = get_predicate_indices (tags2 )
84
+
85
+ # Return if pred_ind1 pred_ind2 overlap
86
+ return any (set .intersection (set (pred_ind1 ), set (pred_ind2 )))
87
+
88
+ def get_coherent_next_tag (prev_label : str , cur_label : str ) -> str :
89
+ """
90
+ Generate a coherent tag, given previous tag and current label.
91
+ """
92
+ if cur_label == "O" :
93
+ # Don't need to add prefix to an "O" label
94
+ return "O"
95
+
96
+ if prev_label == cur_label :
97
+ return f"I-{ cur_label } "
98
+ else :
99
+ return f"B-{ cur_label } "
100
+
101
+ def merge_overlapping_predictions (tags1 : List [str ], tags2 : List [str ]) -> List [str ]:
102
+ """
103
+ Merge two predictions into one. Assumes the predicate in tags1 overlap with
104
+ the predicate of tags2.
105
+ """
106
+ ret_sequence = []
107
+ prev_label = "O"
108
+
109
+ # Build a coherent sequence out of two
110
+ # spans which predicates' overlap
111
+
112
+ for tag1 , tag2 in zip (tags1 , tags2 ):
113
+ label1 = tag1 .split ("-" )[- 1 ]
114
+ label2 = tag2 .split ("-" )[- 1 ]
115
+ if (label1 == "V" ) or (label2 == "V" ):
116
+ # Construct maximal predicate length -
117
+ # add predicate tag if any of the sequence predict it
118
+ cur_label = "V"
119
+
120
+ # Else - prefer an argument over 'O' label
121
+ elif label1 != "O" :
122
+ cur_label = label1
123
+ else :
124
+ cur_label = label2
125
+
126
+ # Append cur tag to the returned sequence
127
+ cur_tag = get_coherent_next_tag (prev_label , cur_label )
128
+ prev_label = cur_label
129
+ ret_sequence .append (cur_tag )
130
+ return ret_sequence
131
+
132
+ def consolidate_predictions (outputs : List [List [str ]], sent_tokens : List [Token ]) -> Dict [str , List [str ]]:
133
+ """
134
+ Identify that certain predicates are part of a multiword predicate
135
+ (e.g., "decided to run") in which case, we don't need to return
136
+ the embedded predicate ("run").
137
+ """
138
+ pred_dict : Dict [str , List [str ]] = {}
139
+ merged_outputs = [join_mwp (output ) for output in outputs ]
140
+ predicate_texts = [get_predicate_text (sent_tokens , tags )
141
+ for tags in merged_outputs ]
142
+
143
+ for pred1_text , tags1 in zip (predicate_texts , merged_outputs ):
144
+ # A flag indicating whether to add tags1 to predictions
145
+ add_to_prediction = True
146
+
147
+ # Check if this predicate overlaps another predicate
148
+ for pred2_text , tags2 in pred_dict .items ():
149
+ if predicates_overlap (tags1 , tags2 ):
150
+ # tags1 overlaps tags2
151
+ pred_dict [pred2_text ] = merge_overlapping_predictions (tags1 , tags2 )
152
+ add_to_prediction = False
153
+
154
+ # This predicate doesn't overlap - add as a new predicate
155
+ if add_to_prediction :
156
+ pred_dict [pred1_text ] = tags1
157
+
158
+ return pred_dict
159
+
160
+
161
+ def sanitize_label (label : str ) -> str :
162
+ """
163
+ Sanitize a BIO label - this deals with OIE
164
+ labels sometimes having some noise, as parentheses.
165
+ """
166
+ if "-" in label :
167
+ prefix , suffix = label .split ("-" )
168
+ suffix = suffix .split ("(" )[- 1 ]
169
+ return f"{ prefix } -{ suffix } "
170
+ else :
171
+ return label
172
+
63
173
@Predictor .register ('open-information-extraction' )
64
174
class OpenIePredictor (Predictor ):
65
175
"""
@@ -116,13 +226,16 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
116
226
for pred_id in pred_ids ]
117
227
118
228
# Run model
119
- outputs = [self ._model .forward_on_instance (instance )["tags" ]
229
+ outputs = [[ sanitize_label ( label ) for label in self ._model .forward_on_instance (instance )["tags" ] ]
120
230
for instance in instances ]
121
231
232
+ # Consolidate predictions
233
+ pred_dict = consolidate_predictions (outputs , sent_tokens )
234
+
122
235
# Build and return output dictionary
123
236
results = {"verbs" : [], "words" : sent_tokens }
124
237
125
- for tags , pred_id in zip ( outputs , pred_ids ):
238
+ for tags in pred_dict . values ( ):
126
239
# Join multi-word predicates
127
240
tags = join_mwp (tags )
128
241
@@ -131,7 +244,7 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
131
244
132
245
# Add a predicate prediction to the return dictionary.
133
246
results ["verbs" ].append ({
134
- "verb" : sent_tokens [ pred_id ]. text ,
247
+ "verb" : get_predicate_text ( sent_tokens , tags ) ,
135
248
"description" : description ,
136
249
"tags" : tags ,
137
250
})
0 commit comments