Skip to content

Commit b36f402

Browse files
authored
[Datumaro] Add masks to tfrecord format (#1156)
* Employ transforms and item wrapper * Add image class and tests * Add image info support to formats * Fix cli * Fix merge and voc converte * Update remote images extractor * Codacy * Remove item name, require path in Image * Merge images of dataset items * Update tests * Add image dir converter * Update Datumaro format * Update COCO format with image info * Update CVAT format with image info * Update TFrecord format with image info * Update VOC formar with image info * Update YOLO format with image info * Update dataset manager bindings with image info * Add image name to id transform * Fix coco export * Add masks support for tfrecord * Refactor coco * Fix comparison * Remove dead code * Extract common code for instances
1 parent f208cfe commit b36f402

File tree

8 files changed

+253
-194
lines changed

8 files changed

+253
-194
lines changed

datumaro/datumaro/plugins/coco_format/converter.py

+10-48
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
from datumaro.components.converter import Converter
1616
from datumaro.components.extractor import (DEFAULT_SUBSET_NAME,
17-
AnnotationType, Points, Mask
17+
AnnotationType, Points
1818
)
1919
from datumaro.components.cli_plugin import CliPlugin
2020
from datumaro.util import find
2121
from datumaro.util.image import save_image
2222
import datumaro.util.mask_tools as mask_tools
23+
import datumaro.util.annotation_tools as anno_tools
2324

2425
from .format import CocoTask, CocoPath
2526

@@ -194,7 +195,7 @@ def crop_segments(cls, instances, img_width, img_height):
194195
if inst[1]:
195196
inst[1] = sum(new_segments, [])
196197
else:
197-
mask = cls.merge_masks(new_segments)
198+
mask = mask_tools.merge_masks(new_segments)
198199
inst[2] = mask_tools.mask_to_rle(mask)
199200

200201
return instances
@@ -205,8 +206,8 @@ def find_instance_parts(self, group, img_width, img_height):
205206
masks = [a for a in group if a.type == AnnotationType.mask]
206207

207208
anns = boxes + polygons + masks
208-
leader = self.find_group_leader(anns)
209-
bbox = self.compute_bbox(anns)
209+
leader = anno_tools.find_group_leader(anns)
210+
bbox = anno_tools.compute_bbox(anns)
210211
mask = None
211212
polygons = [p.points for p in polygons]
212213

@@ -228,68 +229,29 @@ def find_instance_parts(self, group, img_width, img_height):
228229
if masks:
229230
if mask is not None:
230231
masks += [mask]
231-
mask = self.merge_masks(masks)
232+
mask = mask_tools.merge_masks([m.image for m in masks])
232233

233234
if mask is not None:
234235
mask = mask_tools.mask_to_rle(mask)
235236
polygons = []
236237
else:
237238
if masks:
238-
mask = self.merge_masks(masks)
239+
mask = mask_tools.merge_masks([m.image for m in masks])
239240
polygons += mask_tools.mask_to_polygons(mask)
240241
mask = None
241242

242243
return [leader, polygons, mask, bbox]
243244

244-
@staticmethod
245-
def find_group_leader(group):
246-
return max(group, key=lambda x: x.get_area())
247-
248-
@staticmethod
249-
def merge_masks(masks):
250-
if not masks:
251-
return None
252-
253-
def get_mask(m):
254-
if isinstance(m, Mask):
255-
return m.image
256-
else:
257-
return m
258-
259-
binary_mask = get_mask(masks[0])
260-
for m in masks[1:]:
261-
binary_mask |= get_mask(m)
262-
263-
return binary_mask
264-
265-
@staticmethod
266-
def compute_bbox(annotations):
267-
boxes = [ann.get_bbox() for ann in annotations]
268-
x0 = min((b[0] for b in boxes), default=0)
269-
y0 = min((b[1] for b in boxes), default=0)
270-
x1 = max((b[0] + b[2] for b in boxes), default=0)
271-
y1 = max((b[1] + b[3] for b in boxes), default=0)
272-
return [x0, y0, x1 - x0, y1 - y0]
273-
274245
@staticmethod
275246
def find_instance_anns(annotations):
276247
return [a for a in annotations
277-
if a.type in { AnnotationType.bbox, AnnotationType.polygon } or \
278-
a.type == AnnotationType.mask and a.label is not None
248+
if a.type in { AnnotationType.bbox,
249+
AnnotationType.polygon, AnnotationType.mask }
279250
]
280251

281252
@classmethod
282253
def find_instances(cls, annotations):
283-
instance_anns = cls.find_instance_anns(annotations)
284-
285-
ann_groups = []
286-
for g_id, group in groupby(instance_anns, lambda a: a.group):
287-
if not g_id:
288-
ann_groups.extend(([a] for a in group))
289-
else:
290-
ann_groups.append(list(group))
291-
292-
return ann_groups
254+
return anno_tools.find_instances(cls.find_instance_anns(annotations))
293255

294256
def save_annotations(self, item):
295257
instances = self.find_instances(item.annotations)

datumaro/datumaro/plugins/tf_detection_api_format/converter.py

+132-95
Original file line numberDiff line numberDiff line change
@@ -16,115 +16,64 @@
1616
from datumaro.components.converter import Converter
1717
from datumaro.components.cli_plugin import CliPlugin
1818
from datumaro.util.image import encode_image
19+
from datumaro.util.mask_tools import merge_masks
20+
from datumaro.util.annotation_tools import (compute_bbox,
21+
find_group_leader, find_instances)
1922
from datumaro.util.tf_util import import_tf as _import_tf
2023

2124
from .format import DetectionApiPath
2225
tf = _import_tf()
2326

2427

25-
# we need it to filter out non-ASCII characters, otherwise training will crash
28+
# filter out non-ASCII characters, otherwise training will crash
2629
_printable = set(string.printable)
2730
def _make_printable(s):
2831
return ''.join(filter(lambda x: x in _printable, s))
2932

30-
def _make_tf_example(item, get_label_id, get_label, save_images=False):
31-
def int64_feature(value):
32-
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
33-
34-
def int64_list_feature(value):
35-
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
36-
37-
def bytes_feature(value):
38-
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
39-
40-
def bytes_list_feature(value):
41-
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
42-
43-
def float_list_feature(value):
44-
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
45-
46-
47-
features = {
48-
'image/source_id': bytes_feature(str(item.id).encode('utf-8')),
49-
'image/filename': bytes_feature(
50-
('%s%s' % (item.id, DetectionApiPath.IMAGE_EXT)).encode('utf-8')),
51-
}
52-
53-
if not item.has_image:
54-
raise Exception("Failed to export dataset item '%s': "
55-
"item has no image info" % item.id)
56-
height, width = item.image.size
57-
58-
features.update({
59-
'image/height': int64_feature(height),
60-
'image/width': int64_feature(width),
61-
})
62-
63-
features.update({
64-
'image/encoded': bytes_feature(b''),
65-
'image/format': bytes_feature(b'')
66-
})
67-
if save_images:
68-
if item.has_image and item.image.has_data:
69-
fmt = DetectionApiPath.IMAGE_FORMAT
70-
buffer = encode_image(item.image.data, DetectionApiPath.IMAGE_EXT)
71-
72-
features.update({
73-
'image/encoded': bytes_feature(buffer),
74-
'image/format': bytes_feature(fmt.encode('utf-8')),
75-
})
76-
else:
77-
log.warning("Item '%s' has no image" % item.id)
78-
79-
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
80-
xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box)
81-
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
82-
ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box)
83-
classes_text = [] # List of string class name of bounding box (1 per box)
84-
classes = [] # List of integer class id of bounding box (1 per box)
85-
86-
boxes = [ann for ann in item.annotations if ann.type is AnnotationType.bbox]
87-
for box in boxes:
88-
box_label = _make_printable(get_label(box.label))
89-
90-
xmins.append(box.points[0] / width)
91-
xmaxs.append(box.points[2] / width)
92-
ymins.append(box.points[1] / height)
93-
ymaxs.append(box.points[3] / height)
94-
classes_text.append(box_label.encode('utf-8'))
95-
classes.append(get_label_id(box.label))
96-
97-
if boxes:
98-
features.update({
99-
'image/object/bbox/xmin': float_list_feature(xmins),
100-
'image/object/bbox/xmax': float_list_feature(xmaxs),
101-
'image/object/bbox/ymin': float_list_feature(ymins),
102-
'image/object/bbox/ymax': float_list_feature(ymaxs),
103-
'image/object/class/text': bytes_list_feature(classes_text),
104-
'image/object/class/label': int64_list_feature(classes),
105-
})
33+
def int64_feature(value):
34+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
35+
36+
def int64_list_feature(value):
37+
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
10638

