Skip to content

Commit 625f20c

Browse files
zhiltsov-maxChris Lee-Messer
authored and
Chris Lee-Messer
committed
Coco converter updates (cvat-ai#864)
1 parent 4fe33a1 commit 625f20c

File tree

6 files changed

+214
-104
lines changed

6 files changed

+214
-104
lines changed

datumaro/datumaro/components/config_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
SOURCE_SCHEMA = _SchemaBuilder() \
1212
.add('url', str) \
1313
.add('format', str) \
14-
.add('options', str) \
14+
.add('options', dict) \
1515
.build()
1616

1717
class Source(Config):

datumaro/datumaro/components/converters/ms_coco.py

+75-19
Original file line numberDiff line numberDiff line change
@@ -121,40 +121,96 @@ def save_categories(self, dataset):
121121
})
122122

123123
def save_annotations(self, item):
124-
for ann in item.annotations:
125-
if ann.type != AnnotationType.bbox:
124+
annotations = item.annotations.copy()
125+
126+
while len(annotations) != 0:
127+
ann = annotations.pop()
128+
129+
if ann.type == AnnotationType.bbox and ann.label is not None:
130+
pass
131+
elif ann.type == AnnotationType.polygon and ann.label is not None:
132+
pass
133+
elif ann.type == AnnotationType.mask and ann.label is not None:
134+
pass
135+
else:
126136
continue
127137

128-
is_crowd = ann.attributes.get('is_crowd', False)
138+
bbox = None
129139
segmentation = None
130-
if ann.group is not None:
140+
141+
if ann.type == AnnotationType.bbox:
142+
is_crowd = ann.attributes.get('is_crowd', False)
143+
bbox = ann.get_bbox()
144+
elif ann.type == AnnotationType.polygon:
145+
is_crowd = ann.attributes.get('is_crowd', False)
146+
elif ann.type == AnnotationType.mask:
147+
is_crowd = ann.attributes.get('is_crowd', True)
131148
if is_crowd:
132-
segmentation = find(item.annotations, lambda x: \
133-
x.group == ann.group and x.type == AnnotationType.mask)
134-
if segmentation is not None:
135-
binary_mask = np.array(segmentation.image, dtype=np.bool)
136-
binary_mask = np.asfortranarray(binary_mask, dtype=np.uint8)
137-
segmentation = mask_utils.encode(binary_mask)
138-
area = mask_utils.area(segmentation)
139-
segmentation = mask_tools.convert_mask_to_rle(binary_mask)
140-
else:
141-
segmentation = find(item.annotations, lambda x: \
142-
x.group == ann.group and x.type == AnnotationType.polygon)
143-
if segmentation is not None:
144-
area = ann.area()
145-
segmentation = [segmentation.get_points()]
149+
segmentation = ann
150+
area = None
151+
152+
# If ann in a group, try to find corresponding annotations in
153+
# this group, otherwise try to infer them.
154+
155+
if bbox is None and ann.group is not None:
156+
bbox = find(annotations, lambda x: \
157+
x.group == ann.group and \
158+
x.type == AnnotationType.bbox and \
159+
x.label == ann.label)
160+
if bbox is not None:
161+
bbox = bbox.get_bbox()
162+
163+
if is_crowd:
164+
# is_crowd=True means there should be a mask
165+
if segmentation is None and ann.group is not None:
166+
segmentation = find(annotations, lambda x: \
167+
x.group == ann.group and \
168+
x.type == AnnotationType.mask and \
169+
x.label == ann.label)
170+
if segmentation is not None:
171+
binary_mask = np.array(segmentation.image, dtype=np.bool)
172+
binary_mask = np.asfortranarray(binary_mask, dtype=np.uint8)
173+
segmentation = mask_utils.encode(binary_mask)
174+
area = mask_utils.area(segmentation)
175+
segmentation = mask_tools.convert_mask_to_rle(binary_mask)
176+
else:
177+
# is_crowd=False means there are some polygons
178+
polygons = []
179+
if ann.type == AnnotationType.polygon:
180+
polygons = [ ann ]
181+
if ann.group is not None:
182+
# A single object can consist of several polygons
183+
polygons += [p for p in annotations
184+
if p.group == ann.group and \
185+
p.type == AnnotationType.polygon and \
186+
p.label == ann.label]
187+
if polygons:
188+
segmentation = [p.get_points() for p in polygons]
189+
h, w, _ = item.image.shape
190+
rles = mask_utils.frPyObjects(segmentation, h, w)
191+
rle = mask_utils.merge(rles)
192+
area = mask_utils.area(rle)
193+
194+
if ann.group is not None:
195+
# Mark the group as visited to prevent repeats
196+
for a in annotations[:]:
197+
if a.group == ann.group:
198+
annotations.remove(a)
199+
146200
if segmentation is None:
147201
is_crowd = False
148202
segmentation = [ann.get_polygon()]
149203
area = ann.area()
204+
if bbox is None:
205+
bbox = ann.get_bbox()
150206

151207
elem = {
152208
'id': self._get_ann_id(ann),
153209
'image_id': _cast(item.id, int, 0),
154210
'category_id': _cast(ann.label, int, -1) + 1,
155211
'segmentation': segmentation,
156212
'area': float(area),
157-
'bbox': ann.get_bbox(),
213+
'bbox': bbox,
158214
'iscrowd': int(is_crowd),
159215
}
160216
if 'score' in ann.attributes:

