Skip to content

Commit 4f4fd43

Browse files
samet-akcayjiaxian1-intelJiaxiang JiangAlexanderBarabanov
authored
🚀 Add FUVAS Video Anomaly Detection Model (#2654)
* 🚀 Add `FUVAS` Video Anomaly Detection Model (#2627) * fuvas integrated * fuvas integrated with changemd * Update CHANGELOG.md Signed-off-by: Samet Akcay <[email protected]> * Rename fuvas to Fuvas Signed-off-by: Samet Akcay <[email protected]> * Address linting errors in feature extractor Signed-off-by: Samet Akcay <[email protected]> * Fix the linting errors in torch model Signed-off-by: Samet Akcay <[email protected]> * Update docstrings Signed-off-by: Samet Akcay <[email protected]> * use torchvision feature extractor * fuvas readme and change np to torch * fuvas readme modify * update readme and python annotation * add fvcore dependencies * test model change --------- Signed-off-by: Samet Akcay <[email protected]> Co-authored-by: Jiaxiang Jiang <[email protected]> Co-authored-by: Samet Akcay <[email protected]> * 🚀 Add FUVAS Video Anomaly Detection Model (#2652) * Fix pre-commit Signed-off-by: Samet Akcay <[email protected]> * Skip video anomaly detection models for now Signed-off-by: Samet Akcay <[email protected]> --------- Signed-off-by: Samet Akcay <[email protected]> * 🔒 Fix bandit and semgrep issues of Fuvas video anomaly detection model. (#2655) * Address bandit and semgrep issues Signed-off-by: Samet Akcay <[email protected]> * Address bandit and semgrep issues Signed-off-by: Samet Akcay <[email protected]> * Update src/anomalib/models/__init__.py Co-authored-by: Alexander Barabanov <[email protected]> Signed-off-by: Samet Akcay <[email protected]> --------- Signed-off-by: Samet Akcay <[email protected]> Co-authored-by: Alexander Barabanov <[email protected]> * 🗑️ Remove `TaskType` from `FUVAS` Video Anomaly Detection Model (#2658) remove task type from fuvas Signed-off-by: Samet Akcay <[email protected]> --------- Signed-off-by: Samet Akcay <[email protected]> Co-authored-by: Jiaxiang Jiang <[email protected]> Co-authored-by: Jiaxiang Jiang <[email protected]> Co-authored-by: Alexander Barabanov <[email protected]>
1 parent b500b22 commit 4f4fd43

File tree

9 files changed

+508
-6
lines changed

9 files changed

+508
-6
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
🚀 Add new SOTA video Anomaly detection module FUVAS
12+
1113
- 🚀 Add VAD dataset by @abc-125 in https://github.com/openvinotoolkit/anomalib/pull/2603
1214

1315
### Removed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ core = [
5656
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator
5757
# 'aten::_native_multi_head_attention' to ONNX opset version 14 is not supported
5858
"open-clip-torch>=2.23.0,<2.26.1",
59+
"fvcore",
5960
]
6061
openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"]
6162
vlm = ["ollama>=0.4.0", "openai", "python-dotenv","transformers"]

src/anomalib/models/__init__.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,15 @@
7575
VlmAd,
7676
WinClip,
7777
)
78-
from .video import AiVad
78+
from .video import AiVad, Fuvas
79+
80+
# Whitelist of allowed modules for dynamic imports
81+
ALLOWED_MODULES = {
82+
"anomalib.models",
83+
"anomalib.models.image",
84+
"anomalib.models.video",
85+
"anomalib.models.components",
86+
}
7987

8088

8189
class UnknownModelError(ModuleNotFoundError):
@@ -103,6 +111,7 @@ class UnknownModelError(ModuleNotFoundError):
103111
"VlmAd",
104112
"WinClip",
105113
"AiVad",
114+
"Fuvas",
106115
]
107116

