Skip to content

Commit 5bd6b69

Browse files
fix video mask
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 07e3a9a commit 5bd6b69

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

src/anomalib/data/validators/torch/video.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -600,17 +600,18 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None:
600600
if mask is None:
601601
return None
602602
if not isinstance(mask, torch.Tensor):
603-
msg = f"Masks must be a torch.Tensor, got {type(mask)}."
603+
msg = f"Ground truth mask must be a torch.Tensor, got {type(mask)}."
604604
raise TypeError(msg)
605-
if mask.ndim not in {3, 4, 5}:
606-
msg = f"Masks must have shape [B, H, W], [B, T, H, W] or [B, T, 1, H, W], got shape {mask.shape}."
605+
if mask.ndim not in {2, 3, 4}:
606+
msg = f"Ground truth mask must have shape [H, W] or [N, H, W] or [N, 1, H, W] got shape {mask.shape}."
607607
raise ValueError(msg)
608-
if mask.ndim == 5:
609-
if mask.shape[2] != 1:
610-
msg = f"Masks must have 1 channel, got {mask.shape[2]}."
608+
if mask.ndim == 2:
609+
mask = mask.unsqueeze(0)
610+
if mask.ndim == 4:
611+
if mask.shape[1] != 1:
612+
msg = f"Ground truth mask must have 1 channel, got {mask.shape[1]}."
611613
raise ValueError(msg)
612-
mask = mask.squeeze(2)
613-
614+
mask = mask.squeeze(1)
614615
return Mask(mask, dtype=torch.bool)
615616

616617
@staticmethod

src/anomalib/models/video/ai_vad/lightning_model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# SPDX-License-Identifier: Apache-2.0
88

99
import logging
10-
from dataclasses import replace
1110
from typing import Any
1211

1312
from lightning.pytorch.utilities.types import STEP_OUTPUT
@@ -146,13 +145,7 @@ def validation_step(self, batch: VideoBatch, *args, **kwargs) -> STEP_OUTPUT:
146145
del args, kwargs # Unused arguments.
147146

148147
predictions = self.model(batch.image)
149-
150-
return replace(
151-
batch,
152-
pred_score=predictions.pred_score,
153-
anomaly_map=predictions.anomaly_map,
154-
pred_mask=predictions.pred_mask,
155-
)
148+
return batch.update(pred_score=predictions.pred_score, anomaly_map=predictions.anomaly_map)
156149

157150
@property
158151
def trainer_arguments(self) -> dict[str, Any]:

0 commit comments

Comments
 (0)