Skip to content

Commit a376ee7

Browse files
authored
[Datumaro] Introduce image info (#1140)
* Employ transforms and item wrapper * Add image class and tests * Add image info support to formats * Fix cli * Fix merge and voc converte * Update remote images extractor * Codacy * Remove item name, require path in Image * Merge images of dataset items * Update tests * Add image dir converter * Update Datumaro format * Update COCO format with image info * Update CVAT format with image info * Update TFrecord format with image info * Update VOC formar with image info * Update YOLO format with image info * Update dataset manager bindings with image info * Add image name to id transform * Fix coco export
1 parent 0db48af commit a376ee7

36 files changed

+848
-487
lines changed

cvat/apps/dataset_manager/bindings.py

+100-67
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from cvat.apps.engine.models import Task, ShapeType, AttributeType
1515

1616
import datumaro.components.extractor as datumaro
17-
from datumaro.util.image import lazy_image
17+
from datumaro.util.image import Image
1818

1919

2020
class CvatImagesDirExtractor(datumaro.Extractor):
@@ -29,8 +29,7 @@ def __init__(self, url):
2929
path = osp.join(dirpath, name)
3030
if self._is_image(path):
3131
item_id = Task.get_image_frame(path)
32-
item = datumaro.DatasetItem(
33-
id=item_id, image=lazy_image(path))
32+
item = datumaro.DatasetItem(id=item_id, image=path)
3433
items.append((item.id, item))
3534

3635
items = sorted(items, key=lambda e: int(e[0]))
@@ -49,112 +48,90 @@ def __len__(self):
4948
def subsets(self):
5049
return self._subsets
5150

52-
def get(self, item_id, subset=None, path=None):
53-
if path or subset:
54-
raise KeyError()
55-
return self._items[item_id]
56-
5751
def _is_image(self, path):
5852
for ext in self._SUPPORTED_FORMATS:
5953
if osp.isfile(path) and path.endswith(ext):
6054
return True
6155
return False
6256

6357

64-
class CvatTaskExtractor(datumaro.Extractor):
65-
def __init__(self, url, db_task, user):
66-
self._db_task = db_task
67-
self._categories = self._load_categories()
68-
69-
cvat_annotations = TaskAnnotation(db_task.id, user)
70-
with transaction.atomic():
71-
cvat_annotations.init_from_db()
72-
cvat_annotations = Annotation(cvat_annotations.ir_data, db_task)
58+
class CvatAnnotationsExtractor(datumaro.Extractor):
59+
def __init__(self, url, cvat_annotations):
60+
self._categories = self._load_categories(cvat_annotations)
7361

7462
dm_annotations = []
7563

76-
for cvat_anno in cvat_annotations.group_by_frame():
77-
dm_anno = self._read_cvat_anno(cvat_anno)
78-
dm_item = datumaro.DatasetItem(
79-
id=cvat_anno.frame, annotations=dm_anno)
64+
for cvat_frame_anno in cvat_annotations.group_by_frame():
65+
dm_anno = self._read_cvat_anno(cvat_frame_anno, cvat_annotations)
66+
dm_image = Image(path=cvat_frame_anno.name, size=(
67+
cvat_frame_anno.height, cvat_frame_anno.width)
68+
)
69+
dm_item = datumaro.DatasetItem(id=cvat_frame_anno.frame,
70+
annotations=dm_anno, image=dm_image)
8071
dm_annotations.append((dm_item.id, dm_item))
8172

8273
dm_annotations = sorted(dm_annotations, key=lambda e: int(e[0]))
8374
self._items = OrderedDict(dm_annotations)
8475

85-
self._subsets = None
86-
8776
def __iter__(self):
8877
for item in self._items.values():
8978
yield item
9079

9180
def __len__(self):
9281
return len(self._items)
9382

83+
# pylint: disable=no-self-use
9484
def subsets(self):
95-
return self._subsets
85+
return []
86+
# pylint: enable=no-self-use
9687

97-
def get(self, item_id, subset=None, path=None):
98-
if path or subset:
99-
raise KeyError()
100-
return self._items[item_id]
88+
def categories(self):
89+
return self._categories
10190

102-
def _load_categories(self):
91+
@staticmethod
92+
def _load_categories(cvat_anno):
10393
categories = {}
10494
label_categories = datumaro.LabelCategories()
10595

106-
db_labels = self._db_task.label_set.all()
107-
for db_label in db_labels:
108-
db_attributes = db_label.attributespec_set.all()
109-
label_categories.add(db_label.name)
110-
111-
for db_attr in db_attributes:
112-
label_categories.attributes.add(db_attr.name)
96+
for _, label in cvat_anno.meta['task']['labels']:
97+
label_categories.add(label['name'])
98+
for _, attr in label['attributes']:
99+
label_categories.attributes.add(attr['name'])
113100

114101
categories[datumaro.AnnotationType.label] = label_categories
115102

116103
return categories
117104

118-
def categories(self):
119-
return self._categories
120-
121-
def _read_cvat_anno(self, cvat_anno):
105+
def _read_cvat_anno(self, cvat_frame_anno, cvat_task_anno):
122106
item_anno = []
123107

124108
categories = self.categories()
125109
label_cat = categories[datumaro.AnnotationType.label]
126-
127-
label_map = {}
128-
label_attrs = {}
129-
db_labels = self._db_task.label_set.all()
130-
for db_label in db_labels:
131-
label_map[db_label.name] = label_cat.find(db_label.name)[0]
132-
133-
attrs = {}
134-
db_attributes = db_label.attributespec_set.all()
135-
for db_attr in db_attributes:
136-
attrs[db_attr.name] = db_attr
137-
label_attrs[db_label.name] = attrs
138-
map_label = lambda label_db_name: label_map[label_db_name]
110+
map_label = lambda name: label_cat.find(name)[0]
111+
label_attrs = {
112+
label['name']: label['attributes']
113+
for _, label in cvat_task_anno.meta['task']['labels']
114+
}
139115

140116
def convert_attrs(label, cvat_attrs):
141117
cvat_attrs = {a.name: a.value for a in cvat_attrs}
142118
dm_attr = dict()
143-
for attr_name, attr_spec in label_attrs[label].items():
144-
attr_value = cvat_attrs.get(attr_name, attr_spec.default_value)
119+
for _, a_desc in label_attrs[label]:
120+
a_name = a_desc['name']
121+
a_value = cvat_attrs.get(a_name, a_desc['default_value'])
145122
try:
146-
if attr_spec.input_type == AttributeType.NUMBER:
147-
attr_value = float(attr_value)
148-
elif attr_spec.input_type == AttributeType.CHECKBOX:
149-
attr_value = attr_value.lower() == 'true'
150-
dm_attr[attr_name] = attr_value
123+
if a_desc['input_type'] == AttributeType.NUMBER:
124+
a_value = float(a_value)
125+
elif a_desc['input_type'] == AttributeType.CHECKBOX:
126+
a_value = (a_value.lower() == 'true')
127+
dm_attr[a_name] = a_value
151128
except Exception as e:
152-
slogger.task[self._db_task.id].error(
153-
"Failed to convert attribute '%s'='%s': %s" % \
154-
(attr_name, attr_value, e))
129+
raise Exception(
130+
"Failed to convert attribute '%s'='%s': %s" %
131+
(a_name, a_value, e))
155132
return dm_attr
156133

157-
for tag_obj in cvat_anno.tags:
134+
for tag_obj in cvat_frame_anno.tags:
158135
anno_group = tag_obj.group
159136
anno_label = map_label(tag_obj.label)
160137
anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes)
@@ -163,7 +140,7 @@ def convert_attrs(label, cvat_attrs):
163140
attributes=anno_attr, group=anno_group)
164141
item_anno.append(anno)
165142

