Skip to content

Commit 487c60c

Browse files
authored
SDK: Add an adapter layer that presents a CVAT task as a torchvision dataset (#5417)
1 parent 82adde4 commit 487c60c

File tree

8 files changed

+577
-6
lines changed

8 files changed

+577
-6
lines changed

.github/workflows/full.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ jobs:
196196
197197
- name: Running REST API and SDK tests
198198
run: |
199-
pip3 install --user /tmp/cvat_sdk/
199+
pip3 install --user '/tmp/cvat_sdk/[pytorch]'
200200
pip3 install --user cvat-cli/
201201
pip3 install --user -r tests/python/requirements.txt
202202
pytest tests/python -s -v

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ jobs:
164164
165165
- name: Running REST API and SDK tests
166166
run: |
167-
pip3 install --user /tmp/cvat_sdk/
167+
pip3 install --user '/tmp/cvat_sdk/[pytorch]'
168168
pip3 install --user cvat-cli/
169169
pip3 install --user -r tests/python/requirements.txt
170170
pytest tests/python/ -s -v

.github/workflows/schedule.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ jobs:
235235
gen/generate.sh
236236
cd ..
237237
238-
pip3 install --user cvat-sdk/
238+
pip3 install --user 'cvat-sdk/[pytorch]'
239239
pip3 install --user cvat-cli/
240240
pip3 install --user -r tests/python/requirements.txt
241241
pytest tests/python/

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ from online detectors & interactors) (<https://github.com/opencv/cvat/pull/4543>
2121
- Authentication with social accounts google & github (<https://github.com/opencv/cvat/pull/5147>, <https://github.com/opencv/cvat/pull/5181>, <https://github.com/opencv/cvat/pull/5295>)
2222
- REST API tests to export job datasets & annotations and validate their structure (<https://github.com/opencv/cvat/pull/5160>)
2323
- Propagation backward on UI (<https://github.com/opencv/cvat/pull/5355>)
24+
- A PyTorch dataset adapter layer in the SDK
25+
(<https://github.com/opencv/cvat/pull/5417>)
2426

2527
### Changed
2628
- `api/docs`, `api/swagger`, `api/schema`, `server/about` endpoints now allow unauthorized access (<https://github.com/opencv/cvat/pull/4928>, <https://github.com/opencv/cvat/pull/4935>)

cvat-sdk/cvat_sdk/pytorch/__init__.py

Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
import base64
2+
import collections
3+
import json
4+
import os
5+
import shutil
6+
import types
7+
import zipfile
8+
from concurrent.futures import ThreadPoolExecutor
9+
from pathlib import Path
10+
from typing import (
11+
Callable,
12+
Dict,
13+
FrozenSet,
14+
List,
15+
Mapping,
16+
Optional,
17+
Sequence,
18+
Tuple,
19+
Type,
20+
TypeVar,
21+
)
22+
23+
import appdirs
24+
import attrs
25+
import attrs.validators
26+
import PIL.Image
27+
import torchvision.datasets
28+
from typing_extensions import TypedDict
29+
30+
import cvat_sdk.core
31+
import cvat_sdk.core.exceptions
32+
from cvat_sdk.api_client.model_utils import to_json
33+
from cvat_sdk.core.utils import atomic_writer
34+
from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead
35+
36+
_ModelType = TypeVar("_ModelType")
37+
38+
_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
39+
_NUM_DOWNLOAD_THREADS = 4
40+
41+
42+
class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
43+
pass
44+
45+
46+
@attrs.frozen
47+
class FrameAnnotations:
48+
"""
49+
Contains annotations that pertain to a single frame.
50+
"""
51+
52+
tags: List[LabeledImage] = attrs.Factory(list)
53+
shapes: List[LabeledShape] = attrs.Factory(list)
54+
55+
56+
@attrs.frozen
57+
class Target:
58+
"""
59+
Non-image data for a dataset sample.
60+
"""
61+
62+
annotations: FrameAnnotations
63+
"""Annotations for the frame corresponding to the sample."""
64+
65+
label_id_to_index: Mapping[int, int]
66+
"""
67+
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
68+
to an index in the range [0, num_labels), where num_labels is the number of labels
69+
defined in the task. This mapping is consistent across all samples for a given task.
70+
"""
71+
72+
73+
class TaskVisionDataset(torchvision.datasets.VisionDataset):
74+
"""
75+
Represents a task on a CVAT server as a PyTorch Dataset.
76+
77+
This dataset contains one sample for each frame in the task, in the same
78+
order as the frames are in the task. Deleted frames are omitted.
79+
Before transforms are applied, each sample is a tuple of
80+
(image, target), where:
81+
82+
* image is a `PIL.Image.Image` object for the corresponding frame.
83+
* target is a `Target` object containing annotations for the frame.
84+
85+
This class caches all data and annotations for the task on the local file system
86+
during construction. If the task is updated on the server, the cache is updated.
87+
88+
Limitations:
89+
90+
* Only tasks with image (not video) data are supported at the moment.
91+
* Track annotations are currently not accessible.
92+
"""
93+
94+
def __init__(
95+
self,
96+
client: cvat_sdk.core.Client,
97+
task_id: int,
98+
*,
99+
transforms: Optional[Callable] = None,
100+
transform: Optional[Callable] = None,
101+
target_transform: Optional[Callable] = None,
102+
) -> None:
103+
"""
104+
Creates a dataset corresponding to the task with ID `task_id` on the
105+
server that `client` is connected to.
106+
107+
`transforms`, `transform` and `target_transforms` are optional transformation
108+
functions; see the documentation for `torchvision.datasets.VisionDataset` for
109+
more information.
110+
"""
111+
112+
self._logger = client.logger
113+
114+
self._logger.info(f"Fetching task {task_id}...")
115+
self._task = client.tasks.retrieve(task_id)
116+
117+
if not self._task.size or not self._task.data_chunk_size:
118+
raise UnsupportedDatasetError("The task has no data")
119+
120+
if self._task.data_original_chunk_type != "imageset":
121+
raise UnsupportedDatasetError(
122+
f"{self.__class__.__name__} only supports tasks with image chunks;"
123+
f" current chunk type is {self._task.data_original_chunk_type!r}"
124+
)
125+
126+
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
127+
server_dir_name = (
128+
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
129+
)
130+
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"
131+
132+
self._task_dir = server_dir / f"tasks/{self._task.id}"
133+
self._initialize_task_dir()
134+
135+
super().__init__(
136+
os.fspath(self._task_dir),
137+
transforms=transforms,
138+
transform=transform,
139+
target_transform=target_transform,
140+
)
141+
142+
data_meta = self._ensure_model(
143+
"data_meta.json", DataMetaRead, self._task.get_meta, "data metadata"
144+
)
145+
self._active_frame_indexes = sorted(
146+
set(range(self._task.size)) - set(data_meta.deleted_frames)
147+
)
148+
149+
self._logger.info("Downloading chunks...")
150+
151+
self._chunk_dir = self._task_dir / "chunks"
152+
self._chunk_dir.mkdir(exist_ok=True, parents=True)
153+
154+
needed_chunks = {
155+
index // self._task.data_chunk_size for index in self._active_frame_indexes
156+
}
157+
158+
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
159+
for _ in pool.map(self._ensure_chunk, sorted(needed_chunks)):
160+
# just need to loop through all results so that any exceptions are propagated
161+
pass
162+
163+
self._logger.info("All chunks downloaded")
164+
165+
self._label_id_to_index = types.MappingProxyType(
166+
{
167+
label["id"]: label_index
168+
for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id))
169+
}
170+
)
171+
172+
annotations = self._ensure_model(
173+
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
174+
)
175+
176+
self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
177+
FrameAnnotations
178+
)
179+
180+
for tag in annotations.tags:
181+
self._frame_annotations[tag.frame].tags.append(tag)
182+
183+
for shape in annotations.shapes:
184+
self._frame_annotations[shape.frame].shapes.append(shape)
185+
186+
# TODO: tracks?
187+
188+
def _initialize_task_dir(self) -> None:
189+
task_json_path = self._task_dir / "task.json"
190+
191+
try:
192+
with open(task_json_path, "rb") as task_json_file:
193+
saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file))
194+
except Exception:
195+
self._logger.info("Task is not yet cached or the cache is corrupted")
196+
197+
# If the cache was corrupted, the directory might already be there; clear it.
198+
if self._task_dir.exists():
199+
shutil.rmtree(self._task_dir)
200+
else:
201+
if saved_task.updated_date < self._task.updated_date:
202+
self._logger.info(
203+
"Task has been updated on the server since it was cached; purging the cache"
204+
)
205+
shutil.rmtree(self._task_dir)
206+
207+
self._task_dir.mkdir(exist_ok=True, parents=True)
208+
209+
with atomic_writer(task_json_path, "w", encoding="UTF-8") as task_json_file:
210+
json.dump(to_json(self._task._model), task_json_file, indent=4)
211+
print(file=task_json_file) # add final newline
212+
213+
def _ensure_chunk(self, chunk_index: int) -> None:
214+
chunk_path = self._chunk_dir / f"{chunk_index}.zip"
215+
if chunk_path.exists():
216+
return # already downloaded previously
217+
218+
self._logger.info(f"Downloading chunk #{chunk_index}...")
219+
220+
with atomic_writer(chunk_path, "wb") as chunk_file:
221+
self._task.download_chunk(chunk_index, chunk_file, quality="original")
222+
223+
def _ensure_model(
224+
self,
225+
filename: str,
226+
model_type: Type[_ModelType],
227+
download: Callable[[], _ModelType],
228+
model_description: str,
229+
) -> _ModelType:
230+
path = self._task_dir / filename
231+
232+
try:
233+
with open(path, "rb") as f:
234+
model = model_type._new_from_openapi_data(**json.load(f))
235+
self._logger.info(f"Loaded {model_description} from cache")
236+
return model
237+
except FileNotFoundError:
238+
pass
239+
except Exception:
240+
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
241+
242+
self._logger.info(f"Downloading {model_description}...")
243+
model = download()
244+
self._logger.info(f"Downloaded {model_description}")
245+
246+
with atomic_writer(path, "w", encoding="UTF-8") as f:
247+
json.dump(to_json(model), f, indent=4)
248+
print(file=f) # add final newline
249+
250+
return model
251+
252+
def __getitem__(self, sample_index: int):
253+
"""
254+
Returns the sample with index `sample_index`.
255+
256+
`sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
257+
"""
258+
259+
frame_index = self._active_frame_indexes[sample_index]
260+
chunk_index = frame_index // self._task.data_chunk_size
261+
member_index = frame_index % self._task.data_chunk_size
262+
263+
with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
264+
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
265+
sample_image = PIL.Image.open(chunk_member)
266+
sample_image.load()
267+
268+
sample_target = Target(
269+
annotations=self._frame_annotations[frame_index],
270+
label_id_to_index=self._label_id_to_index,
271+
)
272+
273+
if self.transforms:
274+
sample_image, sample_target = self.transforms(sample_image, sample_target)
275+
return sample_image, sample_target
276+
277+
def __len__(self) -> int:
278+
"""Returns the number of samples in the dataset."""
279+
return len(self._active_frame_indexes)
280+
281+
282+
@attrs.frozen
283+
class ExtractSingleLabelIndex:
284+
"""
285+
A target transform that takes a `Target` object and produces a single label index
286+
based on the tag in that object.
287+
288+
This makes the dataset samples compatible with the image classification networks
289+
in torchvision.
290+
291+
If the annotations contain no tags, or multiple tags, raises a `ValueError`.
292+
"""
293+
294+
def __call__(self, target: Target) -> int:
295+
tags = target.annotations.tags
296+
if not tags:
297+
raise ValueError("sample has no tags")
298+
299+
if len(tags) > 1:
300+
raise ValueError("sample has multiple tags")
301+
302+
return target.label_id_to_index[tags[0].label_id]
303+
304+
305+
class LabeledBoxes(TypedDict):
306+
boxes: Sequence[Tuple[float, float, float, float]]
307+
labels: Sequence[int]
308+
309+
310+
_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
311+
312+
313+
@attrs.frozen
314+
class ExtractBoundingBoxes:
315+
"""
316+
A target transform that takes a `Target` object and returns a dictionary compatible
317+
with the object detection networks in torchvision.
318+
319+
The dictionary contains the following entries:
320+
321+
"boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape
322+
in the annotations.
323+
"labels": a sequence of corresponding label indices.
324+
325+
Limitations:
326+
327+
* Only the following shape types are supported: rectangle, polygon, polyline,
328+
points, ellipse.
329+
* Rotated shapes are not supported.
330+
331+
Unsupported shapes will cause a `UnsupportedDatasetError` exception to be
332+
raised unless they are filtered out by `include_shape_types`.
333+
"""
334+
335+
include_shape_types: FrozenSet[str] = attrs.field(
336+
converter=frozenset,
337+
validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)),
338+
kw_only=True,
339+
)
340+
"""Shapes whose type is not in this set will be ignored."""
341+
342+
def __call__(self, target: Target) -> LabeledBoxes:
343+
boxes = []
344+
labels = []
345+
346+
for shape in target.annotations.shapes:
347+
if shape.type.value not in self.include_shape_types:
348+
continue
349+
350+
if shape.rotation != 0:
351+
raise UnsupportedDatasetError("Rotated shapes are not supported")
352+
353+
x_coords = shape.points[0::2]
354+
y_coords = shape.points[1::2]
355+
356+
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
357+
labels.append(target.label_id_to_index[shape.label_id])
358+
359+
return LabeledBoxes(boxes=boxes, labels=labels)

0 commit comments

Comments
 (0)