Skip to content

Commit ec49a5a

Browse files
committed
Merge branch 'release/v2.0.0' of github.com:openvinotoolkit/anomalib into add-missing-aux-components
2 parents e2472db + 8bd06a9 commit ec49a5a

File tree

15 files changed

+56
-43
lines changed

15 files changed

+56
-43
lines changed

src/anomalib/data/validators/torch/video.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -588,10 +588,10 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None:
588588
Examples:
589589
>>> import torch
590590
>>> from anomalib.data.validators import VideoBatchValidator
591-
>>> gt_masks = torch.rand(2, 10, 224, 224) > 0.5 # 2 videos, 10 frames each
591+
>>> gt_masks = torch.rand(10, 224, 224) > 0.5 # 10 frames each
592592
>>> validated_masks = VideoBatchValidator.validate_gt_mask(gt_masks)
593593
>>> print(validated_masks.shape)
594-
torch.Size([2, 10, 224, 224])
594+
torch.Size([10, 224, 224])
595595
>>> single_frame_masks = torch.rand(4, 456, 256) > 0.5 # 4 single-frame images
596596
>>> validated_single_frame = VideoBatchValidator.validate_gt_mask(single_frame_masks)
597597
>>> print(validated_single_frame.shape)
@@ -600,17 +600,18 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None:
600600
if mask is None:
601601
return None
602602
if not isinstance(mask, torch.Tensor):
603-
msg = f"Masks must be a torch.Tensor, got {type(mask)}."
603+
msg = f"Ground truth mask must be a torch.Tensor, got {type(mask)}."
604604
raise TypeError(msg)
605-
if mask.ndim not in {3, 4, 5}:
606-
msg = f"Masks must have shape [B, H, W], [B, T, H, W] or [B, T, 1, H, W], got shape {mask.shape}."
605+
if mask.ndim not in {2, 3, 4}:
606+
msg = f"Ground truth mask must have shape [H, W] or [N, H, W] or [N, 1, H, W] got shape {mask.shape}."
607607
raise ValueError(msg)
608-
if mask.ndim == 5:
609-
if mask.shape[2] != 1:
610-
msg = f"Masks must have 1 channel, got {mask.shape[2]}."
608+
if mask.ndim == 2:
609+
mask = mask.unsqueeze(0)
610+
if mask.ndim == 4:
611+
if mask.shape[1] != 1:
612+
msg = f"Ground truth mask must have 1 channel, got {mask.shape[1]}."
611613
raise ValueError(msg)
612-
mask = mask.squeeze(2)
613-
614+
mask = mask.squeeze(1)
614615
return Mask(mask, dtype=torch.bool)
615616

616617
@staticmethod

src/anomalib/engine/engine.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,6 @@ def _setup_trainer(self, model: AnomalibModule) -> None:
259259
# Setup anomalib callbacks to be used with the trainer
260260
self._setup_anomalib_callbacks()
261261

262-
# Temporarily set devices to 1 to avoid issues with multiple processes
263-
self._cache.args["devices"] = 1
264-
265262
# Instantiate the trainer if it is not already instantiated
266263
if self._trainer is None:
267264
self._trainer = Trainer(**self._cache.args)

src/anomalib/metrics/evaluator.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6+
import logging
67
from collections.abc import Sequence
78
from typing import Any
89

@@ -14,6 +15,8 @@
1415