107-
tf_example = tf.train.Example(
108-
features=tf.train.Features(feature=features))
39+
def bytes_feature(value):
40+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
10941

110-
return tf_example
42+
def bytes_list_feature(value):
43+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
44+
45+
def float_list_feature(value):
46+
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
11147

11248
class TfDetectionApiConverter(Converter, CliPlugin):
11349
@classmethod
11450
def build_cmdline_parser(cls, **kwargs):
11551
parser = super().build_cmdline_parser(**kwargs)
11652
parser.add_argument('--save-images', action='store_true',
11753
help="Save images (default: %(default)s)")
54+
parser.add_argument('--save-masks', action='store_true',
55+
help="Include instance masks (default: %(default)s)")
11856
return parser
11957

120-
def __init__(self, save_images=False):
58+
def __init__(self, save_images=False, save_masks=False):
12159
super().__init__()
12260

12361
self._save_images = save_images
62+
self._save_masks = save_masks
12463

12564
def __call__(self, extractor, save_dir):
12665
os.makedirs(save_dir, exist_ok=True)
12766

67+
label_categories = extractor.categories().get(AnnotationType.label,
68+
LabelCategories())
69+
get_label = lambda label_id: label_categories.items[label_id].name \
70+
if label_id is not None else ''
71+
label_ids = OrderedDict((label.name, 1 + idx)
72+
for idx, label in enumerate(label_categories.items))
73+
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)
74+
self._get_label = get_label
75+
self._get_label_id = map_label_id
76+
12877
subsets = extractor.subsets()
12978
if len(subsets) == 0:
13079
subsets = [ None ]
@@ -136,14 +85,6 @@ def __call__(self, extractor, save_dir):
13685
subset_name = DEFAULT_SUBSET_NAME
13786
subset = extractor
13887

