Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 335d899

Browse files
Haoxun Zhanmatt-gardner
Haoxun Zhan
authored andcommitted
Make SpanBasedF1Measure support BMES (#1692)
* Make SpanBasedF1Measure support BMES. * Decoration change. * Fix code style. * Add test for SpanBasedF1Measure(label_encoding="BMES"). * Bugfix. * Fix testcase. * Sse text[2:] instead of text.partition('-')[2]
1 parent d16f6c0 commit 335d899

File tree

5 files changed

+159
-4
lines changed

5 files changed

+159
-4
lines changed

allennlp/data/dataset_readers/dataset_utils/span_utils.py

+69
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,72 @@ def process_stack(stack, out_stack):
371371
process_stack(stack, bioul_sequence)
372372

373373
return bioul_sequence
374+
375+
376+
def bmes_tags_to_spans(tag_sequence: List[str],
377+
classes_to_ignore: List[str] = None) -> List[TypedStringSpan]:
378+
"""
379+
Given a sequence corresponding to BMES tags, extracts spans.
380+
Spans are inclusive and can be of zero length, representing a single word span.
381+
Ill-formed spans are not allowed and will raise ``InvalidTagSequence``.
382+
This function works properly when the spans are unlabeled (i.e., your labels are
383+
simply "B", "M", "E" and "S").
384+
385+
Parameters
386+
----------
387+
tag_sequence : List[str], required.
388+
The integer class labels for a sequence.
389+
classes_to_ignore : List[str], optional (default = None).
390+
A list of string class labels `excluding` the bio tag
391+
which should be ignored when extracting spans.
392+
393+
Returns
394+
-------
395+
spans : List[TypedStringSpan]
396+
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
397+
Note that the label `does not` contain any BIO tag prefixes.
398+
"""
399+
def extract_bmes_tag_label(text):
400+
bmes_tag = text[0]
401+
label = text[2:]
402+
return bmes_tag, label
403+
404+
spans = []
405+
classes_to_ignore = classes_to_ignore or []
406+
invalid = False
407+
index = 0
408+
while index < len(tag_sequence) and not invalid:
409+
start_bmes_tag, start_label = extract_bmes_tag_label(tag_sequence[index])
410+
start_index = index
411+
412+
if start_bmes_tag == 'B':
413+
index += 1
414+
while index < len(tag_sequence):
415+
bmes_tag, label = extract_bmes_tag_label(tag_sequence[index])
416+
# Stop conditions.
417+
if label != start_label or bmes_tag not in ('M', 'E'):
418+
invalid = True
419+
break
420+
if bmes_tag == 'E':
421+
break
422+
# bmes_tag == 'M', move to next.
423+
index += 1
424+
425+
if index >= len(tag_sequence):
426+
invalid = True
427+
if not invalid:
428+
spans.append((start_label, (start_index, index)))
429+
430+
elif start_bmes_tag == 'S':
431+
spans.append((start_label, (start_index, start_index)))
432+
433+
else:
434+
invalid = True
435+
436+
# Move to next span.
437+
index += 1
438+
439+
if invalid:
440+
raise InvalidTagSequence(tag_sequence)
441+
442+
return [span for span in spans if span[0] not in classes_to_ignore]

allennlp/models/crf_tagger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class CrfTagger(Model):
3737
An optional feedforward layer to apply after the encoder.
3838
label_encoding : ``str``, optional (default=``None``)
3939
Label encoding to use when calculating span f1 and constraining
40-
the CRF at decoding time . Valid options are "BIO", "BIOUL", "IOB1".
40+
the CRF at decoding time . Valid options are "BIO", "BIOUL", "IOB1", "BMES".
4141
Required if ``calculate_span_f1`` or ``constrain_crf_decoding`` is true.
4242
constraint_type : ``str``, optional (default=``None``)
4343
If provided, the CRF will be constrained at decoding time

allennlp/tests/data/dataset_readers/dataset_utils/span_utils_test.py

+42
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,45 @@ def test_bio_to_bioul(self):
141141
with self.assertRaises(span_utils.InvalidTagSequence):
142142
tag_sequence = ['O', 'I-PER', 'B-PER', 'I-PER', 'I-PER', 'B-PER']
143143
bioul_sequence = span_utils.to_bioul(tag_sequence, encoding="BIO")
144+
145+
def test_bmes_tags_to_spans_extracts_correct_spans(self):
146+
tag_sequence = ["B-ARG1", "M-ARG1", "E-ARG1", "B-ARG2", "E-ARG2", "S-ARG3"]
147+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
148+
assert set(spans) == {("ARG1", (0, 2)), ("ARG2", (3, 4)), ("ARG3", (5, 5))}
149+
150+
tag_sequence = ["S-ARG1", "B-ARG2", "E-ARG2", "S-ARG3"]
151+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
152+
assert set(spans) == {("ARG1", (0, 0)), ("ARG2", (1, 2)), ("ARG3", (3, 3))}
153+
154+
# Check that it raises when labels are not correct.
155+
tag_sequence = ["B-ARG1", "M-ARG2", "E-ARG1"]
156+
with self.assertRaises(span_utils.InvalidTagSequence):
157+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
158+
159+
# Check that it raises when tag transitions are not correct.
160+
tag_sequence = ["B-ARG1", "B-ARG1"]
161+
with self.assertRaises(span_utils.InvalidTagSequence):
162+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
163+
tag_sequence = ["B-ARG1", "S-ARG1"]
164+
with self.assertRaises(span_utils.InvalidTagSequence):
165+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
166+
167+
def test_bmes_tags_to_spans_extracts_correct_spans_without_labels(self):
168+
tag_sequence = ["B", "M", "E", "B", "E", "S"]
169+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
170+
assert set(spans) == {("", (0, 2)), ("", (3, 4)), ("", (5, 5))}
171+
172+
tag_sequence = ["S", "B", "E", "S"]
173+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
174+
assert set(spans) == {("", (0, 0)), ("", (1, 2)), ("", (3, 3))}
175+
176+
# Check that it raises when tag transitions are not correct.
177+
tag_sequence = ["B", "B"]
178+
with self.assertRaises(span_utils.InvalidTagSequence):
179+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
180+
tag_sequence = ["B", "S"]
181+
with self.assertRaises(span_utils.InvalidTagSequence):
182+
spans = span_utils.bmes_tags_to_spans(tag_sequence)
183+
tag_sequence = ["B", "E", "M"]
184+
with self.assertRaises(span_utils.InvalidTagSequence):
185+
spans = span_utils.bmes_tags_to_spans(tag_sequence)

allennlp/tests/training/metrics/span_based_f1_measure_test.py

+40
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def setUp(self):
3131
vocab.add_token_to_namespace("B-ARGM-ADJ", "tags")
3232
vocab.add_token_to_namespace("I-ARGM-ADJ", "tags")
3333

34+
# BMES.
35+
vocab.add_token_to_namespace("B", "bmes_tags")
36+
vocab.add_token_to_namespace("M", "bmes_tags")
37+
vocab.add_token_to_namespace("E", "bmes_tags")
38+
vocab.add_token_to_namespace("S", "bmes_tags")
39+
3440
self.vocab = vocab
3541

3642
def test_span_metrics_are_computed_correcly_with_prediction_map(self):
@@ -167,6 +173,40 @@ def test_span_metrics_are_computed_correctly(self):
167173
numpy.testing.assert_almost_equal(metric_dict["precision-overall"], 0.5)
168174
numpy.testing.assert_almost_equal(metric_dict["f1-measure-overall"], 0.5)
169175

176+
def test_bmes_span_metrics_are_computed_correctly(self):
177+
# (bmes_tags) B:0, M:1, E:2, S:3.
178+
# [S, B, M, E, S]
179+
# [S, S, S, S, S]
180+
gold_indices = [[3, 0, 1, 2, 3],
181+
[3, 3, 3, 3, 3]]
182+
gold_tensor = torch.Tensor(gold_indices)
183+
184+
prediction_tensor = torch.rand([2, 5, 4])
185+
# [S, B, E, S, S]
186+
# TP: 2, FP: 2, FN: 1.
187+
prediction_tensor[0, 0, 3] = 1 # (True positive)
188+
prediction_tensor[0, 1, 0] = 1 # (False positive
189+
prediction_tensor[0, 2, 2] = 1 # *)
190+
prediction_tensor[0, 3, 3] = 1 # (False positive)
191+
prediction_tensor[0, 4, 3] = 1 # (True positive)
192+
# [B, E, S, B, E]
193+
# TP: 1, FP: 2, FN: 4.
194+
prediction_tensor[1, 0, 0] = 1 # (False positive
195+
prediction_tensor[1, 1, 2] = 1 # *)
196+
prediction_tensor[1, 2, 3] = 1 # (True positive)
197+
prediction_tensor[1, 3, 0] = 1 # (False positive
198+
prediction_tensor[1, 4, 2] = 1 # *)
199+
200+
metric = SpanBasedF1Measure(self.vocab, "bmes_tags", label_encoding="BMES")
201+
metric(prediction_tensor, gold_tensor)
202+
203+
# TP: 3, FP: 4, FN: 5.
204+
metric_dict = metric.get_metric()
205+
206+
numpy.testing.assert_almost_equal(metric_dict["recall-overall"], 0.375)
207+
numpy.testing.assert_almost_equal(metric_dict["precision-overall"], 0.428, decimal=3)
208+
numpy.testing.assert_almost_equal(metric_dict["f1-measure-overall"], 0.4)
209+
170210
def test_span_f1_can_build_from_params(self):
171211
params = Params({"type": "span_f1", "tag_namespace": "tags", "ignore_classes": ["V"]})
172212
metric = Metric.from_params(params=params, vocabulary=self.vocab)

allennlp/training/metrics/span_based_f1_measure.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
bio_tags_to_spans,
1212
bioul_tags_to_spans,
1313
iob1_tags_to_spans,
14+
bmes_tags_to_spans,
1415
TypedStringSpan
1516
)
1617