1516
from anomalib.metrics import AnomalibMetric
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class Evaluator(nn.Module, Callback):
1922
"""Evaluator module for LightningModule.
@@ -53,8 +56,15 @@ def __init__(
5356
super().__init__()
5457
self.val_metrics = ModuleList(self.validate_metrics(val_metrics))
5558
self.test_metrics = ModuleList(self.validate_metrics(test_metrics))
56-
57-
if compute_on_cpu:
59+
self.compute_on_cpu = compute_on_cpu
60+
61+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
62+
"""Move metrics to cpu if ``num_devices == 1`` and ``compute_on_cpu`` is set to ``True``."""
63+
del pl_module, stage # Unused arguments.
64+
if trainer.num_devices > 1:
65+
if self.compute_on_cpu:
66+
logger.warning("Number of devices is greater than 1, setting compute_on_cpu to False.")
67+
elif self.compute_on_cpu:
5868
self.metrics_to_cpu(self.val_metrics)
5969
self.metrics_to_cpu(self.test_metrics)
6070

src/anomalib/models/components/base/memory_bank_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class MemoryBankMixin(nn.Module):
1919
def __init__(self, *args, **kwargs) -> None:
2020
super().__init__(*args, **kwargs)
2121
self.register_buffer("_is_fitted", torch.tensor([False]))
22+
self.device: torch.device # defined in lightning module
2223
self._is_fitted: torch.Tensor
2324

2425
@abstractmethod
@@ -34,10 +35,10 @@ def on_validation_start(self) -> None:
3435
"""Ensure that the model is fitted before validation starts."""
3536
if not self._is_fitted:
3637
self.fit()
37-
self._is_fitted = torch.tensor([True])
38+
self._is_fitted = torch.tensor([True], device=self.device)
3839

3940
def on_train_epoch_end(self) -> None:
4041
"""Ensure that the model is fitted before validation starts."""
4142
if not self._is_fitted:
4243
self.fit()
43-
self._is_fitted = torch.tensor([True])
44+
self._is_fitted = torch.tensor([True], device=self.device)

src/anomalib/models/components/classification/kde_classifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ def fit(self, embeddings: torch.Tensor) -> bool:
9393

9494
# if max training points is non-zero and smaller than number of staged features, select random subset
9595
if embeddings.shape[0] > self.max_training_points:
96-
selected_idx = torch.tensor(random.sample(range(embeddings.shape[0]), self.max_training_points))
96+
selected_idx = torch.tensor(
97+
random.sample(range(embeddings.shape[0]), self.max_training_points),
98+
device=embeddings.device,
99+
)
97100
selected_features = embeddings[selected_idx]
98101
else:
99102
selected_features = embeddings

src/anomalib/models/components/dimensionality_reduction/pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def fit(self, dataset: torch.Tensor) -> None:
7474
else:
7575
num_components = int(self.n_components)
7676

77-
self.num_components = torch.Tensor([num_components])
77+
self.num_components = torch.tensor([num_components], device=dataset.device)
7878

7979
self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components].float()
8080
self.singular_values = sig[:num_components].float()
@@ -98,7 +98,7 @@ def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor:
9898
mean = dataset.mean(dim=0)
9999
dataset -= mean
100100
num_components = int(self.n_components)
101-
self.num_components = torch.Tensor([num_components])
101+
self.num_components = torch.tensor([num_components], device=dataset.device)
102102

103103
v_h = torch.linalg.svd(dataset)[-1]
104104
self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components]

src/anomalib/models/image/dfkde/lightning_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
9494
embedding = self.model(batch.image)
9595
self.embeddings.append(embedding)
9696

97+
# Return a dummy loss tensor
98+
return torch.tensor(0.0, requires_grad=True, device=self.device)
99+
97100
def fit(self) -> None:
98101
"""Fit a KDE Model to the embedding collected from the training set."""
99102
embeddings = torch.vstack(self.embeddings)

src/anomalib/models/image/dfm/lightning_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
100100
embedding = self.model.get_features(batch.image).squeeze()
101101
self.embeddings.append(embedding)
102102

103+
# Return a dummy loss tensor
104+
return torch.tensor(0.0, requires_grad=True, device=self.device)
105+
103106
def fit(self) -> None:
104107
"""Fit a PCA transformation and a Gaussian model to dataset."""
105108
logger.info("Aggregating the embedding extracted from the training set.")

src/anomalib/models/image/dfm/torch_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def fit(self, dataset: torch.Tensor) -> None:
4141
dataset (torch.Tensor): Input dataset to fit the model.
4242
"""
4343
num_samples = dataset.shape[1]
44-
self.mean_vec = torch.mean(dataset, dim=1)
44+
self.mean_vec = torch.mean(dataset, dim=1, device=dataset.device)
4545
data_centered = (dataset - self.mean_vec.reshape(-1, 1)) / math.sqrt(num_samples)
4646
self.u_mat, self.sigma_mat, _ = torch.linalg.svd(data_centered, full_matrices=False)
4747

src/anomalib/models/image/dsr/anomaly_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def augment_batch(self, batch: Tensor) -> Tensor:
7373
masks_list: list[Tensor] = []
7474
for _ in range(batch_size):
7575
if torch.rand(1) > self.p_anomalous: # include normal samples
76-
masks_list.append(torch.zeros((1, height, width)))
76+
masks_list.append(torch.zeros((1, height, width), device=batch.device))
7777
else:
7878
mask = self.generate_anomaly(height, width)
7979
masks_list.append(mask)

src/anomalib/models/image/padim/lightning_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
9191
del args, kwargs # These variables are not used.
9292

9393
embedding = self.model(batch.image)
94-
self.embeddings.append(embedding.cpu())
94+
self.embeddings.append(embedding)
95+
96+
# Return a dummy loss tensor
97+
return torch.tensor(0.0, requires_grad=True, device=self.device)
9598