139-
label_categories = subset.categories().get(AnnotationType.label,
140-
LabelCategories())
141-
get_label = lambda label_id: label_categories.items[label_id].name \
142-
if label_id is not None else ''
143-
label_ids = OrderedDict((label.name, 1 + idx)
144-
for idx, label in enumerate(label_categories.items))
145-
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)
146-
14788
labelmap_path = osp.join(save_dir, DetectionApiPath.LABELMAP_FILE)
14889
with codecs.open(labelmap_path, 'w', encoding='utf8') as f:
14990
for label, idx in label_ids.items():
@@ -157,10 +98,106 @@ def __call__(self, extractor, save_dir):
15798
anno_path = osp.join(save_dir, '%s.tfrecord' % (subset_name))
15899
with tf.io.TFRecordWriter(anno_path) as writer:
159100
for item in subset:
160-
tf_example = _make_tf_example(
161-
item,
162-
get_label=get_label,
163-
get_label_id=map_label_id,
164-
save_images=self._save_images,
165-
)
101+
tf_example = self._make_tf_example(item)
166102
writer.write(tf_example.SerializeToString())
103+
104+
@staticmethod
105+
def _find_instances(annotations):
106+
return find_instances(a for a in annotations
107+
if a.type in { AnnotationType.bbox, AnnotationType.mask })
108+
109+
def _find_instance_parts(self, group, img_width, img_height):
110+
boxes = [a for a in group if a.type == AnnotationType.bbox]
111+
masks = [a for a in group if a.type == AnnotationType.mask]
112+
113+
anns = boxes + masks
114+
leader = find_group_leader(anns)
115+
bbox = compute_bbox(anns)
116+
117+
mask = None
118+
if self._save_masks:
119+
mask = merge_masks([m.image for m in masks])
120+
121+
return [leader, mask, bbox]
122+
123+
def _export_instances(self, instances, width, height):
124+
xmins = [] # List of normalized left x coordinates of bounding boxes (1 per box)
125+
xmaxs = [] # List of normalized right x coordinates of bounding boxes (1 per box)
126+
ymins = [] # List of normalized top y coordinates of bounding boxes (1 per box)
127+
ymaxs = [] # List of normalized bottom y coordinates of bounding boxes (1 per box)
128+
classes_text = [] # List of class names of bounding boxes (1 per box)
129+
classes = [] # List of class ids of bounding boxes (1 per box)
130+
masks = [] # List of PNG-encoded instance masks (1 per box)
131+
132+
for leader, mask, box in instances:
133+
label = _make_printable(self._get_label(leader.label))
134+
classes_text.append(label.encode('utf-8'))
135+
classes.append(self._get_label_id(leader.label))
136+
137+
xmins.append(box[0] / width)
138+
xmaxs.append((box[0] + box[2]) / width)
139+
ymins.append(box[1] / height)
140+
ymaxs.append((box[1] + box[3]) / height)
141+
142+
if self._save_masks:
143+
if mask is not None:
144+
mask = encode_image(mask, '.png')
145+
else:
146+
mask = b''
147+
masks.append(mask)
148+
149+
result = {}
150+
if classes:
151+
result = {
152+
'image/object/bbox/xmin': float_list_feature(xmins),
153+
'image/object/bbox/xmax': float_list_feature(xmaxs),
154+
'image/object/bbox/ymin': float_list_feature(ymins),
155+
'image/object/bbox/ymax': float_list_feature(ymaxs),
156+
'image/object/class/text': bytes_list_feature(classes_text),
157+
'image/object/class/label': int64_list_feature(classes),
158+
}
159+
if masks:
160+
result['image/object/mask'] = bytes_list_feature(masks)
161+
return result
162+
163+
def _make_tf_example(self, item):
164+
features = {
165+
'image/source_id': bytes_feature(str(item.id).encode('utf-8')),
166+
'image/filename': bytes_feature(
167+
('%s%s' % (item.id, DetectionApiPath.IMAGE_EXT)).encode('utf-8')),
168+
}
169+
170+
if not item.has_image:
171+
raise Exception("Failed to export dataset item '%s': "
172+
"item has no image info" % item.id)
173+
height, width = item.image.size
174+
175+
features.update({
176+
'image/height': int64_feature(height),
177+
'image/width': int64_feature(width),
178+
})
179+
180+
features.update({
181+
'image/encoded': bytes_feature(b''),
182+
'image/format': bytes_feature(b'')
183+
})
184+
if self._save_images:
185+
if item.has_image and item.image.has_data:
186+
fmt = DetectionApiPath.IMAGE_FORMAT
187+
buffer = encode_image(item.image.data, DetectionApiPath.IMAGE_EXT)
188+
189+
features.update({
190+
'image/encoded': bytes_feature(buffer),
191+
'image/format': bytes_feature(fmt.encode('utf-8')),
192+
})
193+
else:
194+
log.warning("Item '%s' has no image" % item.id)
195+
196+
instances = self._find_instances(item.annotations)
197+
instances = [self._find_instance_parts(i, width, height) for i in instances]
198+
features.update(self._export_instances(instances, width, height))
199+
200+
tf_example = tf.train.Example(
201+
features=tf.train.Features(feature=features))
202+
203+
return tf_example

0 commit comments

Comments
 (0)