datumaro/datumaro/components/extractor.py

+8
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,14 @@ def __init__(self, points=None,
271271
def get_polygon(self):
272272
return self.get_points()
273273

274+
def area(self):
275+
import pycocotools.mask as mask_utils
276+
277+
_, _, w, h = self.get_bbox()
278+
rle = mask_utils.frPyObjects([self.get_points()], h, w)
279+
area = mask_utils.area(rle)
280+
return area
281+
274282
class BboxObject(ShapeObject):
275283
# pylint: disable=redefined-builtin
276284
def __init__(self, x=0, y=0, w=0, h=0,

datumaro/datumaro/components/extractors/ms_coco.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __len__(self):
6161
def categories(self):
6262
return self._parent.categories()
6363

64-
def __init__(self, path, task):
64+
def __init__(self, path, task, merge_instance_polygons=False):
6565
super().__init__()
6666

6767
rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0]
@@ -80,6 +80,8 @@ def __init__(self, path, task):
8080

8181
self._load_categories()
8282

83+
self._merge_instance_polygons = merge_instance_polygons
84+
8385
@staticmethod
8486
def _make_subset_loader(path):
8587
# COCO API has an 'unclosed file' warning
@@ -212,20 +214,22 @@ def _parse_annotation(self, ann, ann_type, parsed_annotations,
212214
segmentation = ann.get('segmentation')
213215
if segmentation is not None:
214216
group = ann_id
217+
rle = None
215218

216219
if isinstance(segmentation, list):
217-
# polygon -- a single object might consist of multiple parts
220+
# polygon - a single object can consist of multiple parts
218221
for polygon_points in segmentation:
219222
parsed_annotations.append(PolygonObject(
220223
points=polygon_points, label=label_id,
221-
group=group
224+
id=ann_id, group=group, attributes=attributes
222225
))
223226

224-
# we merge all parts into one mask RLE code
225-
img_h = image_info['height']
226-
img_w = image_info['width']
227-
rles = mask_utils.frPyObjects(segmentation, img_h, img_w)
228-
rle = mask_utils.merge(rles)
227+
if self._merge_instance_polygons:
228+
# merge all parts into a single mask RLE
229+
img_h = image_info['height']
230+
img_w = image_info['width']
231+
rles = mask_utils.frPyObjects(segmentation, img_h, img_w)
232+
rle = mask_utils.merge(rles)
229233
elif isinstance(segmentation['counts'], list):
230234
# uncompressed RLE
231235
img_h, img_w = segmentation['size']
@@ -234,9 +238,10 @@ def _parse_annotation(self, ann, ann_type, parsed_annotations,
234238
# compressed RLE
235239
rle = segmentation
236240

237-
parsed_annotations.append(RleMask(rle=rle, label=label_id,
238-
group=group
239-
))
241+
if rle is not None:
242+
parsed_annotations.append(RleMask(rle=rle, label=label_id,
243+
id=ann_id, group=group, attributes=attributes
244+
))
240245

241246
parsed_annotations.append(
242247
BboxObject(x, y, w, h, label=label_id,
@@ -277,21 +282,22 @@ def _parse_annotation(self, ann, ann_type, parsed_annotations,
277282
return parsed_annotations
278283

279284
class CocoImageInfoExtractor(CocoExtractor):
280-
def __init__(self, path):
281-
super().__init__(path, task=CocoAnnotationType.image_info)
285+
def __init__(self, path, **kwargs):
286+
super().__init__(path, task=CocoAnnotationType.image_info, **kwargs)
282287

283288
class CocoCaptionsExtractor(CocoExtractor):
284-
def __init__(self, path):
285-
super().__init__(path, task=CocoAnnotationType.captions)
289+
def __init__(self, path, **kwargs):
290+
super().__init__(path, task=CocoAnnotationType.captions, **kwargs)
286291

287292
class CocoInstancesExtractor(CocoExtractor):
288-
def __init__(self, path):
289-
super().__init__(path, task=CocoAnnotationType.instances)
293+
def __init__(self, path, **kwargs):
294+
super().__init__(path, task=CocoAnnotationType.instances, **kwargs)
290295

291296
class CocoPersonKeypointsExtractor(CocoExtractor):
292-
def __init__(self, path):
293-
super().__init__(path, task=CocoAnnotationType.person_keypoints)
297+
def __init__(self, path, **kwargs):
298+
super().__init__(path, task=CocoAnnotationType.person_keypoints,
299+
**kwargs)
294300

295301
class CocoLabelsExtractor(CocoExtractor):
296-
def __init__(self, path):
297-
super().__init__(path, task=CocoAnnotationType.labels)
302+
def __init__(self, path, **kwargs):
303+
super().__init__(path, task=CocoAnnotationType.labels, **kwargs)

datumaro/datumaro/components/importers/ms_coco.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CocoImporter:
2222
def __init__(self, task_filter=None):
2323
self._task_filter = task_filter
2424

25-
def __call__(self, path):
25+
def __call__(self, path, **extra_params):
2626
from datumaro.components.project import Project # cyclic import
2727
project = Project()
2828

@@ -37,6 +37,7 @@ def __call__(self, path):
3737
project.add_source(source_name, {
3838
'url': ann_file,
3939
'format': self._COCO_EXTRACTORS[ann_type],
40+
'options': extra_params,
4041
})
4142

4243
return project

0 commit comments

Comments
 (0)