|
| 1 | +import base64 |
| 2 | +import collections |
| 3 | +import json |
| 4 | +import os |
| 5 | +import shutil |
| 6 | +import types |
| 7 | +import zipfile |
| 8 | +from concurrent.futures import ThreadPoolExecutor |
| 9 | +from pathlib import Path |
| 10 | +from typing import ( |
| 11 | + Callable, |
| 12 | + Dict, |
| 13 | + FrozenSet, |
| 14 | + List, |
| 15 | + Mapping, |
| 16 | + Optional, |
| 17 | + Sequence, |
| 18 | + Tuple, |
| 19 | + Type, |
| 20 | + TypeVar, |
| 21 | +) |
| 22 | + |
| 23 | +import appdirs |
| 24 | +import attrs |
| 25 | +import attrs.validators |
| 26 | +import PIL.Image |
| 27 | +import torchvision.datasets |
| 28 | +from typing_extensions import TypedDict |
| 29 | + |
| 30 | +import cvat_sdk.core |
| 31 | +import cvat_sdk.core.exceptions |
| 32 | +from cvat_sdk.api_client.model_utils import to_json |
| 33 | +from cvat_sdk.core.utils import atomic_writer |
| 34 | +from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead |
| 35 | + |
| 36 | +_ModelType = TypeVar("_ModelType") |
| 37 | + |
| 38 | +_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai")) |
| 39 | +_NUM_DOWNLOAD_THREADS = 4 |
| 40 | + |
| 41 | + |
| 42 | +class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException): |
| 43 | + pass |
| 44 | + |
| 45 | + |
| 46 | +@attrs.frozen |
| 47 | +class FrameAnnotations: |
| 48 | + """ |
| 49 | + Contains annotations that pertain to a single frame. |
| 50 | + """ |
| 51 | + |
| 52 | + tags: List[LabeledImage] = attrs.Factory(list) |
| 53 | + shapes: List[LabeledShape] = attrs.Factory(list) |
| 54 | + |
| 55 | + |
| 56 | +@attrs.frozen |
| 57 | +class Target: |
| 58 | + """ |
| 59 | + Non-image data for a dataset sample. |
| 60 | + """ |
| 61 | + |
| 62 | + annotations: FrameAnnotations |
| 63 | + """Annotations for the frame corresponding to the sample.""" |
| 64 | + |
| 65 | + label_id_to_index: Mapping[int, int] |
| 66 | + """ |
| 67 | + A mapping from label_id values in `LabeledImage` and `LabeledShape` objects |
| 68 | + to an index in the range [0, num_labels), where num_labels is the number of labels |
| 69 | + defined in the task. This mapping is consistent across all samples for a given task. |
| 70 | + """ |
| 71 | + |
| 72 | + |
| 73 | +class TaskVisionDataset(torchvision.datasets.VisionDataset): |
| 74 | + """ |
| 75 | + Represents a task on a CVAT server as a PyTorch Dataset. |
| 76 | +
|
| 77 | + This dataset contains one sample for each frame in the task, in the same |
| 78 | + order as the frames are in the task. Deleted frames are omitted. |
| 79 | + Before transforms are applied, each sample is a tuple of |
| 80 | + (image, target), where: |
| 81 | +
|
| 82 | + * image is a `PIL.Image.Image` object for the corresponding frame. |
| 83 | + * target is a `Target` object containing annotations for the frame. |
| 84 | +
|
| 85 | + This class caches all data and annotations for the task on the local file system |
| 86 | + during construction. If the task is updated on the server, the cache is updated. |
| 87 | +
|
| 88 | + Limitations: |
| 89 | +
|
| 90 | + * Only tasks with image (not video) data are supported at the moment. |
| 91 | + * Track annotations are currently not accessible. |
| 92 | + """ |
| 93 | + |
| 94 | + def __init__( |
| 95 | + self, |
| 96 | + client: cvat_sdk.core.Client, |
| 97 | + task_id: int, |
| 98 | + *, |
| 99 | + transforms: Optional[Callable] = None, |
| 100 | + transform: Optional[Callable] = None, |
| 101 | + target_transform: Optional[Callable] = None, |
| 102 | + ) -> None: |
| 103 | + """ |
| 104 | + Creates a dataset corresponding to the task with ID `task_id` on the |
| 105 | + server that `client` is connected to. |
| 106 | +
|
| 107 | + `transforms`, `transform` and `target_transforms` are optional transformation |
| 108 | + functions; see the documentation for `torchvision.datasets.VisionDataset` for |
| 109 | + more information. |
| 110 | + """ |
| 111 | + |
| 112 | + self._logger = client.logger |
| 113 | + |
| 114 | + self._logger.info(f"Fetching task {task_id}...") |
| 115 | + self._task = client.tasks.retrieve(task_id) |
| 116 | + |
| 117 | + if not self._task.size or not self._task.data_chunk_size: |
| 118 | + raise UnsupportedDatasetError("The task has no data") |
| 119 | + |
| 120 | + if self._task.data_original_chunk_type != "imageset": |
| 121 | + raise UnsupportedDatasetError( |
| 122 | + f"{self.__class__.__name__} only supports tasks with image chunks;" |
| 123 | + f" current chunk type is {self._task.data_original_chunk_type!r}" |
| 124 | + ) |
| 125 | + |
| 126 | + # Base64-encode the name to avoid FS-unsafe characters (like slashes) |
| 127 | + server_dir_name = ( |
| 128 | + base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode() |
| 129 | + ) |
| 130 | + server_dir = _CACHE_DIR / f"servers/{server_dir_name}" |
| 131 | + |
| 132 | + self._task_dir = server_dir / f"tasks/{self._task.id}" |
| 133 | + self._initialize_task_dir() |
| 134 | + |
| 135 | + super().__init__( |
| 136 | + os.fspath(self._task_dir), |
| 137 | + transforms=transforms, |
| 138 | + transform=transform, |
| 139 | + target_transform=target_transform, |
| 140 | + ) |
| 141 | + |
| 142 | + data_meta = self._ensure_model( |
| 143 | + "data_meta.json", DataMetaRead, self._task.get_meta, "data metadata" |
| 144 | + ) |
| 145 | + self._active_frame_indexes = sorted( |
| 146 | + set(range(self._task.size)) - set(data_meta.deleted_frames) |
| 147 | + ) |
| 148 | + |
| 149 | + self._logger.info("Downloading chunks...") |
| 150 | + |
| 151 | + self._chunk_dir = self._task_dir / "chunks" |
| 152 | + self._chunk_dir.mkdir(exist_ok=True, parents=True) |
| 153 | + |
| 154 | + needed_chunks = { |
| 155 | + index // self._task.data_chunk_size for index in self._active_frame_indexes |
| 156 | + } |
| 157 | + |
| 158 | + with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool: |
| 159 | + for _ in pool.map(self._ensure_chunk, sorted(needed_chunks)): |
| 160 | + # just need to loop through all results so that any exceptions are propagated |
| 161 | + pass |
| 162 | + |
| 163 | + self._logger.info("All chunks downloaded") |
| 164 | + |
| 165 | + self._label_id_to_index = types.MappingProxyType( |
| 166 | + { |
| 167 | + label["id"]: label_index |
| 168 | + for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id)) |
| 169 | + } |
| 170 | + ) |
| 171 | + |
| 172 | + annotations = self._ensure_model( |
| 173 | + "annotations.json", LabeledData, self._task.get_annotations, "annotations" |
| 174 | + ) |
| 175 | + |
| 176 | + self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict( |
| 177 | + FrameAnnotations |
| 178 | + ) |
| 179 | + |
| 180 | + for tag in annotations.tags: |
| 181 | + self._frame_annotations[tag.frame].tags.append(tag) |
| 182 | + |
| 183 | + for shape in annotations.shapes: |
| 184 | + self._frame_annotations[shape.frame].shapes.append(shape) |
| 185 | + |
| 186 | + # TODO: tracks? |
| 187 | + |
| 188 | + def _initialize_task_dir(self) -> None: |
| 189 | + task_json_path = self._task_dir / "task.json" |
| 190 | + |
| 191 | + try: |
| 192 | + with open(task_json_path, "rb") as task_json_file: |
| 193 | + saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file)) |
| 194 | + except Exception: |
| 195 | + self._logger.info("Task is not yet cached or the cache is corrupted") |
| 196 | + |
| 197 | + # If the cache was corrupted, the directory might already be there; clear it. |
| 198 | + if self._task_dir.exists(): |
| 199 | + shutil.rmtree(self._task_dir) |
| 200 | + else: |
| 201 | + if saved_task.updated_date < self._task.updated_date: |
| 202 | + self._logger.info( |
| 203 | + "Task has been updated on the server since it was cached; purging the cache" |
| 204 | + ) |
| 205 | + shutil.rmtree(self._task_dir) |
| 206 | + |
| 207 | + self._task_dir.mkdir(exist_ok=True, parents=True) |
| 208 | + |
| 209 | + with atomic_writer(task_json_path, "w", encoding="UTF-8") as task_json_file: |
| 210 | + json.dump(to_json(self._task._model), task_json_file, indent=4) |
| 211 | + print(file=task_json_file) # add final newline |
| 212 | + |
| 213 | + def _ensure_chunk(self, chunk_index: int) -> None: |
| 214 | + chunk_path = self._chunk_dir / f"{chunk_index}.zip" |
| 215 | + if chunk_path.exists(): |
| 216 | + return # already downloaded previously |
| 217 | + |
| 218 | + self._logger.info(f"Downloading chunk #{chunk_index}...") |
| 219 | + |
| 220 | + with atomic_writer(chunk_path, "wb") as chunk_file: |
| 221 | + self._task.download_chunk(chunk_index, chunk_file, quality="original") |
| 222 | + |
| 223 | + def _ensure_model( |
| 224 | + self, |
| 225 | + filename: str, |
| 226 | + model_type: Type[_ModelType], |
| 227 | + download: Callable[[], _ModelType], |
| 228 | + model_description: str, |
| 229 | + ) -> _ModelType: |
| 230 | + path = self._task_dir / filename |
| 231 | + |
| 232 | + try: |
| 233 | + with open(path, "rb") as f: |
| 234 | + model = model_type._new_from_openapi_data(**json.load(f)) |
| 235 | + self._logger.info(f"Loaded {model_description} from cache") |
| 236 | + return model |
| 237 | + except FileNotFoundError: |
| 238 | + pass |
| 239 | + except Exception: |
| 240 | + self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True) |
| 241 | + |
| 242 | + self._logger.info(f"Downloading {model_description}...") |
| 243 | + model = download() |
| 244 | + self._logger.info(f"Downloaded {model_description}") |
| 245 | + |
| 246 | + with atomic_writer(path, "w", encoding="UTF-8") as f: |
| 247 | + json.dump(to_json(model), f, indent=4) |
| 248 | + print(file=f) # add final newline |
| 249 | + |
| 250 | + return model |
| 251 | + |
| 252 | + def __getitem__(self, sample_index: int): |
| 253 | + """ |
| 254 | + Returns the sample with index `sample_index`. |
| 255 | +
|
| 256 | + `sample_index` must satisfy the condition `0 <= sample_index < len(self)`. |
| 257 | + """ |
| 258 | + |
| 259 | + frame_index = self._active_frame_indexes[sample_index] |
| 260 | + chunk_index = frame_index // self._task.data_chunk_size |
| 261 | + member_index = frame_index % self._task.data_chunk_size |
| 262 | + |
| 263 | + with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip: |
| 264 | + with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member: |
| 265 | + sample_image = PIL.Image.open(chunk_member) |
| 266 | + sample_image.load() |
| 267 | + |
| 268 | + sample_target = Target( |
| 269 | + annotations=self._frame_annotations[frame_index], |
| 270 | + label_id_to_index=self._label_id_to_index, |
| 271 | + ) |
| 272 | + |
| 273 | + if self.transforms: |
| 274 | + sample_image, sample_target = self.transforms(sample_image, sample_target) |
| 275 | + return sample_image, sample_target |
| 276 | + |
| 277 | + def __len__(self) -> int: |
| 278 | + """Returns the number of samples in the dataset.""" |
| 279 | + return len(self._active_frame_indexes) |
| 280 | + |
| 281 | + |
| 282 | +@attrs.frozen |
| 283 | +class ExtractSingleLabelIndex: |
| 284 | + """ |
| 285 | + A target transform that takes a `Target` object and produces a single label index |
| 286 | + based on the tag in that object. |
| 287 | +
|
| 288 | + This makes the dataset samples compatible with the image classification networks |
| 289 | + in torchvision. |
| 290 | +
|
| 291 | + If the annotations contain no tags, or multiple tags, raises a `ValueError`. |
| 292 | + """ |
| 293 | + |
| 294 | + def __call__(self, target: Target) -> int: |
| 295 | + tags = target.annotations.tags |
| 296 | + if not tags: |
| 297 | + raise ValueError("sample has no tags") |
| 298 | + |
| 299 | + if len(tags) > 1: |
| 300 | + raise ValueError("sample has multiple tags") |
| 301 | + |
| 302 | + return target.label_id_to_index[tags[0].label_id] |
| 303 | + |
| 304 | + |
| 305 | +class LabeledBoxes(TypedDict): |
| 306 | + boxes: Sequence[Tuple[float, float, float, float]] |
| 307 | + labels: Sequence[int] |
| 308 | + |
| 309 | + |
| 310 | +_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"]) |
| 311 | + |
| 312 | + |
| 313 | +@attrs.frozen |
| 314 | +class ExtractBoundingBoxes: |
| 315 | + """ |
| 316 | + A target transform that takes a `Target` object and returns a dictionary compatible |
| 317 | + with the object detection networks in torchvision. |
| 318 | +
|
| 319 | + The dictionary contains the following entries: |
| 320 | +
|
| 321 | + "boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape |
| 322 | + in the annotations. |
| 323 | + "labels": a sequence of corresponding label indices. |
| 324 | +
|
| 325 | + Limitations: |
| 326 | +
|
| 327 | + * Only the following shape types are supported: rectangle, polygon, polyline, |
| 328 | + points, ellipse. |
| 329 | + * Rotated shapes are not supported. |
| 330 | +
|
| 331 | + Unsupported shapes will cause a `UnsupportedDatasetError` exception to be |
| 332 | + raised unless they are filtered out by `include_shape_types`. |
| 333 | + """ |
| 334 | + |
| 335 | + include_shape_types: FrozenSet[str] = attrs.field( |
| 336 | + converter=frozenset, |
| 337 | + validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)), |
| 338 | + kw_only=True, |
| 339 | + ) |
| 340 | + """Shapes whose type is not in this set will be ignored.""" |
| 341 | + |
| 342 | + def __call__(self, target: Target) -> LabeledBoxes: |
| 343 | + boxes = [] |
| 344 | + labels = [] |
| 345 | + |
| 346 | + for shape in target.annotations.shapes: |
| 347 | + if shape.type.value not in self.include_shape_types: |
| 348 | + continue |
| 349 | + |
| 350 | + if shape.rotation != 0: |
| 351 | + raise UnsupportedDatasetError("Rotated shapes are not supported") |
| 352 | + |
| 353 | + x_coords = shape.points[0::2] |
| 354 | + y_coords = shape.points[1::2] |
| 355 | + |
| 356 | + boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords))) |
| 357 | + labels.append(target.label_id_to_index[shape.label_id]) |
| 358 | + |
| 359 | + return LabeledBoxes(boxes=boxes, labels=labels) |
0 commit comments