Skip to content

Commit 2ebca5b

Browse files
authored
[Datumaro] Dataset format auto detection (#1242)
* Add dataset format detection * Add auto format detection for import * Split VOC extractor
1 parent 24130cd commit 2ebca5b

File tree

17 files changed

+572
-857
lines changed

17 files changed

+572
-857
lines changed

datumaro/datumaro/cli/contexts/project/__init__.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser):
132132
help="Overwrite existing files in the save directory")
133133
parser.add_argument('-i', '--input-path', required=True, dest='source',
134134
help="Path to import project from")
135-
parser.add_argument('-f', '--format', required=True,
136-
help="Source project format")
135+
parser.add_argument('-f', '--format',
136+
help="Source project format. Will try to detect, if not specified.")
137137
parser.add_argument('extra_args', nargs=argparse.REMAINDER,
138138
help="Additional arguments for importer (pass '-- -h' for help)")
139139
parser.set_defaults(command=import_command)
@@ -164,22 +164,53 @@ def import_command(args):
164164
if project_name is None:
165165
project_name = osp.basename(project_dir)
166166

167-
try:
168-
env = Environment()
169-
importer = env.make_importer(args.format)
170-
except KeyError:
171-
raise CliException("Importer for format '%s' is not found" % \
172-
args.format)
173-
174-
extra_args = {}
175-
if hasattr(importer, 'from_cmdline'):
176-
extra_args = importer.from_cmdline(args.extra_args)
167+
env = Environment()
168+
log.info("Importing project from '%s'" % args.source)
169+
170+
if not args.format:
171+
if args.extra_args:
172+
raise CliException("Extra args can not be used without format")
173+
174+
log.info("Trying to detect dataset format...")
175+
176+
matches = []
177+
for format_name in env.importers.items:
178+
log.debug("Checking '%s' format...", format_name)
179+
importer = env.make_importer(format_name)
180+
try:
181+
match = importer.detect(args.source)
182+
if match:
183+
log.debug("format matched")
184+
matches.append((format_name, importer))
185+
except NotImplementedError:
186+
log.debug("Format '%s' does not support auto detection.",
187+
format_name)
188+
189+
if len(matches) == 0:
190+
log.error("Failed to detect dataset format automatically. "
191+
"Try to specify format with '-f/--format' parameter.")
192+
return 1
193+
elif len(matches) != 1:
194+
log.error("Multiple formats match the dataset: %s. "
195+
"Try to specify format with '-f/--format' parameter.",
196+
', '.join(m[0] for m in matches))
197+
return 2
198+
199+
format_name, importer = matches[0]
200+
args.format = format_name
201+
else:
202+
try:
203+
importer = env.make_importer(args.format)
204+
if hasattr(importer, 'from_cmdline'):
205+
extra_args = importer.from_cmdline(args.extra_args)
206+
except KeyError:
207+
raise CliException("Importer for format '%s' is not found" % \
208+
args.format)
177209

178-
log.info("Importing project from '%s' as '%s'" % \
179-
(args.source, args.format))
210+
log.info("Importing project as '%s'" % args.format)
180211

181212
source = osp.abspath(args.source)
182-
project = importer(source, **extra_args)
213+
project = importer(source, **locals().get('extra_args', {}))
183214
project.config.project_name = project_name
184215
project.config.project_dir = project_dir
185216

datumaro/datumaro/components/extractor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,10 @@ class SourceExtractor(Extractor):
743743
pass
744744

745745
class Importer:
746+
@classmethod
747+
def detect(cls, path):
748+
raise NotImplementedError()
749+
746750
def __call__(self, path, **extra_params):
747751
raise NotImplementedError()
748752

datumaro/datumaro/plugins/coco_format/importer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os.path as osp
1010

1111
from datumaro.components.extractor import Importer
12+
from datumaro.util.log_utils import logging_disabled
1213

1314
from .format import CocoTask, CocoPath
1415

