Skip to content

Commit be5577d

Browse files
authored
[Datumaro] Label remapping transform (#1233)
* Add label remapping transform * Apply transforms before project saving * Refactor voc converter
1 parent 78dad73 commit be5577d

File tree

4 files changed

+228
-17
lines changed

4 files changed

+228
-17
lines changed

datumaro/datumaro/components/project.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,8 @@ def sources(self):
634634
return self._sources
635635

636636
def _save_branch_project(self, extractor, save_dir=None):
637+
extractor = Dataset.from_extractors(extractor) # apply lazy transforms
638+
637639
# NOTE: probably this function should be in the ViewModel layer
638640
save_dir = osp.abspath(save_dir)
639641
if save_dir:

datumaro/datumaro/plugins/transforms.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
#
44
# SPDX-License-Identifier: MIT
55

6+
from enum import Enum
67
import logging as log
78
import os.path as osp
89
import random
910

1011
import pycocotools.mask as mask_utils
1112

1213
from datumaro.components.extractor import (Transform, AnnotationType,
13-
RleMask, Polygon, Bbox)
14+
RleMask, Polygon, Bbox,
15+
LabelCategories, MaskCategories, PointsCategories
16+
)
1417
from datumaro.components.cli_plugin import CliPlugin
1518
import datumaro.util.mask_tools as mask_tools
1619
from datumaro.util.annotation_tools import find_group_leader, find_instances
@@ -46,7 +49,7 @@ def crop_segments(cls, segment_anns, img_width, img_height):
4649
segments.append(s.points)
4750
elif s.type == AnnotationType.mask:
4851
if isinstance(s, RleMask):
49-
rle = s._rle
52+
rle = s.rle
5053
else:
5154
rle = mask_tools.mask_to_rle(s.image)
5255
segments.append(rle)
@@ -365,3 +368,116 @@ def transform_item(self, item):
365368
if item.has_image and item.image.filename:
366369
name = osp.splitext(item.image.filename)[0]
367370
return self.wrap_item(item, id=name)
371+
372+
class RemapLabels(Transform, CliPlugin):
373+
DefaultAction = Enum('DefaultAction', ['keep', 'delete'])
374+
375+
@staticmethod
376+
def _split_arg(s):
377+
parts = s.split(':')
378+
if len(parts) != 2:
379+
import argparse
380+
raise argparse.ArgumentTypeError()
381+
return (parts[0], parts[1])
382+
383+
@classmethod
384+
def build_cmdline_parser(cls, **kwargs):
385+
parser = super().build_cmdline_parser(**kwargs)
386+
parser.add_argument('-l', '--label', action='append',
387+
type=cls._split_arg, dest='mapping',
388+
help="Label in the form of: '<src>:<dst>' (repeatable)")
389+
parser.add_argument('--default',
390+
choices=[a.name for a in cls.DefaultAction],
391+
default=cls.DefaultAction.keep.name,
392+
help="Action for unspecified labels")
393+
return parser
394+
395+
def __init__(self, extractor, mapping, default=None):
396+
super().__init__(extractor)
397+
398+
assert isinstance(default, (str, self.DefaultAction))
399+
if isinstance(default, str):
400+
default = self.DefaultAction[default]
401+
402+
assert isinstance(mapping, (dict, list))
403+
if isinstance(mapping, list):
404+
mapping = dict(mapping)
405+
406+
self._categories = {}
407+
408+
src_label_cat = self._extractor.categories().get(AnnotationType.label)
409+
if src_label_cat is not None:
410+
self._make_label_id_map(src_label_cat, mapping, default)
411+
412+
src_mask_cat = self._extractor.categories().get(AnnotationType.mask)
413+
if src_mask_cat is not None:
414+
assert src_label_cat is not None
415+
dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes)
416+
dst_mask_cat.colormap = {
417+
id: src_mask_cat.colormap[id]
418+
for id, _ in enumerate(src_label_cat.items)
419+
if self._map_id(id) or id == 0
420+
}
421+
self._categories[AnnotationType.mask] = dst_mask_cat
422+
423+
src_points_cat = self._extractor.categories().get(AnnotationType.points)
424+
if src_points_cat is not None:
425+
assert src_label_cat is not None
426+
dst_points_cat = PointsCategories(attributes=src_points_cat.attributes)
427+
dst_points_cat.items = {
428+
id: src_points_cat.items[id]
429+
for id, item in enumerate(src_label_cat.items)
430+
if self._map_id(id) or id == 0
431+
}
432+
self._categories[AnnotationType.points] = dst_points_cat
433+
434+
def _make_label_id_map(self, src_label_cat, label_mapping, default_action):
435+
dst_label_cat = LabelCategories(attributes=src_label_cat.attributes)
436+
id_mapping = {}
437+
for src_index, src_label in enumerate(src_label_cat.items):
438+
dst_label = label_mapping.get(src_label.name)
439+
if not dst_label and default_action == self.DefaultAction.keep:
440+
dst_label = src_label.name # keep unspecified as is
441+
if not dst_label:
442+
continue
443+
444+
dst_index = dst_label_cat.find(dst_label)[0]
445+
if dst_index is None:
446+
dst_label_cat.add(dst_label,
447+
src_label.parent, src_label.attributes)
448+
dst_index = dst_label_cat.find(dst_label)[0]
449+
id_mapping[src_index] = dst_index
450+
451+
if log.getLogger().isEnabledFor(log.DEBUG):
452+
log.debug("Label mapping:")
453+
for src_id, src_label in enumerate(src_label_cat.items):
454+
if id_mapping.get(src_id):
455+
log.debug("#%s '%s' -> #%s '%s'",
456+
src_id, src_label.name, id_mapping[src_id],
457+
dst_label_cat.items[id_mapping[src_id]].name
458+
)
459+
else:
460+
log.debug("#%s '%s' -> <deleted>", src_id, src_label.name)
461+
462+
self._map_id = lambda src_id: id_mapping.get(src_id, None)
463+
self._categories[AnnotationType.label] = dst_label_cat
464+
465+
def categories(self):
466+
return self._categories
467+
468+
def transform_item(self, item):
469+
# TODO: provide non-inplace version
470+
annotations = []
471+
for ann in item.annotations:
472+
if ann.type in { AnnotationType.label, AnnotationType.mask,
473+
AnnotationType.points, AnnotationType.polygon,
474+
AnnotationType.polyline, AnnotationType.bbox
475+
} and ann.label is not None:
476+
conv_label = self._map_id(ann.label)
477+
if conv_label is not None:
478+
ann._label = conv_label
479+
annotations.append(ann)
480+
else:
481+
annotations.append(ann)
482+
item._annotations = annotations
483+
return item

