Skip to content

Commit 8b301e1

Browse files
zhiltsov-maxChris Lee-Messer
authored and
Chris Lee-Messer
committed
[Datumaro] Fix TFrecord converter constructor (cvat-ai#993)
1 parent a8f186f commit 8b301e1

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

cvat/apps/engine/views.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,6 @@ def frame(self, request, pk, frame):
593593
@action(detail=True, methods=['GET'], serializer_class=None,
594594
url_path='dataset')
595595
def dataset_export(self, request, pk):
596-
597596
db_task = self.get_object()
598597

599598
action = request.query_params.get("action", "")
@@ -611,7 +610,7 @@ def dataset_export(self, request, pk):
611610
raise serializers.ValidationError(
612611
"Unexpected parameter 'format' specified for the request")
613612

614-
rq_id = "task_dataset_export.{}.{}".format(pk, dst_format)
613+
rq_id = "/api/v1/tasks/{}/dataset/{}".format(pk, dst_format)
615614
queue = django_rq.get_queue("default")
616615

617616
rq_job = queue.fetch_job(rq_id)

datumaro/datumaro/components/converters/tfrecord.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,27 @@ def float_list_feature(value):
9898
return tf_example
9999

100100
class DetectionApiConverter:
101-
def __init__(self, save_images=True):
102-
self.save_images = save_images
101+
def __init__(self, save_images=False, cmdline_args=None):
102+
super().__init__()
103+
104+
self._save_images = save_images
105+
106+
if cmdline_args is not None:
107+
options = self._parse_cmdline(cmdline_args)
108+
for k, v in options.items():
109+
if hasattr(self, '_' + str(k)):
110+
setattr(self, '_' + str(k), v)
111+
112+
@classmethod
113+
def build_cmdline_parser(cls, parser=None):
114+
import argparse
115+
if not parser:
116+
parser = argparse.ArgumentParser()
117+
118+
parser.add_argument('--save-images', action='store_true',
119+
help="Save images (default: %(default)s)")
120+
121+
return parser
103122

104123
def __call__(self, extractor, save_dir):
105124
tf = _import_tf()
@@ -141,6 +160,6 @@ def __call__(self, extractor, save_dir):
141160
item,
142161
get_label=get_label,
143162
get_label_id=map_label_id,
144-
save_images=self.save_images,
163+
save_images=self._save_images,
145164
)
146165
writer.write(tf_example.SerializeToString())

datumaro/tests/test_tfrecord_format.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def categories(self):
112112

113113
with TestDir() as test_dir:
114114
self._test_can_save_and_load(
115-
TestExtractor(), DetectionApiConverter(), test_dir)
115+
TestExtractor(), DetectionApiConverter(save_images=True),
116+
test_dir)
116117

117118
def test_labelmap_parsing(self):
118119
text = """

0 commit comments

Comments
 (0)