Skip to content

Commit 0823ab8

Browse files
🚀 Add datumaro annotation dataloader (#2377)
* Add datumaro annotation dataloader Signed-off-by: Ashwin Vaidya <[email protected]> * Update changelog Signed-off-by: Ashwin Vaidya <[email protected]> * Add examples Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent c99f868 commit 0823ab8

File tree

7 files changed

+327
-5
lines changed

7 files changed

+327
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Add `Datumaro` annotation format support by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2377
1112
- Add `AUPIMO` tutorials notebooks in https://github.com/openvinotoolkit/anomalib/pull/2330 and https://github.com/openvinotoolkit/anomalib/pull/2336
1213
- Add `AUPIMO` metric by [jpcbertoldo](https://github.com/jpcbertoldo) in https://github.com/openvinotoolkit/anomalib/pull/1726 and refactored by [ashwinvaidya17](https://github.com/ashwinvaidya17) in https://github.com/openvinotoolkit/anomalib/pull/2329
1314

configs/data/datumaro.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
class_path: anomalib.data.Datumaro
2+
init_args:
3+
root: "datasets/datumaro"
4+
train_batch_size: 32
5+
eval_batch_size: 32
6+
num_workers: 8
7+
image_size: null
8+
transform: null
9+
train_transform: null
10+
eval_transform: null
11+
test_split_mode: FROM_DIR
12+
test_split_ratio: 0.2
13+
val_split_mode: FROM_TEST
14+
val_split_ratio: 0.5
15+
seed: null

src/anomalib/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .base import AnomalibDataModule, AnomalibDataset
1616
from .depth import DepthDataFormat, Folder3D, MVTec3D
17-
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa
17+
from .image import BTech, Datumaro, Folder, ImageDataFormat, Kolektor, MVTec, Visa
1818
from .predict import PredictDataset
1919
from .utils import LabelName
2020
from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat
@@ -70,6 +70,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule
7070
"VideoDataFormat",
7171
"get_datamodule",
7272
"BTech",
73+
"Datumaro",
7374
"Folder",
7475
"Folder3D",
7576
"PredictDataset",

src/anomalib/data/image/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from enum import Enum
1010

1111
from .btech import BTech
12+
from .datumaro import Datumaro
1213
from .folder import Folder
1314
from .kolektor import Kolektor
1415
from .mvtec import MVTec
@@ -18,13 +19,14 @@
1819
class ImageDataFormat(str, Enum):
1920
"""Supported Image Dataset Types."""
2021

21-
MVTEC = "mvtec"
22-
MVTEC_3D = "mvtec_3d"
2322
BTECH = "btech"
24-
KOLEKTOR = "kolektor"
23+
DATUMARO = "datumaro"
2524
FOLDER = "folder"
2625
FOLDER_3D = "folder_3d"
26+
KOLEKTOR = "kolektor"
27+
MVTEC = "mvtec"
28+
MVTEC_3D = "mvtec_3d"
2729
VISA = "visa"
2830

2931

30-
__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "Visa"]
32+
__all__ = ["BTech", "Datumaro", "Folder", "Kolektor", "MVTec", "Visa"]

src/anomalib/data/image/datumaro.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""Dataloader for Datumaro format.
2+
3+
Note: This currently only works for annotations exported from Intel Geti™.
4+
"""
5+
6+
# Copyright (C) 2024 Intel Corporation
7+
# SPDX-License-Identifier: Apache-2.0
8+
9+
import json
10+
from pathlib import Path
11+
12+
import pandas as pd
13+
from torchvision.transforms.v2 import Transform
14+
15+
from anomalib import TaskType
16+
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
17+
from anomalib.data.utils import LabelName, Split, TestSplitMode, ValSplitMode
18+
19+
20+
def make_datumaro_dataset(root: str | Path, split: str | Split | None = None) -> pd.DataFrame:
21+
"""Make Datumaro Dataset.
22+
23+
Assumes the following directory structure:
24+
25+
dataset
26+
├── annotations
27+
│ └── default.json
28+
└── images
29+
└── default
30+
├── image1.jpg
31+
├── image2.jpg
32+
└── ...
33+
34+
Args:
35+
root (str | Path): Path to the dataset root directory.
36+
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST.
37+
Defaults to ``None``.
38+
39+
Examples:
40+
>>> root = Path("path/to/dataset")
41+
>>> samples = make_datumaro_dataset(root)
42+
>>> samples.head()
43+
image_path label label_index split mask_path
44+
0 path/to/dataset... Normal 0 Split.TRAIN
45+
1 path/to/dataset... Normal 0 Split.TRAIN
46+
2 path/to/dataset... Normal 0 Split.TRAIN
47+
3 path/to/dataset... Normal 0 Split.TRAIN
48+
4 path/to/dataset... Normal 0 Split.TRAIN
49+
50+
51+
Returns:
52+
DataFrame: an output dataframe containing samples for the requested split (ie., train or test).
53+
"""
54+
annotation_file = Path(root) / "annotations" / "default.json"
55+
with annotation_file.open() as f:
56+
annotations = json.load(f)
57+
58+
categories = annotations["categories"]
59+
categories = {idx: label["name"] for idx, label in enumerate(categories["label"]["labels"])}
60+
61+
samples = []
62+
for item in annotations["items"]:
63+
image_path = Path(root) / "images" / "default" / item["image"]["path"]
64+
label_index = item["annotations"][0]["label_id"]
65+
label = categories[label_index]
66+
samples.append({
67+
"image_path": str(image_path),
68+
"label": label,
69+
"label_index": label_index,
70+
"split": None,
71+
"mask_path": "", # mask is provided in the annotation file and is not on disk.
72+
})
73+
samples_df = pd.DataFrame(
74+
samples,
75+
columns=["image_path", "label", "label_index", "split", "mask_path"],
76+
index=range(len(samples)),
77+
)
78+
# Create test/train split
79+
# By default assign all "Normal" samples to train and all "Anomalous" samples to test
80+
samples_df.loc[samples_df["label_index"] == LabelName.NORMAL, "split"] = Split.TRAIN
81+
samples_df.loc[samples_df["label_index"] == LabelName.ABNORMAL, "split"] = Split.TEST
82+
83+
# Get the data frame for the split.
84+
if split:
85+
samples_df = samples_df[samples_df.split == split].reset_index(drop=True)
86+
87+
return samples_df
88+
89+
90+
class DatumaroDataset(AnomalibDataset):
91+
"""Datumaro dataset class.
92+
93+
Args:
94+
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``.
95+
root (str | Path): Path to the dataset root directory.
96+
transform (Transform, optional): Transforms that should be applied to the input images.
97+
Defaults to ``None``.
98+
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
99+
Defaults to ``None``.
100+
101+
102+
Examples:
103+
.. code-block:: python
104+
105+
from anomalib.data.image.datumaro import DatumaroDataset
106+
from torchvision.transforms.v2 import Resize
107+
108+
dataset = DatumaroDataset(root=root,
109+
task="classification",
110+
transform=Resize((256, 256)),
111+
)
112+
print(dataset[0].keys())
113+
# Output: dict_keys(['dm_format_version', 'infos', 'categories', 'items'])
114+
115+
"""
116+
117+
def __init__(
118+
self,
119+
task: TaskType,
120+
root: str | Path,
121+
transform: Transform | None = None,
122+
split: str | Split | None = None,
123+
) -> None:
124+
super().__init__(task, transform)
125+
self.split = split
126+
self.samples = make_datumaro_dataset(root, split)
127+
128+
129+
class Datumaro(AnomalibDataModule):
130+
"""Datumaro datamodule.
131+
132+
Args:
133+
root (str | Path): Path to the dataset root directory.
134+
train_batch_size (int): Batch size for training dataloader.
135+
Defaults to ``32``.
136+
eval_batch_size (int): Batch size for evaluation dataloader.
137+
Defaults to ``32``.
138+
num_workers (int): Number of workers for dataloaders.
139+
Defaults to ``8``.
140+
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``.
141+
Defaults to ``TaskType.CLASSIFICATION``. Currently only supports classification.
142+
image_size (tuple[int, int], optional): Size to which input images should be resized.
143+
Defaults to ``None``.
144+
transform (Transform, optional): Transforms that should be applied to the input images.
145+
Defaults to ``None``.
146+
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
147+
Defaults to ``None``.
148+
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
149+
Defaults to ``None``.
150+
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
151+
Defaults to ``TestSplitMode.FROM_DIR``.
152+
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
153+
Defaults to ``0.2``.
154+
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
155+
Defaults to ``ValSplitMode.SAME_AS_TEST``.
156+
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
157+
Defaults to ``0.5``.
158+
seed (int | None, optional): Seed which may be set to a fixed value for reproducibility.
159+
Defualts to ``None``.
160+
161+
Examples:
162+
To create a Datumaro datamodule
163+
164+
>>> from pathlib import Path
165+
>>> from torchvision.transforms.v2 import Resize
166+
>>> root = Path("path/to/dataset")
167+
>>> datamodule = Datumaro(root, transform=Resize((256, 256)))
168+
>>> datamodule.setup()
169+
>>> i, data = next(enumerate(datamodule.train_dataloader()))
170+
>>> data.keys()
171+
dict_keys(['image_path', 'label', 'image'])
172+
173+
>>> data["image"].shape
174+
torch.Size([32, 3, 256, 256])
175+
"""
176+
177+
def __init__(
178+
self,
179+
root: str | Path,
180+
train_batch_size: int = 32,
181+
eval_batch_size: int = 32,
182+
num_workers: int = 8,
183+
task: TaskType = TaskType.CLASSIFICATION,
184+
image_size: tuple[int, int] | None = None,
185+
transform: Transform | None = None,
186+
train_transform: Transform | None = None,
187+
eval_transform: Transform | None = None,
188+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
189+
test_split_ratio: float = 0.5,
190+
val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST,
191+
val_split_ratio: float = 0.5,
192+
seed: int | None = None,
193+
) -> None:
194+
if task != TaskType.CLASSIFICATION:
195+
msg = "Datumaro dataloader currently only supports classification task."
196+
raise ValueError(msg)
197+
super().__init__(
198+
train_batch_size=train_batch_size,
199+
eval_batch_size=eval_batch_size,
200+
num_workers=num_workers,
201+
val_split_mode=val_split_mode,
202+
val_split_ratio=val_split_ratio,
203+
test_split_mode=test_split_mode,
204+
test_split_ratio=test_split_ratio,
205+
image_size=image_size,
206+
transform=transform,
207+
train_transform=train_transform,
208+
eval_transform=eval_transform,
209+
seed=seed,
210+
)
211+
self.root = root
212+
self.task = task
213+
214+
def _setup(self, _stage: str | None = None) -> None:
215+
self.train_data = DatumaroDataset(
216+
task=self.task,
217+
root=self.root,
218+
transform=self.train_transform,
219+
split=Split.TRAIN,
220+
)
221+
self.test_data = DatumaroDataset(
222+
task=self.task,
223+
root=self.root,
224+
transform=self.eval_transform,
225+
split=Split.TEST,
226+
)

tests/helpers/data.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import json
89
import shutil
910
from contextlib import ContextDecorator
1011
from pathlib import Path
@@ -319,6 +320,43 @@ def __init__(
319320
self.min_size = min_size
320321
self.image_generator = DummyImageGenerator(image_shape=image_shape, rng=self.rng)
321322

323+
def _generate_dummy_datumaro_dataset(self) -> None:
324+
"""Generates dummy Datumaro dataset in a temporary directory."""
325+
# generate images
326+
image_root = self.dataset_root / "images" / "default"
327+
image_root.mkdir(parents=True, exist_ok=True)
328+
329+
file_names: list[str] = []
330+
331+
# Create normal images
332+
for i in range(self.num_train + self.num_test):
333+
label = LabelName.NORMAL
334+
image_filename = image_root / f"normal_{i:03}.png"
335+
file_names.append(image_filename)
336+
self.image_generator.generate_image(label, image_filename)
337+
338+
# Create abnormal images
339+
for i in range(self.num_test):
340+
label = LabelName.ABNORMAL
341+
image_filename = image_root / f"abnormal_{i:03}.png"
342+
file_names.append(image_filename)
343+
self.image_generator.generate_image(label, image_filename)
344+
345+
# create annotation file
346+
annotation_file = self.dataset_root / "annotations" / "default.json"
347+
annotation_file.parent.mkdir(parents=True, exist_ok=True)
348+
annotations = {
349+
"categories": {"label": {"labels": [{"name": "Normal"}, {"name": "Anomalous"}]}},
350+
"items": [],
351+
}
352+
for file_name in file_names:
353+
annotations["items"].append({
354+
"annotations": [{"label_id": 1 if "abnormal" in str(file_name) else 0}],
355+
"image": {"path": file_name.name},
356+
})
357+
with annotation_file.open("w") as f:
358+
json.dump(annotations, f)
359+
322360
def _generate_dummy_mvtec_dataset(
323361
self,
324362
normal_dir: str = "good",
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Unit tests - Datumaro Datamodule."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from pathlib import Path
7+
8+
import pytest
9+
10+
from anomalib import TaskType
11+
from anomalib.data import Datumaro
12+
from tests.unit.data.base.image import _TestAnomalibImageDatamodule
13+
14+
15+
class TestDatumaro(_TestAnomalibImageDatamodule):
16+
"""Datumaro Datamodule Unit Tests."""
17+
18+
@pytest.fixture()
19+
@staticmethod
20+
def datamodule(dataset_path: Path, task_type: TaskType) -> Datumaro:
21+
"""Create and return a Datumaro datamodule."""
22+
if task_type != TaskType.CLASSIFICATION:
23+
pytest.skip("Datumaro only supports classification tasks.")
24+
25+
_datamodule = Datumaro(
26+
root=dataset_path / "datumaro",
27+
task=task_type,
28+
train_batch_size=4,
29+
eval_batch_size=4,
30+
)
31+
_datamodule.setup()
32+
33+
return _datamodule
34+
35+
@pytest.fixture()
36+
@staticmethod
37+
def fxt_data_config_path() -> str:
38+
"""Return the path to the test data config."""
39+
return "configs/data/datumaro.yaml"

0 commit comments

Comments
 (0)