datumaro/datumaro/plugins/voc_format/converter.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,13 @@ def _write_xml_bbox(bbox, parent_elem):
5353
class _Converter:
5454
def __init__(self, extractor, save_dir,
5555
tasks=None, apply_colormap=True, save_images=False, label_map=None):
56-
assert tasks is None or isinstance(tasks, (VocTask, list))
56+
assert tasks is None or isinstance(tasks, (VocTask, list, set))
5757
if tasks is None:
58-
tasks = list(VocTask)
58+
tasks = set(VocTask)
5959
elif isinstance(tasks, VocTask):
60-
tasks = [tasks]
60+
tasks = {tasks}
6161
else:
62-
tasks = [t if t in VocTask else VocTask[t] for t in tasks]
63-
62+
tasks = set(t if t in VocTask else VocTask[t] for t in tasks)
6463
self._tasks = tasks
6564

6665
self._extractor = extractor
@@ -259,10 +258,10 @@ def save_subsets(self):
259258
if len(actions_elem) != 0:
260259
obj_elem.append(actions_elem)
261260

262-
if set(self._tasks) & set([None,
261+
if self._tasks & {None,
263262
VocTask.detection,
264263
VocTask.person_layout,
265-
VocTask.action_classification]):
264+
VocTask.action_classification}:
266265
with open(osp.join(self._ann_dir, item.id + '.xml'), 'w') as f:
267266
f.write(ET.tostring(root_elem,
268267
encoding='unicode', pretty_print=True))
@@ -302,19 +301,19 @@ def save_subsets(self):
302301
action_list[item.id] = None
303302
segm_list[item.id] = None
304303

305-
if set(self._tasks) & set([None,
304+
if self._tasks & {None,
306305
VocTask.classification,
307306
VocTask.detection,
308307
VocTask.action_classification,
309-
VocTask.person_layout]):
308+
VocTask.person_layout}:
310309
self.save_clsdet_lists(subset_name, clsdet_list)
311-
if set(self._tasks) & set([None, VocTask.classification]):
310+
if self._tasks & {None, VocTask.classification}:
312311
self.save_class_lists(subset_name, class_lists)
313-
if set(self._tasks) & set([None, VocTask.action_classification]):
312+
if self._tasks & {None, VocTask.action_classification}:
314313
self.save_action_lists(subset_name, action_list)
315-
if set(self._tasks) & set([None, VocTask.person_layout]):
314+
if self._tasks & {None, VocTask.person_layout}:
316315
self.save_layout_lists(subset_name, layout_list)
317-
if set(self._tasks) & set([None, VocTask.segmentation]):
316+
if self._tasks & {None, VocTask.segmentation}:
318317
self.save_segm_lists(subset_name, segm_list)
319318

