Skip to content

Commit 192fd72

Browse files
authored
Fix creation of tasks with Git repositories via the SDK (#5409)
Fixes #4365
1 parent 8b13a2c commit 192fd72

File tree

12 files changed

+116
-65
lines changed

12 files changed

+116
-65
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ non-ascii paths while adding files from "Connected file share" (issue #4428)
9393
- Fix chart not being upgradable (<https://github.com/opencv/cvat/pull/5371>)
9494
- Broken helm chart - if using custom release name (<https://github.com/opencv/cvat/pull/5403>)
9595
- Missing source tag in project annotations (<https://github.com/opencv/cvat/pull/5408>)
96+
- Creating a task with a Git repository via the SDK
97+
(<https://github.com/opencv/cvat/issues/4365>)
9698

9799
### Security
98100
- TDB

cvat-sdk/cvat_sdk/core/client.py

+3
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ def git_create(self, task_id: int) -> str:
291291
def git_check(self, rq_id: int) -> str:
292292
return self.git + f"check/{rq_id}"
293293

294+
def git_get(self, task_id: int) -> str:
295+
return self.git + f"get/{task_id}"
296+
294297
def make_endpoint_url(
295298
self,
296299
path: str,

cvat-sdk/cvat_sdk/core/git.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def create_git_repo(
3030
post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id},
3131
headers=common_headers,
3232
)
33-
response_json = json.loads(response)
33+
response_json = json.loads(response.data)
3434
rq_id = response_json["rq_id"]
3535
client.logger.info(f"Create RQ ID: {rq_id}")
3636

cvat-sdk/cvat_sdk/core/proxies/tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def create_from_data(
339339

340340
if dataset_repository_url:
341341
git.create_git_repo(
342-
self,
342+
self._client,
343343
task_id=task.id,
344344
repo_url=dataset_repository_url,
345345
status_check_period=status_check_period,

cvat-ui/src/utils/git-utils.ts

+3
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ export async function changeRepo(taskId: number, type: string, value: any): Prom
201201
core.server
202202
.request(`${baseURL}/git/repository/${taskId}`, {
203203
method: 'PATCH',
204+
headers: {
205+
'Content-type': 'application/json',
206+
},
204207
data: JSON.stringify({
205208
type,
206209
value,

cvat/apps/dataset_repo/views.py

+42-24
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,87 @@
33
# SPDX-License-Identifier: MIT
44
import http.client
55

6-
from django.http import HttpResponseBadRequest, JsonResponse, HttpResponse
6+
from django.http import HttpResponseBadRequest, HttpResponse
77
from rules.contrib.views import permission_required, objectgetter
88

9-
from cvat.apps.iam.decorators import login_required
9+
from rest_framework.permissions import IsAuthenticated
10+
from rest_framework.response import Response
11+
from rest_framework.request import Request
12+
from rest_framework.decorators import api_view, permission_classes
13+
14+
from drf_spectacular.utils import extend_schema
15+
1016
from cvat.apps.engine.log import slogger
1117
from cvat.apps.engine import models
1218
from cvat.apps.dataset_repo.models import GitData
1319
import contextlib
1420

1521
import cvat.apps.dataset_repo.dataset_repo as CVATGit
1622
import django_rq
17-
import json
1823

19-
@login_required
24+
def _legacy_api_view(allowed_method_names=None):
25+
# Currently, the views in this file use the legacy permission-checking
26+
# approach, so this decorator disables the default DRF permission classes.
27+
# TODO: migrate to DRF permissions, make the views compatible with drf-spectacular,
28+
# and remove this decorator.
29+
def decorator(view):
30+
view = permission_classes([IsAuthenticated])(view)
31+
view = api_view(allowed_method_names)(view)
32+
view = extend_schema(exclude=True)(view)
33+
return view
34+
35+
return decorator
36+
37+
@_legacy_api_view()
2038
def check_process(request, rq_id):
2139
try:
2240
queue = django_rq.get_queue('default')
2341
rq_job = queue.fetch_job(rq_id)
2442

2543
if rq_job is not None:
2644
if rq_job.is_queued or rq_job.is_started:
27-
return JsonResponse({"status": rq_job.get_status()})
45+
return Response({"status": rq_job.get_status()})
2846
elif rq_job.is_finished:
29-
return JsonResponse({"status": rq_job.get_status()})
47+
return Response({"status": rq_job.get_status()})
3048
else:
31-
return JsonResponse({"status": rq_job.get_status(), "stderr": rq_job.exc_info})
49+
return Response({"status": rq_job.get_status(), "stderr": rq_job.exc_info})
3250
else:
33-
return JsonResponse({"status": "unknown"})
51+
return Response({"status": "unknown"})
3452
except Exception as ex:
3553
slogger.glob.error("error occurred during checking repository request with rq id {}".format(rq_id), exc_info=True)
3654
return HttpResponseBadRequest(str(ex))
3755

3856

39-
@login_required
57+
@_legacy_api_view(['POST'])
4058
@permission_required(perm=['engine.task.create'],
4159
fn=objectgetter(models.Task, 'tid'), raise_exception=True)
42-
def create(request, tid):
60+
def create(request: Request, tid):
4361
try:
4462
slogger.task[tid].info("create repository request")
45-
body = json.loads(request.body.decode('utf-8'))
63+
body = request.data
4664
path = body["path"]
47-
export_format = body["format"]
65+
export_format = body.get("format")
4866
lfs = body["lfs"]
4967
rq_id = "git.create.{}".format(tid)
5068
queue = django_rq.get_queue("default")
5169

5270
queue.enqueue_call(func = CVATGit.initial_create, args = (tid, path, export_format, lfs, request.user), job_id = rq_id)
53-
return JsonResponse({ "rq_id": rq_id })
71+
return Response({ "rq_id": rq_id })
5472
except Exception as ex:
5573
slogger.glob.error("error occurred during initial cloning repository request with rq id {}".format(rq_id), exc_info=True)
5674
return HttpResponseBadRequest(str(ex))
5775

5876

59-
@login_required
60-
def push_repository(request, tid):
77+
@_legacy_api_view()
78+
def push_repository(request: Request, tid):
6179
try:
6280
slogger.task[tid].info("push repository request")
6381

6482
rq_id = "git.push.{}".format(tid)
6583
queue = django_rq.get_queue('default')
6684
queue.enqueue_call(func = CVATGit.push, args = (tid, request.user, request.scheme, request.get_host()), job_id = rq_id)
6785

68-
return JsonResponse({ "rq_id": rq_id })
86+
return Response({ "rq_id": rq_id })
6987
except Exception as ex:
7088
with contextlib.suppress(Exception):
7189
slogger.task[tid].error("error occurred during pushing repository request",
@@ -74,24 +92,24 @@ def push_repository(request, tid):
7492
return HttpResponseBadRequest(str(ex))
7593

7694

77-
@login_required
78-
def get_repository(request, tid):
95+
@_legacy_api_view()
96+
def get_repository(request: Request, tid):
7997
try:
8098
slogger.task[tid].info("get repository request")
81-
return JsonResponse(CVATGit.get(tid, request.user))
99+
return Response(CVATGit.get(tid, request.user))
82100
except Exception as ex:
83101
with contextlib.suppress(Exception):
84102
slogger.task[tid].error("error occurred during getting repository info request",
85103
exc_info=True)
86104

87105
return HttpResponseBadRequest(str(ex))
88106

89-
@login_required
107+
@_legacy_api_view(['PATCH'])
90108
@permission_required(perm=['engine.task.access'],
91109
fn=objectgetter(models.Task, 'tid'), raise_exception=True)
92-
def update_git_repo(request, tid):
110+
def update_git_repo(request: Request, tid):
93111
try:
94-
body = json.loads(request.body.decode('utf-8'))
112+
body = request.data
95113
req_type = body["type"]
96114
value = body["value"]
97115
git_data_obj = GitData.objects.filter(task_id=tid)[0]
@@ -114,15 +132,15 @@ def update_git_repo(request, tid):
114132
return HttpResponseBadRequest(str(ex))
115133

116134

117-
@login_required
135+
@_legacy_api_view()
118136
def get_meta_info(request):
119137
try:
120138
db_git_records = GitData.objects.all()
121139
response = {}
122140
for db_git in db_git_records:
123141
response[db_git.task_id] = db_git.status
124142

125-
return JsonResponse(response, safe = False)
143+
return Response(response)
126144
except Exception as ex:
127145
slogger.glob.exception("error occurred during get meta request", exc_info = True)
128146
return HttpResponseBadRequest(str(ex))

cvat/apps/iam/decorators.py

-36
This file was deleted.

cvat/settings/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@
5353
def generate_ssh_keys():
5454
keys_dir = '{}/keys'.format(os.getcwd())
5555
ssh_dir = '{}/.ssh'.format(os.getenv('HOME'))
56-
pidfile = os.path.join(ssh_dir, 'ssh.pid')
56+
pidfile = os.path.join(keys_dir, 'ssh.pid')
5757

5858
def add_ssh_keys():
59-
IGNORE_FILES = ('README.md', 'ssh.pid')
59+
IGNORE_FILES = ('README.md',)
6060
keys_to_add = [entry.name for entry in os.scandir(ssh_dir) if entry.name not in IGNORE_FILES]
6161
keys_to_add = ' '.join(os.path.join(ssh_dir, f) for f in keys_to_add)
6262
subprocess.run(['ssh-add {}'.format(keys_to_add)], # nosec

tests/docker-compose.webhook.yml renamed to tests/docker-compose.test_servers.yml

+14
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,17 @@ services:
1515
cvat:
1616
aliases:
1717
- webhooks
18+
19+
git_server:
20+
image: alpine/git
21+
restart: always
22+
depends_on:
23+
- cvat_server
24+
entrypoint: /mnt/scripts/entrypoint.sh
25+
volumes:
26+
- ./tests/git_server/:/mnt/scripts:ro
27+
- cvat_keys:/mnt/keys:ro
28+
networks:
29+
cvat:
30+
aliases:
31+
- gitserver

tests/git_server/entrypoint.sh

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/sh
2+
3+
set -e
4+
5+
mkdir -p ~/repos/repo.git
6+
git -C ~/repos/repo.git init --bare
7+
8+
mkdir -p ~/.ssh
9+
# Authorize CVAT's client keys
10+
cat /mnt/keys/*.pub > ~/.ssh/authorized_keys
11+
12+
ssh-keygen -A
13+
exec /usr/sbin/sshd -D

tests/python/sdk/test_tasks.py

+34
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MIT
44

55
import io
6+
import json
67
import os.path as osp
78
import zipfile
89
from logging import Logger
@@ -169,6 +170,39 @@ def test_cant_create_task_with_no_data(self):
169170
assert capture.match("No media data found")
170171
assert self.stdout.getvalue() == ""
171172

173+
def test_can_create_task_with_git_repo(self, fxt_image_file: Path):
174+
pbar_out = io.StringIO()
175+
pbar = make_pbar(file=pbar_out)
176+
177+
task_spec = {
178+
"name": f"task with Git repo",
179+
"labels": [{"name": "car"}],
180+
}
181+
182+
repository_url = "root@gitserver:repos/repo.git [annotations/annot.zip]"
183+
184+
task = self.client.tasks.create_from_data(
185+
spec=task_spec,
186+
resource_type=ResourceType.LOCAL,
187+
resources=[str(fxt_image_file)],
188+
pbar=pbar,
189+
dataset_repository_url=repository_url,
190+
)
191+
192+
assert task.size == 1
193+
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
194+
assert self.stdout.getvalue() == ""
195+
196+
git_get_response = self.client.api_client.rest_client.GET(
197+
self.client.api_map.git_get(task.id),
198+
headers=self.client.api_client.get_common_headers(),
199+
)
200+
201+
response_json = json.loads(git_get_response.data)
202+
assert response_json["url"]["value"] == repository_url
203+
assert response_json["format"] == "CVAT for images 1.1"
204+
assert response_json["lfs"] is False
205+
172206
def test_can_retrieve_task(self, fxt_new_task: Task):
173207
task_id = fxt_new_task.id
174208

tests/python/shared/fixtures/init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"docker-compose.dev.yml",
3333
"tests/docker-compose.file_share.yml",
3434
"tests/docker-compose.minio.yml",
35-
"tests/docker-compose.webhook.yml",
35+
"tests/docker-compose.test_servers.yml",
3636
)
3737
] + CONTAINER_NAME_FILES
3838

0 commit comments

Comments
 (0)