16
16
from datumaro .components .converter import Converter
17
17
from datumaro .components .cli_plugin import CliPlugin
18
18
from datumaro .util .image import encode_image
19
+ from datumaro .util .mask_tools import merge_masks
20
+ from datumaro .util .annotation_tools import (compute_bbox ,
21
+ find_group_leader , find_instances )
19
22
from datumaro .util .tf_util import import_tf as _import_tf
20
23
21
24
from .format import DetectionApiPath
22
25
tf = _import_tf ()
23
26
24
27
25
- # we need it to filter out non-ASCII characters, otherwise training will crash
28
+ # filter out non-ASCII characters, otherwise training will crash
26
29
_printable = set (string .printable )
27
30
def _make_printable (s ):
28
31
return '' .join (filter (lambda x : x in _printable , s ))
29
32
30
- def _make_tf_example (item , get_label_id , get_label , save_images = False ):
31
- def int64_feature (value ):
32
- return tf .train .Feature (int64_list = tf .train .Int64List (value = [value ]))
33
-
34
- def int64_list_feature (value ):
35
- return tf .train .Feature (int64_list = tf .train .Int64List (value = value ))
36
-
37
- def bytes_feature (value ):
38
- return tf .train .Feature (bytes_list = tf .train .BytesList (value = [value ]))
39
-
40
- def bytes_list_feature (value ):
41
- return tf .train .Feature (bytes_list = tf .train .BytesList (value = value ))
42
-
43
- def float_list_feature (value ):
44
- return tf .train .Feature (float_list = tf .train .FloatList (value = value ))
45
-
46
-
47
- features = {
48
- 'image/source_id' : bytes_feature (str (item .id ).encode ('utf-8' )),
49
- 'image/filename' : bytes_feature (
50
- ('%s%s' % (item .id , DetectionApiPath .IMAGE_EXT )).encode ('utf-8' )),
51
- }
52
-
53
- if not item .has_image :
54
- raise Exception ("Failed to export dataset item '%s': "
55
- "item has no image info" % item .id )
56
- height , width = item .image .size
57
-
58
- features .update ({
59
- 'image/height' : int64_feature (height ),
60
- 'image/width' : int64_feature (width ),
61
- })
62
-
63
- features .update ({
64
- 'image/encoded' : bytes_feature (b'' ),
65
- 'image/format' : bytes_feature (b'' )
66
- })
67
- if save_images :
68
- if item .has_image and item .image .has_data :
69
- fmt = DetectionApiPath .IMAGE_FORMAT
70
- buffer = encode_image (item .image .data , DetectionApiPath .IMAGE_EXT )
71
-
72
- features .update ({
73
- 'image/encoded' : bytes_feature (buffer ),
74
- 'image/format' : bytes_feature (fmt .encode ('utf-8' )),
75
- })
76
- else :
77
- log .warning ("Item '%s' has no image" % item .id )
78
-
79
- xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
80
- xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box)
81
- ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
82
- ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box)
83
- classes_text = [] # List of string class name of bounding box (1 per box)
84
- classes = [] # List of integer class id of bounding box (1 per box)
85
-
86
- boxes = [ann for ann in item .annotations if ann .type is AnnotationType .bbox ]
87
- for box in boxes :
88
- box_label = _make_printable (get_label (box .label ))
89
-
90
- xmins .append (box .points [0 ] / width )
91
- xmaxs .append (box .points [2 ] / width )
92
- ymins .append (box .points [1 ] / height )
93
- ymaxs .append (box .points [3 ] / height )
94
- classes_text .append (box_label .encode ('utf-8' ))
95
- classes .append (get_label_id (box .label ))
96
-
97
- if boxes :
98
- features .update ({
99
- 'image/object/bbox/xmin' : float_list_feature (xmins ),
100
- 'image/object/bbox/xmax' : float_list_feature (xmaxs ),
101
- 'image/object/bbox/ymin' : float_list_feature (ymins ),
102
- 'image/object/bbox/ymax' : float_list_feature (ymaxs ),
103
- 'image/object/class/text' : bytes_list_feature (classes_text ),
104
- 'image/object/class/label' : int64_list_feature (classes ),
105
- })
33
+ def int64_feature (value ):
34
+ return tf .train .Feature (int64_list = tf .train .Int64List (value = [value ]))
35
+
36
+ def int64_list_feature (value ):
37
+ return tf .train .Feature (int64_list = tf .train .Int64List (value = value ))
106
38
107
- tf_example = tf . train . Example (
108
- features = tf .train .Features ( feature = features ))
39
+ def bytes_feature ( value ):
40
+ return tf . train . Feature ( bytes_list = tf .train .BytesList ( value = [ value ] ))
109
41
110
- return tf_example
42
+ def bytes_list_feature (value ):
43
+ return tf .train .Feature (bytes_list = tf .train .BytesList (value = value ))
44
+
45
+ def float_list_feature (value ):
46
+ return tf .train .Feature (float_list = tf .train .FloatList (value = value ))
111
47
112
48
class TfDetectionApiConverter (Converter , CliPlugin ):
113
49
@classmethod
114
50
def build_cmdline_parser (cls , ** kwargs ):
115
51
parser = super ().build_cmdline_parser (** kwargs )
116
52
parser .add_argument ('--save-images' , action = 'store_true' ,
117
53
help = "Save images (default: %(default)s)" )
54
+ parser .add_argument ('--save-masks' , action = 'store_true' ,
55
+ help = "Include instance masks (default: %(default)s)" )
118
56
return parser
119
57
120
- def __init__ (self , save_images = False ):
58
+ def __init__ (self , save_images = False , save_masks = False ):
121
59
super ().__init__ ()
122
60
123
61
self ._save_images = save_images
62
+ self ._save_masks = save_masks
124
63
125
64
def __call__ (self , extractor , save_dir ):
126
65
os .makedirs (save_dir , exist_ok = True )
127
66
67
+ label_categories = extractor .categories ().get (AnnotationType .label ,
68
+ LabelCategories ())
69
+ get_label = lambda label_id : label_categories .items [label_id ].name \
70
+ if label_id is not None else ''
71
+ label_ids = OrderedDict ((label .name , 1 + idx )
72
+ for idx , label in enumerate (label_categories .items ))
73
+ map_label_id = lambda label_id : label_ids .get (get_label (label_id ), 0 )
74
+ self ._get_label = get_label
75
+ self ._get_label_id = map_label_id
76
+
128
77
subsets = extractor .subsets ()
129
78
if len (subsets ) == 0 :
130
79
subsets = [ None ]
@@ -136,14 +85,6 @@ def __call__(self, extractor, save_dir):
136
85
subset_name = DEFAULT_SUBSET_NAME
137
86
subset = extractor
138
87
139
- label_categories = subset .categories ().get (AnnotationType .label ,
140
- LabelCategories ())
141
- get_label = lambda label_id : label_categories .items [label_id ].name \
142
- if label_id is not None else ''
143
- label_ids = OrderedDict ((label .name , 1 + idx )
144
- for idx , label in enumerate (label_categories .items ))
145
- map_label_id = lambda label_id : label_ids .get (get_label (label_id ), 0 )
146
-
147
88
labelmap_path = osp .join (save_dir , DetectionApiPath .LABELMAP_FILE )
148
89
with codecs .open (labelmap_path , 'w' , encoding = 'utf8' ) as f :
149
90
for label , idx in label_ids .items ():
@@ -157,10 +98,106 @@ def __call__(self, extractor, save_dir):
157
98
anno_path = osp .join (save_dir , '%s.tfrecord' % (subset_name ))
158
99
with tf .io .TFRecordWriter (anno_path ) as writer :
159
100
for item in subset :
160
- tf_example = _make_tf_example (
161
- item ,
162
- get_label = get_label ,
163
- get_label_id = map_label_id ,
164
- save_images = self ._save_images ,
165
- )
101
+ tf_example = self ._make_tf_example (item )
166
102
writer .write (tf_example .SerializeToString ())
103
+
104
+ @staticmethod
105
+ def _find_instances (annotations ):
106
+ return find_instances (a for a in annotations
107
+ if a .type in { AnnotationType .bbox , AnnotationType .mask })
108
+
109
+ def _find_instance_parts (self , group , img_width , img_height ):
110
+ boxes = [a for a in group if a .type == AnnotationType .bbox ]
111
+ masks = [a for a in group if a .type == AnnotationType .mask ]
112
+
113
+ anns = boxes + masks
114
+ leader = find_group_leader (anns )
115
+ bbox = compute_bbox (anns )
116
+
117
+ mask = None
118
+ if self ._save_masks :
119
+ mask = merge_masks ([m .image for m in masks ])
120
+
121
+ return [leader , mask , bbox ]
122
+
123
+ def _export_instances (self , instances , width , height ):
124
+ xmins = [] # List of normalized left x coordinates of bounding boxes (1 per box)
125
+ xmaxs = [] # List of normalized right x coordinates of bounding boxes (1 per box)
126
+ ymins = [] # List of normalized top y coordinates of bounding boxes (1 per box)
127
+ ymaxs = [] # List of normalized bottom y coordinates of bounding boxes (1 per box)
128
+ classes_text = [] # List of class names of bounding boxes (1 per box)
129
+ classes = [] # List of class ids of bounding boxes (1 per box)
130
+ masks = [] # List of PNG-encoded instance masks (1 per box)
131
+
132
+ for leader , mask , box in instances :
133
+ label = _make_printable (self ._get_label (leader .label ))
134
+ classes_text .append (label .encode ('utf-8' ))
135
+ classes .append (self ._get_label_id (leader .label ))
136
+
137
+ xmins .append (box [0 ] / width )
138
+ xmaxs .append ((box [0 ] + box [2 ]) / width )
139
+ ymins .append (box [1 ] / height )
140
+ ymaxs .append ((box [1 ] + box [3 ]) / height )
141
+
142
+ if self ._save_masks :
143
+ if mask is not None :
144
+ mask = encode_image (mask , '.png' )
145
+ else :
146
+ mask = b''
147
+ masks .append (mask )
148
+
149
+ result = {}
150
+ if classes :
151
+ result = {
152
+ 'image/object/bbox/xmin' : float_list_feature (xmins ),
153
+ 'image/object/bbox/xmax' : float_list_feature (xmaxs ),
154
+ 'image/object/bbox/ymin' : float_list_feature (ymins ),
155
+ 'image/object/bbox/ymax' : float_list_feature (ymaxs ),
156
+ 'image/object/class/text' : bytes_list_feature (classes_text ),
157
+ 'image/object/class/label' : int64_list_feature (classes ),
158
+ }
159
+ if masks :
160
+ result ['image/object/mask' ] = bytes_list_feature (masks )
161
+ return result
162
+
163
+ def _make_tf_example (self , item ):
164
+ features = {
165
+ 'image/source_id' : bytes_feature (str (item .id ).encode ('utf-8' )),
166
+ 'image/filename' : bytes_feature (
167
+ ('%s%s' % (item .id , DetectionApiPath .IMAGE_EXT )).encode ('utf-8' )),
168
+ }
169
+
170
+ if not item .has_image :
171
+ raise Exception ("Failed to export dataset item '%s': "
172
+ "item has no image info" % item .id )
173
+ height , width = item .image .size
174
+
175
+ features .update ({
176
+ 'image/height' : int64_feature (height ),
177
+ 'image/width' : int64_feature (width ),
178
+ })
179
+
180
+ features .update ({
181
+ 'image/encoded' : bytes_feature (b'' ),
182
+ 'image/format' : bytes_feature (b'' )
183
+ })
184
+ if self ._save_images :
185
+ if item .has_image and item .image .has_data :
186
+ fmt = DetectionApiPath .IMAGE_FORMAT
187
+ buffer = encode_image (item .image .data , DetectionApiPath .IMAGE_EXT )
188
+
189
+ features .update ({
190
+ 'image/encoded' : bytes_feature (buffer ),
191
+ 'image/format' : bytes_feature (fmt .encode ('utf-8' )),
192
+ })
193
+ else :
194
+ log .warning ("Item '%s' has no image" % item .id )
195
+
196
+ instances = self ._find_instances (item .annotations )
197
+ instances = [self ._find_instance_parts (i , width , height ) for i in instances ]
198
+ features .update (self ._export_instances (instances , width , height ))
199
+
200
+ tf_example = tf .train .Example (
201
+ features = tf .train .Features (feature = features ))
202
+
203
+ return tf_example
0 commit comments