Skip to content

Commit ddede9e

Browse files
committed
Update the tests
Signed-off-by: Samet Akcay <[email protected]>
1 parent c44f377 commit ddede9e

File tree

9 files changed

+27
-33
lines changed

9 files changed

+27
-33
lines changed

src/anomalib/data/datamodules/image/datumaro.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from anomalib import TaskType
1212
from anomalib.data.datamodules.base import AnomalibDataModule
1313
from anomalib.data.datasets.image.datumaro import DatumaroDataset
14-
from anomalib.data.utils import Split, TestSplitMode, ValSplitMode
14+
from anomalib.data.utils import Split, SplitMode, TestSplitMode, ValSplitMode
1515

1616

1717
class Datumaro(AnomalibDataModule):
@@ -69,9 +69,9 @@ def __init__(
6969
eval_batch_size: int = 32,
7070
num_workers: int = 8,
7171
task: TaskType = TaskType.CLASSIFICATION,
72-
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
73-
test_split_ratio: float = 0.5,
74-
val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST,
72+
test_split_mode: SplitMode | TestSplitMode | str = SplitMode.AUTO,
73+
test_split_ratio: float | None = None,
74+
val_split_mode: SplitMode | ValSplitMode | str = SplitMode.AUTO,
7575
val_split_ratio: float = 0.5,
7676
seed: int | None = None,
7777
) -> None:

src/anomalib/data/datasets/base/image.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __getitem__(self, index: int) -> DatasetItem:
288288
DatasetItem: DatasetItem instance containing image and ground truth (if available).
289289
"""
290290
image_path = self.samples.iloc[index].image_path
291+
mask_path = self.samples.iloc[index].get("mask_path") if "mask_path" in self.samples.columns else None
291292
label_index = self.samples.iloc[index].label_index
292293

293294
image = read_image(image_path, as_tensor=True)
@@ -296,16 +297,17 @@ def __getitem__(self, index: int) -> DatasetItem:
296297
if self.task == TaskType.CLASSIFICATION:
297298
item["image"] = self.transform(image) if self.transform else image
298299
elif self.task == TaskType.SEGMENTATION:
300+
if mask_path is None:
301+
msg = "mask_path is required for segmentation tasks but was not found in samples DataFrame"
302+
raise ValueError(msg)
299303
# Only Anomalous (1) images have masks in anomaly datasets
300304
# Therefore, create empty mask for Normal (0) images.
301-
mask_path = self.samples.iloc[index].mask_path
302305
mask = (
303306
Mask(torch.zeros(image.shape[-2:])).to(torch.uint8)
304307
if label_index == LabelName.NORMAL
305308
else read_mask(mask_path, as_tensor=True)
306309
)
307310
item["image"], item["gt_mask"] = self.transform(image, mask) if self.transform else (image, mask)
308-
309311
else:
310312
msg = f"Unknown task type: {self.task}"
311313
raise ValueError(msg)

src/anomalib/data/datasets/base/video.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from abc import ABC
7+
from collections.abc import Callable
78
from enum import Enum
8-
from typing import TYPE_CHECKING
99

1010
import torch
1111
from pandas import DataFrame
@@ -14,14 +14,11 @@
1414
from torchvision.tv_tensors import Mask
1515

1616
from anomalib import TaskType
17-
from anomalib.data.dataclasses import VideoItem
17+
from anomalib.data.dataclasses import VideoBatch, VideoItem
1818
from anomalib.data.utils.video import ClipsIndexer
1919

2020
from .image import AnomalibDataset
2121

22-
if TYPE_CHECKING:
23-
from collections.abc import Callable
24-
2522

2623
class VideoTargetFrame(str, Enum):
2724
"""Target frame for a video-clip.
@@ -172,3 +169,8 @@ def __getitem__(self, index: int) -> VideoItem:
172169
item = self._select_targets(item)
173170

174171
return item
172+
173+
@property
174+
def collate_fn(self) -> Callable:
175+
"""Return the collate function for video batches."""
176+
return VideoBatch.collate

src/anomalib/data/datasets/image/kolektor.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# Copyright (C) 2024 Intel Corporation
1818
# SPDX-License-Identifier: Apache-2.0
1919

20-
import logging
2120
from pathlib import Path
2221