320319
def save_action_lists(self, subset_name, action_list):

datumaro/tests/test_transforms.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from unittest import TestCase
44

55
from datumaro.components.extractor import (Extractor, DatasetItem,
6-
Mask, Polygon, PolyLine, Points, Bbox
6+
Mask, Polygon, PolyLine, Points, Bbox, Label,
7+
LabelCategories, MaskCategories, AnnotationType
78
)
8-
from datumaro.util.test_utils import compare_datasets
9+
import datumaro.util.mask_tools as mask_tools
910
import datumaro.plugins.transforms as transforms
11+
from datumaro.util.test_utils import compare_datasets
1012

1113

1214
class TransformsTest(TestCase):
@@ -361,3 +363,95 @@ def __iter__(self):
361363
('train', -0.5),
362364
('test', 1.5),
363365
])
366+
367+
def test_remap_labels(self):
368+
class SrcExtractor(Extractor):
369+
def __iter__(self):
370+
return iter([
371+
DatasetItem(id=1, annotations=[
372+
# Should be remapped
373+
Label(1),
374+
Bbox(1, 2, 3, 4, label=2),
375+
Mask(image=np.array([1]), label=3),
376+
377+
# Should be kept
378+
Polygon([1, 1, 2, 2, 3, 4], label=4),
379+
PolyLine([1, 3, 4, 2, 5, 6], label=None)
380+
]),
381+
])
382+
383+
def categories(self):
384+
label_cat = LabelCategories()
385+
label_cat.add('label0')
386+
label_cat.add('label1')
387+
label_cat.add('label2')
388+
label_cat.add('label3')
389+
label_cat.add('label4')
390+
391+
mask_cat = MaskCategories(
392+
colormap=mask_tools.generate_colormap(5))
393+
394+
return {
395+
AnnotationType.label: label_cat,
396+
AnnotationType.mask: mask_cat,
397+
}
398+
399+
class DstExtractor(Extractor):
400+
def __iter__(self):
401+
return iter([
402+
DatasetItem(id=1, annotations=[
403+
Label(1),
404+
Bbox(1, 2, 3, 4, label=0),
405+
Mask(image=np.array([1]), label=1),
406+
407+
Polygon([1, 1, 2, 2, 3, 4], label=2),
408+
PolyLine([1, 3, 4, 2, 5, 6], label=None)
409+
]),
410+
])
411+
412+
def categories(self):
413+
label_cat = LabelCategories()
414+
label_cat.add('label0')
415+
label_cat.add('label9')
416+
label_cat.add('label4')
417+
418+
mask_cat = MaskCategories(colormap={
419+
k: v for k, v in mask_tools.generate_colormap(5).items()
420+
if k in { 0, 1, 3, 4 }
421+
})
422+
423+
return {
424+
AnnotationType.label: label_cat,
425+
AnnotationType.mask: mask_cat,
426+
}
427+
428+
actual = transforms.RemapLabels(SrcExtractor(), mapping={
429+
'label1': 'label9',
430+
'label2': 'label0',
431+
'label3': 'label9',
432+
}, default='keep')
433+
434+
compare_datasets(self, DstExtractor(), actual)
435+
436+
def test_remap_labels_delete_unspecified(self):
437+
class SrcExtractor(Extractor):
438+
def __iter__(self):
439+
return iter([ DatasetItem(id=1, annotations=[ Label(0) ]) ])
440+
441+
def categories(self):
442+
label_cat = LabelCategories()
443+
label_cat.add('label0')
444+
445+
return { AnnotationType.label: label_cat }
446+
447+
class DstExtractor(Extractor):
448+
def __iter__(self):
449+
return iter([ DatasetItem(id=1, annotations=[]) ])
450+
451+
def categories(self):
452+
return { AnnotationType.label: LabelCategories() }
453+
454+
actual = transforms.RemapLabels(SrcExtractor(),
455+
mapping={}, default='delete')
456+
457+
compare_datasets(self, DstExtractor(), actual)

0 commit comments

Comments
 (0)