166-
for shape_obj in cvat_anno.labeled_shapes:
143+
for shape_obj in cvat_frame_anno.labeled_shapes:
167144
anno_group = shape_obj.group
168145
anno_label = map_label(shape_obj.label)
169146
anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes)
@@ -183,8 +160,64 @@ def convert_attrs(label, cvat_attrs):
183160
anno = datumaro.Bbox(x0, y0, x1 - x0, y1 - y0,
184161
label=anno_label, attributes=anno_attr, group=anno_group)
185162
else:
186-
raise Exception("Unknown shape type '%s'" % (shape_obj.type))
163+
raise Exception("Unknown shape type '%s'" % shape_obj.type)
187164

188165
item_anno.append(anno)
189166

190-
return item_anno
167+
return item_anno
168+
169+
170+
class CvatTaskExtractor(CvatAnnotationsExtractor):
171+
def __init__(self, url, db_task, user):
172+
cvat_annotations = TaskAnnotation(db_task.id, user)
173+
with transaction.atomic():
174+
cvat_annotations.init_from_db()
175+
cvat_annotations = Annotation(cvat_annotations.ir_data, db_task)
176+
super().__init__(url, cvat_annotations)
177+
178+
179+
def match_frame(item, cvat_task_anno):
180+
frame_number = None
181+
if frame_number is None:
182+
try:
183+
frame_number = cvat_task_anno.match_frame(item.id)
184+
except Exception:
185+
pass
186+
if frame_number is None and item.has_image:
187+
try:
188+
frame_number = cvat_task_anno.match_frame(item.image.filename)
189+
except Exception:
190+
pass
191+
if frame_number is None:
192+
try:
193+
frame_number = int(item.id)
194+
except Exception:
195+
pass
196+
if not frame_number in cvat_task_anno.frame_info:
197+
raise Exception("Could not match item id: '%s' with any task frame" %
198+
item.id)
199+
return frame_number
200+
201+
def import_dm_annotations(dm_dataset, cvat_task_anno):
202+
shapes = {
203+
datumaro.AnnotationType.bbox: ShapeType.RECTANGLE,
204+
datumaro.AnnotationType.polygon: ShapeType.POLYGON,
205+
datumaro.AnnotationType.polyline: ShapeType.POLYLINE,
206+
datumaro.AnnotationType.points: ShapeType.POINTS,
207+
}
208+
209+
label_cat = dm_dataset.categories()[datumaro.AnnotationType.label]
210+
211+
for item in dm_dataset:
212+
frame_number = match_frame(item, cvat_task_anno)
213+
214+
for ann in item.annotations:
215+
if ann.type in shapes:
216+
cvat_task_anno.add_shape(cvat_task_anno.LabeledShape(
217+
type=shapes[ann.type],
218+
frame=frame_number,
219+
label=label_cat.items[ann.label].name,
220+
points=ann.points,
221+
occluded=False,
222+
attributes=[],
223+
))