9699
def fit(self) -> None:
97100
"""Fit a Gaussian to the embedding collected from the training set."""

src/anomalib/models/image/patchcore/lightning_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
118118

119119
embedding = self.model(batch.image)
120120
self.embeddings.append(embedding)
121+
# Return a dummy loss tensor
122+
return torch.tensor(0.0, requires_grad=True, device=self.device)
121123

122124
def fit(self) -> None:
123125
"""Apply subsampling to the embedding collected from the training set."""

src/anomalib/models/video/ai_vad/lightning_model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# SPDX-License-Identifier: Apache-2.0
88

99
import logging
10-
from dataclasses import replace
1110
from typing import Any
1211

12+
import torch
1313
from lightning.pytorch.utilities.types import STEP_OUTPUT
1414

1515
from anomalib import LearningType
@@ -124,6 +124,9 @@ def training_step(self, batch: VideoBatch) -> None:
124124
self.model.density_estimator.update(features, video_path)
125125
self.total_detections += len(next(iter(features.values())))
126126

127+
# Return a dummy loss tensor
128+
return torch.tensor(0.0, requires_grad=True, device=self.device)
129+
127130
def fit(self) -> None:
128131
"""Fit the density estimators to the extracted features from the training set."""
129132
if self.total_detections == 0:
@@ -147,13 +150,7 @@ def validation_step(self, batch: VideoBatch, *args, **kwargs) -> STEP_OUTPUT:
147150
del args, kwargs # Unused arguments.
148151

149152
predictions = self.model(batch.image)
150-
151-
return replace(
152-
batch,
153-
pred_score=predictions.pred_score,
154-
anomaly_map=predictions.anomaly_map,
155-
pred_mask=predictions.pred_mask,
156-
)
153+
return batch.update(pred_score=predictions.pred_score, anomaly_map=predictions.anomaly_map)
157154

158155
@property
159156
def trainer_arguments(self) -> dict[str, Any]:

src/anomalib/utils/config.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,3 @@ def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None:
254254
"Anomalib's models and visualizer are currently not compatible with video datasets with a clip length > 1. "
255255
"Custom changes to these modules will be needed to prevent errors and/or unpredictable behaviour.",
256256
)
257-
if (
258-
"devices" in config.trainer
259-
and (config.trainer.devices is None or config.trainer.devices != 1)
260-
and config.trainer.accelerator != "cpu"
261-
):
262-
logger.warning("Anomalib currently does not support multi-gpu training. Setting devices to 1.")
263-
config.trainer.devices = 1

tests/unit/data/validators/torch/test_video.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ def test_validate_gt_label_invalid_type(self) -> None:
174174

175175
def test_validate_gt_mask_valid(self) -> None:
176176
"""Test validation of valid ground truth masks."""
177-
masks = torch.randint(0, 2, (2, 10, 224, 224))
177+
masks = torch.randint(0, 2, (10, 1, 224, 224))
178178
validated_masks = self.validator.validate_gt_mask(masks)
179179
assert isinstance(validated_masks, Mask)
180-
assert validated_masks.shape == (2, 10, 224, 224)
180+
assert validated_masks.shape == (10, 224, 224)
181181
assert validated_masks.dtype == torch.bool
182182

183183
def test_validate_gt_mask_none(self) -> None:
@@ -186,13 +186,13 @@ def test_validate_gt_mask_none(self) -> None:
186186

187187
def test_validate_gt_mask_invalid_type(self) -> None:
188188
"""Test validation of ground truth masks with invalid type."""
189-
with pytest.raises(TypeError, match="Masks must be a torch.Tensor"):
189+
with pytest.raises(TypeError, match="Ground truth mask must be a torch.Tensor"):
190190
self.validator.validate_gt_mask([torch.zeros(10, 224, 224)])
191191

192192
def test_validate_gt_mask_invalid_shape(self) -> None:
193193
"""Test validation of ground truth masks with invalid shape."""
194-
with pytest.raises(ValueError, match="Masks must have 1 channel, got 2."):
195-
self.validator.validate_gt_mask(torch.zeros(2, 10, 2, 224, 224))
194+
with pytest.raises(ValueError, match="Ground truth mask must have 1 channel, got 2."):
195+
self.validator.validate_gt_mask(torch.zeros(10, 2, 224, 224))
196196

197197
def test_validate_anomaly_map_valid(self) -> None:
198198
"""Test validation of a valid anomaly map batch."""

0 commit comments

Comments
 (0)