3
3
#
4
4
# SPDX-License-Identifier: MIT
5
5
6
+ from enum import Enum
6
7
import logging as log
7
8
import os .path as osp
8
9
import random
9
10
10
11
import pycocotools .mask as mask_utils
11
12
12
13
from datumaro .components .extractor import (Transform , AnnotationType ,
13
- RleMask , Polygon , Bbox )
14
+ RleMask , Polygon , Bbox ,
15
+ LabelCategories , MaskCategories , PointsCategories
16
+ )
14
17
from datumaro .components .cli_plugin import CliPlugin
15
18
import datumaro .util .mask_tools as mask_tools
16
19
from datumaro .util .annotation_tools import find_group_leader , find_instances
@@ -46,7 +49,7 @@ def crop_segments(cls, segment_anns, img_width, img_height):
46
49
segments .append (s .points )
47
50
elif s .type == AnnotationType .mask :
48
51
if isinstance (s , RleMask ):
49
- rle = s ._rle
52
+ rle = s .rle
50
53
else :
51
54
rle = mask_tools .mask_to_rle (s .image )
52
55
segments .append (rle )
@@ -365,3 +368,116 @@ def transform_item(self, item):
365
368
if item .has_image and item .image .filename :
366
369
name = osp .splitext (item .image .filename )[0 ]
367
370
return self .wrap_item (item , id = name )
371
+
372
+ class RemapLabels (Transform , CliPlugin ):
373
+ DefaultAction = Enum ('DefaultAction' , ['keep' , 'delete' ])
374
+
375
+ @staticmethod
376
+ def _split_arg (s ):
377
+ parts = s .split (':' )
378
+ if len (parts ) != 2 :
379
+ import argparse
380
+ raise argparse .ArgumentTypeError ()
381
+ return (parts [0 ], parts [1 ])
382
+
383
+ @classmethod
384
+ def build_cmdline_parser (cls , ** kwargs ):
385
+ parser = super ().build_cmdline_parser (** kwargs )
386
+ parser .add_argument ('-l' , '--label' , action = 'append' ,
387
+ type = cls ._split_arg , dest = 'mapping' ,
388
+ help = "Label in the form of: '<src>:<dst>' (repeatable)" )
389
+ parser .add_argument ('--default' ,
390
+ choices = [a .name for a in cls .DefaultAction ],
391
+ default = cls .DefaultAction .keep .name ,
392
+ help = "Action for unspecified labels" )
393
+ return parser
394
+
395
+ def __init__ (self , extractor , mapping , default = None ):
396
+ super ().__init__ (extractor )
397
+
398
+ assert isinstance (default , (str , self .DefaultAction ))
399
+ if isinstance (default , str ):
400
+ default = self .DefaultAction [default ]
401
+
402
+ assert isinstance (mapping , (dict , list ))
403
+ if isinstance (mapping , list ):
404
+ mapping = dict (mapping )
405
+
406
+ self ._categories = {}
407
+
408
+ src_label_cat = self ._extractor .categories ().get (AnnotationType .label )
409
+ if src_label_cat is not None :
410
+ self ._make_label_id_map (src_label_cat , mapping , default )
411
+
412
+ src_mask_cat = self ._extractor .categories ().get (AnnotationType .mask )
413
+ if src_mask_cat is not None :
414
+ assert src_label_cat is not None
415
+ dst_mask_cat = MaskCategories (attributes = src_mask_cat .attributes )
416
+ dst_mask_cat .colormap = {
417
+ id : src_mask_cat .colormap [id ]
418
+ for id , _ in enumerate (src_label_cat .items )
419
+ if self ._map_id (id ) or id == 0
420
+ }
421
+ self ._categories [AnnotationType .mask ] = dst_mask_cat
422
+
423
+ src_points_cat = self ._extractor .categories ().get (AnnotationType .points )
424
+ if src_points_cat is not None :
425
+ assert src_label_cat is not None
426
+ dst_points_cat = PointsCategories (attributes = src_points_cat .attributes )
427
+ dst_points_cat .items = {
428
+ id : src_points_cat .items [id ]
429
+ for id , item in enumerate (src_label_cat .items )
430
+ if self ._map_id (id ) or id == 0
431
+ }
432
+ self ._categories [AnnotationType .points ] = dst_points_cat
433
+
434
+ def _make_label_id_map (self , src_label_cat , label_mapping , default_action ):
435
+ dst_label_cat = LabelCategories (attributes = src_label_cat .attributes )
436
+ id_mapping = {}
437
+ for src_index , src_label in enumerate (src_label_cat .items ):
438
+ dst_label = label_mapping .get (src_label .name )
439
+ if not dst_label and default_action == self .DefaultAction .keep :
440
+ dst_label = src_label .name # keep unspecified as is
441
+ if not dst_label :
442
+ continue
443
+
444
+ dst_index = dst_label_cat .find (dst_label )[0 ]
445
+ if dst_index is None :
446
+ dst_label_cat .add (dst_label ,
447
+ src_label .parent , src_label .attributes )
448
+ dst_index = dst_label_cat .find (dst_label )[0 ]
449
+ id_mapping [src_index ] = dst_index
450
+
451
+ if log .getLogger ().isEnabledFor (log .DEBUG ):
452
+ log .debug ("Label mapping:" )
453
+ for src_id , src_label in enumerate (src_label_cat .items ):
454
+ if id_mapping .get (src_id ):
455
+ log .debug ("#%s '%s' -> #%s '%s'" ,
456
+ src_id , src_label .name , id_mapping [src_id ],
457
+ dst_label_cat .items [id_mapping [src_id ]].name
458
+ )
459
+ else :
460
+ log .debug ("#%s '%s' -> <deleted>" , src_id , src_label .name )
461
+
462
+ self ._map_id = lambda src_id : id_mapping .get (src_id , None )
463
+ self ._categories [AnnotationType .label ] = dst_label_cat
464
+
465
+ def categories (self ):
466
+ return self ._categories
467
+
468
+ def transform_item (self , item ):
469
+ # TODO: provide non-inplace version
470
+ annotations = []
471
+ for ann in item .annotations :
472
+ if ann .type in { AnnotationType .label , AnnotationType .mask ,
473
+ AnnotationType .points , AnnotationType .polygon ,
474
+ AnnotationType .polyline , AnnotationType .bbox
475
+ } and ann .label is not None :
476
+ conv_label = self ._map_id (ann .label )
477
+ if conv_label is not None :
478
+ ann ._label = conv_label
479
+ annotations .append (ann )
480
+ else :
481
+ annotations .append (ann )
482
+ item ._annotations = annotations
483
+ return item
0 commit comments