Skip to content

Commit 9847dc1

Browse files
authored
Document.add_all_annotations_from_other returns a mapping (#418)
* Document.add_all_annotations_from_other returns a mapping from original annotations to new ones * make mypy happy * improve tests
1 parent c61f8e1 commit 9847dc1

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

src/pytorch_ie/core/document.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def add_all_annotations_from_other(
771771
process_predictions: bool = True,
772772
strict: bool = True,
773773
verbose: bool = True,
774-
) -> Dict[str, List[Annotation]]:
774+
) -> Dict[str, Dict[Annotation, Annotation]]:
775775
"""Adds all annotations from another document to this document. It allows to blacklist annotations
776776
and also to override annotations. It returns the original annotations for which a new annotation was
777777
added to the current document.
@@ -854,7 +854,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
854854
```
855855
"""
856856
removed_annotations = defaultdict(set, removed_annotations or dict())
857-
added_annotations = defaultdict(list)
857+
added_annotations: Dict[str, Dict[Annotation, Annotation]] = defaultdict(dict)
858858

859859
annotation_store: Dict[str, Dict[int, Annotation]] = defaultdict(dict)
860860
named_annotation_fields = {field.name: field for field in self.annotation_fields()}
@@ -897,7 +897,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
897897
if ann._id != new_ann._id:
898898
annotation_store[field_name][ann._id] = new_ann
899899
self[field_name].append(new_ann)
900-
added_annotations[field_name].append(ann)
900+
added_annotations[field_name][ann] = new_ann
901901
else:
902902
if strict:
903903
raise ValueError(
@@ -922,7 +922,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
922922
if ann._id != new_ann._id:
923923
annotation_store[field_name][ann._id] = new_ann
924924
self[field_name].predictions.append(new_ann)
925-
added_annotations[field_name].append(ann)
925+
added_annotations[field_name][ann] = new_ann
926926
else:
927927
if strict:
928928
raise ValueError(

tests/test_document.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -662,10 +662,13 @@ def test_document_extend_from_other_full_copy(text_document):
662662
"relation_attributes",
663663
"labels",
664664
}
665-
for layer_name, annotation_set in added_annotations.items():
666-
assert len(annotation_set) > 0
665+
for layer_name, annotation_mapping in added_annotations.items():
666+
assert len(annotation_mapping) > 0
667667
available_annotations = text_document[layer_name]
668-
assert annotation_set == list(available_annotations)
668+
assert set(annotation_mapping) == set(available_annotations)
669+
assert len(annotation_mapping) == 1
670+
# since we have only one annotation, we can construct the expected mapping
671+
assert annotation_mapping == {available_annotations[0]: doc_new[layer_name][0]}
669672

670673

671674
def test_document_extend_from_other_wrong_override_annotation_mapping(text_document):
@@ -705,12 +708,19 @@ class TestDocument2(TokenBasedDocument):
705708
added_annotations = token_document.add_all_annotations_from_other(
706709
text_document, override_annotations=annotation_mapping
707710
)
711+
added_annotation_sets = {k: set(v) for k, v in added_annotations.items()}
708712
# check that the added annotations are as expected (the entity annotations are already there)
709-
assert added_annotations == {
710-
"relations": list(text_document.relations),
711-
"relation_attributes": list(text_document.relation_attributes),
712-
"labels": list(text_document.labels),
713+
assert added_annotation_sets == {
714+
"relations": set(text_document.relations),
715+
"relation_attributes": set(text_document.relation_attributes),
716+
"labels": set(text_document.labels),
713717
}
718+
for layer_name, annotation_mapping in added_annotations.items():
719+
text_annotations = text_document[layer_name]
720+
token_annotations = token_document[layer_name]
721+
assert len(annotation_mapping) == len(text_annotations) == len(token_annotations) == 1
722+
# since we have only one annotation, we can construct the expected mapping
723+
assert annotation_mapping == {text_annotations[0]: token_annotations[0]}
714724

715725
assert (
716726
len(token_document.entities1)
@@ -740,11 +750,15 @@ def test_document_extend_from_other_remove(text_document):
740750
removed_annotations={"entities1": {text_document.entities1[0]._id}},
741751
strict=False,
742752
)
743-
753+
added_annotation_sets = {k: set(v) for k, v in added_annotations.items()}
744754
# the only entity in entities1 is removed and since the relation has it as head, the relation is removed as well
755+
assert added_annotation_sets == {
756+
"entities2": set(text_document.entities2),
757+
"labels": set(text_document.labels),
758+
}
745759
assert added_annotations == {
746-
"entities2": list(text_document.entities2),
747-
"labels": list(text_document.labels),
760+
"entities2": {text_document.entities2[0]: doc_new.entities2[0]},
761+
"labels": {text_document.labels[0]: doc_new.labels[0]},
748762
}
749763

750764
assert len(doc_new.entities1) == 0

0 commit comments

Comments
 (0)