diff --git a/CHANGELOG.md b/CHANGELOG.md index 48c0d0ff31ba..48da28d95872 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Siammask tracker as DL serverless function () - [Datumaro] Added model info and source info commands () - [Datumaro] Dataset statistics () +- [Datumaro] Multi-dataset merge (https://github.com/opencv/cvat/pull/1695) ### Changed - Shape coordinates are rounded to 2 digits in dumped annotations () diff --git a/datumaro/datumaro/cli/__main__.py b/datumaro/datumaro/cli/__main__.py index a2946e968b49..fabe43f82b1d 100644 --- a/datumaro/datumaro/cli/__main__.py +++ b/datumaro/datumaro/cli/__main__.py @@ -68,6 +68,7 @@ def make_parser(): ('remove', commands.remove, "Remove source from project"), ('export', commands.export, "Export project"), ('explain', commands.explain, "Run Explainable AI algorithm for model"), + ('merge', commands.merge, "Merge datasets"), ('convert', commands.convert, "Convert dataset"), ] diff --git a/datumaro/datumaro/cli/commands/__init__.py b/datumaro/datumaro/cli/commands/__init__.py index 3c3bffe6a8a3..7249842e5a61 100644 --- a/datumaro/datumaro/cli/commands/__init__.py +++ b/datumaro/datumaro/cli/commands/__init__.py @@ -3,4 +3,4 @@ # # SPDX-License-Identifier: MIT -from . import add, create, explain, export, remove, convert +from . import add, create, explain, export, remove, merge, convert diff --git a/datumaro/datumaro/cli/commands/merge.py b/datumaro/datumaro/cli/commands/merge.py new file mode 100644 index 000000000000..2583cd8641bb --- /dev/null +++ b/datumaro/datumaro/cli/commands/merge.py @@ -0,0 +1,124 @@ + +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import json +import logging as log +import os.path as osp +from collections import OrderedDict + +from datumaro.components.project import Project +from datumaro.components.operations import (IntersectMerge, + QualityError, MergeError) + +from ..util import at_least, MultilineFormatter, CliException +from ..util.project import generate_next_file_name, load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Merge few projects", + description=""" + Merges multiple datasets into one. This can be useful if you + have few annotations and wish to merge them, + taking into consideration potential overlaps and conflicts. + This command can try to find a common ground by voting or + return a list of conflicts.|n + |n + Examples:|n + - Merge annotations from 3 (or more) annotators:|n + |s|smerge project1/ project2/ project3/|n + - Check groups of the merged dataset for consistence:|n + |s|s|slook for groups consising of 'person', 'hand' 'head', 'foot'|n + |s|smerge project1/ project2/ -g 'person,hand?,head,foot?' + """, + formatter_class=MultilineFormatter) + + def _group(s): + return s.split(',') + + parser.add_argument('project', nargs='+', action=at_least(2), + help="Path to a project (repeatable)") + parser.add_argument('-iou', '--iou-thresh', default=0.25, type=float, + help="IoU match threshold for segments (default: %(default)s)") + parser.add_argument('-oconf', '--output-conf-thresh', + default=0.0, type=float, + help="Confidence threshold for output " + "annotations (default: %(default)s)") + parser.add_argument('--quorum', default=0, type=int, + help="Minimum count for a label and attribute voting " + "results to be counted (default: %(default)s)") + parser.add_argument('-g', '--groups', action='append', type=_group, + default=[], + help="A comma-separated list of labels in " + "annotation groups to check. '?' postfix can be added to a label to" + "make it optional in the group (repeatable)") + parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, + help="Output directory (default: current project's dir)") + parser.add_argument('--overwrite', action='store_true', + help="Overwrite existing files in the save directory") + parser.set_defaults(command=merge_command) + + return parser + +def merge_command(args): + source_projects = [load_project(p) for p in args.project] + + dst_dir = args.dst_dir + if dst_dir: + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % dst_dir) + else: + dst_dir = generate_next_file_name('merged') + + source_datasets = [] + for p in source_projects: + log.debug("Loading project '%s' dataset", p.config.project_name) + source_datasets.append(p.make_dataset()) + + merger = IntersectMerge(conf=IntersectMerge.Conf( + pairwise_dist=args.iou_thresh, groups=args.groups, + output_conf_thresh=args.output_conf_thresh, quorum=args.quorum + )) + merged_dataset = merger(source_datasets) + + merged_project = Project() + output_dataset = merged_project.make_dataset() + output_dataset.define_categories(merged_dataset.categories()) + merged_dataset = output_dataset.update(merged_dataset) + merged_dataset.save(save_dir=dst_dir) + + report_path = osp.join(dst_dir, 'merge_report.json') + save_merge_report(merger, report_path) + + dst_dir = osp.abspath(dst_dir) + log.info("Merge results have been saved to '%s'" % dst_dir) + log.info("Report has been saved to '%s'" % report_path) + + return 0 + +def save_merge_report(merger, path): + item_errors = OrderedDict() + source_errors = OrderedDict() + all_errors = [] + + for e in merger.errors: + if isinstance(e, QualityError): + item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1 + elif isinstance(e, MergeError): + for s in e.sources: + source_errors[s] = source_errors.get(s, 0) + 1 + item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1 + + all_errors.append(str(e)) + + errors = OrderedDict([ + ('Item errors', item_errors), + ('Source errors', source_errors), + ('All errors', all_errors), + ]) + + with open(path, 'w') as f: + json.dump(errors, f, indent=4) \ No newline at end of file diff --git a/datumaro/datumaro/cli/util/__init__.py b/datumaro/datumaro/cli/util/__init__.py index 2d04a6936b71..3884b156a5cd 100644 --- a/datumaro/datumaro/cli/util/__init__.py +++ b/datumaro/datumaro/cli/util/__init__.py @@ -37,6 +37,28 @@ def _fill_text(self, text, width, indent): multiline_text += formatted_paragraph return multiline_text +def required_count(nmin=0, nmax=0): + assert 0 <= nmin and 0 <= nmax and nmin or nmax + + class RequiredCount(argparse.Action): + def __call__(self, parser, args, values, option_string=None): + k = len(values) + if not ((nmin and (nmin <= k) or not nmin) and \ + (nmax and (k <= nmax) or not nmax)): + msg = "Argument '%s' requires" % self.dest + if nmin and nmax: + msg += " from %s to %s arguments" % (nmin, nmax) + elif nmin: + msg += " at least %s arguments" % nmin + else: + msg += " no more %s arguments" % nmax + raise argparse.ArgumentTypeError(msg) + setattr(args, self.dest, values) + return RequiredCount + +def at_least(n): + return required_count(n, 0) + def make_file_name(s): # adapted from # https://docs.djangoproject.com/en/2.1/_modules/django/utils/text/#slugify diff --git a/datumaro/datumaro/components/algorithms/rise.py b/datumaro/datumaro/components/algorithms/rise.py index 277bedd2d3b4..2f65c8cfe676 100644 --- a/datumaro/datumaro/components/algorithms/rise.py +++ b/datumaro/datumaro/components/algorithms/rise.py @@ -9,6 +9,7 @@ from math import ceil from datumaro.components.extractor import AnnotationType +from datumaro.util.annotation_util import nms def flatmatvec(mat): @@ -51,24 +52,6 @@ def split_outputs(annotations): bboxes.append(r) return labels, bboxes - @staticmethod - def nms(boxes, iou_thresh=0.5): - indices = np.argsort([b.attributes['score'] for b in boxes]) - ious = np.array([[a.iou(b) for b in boxes] for a in boxes]) - - predictions = [] - while len(indices) != 0: - i = len(indices) - 1 - pred_idx = indices[i] - to_remove = [i] - predictions.append(boxes[pred_idx]) - for i, box_idx in enumerate(indices[:i]): - if iou_thresh < ious[pred_idx, box_idx]: - to_remove.append(i) - indices = np.delete(indices, to_remove) - - return predictions - def normalize_hmaps(self, heatmaps, counts): eps = np.finfo(heatmaps.dtype).eps mhmaps = flatmatvec(heatmaps) @@ -106,7 +89,7 @@ def apply(self, image, progressive=False): result_bboxes = [b for b in result_bboxes \ if self.det_conf_thresh <= b.attributes['score']] if 0 < self.nms_thresh: - result_bboxes = self.nms(result_bboxes, self.nms_thresh) + result_bboxes = nms(result_bboxes, self.nms_thresh) predicted_labels = set() if len(result_labels) != 0: @@ -194,7 +177,7 @@ def apply(self, image, progressive=False): result_bboxes = [b for b in result_bboxes \ if self.det_conf_thresh <= b.attributes['score']] if 0 < self.nms_thresh: - result_bboxes = self.nms(result_bboxes, self.nms_thresh) + result_bboxes = nms(result_bboxes, self.nms_thresh) for detection in result_bboxes: for pred_idx, pred in enumerate(predicted_bboxes): @@ -202,7 +185,7 @@ def apply(self, image, progressive=False): continue iou = pred.iou(detection) - assert 0 <= iou and iou <= 1 + assert iou == -1 or 0 <= iou and iou <= 1 if iou < iou_thresh: continue diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index 609c6e9ed876..d7991cd121e0 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -13,6 +13,7 @@ from datumaro.util.image import Image from datumaro.util.attrs_util import not_empty, default_if_none + AnnotationType = Enum('AnnotationType', [ 'label', @@ -28,9 +29,9 @@ @attrs class Annotation: - id = attrib(converter=int, default=0, kw_only=True) - attributes = attrib(converter=dict, factory=dict, kw_only=True) - group = attrib(converter=int, default=0, kw_only=True) + id = attrib(default=0, validator=default_if_none(int), kw_only=True) + attributes = attrib(factory=dict, validator=default_if_none(dict), kw_only=True) + group = attrib(default=0, validator=default_if_none(int), kw_only=True) def __attrs_post_init__(self): assert isinstance(self.type, AnnotationType) @@ -92,7 +93,7 @@ def _reindex(self): self._indices = indices def add(self, name, parent=None, attributes=None): - assert name not in self._indices + assert name not in self._indices, name if attributes is None: attributes = set() else: @@ -110,7 +111,7 @@ def add(self, name, parent=None, attributes=None): def find(self, name): index = self._indices.get(name) - if index: + if index is not None: return index, self.items[index] return index, None @@ -148,7 +149,7 @@ class Mask(Annotation): _image = attrib() label = attrib(converter=attr.converters.optional(int), default=None, kw_only=True) - z_order = attrib(converter=int, default=0, kw_only=True) + z_order = attrib(default=0, validator=default_if_none(int), kw_only=True) @property def image(self): @@ -274,31 +275,13 @@ def extract(self, instance_id): def lazy_extract(self, instance_id): return lambda: self.extract(instance_id) -def compute_iou(bbox_a, bbox_b): - aX, aY, aW, aH = bbox_a - bX, bY, bW, bH = bbox_b - in_right = min(aX + aW, bX + bW) - in_left = max(aX, bX) - in_top = max(aY, bY) - in_bottom = min(aY + aH, bY + bH) - - in_w = max(0, in_right - in_left) - in_h = max(0, in_bottom - in_top) - intersection = in_w * in_h - - a_area = aW * aH - b_area = bW * bH - union = a_area + b_area - intersection - - return intersection / max(1.0, union) - @attrs class _Shape(Annotation): points = attrib(converter=lambda x: [round(p, _COORDINATE_ROUNDING_DIGITS) for p in x]) label = attrib(converter=attr.converters.optional(int), default=None, kw_only=True) - z_order = attrib(converter=int, default=0, kw_only=True) + z_order = attrib(default=0, validator=default_if_none(int), kw_only=True) def get_area(self): raise NotImplementedError() @@ -386,7 +369,8 @@ def as_polygon(self): ] def iou(self, other): - return compute_iou(self.get_bbox(), other.get_bbox()) + from datumaro.util.annotation_util import bbox_iou + return bbox_iou(self.get_bbox(), other.get_bbox()) def wrap(item, **kwargs): d = {'x': item.x, 'y': item.y, 'w': item.w, 'h': item.h} diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 7961775e8b90..9e63d3a7e84e 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -3,14 +3,782 @@ # # SPDX-License-Identifier: MIT -import logging as log +from collections import OrderedDict from copy import deepcopy +import logging as log +import attr import cv2 import numpy as np +from attr import attrib, attrs + +from datumaro.components.cli_plugin import CliPlugin +from datumaro.components.extractor import AnnotationType, Bbox, Label +from datumaro.components.project import Dataset +from datumaro.util import find +from datumaro.util.attrs_util import ensure_cls +from datumaro.util.annotation_util import (segment_iou, bbox_iou, + mean_bbox, OKS, find_instances, max_bbox, smooth_line) + +def get_ann_type(anns, t): + return [a for a in anns if a.type == t] + +def match_annotations_equal(a, b): + matches = [] + a_unmatched = a[:] + b_unmatched = b[:] + for a_ann in a: + for b_ann in b_unmatched: + if a_ann != b_ann: + continue + + matches.append((a_ann, b_ann)) + a_unmatched.remove(a_ann) + b_unmatched.remove(b_ann) + break + + return matches, a_unmatched, b_unmatched + +def merge_annotations_equal(a, b): + matches, a_unmatched, b_unmatched = match_annotations_equal(a, b) + return [ann_a for (ann_a, _) in matches] + a_unmatched + b_unmatched + +def merge_categories(sources): + categories = {} + for source in sources: + categories.update(source) + for source in sources: + for cat_type, source_cat in source.items(): + if not categories[cat_type] == source_cat: + raise NotImplementedError( + "Merging different categories is not implemented yet") + return categories + +class MergingStrategy(CliPlugin): + @classmethod + def merge(cls, sources, **options): + instance = cls(**options) + return instance(sources) + + def __init__(self, **options): + super().__init__(**options) + self.__dict__['_sources'] = None + + def __call__(self, sources): + raise NotImplementedError() + + +@attrs +class DatasetError: + item_id = attrib() + +@attrs +class QualityError(DatasetError): + pass + +@attrs +class TooCloseError(QualityError): + a = attrib() + b = attrib() + distance = attrib() + + def __str__(self): + return "Item %s: annotations are too close: %s, %s, distance = %s" % \ + (self.item_id, self.a, self.b, self.distance) + +@attrs +class WrongGroupError(QualityError): + found = attrib(converter=set) + expected = attrib(converter=set) + group = attrib(converter=list) + + def __str__(self): + return "Item %s: annotation group has wrong labels: " \ + "found %s, expected %s, group %s" % \ + (self.item_id, self.found, self.expected, self.group) + +@attrs +class MergeError(DatasetError): + sources = attrib(converter=set) + +@attrs +class NoMatchingAnnError(MergeError): + ann = attrib() + + def __str__(self): + return "Item %s: can't find matching annotation " \ + "in sources %s, annotation is %s" % \ + (self.item_id, self.sources, self.ann) + +@attrs +class NoMatchingItemError(MergeError): + def __str__(self): + return "Item %s: can't find matching item in sources %s" % \ + (self.item_id, self.sources) + +@attrs +class FailedLabelVotingError(MergeError): + votes = attrib() + ann = attrib(default=None) + + def __str__(self): + return "Item %s: label voting failed%s, votes %s, sources %s" % \ + (self.item_id, 'for ann %s' % self.ann if self.ann else '', + self.votes, self.sources) + +@attrs +class FailedAttrVotingError(MergeError): + attr = attrib() + votes = attrib() + ann = attrib() + + def __str__(self): + return "Item %s: attribute voting failed " \ + "for ann %s, votes %s, sources %s" % \ + (self.item_id, self.ann, self.votes, self.sources) + +@attrs +class IntersectMerge(MergingStrategy): + @attrs(repr_ns='IntersectMerge', kw_only=True) + class Conf: + pairwise_dist = attrib(converter=float, default=0.5) + sigma = attrib(converter=list, factory=list) + + output_conf_thresh = attrib(converter=float, default=0) + quorum = attrib(converter=int, default=0) + ignored_attributes = attrib(converter=set, factory=set) + + def _groups_conveter(value): + result = [] + for group in value: + rg = set() + for label in group: + optional = label.endswith('?') + name = label if not optional else label[:-1] + rg.add((name, optional)) + result.append(rg) + return result + groups = attrib(converter=_groups_conveter, factory=list) + close_distance = attrib(converter=float, default=0.75) + conf = attrib(converter=ensure_cls(Conf), factory=Conf) + + # Error trackers: + errors = attrib(factory=list, init=False) + def add_item_error(self, error, *args, **kwargs): + self.errors.append(error(self._item_id, *args, **kwargs)) + + # Indexes: + _dataset_map = attrib(init=False) # id(dataset) -> (dataset, index) + _item_map = attrib(init=False) # id(item) -> (item, id(dataset)) + _ann_map = attrib(init=False) # id(ann) -> (ann, id(item)) + _item_id = attrib(init=False) + _item = attrib(init=False) + + # Misc. + _categories = attrib(init=False) # merged categories + + def __call__(self, datasets): + self._categories = merge_categories(d.categories() for d in datasets) + merged = Dataset(categories=self._categories) + + self._check_groups_definition() + + item_matches, item_map = self.match_items(datasets) + self._item_map = item_map + self._dataset_map = { id(d): (d, i) for i, d in enumerate(datasets) } + + for item_id, items in item_matches.items(): + self._item_id = item_id + + if len(items) < len(datasets): + missing_sources = set(id(s) for s in datasets) - set(items) + missing_sources = [self._dataset_map[s][1] + for s in missing_sources] + self.add_item_error(NoMatchingItemError, missing_sources) + merged.put(self.merge_items(items)) + + return merged + + def get_ann_source(self, ann_id): + return self._item_map[self._ann_map[ann_id][1]][1] + + def merge_items(self, items): + self._item = next(iter(items.values())) + + self._ann_map = {} + sources = [] + for item in items.values(): + self._ann_map.update({ id(a): (a, id(item)) + for a in item.annotations }) + sources.append(item.annotations) + log.debug("Merging item %s: source annotations %s" % \ + (self._item_id, list(map(len, sources)))) + + annotations = self.merge_annotations(sources) + + annotations = [a for a in annotations + if self.conf.output_conf_thresh <= a.attributes.get('score', 1)] + + return self._item.wrap(annotations=annotations) + + def merge_annotations(self, sources): + self._make_mergers(sources) + + clusters = self._match_annotations(sources) + + joined_clusters = sum(clusters.values(), []) + group_map = self._find_cluster_groups(joined_clusters) + + annotations = [] + for t, clusters in clusters.items(): + for cluster in clusters: + self._check_cluster_sources(cluster) + + merged_clusters = self._merge_clusters(t, clusters) + + for merged_ann, cluster in zip(merged_clusters, clusters): + attributes = self._find_cluster_attrs(cluster, merged_ann) + attributes = { k: v for k, v in attributes.items() + if k not in self.conf.ignored_attributes } + attributes.update(merged_ann.attributes) + merged_ann.attributes = attributes + + new_group_id = find(enumerate(group_map), + lambda e: id(cluster) in e[1][0]) + if new_group_id is None: + new_group_id = 0 + else: + new_group_id = new_group_id[0] + 1 + merged_ann.group = new_group_id + + if self.conf.close_distance: + self._check_annotation_distance(t, merged_clusters) + + annotations += merged_clusters + + if self.conf.groups: + self._check_groups(annotations) + + return annotations + + @staticmethod + def match_items(datasets): + item_ids = set((item.id, item.subset) for d in datasets for item in d) + + item_map = {} # id(item) -> (item, id(dataset)) + + matches = OrderedDict() + for (item_id, item_subset) in sorted(item_ids, key=lambda e: e[0]): + items = {} + for d in datasets: + try: + item = d.get(item_id, subset=item_subset) + items[id(d)] = item + item_map[id(item)] = (item, id(d)) + except KeyError: + pass + matches[(item_id, item_subset)] = items + + return matches, item_map + + def _match_annotations(self, sources): + all_by_type = {} + for s in sources: + src_by_type = {} + for a in s: + src_by_type.setdefault(a.type, []).append(a) + for k, v in src_by_type.items(): + all_by_type.setdefault(k, []).append(v) + + clusters = {} + for k, v in all_by_type.items(): + clusters.setdefault(k, []).extend(self._match_ann_type(k, v)) + + return clusters + + def _make_mergers(self, sources): + def _make(c, **kwargs): + kwargs.update(attr.asdict(self.conf)) + fields = attr.fields_dict(c) + return c(**{ k: v for k, v in kwargs.items() if k in fields }, + context=self) + + def _for_type(t, **kwargs): + if t is AnnotationType.label: + return _make(LabelMerger, **kwargs) + elif t is AnnotationType.bbox: + return _make(BboxMerger, **kwargs) + elif t is AnnotationType.mask: + return _make(MaskMerger, **kwargs) + elif t is AnnotationType.polygon: + return _make(PolygonMerger, **kwargs) + elif t is AnnotationType.polyline: + return _make(LineMerger, **kwargs) + elif t is AnnotationType.points: + return _make(PointsMerger, **kwargs) + elif t is AnnotationType.caption: + return _make(CaptionsMerger, **kwargs) + else: + raise NotImplementedError("Type %s is not supported" % t) + + instance_map = {} + for s in sources: + s_instances = find_instances(s) + for inst in s_instances: + inst_bbox = max_bbox([a for a in inst if a.type in + {AnnotationType.polygon, + AnnotationType.mask, AnnotationType.bbox} + ]) + for ann in inst: + instance_map[id(ann)] = [inst, inst_bbox] + + self._mergers = { t: _for_type(t, instance_map=instance_map) + for t in AnnotationType } + + def _match_ann_type(self, t, sources): + return self._mergers[t].match_annotations(sources) + + def _merge_clusters(self, t, clusters): + return self._mergers[t].merge_clusters(clusters) + + @staticmethod + def _find_cluster_groups(clusters): + cluster_groups = [] + visited = set() + for a_idx, cluster_a in enumerate(clusters): + if a_idx in visited: + continue + visited.add(a_idx) + + cluster_group = { id(cluster_a) } + + # find segment groups in the cluster group + a_groups = set(ann.group for ann in cluster_a) + for cluster_b in clusters[a_idx+1 :]: + b_groups = set(ann.group for ann in cluster_b) + if a_groups & b_groups: + a_groups |= b_groups + + # now we know all the segment groups in this cluster group + # so we can find adjacent clusters + for b_idx, cluster_b in enumerate(clusters[a_idx+1 :]): + b_idx = a_idx + 1 + b_idx + b_groups = set(ann.group for ann in cluster_b) + if a_groups & b_groups: + cluster_group.add( id(cluster_b) ) + visited.add(b_idx) + + if a_groups == {0}: + continue # skip annotations without a group + cluster_groups.append( (cluster_group, a_groups) ) + return cluster_groups + + def _find_cluster_attrs(self, cluster, ann): + quorum = self.conf.quorum or 0 + + # TODO: when attribute types are implemented, add linear + # interpolation for contiguous values + + attr_votes = {} # name -> { value: score , ... } + for s in cluster: + for name, value in s.attributes.items(): + votes = attr_votes.get(name, {}) + votes[value] = 1 + votes.get(value, 0) + attr_votes[name] = votes + + attributes = {} + for name, votes in attr_votes.items(): + winner, count = max(votes.items(), key=lambda e: e[1]) + if count < quorum: + if sum(votes.values()) < quorum: + # blame provokers + missing_sources = set( + self.get_ann_source(id(a)) for a in cluster + if s.attributes.get(name) == winner) + else: + # blame outliers + missing_sources = set( + self.get_ann_source(id(a)) for a in cluster + if s.attributes.get(name) != winner) + missing_sources = [self._dataset_map[s][1] + for s in missing_sources] + self.add_item_error(FailedAttrVotingError, + missing_sources, name, votes, ann) + continue + attributes[name] = winner + + return attributes + + def _check_cluster_sources(self, cluster): + if len(cluster) == len(self._dataset_map): + return + + def _has_item(s): + try: + item =self._dataset_map[s][0].get(*self._item_id) + if len(item.annotations) == 0: + return False + return True + except KeyError: + return False + + missing_sources = set(self._dataset_map) - \ + set(self.get_ann_source(id(a)) for a in cluster) + missing_sources = [self._dataset_map[s][1] for s in missing_sources + if _has_item(s)] + if missing_sources: + self.add_item_error(NoMatchingAnnError, missing_sources, cluster[0]) + + def _check_annotation_distance(self, t, annotations): + for a_idx, a_ann in enumerate(annotations): + for b_ann in annotations[a_idx+1:]: + d = self._mergers[t].distance(a_ann, b_ann) + if self.conf.close_distance < d: + self.add_item_error(TooCloseError, a_ann, b_ann, d) + + def _check_groups(self, annotations): + check_groups = [] + for check_group_raw in self.conf.groups: + check_group = set(l[0] for l in check_group_raw) + optional = set(l[0] for l in check_group_raw if l[1]) + check_groups.append((check_group, optional)) + + def _check_group(group_labels, group): + for check_group, optional in check_groups: + common = check_group & group_labels + real_miss = check_group - common - optional + extra = group_labels - check_group + if common and (extra or real_miss): + self.add_item_error(WrongGroupError, group_labels, + check_group, group) + break + + groups = find_instances(annotations) + for group in groups: + group_labels = set() + for ann in group: + if not hasattr(ann, 'label'): + continue + label = self._get_label_name(ann.label) + + if ann.group: + group_labels.add(label) + else: + _check_group({label}, [ann]) + + if not group_labels: + continue + _check_group(group_labels, group) + + def _get_label_name(self, label_id): + return self._categories[AnnotationType.label].items[label_id].name + + def _check_groups_definition(self): + for group in self.conf.groups: + for label, _ in group: + _, entry = self._categories[AnnotationType.label].find(label) + if entry is None: + raise ValueError("Datasets do not contain " + "label '%s', available labels %s" % \ + (label, [i.name for i in + self._categories[AnnotationType.label].items]) + ) + +@attrs +class AnnotationMatcher: + def match_annotations(self, sources): + raise NotImplementedError() + +@attrs +class LabelMatcher(AnnotationMatcher): + @staticmethod + def distance(a, b): + return a.label == b.label + + def match_annotations(self, sources): + return [sum(sources, [])] + +@attrs(kw_only=True) +class _ShapeMatcher(AnnotationMatcher): + pairwise_dist = attrib(converter=float, default=0.9) + cluster_dist = attrib(converter=float, default=-1.0) + + def match_annotations(self, sources): + distance = self.distance + pairwise_dist = self.pairwise_dist + cluster_dist = self.cluster_dist + + if cluster_dist < 0: cluster_dist = pairwise_dist + + id_segm = { id(a): (a, id(s)) for s in sources for a in s } + + def _is_close_enough(cluster, extra_id): + # check if whole cluster IoU will not be broken + # when this segment is added + b = id_segm[extra_id][0] + for a_id in cluster: + a = id_segm[a_id][0] + if distance(a, b) < cluster_dist: + return False + return True + + def _has_same_source(cluster, extra_id): + b = id_segm[extra_id][1] + for a_id in cluster: + a = id_segm[a_id][1] + if a == b: + return True + return False + + # match segments in sources, pairwise + adjacent = { i: [] for i in id_segm } # id(sgm) -> [id(adj_sgm1), ...] + for a_idx, src_a in enumerate(sources): + for src_b in sources[a_idx+1 :]: + matches, _, _, _ = match_segments(src_a, src_b, + dist_thresh=pairwise_dist, distance=distance) + for m in matches: + adjacent[id(m[0])].append(id(m[1])) + + # join all segments into matching clusters + clusters = [] + visited = set() + for cluster_idx in adjacent: + if cluster_idx in visited: + continue + + cluster = set() + to_visit = { cluster_idx } + while to_visit: + c = to_visit.pop() + cluster.add(c) + visited.add(c) + + for i in adjacent[c]: + if i in visited: + continue + if 0 < cluster_dist and not _is_close_enough(cluster, i): + continue + if _has_same_source(cluster, i): + continue + + to_visit.add(i) + + clusters.append([id_segm[i][0] for i in cluster]) + + return clusters + + @staticmethod + def distance(a, b): + return segment_iou(a, b) + +@attrs +class BboxMatcher(_ShapeMatcher): + pass + +@attrs +class PolygonMatcher(_ShapeMatcher): + pass + +@attrs +class MaskMatcher(_ShapeMatcher): + pass + +@attrs(kw_only=True) +class PointsMatcher(_ShapeMatcher): + sigma = attrib(converter=list, default=None) + instance_map = attrib(converter=dict) + + def distance(self, a, b): + a_bbox = self.instance_map[id(a)][1] + b_bbox = self.instance_map[id(b)][1] + if bbox_iou(a_bbox, b_bbox) <= 0: + return 0 + bbox = mean_bbox([a_bbox, b_bbox]) + return OKS(a, b, sigma=self.sigma, bbox=bbox) + +@attrs +class LineMatcher(_ShapeMatcher): + @staticmethod + def distance(a, b): + a_bbox = a.get_bbox() + b_bbox = b.get_bbox() + bbox = max_bbox([a_bbox, b_bbox]) + area = bbox[2] * bbox[3] + if not area: + return 1 + + # compute inter-line area, normalize by common bbox + point_count = max(max(len(a.points) // 2, len(b.points) // 2), 5) + a, sa = smooth_line(a.points, point_count) + b, sb = smooth_line(b.points, point_count) + dists = np.linalg.norm(a - b, axis=1) + dists = (dists[:-1] + dists[1:]) * 0.5 + s = np.sum(dists) * 0.5 * (sa + sb) / area + return abs(1 - s) + +@attrs +class CaptionsMatcher(AnnotationMatcher): + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs(kw_only=True) +class AnnotationMerger: + _context = attrib(type=IntersectMerge, default=None) + + def merge_clusters(self, clusters): + raise NotImplementedError() + +@attrs(kw_only=True) +class LabelMerger(AnnotationMerger, LabelMatcher): + quorum = attrib(converter=int, default=0) + + def merge_clusters(self, clusters): + assert len(clusters) <= 1 + if len(clusters) == 0: + return [] + + votes = {} # label -> score + for label_ann in clusters[0]: + votes[label_ann.label] = 1 + votes.get(label_ann.label, 0) + + merged = [] + for label, count in votes.items(): + if count < self.quorum: + sources = set(self.get_ann_source(id(a)) for a in clusters[0] + if label not in [l.label for l in a]) + sources = [self._context._dataset_map[s][1] for s in sources] + self._context.add_item_error(FailedLabelVotingError, + sources, votes) + continue + + merged.append(Label(label, attributes={ + 'score': count / len(self._context._dataset_map) + })) + + return merged + +@attrs(kw_only=True) +class _ShapeMerger(AnnotationMerger, _ShapeMatcher): + quorum = attrib(converter=int, default=0) + + def merge_clusters(self, clusters): + merged = [] + for cluster in clusters: + label, label_score = self.find_cluster_label(cluster) + shape, shape_score = self.merge_cluster_shape(cluster) + + shape.z_order = max(cluster, key=lambda a: a.z_order).z_order + shape.label = label + shape.attributes['score'] = label_score * shape_score \ + if label is not None else shape_score + + merged.append(shape) + + return merged + + def find_cluster_label(self, cluster): + votes = {} + for s in cluster: + state = votes.setdefault(s.label, [0, 0]) + state[0] += s.attributes.get('score', 1.0) + state[1] += 1 + + label, (score, count) = max(votes.items(), key=lambda e: e[1][0]) + if count < self.quorum: + self._context.add_item_error(FailedLabelVotingError, votes) + score = score / count if count else None + return label, score + + @staticmethod + def _merge_cluster_shape_mean_box_nearest(cluster): + mbbox = Bbox(*mean_bbox(cluster)) + dist = (segment_iou(mbbox, s) for s in cluster) + nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1]) + return cluster[nearest_pos] + + def merge_cluster_shape(self, cluster): + shape = self._merge_cluster_shape_mean_box_nearest(cluster) + shape_score = sum(max(0, self.distance(shape, s)) + for s in cluster) / len(cluster) + return shape, shape_score + +@attrs +class BboxMerger(_ShapeMerger, BboxMatcher): + pass + +@attrs +class PolygonMerger(_ShapeMerger, PolygonMatcher): + pass + +@attrs +class MaskMerger(_ShapeMerger, MaskMatcher): + pass + +@attrs +class PointsMerger(_ShapeMerger, PointsMatcher): + pass + +@attrs +class LineMerger(_ShapeMerger, LineMatcher): + pass + +@attrs +class CaptionsMerger(AnnotationMerger, CaptionsMatcher): + pass + +def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0): + if distance == 'iou': + distance = segment_iou + else: + assert callable(distance) + + a_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) + b_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) + + # a_matches: indices of b_segms matched to a bboxes + # b_matches: indices of a_segms matched to b bboxes + a_matches = -np.ones(len(a_segms), dtype=int) + b_matches = -np.ones(len(b_segms), dtype=int) + + distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms]) + + # matches: boxes we succeeded to match completely + # mispred: boxes we succeeded to match, having label mismatch + matches = [] + mispred = [] + + for a_idx, a_segm in enumerate(a_segms): + if len(b_segms) == 0: + break + matched_b = a_matches[a_idx] + max_dist = max(distances[a_idx, matched_b], dist_thresh) + for b_idx, b_segm in enumerate(b_segms): + if 0 <= b_matches[b_idx]: # assign a_segm with max conf + continue + d = distances[a_idx, b_idx] + if d < max_dist: + continue + max_dist = d + matched_b = b_idx + + if matched_b < 0: + continue + a_matches[a_idx] = matched_b + b_matches[matched_b] = a_idx + + b_segm = b_segms[matched_b] + + if a_segm.label == b_segm.label: + matches.append( (a_segm, b_segm) ) + else: + mispred.append( (a_segm, b_segm) ) -from datumaro.components.extractor import AnnotationType + # *_umatched: boxes of (*) we failed to match + a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0] + b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0] + return matches, mispred, a_unmatched, b_unmatched def mean_std(dataset): """ diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index 9ee388397379..8ac3ceb02844 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -351,16 +351,7 @@ def categories(self): @classmethod def from_extractors(cls, *sources): - # merge categories - # TODO: implement properly with merging and annotations remapping - categories = {} - for source in sources: - categories.update(source.categories()) - for source in sources: - for cat_type, source_cat in source.categories().items(): - if not categories[cat_type] == source_cat: - raise NotImplementedError( - "Merging different categories is not implemented yet") + categories = cls._merge_categories(s.categories() for s in sources) dataset = Dataset(categories=categories) # merge items @@ -457,7 +448,7 @@ def _lazy_image(item): @classmethod def _merge_items(cls, existing_item, current_item, path=None): return existing_item.wrap(path=path, - image=cls._merge_images(existing_item, current_item), + image=cls._merge_images(existing_item, current_item), annotations=cls._merge_anno( existing_item.annotations, current_item.annotations)) @@ -489,18 +480,15 @@ def _merge_images(existing_item, current_item): @staticmethod def _merge_anno(a, b): - from itertools import chain - merged = [] - for item in chain(a, b): - found = False - for elem in merged: - if elem == item: - found = True - break - if not found: - merged.append(item) - - return merged + # TODO: implement properly with merging and annotations remapping + from .operations import merge_annotations_equal + return merge_annotations_equal(a, b) + + @staticmethod + def _merge_categories(sources): + # TODO: implement properly with merging and annotations remapping + from .operations import merge_categories + return merge_categories(sources) class ProjectDataset(Dataset): def __init__(self, project): @@ -535,14 +523,9 @@ def __init__(self, project): # merge categories # TODO: implement properly with merging and annotations remapping - categories = {} - for source in self._sources.values(): - categories.update(source.categories()) - for source in self._sources.values(): - for cat_type, source_cat in source.categories().items(): - if not categories[cat_type] == source_cat: - raise NotImplementedError( - "Merging different categories is not implemented yet") + categories = self._merge_categories(s.categories() + for s in self._sources.values()) + # ovewrite with own categories if own_source is not None and (not categories or len(own_source) != 0): categories.update(own_source.categories()) self._categories = categories diff --git a/datumaro/datumaro/plugins/accuracy_checker_plugin/details/representation.py b/datumaro/datumaro/plugins/accuracy_checker_plugin/details/representation.py index 023c0955f862..d7007806bfde 100644 --- a/datumaro/datumaro/plugins/accuracy_checker_plugin/details/representation.py +++ b/datumaro/datumaro/plugins/accuracy_checker_plugin/details/representation.py @@ -9,7 +9,7 @@ import accuracy_checker.representation as ac import datumaro.components.extractor as dm -from datumaro.util.annotation_tools import softmax +from datumaro.util.annotation_util import softmax def import_predictions(predictions): # Convert Accuracy checker predictions to Datumaro annotations diff --git a/datumaro/datumaro/plugins/coco_format/converter.py b/datumaro/datumaro/plugins/coco_format/converter.py index 392f02619351..27cdd08754a3 100644 --- a/datumaro/datumaro/plugins/coco_format/converter.py +++ b/datumaro/datumaro/plugins/coco_format/converter.py @@ -12,7 +12,7 @@ import pycocotools.mask as mask_utils -import datumaro.util.annotation_tools as anno_tools +import datumaro.util.annotation_util as anno_tools import datumaro.util.mask_tools as mask_tools from datumaro.components.converter import Converter from datumaro.components.extractor import (_COORDINATE_ROUNDING_DIGITS, @@ -202,7 +202,7 @@ def find_instance_parts(self, group, img_width, img_height): anns = boxes + polygons + masks leader = anno_tools.find_group_leader(anns) - bbox = anno_tools.compute_bbox(anns) + bbox = anno_tools.max_bbox(anns) mask = None polygons = [p.points for p in polygons] diff --git a/datumaro/datumaro/plugins/cvat_format/converter.py b/datumaro/datumaro/plugins/cvat_format/converter.py index 0db14e1a17d3..37751703aba3 100644 --- a/datumaro/datumaro/plugins/cvat_format/converter.py +++ b/datumaro/datumaro/plugins/cvat_format/converter.py @@ -11,7 +11,7 @@ from datumaro.components.converter import Converter from datumaro.components.extractor import DEFAULT_SUBSET_NAME, AnnotationType -from datumaro.util import cast, pairwise +from datumaro.util import cast, pairs from .format import CvatPath @@ -246,7 +246,7 @@ def _write_shape(self, shape): ','.join(( "{:.2f}".format(x), "{:.2f}".format(y) - )) for x, y in pairwise(shape.points)) + )) for x, y in pairs(shape.points)) )), ])) diff --git a/datumaro/datumaro/plugins/tf_detection_api_format/converter.py b/datumaro/datumaro/plugins/tf_detection_api_format/converter.py index 481e8a2ed149..7ff3569dba23 100644 --- a/datumaro/datumaro/plugins/tf_detection_api_format/converter.py +++ b/datumaro/datumaro/plugins/tf_detection_api_format/converter.py @@ -16,7 +16,7 @@ ) from datumaro.components.converter import Converter from datumaro.util.image import encode_image -from datumaro.util.annotation_tools import (compute_bbox, +from datumaro.util.annotation_util import (max_bbox, find_group_leader, find_instances) from datumaro.util.mask_tools import merge_masks from datumaro.util.tf_util import import_tf as _import_tf @@ -111,7 +111,7 @@ def _find_instance_parts(self, group, img_width, img_height): anns = boxes + masks leader = find_group_leader(anns) - bbox = compute_bbox(anns) + bbox = max_bbox(anns) mask = None if self._save_masks: diff --git a/datumaro/datumaro/plugins/transforms.py b/datumaro/datumaro/plugins/transforms.py index b31b4762ccdf..368a891a02d2 100644 --- a/datumaro/datumaro/plugins/transforms.py +++ b/datumaro/datumaro/plugins/transforms.py @@ -17,7 +17,7 @@ ) from datumaro.components.cli_plugin import CliPlugin import datumaro.util.mask_tools as mask_tools -from datumaro.util.annotation_tools import find_group_leader, find_instances +from datumaro.util.annotation_util import find_group_leader, find_instances class CropCoveredSegments(Transform, CliPlugin): diff --git a/datumaro/datumaro/util/__init__.py b/datumaro/datumaro/util/__init__.py index 126a365bcc64..293bb5f62f34 100644 --- a/datumaro/datumaro/util/__init__.py +++ b/datumaro/datumaro/util/__init__.py @@ -62,7 +62,7 @@ def to_snake_case(s): name.append(char) return ''.join(name) -def pairwise(iterable): +def pairs(iterable): a = iter(iterable) return zip(a, a) diff --git a/datumaro/datumaro/util/annotation_tools.py b/datumaro/datumaro/util/annotation_tools.py deleted file mode 100644 index add234e782b4..000000000000 --- a/datumaro/datumaro/util/annotation_tools.py +++ /dev/null @@ -1,34 +0,0 @@ - -# Copyright (C) 2020 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from itertools import groupby - -import numpy as np - - -def find_instances(instance_anns): - instance_anns = sorted(instance_anns, key=lambda a: a.group) - ann_groups = [] - for g_id, group in groupby(instance_anns, lambda a: a.group): - if not g_id: - ann_groups.extend(([a] for a in group)) - else: - ann_groups.append(list(group)) - - return ann_groups - -def find_group_leader(group): - return max(group, key=lambda x: x.get_area()) - -def compute_bbox(annotations): - boxes = [ann.get_bbox() for ann in annotations] - x0 = min((b[0] for b in boxes), default=0) - y0 = min((b[1] for b in boxes), default=0) - x1 = max((b[0] + b[2] for b in boxes), default=0) - y1 = max((b[1] + b[3] for b in boxes), default=0) - return [x0, y0, x1 - x0, y1 - y0] - -def softmax(x): - return np.exp(x) / sum(np.exp(x)) diff --git a/datumaro/datumaro/util/annotation_util.py b/datumaro/datumaro/util/annotation_util.py new file mode 100644 index 000000000000..38a2c814c02d --- /dev/null +++ b/datumaro/datumaro/util/annotation_util.py @@ -0,0 +1,213 @@ + +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from itertools import groupby + +import numpy as np + +from datumaro.components.extractor import _Shape, Mask, AnnotationType, RleMask +from datumaro.util.mask_tools import mask_to_rle + + +def find_instances(instance_anns): + instance_anns = sorted(instance_anns, key=lambda a: a.group) + ann_groups = [] + for g_id, group in groupby(instance_anns, lambda a: a.group): + if not g_id: + ann_groups.extend(([a] for a in group)) + else: + ann_groups.append(list(group)) + + return ann_groups + +def find_group_leader(group): + return max(group, key=lambda x: x.get_area()) + +def _get_bbox(ann): + if isinstance(ann, (_Shape, Mask)): + return ann.get_bbox() + else: + return ann + +def max_bbox(annotations): + boxes = [_get_bbox(ann) for ann in annotations] + x0 = min((b[0] for b in boxes), default=0) + y0 = min((b[1] for b in boxes), default=0) + x1 = max((b[0] + b[2] for b in boxes), default=0) + y1 = max((b[1] + b[3] for b in boxes), default=0) + return [x0, y0, x1 - x0, y1 - y0] + +def mean_bbox(annotations): + le = len(annotations) + boxes = [_get_bbox(ann) for ann in annotations] + mlb = sum(b[0] for b in boxes) / le + mtb = sum(b[1] for b in boxes) / le + mrb = sum(b[0] + b[2] for b in boxes) / le + mbb = sum(b[1] + b[3] for b in boxes) / le + return [mlb, mtb, mrb - mlb, mbb - mtb] + +def softmax(x): + return np.exp(x) / sum(np.exp(x)) + +def nms(segments, iou_thresh=0.5): + """ + Non-maxima suppression algorithm. + """ + + indices = np.argsort([b.attributes['score'] for b in segments]) + ious = np.array([[iou(a, b) for b in segments] for a in segments]) + + predictions = [] + while len(indices) != 0: + i = len(indices) - 1 + pred_idx = indices[i] + to_remove = [i] + predictions.append(segments[pred_idx]) + for i, box_idx in enumerate(indices[:i]): + if iou_thresh < ious[pred_idx, box_idx]: + to_remove.append(i) + indices = np.delete(indices, to_remove) + + return predictions + +def bbox_iou(a, b): + """ + IoU computations for simple cases with bounding boxes + """ + bbox_a = _get_bbox(a) + bbox_b = _get_bbox(b) + + aX, aY, aW, aH = bbox_a + bX, bY, bW, bH = bbox_b + in_right = min(aX + aW, bX + bW) + in_left = max(aX, bX) + in_top = max(aY, bY) + in_bottom = min(aY + aH, bY + bH) + + in_w = max(0, in_right - in_left) + in_h = max(0, in_bottom - in_top) + intersection = in_w * in_h + if not intersection: + return -1 + + a_area = aW * aH + b_area = bW * bH + union = a_area + b_area - intersection + return intersection / union + +def segment_iou(a, b): + """ + Generic IoU computation with masks, polygons, and boxes. + Returns -1 if no intersection, [0; 1] otherwise + """ + from pycocotools import mask as mask_utils + + a_bbox = a.get_bbox() + b_bbox = b.get_bbox() + + is_bbox = AnnotationType.bbox in [a.type, b.type] + if is_bbox: + a = [a_bbox] + b = [b_bbox] + else: + w = max(a_bbox[0] + a_bbox[2], b_bbox[0] + b_bbox[2]) + h = max(a_bbox[1] + a_bbox[3], b_bbox[1] + b_bbox[3]) + + def _to_rle(ann): + if ann.type == AnnotationType.polygon: + return mask_utils.frPyObjects([ann.points], h, w) + elif isinstance(ann, RleMask): + return [ann._rle] + elif ann.type == AnnotationType.mask: + return mask_utils.frPyObjects([mask_to_rle(ann.image)], h, w) + else: + raise TypeError("Unexpected arguments: %s, %s" % (a, b)) + a = _to_rle(a) + b = _to_rle(b) + return float(mask_utils.iou(a, b, [not is_bbox])) + +def PDJ(a, b, eps=None, ratio=0.05, bbox=None): + """ + Percentage of Detected Joints metric. + Counts the number of matching points. + """ + + assert eps is not None or ratio is not None + + p1 = np.array(a.points).reshape((-1, 2)) + p2 = np.array(b.points).reshape((-1, 2)) + if len(p1) != len(p2): + return 0 + + if not eps: + if bbox is None: + bbox = mean_bbox([a, b]) + + diag = (bbox[2] ** 2 + bbox[3] ** 2) ** 0.5 + eps = ratio * diag + + dists = np.linalg.norm(p1 - p2, axis=1) + return np.sum(dists < eps) / len(p1) + +def OKS(a, b, sigma=None, bbox=None, scale=None): + """ + Object Keypoint Similarity metric. + https://cocodataset.org/#keypoints-eval + """ + + p1 = np.array(a.points).reshape((-1, 2)) + p2 = np.array(b.points).reshape((-1, 2)) + if len(p1) != len(p2): + return 0 + + if not sigma: + sigma = 0.1 + else: + assert len(sigma) == len(p1) + + if not scale: + if bbox is None: + bbox = mean_bbox([a, b]) + scale = bbox[2] * bbox[3] + + dists = np.linalg.norm(p1 - p2, axis=1) + return np.sum(np.exp(-(dists ** 2) / (2 * scale * (2 * sigma) ** 2))) + +def smooth_line(points, segments): + assert 2 <= len(points) // 2 and len(points) % 2 == 0 + + if len(points) // 2 == segments: + return points + + points = list(points) + if len(points) == 2: + points.extend(points) + points = np.array(points).reshape((-1, 2)) + + lengths = np.linalg.norm(points[1:] - points[:-1], axis=1) + dists = [0] + for l in lengths: + dists.append(dists[-1] + l) + + step = dists[-1] / segments + + new_points = np.zeros((segments + 1, 2)) + new_points[0] = points[0] + + old_segment = 0 + for new_segment in range(1, segments + 1): + pos = new_segment * step + while dists[old_segment + 1] < pos and old_segment + 2 < len(dists): + old_segment += 1 + + segment_start = dists[old_segment] + segment_len = lengths[old_segment] + prev_p = points[old_segment] + next_p = points[old_segment + 1] + r = (pos - segment_start) / segment_len + + new_points[new_segment] = prev_p * (1 - r) + next_p * r + + return new_points, step diff --git a/datumaro/datumaro/util/attrs_util.py b/datumaro/datumaro/util/attrs_util.py index af92c5499a78..15f0c3183ef0 100644 --- a/datumaro/datumaro/util/attrs_util.py +++ b/datumaro/datumaro/util/attrs_util.py @@ -23,4 +23,12 @@ def validator(inst, attribute, value): elif not isinstance(value, attribute.type or conv): value = conv(value) setattr(inst, attribute.name, value) - return validator \ No newline at end of file + return validator + +def ensure_cls(c): + def converter(arg): + if isinstance(arg, c): + return arg + else: + return c(**arg) + return converter \ No newline at end of file diff --git a/datumaro/datumaro/util/test_utils.py b/datumaro/datumaro/util/test_utils.py index cca952787d81..f93a74ce1b37 100644 --- a/datumaro/datumaro/util/test_utils.py +++ b/datumaro/datumaro/util/test_utils.py @@ -65,7 +65,22 @@ def compare_categories(test, expected, actual): actual[AnnotationType.points].items, ) -def compare_datasets(test, expected, actual): +def _compare_annotations(expected, actual, ignored_attrs=None): + if not ignored_attrs: + return expected == actual + + a_attr = expected.attributes + b_attr = actual.attributes + + expected.attributes = {k:v for k,v in a_attr.items() if k not in ignored_attrs} + actual.attributes = {k:v for k,v in b_attr.items() if k not in ignored_attrs} + r = expected == actual + + expected.attributes = a_attr + actual.attributes = b_attr + return r + +def compare_datasets(test, expected, actual, ignored_attrs=None): compare_categories(test, expected.categories(), actual.categories()) test.assertEqual(sorted(expected.subsets()), sorted(actual.subsets())) @@ -82,8 +97,11 @@ def compare_datasets(test, expected, actual): if x.type == ann_a.type] test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id) - ann_b = find(ann_b_matches, lambda x: x == ann_a) - test.assertEqual(ann_a, ann_b, 'ann %s, candidates %s' % (ann_a, ann_b_matches)) + ann_b = find(ann_b_matches, lambda x: + _compare_annotations(x, ann_a, ignored_attrs=ignored_attrs)) + if ann_b is None: + test.assertEqual(ann_a, ann_b, + 'ann %s, candidates %s' % (ann_a, ann_b_matches)) item_b.annotations.remove(ann_b) # avoid repeats def compare_datasets_strict(test, expected, actual): diff --git a/datumaro/tests/test_ops.py b/datumaro/tests/test_ops.py index ed165b2ddd19..dd4520b52f6b 100644 --- a/datumaro/tests/test_ops.py +++ b/datumaro/tests/test_ops.py @@ -1,11 +1,14 @@ +from unittest import TestCase + import numpy as np -from datumaro.components.extractor import (Extractor, DatasetItem, Label, - Mask, Bbox, Points, Caption) +from datumaro.components.extractor import (Bbox, Caption, DatasetItem, + Extractor, Label, Mask, Points, Polygon, PolyLine) +from datumaro.components.operations import (FailedAttrVotingError, + IntersectMerge, NoMatchingAnnError, NoMatchingItemError, WrongGroupError, + compute_ann_statistics, mean_std) from datumaro.components.project import Dataset -from datumaro.components.operations import mean_std, compute_ann_statistics - -from unittest import TestCase +from datumaro.util.test_utils import compare_datasets class TestOperations(TestCase): @@ -131,4 +134,234 @@ def test_stats(self): actual = compute_ann_statistics(dataset) - self.assertEqual(expected, actual) \ No newline at end of file + self.assertEqual(expected, actual) + +class TestMultimerge(TestCase): + def test_can_match_items(self): + # items 1 and 3 are unique, item 2 is common and should be merged + + source0 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ Label(0), ]), + DatasetItem(2, annotations=[ Label(0), ]), + ], categories=['a', 'b']) + + source1 = Dataset.from_iterable([ + DatasetItem(2, annotations=[ Label(1), ]), + DatasetItem(3, annotations=[ Label(0), ]), + ], categories=['a', 'b']) + + source2 = Dataset.from_iterable([ + DatasetItem(2, annotations=[ Label(0), Bbox(1, 2, 3, 4) ]), + ], categories=['a', 'b']) + + expected = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(0, attributes={'score': 1/3}), + ]), + DatasetItem(2, annotations=[ + Label(0, attributes={'score': 2/3}), + Label(1, attributes={'score': 1/3}), + Bbox(1, 2, 3, 4, attributes={'score': 1.0}), + ]), + DatasetItem(3, annotations=[ + Label(0, attributes={'score': 1/3}), + ]), + ], categories=['a', 'b']) + + merger = IntersectMerge() + merged = merger([source0, source1, source2]) + + compare_datasets(self, expected, merged) + self.assertEqual( + [ + NoMatchingItemError(item_id=('1', ''), sources={1, 2}), + NoMatchingItemError(item_id=('3', ''), sources={0, 2}), + ], + sorted((e for e in merger.errors + if isinstance(e, NoMatchingItemError)), + key=lambda e: e.item_id) + ) + self.assertEqual( + [ + NoMatchingAnnError(item_id=('2', ''), sources={0, 1}, + ann=source2.get('2').annotations[1]), + ], + sorted((e for e in merger.errors + if isinstance(e, NoMatchingAnnError)), + key=lambda e: e.item_id) + ) + + def test_can_match_shapes(self): + source0 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + # unique + Bbox(1, 2, 3, 4, label=1), + + # common + Mask(label=3, z_order=2, image=np.array([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 0], + ])), + Polygon([1, 0, 3, 2, 1, 2]), + + # an instance with keypoints + Bbox(4, 5, 2, 4, label=2, z_order=1, group=1), + Points([5, 6], label=0, group=1), + Points([6, 8], label=1, group=1), + + PolyLine([1, 1, 2, 1, 3, 1]), + ]), + ], categories=['a', 'b', 'c']) + + source1 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + # common + Mask(label=3, image=np.array([ + [0, 0, 0, 0], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ])), + Polygon([0, 2, 2, 0, 2, 1]), + + # an instance with keypoints + Bbox(4, 4, 2, 5, label=2, z_order=1, group=2), + Points([5.5, 6.5], label=0, group=2), + Points([6, 8], label=1, group=2), + + PolyLine([1, 1.5, 2, 1.5]), + ]), + ], categories=['a', 'b', 'c']) + + source2 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + # common + Mask(label=3, z_order=3, image=np.array([ + [0, 0, 1, 1], + [0, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 0], + ])), + Polygon([3, 1, 2, 2, 0, 1]), + + # an instance with keypoints, one is missing + Bbox(3, 6, 2, 3, label=2, z_order=4, group=3), + Points([4.5, 5.5], label=0, group=3), + + PolyLine([1, 1.25, 3, 1, 4, 2]), + ]), + ], categories=['a', 'b', 'c']) + + expected = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + # unique + Bbox(1, 2, 3, 4, label=1), + + # common + # nearest to mean bbox + Mask(label=3, z_order=3, image=np.array([ + [0, 0, 0, 0], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ])), + Polygon([1, 0, 3, 2, 1, 2]), + + # an instance with keypoints + Bbox(4, 5, 2, 4, label=2, z_order=4, group=1), + Points([5, 6], label=0, group=1), + Points([6, 8], label=1, group=1), + + PolyLine([1, 1.25, 3, 1, 4, 2]), + ]), + ], categories=['a', 'b', 'c']) + + merger = IntersectMerge(conf={'quorum': 1, 'pairwise_dist': 0.1}) + merged = merger([source0, source1, source2]) + + compare_datasets(self, expected, merged, ignored_attrs={'score'}) + self.assertEqual( + [ + NoMatchingAnnError(item_id=('1', ''), sources={2}, + ann=source0.get('1').annotations[5]), + NoMatchingAnnError(item_id=('1', ''), sources={1, 2}, + ann=source0.get('1').annotations[0]), + ], + sorted((e for e in merger.errors + if isinstance(e, NoMatchingAnnError)), + key=lambda e: len(e.sources)) + ) + + def test_attributes(self): + source0 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(2, attributes={ + 'unique': 1, + 'common_under_quorum': 2, + 'common_over_quorum': 3, + 'ignored': 'q', + }), + ]), + ], categories=['a', 'b', 'c']) + + source1 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(2, attributes={ + 'common_under_quorum': 2, + 'common_over_quorum': 3, + 'ignored': 'q', + }), + ]), + ], categories=['a', 'b', 'c']) + + source2 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(2, attributes={ + 'common_over_quorum': 3, + 'ignored': 'q', + }), + ]), + ], categories=['a', 'b', 'c']) + + expected = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(2, attributes={ 'common_over_quorum': 3 }), + ]), + ], categories=['a', 'b', 'c']) + + merger = IntersectMerge(conf={ + 'quorum': 3, 'ignored_attributes': {'ignored'}}) + merged = merger([source0, source1, source2]) + + compare_datasets(self, expected, merged, ignored_attrs={'score'}) + self.assertEqual(2, len([e for e in merger.errors + if isinstance(e, FailedAttrVotingError)]) + ) + + def test_group_checks(self): + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Bbox(0, 0, 0, 0, label=0, group=1), # misses an optional label + Bbox(0, 0, 0, 0, label=1, group=1), + + Bbox(0, 0, 0, 0, label=2, group=2), # misses a mandatory label - error + Bbox(0, 0, 0, 0, label=2, group=2), + + Bbox(0, 0, 0, 0, label=4), # misses an optional label + Bbox(0, 0, 0, 0, label=5), # misses a mandatory label - error + Bbox(0, 0, 0, 0, label=0), # misses a mandatory label - error + + Bbox(0, 0, 0, 0, label=3), # not listed - not checked + ]), + ], categories=['a', 'a_g1', 'a_g2_opt', 'b', 'c', 'c_g1_opt']) + + merger = IntersectMerge(conf={'groups': [ + ['a', 'a_g1', 'a_g2_opt?'], ['c', 'c_g1_opt?'] + ]}) + merger([dataset, dataset]) + + self.assertEqual(3, len([e for e in merger.errors + if isinstance(e, WrongGroupError)]), merger.errors + )