108117
logger = logging.getLogger(__name__)
@@ -262,7 +271,19 @@ def get_model(model: DictConfig | str | dict | Namespace, *args, **kwdargs) -> A
262271
model = OmegaConf.create(model)
263272
try:
264273
if len(model.class_path.split(".")) > 1:
265-
module = import_module(".".join(model.class_path.split(".")[:-1]))
274+
# Security check: Only allow imports from whitelisted modules
275+
module_path = ".".join(model.class_path.split(".")[:-1])
276+
if module_path not in ALLOWED_MODULES:
277+
logger.error(
278+
f"Module import from '{module_path}' is not allowed. "
279+
f"Only imports from {ALLOWED_MODULES} are permitted.",
280+
)
281+
msg = f"Module import from '{module_path}' is not allowed."
282+
raise UnknownModelError(msg)
283+
284+
# Use a whitelist approach to prevent arbitrary code execution
285+
# nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import
286+
module = import_module(module_path)
266287
else:
267288
module = import_module("anomalib.models")
268289
except ModuleNotFoundError as exception:

src/anomalib/models/video/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
- :class:`AiVad`: AI-based Video Anomaly Detection
2828
"""
2929

30-
# Copyright (C) 2023-2024 Intel Corporation
30+
# Copyright (C) 2023-2025 Intel Corporation
3131
# SPDX-License-Identifier: Apache-2.0
3232

3333
from .ai_vad import AiVad
34+
from .fuvas import Fuvas
3435

35-
__all__ = ["AiVad"]
36+
__all__ = ["AiVad", "Fuvas"]
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# [ICASSP 2025] FUVAS: Few-shot Unsupervised Video Anomaly Segmentation via Low-Rank Factorization of Spatio-Temporal Features
2+
3+
## 📝 Description
4+
5+
This folder contains the FUVAS video anomaly detection model which can support both transformer based backbone and CNN based backbone
6+
7+
## 💡 Examples
8+
9+
The following example shows how to use the FUVAS model to train on the ucsdped dataset.
10+
11+
<summary>Training the Fuvas model on UCSDped video dataset</summary>
12+
13+
```python
14+
# Import the necessary modules
15+
from anomalib.data import UCSDped
16+
from anomalib.models import Fuvas
17+
from anomalib.engine import Engine
18+
19+
# Load the ucsdped dataset, model and engine.
20+
datamodule = UCSDped()
21+
model = Fuvas()
22+
engine = Engine()
23+
24+
# Train the model
25+
engine.train(model, datamodule)
26+
```
27+
28+
## Example running output
29+
30+
| Test metric | DataLoader 0 |
31+
| ------------ | ------------------ |
32+
| frame_AUROC | 0.9135797023773193 |
33+
| mean_F1Score | 0.9350237846374512 |
34+
| pixel_AUROC | 0.9756277996063232 |
35+
36+
<section class="section" id="BibTeX">
37+
<div class="container is-max-desktop content">
38+
<h2 class="title">Citation</h2>
39+
<pre><code>@inproceedings{icassp2025fuvas,
40+
title={FUVAS: Few-shot Unsupervised Video Anomaly Segmentation via Low-Rank Factorization of Spatio-Temporal Features},
41+
author={Jiang, Jiaxiang and Ndiour, Ibrahima J and Subedar, Mahesh and Tickoo, Omesh},
42+
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
43+
pages={1--5},
44+
year={2025},
45+
organization={IEEE}
46+
}</code></pre>
47+
</div>
48+
</section>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""FUVAS: Few-shot Unsupervised Video Anomaly Segmentation via Low-Rank Factorization of Spatio-Temporal Features.
2+
3+
The FUVAS model extracts deep features from video clips using a pre-trained 3D CNN/transformer
4+
backbone and fits a PCA-based reconstruction model to detect anomalies. The model computes
5+
feature reconstruction errors to identify anomalous frames and regions in videos.
6+
7+
Example:
8+
>>> from anomalib.models.video import Fuvas
9+
>>> model = Fuvas(
10+
... backbone="x3d_s",
11+
... layer="blocks.4"
12+
... )
13+
14+
The model can be used with video anomaly detection datasets supported in anomalib.
15+
16+
Notes:
17+
The model implementation is available in the ``lightning_model`` module.
18+
19+
See Also:
20+
:class:`anomalib.models.video.fuvas.lightning_model.Fuvas`:
21+
Lightning implementation of the FUVAS model.
22+
"""
23+
24+
# Copyright (C) 2025 Intel Corporation
25+
# SPDX-License-Identifier: Apache-2.0
26+
27+
from .lightning_model import Fuvas
28+
29+
__all__ = ["Fuvas"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""FUVAS: Few-shot Unsupervised Video Anomaly Segmentation via Low-Rank Factorization of Spatio-Temporal Features.
2+
3+
This module provides a PyTorch Lightning implementation of the FUVAS model for
4+
video anomaly detection and segmentation. The model extracts deep features from video clips
5+
using a pre-trained 3D CNN/transformer backbone and fits a PCA-based reconstruction model
6+
to detect anomalies.
7+
8+
Paper: https://ieeexplore.ieee.org/abstract/document/10887597
9+
10+
Example:
11+
>>> from anomalib.models.video import fuvas
12+
>>> model = fuvas(
13+
... backbone="x3d_s",
14+
... layer="blocks.4",
15+
... pre_trained=True
16+
... )
17+
18+
Notes:
19+
The model uses a pre-trained backbone to extract features and fits a PCA
20+
transformation during training. No gradient updates are performed on the backbone.
21+
Anomaly detection is based on feature reconstruction error.
22+
23+
See Also:
24+
:class:`anomalib.models.video.fuvas.torch_model.FUVASModel`:
25+
PyTorch implementation of the FUVAS model.
26+
"""
27+
28+
# Copyright (C) 2025 Intel Corporation
29+
# SPDX-License-Identifier: Apache-2.0
30+
31+
import logging
32+
from typing import Any
33+
34+
import torch
35+
from lightning.pytorch.utilities.types import STEP_OUTPUT
36+
37+
from anomalib import LearningType
38+
from anomalib.data import Batch
39+
from anomalib.metrics import Evaluator
40+
from anomalib.models.components import AnomalibModule, MemoryBankMixin
41+
from anomalib.post_processing import PostProcessor
42+
from anomalib.pre_processing import PreProcessor
43+
from anomalib.visualization import Visualizer
44+
45+
from .torch_model import FUVASModel
46+
47+
logger = logging.getLogger(__name__)
48+
49+
50+
class Fuvas(MemoryBankMixin, AnomalibModule):
51+
"""FUVAS Lightning Module.
52+
53+
Args:
54+
backbone (str): Name of the backbone 3D CNN/transformer network.
55+
Defaults to ``"x3d_s"``.
56+
layer (str): Name of the layer to extract features from the backbone.
57+
Defaults to ``"blocks.4"``.
58+
pre_trained (bool, optional): Whether to use a pre-trained backbone.
59+
Defaults to ``True``.
60+
spatial_pool (bool, optional): Whether to use spatial pooling on features.
61+
Defaults to ``True``.
62+
pooling_kernel_size (int, optional): Kernel size for pooling features.
63+
Defaults to ``1``.
64+
pca_level (float, optional): Ratio of variance to preserve in PCA.
65+
Must be between 0 and 1.
66+
Defaults to ``0.98``.
67+
pre_processor (PreProcessor | bool, optional): Pre-processor to use.
68+
If ``True``, uses the default pre-processor.
69+
If ``False``, no pre-processing is performed.
70+
Defaults to ``True``.
71+
post_processor (PostProcessor | bool, optional): Post-processor to use.
72+
If ``True``, uses the default post-processor.
73+
If ``False``, no post-processing is performed.
74+
Defaults to ``True``.
75+
evaluator (Evaluator | bool, optional): Evaluator to use.
76+
If ``True``, uses the default evaluator.
77+
If ``False``, no evaluation is performed.
78+
Defaults to ``True``.
79+
visualizer (Visualizer | bool, optional): Visualizer to use.
80+
If ``True``, uses the default visualizer.
81+
If ``False``, no visualization is performed.
82+
Defaults to ``True``.
83+
"""
84+
85+
def __init__(
86+
self,
87+
backbone: str = "x3d_s",
88+
layer: str = "blocks.4",
89+
pre_trained: bool = True,
90+
spatial_pool: bool = True,
91+
pooling_kernel_size: int = 1,
92+
pca_level: float = 0.98,
93+
pre_processor: PreProcessor | bool = True,
94+
post_processor: PostProcessor | bool = True,
95+
evaluator: Evaluator | bool = True,
96+
visualizer: Visualizer | bool = True,
97+
) -> None:
98+
super().__init__(
99+
pre_processor=pre_processor,
100+
post_processor=post_processor,
101+
evaluator=evaluator,
102+
visualizer=visualizer,
103+
)
104+
105+
self.model: FUVASModel = FUVASModel(
106+
backbone=backbone,
107+
pre_trained=pre_trained,
108+
layer=layer,
109+
pooling_kernel_size=pooling_kernel_size,
110+
n_comps=pca_level,
111+
spatial_pool=spatial_pool,
112+
)
113+
self.embeddings: list[torch.Tensor] = []
114+
115+
@staticmethod
116+
def configure_optimizers() -> None: # pylint: disable=arguments-differ
117+
"""Configure optimizers for training.
118+
119+
Returns:
120+
None: FUVAS doesn't require optimization.
121+
"""
122+
return
123+
124+
def training_step(self, batch: Batch, *args, **kwargs) -> torch.Tensor:
125+
"""Extract features from the input batch during training.
126+
127+
Args:
128+
batch (Batch): Input batch containing video clips.
129+
*args: Additional positional arguments (unused).
130+
**kwargs: Additional keyword arguments (unused).
131+
132+
Returns:
133+
torch.Tensor: Dummy loss tensor for compatibility.
134+
"""
135+
del args, kwargs # These variables are not used.
136+
137+
# Ensure batch.image is a tensor
138+
if batch.image is None or not isinstance(batch.image, torch.Tensor):
139+
msg = "Expected batch.image to be a tensor, but got None or non-tensor type"
140+
raise ValueError(msg)
141+
142+
embedding = self.model.get_features(batch.image)[0].squeeze()
143+
self.embeddings.append(embedding)
144+
145+
# Return a dummy loss tensor
146+
return torch.tensor(0.0, requires_grad=True, device=self.device)
147+
148+
def fit(self) -> None:
149+
"""Fit the PCA transformation to the embeddings.
150+
151+
The method aggregates embeddings collected during training and fits
152+
the PCA transformation used for anomaly scoring.
153+
"""
154+
logger.info("Aggregating the embedding extracted from the training set.")
155+
embeddings = torch.vstack(self.embeddings)
156+
157+
logger.info("Fitting a PCA to dataset.")
158+
self.model.fit(embeddings)
159+
160+
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
161+
"""Compute predictions for the input batch during validation.
162+
163+
Args:
164+
batch (Batch): Input batch containing video clips.
165+
*args: Additional positional arguments (unused).
166+
**kwargs: Additional keyword arguments (unused).
167+
168+
Returns:
169+
STEP_OUTPUT: Dictionary containing anomaly scores and maps.
170+
"""
171+
del args, kwargs # These variables are not used.
172+
173+
predictions = self.model(batch.image)
174+
return batch.update(pred_score=predictions.pred_score, anomaly_map=predictions.anomaly_map)
175+
176+
@property
177+
def trainer_arguments(self) -> dict[str, Any]:
178+
"""Get FUVAS-specific trainer arguments.
179+
180+
Returns:
181+
dict[str, Any]: Dictionary of trainer arguments:
182+
- ``gradient_clip_val`` (int): Disable gradient clipping
183+
- ``max_epochs`` (int): Train for one epoch only
184+
- ``num_sanity_val_steps`` (int): Skip validation sanity checks
185+
"""
186+
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0}
187+
188+
@property
189+
def learning_type(self) -> LearningType:
190+
"""Get the learning type of the model.
191+
192+
Returns:
193+
LearningType: The model uses one-class learning.
194+
"""
195+
return LearningType.ONE_CLASS

0 commit comments

Comments
 (0)