@@ -22,6 +23,11 @@ class CocoImporter(Importer):
2223
CocoTask.image_info: 'coco_image_info',
2324
}
2425

26+
@classmethod
27+
def detect(cls, path):
28+
with logging_disabled(log.WARN):
29+
return len(cls.find_subsets(path)) != 0
30+
2531
def __call__(self, path, **extra_params):
2632
from datumaro.components.project import Project # cyclic import
2733
project = Project()
@@ -53,7 +59,7 @@ def find_subsets(path):
5359

5460
if osp.basename(osp.normpath(path)) != CocoPath.ANNOTATIONS_DIR:
5561
path = osp.join(path, CocoPath.ANNOTATIONS_DIR)
56-
subset_paths += glob(osp.join(path, '*_*.json'))
62+
subset_paths += glob(osp.join(path, '*_*.json'))
5763

5864
subsets = defaultdict(dict)
5965
for subset_path in subset_paths:

datumaro/datumaro/plugins/cvat_format/importer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,15 @@
1515
class CvatImporter(Importer):
1616
EXTRACTOR_NAME = 'cvat'
1717

18+
@classmethod
19+
def detect(cls, path):
20+
return len(cls.find_subsets(path)) != 0
21+
1822
def __call__(self, path, **extra_params):
1923
from datumaro.components.project import Project # cyclic import
2024
project = Project()
2125

22-
if path.endswith('.xml') and osp.isfile(path):
23-
subset_paths = [path]
24-
else:
25-
subset_paths = glob(osp.join(path, '*.xml'))
26-
27-
if osp.basename(osp.normpath(path)) != CvatPath.ANNOTATIONS_DIR:
28-
path = osp.join(path, CvatPath.ANNOTATIONS_DIR)
29-
subset_paths += glob(osp.join(path, '*.xml'))
26+
subset_paths = self.find_subsets(path)
3027

3128
if len(subset_paths) == 0:
3229
raise Exception("Failed to find 'cvat' dataset at '%s'" % path)
@@ -46,3 +43,15 @@ def __call__(self, path, **extra_params):
4643
})
4744

4845
return project
46+
47+
@staticmethod
48+
def find_subsets(path):
49+
if path.endswith('.xml') and osp.isfile(path):
50+
subset_paths = [path]
51+
else:
52+
subset_paths = glob(osp.join(path, '*.xml'))
53+
54+
if osp.basename(osp.normpath(path)) != CvatPath.ANNOTATIONS_DIR:
55+
path = osp.join(path, CvatPath.ANNOTATIONS_DIR)
56+
subset_paths += glob(osp.join(path, '*.xml'))
57+
return subset_paths

datumaro/datumaro/plugins/datumaro_format/importer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
class DatumaroImporter(Importer):
1616
EXTRACTOR_NAME = 'datumaro'
1717

18+
@classmethod
19+
def detect(cls, path):
20+
return len(cls.find_subsets(path)) != 0
21+
1822
def __call__(self, path, **extra_params):
1923
from datumaro.components.project import Project # cyclic import
2024
project = Project()
2125

22-
if path.endswith('.json') and osp.isfile(path):
23-
subset_paths = [path]
24-
else:
25-
subset_paths = glob(osp.join(path, '*.json'))
26-
27-
if osp.basename(osp.normpath(path)) != DatumaroPath.ANNOTATIONS_DIR:
28-
path = osp.join(path, DatumaroPath.ANNOTATIONS_DIR)
29-
subset_paths += glob(osp.join(path, '*.json'))
30-
26+
subset_paths = self.find_subsets(path)
3127
if len(subset_paths) == 0:
3228
raise Exception("Failed to find 'datumaro' dataset at '%s'" % path)
3329

@@ -46,3 +42,15 @@ def __call__(self, path, **extra_params):
4642
})
4743

