Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit ae72f79

Browse files
Better multi-word predicates in Open IE predictors (#1759)
* merging overlapping predicates
1 parent cca99b9 commit ae72f79

File tree

2 files changed

+164
-7
lines changed

2 files changed

+164
-7
lines changed

allennlp/predictors/open_information_extraction.py

+117-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Dict
22

33
from overrides import overrides
44

@@ -60,6 +60,116 @@ def make_oie_string(tokens: List[Token], tags: List[str]) -> str:
6060

6161
return " ".join(frame)
6262

63+
def get_predicate_indices(tags: List[str]) -> List[int]:
64+
"""
65+
Return the word indices of a predicate in BIO tags.
66+
"""
67+
return [ind for ind, tag in enumerate(tags) if 'V' in tag]
68+
69+
def get_predicate_text(sent_tokens: List[Token], tags: List[str]) -> str:
70+
"""
71+
Get the predicate in this prediction.
72+
"""
73+
return " ".join([sent_tokens[pred_id].text
74+
for pred_id in get_predicate_indices(tags)])
75+
76+
def predicates_overlap(tags1: List[str], tags2: List[str]) -> bool:
77+
"""
78+
Tests whether the predicate in BIO tags1 overlap
79+
with those of tags2.
80+
"""
81+
# Get predicate word indices from both predictions
82+
pred_ind1 = get_predicate_indices(tags1)
83+
pred_ind2 = get_predicate_indices(tags2)
84+
85+
# Return if pred_ind1 pred_ind2 overlap
86+
return any(set.intersection(set(pred_ind1), set(pred_ind2)))
87+
88+
def get_coherent_next_tag(prev_label: str, cur_label: str) -> str:
89+
"""
90+
Generate a coherent tag, given previous tag and current label.
91+
"""
92+
if cur_label == "O":
93+
# Don't need to add prefix to an "O" label
94+
return "O"
95+
96+
if prev_label == cur_label:
97+
return f"I-{cur_label}"
98+
else:
99+
return f"B-{cur_label}"
100+
101+
def merge_overlapping_predictions(tags1: List[str], tags2: List[str]) -> List[str]:
102+
"""
103+
Merge two predictions into one. Assumes the predicate in tags1 overlap with
104+
the predicate of tags2.
105+
"""
106+
ret_sequence = []
107+
prev_label = "O"
108+
109+
# Build a coherent sequence out of two
110+
# spans which predicates' overlap
111+
112+
for tag1, tag2 in zip(tags1, tags2):
113+
label1 = tag1.split("-")[-1]
114+
label2 = tag2.split("-")[-1]
115+
if (label1 == "V") or (label2 == "V"):
116+
# Construct maximal predicate length -
117+
# add predicate tag if any of the sequence predict it
118+
cur_label = "V"
119+
120+
# Else - prefer an argument over 'O' label
121+
elif label1 != "O":
122+
cur_label = label1
123+
else:
124+
cur_label = label2
125+
126+
# Append cur tag to the returned sequence
127+
cur_tag = get_coherent_next_tag(prev_label, cur_label)
128+
prev_label = cur_label
129+
ret_sequence.append(cur_tag)
130+
return ret_sequence
131+
132+
def consolidate_predictions(outputs: List[List[str]], sent_tokens: List[Token]) -> Dict[str, List[str]]:
133+
"""
134+
Identify that certain predicates are part of a multiword predicate
135+
(e.g., "decided to run") in which case, we don't need to return
136+
the embedded predicate ("run").
137+
"""
138+
pred_dict: Dict[str, List[str]] = {}
139+
merged_outputs = [join_mwp(output) for output in outputs]
140+
predicate_texts = [get_predicate_text(sent_tokens, tags)
141+
for tags in merged_outputs]
142+
143+
for pred1_text, tags1 in zip(predicate_texts, merged_outputs):
144+
# A flag indicating whether to add tags1 to predictions
145+
add_to_prediction = True
146+
147+
# Check if this predicate overlaps another predicate
148+
for pred2_text, tags2 in pred_dict.items():
149+
if predicates_overlap(tags1, tags2):
150+
# tags1 overlaps tags2
151+
pred_dict[pred2_text] = merge_overlapping_predictions(tags1, tags2)
152+
add_to_prediction = False
153+
154+
# This predicate doesn't overlap - add as a new predicate
155+
if add_to_prediction:
156+
pred_dict[pred1_text] = tags1
157+
158+
return pred_dict
159+
160+
161+
def sanitize_label(label: str) -> str:
162+
"""
163+
Sanitize a BIO label - this deals with OIE
164+
labels sometimes having some noise, as parentheses.
165+
"""
166+
if "-" in label:
167+
prefix, suffix = label.split("-")
168+
suffix = suffix.split("(")[-1]
169+
return f"{prefix}-{suffix}"
170+
else:
171+
return label
172+
63173
@Predictor.register('open-information-extraction')
64174
class OpenIePredictor(Predictor):
65175
"""
@@ -116,13 +226,16 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
116226
for pred_id in pred_ids]
117227

118228
# Run model
119-
outputs = [self._model.forward_on_instance(instance)["tags"]
229+
outputs = [[sanitize_label(label) for label in self._model.forward_on_instance(instance)["tags"]]
120230
for instance in instances]
121231

232+
# Consolidate predictions
233+
pred_dict = consolidate_predictions(outputs, sent_tokens)
234+
122235
# Build and return output dictionary
123236
results = {"verbs": [], "words": sent_tokens}
124237

125-
for tags, pred_id in zip(outputs, pred_ids):
238+
for tags in pred_dict.values():
126239
# Join multi-word predicates
127240
tags = join_mwp(tags)
128241

@@ -131,7 +244,7 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
131244

132245
# Add a predicate prediction to the return dictionary.
133246
results["verbs"].append({
134-
"verb": sent_tokens[pred_id].text,
247+
"verb": get_predicate_text(sent_tokens, tags),
135248
"description": description,
136249
"tags": tags,
137250
})

allennlp/tests/predictors/open_information_extraction_test.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from allennlp.common.testing import AllenNlpTestCase
33
from allennlp.models.archival import load_archive
44
from allennlp.predictors import Predictor
5+
from allennlp.predictors.open_information_extraction import consolidate_predictions, get_predicate_text
6+
from allennlp.data.tokenizers import WordTokenizer
7+
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
58

69
class TestOpenIePredictor(AllenNlpTestCase):
710
def test_uses_named_inputs(self):
@@ -27,9 +30,6 @@ def test_uses_named_inputs(self):
2730
assert verbs is not None
2831
assert isinstance(verbs, list)
2932

30-
predicates = [verb["verb"] for verb in verbs]
31-
assert predicates == ["met", "spoke"]
32-
3333
for verb in verbs:
3434
tags = verb.get("tags")
3535
assert tags is not None
@@ -49,3 +49,47 @@ def test_prediction_with_no_verbs(self):
4949

5050
result = predictor.predict_json(input1)
5151
assert result == {'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': []}
52+
53+
def test_predicate_consolidation(self):
54+
"""
55+
Test whether the predictor can correctly consolidate multiword
56+
predicates.
57+
"""
58+
tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=True))
59+
60+
sent_tokens = tokenizer.tokenize("In December, John decided to join the party.")
61+
62+
# Emulate predications - for both "decided" and "join"
63+
predictions = [['B-ARG2', 'I-ARG2', 'O', 'B-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', \
64+
'I-ARG1', 'I-ARG1', 'O'],
65+
['O', 'O', 'O', 'B-ARG0', 'B-BV', 'I-BV', 'B-V', 'B-ARG1', \
66+
'I-ARG1', 'O']]
67+
# Consolidate
68+
pred_dict = consolidate_predictions(predictions, sent_tokens)
69+
70+
# Check that only "decided to join" is left
71+
assert len(pred_dict) == 1
72+
tags = list(pred_dict.values())[0]
73+
assert get_predicate_text(sent_tokens, tags) == "decided to join"
74+
75+
def test_more_than_two_overlapping_predicates(self):
76+
"""
77+
Test whether the predictor can correctly consolidate multiword
78+
predicates.
79+
"""
80+
tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=True))
81+
82+
sent_tokens = tokenizer.tokenize("John refused to consider joining the club.")
83+
84+
# Emulate predications - for "refused" and "consider" and "joining"
85+
predictions = [['B-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'O'],\
86+
['B-ARG0', 'B-BV', 'I-BV', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1', 'O'],\
87+
['B-ARG0', 'B-BV', 'I-BV', 'I-BV', 'B-V', 'B-ARG1', 'I-ARG1', 'O']]
88+
89+
# Consolidate
90+
pred_dict = consolidate_predictions(predictions, sent_tokens)
91+
92+
# Check that only "refused to consider to join" is left
93+
assert len(pred_dict) == 1
94+
tags = list(pred_dict.values())[0]
95+
assert get_predicate_text(sent_tokens, tags) == "refused to consider joining"

0 commit comments

Comments
 (0)