cvat/apps/dataset_manager/export_templates/plugins/cvat_rest_api_task_images.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
SchemaBuilder as _SchemaBuilder,
1414
)
1515
import datumaro.components.extractor as datumaro
16-
from datumaro.util.image import lazy_image, load_image
16+
from datumaro.util.image import lazy_image, load_image, Image
1717

1818
from cvat.utils.cli.core import CLI as CVAT_CLI, CVAT_API_V1
1919

@@ -103,8 +103,11 @@ def __init__(self, url):
103103
items = []
104104
for entry in image_list:
105105
item_id = entry['id']
106-
item = datumaro.DatasetItem(
107-
id=item_id, image=self._make_image_loader(item_id))
106+
size = None
107+
if entry.get('height') and entry.get('width'):
108+
size = (entry['height'], entry['width'])
109+
image = Image(data=self._make_image_loader(item_id), size=size)
110+
item = datumaro.DatasetItem(id=item_id, image=image)
108111
items.append((item.id, item))
109112

110113
items = sorted(items, key=lambda e: int(e[0]))

datumaro/datumaro/cli/contexts/project/__init__.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,17 @@ def import_command(args):
156156
if project_name is None:
157157
project_name = osp.basename(project_dir)
158158

159-
extra_args = {}
160159
try:
161160
env = Environment()
162161
importer = env.make_importer(args.format)
163-
if hasattr(importer, 'from_cmdline'):
164-
extra_args = importer.from_cmdline(args.extra_args)
165162
except KeyError:
166163
raise CliException("Importer for format '%s' is not found" % \
167164
args.format)
168165

166+
extra_args = {}
167+
if hasattr(importer, 'from_cmdline'):
168+
extra_args = importer.from_cmdline(args.extra_args)
169+
169170
log.info("Importing project from '%s' as '%s'" % \
170171
(args.source, args.format))
171172

@@ -293,13 +294,14 @@ def export_command(args):
293294

294295
try:
295296
converter = project.env.converters.get(args.format)
296-
if hasattr(converter, 'from_cmdline'):
297-
extra_args = converter.from_cmdline(args.extra_args)
298-
converter = converter(**extra_args)
299297
except KeyError:
300298
raise CliException("Converter for format '%s' is not found" % \
301299
args.format)
302300

301+
if hasattr(converter, 'from_cmdline'):
302+
extra_args = converter.from_cmdline(args.extra_args)
303+
converter = converter(**extra_args)
304+
303305
filter_args = FilterModes.make_filter_args(args.filter_mode)
304306

305307
log.info("Loading the project...")
@@ -559,14 +561,15 @@ def transform_command(args):
559561
(project.config.project_name, make_file_name(args.transform)))
560562
dst_dir = osp.abspath(dst_dir)
561563

562-
extra_args = {}
563564
try:
564565
transform = project.env.transforms.get(args.transform)
565-
if hasattr(transform, 'from_cmdline'):
566-
extra_args = transform.from_cmdline(args.extra_args)
567566
except KeyError:
568567
raise CliException("Transform '%s' is not found" % args.transform)
569568

569+
extra_args = {}
570+
if hasattr(transform, 'from_cmdline'):
571+
extra_args = transform.from_cmdline(args.extra_args)
572+
570573
log.info("Loading the project...")
571574
dataset = project.make_dataset()
572575

0 commit comments

Comments
 (0)