4844
return project
45+
46+
@staticmethod
47+
def find_subsets(path):
48+
if path.endswith('.json') and osp.isfile(path):
49+
subset_paths = [path]
50+
else:
51+
subset_paths = glob(osp.join(path, '*.json'))
52+
53+
if osp.basename(osp.normpath(path)) != DatumaroPath.ANNOTATIONS_DIR:
54+
path = osp.join(path, DatumaroPath.ANNOTATIONS_DIR)
55+
subset_paths += glob(osp.join(path, '*.json'))
56+
return subset_paths

datumaro/datumaro/plugins/tf_detection_api_format/importer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
class TfDetectionApiImporter(Importer):
1414
EXTRACTOR_NAME = 'tf_detection_api'
1515

16+
@classmethod
17+
def detect(cls, path):
18+
return len(cls.find_subsets(path)) != 0
19+
1620
def __call__(self, path, **extra_params):
1721
from datumaro.components.project import Project # cyclic import
1822
project = Project()
1923

20-
if path.endswith('.tfrecord') and osp.isfile(path):
21-
subset_paths = [path]
22-
else:
23-
subset_paths = glob(osp.join(path, '*.tfrecord'))
24-
24+
subset_paths = self.find_subsets(path)
2525
if len(subset_paths) == 0:
2626
raise Exception(
2727
"Failed to find 'tf_detection_api' dataset at '%s'" % path)
@@ -42,3 +42,10 @@ def __call__(self, path, **extra_params):
4242

4343
return project
4444

45+
@staticmethod
46+
def find_subsets(path):
47+
if path.endswith('.tfrecord') and osp.isfile(path):
48+
subset_paths = [path]
49+
else:
50+
subset_paths = glob(osp.join(path, '*.tfrecord'))
51+
return subset_paths

datumaro/datumaro/plugins/voc_format/converter.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def save_subsets(self):
317317
self.save_segm_lists(subset_name, segm_list)
318318

319319
def save_action_lists(self, subset_name, action_list):
320+
if not action_list:
321+
return
322+
320323
os.makedirs(self._action_subsets_dir, exist_ok=True)
321324

322325
ann_file = osp.join(self._action_subsets_dir, subset_name + '.txt')
@@ -342,11 +345,11 @@ def save_action_lists(self, subset_name, action_list):
342345
(item, 1 + obj_id, 1 if presented else -1))
343346

344347
def save_class_lists(self, subset_name, class_lists):
345-
os.makedirs(self._cls_subsets_dir, exist_ok=True)
346-
347-
if len(class_lists) == 0:
348+
if not class_lists:
348349
return
349350

351+
os.makedirs(self._cls_subsets_dir, exist_ok=True)
352+
350353
for label in self._label_map:
351354
ann_file = osp.join(self._cls_subsets_dir,
352355
'%s_%s.txt' % (label, subset_name))
@@ -360,6 +363,9 @@ def save_class_lists(self, subset_name, class_lists):
360363
f.write('%s % d\n' % (item, 1 if presented else -1))
361364

362365
def save_clsdet_lists(self, subset_name, clsdet_list):
366+
if not clsdet_list:
367+
return
368+
363369
os.makedirs(self._cls_subsets_dir, exist_ok=True)
364370

365371
ann_file = osp.join(self._cls_subsets_dir, subset_name + '.txt')
@@ -368,6 +374,9 @@ def save_clsdet_lists(self, subset_name, clsdet_list):
368374
f.write('%s\n' % item)
369375

370376
def save_segm_lists(self, subset_name, segm_list):
377+
if not segm_list:
378+
return
379+
371380
os.makedirs(self._segm_subsets_dir, exist_ok=True)
372381

373382
ann_file = osp.join(self._segm_subsets_dir, subset_name + '.txt')
@@ -376,6 +385,9 @@ def save_segm_lists(self, subset_name, segm_list):
376385
f.write('%s\n' % item)
377386

378387
def save_layout_lists(self, subset_name, layout_list):
388+
if not layout_list:
389+
return
390+
379391
os.makedirs(self._layout_subsets_dir, exist_ok=True)
380392

381393
ann_file = osp.join(self._layout_subsets_dir, subset_name + '.txt')

0 commit comments

Comments
 (0)