2322
import numpy as np
@@ -29,19 +28,10 @@
2928
from anomalib import TaskType
3029
from anomalib.data.datasets import AnomalibDataset
3130
from anomalib.data.errors import MisMatchError
32-
from anomalib.data.utils import DownloadInfo, Split, validate_path
31+
from anomalib.data.utils import Split, validate_path
3332

3433
__all__ = ["KolektorDataset", "make_kolektor_dataset"]
3534

36-
logger = logging.getLogger(__name__)
37-
38-
DOWNLOAD_INFO = DownloadInfo(
39-
name="kolektor",
40-
url="https://go.vicos.si/kolektorsdd",
41-
hashsum="65dc621693418585de9c4467d1340ea7958a6181816f0dc2883a1e8b61f9d4dc",
42-
filename="KolektorSDD.zip",
43-
)
44-
4535

4636
class KolektorDataset(AnomalibDataset):
4737
"""Kolektor dataset class.

tests/helpers/data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ class DummyImageDatasetGenerator(DummyDatasetGenerator):
265265
Args:
266266
data_format (DataFormat): Data format of the dataset.
267267
root (Path | str, optional): Path to the root directory. Defaults to None.
268-
num_train (int, optional): Number of training images to generate. Defaults to 5.
269-
num_test (int, optional): Number of testing images to generate per category. Defaults to 5.
268+
num_train (int, optional): Number of training images to generate. Defaults to 8.
269+
num_test (int, optional): Number of testing images to generate per category. Defaults to 8.
270270
img_height (int, optional): Height of the image. Defaults to 128.
271271
img_width (int, optional): Width of the image. Defaults to 128.
272272
max_size (Optional[int], optional): Maximum size of the test shapes. Defaults to 10.
@@ -301,8 +301,8 @@ def __init__(
301301
root: Path | str | None = None,
302302
normal_category: str = "good",
303303
abnormal_category: str = "bad",
304-
num_train: int = 5,
305-
num_test: int = 5,
304+
num_train: int = 8,
305+
num_test: int = 8,
306306
image_shape: tuple[int, int] = (256, 256),
307307
num_channels: int = 3,
308308
min_size: int = 64,

tests/unit/data/datamodule/image/test_csv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def datamodule(dataset_path: Path, task_type: TaskType) -> CSV:
2626
_datamodule = CSV(
2727
name="dummy_csv",
2828
csv_path=dataset_path / "csv" / "samples.csv",
29-
train_batch_size=2,
30-
eval_batch_size=2,
29+
train_batch_size=4,
30+
eval_batch_size=4,
3131
num_workers=0,
3232
task=task_type,
3333
test_split_mode="predefined",

tests/unit/data/datamodule/image/test_folder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def datamodule(dataset_path: Path, task_type: TaskType) -> Folder:
3333
normal_dir="train/good",
3434
abnormal_dir="test/bad",
3535
mask_dir=mask_dir,
36-
train_batch_size=2,
37-
eval_batch_size=2,
36+
train_batch_size=4,
37+
eval_batch_size=4,
3838
num_workers=0,
3939
task=task_type,
4040
)

tests/unit/data/datamodule/image/test_mvtec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def datamodule(dataset_path: Path, task_type: TaskType) -> MVTec:
2323
root=dataset_path / "mvtec",
2424
category="dummy",
2525
task=task_type,
26-
train_batch_size=2,
27-
eval_batch_size=2,
26+
train_batch_size=4,
27+
eval_batch_size=4,
2828
)
2929
_datamodule.prepare_data()
3030
_datamodule.setup()

tests/unit/data/utils/test_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_existing_image_directory(dataset_path: Path) -> None:
2525
"""Test ``get_image_filenames`` returns the correct image filenames from an existing directory."""
2626
directory_path = dataset_path / "mvtec/dummy/train/good"
2727
image_filenames = get_image_filenames(directory_path)
28-
expected_filenames = [(directory_path / f"{i:03d}.png").resolve() for i in range(5)]
28+
expected_filenames = [(directory_path / f"{i:03d}.png").resolve() for i in range(8)]
2929
assert set(image_filenames) == set(expected_filenames)
3030

3131
@staticmethod

0 commit comments

Comments
 (0)