@@ -58,10 +59,10 @@ def __init__(self,
5859
spans in a BIO tagging scheme which are typically not included.
5960
label_encoding : ``str``, optional (default = "BIO")
6061
The encoding used to specify label span endpoints in the sequence.
61-
Valid options are "BIO", "IOB1", or BIOUL".
62+
Valid options are "BIO", "IOB1", "BIOUL" or "BMES".
6263
"""
63-
if label_encoding not in ["BIO", "IOB1", "BIOUL"]:
64-
raise ConfigurationError("Unknown label encoding - expected 'BIO', 'IOB1', 'BIOUL'.")
64+
if label_encoding not in ["BIO", "IOB1", "BIOUL", "BMES"]:
65+
raise ConfigurationError("Unknown label encoding - expected 'BIO', 'IOB1', 'BIOUL', 'BMES'.")
6566

6667
self._label_encoding = label_encoding
6768
self._label_vocabulary = vocabulary.get_index_to_token_vocabulary(tag_namespace)
@@ -143,6 +144,9 @@ def __call__(self,
143144
elif self._label_encoding == "BIOUL":
144145
predicted_spans = bioul_tags_to_spans(predicted_string_labels, self._ignore_classes)
145146
gold_spans = bioul_tags_to_spans(gold_string_labels, self._ignore_classes)
147+
elif self._label_encoding == "BMES":
148+
predicted_spans = bmes_tags_to_spans(predicted_string_labels, self._ignore_classes)
149+
gold_spans = bmes_tags_to_spans(gold_string_labels, self._ignore_classes)
146150

147151
predicted_spans = self._handle_continued_spans(predicted_spans)
148152
gold_spans = self._handle_continued_spans(gold_spans)

0 commit comments

Comments
 (0)