diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf102..7f7c18fca8 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +SPHINXOPTS = -j auto SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 02603dc810..0000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -myst-parser -nbsphinx -pandoc -sphinx -sphinx_autodoc_typehints -sphinx_book_theme -sphinx-copybutton -sphinx_design diff --git a/docs/source/conf.py b/docs/source/conf.py index 16e79a59af..890bb5100b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,26 +10,25 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - import sys from pathlib import Path -# Define the path to your module using Path -module_path = Path(__file__).parent.parent / "src" +# Define paths +project_root = Path(__file__).parent.parent.parent +module_path = project_root / "src" +examples_path = project_root / "examples" -# Insert the path to sys.path +# Insert paths to sys.path sys.path.insert(0, str(module_path.resolve())) +sys.path.insert(0, str(project_root.resolve())) project = "Anomalib" -copyright = "2023, Intel OpenVINO" # noqa: A001 -author = "Intel OpenVINO" -release = "2022" +copyright = "Intel Corporation" # noqa: A001 +author = "Intel Corporation" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration - extensions = [ "sphinx.ext.autodoc", "sphinx.ext.mathjax", @@ -39,20 +38,49 @@ "sphinx.ext.napoleon", "sphinx_autodoc_typehints", "sphinx_copybutton", + "sphinx.ext.intersphinx", + "sphinx.ext.autosectionlabel", ] +# MyST configuration myst_enable_extensions = [ "colon_fence", - # other MyST extensions... + "linkify", + "substitution", + "tasklist", + "deflist", + "fieldlist", ] + +# Add separate setting for eval-rst +myst_enable_eval_rst = True + +# Notebook handling nbsphinx_allow_errors = True +nbsphinx_execute = "auto" # Execute notebooks during build +nbsphinx_timeout = 300 # Timeout in seconds + +# Templates and patterns templates_path = ["_templates"] -exclude_patterns: list[str] = [] +exclude_patterns: list[str] = [ + "_build", + "**.ipynb_checkpoints", + "**/.pytest_cache", + "**/.git", + "**/.github", + "**/.venv", + "**/*.egg-info", + "**/build", + "**/dist", +] # Automatic exclusion of prompts from the copies # https://sphinx-copybutton.readthedocs.io/en/latest/use.html#automatic-exclusion-of-prompts-from-the-copies copybutton_exclude = ".linenos, .gp, .go" +# Enable section anchors for cross-referencing +autosectionlabel_prefix_document = True + # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output @@ -65,3 +93,13 @@ "text": "Anomalib", }, } + +# Add references to example files +html_context = {"examples_path": str(examples_path)} + +# External documentation references +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "torch": ("https://pytorch.org/docs/stable", None), + "lightning": ("https://lightning.ai/docs/pytorch/stable/", None), +} diff --git a/docs/source/index.md b/docs/source/index.md index eea06f1275..46199a425d 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -68,19 +68,11 @@ Learn more about anomalib API and CLI. Learn how to use anomalib for your anomaly detection tasks. ::: -:::{grid-item-card} {octicon}`telescope` Topic Guide -:link: markdown/guides/topic/index -:link-type: doc - -Learn more about the internals of anomalib. -::: - :::{grid-item-card} {octicon}`code` Developer Guide :link: markdown/guides/developer/index :link-type: doc Learn how to develop and contribute to anomalib. -::: :::: @@ -98,7 +90,6 @@ markdown/get_started/migration markdown/guides/reference/index markdown/guides/how_to/index -markdown/guides/topic/index markdown/guides/developer/index ``` diff --git a/docs/source/markdown/guides/reference/callbacks/index.md b/docs/source/markdown/guides/reference/callbacks/index.md index 4fdeaedfff..1c76a19fcb 100644 --- a/docs/source/markdown/guides/reference/callbacks/index.md +++ b/docs/source/markdown/guides/reference/callbacks/index.md @@ -1,8 +1,90 @@ # Callbacks +```{grid} 2 +:gutter: 2 + +:::{card} {octicon}`download` Model Checkpoint +:link: checkpoint +:link-type: ref + +Save and manage model checkpoints during training. +::: + +:::{card} {octicon}`graph` Graph Logger +:link: graph-logger +:link-type: ref + +Log model computation graphs for visualization. +::: + +:::{card} {octicon}`package` Load Model +:link: load-model +:link-type: ref + +Load pre-trained models and weights. +::: + +:::{card} {octicon}`table` Tile Configuration +:link: tile-configuration +:link-type: ref + +Configure and manage image tiling settings. +::: + +:::{card} {octicon}`clock` Timer +:link: timer +:link-type: ref + +Track and measure execution times during training. +::: +``` + +(checkpoint)= + +## {octicon}`download` Model Checkpoint + +```{eval-rst} +.. automodule:: anomalib.callbacks.checkpoint + :members: + :show-inheritance: +``` + +(graph-logger)= + +## {octicon}`graph` Graph Logger + +```{eval-rst} +.. automodule:: anomalib.callbacks.graph + :members: + :show-inheritance: +``` + +(load-model)= + +## {octicon}`package` Load Model + +```{eval-rst} +.. automodule:: anomalib.callbacks.model_loader + :members: + :show-inheritance: +``` + +(tile-configuration)= + +## {octicon}`table` Tile Configuration + +```{eval-rst} +.. automodule:: anomalib.callbacks.tile_configuration + :members: + :show-inheritance: +``` + +(timer)= + +## {octicon}`clock` Timer + ```{eval-rst} -.. automodule:: anomalib.callbacks +.. automodule:: anomalib.callbacks.timer :members: - :exclude-members: get_visualization_callbacks :show-inheritance: ``` diff --git a/docs/source/markdown/guides/reference/cli/index.md b/docs/source/markdown/guides/reference/cli/index.md index 69f183f142..5c047b2b74 100644 --- a/docs/source/markdown/guides/reference/cli/index.md +++ b/docs/source/markdown/guides/reference/cli/index.md @@ -1,7 +1,7 @@ # CLI ```{eval-rst} -.. automodule:: anomalib.cli +.. automodule:: anomalib.cli.cli :members: :show-inheritance: ``` diff --git a/docs/source/markdown/guides/reference/data/base/datamodule.md b/docs/source/markdown/guides/reference/data/base/datamodule.md deleted file mode 100644 index 2c48711943..0000000000 --- a/docs/source/markdown/guides/reference/data/base/datamodule.md +++ /dev/null @@ -1,7 +0,0 @@ -# Base Datamodules - -```{eval-rst} -.. automodule:: anomalib.data.base.datamodule - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/base/dataset.md b/docs/source/markdown/guides/reference/data/base/dataset.md deleted file mode 100644 index 38ba53fc41..0000000000 --- a/docs/source/markdown/guides/reference/data/base/dataset.md +++ /dev/null @@ -1,7 +0,0 @@ -# Base Dataset - -```{eval-rst} -.. automodule:: anomalib.data.base.dataset - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/base/depth.md b/docs/source/markdown/guides/reference/data/base/depth.md deleted file mode 100644 index f179b07f60..0000000000 --- a/docs/source/markdown/guides/reference/data/base/depth.md +++ /dev/null @@ -1,7 +0,0 @@ -# Base Depth Data - -```{eval-rst} -.. automodule:: anomalib.data.base.depth - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/base/index.md b/docs/source/markdown/guides/reference/data/base/index.md deleted file mode 100644 index efe077950b..0000000000 --- a/docs/source/markdown/guides/reference/data/base/index.md +++ /dev/null @@ -1,43 +0,0 @@ -# Base Data - -::::{grid} - -:::{grid-item-card} {octicon}`copy` Base Dataset -:link: ./dataset -:link-type: doc - -Learn more about base anomalib dataset -::: - -:::{grid-item-card} {octicon}`file-media` Base Datamodule -:link: ./datamodule -:link-type: doc - -Learn more about base anomalib datamodule -::: - -:::{grid-item-card} {octicon}`video` Video -:link: ./video -:link-type: doc - -Learn more about base anomalib video data -::: - -:::{grid-item-card} {octicon}`database` Depth -:link: ./depth/ -:link-type: doc - -Learn more about base anomalib depth data -::: - -:::: - -```{toctree} -:caption: Base Data -:hidden: - -./dataset -./datamodule -./depth -./video -``` diff --git a/docs/source/markdown/guides/reference/data/base/video.md b/docs/source/markdown/guides/reference/data/base/video.md deleted file mode 100644 index bb9284d583..0000000000 --- a/docs/source/markdown/guides/reference/data/base/video.md +++ /dev/null @@ -1,7 +0,0 @@ -# Base Video Data - -```{eval-rst} -.. automodule:: anomalib.data.base.video - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/dataclasses/generic.md b/docs/source/markdown/guides/reference/data/dataclasses/generic.md new file mode 100644 index 0000000000..e435911394 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/dataclasses/generic.md @@ -0,0 +1,111 @@ +# Generic Dataclasses + +The generic dataclasses module provides the foundational data structures and validation logic used throughout Anomalib. These classes are designed to be flexible and type-safe, serving as the base for both PyTorch and NumPy implementations. + +```{eval-rst} +.. currentmodule:: anomalib.data.dataclasses.generic +``` + +## Core Concepts + +### Type Variables + +The module uses several type variables to ensure type safety across different implementations: + +- `ImageT`: Type variable for image data (PyTorch Image/Video or NumPy array) +- `T`: Type variable for tensor-like data (PyTorch Tensor or NumPy array) +- `MaskT`: Type variable for mask data (PyTorch Mask or NumPy array) +- `PathT`: Type variable for path data (string or list of strings) + +## Base Classes + +### InputFields + +```{eval-rst} +.. autoclass:: _InputFields + :members: + :show-inheritance: +``` + +### ImageInputFields + +```{eval-rst} +.. autoclass:: _ImageInputFields + :members: + :show-inheritance: +``` + +### VideoInputFields + +```{eval-rst} +.. autoclass:: _VideoInputFields + :members: + :show-inheritance: +``` + +### DepthInputFields + +```{eval-rst} +.. autoclass:: _DepthInputFields + :members: + :show-inheritance: +``` + +### OutputFields + +```{eval-rst} +.. autoclass:: _OutputFields + :members: + :show-inheritance: +``` + +## Mixins + +### UpdateMixin + +```{eval-rst} +.. autoclass:: UpdateMixin + :members: + :show-inheritance: +``` + +### BatchIterateMixin + +```{eval-rst} +.. autoclass:: BatchIterateMixin + :members: + :show-inheritance: +``` + +## Generic Classes + +### GenericItem + +```{eval-rst} +.. autoclass:: _GenericItem + :members: + :show-inheritance: +``` + +### GenericBatch + +```{eval-rst} +.. autoclass:: _GenericBatch + :members: + :show-inheritance: +``` + +## Field Validation + +### FieldDescriptor + +```{eval-rst} +.. autoclass:: FieldDescriptor + :members: + :show-inheritance: +``` + +## See Also + +- {doc}`torch` +- {doc}`numpy` diff --git a/docs/source/markdown/guides/reference/data/dataclasses/index.md b/docs/source/markdown/guides/reference/data/dataclasses/index.md new file mode 100644 index 0000000000..d762045abe --- /dev/null +++ b/docs/source/markdown/guides/reference/data/dataclasses/index.md @@ -0,0 +1,70 @@ +# Data Classes + +Anomalib's dataclasses provide type-safe data containers with automatic validation. They support both PyTorch and NumPy backends for flexible data handling. + +::::{grid} 1 2 2 3 +:gutter: 3 +:padding: 2 +:class-container: landing-grid + +:::{grid-item-card} {octicon}`package` Generic Classes +:link: generic +:link-type: doc +:class-card: custom-card + +Base dataclasses that define common data structures and validation logic: + +- Generic Item/Batch +- Input/Output Fields +- Validation Mixins + ++++ +[Learn More »](generic) +::: + +:::{grid-item-card} {octicon}`cpu` PyTorch Classes +:link: torch +:link-type: doc +:class-card: custom-card + +PyTorch tensor-based implementations: + +- Image, Video, Depth Items +- Batch Processing Support +- Type-safe Validation + ++++ +[Learn More »](torch) +::: + +:::{grid-item-card} {octicon}`database` NumPy Classes +:link: numpy +:link-type: doc +:class-card: custom-card + +NumPy array-based implementations: + +- Efficient Data Processing +- Array-based Containers +- Conversion Utilities + ++++ +[Learn More »](numpy) +::: +:::: + +## Documentation + +For detailed documentation and examples, see: + +- {doc}`Generic Base Classes ` +- {doc}`PyTorch Classes ` +- {doc}`NumPy Classes ` + +```{toctree} +:hidden: + +generic +torch +numpy +``` diff --git a/docs/source/markdown/guides/reference/data/dataclasses/numpy.md b/docs/source/markdown/guides/reference/data/dataclasses/numpy.md new file mode 100644 index 0000000000..a6f622ca6a --- /dev/null +++ b/docs/source/markdown/guides/reference/data/dataclasses/numpy.md @@ -0,0 +1,93 @@ +# Numpy Dataclasses + +The numpy dataclasses module provides numpy-based implementations of the generic dataclasses used in Anomalib. These classes are designed to work with numpy arrays for efficient data handling and processing in anomaly detection tasks. + +```{eval-rst} +.. currentmodule:: anomalib.data.dataclasses.numpy +``` + +## Overview + +The module includes several categories of dataclasses: + +- **Base Classes**: Generic numpy-based data structures +- **Image Classes**: Specialized for image data processing +- **Video Classes**: Designed for video data handling +- **Depth Classes**: Specific to depth-based anomaly detection + +## Base Classes + +### NumpyItem + +```{eval-rst} +.. autoclass:: NumpyItem + :members: + :show-inheritance: +``` + +### NumpyBatch + +```{eval-rst} +.. autoclass:: NumpyBatch + :members: + :show-inheritance: +``` + +## Image Classes + +### NumpyImageItem + +```{eval-rst} +.. autoclass:: NumpyImageItem + :members: + :show-inheritance: +``` + +### NumpyImageBatch + +```{eval-rst} +.. autoclass:: NumpyImageBatch + :members: + :show-inheritance: +``` + +## Video Classes + +### NumpyVideoItem + +```{eval-rst} +.. autoclass:: NumpyVideoItem + :members: + :show-inheritance: +``` + +### NumpyVideoBatch + +```{eval-rst} +.. autoclass:: NumpyVideoBatch + :members: + :show-inheritance: +``` + +## Depth Classes + +### NumpyDepthItem + +```{eval-rst} +.. autoclass:: NumpyDepthItem + :members: + :show-inheritance: +``` + +### NumpyDepthBatch + +```{eval-rst} +.. autoclass:: NumpyDepthBatch + :members: + :show-inheritance: +``` + +## See Also + +- {doc}`../index` +- {doc}`../torch` diff --git a/docs/source/markdown/guides/reference/data/dataclasses/torch.md b/docs/source/markdown/guides/reference/data/dataclasses/torch.md new file mode 100644 index 0000000000..c8010dbad4 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/dataclasses/torch.md @@ -0,0 +1,109 @@ +# Torch Dataclasses + +The torch dataclasses module provides PyTorch-based implementations of the generic dataclasses used in Anomalib. These classes are designed to work with PyTorch tensors for efficient data handling and processing in anomaly detection tasks. + +```{eval-rst} +.. currentmodule:: anomalib.data.dataclasses.torch +``` + +## Overview + +The module includes several categories of dataclasses: + +- **Base Classes**: Generic PyTorch-based data structures +- **Image Classes**: Specialized for image data processing +- **Video Classes**: Designed for video data handling +- **Depth Classes**: Specific to depth-based anomaly detection + +## Base Classes + +### DatasetItem + +```{eval-rst} +.. autoclass:: DatasetItem + :members: + :show-inheritance: +``` + +### Batch + +```{eval-rst} +.. autoclass:: Batch + :members: + :show-inheritance: +``` + +### InferenceBatch + +```{eval-rst} +.. autoclass:: InferenceBatch + :members: + :show-inheritance: +``` + +### ToNumpyMixin + +```{eval-rst} +.. autoclass:: ToNumpyMixin + :members: + :show-inheritance: +``` + +## Image Classes + +### ImageItem + +```{eval-rst} +.. autoclass:: ImageItem + :members: + :show-inheritance: +``` + +### ImageBatch + +```{eval-rst} +.. autoclass:: ImageBatch + :members: + :show-inheritance: +``` + +## Video Classes + +### VideoItem + +```{eval-rst} +.. autoclass:: VideoItem + :members: + :show-inheritance: +``` + +### VideoBatch + +```{eval-rst} +.. autoclass:: VideoBatch + :members: + :show-inheritance: +``` + +## Depth Classes + +### DepthItem + +```{eval-rst} +.. autoclass:: DepthItem + :members: + :show-inheritance: +``` + +### DepthBatch + +```{eval-rst} +.. autoclass:: DepthBatch + :members: + :show-inheritance: +``` + +## See Also + +- {doc}`../index` +- {doc}`../numpy` diff --git a/docs/source/markdown/guides/reference/data/datamodules/base/image.md b/docs/source/markdown/guides/reference/data/datamodules/base/image.md new file mode 100644 index 0000000000..aa08a3a351 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/base/image.md @@ -0,0 +1,7 @@ +# Image Base Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.base.image + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/base/index.md b/docs/source/markdown/guides/reference/data/datamodules/base/index.md new file mode 100644 index 0000000000..f85477c722 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/base/index.md @@ -0,0 +1,21 @@ +# Base Datamodules + +Base DataModules provide core functionality for specific data types that can be extended by other DataModules. + +::::{grid} 2 +:gutter: 2 + +:::{card} Image DataModule +:link: image +:link-type: doc + +Base DataModule for image-based anomaly detection tasks. +::: + +:::{card} Video DataModule +:link: video +:link-type: doc + +Base DataModule for video-based anomaly detection tasks. +::: +:::: diff --git a/docs/source/markdown/guides/reference/data/datamodules/base/video.md b/docs/source/markdown/guides/reference/data/datamodules/base/video.md new file mode 100644 index 0000000000..4e6522f5a6 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/base/video.md @@ -0,0 +1,7 @@ +# Video Base Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.base.video + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/depth/folder_3d.md b/docs/source/markdown/guides/reference/data/datamodules/depth/folder_3d.md new file mode 100644 index 0000000000..1b29484123 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/depth/folder_3d.md @@ -0,0 +1,7 @@ +# Folder Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.depth.folder_3d + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/depth/index.md b/docs/source/markdown/guides/reference/data/datamodules/depth/index.md new file mode 100644 index 0000000000..9e8ce32a4e --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/depth/index.md @@ -0,0 +1,31 @@ +# Depth Datamodules + +Anomalib provides datamodules for handling depth-based anomaly detection datasets. These datamodules are designed to work with both RGB and depth information for 3D anomaly detection tasks. + +## Available Datamodules + +```{grid} 2 +:gutter: 2 + +:::{grid-item-card} MVTec 3D +:link: mvtec_3d +:link-type: doc + +MVTec 3D-AD dataset datamodule for unsupervised 3D anomaly detection and localization. +::: + +:::{grid-item-card} Folder 3D +:link: folder_3d +:link-type: doc + +Custom folder-based 3D datamodule for organizing your own depth-based anomaly detection dataset. +::: +``` + +```{toctree} +:hidden: +:maxdepth: 1 + +mvtec_3d +folder_3d +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/depth/mvtec_3d.md b/docs/source/markdown/guides/reference/data/datamodules/depth/mvtec_3d.md new file mode 100644 index 0000000000..ae5c2cd842 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/depth/mvtec_3d.md @@ -0,0 +1,7 @@ +# MVTec 3D Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.depth.mvtec_3d + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image.md b/docs/source/markdown/guides/reference/data/datamodules/image.md new file mode 100644 index 0000000000..c79a299a92 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image.md @@ -0,0 +1,60 @@ +# Image Datamodules + +Image datamodules in Anomalib are designed to handle image-based anomaly detection datasets. They provide a standardized interface for loading and processing image data for both training and inference. + +## Available Datamodules + +```{grid} 3 +:gutter: 2 + +:::{grid-item-card} BTech +:link: anomalib.data.datamodules.image.BTech +:link-type: doc + +Surface defect detection in steel manufacturing. +::: + +:::{grid-item-card} Datumaro +:link: anomalib.data.datamodules.image.Datumaro +:link-type: doc + +Dataset format compatible with Intel Geti™. +::: + +:::{grid-item-card} Folder +:link: anomalib.data.datamodules.image.Folder +:link-type: doc + +Custom folder-based dataset organization. +::: + +:::{grid-item-card} Kolektor +:link: anomalib.data.datamodules.image.Kolektor +:link-type: doc + +Surface defect detection in electrical commutators. +::: + +:::{grid-item-card} MVTec +:link: anomalib.data.datamodules.image.MVTec +:link-type: doc + +Industrial anomaly detection benchmark. +::: + +:::{grid-item-card} Visa +:link: anomalib.data.datamodules.image.Visa +:link-type: doc + +Visual inspection of surface anomalies. +::: +``` + +## API Reference + +```{eval-rst} +.. automodule:: anomalib.data + :members: BTech, Datumaro, Folder, Kolektor, MVTec, Visa + :undoc-members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/btech.md b/docs/source/markdown/guides/reference/data/datamodules/image/btech.md new file mode 100644 index 0000000000..5abeb7c4c8 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/btech.md @@ -0,0 +1,7 @@ +# Btech Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.image.btech + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/datumaro.md b/docs/source/markdown/guides/reference/data/datamodules/image/datumaro.md new file mode 100644 index 0000000000..a26a8e4e9b --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/datumaro.md @@ -0,0 +1,7 @@ +# Datumaro Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.image.datumaro + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/folder.md b/docs/source/markdown/guides/reference/data/datamodules/image/folder.md new file mode 100644 index 0000000000..918d5a22e3 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/folder.md @@ -0,0 +1,7 @@ +# Folder Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.image.folder + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/index.md b/docs/source/markdown/guides/reference/data/datamodules/image/index.md new file mode 100644 index 0000000000..50e1c4a86d --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/index.md @@ -0,0 +1,63 @@ +# Image Datamodules + +Anomalib provides various datamodules for handling image-based anomaly detection datasets. These datamodules support both standard image datasets and custom folder structures. + +## Available Datamodules + +```{grid} 3 +:gutter: 2 + +:::{grid-item-card} BTech +:link: btech +:link-type: doc + +BTech dataset datamodule for surface defect detection. +::: + +:::{grid-item-card} Datumaro +:link: datumaro +:link-type: doc + +Datumaro format datamodule (compatible with Intel Geti™). +::: + +:::{grid-item-card} Folder +:link: folder +:link-type: doc + +Custom folder-based datamodule for organizing your own image dataset. +::: + +:::{grid-item-card} Kolektor +:link: kolektor +:link-type: doc + +Kolektor Surface-Defect dataset datamodule. +::: + +:::{grid-item-card} MVTec +:link: mvtec +:link-type: doc + +MVTec AD dataset datamodule for unsupervised anomaly detection. +::: + +:::{grid-item-card} Visa +:link: visa +:link-type: doc + +Visual Inspection of Surface Anomalies (VisA) dataset datamodule. +::: +``` + +```{toctree} +:hidden: +:maxdepth: 1 + +btech +datumaro +folder +kolektor +mvtec +visa +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/kolektor.md b/docs/source/markdown/guides/reference/data/datamodules/image/kolektor.md new file mode 100644 index 0000000000..e86a321ea2 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/kolektor.md @@ -0,0 +1,7 @@ +# Kolektor Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.image.kolektor + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/mvtec.md b/docs/source/markdown/guides/reference/data/datamodules/image/mvtec.md new file mode 100644 index 0000000000..3ef6847c0d --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/mvtec.md @@ -0,0 +1,7 @@ +# MVTec Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.image.mvtec + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/visa.md b/docs/source/markdown/guides/reference/data/datamodules/image/visa.md new file mode 100644 index 0000000000..0c0bed8b45 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/visa.md @@ -0,0 +1,7 @@ +# Visa Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.image.visa + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/index.md b/docs/source/markdown/guides/reference/data/datamodules/index.md new file mode 100644 index 0000000000..699c4e7b3c --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/index.md @@ -0,0 +1,107 @@ +# Datamodules + +Anomalib provides various datamodules for different types of data modalities. These datamodules are organized into three main categories: + +## Image Datamodules + +```{grid} 3 +:gutter: 2 + +:::{grid-item-card} BTech +:link: image/btech +:link-type: doc + +BTech dataset datamodule for surface defect detection. +::: + +:::{grid-item-card} Datumaro +:link: image/datumaro +:link-type: doc + +Datumaro format datamodule (compatible with Intel Geti™). +::: + +:::{grid-item-card} Folder +:link: image/folder +:link-type: doc + +Custom folder-based datamodule for organizing your own image dataset. +::: + +:::{grid-item-card} Kolektor +:link: image/kolektor +:link-type: doc + +Kolektor Surface-Defect dataset datamodule. +::: + +:::{grid-item-card} MVTec +:link: image/mvtec +:link-type: doc + +MVTec AD dataset datamodule for unsupervised anomaly detection. +::: + +:::{grid-item-card} Visa +:link: image/visa +:link-type: doc + +Visual Inspection of Surface Anomalies (VisA) dataset datamodule. +::: +``` + +## Video Datamodules + +```{grid} 3 +:gutter: 2 + +:::{grid-item-card} Avenue +:link: video/avenue +:link-type: doc + +CUHK Avenue dataset datamodule for video anomaly detection. +::: + +:::{grid-item-card} ShanghaiTech +:link: video/shanghaitech +:link-type: doc + +ShanghaiTech dataset datamodule for video anomaly detection. +::: + +:::{grid-item-card} UCSDped +:link: video/ucsdped +:link-type: doc + +UCSD Pedestrian dataset datamodule for video anomaly detection. +::: +``` + +```{toctree} +:hidden: +:maxdepth: 1 + +depth/index +image/index +video/index +``` + +## Depth Datamodules + +```{grid} 2 +:gutter: 2 + +:::{grid-item-card} MVTec 3D +:link: depth/mvtec_3d +:link-type: doc + +MVTec 3D-AD dataset datamodule for unsupervised 3D anomaly detection and localization. +::: + +:::{grid-item-card} Folder 3D +:link: depth/folder_3d +:link-type: doc + +Custom folder-based 3D datamodule for organizing your own depth-based anomaly detection dataset. +::: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/video.md b/docs/source/markdown/guides/reference/data/datamodules/video.md new file mode 100644 index 0000000000..1899e91f51 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/video.md @@ -0,0 +1,39 @@ +# Video Datamodules + +Video datamodules in Anomalib are designed to handle video-based anomaly detection datasets. They provide a standardized interface for loading and processing video data for both training and inference. + +## Available Datamodules + +```{grid} 3 +:gutter: 2 + +:::{grid-item-card} Avenue +:link: anomalib.data.Avenue +:link-type: doc + +CUHK Avenue dataset for video anomaly detection. +::: + +:::{grid-item-card} ShanghaiTech +:link: anomalib.data.ShanghaiTech +:link-type: doc + +ShanghaiTech dataset for video anomaly detection. +::: + +:::{grid-item-card} UCSDped +:link: anomalib.data.UCSDped +:link-type: doc + +UCSD Pedestrian dataset for video anomaly detection. +::: +``` + +## API Reference + +```{eval-rst} +.. automodule:: anomalib.data + :members: Avenue, ShanghaiTech, UCSDped + :undoc-members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/video/avenue.md b/docs/source/markdown/guides/reference/data/datamodules/video/avenue.md new file mode 100644 index 0000000000..3f4cf842a0 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/video/avenue.md @@ -0,0 +1,7 @@ +#  Avenue Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.video.avenue + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/video/index.md b/docs/source/markdown/guides/reference/data/datamodules/video/index.md new file mode 100644 index 0000000000..0be1043020 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/video/index.md @@ -0,0 +1,39 @@ +# Video Datamodules + +Anomalib provides datamodules for handling video-based anomaly detection datasets. These datamodules are specifically designed to work with video sequences and support various video anomaly detection benchmarks. + +## Available Datamodules + +```{grid} 3 +:gutter: 2 + +:::{grid-item-card} Avenue +:link: avenue +:link-type: doc + +CUHK Avenue dataset datamodule for video anomaly detection. +::: + +:::{grid-item-card} ShanghaiTech +:link: shanghaitech +:link-type: doc + +ShanghaiTech dataset datamodule for video anomaly detection. +::: + +:::{grid-item-card} UCSDped +:link: ucsdped +:link-type: doc + +UCSD Pedestrian dataset datamodule for video anomaly detection. +::: +``` + +```{toctree} +:hidden: +:maxdepth: 1 + +avenue +shanghaitech +ucsdped +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/video/shanghaitech.md b/docs/source/markdown/guides/reference/data/datamodules/video/shanghaitech.md new file mode 100644 index 0000000000..fa919af894 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/video/shanghaitech.md @@ -0,0 +1,7 @@ +# ShanghaiTech Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.video.shanghaitech + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/datamodules/video/ucsdped.md b/docs/source/markdown/guides/reference/data/datamodules/video/ucsdped.md new file mode 100644 index 0000000000..dffac0534b --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/video/ucsdped.md @@ -0,0 +1,7 @@ +# UCSDped Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.video.ucsd_ped + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/depth/folder_3d.md b/docs/source/markdown/guides/reference/data/depth/folder_3d.md deleted file mode 100644 index 3b0f93280d..0000000000 --- a/docs/source/markdown/guides/reference/data/depth/folder_3d.md +++ /dev/null @@ -1,7 +0,0 @@ -# Folder 3D Data - -```{eval-rst} -.. automodule:: anomalib.data.depth.folder_3d - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/depth/index.md b/docs/source/markdown/guides/reference/data/depth/index.md deleted file mode 100644 index ca63ee078d..0000000000 --- a/docs/source/markdown/guides/reference/data/depth/index.md +++ /dev/null @@ -1,27 +0,0 @@ -# Depth Data - -::::{grid} - -:::{grid-item-card} Folder 3D -:link: ./folder_3d -:link-type: doc - -Learn more about custom folder 3D dataset. -::: - -:::{grid-item-card} MVTec 3D -:link: ./mvtec_3d -:link-type: doc - -Learn more about MVTec 3D dataset -::: - -:::: - -```{toctree} -:caption: Depth -:hidden: - -./folder_3d -./mvtec_3d -``` diff --git a/docs/source/markdown/guides/reference/data/depth/mvtec_3d.md b/docs/source/markdown/guides/reference/data/depth/mvtec_3d.md deleted file mode 100644 index dfcf4fd814..0000000000 --- a/docs/source/markdown/guides/reference/data/depth/mvtec_3d.md +++ /dev/null @@ -1,7 +0,0 @@ -# MVTec 3D Data - -```{eval-rst} -.. automodule:: anomalib.data.depth.mvtec_3d - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/image/btech.md b/docs/source/markdown/guides/reference/data/image/btech.md deleted file mode 100644 index 92199ccd1b..0000000000 --- a/docs/source/markdown/guides/reference/data/image/btech.md +++ /dev/null @@ -1,7 +0,0 @@ -#  BTech Data - -```{eval-rst} -.. automodule:: anomalib.data.image.btech - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/image/folder.md b/docs/source/markdown/guides/reference/data/image/folder.md deleted file mode 100644 index 307262b9c4..0000000000 --- a/docs/source/markdown/guides/reference/data/image/folder.md +++ /dev/null @@ -1,7 +0,0 @@ -# Folder Data - -```{eval-rst} -.. automodule:: anomalib.data.image.folder - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/image/index.md b/docs/source/markdown/guides/reference/data/image/index.md deleted file mode 100644 index 2525d0d914..0000000000 --- a/docs/source/markdown/guides/reference/data/image/index.md +++ /dev/null @@ -1,51 +0,0 @@ -# Image Data - -::::{grid} - -:::{grid-item-card} BTech -:link: ./btech -:link-type: doc - -Learn more about BTech dataset. -::: - -:::{grid-item-card} Folder -:link: ./folder -:link-type: doc - -Learn more about custom folder dataset. -::: - -:::{grid-item-card} Kolektor -:link: ./kolektor -:link-type: doc - -Learn more about Kolektor dataset. -::: - -:::{grid-item-card} MVTec 2D -:link: ./mvtec -:link-type: doc - -Learn more about MVTec 2D dataset -::: - -:::{grid-item-card} Visa -:link: ./visa -:link-type: doc - -Learn more about Visa dataset. -::: - -:::: - -```{toctree} -:caption: Image -:hidden: - -./btech -./folder -./kolektor -./mvtec -./visa -``` diff --git a/docs/source/markdown/guides/reference/data/image/kolektor.md b/docs/source/markdown/guides/reference/data/image/kolektor.md deleted file mode 100644 index ace9d62127..0000000000 --- a/docs/source/markdown/guides/reference/data/image/kolektor.md +++ /dev/null @@ -1,7 +0,0 @@ -# Kolektor Data - -```{eval-rst} -.. automodule:: anomalib.data.image.kolektor - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/image/mvtec.md b/docs/source/markdown/guides/reference/data/image/mvtec.md deleted file mode 100644 index c0cbb77735..0000000000 --- a/docs/source/markdown/guides/reference/data/image/mvtec.md +++ /dev/null @@ -1,7 +0,0 @@ -# MVTec Data - -```{eval-rst} -.. automodule:: anomalib.data.image.mvtec - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/image/visa.md b/docs/source/markdown/guides/reference/data/image/visa.md deleted file mode 100644 index 43bfa7c7fb..0000000000 --- a/docs/source/markdown/guides/reference/data/image/visa.md +++ /dev/null @@ -1,7 +0,0 @@ -# Visa Data - -```{eval-rst} -.. automodule:: anomalib.data.image.visa - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/index.md b/docs/source/markdown/guides/reference/data/index.md index e646d06f91..26635a0380 100644 --- a/docs/source/markdown/guides/reference/data/index.md +++ b/docs/source/markdown/guides/reference/data/index.md @@ -1,46 +1,73 @@ # Data -Anomalib data can be categorized into four main types: base, image, video, and depth. Image, video and depth datasets are based on the base dataset and datamodule implementations. +A comprehensive data handling pipeline with modular components for anomaly detection tasks. -::::{grid} +::::{grid} 1 2 2 3 +:gutter: 3 +:padding: 2 +:class-container: landing-grid -:::{grid-item-card} {octicon}`copy` Base Classes -:link: ./base/index +:::{grid-item-card} {octicon}`package` Data Classes +:link: ./dataclasses/index :link-type: doc +:class-card: custom-card -Learn more about base anomalib data interfaces. +Core data structures that define how data is represented and validated throughout the pipeline. Features type-safe containers, dual backend support, and automatic validation. + ++++ +[Learn more »](./dataclasses/index) ::: -:::{grid-item-card} {octicon}`file-media` Image -:link: ./image/index +:::{grid-item-card} {octicon}`database` Datasets +:link: ./datasets/index :link-type: doc +:class-card: custom-card + +Ready-to-use PyTorch Dataset implementations of standard benchmark datasets (MVTec, BTech) and support for custom datasets across multiple modalities (Image, Video, Depth). -Learn more about anomalib image datasets. ++++ +[Learn more »](./datasets/index) ::: -:::{grid-item-card} {octicon}`video` Video -:link: ./video/index +:::{grid-item-card} {octicon}`workflow` Data Modules +:link: ./datamodules/index :link-type: doc +:class-card: custom-card + +Lightning implementations of these PyTorch datasets that provide automated data loading, train/val/test splitting, and distributed training support through the PyTorch Lightning DataModule interface. -Learn more about anomalib video datasets. ++++ +[Learn more »](./datamodules/index) ::: +:::: + +## Additional Resources -:::{grid-item-card} {octicon}`database` Depth -:link: ./depth/index +::::{grid} 2 2 2 2 +:gutter: 2 +:padding: 1 + +:::{grid-item-card} {octicon}`tools` Data Utils +:link: ./utils/index :link-type: doc -Learn more about anomalib depth datasets. +Helper functions and utilities for data processing and augmentation. ::: +:::{grid-item-card} {octicon}`book` Tutorials +:link: ../tutorials/index +:link-type: doc + +Step-by-step guides on using the data components. +::: :::: ```{toctree} -:caption: Data +:caption: Data Components :hidden: -./base/index -./image/index -./video/index -./depth/index +./dataclasses/index +./datasets/index +./datamodules/index ./utils/index ``` diff --git a/docs/source/markdown/guides/reference/data/utils/index.md b/docs/source/markdown/guides/reference/data/utils/index.md index 7a7fc97efa..1e388c2080 100644 --- a/docs/source/markdown/guides/reference/data/utils/index.md +++ b/docs/source/markdown/guides/reference/data/utils/index.md @@ -1,6 +1,6 @@ # Data Utils -::::{grid} 1 2 2 2 +::::{grid} 1 3 3 3 :margin: 1 1 0 0 :gutter: 1 @@ -11,13 +11,6 @@ Learn more about anomalib API and CLI. ::: -:::{grid-item-card} {octicon}`question` Data Transforms -:link: transforms -:link-type: doc - -Learn how to use anomalib for your anomaly detection tasks. -::: - :::{grid-item-card} {octicon}`telescope` Tiling :link: tiling :link-type: doc diff --git a/docs/source/markdown/guides/reference/data/utils/transforms.md b/docs/source/markdown/guides/reference/data/utils/transforms.md deleted file mode 100644 index e59e9ae3d6..0000000000 --- a/docs/source/markdown/guides/reference/data/utils/transforms.md +++ /dev/null @@ -1,7 +0,0 @@ -# Data Transforms - -```{eval-rst} -.. automodule:: anomalib.data.utils.transforms - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/video/avenue.md b/docs/source/markdown/guides/reference/data/video/avenue.md deleted file mode 100644 index 2dd00dce4d..0000000000 --- a/docs/source/markdown/guides/reference/data/video/avenue.md +++ /dev/null @@ -1,7 +0,0 @@ -# Avenue Data - -```{eval-rst} -.. automodule:: anomalib.data.video.avenue - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/video/index.md b/docs/source/markdown/guides/reference/data/video/index.md deleted file mode 100644 index 9f357053fa..0000000000 --- a/docs/source/markdown/guides/reference/data/video/index.md +++ /dev/null @@ -1,35 +0,0 @@ -# Video Data - -::::{grid} - -:::{grid-item-card} Avenue -:link: ./avenue -:link-type: doc - -Learn more about Avenue dataset. -::: - -:::{grid-item-card} Shanghai Tech -:link: ./shanghaitech -:link-type: doc - -Learn more about Shanghai Tech dataset. -::: - -:::{grid-item-card} UCSD -:link: ./ucsd_ped -:link-type: doc - -Learn more about UCSD Ped1 and Ped2 datasets. -::: - -:::: - -```{toctree} -:caption: Image -:hidden: - -./avenue -./shanghaitech -./ucsd_ped -``` diff --git a/docs/source/markdown/guides/reference/data/video/shanghaitech.md b/docs/source/markdown/guides/reference/data/video/shanghaitech.md deleted file mode 100644 index 38b9ea77c0..0000000000 --- a/docs/source/markdown/guides/reference/data/video/shanghaitech.md +++ /dev/null @@ -1,7 +0,0 @@ -# Shanghai Tech Data - -```{eval-rst} -.. automodule:: anomalib.data.video.shanghaitech - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/data/video/ucsd_ped.md b/docs/source/markdown/guides/reference/data/video/ucsd_ped.md deleted file mode 100644 index 0236868341..0000000000 --- a/docs/source/markdown/guides/reference/data/video/ucsd_ped.md +++ /dev/null @@ -1,7 +0,0 @@ -# UCSD Data - -```{eval-rst} -.. automodule:: anomalib.data.video.ucsd_ped - :members: - :show-inheritance: -``` diff --git a/docs/source/markdown/guides/reference/deploy/index.md b/docs/source/markdown/guides/reference/deploy/index.md index 58dee6829c..463dcd80b4 100644 --- a/docs/source/markdown/guides/reference/deploy/index.md +++ b/docs/source/markdown/guides/reference/deploy/index.md @@ -1,4 +1,4 @@ -# Deployment +# Inference ```{eval-rst} .. automodule:: anomalib.deploy diff --git a/docs/source/markdown/guides/reference/engine/index.md b/docs/source/markdown/guides/reference/engine/index.md index 629a8bdd0b..25b4251880 100644 --- a/docs/source/markdown/guides/reference/engine/index.md +++ b/docs/source/markdown/guides/reference/engine/index.md @@ -1,7 +1,8 @@ # Engine ```{eval-rst} -.. automodule:: anomalib.engine +.. currentmodule:: anomalib.engine.engine +.. autoclass:: Engine :members: :show-inheritance: ``` diff --git a/docs/source/markdown/guides/reference/index.md b/docs/source/markdown/guides/reference/index.md index 435569f5f2..b5931cf29c 100644 --- a/docs/source/markdown/guides/reference/index.md +++ b/docs/source/markdown/guides/reference/index.md @@ -2,86 +2,138 @@ This section contains the API and CLI reference for anomalib. -::::{grid} 1 2 2 3 -:margin: 1 1 0 0 -:gutter: 1 +## Core Components + +::::{grid} 2 2 2 3 +:gutter: 2 +:padding: 1 :::{grid-item-card} {octicon}`database` Data :link: ./data/index :link-type: doc -Learn more about anomalib datamodules. +Core component for data handling and datasets. ::: :::{grid-item-card} {octicon}`dependabot` Models :link: ./models/index :link-type: doc -Learn more about image and video models. +Anomaly detection model implementations. ::: :::{grid-item-card} {octicon}`gear` Engine :link: ./engine/index :link-type: doc -Learn more about anomalib Engine. +Core training and inference engine. +::: +:::: + +## Processing & Analysis + +::::{grid} 2 2 2 3 +:gutter: 2 +:padding: 1 + +:::{grid-item-card} {octicon}`filter` Pre-processing +:link: ./pre_processing/index +:link-type: doc + +Data preparation and augmentation. +::: + +:::{grid-item-card} {octicon}`filter` Post-processing +:link: ./post_processing/index +:link-type: doc + +Anomaly map processing and thresholding. ::: :::{grid-item-card} {octicon}`meter` Metrics :link: ./metrics/index :link-type: doc -Learn more about anomalib metrics +Performance evaluation metrics. ::: +:::: + +## Framework Components + +::::{grid} 2 2 2 3 +:gutter: 2 +:padding: 1 :::{grid-item-card} {octicon}`graph` Loggers :link: ./loggers/index :link-type: doc -Learn more about anomalib loggers +Experiment logging and tracking. ::: :::{grid-item-card} {octicon}`gear` Callbacks :link: ./callbacks/index :link-type: doc -Learn more about anomalib callbacks +Training callbacks and hooks. ::: -:::{grid-item-card} {octicon}`code-square` CLI -:link: ./cli/index +:::{grid-item-card} {octicon}`workflow` Pipelines +:link: ./pipelines/index :link-type: doc -Learn more about anomalib CLI +Training and optimization pipelines. ::: -:::{grid-item-card} {octicon}`cpu` Deployment -:link: ./deploy/index +:::{grid-item-card} {octicon}`image` Visualization +:link: ./visualization/index :link-type: doc -Learn more about anomalib CLI +Result visualization tools. ::: -:::{grid-item-card} {octicon}`workflow` Pipelines -:link: ./pipelines/index +:::{grid-item-card} {octicon}`tools` Utils +:link: ./utils/index :link-type: doc -Learn more about anomalib hpo, sweep and benchmarking pipelines +Utility functions and helpers. ::: +:::{grid-item-card} {octicon}`terminal` CLI +:link: ./cli/index +:link-type: doc + +Command line interface tools. +::: +:::: + +::::{grid} 1 +:gutter: 2 +:padding: 1 + +:::{grid-item-card} {octicon}`cpu` Inference +:link: ./deploy/index +:link-type: doc + +Model inference and optimization. +::: :::: ```{toctree} -:caption: Data +:caption: Reference :hidden: ./data/index ./models/index ./engine/index +./pre_processing/index +./post_processing/index ./metrics/index ./loggers/index ./callbacks/index +./pipelines/index +./visualization/index +./utils/index ./cli/index ./deploy/index -./pipelines/index ``` diff --git a/docs/source/markdown/guides/reference/loggers/index.md b/docs/source/markdown/guides/reference/loggers/index.md index 6f89dc102c..de1ce52213 100644 --- a/docs/source/markdown/guides/reference/loggers/index.md +++ b/docs/source/markdown/guides/reference/loggers/index.md @@ -1,8 +1,73 @@ # Loggers +```{grid} 2 +:gutter: 2 + +:::{card} Comet Logger +:link: comet-logger +:link-type: ref + +Monitor your experiments with Comet's comprehensive ML platform. +::: + +:::{card} Wandb Logger +:link: wandb-logger +:link-type: ref + +Track and visualize your ML experiments with Weights & Biases. +::: + +:::{card} Tensorboard Logger +:link: tensorboard-logger +:link-type: ref + +Visualize your training metrics with TensorBoard. +::: + +:::{card} MLFlow Logger +:link: mlflow-logger +:link-type: ref + +Track and manage your ML lifecycle with MLflow. +::: +``` + +(comet-logger)= + +## Comet Logger + +```{eval-rst} +.. automodule:: anomalib.loggers.comet + :members: + :show-inheritance: +``` + +(wandb-logger)= + +## Wandb Logger + +```{eval-rst} +.. automodule:: anomalib.loggers.wandb + :members: + :show-inheritance: +``` + +(tensorboard-logger)= + +## Tensorboard Logger + +```{eval-rst} +.. automodule:: anomalib.loggers.tensorboard + :members: + :show-inheritance: +``` + +(mlflow-logger)= + +## MLFlow Logger + ```{eval-rst} -.. automodule:: anomalib.loggers +.. automodule:: anomalib.loggers.mlflow :members: - :exclude-members: get_experiment_logger, configure_logger :show-inheritance: ``` diff --git a/docs/source/markdown/guides/reference/models/image/fre.md b/docs/source/markdown/guides/reference/models/image/fre.md new file mode 100644 index 0000000000..180f8d3775 --- /dev/null +++ b/docs/source/markdown/guides/reference/models/image/fre.md @@ -0,0 +1,13 @@ +# FRE + +```{eval-rst} +.. automodule:: anomalib.models.image.fre.lightning_model + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.fre.torch_model + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/models/image/index.md b/docs/source/markdown/guides/reference/models/image/index.md index a872a2c7b2..cabd819860 100644 --- a/docs/source/markdown/guides/reference/models/image/index.md +++ b/docs/source/markdown/guides/reference/models/image/index.md @@ -67,6 +67,13 @@ EfficientAD: Accurate Visual Anomaly Detection at Millisecond-Level Latencies FastFlow: Unsupervised Anomaly Detection and Localization via 2D Normalizing Flows ::: +:::{grid-item-card} {material-regular}`model_training;1.5em` FRE +:link: ./fre +:link-type: doc + +FRE: A Fast Method For Anomaly Detection And Segmentation +::: + :::{grid-item-card} {material-regular}`model_training;1.5em` GANomaly :link: ./ganomaly :link-type: doc @@ -109,6 +116,13 @@ Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold ::: +:::{grid-item-card} {material-regular}`model_training;1.5em` VLM-AD +:link: ./vlm_ad +:link-type: doc + +VLM-AD: Vision-Language Model for Anomaly Detection +::: + :::{grid-item-card} {material-regular}`model_training;1.5em` WinCLIP :link: ./winclip :link-type: doc @@ -130,6 +144,7 @@ WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation ./dsr ./efficient_ad ./fastflow +./fre ./ganomaly ./padim ./patchcore diff --git a/docs/source/markdown/guides/reference/models/image/vlm_ad.md b/docs/source/markdown/guides/reference/models/image/vlm_ad.md new file mode 100644 index 0000000000..3869f74ebc --- /dev/null +++ b/docs/source/markdown/guides/reference/models/image/vlm_ad.md @@ -0,0 +1,8 @@ +# VLM-AD + +```{eval-rst} +.. automodule:: anomalib.models.image.vlm_ad + :members: + :show-inheritance: + :special-members: __all__ +``` diff --git a/docs/source/markdown/guides/reference/models/index.md b/docs/source/markdown/guides/reference/models/index.md index a8ad7ffa9d..bb705c403e 100644 --- a/docs/source/markdown/guides/reference/models/index.md +++ b/docs/source/markdown/guides/reference/models/index.md @@ -8,21 +8,30 @@ :link: ./components/index :link-type: doc -Learn more about components to design your own anomaly detection models. +Core building blocks and utilities for creating custom anomaly detection models, including feature extractors, anomaly scoring functions, and visualization tools. + ++++ +[Learn more »](./components/index) ::: :::{grid-item-card} {octicon}`file-media` Image Models :link: ./image/index :link-type: doc -Learn more about image anomaly detection models. +Collection of state-of-the-art deep learning models for detecting anomalies in images, including both reconstruction and embedding-based approaches. + ++++ +[Learn more »](./image/index) ::: :::{grid-item-card} {octicon}`video` Video Models :link: ./video/index :link-type: doc -Learn more about video anomaly detection models. +Advanced models designed specifically for anomaly detection in video sequences, leveraging temporal information and motion patterns. + ++++ +[Learn more »](./video/index) ::: :::: diff --git a/docs/source/markdown/guides/reference/post_processing/index.md b/docs/source/markdown/guides/reference/post_processing/index.md new file mode 100644 index 0000000000..02bdb4d638 --- /dev/null +++ b/docs/source/markdown/guides/reference/post_processing/index.md @@ -0,0 +1,46 @@ +# Post-processing + +::::{grid} 1 2 2 2 +:gutter: 3 +:padding: 2 + +:::{grid-item-card} {octicon}`gear` Base Post-processor +:link: base-post-processor +:link-type: ref + +Base class for post-processing. + ++++ +[Learn more »](base-post-processor) +::: + +:::{grid-item-card} {octicon}`gear` One-class Post-processor +:link: one-class-post-processor +:link-type: ref + +Post-processor for one-class anomaly detection. + ++++ +[Learn more »](one-class-post-processor) +::: +:::: + +(base-post-processor)= + +## Base Post-processor + +```{eval-rst} +.. automodule:: anomalib.post_processing.base + :members: + :show-inheritance: +``` + +(one-class-post-processor)= + +## One-class Post-processor + +```{eval-rst} +.. automodule:: anomalib.post_processing.one_class + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/pre_processing/index.md b/docs/source/markdown/guides/reference/pre_processing/index.md new file mode 100644 index 0000000000..23738e9fe5 --- /dev/null +++ b/docs/source/markdown/guides/reference/pre_processing/index.md @@ -0,0 +1,7 @@ +# Pre-processing + +```{eval-rst} +.. automodule:: anomalib.pre_processing + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/topic/index.md b/docs/source/markdown/guides/topic/index.md deleted file mode 100644 index bd8a29d718..0000000000 --- a/docs/source/markdown/guides/topic/index.md +++ /dev/null @@ -1,7 +0,0 @@ -# Topic Guide - -This section contains design documents and other internals of anomalib. - -```{warning} -This section is under construction 🚧 -``` diff --git a/pyproject.toml b/pyproject.toml index efdad6e41c..5d72ebd91b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ loggers = [ ] notebooks = ["gitpython", "ipykernel", "ipywidgets", "notebook"] docs = [ - "myst-parser", + "myst-parser[linkify]", "nbsphinx", "pandoc", "sphinx", diff --git a/src/anomalib/__init__.py b/src/anomalib/__init__.py index 281e5df759..dde3a9da26 100644 --- a/src/anomalib/__init__.py +++ b/src/anomalib/__init__.py @@ -1,4 +1,33 @@ -"""Anomalib library for research and benchmarking.""" +"""Anomalib library for research and benchmarking. + +This library provides tools and utilities for anomaly detection research and +benchmarking. The key components include: + + - Multiple state-of-the-art anomaly detection models + - Standardized training and evaluation pipelines + - Support for various data formats and tasks + - Visualization and analysis tools + - Benchmarking utilities + +Example: + >>> from anomalib.models import Padim + >>> # Create and train model + >>> model = Padim() + >>> model.train(train_dataloader) + >>> # Generate predictions + >>> predictions = model.predict(test_dataloader) + +The library supports: + - Classification and segmentation tasks + - One-class, zero-shot, and few-shot learning + - Multiple input formats (images, videos) + - Custom dataset integration + - Extensive configuration options + +Note: + The library is designed for both research and production use cases, + with a focus on reproducibility and ease of use. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,7 +38,24 @@ class LearningType(str, Enum): - """Learning type defining how the model learns from the dataset samples.""" + """Learning type defining how the model learns from the dataset samples. + + This enum defines the different learning paradigms supported by anomalib models: + + - ``ONE_CLASS``: Model learns from a single class of normal samples + - ``ZERO_SHOT``: Model learns without any task-specific training samples + - ``FEW_SHOT``: Model learns from a small number of training samples + + Example: + >>> from anomalib import LearningType + >>> learning_type = LearningType.ONE_CLASS + >>> print(learning_type) + 'one_class' + + Note: + The learning type affects how the model is trained and what kind of data + it expects during training. + """ ONE_CLASS = "one_class" ZERO_SHOT = "zero_shot" @@ -17,7 +63,26 @@ class LearningType(str, Enum): class TaskType(str, Enum): - """Task type used when generating predictions on the dataset.""" + """Task type defining the model's prediction output format. + + This enum defines the different task types supported by anomalib models: + + - ``CLASSIFICATION``: Model predicts anomaly scores at the image level + - ``SEGMENTATION``: Model predicts pixel-wise anomaly scores and masks + + Example: + >>> from anomalib import TaskType + >>> task_type = TaskType.CLASSIFICATION + >>> print(task_type) + 'classification' + + Note: + The task type determines: + - The model architecture and output format + - Required ground truth annotation format + - Evaluation metrics used + - Visualization methods available + """ CLASSIFICATION = "classification" SEGMENTATION = "segmentation" diff --git a/src/anomalib/callbacks/__init__.py b/src/anomalib/callbacks/__init__.py index 38c9537ca4..087e36d620 100644 --- a/src/anomalib/callbacks/__init__.py +++ b/src/anomalib/callbacks/__init__.py @@ -1,6 +1,40 @@ -"""Callbacks for Anomalib models.""" +"""Callbacks for Anomalib models. -# Copyright (C) 2022 Intel Corporation +This module provides various callbacks used in Anomalib for model training, logging, and optimization. +The callbacks include model checkpointing, graph logging, model loading, tiler configuration, and timing. + +The module exports the following callbacks: + +- :class:`ModelCheckpoint`: Save model checkpoints during training +- :class:`GraphLogger`: Log model computation graphs +- :class:`LoadModelCallback`: Load pre-trained model weights +- :class:`TilerConfigurationCallback`: Configure image tiling settings +- :class:`TimerCallback`: Track training/inference timing + +Example: + Get default callbacks based on configuration: + + >>> from anomalib.callbacks import get_callbacks + >>> from omegaconf import DictConfig + >>> config = DictConfig({"trainer": {}, "project": {"path": "/tmp"}}) + >>> callbacks = get_callbacks(config) + >>> isinstance(callbacks, list) + True + + Use callbacks in trainer: + + >>> import lightning.pytorch as pl + >>> trainer = pl.Trainer(callbacks=callbacks) + +See Also: + - :mod:`anomalib.callbacks.checkpoint`: Model checkpoint callback + - :mod:`anomalib.callbacks.graph`: Graph logging callback + - :mod:`anomalib.callbacks.model_loader`: Model loading callback + - :mod:`anomalib.callbacks.tiler_configuration`: Tiler configuration callback + - :mod:`anomalib.callbacks.timer`: Timer callback +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import logging @@ -31,13 +65,51 @@ def get_callbacks(config: DictConfig | ListConfig | Namespace) -> list[Callback]: - """Return base callbacks for all the lightning models. + """Get default callbacks for Anomalib models based on configuration. + + This function returns a list of callbacks based on the provided configuration. + It automatically adds: + + - Model loading callback if checkpoint path is specified + - NNCF optimization callback if NNCF optimization is enabled Args: - config (DictConfig | ListConfig | Namespace): Model config + config (DictConfig | ListConfig | Namespace): Configuration object containing model and training settings. + Expected to have the following structure: + + .. code-block:: yaml + + trainer: + ckpt_path: Optional[str] # Path to model checkpoint + optimization: + nncf: + apply: bool # Whether to apply NNCF optimization + # Other NNCF config options + project: + path: str # Project directory path + + Returns: + list[Callback]: List of PyTorch Lightning callbacks to be used during training. + May include: + + - :class:`LoadModelCallback`: For loading model checkpoints + - :class:`NNCFCallback`: For neural network compression + - Other default callbacks + + Example: + >>> from omegaconf import DictConfig + >>> config = DictConfig({ + ... "trainer": {"ckpt_path": None}, + ... "project": {"path": "/tmp"}, + ... "optimization": {"nncf": {"apply": False}} + ... }) + >>> callbacks = get_callbacks(config) + >>> isinstance(callbacks, list) + True - Return: - (list[Callback]): List of callbacks. + Note: + NNCF is imported dynamically only when required since it conflicts with + some kornia JIT operations. """ logger.info("Loading the callbacks") diff --git a/src/anomalib/callbacks/checkpoint.py b/src/anomalib/callbacks/checkpoint.py index 7d7b4bb7d5..30114b1b99 100644 --- a/src/anomalib/callbacks/checkpoint.py +++ b/src/anomalib/callbacks/checkpoint.py @@ -1,4 +1,27 @@ -"""Anomalib Model Checkpoint Callback.""" +"""Anomalib Model Checkpoint Callback. + +This module provides the :class:`ModelCheckpoint` callback that extends PyTorch Lightning's +:class:`~lightning.pytorch.callbacks.ModelCheckpoint` to support zero-shot and few-shot learning scenarios. + +The callback enables checkpoint saving without requiring training steps, which is particularly useful for +zero-shot and few-shot learning models where the training process may only involve validation. + +Example: + Create and use a checkpoint callback: + + >>> from anomalib.callbacks import ModelCheckpoint + >>> checkpoint_callback = ModelCheckpoint( + ... dirpath="checkpoints", + ... filename="best", + ... monitor="val_loss" + ... ) + >>> from lightning.pytorch import Trainer + >>> trainer = Trainer(callbacks=[checkpoint_callback]) + +Note: + This callback is particularly important for zero-shot and few-shot models where + traditional training-based checkpoint saving strategies may not be appropriate. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,29 +34,61 @@ class ModelCheckpoint(LightningCheckpoint): - """Anomalib Model Checkpoint Callback. - - This class overrides the Lightning ModelCheckpoint callback to enable saving checkpoints without running any - training steps. This is useful for zero-/few-shot models, where the fit sequence only consists of validation. - - To enable saving checkpoints without running any training steps, we need to override two checks which are being - called in the ``on_validation_end`` method of the parent class: - - ``_should_save_on_train_epoch_end``: This method checks whether the checkpoint should be saved at the end of a - training epoch, or at the end of the validation sequence. We modify this method to default to saving at the end - of the validation sequence when the model is of zero- or few-shot type, unless ``save_on_train_epoch_end`` is - specifically set by the user. - - ``_should_skip_saving_checkpoint``: This method checks whether the checkpoint should be saved at all. We modify - this method to allow saving during both the ``FITTING`` and ``VALIDATING`` states. In addition, we allow saving - if the global step has not changed since the last checkpoint, but only for zero- and few-shot models. This is - needed because both the last global step and the last checkpoint remain unchanged during zero-/few-shot - training, which would otherwise prevent saving checkpoints during validation. + """Custom ModelCheckpoint callback for Anomalib. + + This callback extends PyTorch Lightning's + :class:`~lightning.pytorch.callbacks.ModelCheckpoint` to enable checkpoint saving + without requiring training steps. This is particularly useful for zero-shot and few-shot + learning models where the training process may only involve validation. + + The callback overrides two key methods from the parent class: + + 1. :meth:`_should_save_on_train_epoch_end`: Controls whether checkpoints are saved at the end + of training epochs or validation sequences. For zero-shot and few-shot models, it defaults + to saving at validation end unless explicitly configured otherwise. + + 2. :meth:`_should_skip_saving_checkpoint`: Determines if checkpoint saving should be skipped. + Modified to: + + - Allow saving during both ``FITTING`` and ``VALIDATING`` states + - Permit saving even when global step hasn't changed (for zero-shot/few-shot models) + - Maintain standard checkpoint skipping conditions (``fast_dev_run``, sanity checking) + + Example: + Create and use a checkpoint callback: + + >>> from anomalib.callbacks import ModelCheckpoint + >>> # Create a checkpoint callback + >>> checkpoint_callback = ModelCheckpoint( + ... dirpath="checkpoints", + ... filename="best", + ... monitor="val_loss" + ... ) + >>> # Use it with Lightning Trainer + >>> from lightning.pytorch import Trainer + >>> trainer = Trainer(callbacks=[checkpoint_callback]) + + Note: + All arguments from PyTorch Lightning's :class:`~lightning.pytorch.callbacks.ModelCheckpoint` are supported. + See :class:`~lightning.pytorch.callbacks.ModelCheckpoint` for details. """ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool: - """Checks whether the checkpoint should be saved. + """Determine if checkpoint saving should be skipped. + + Args: + trainer (:class:`~lightning.pytorch.Trainer`): PyTorch Lightning trainer instance. + + Returns: + bool: ``True`` if checkpoint saving should be skipped, ``False`` otherwise. - Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow - saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models). + Note: + The method considers the following conditions: + + - Skips if ``fast_dev_run`` is enabled + - Skips if not in ``FITTING`` or ``VALIDATING`` state + - Skips during sanity checking + - For non-zero/few-shot models, skips if global step hasn't changed """ is_zero_or_few_shot = trainer.lightning_module.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT} return ( @@ -44,10 +99,20 @@ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool: ) def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool: - """Checks whether the checkpoint should be saved at the end of a training epoch or validation sequence. + """Determine if checkpoint should be saved at training epoch end. + + Args: + trainer (:class:`~lightning.pytorch.Trainer`): PyTorch Lightning trainer instance. + + Returns: + bool: ``True`` if checkpoint should be saved at training epoch end, ``False`` otherwise. + + Note: + The method follows this decision flow: - Overrides the parent method to default to saving at the end of the validation sequence when the model is of - zero- or few-shot type, unless ``save_on_train_epoch_end`` is specifically set by the user. + - Returns user-specified value if ``_save_on_train_epoch_end`` is set + - For zero/few-shot models, defaults to ``False`` (save at validation end) + - Otherwise, follows parent class behavior """ if self._save_on_train_epoch_end is not None: return self._save_on_train_epoch_end diff --git a/src/anomalib/callbacks/graph.py b/src/anomalib/callbacks/graph.py index 38864245f6..e73b1b9cdf 100644 --- a/src/anomalib/callbacks/graph.py +++ b/src/anomalib/callbacks/graph.py @@ -1,6 +1,38 @@ -"""Log model graph to respective logger.""" +"""Graph logging callback for model visualization. -# Copyright (C) 2022 Intel Corporation +This module provides the :class:`GraphLogger` callback for visualizing model architectures in various logging backends. +The callback supports TensorBoard, Comet, and Weights & Biases (W&B) logging. + +The callback automatically detects which logger is being used and +handles the graph logging appropriately for each backend. + +Example: + Log model graph to TensorBoard: + + >>> from anomalib.callbacks import GraphLogger + >>> from anomalib.loggers import AnomalibTensorBoardLogger + >>> from anomalib.engine import Engine + >>> logger = AnomalibTensorBoardLogger() + >>> callbacks = [GraphLogger()] + >>> engine = Engine(logger=logger, callbacks=callbacks) + + Log model graph to Comet: + + >>> from anomalib.callbacks import GraphLogger + >>> from anomalib.loggers import AnomalibCometLogger + >>> from anomalib.engine import Engine + >>> logger = AnomalibCometLogger() + >>> callbacks = [GraphLogger()] + >>> engine = Engine(logger=logger, callbacks=callbacks) + +Note: + For TensorBoard and Comet, the graph is logged at the end of training. + For W&B, the graph is logged at the start of training but requires one backward pass + to be populated. This means it may not work for models that don't require training + (e.g., :class:`PaDiM`). +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import torch @@ -12,34 +44,45 @@ class GraphLogger(Callback): """Log model graph to respective logger. - Examples: - Log model graph to Tensorboard + This callback logs the model architecture graph to the configured logger. It supports multiple + logging backends including TensorBoard, Comet, and Weights & Biases (W&B). + + The callback automatically detects which logger is being used and handles the graph logging + appropriately for each backend. + + Example: + Create and use a graph logger: >>> from anomalib.callbacks import GraphLogger >>> from anomalib.loggers import AnomalibTensorBoardLogger - >>> from anomalib.engine import Engine - ... + >>> from lightning.pytorch import Trainer >>> logger = AnomalibTensorBoardLogger() - >>> callbacks = [GraphLogger()] - >>> engine = Engine(logger=logger, callbacks=callbacks) + >>> graph_logger = GraphLogger() + >>> trainer = Trainer(logger=logger, callbacks=[graph_logger]) - Log model graph to Comet - - >>> from anomalib.loggers import AnomalibCometLogger - >>> from anomalib.engine import Engine - ... - >>> logger = AnomalibCometLogger() - >>> callbacks = [GraphLogger()] - >>> engine = Engine(logger=logger, callbacks=callbacks) + Note: + - For TensorBoard and Comet, the graph is logged at the end of training + - For W&B, the graph is logged at the start of training but requires one backward pass + to be populated. This means it may not work for models that don't require training + (e.g., :class:`PaDiM`) """ @staticmethod def on_train_start(trainer: Trainer, pl_module: LightningModule) -> None: - """Log model graph to respective logger. + """Log model graph to respective logger at training start. + + This method is called automatically at the start of training. For W&B logger, + it sets up model watching with graph logging enabled. Args: - trainer: Trainer object which contans reference to loggers. - pl_module: LightningModule object which is logged. + trainer (Trainer): PyTorch Lightning trainer instance containing logger references. + pl_module (LightningModule): Lightning module instance to be logged. + + Example: + >>> from anomalib.callbacks import GraphLogger + >>> callback = GraphLogger() + >>> # Called automatically by trainer + >>> # callback.on_train_start(trainer, model) """ for logger in trainer.loggers: if isinstance(logger, AnomalibWandbLogger): @@ -50,11 +93,21 @@ def on_train_start(trainer: Trainer, pl_module: LightningModule) -> None: @staticmethod def on_train_end(trainer: Trainer, pl_module: LightningModule) -> None: - """Unwatch model if configured for wandb and log it model graph in Tensorboard if specified. + """Log model graph at training end and cleanup. + + This method is called automatically at the end of training. It: + - Logs the model graph for TensorBoard and Comet loggers + - Unwatches the model for W&B logger Args: - trainer: Trainer object which contans reference to loggers. - pl_module: LightningModule object which is logged. + trainer (Trainer): PyTorch Lightning trainer instance containing logger references. + pl_module (LightningModule): Lightning module instance to be logged. + + Example: + >>> from anomalib.callbacks import GraphLogger + >>> callback = GraphLogger() + >>> # Called automatically by trainer + >>> # callback.on_train_end(trainer, model) """ for logger in trainer.loggers: if isinstance(logger, AnomalibCometLogger | AnomalibTensorBoardLogger): diff --git a/src/anomalib/callbacks/model_loader.py b/src/anomalib/callbacks/model_loader.py index 8c688b3127..f977882106 100644 --- a/src/anomalib/callbacks/model_loader.py +++ b/src/anomalib/callbacks/model_loader.py @@ -1,6 +1,26 @@ -"""Callback that loads model weights from the state dict.""" +"""Model loader callback. -# Copyright (C) 2022 Intel Corporation +This module provides the :class:`LoadModelCallback` for loading pre-trained model weights from a state dict. + +The callback loads model weights from a specified path when inference begins. This is useful for loading +pre-trained models for inference or fine-tuning. + +Example: + Load pre-trained weights and create a trainer: + + >>> from anomalib.callbacks import LoadModelCallback + >>> from anomalib.engine import Engine + >>> from anomalib.models import Padim + >>> model = Padim() + >>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")] + >>> engine = Engine(model=model, callbacks=callbacks) + +Note: + The weights file should be a PyTorch state dict saved with either a ``.pt`` or ``.pth`` extension. + The state dict should contain a ``"state_dict"`` key with the model weights. +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import logging @@ -14,14 +34,29 @@ class LoadModelCallback(Callback): - """Callback that loads the model weights from the state dict. + """Callback that loads model weights from a state dict. + + This callback loads pre-trained model weights from a specified path when inference begins. + The weights are loaded into the model's state dict using the device specified by the model. + + Args: + weights_path (str): Path to the model weights file (``.pt`` or ``.pth``). + The file should contain a state dict with a ``"state_dict"`` key. Examples: + Create a callback and use it with a trainer: + >>> from anomalib.callbacks import LoadModelCallback >>> from anomalib.engine import Engine - ... - >>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")] - >>> engine = Engine(callbacks=callbacks) + >>> from anomalib.models import Padim + >>> model = Padim() + >>> # Create callback with path to weights + >>> callback = LoadModelCallback(weights_path="path/to/weights.pt") + >>> # Use callback with engine + >>> engine = Engine(model=model, callbacks=[callback]) + + Note: + The callback automatically handles device mapping when loading weights. """ def __init__(self, weights_path: str) -> None: @@ -30,7 +65,18 @@ def __init__(self, weights_path: str) -> None: def setup(self, trainer: Trainer, pl_module: AnomalibModule, stage: str | None = None) -> None: """Call when inference begins. - Loads the model weights from ``weights_path`` into the PyTorch module. + This method is called by PyTorch Lightning when inference begins. It loads the model + weights from the specified path into the module's state dict. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (AnomalibModule): The module to load weights into. + stage (str | None, optional): Current stage of execution. Defaults to ``None``. + + Note: + The weights are loaded using ``torch.load`` with automatic device mapping based on + the module's device. The state dict is expected to have a ``"state_dict"`` key + containing the model weights. """ del trainer, stage # These variables are not used. diff --git a/src/anomalib/callbacks/nncf/__init__.py b/src/anomalib/callbacks/nncf/__init__.py index 074a1bd861..6691729144 100644 --- a/src/anomalib/callbacks/nncf/__init__.py +++ b/src/anomalib/callbacks/nncf/__init__.py @@ -1,4 +1,4 @@ """Integration NNCF.""" -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/callbacks/nncf/callback.py b/src/anomalib/callbacks/nncf/callback.py index ce45f0a866..2372d1b972 100644 --- a/src/anomalib/callbacks/nncf/callback.py +++ b/src/anomalib/callbacks/nncf/callback.py @@ -1,6 +1,15 @@ -"""Callbacks for NNCF optimization.""" +"""NNCF optimization callback. -# Copyright (C) 2022 Intel Corporation +This module provides the `NNCFCallback` for optimizing neural networks using Intel's Neural Network +Compression Framework (NNCF). The callback handles model compression techniques like quantization +and pruning. + +Note: + The callback assumes that the Lightning module contains a 'model' attribute which is the + PyTorch module to be compressed. +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import subprocess # nosec B404 @@ -19,15 +28,45 @@ class NNCFCallback(Callback): - """Callback for NNCF compression. + """Callback for NNCF model compression. - Assumes that the pl module contains a 'model' attribute, which is - the PyTorch module that must be compressed. + This callback handles the compression of PyTorch models using NNCF during training. + It supports various compression techniques like quantization and pruning. Args: - config (dict): NNCF Configuration - export_dir (Str): Path where the export `onnx` and the OpenVINO `xml` and `bin` IR are saved. - If None model will not be exported. + config (dict): NNCF configuration dictionary that specifies the compression + parameters and algorithms to be applied. See the NNCF documentation for + details on configuration options. + export_dir (str | None, optional): Directory path where the exported models will be saved. + If provided, the following files will be exported: + + - ONNX model file (`model_nncf.onnx`) + - OpenVINO IR files (`model_nncf.xml` and `model_nncf.bin`) + + If ``None``, model export will be skipped. Defaults to ``None``. + + Examples: + Configure NNCF quantization: + + >>> nncf_config = { + ... "input_info": {"sample_size": [1, 3, 224, 224]}, + ... "compression": {"algorithm": "quantization"} + ... } + >>> callback = NNCFCallback(config=nncf_config, export_dir="./compressed_models") + >>> trainer = pl.Trainer(callbacks=[callback]) + + Note: + - The callback assumes that the Lightning module contains a ``model`` attribute which is the + PyTorch module to be compressed. + - The compression is initialized using the validation dataloader since it contains both normal + and anomalous samples, unlike the training set which only has normal samples. + - Model export requires OpenVINO's Model Optimizer (``mo``) to be available in the system PATH. + + See Also: + - :class:`lightning.pytorch.Callback`: Base callback class + - :class:`nncf.NNCFConfig`: NNCF configuration class + - :func:`nncf.torch.register_default_init_args`: Register initialization arguments + - :func:`anomalib.callbacks.nncf.utils.wrap_nncf_model`: Wrap model for NNCF compression """ def __init__(self, config: dict, export_dir: str | None = None) -> None: @@ -36,10 +75,15 @@ def __init__(self, config: dict, export_dir: str | None = None) -> None: self.nncf_ctrl: CompressionAlgorithmController | None = None def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None: - """Call when fit or test begins. + """Initialize NNCF compression when training begins. - Takes the pytorch model and wraps it using the compression controller - so that it is ready for nncf fine-tuning. + This method is called when training or testing begins. It wraps the PyTorch model + using the NNCF compression controller to prepare it for compression during training. + + Args: + trainer (pl.Trainer): PyTorch Lightning trainer instance + pl_module (pl.LightningModule): The Lightning module containing the model to compress + stage (str | None, optional): Current stage of training. Defaults to ``None``. """ del stage # `stage` variable is not used. @@ -66,9 +110,17 @@ def on_train_batch_start( batch_idx: int, unused: int = 0, ) -> None: - """Call when the train batch begins. + """Prepare compression before each training batch. + + Called at the beginning of each training batch to update the compression + scheduler for the next step. - Prepare compression method to continue training the model in the next step. + Args: + trainer (pl.Trainer): PyTorch Lightning trainer instance + pl_module (pl.LightningModule): The Lightning module being trained + batch (Any): Current batch of data + batch_idx (int): Index of current batch + unused (int, optional): Unused parameter. Defaults to ``0``. """ del trainer, pl_module, batch, batch_idx, unused # These variables are not used. @@ -76,9 +128,14 @@ def on_train_batch_start( self.nncf_ctrl.scheduler.step() def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - """Call when the train epoch starts. + """Prepare compression before each training epoch. - Prepare compression method to continue training the model in the next epoch. + Called at the beginning of each training epoch to update the compression + scheduler for the next epoch. + + Args: + trainer (pl.Trainer): PyTorch Lightning trainer instance + pl_module (pl.LightningModule): The Lightning module being trained """ del trainer, pl_module # `trainer` and `pl_module` variables are not used. @@ -86,9 +143,20 @@ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModul self.nncf_ctrl.scheduler.epoch_step() def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - """Call when the train ends. + """Export the compressed model when training ends. + + This method handles the export of the compressed model to ONNX format and + optionally converts it to OpenVINO IR format if the export directory is specified. + + Args: + trainer (pl.Trainer): PyTorch Lightning trainer instance + pl_module (pl.LightningModule): The trained Lightning module - Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR. + Note: + - Requires OpenVINO's Model Optimizer (``mo``) to be available in the system PATH + - Creates the export directory if it doesn't exist + - Exports ONNX model as ``model_nncf.onnx`` + - Converts ONNX to OpenVINO IR format using ``mo`` """ del trainer, pl_module # `trainer` and `pl_module` variables are not used. diff --git a/src/anomalib/callbacks/nncf/utils.py b/src/anomalib/callbacks/nncf/utils.py index 99f1db6aaa..6f0a783a77 100644 --- a/src/anomalib/callbacks/nncf/utils.py +++ b/src/anomalib/callbacks/nncf/utils.py @@ -1,6 +1,17 @@ -"""Utils for NNCf optimization.""" +"""Utilities for Neural Network Compression Framework (NNCF) optimization. -# Copyright (C) 2022 Intel Corporation +This module provides utility functions and classes for working with Intel's Neural Network +Compression Framework (NNCF). It includes functionality for model initialization, state +management, and configuration handling. + +The module contains: + +- ``InitLoader``: A data loader class for NNCF initialization +- Functions for wrapping PyTorch models with NNCF compression +- Utilities for handling NNCF model states and configurations +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import logging @@ -24,19 +35,67 @@ class InitLoader(PTInitializingDataLoader): - """Initializing data loader for NNCF to be used with unsupervised training algorithms.""" + """Initializing data loader for NNCF to be used with unsupervised training algorithms. + + This class extends NNCF's ``PTInitializingDataLoader`` to handle unsupervised training data. + It provides methods for iterating through the data and extracting inputs for model initialization. + + Args: + data_loader (DataLoader): PyTorch ``DataLoader`` containing the initialization data. + + Examples: + Create an initialization loader from a PyTorch dataloader: + + >>> from torch.utils.data import DataLoader, TensorDataset + >>> import torch + >>> dataset = TensorDataset(torch.randn(10, 3, 32, 32)) + >>> dataloader = DataLoader(dataset) + >>> init_loader = InitLoader(dataloader) + + Iterate through the loader: + + >>> for batch in init_loader: + ... assert isinstance(batch, torch.Tensor) + ... assert batch.shape[1:] == (3, 32, 32) + + Note: + The loader expects the dataloader to return dictionaries with an ``"image"`` key + containing the input tensor. + """ def __init__(self, data_loader: DataLoader) -> None: super().__init__(data_loader) self._data_loader_iter: Iterator def __iter__(self) -> "InitLoader": - """Create iterator for dataloader.""" + """Create iterator for dataloader. + + Returns: + InitLoader: Self reference for iteration. + + Example: + >>> from torch.utils.data import DataLoader, TensorDataset + >>> loader = InitLoader(DataLoader(TensorDataset(torch.randn(1,3,32,32)))) + >>> iterator = iter(loader) + >>> isinstance(iterator, InitLoader) + True + """ self._data_loader_iter = iter(self._data_loader) return self def __next__(self) -> torch.Tensor: - """Return next item from dataloader iterator.""" + """Return next item from dataloader iterator. + + Returns: + torch.Tensor: Next image tensor from the dataloader. + + Example: + >>> from torch.utils.data import DataLoader, TensorDataset + >>> loader = InitLoader(DataLoader(TensorDataset(torch.randn(1,3,32,32)))) + >>> batch = next(iter(loader)) + >>> isinstance(batch, torch.Tensor) + True + """ loaded_item = next(self._data_loader_iter) return loaded_item["image"] @@ -44,9 +103,20 @@ def __next__(self) -> torch.Tensor: def get_inputs(dataloader_output: dict[str, str | torch.Tensor]) -> tuple[tuple, dict]: """Get input to model. + Args: + dataloader_output (dict[str, str | torch.Tensor]): Output from the dataloader + containing the input tensor. + Returns: - (dataloader_output,), {}: tuple[tuple, dict]: The current model call to be made during - the initialization process + tuple[tuple, dict]: A tuple containing: + - A tuple with the dataloader output + - An empty dict for additional arguments + + Example: + >>> output = {"image": torch.randn(1,3,32,32)} + >>> args, kwargs = InitLoader.get_inputs(output) + >>> isinstance(args, tuple) and isinstance(kwargs, dict) + True """ return (dataloader_output,), {} @@ -54,10 +124,15 @@ def get_inputs(dataloader_output: dict[str, str | torch.Tensor]) -> tuple[tuple, def get_target(_) -> None: # noqa: ANN001 """Return structure for ground truth in loss criterion based on dataloader output. - This implementation does not do anything and is a placeholder. + This implementation is a placeholder that returns ``None`` since ground truth + is not used in unsupervised training. Returns: - None + None: Always returns ``None`` as targets are not used. + + Example: + >>> InitLoader.get_target(None) is None + True """ return @@ -68,13 +143,32 @@ def wrap_nncf_model( dataloader: DataLoader, init_state_dict: dict, ) -> tuple[CompressionAlgorithmController, NNCFNetwork]: - """Wrap model by NNCF. + """Wrap PyTorch model with NNCF compression. - :param model: Anomalib model. - :param config: NNCF config. - :param dataloader: Dataloader for initialization of NNCF model. - :param init_state_dict: Opti - :return: compression controller, compressed model + Args: + model (nn.Module): Anomalib model to be compressed. + config (dict): NNCF configuration dictionary. + dataloader (DataLoader): DataLoader for NNCF model initialization. + init_state_dict (dict): Initial state dictionary for model initialization. + + Returns: + tuple[CompressionAlgorithmController, NNCFNetwork]: A tuple containing: + - The compression controller + - The compressed model + + Warning: + Either ``dataloader`` or ``init_state_dict`` must be provided for proper quantizer initialization. + + Example: + >>> import torch.nn as nn + >>> from torch.utils.data import DataLoader, TensorDataset + >>> model = nn.Linear(10, 2) + >>> config = {"input_info": {"sample_size": [1, 10]}} + >>> data = torch.randn(100, 10) + >>> dataloader = DataLoader(TensorDataset(data)) + >>> controller, compressed = wrap_nncf_model(model, config, dataloader, {}) + >>> isinstance(compressed, NNCFNetwork) + True """ nncf_config = NNCFConfig.from_dict(config) @@ -109,16 +203,53 @@ def wrap_nncf_model( def is_state_nncf(state: dict) -> bool: - """Check if state is the result of NNCF-compressed model.""" + """Check if state is the result of NNCF-compressed model. + + Args: + state (dict): Model state dictionary to check. + + Returns: + bool: ``True`` if the state is from an NNCF-compressed model, ``False`` otherwise. + + Example: + >>> state = {"meta": {"nncf_enable_compression": True}} + >>> is_state_nncf(state) + True + >>> state = {"meta": {}} + >>> is_state_nncf(state) + False + """ return bool(state.get("meta", {}).get("nncf_enable_compression", False)) def compose_nncf_config(nncf_config: dict, enabled_options: list[str]) -> dict: - """Compose NNCf config by selected options. + """Compose NNCF config by selected options. + + This function merges different parts of the NNCF configuration based on enabled options. + It supports ordered application of configuration parts through the ``order_of_parts`` field. + + Args: + nncf_config (dict): Base NNCF configuration dictionary. + enabled_options (list[str]): List of enabled optimization options. - :param nncf_config: - :param enabled_options: - :return: config + Returns: + dict: Composed NNCF configuration. + + Raises: + TypeError: If ``order_of_parts`` is not a list. + ValueError: If an enabled option is not in ``order_of_parts``. + KeyError: If ``base`` part or any enabled option is missing from config. + RuntimeError: If there's an error during config merging. + + Example: + >>> config = { + ... "base": {"epochs": 1}, + ... "quantization": {"epochs": 2}, + ... "order_of_parts": ["quantization"] + ... } + >>> result = compose_nncf_config(config, ["quantization"]) + >>> result["epochs"] + 2 """ optimisation_parts = nncf_config optimisation_parts_to_choose = [] @@ -169,14 +300,26 @@ def merge_dicts_and_lists_b_into_a( a: dict[Any, Any] | list[Any], b: dict[Any, Any] | list[Any], ) -> dict[Any, Any] | list[Any]: - """Merge dict configs. + """Merge two configuration dictionaries or lists. + + This function provides the public interface for merging configurations. + It delegates to the internal ``_merge_dicts_and_lists_b_into_a`` function. Args: - a (dict[Any, Any] | list[Any]): First dict or list. - b (dict[Any, Any] | list[Any]): Second dict or list. + a (dict[Any, Any] | list[Any]): First dictionary or list to merge. + b (dict[Any, Any] | list[Any]): Second dictionary or list to merge into first. Returns: - dict[Any, Any] | list[Any]: Merged dict or list. + dict[Any, Any] | list[Any]: Merged configuration. + + Example: + >>> a = {"x": 1, "y": [1, 2]} + >>> b = {"y": [3], "z": 2} + >>> result = merge_dicts_and_lists_b_into_a(a, b) + >>> result["y"] + [1, 2, 3] + >>> result["z"] + 2 """ return _merge_dicts_and_lists_b_into_a(a, b, "") @@ -186,30 +329,37 @@ def _merge_dicts_and_lists_b_into_a( b: dict[Any, Any] | list[Any], cur_key: int | str | None = None, ) -> dict[Any, Any] | list[Any]: - """Merge dict configs. + """Recursively merge two configuration dictionaries or lists. - * works with usual dicts and lists and derived types - * supports merging of lists (by concatenating the lists) - * makes recursive merging for dict + dict case - * overwrites when merging scalar into scalar - Note that we merge b into a (whereas Config makes merge a into b), - since otherwise the order of list merging is counter-intuitive. + This function implements the following merge behavior: + - Works with standard dicts, lists and their derived types + - Merges lists by concatenation + - Performs recursive merging for nested dictionaries + - Overwrites scalar values when merging Args: - a (dict[Any, Any] | list[Any]): First dict or list. - b (dict[Any, Any] | list[Any]): Second dict or list. - cur_key (int | str | None, optional): key for current level of recursion. Defaults to None. + a (dict[Any, Any] | list[Any]): First dictionary or list to merge. + b (dict[Any, Any] | list[Any]): Second dictionary or list to merge into first. + cur_key (int | str | None, optional): Current key in recursive merge. Defaults to None. Returns: - dict[Any, Any] | list[Any]: Merged dict or list. + dict[Any, Any] | list[Any]: Merged configuration. + + Raises: + TypeError: If inputs are not dictionaries or lists, or if types are incompatible. + + Example: + >>> a = {"x": {"y": [1]}} + >>> b = {"x": {"y": [2]}} + >>> result = _merge_dicts_and_lists_b_into_a(a, b) + >>> result["x"]["y"] + [1, 2] """ def _err_str(_a: dict | list, _b: dict | list, _key: int | str | None = None) -> str: _key_str = "of whole structures" if _key is None else f"during merging for key=`{_key}`" return ( - f"Error in merging parts of config: different types {_key_str}," - f" type(a) = {type(_a)}," - f" type(b) = {type(_b)}" + f"Error in merging parts of config: different types {_key_str}, type(a) = {type(_a)}, type(b) = {type(_b)}" ) if not (isinstance(a, dict | list)): diff --git a/src/anomalib/callbacks/tiler_configuration.py b/src/anomalib/callbacks/tiler_configuration.py index f44a4d679f..9e3d92d5d7 100644 --- a/src/anomalib/callbacks/tiler_configuration.py +++ b/src/anomalib/callbacks/tiler_configuration.py @@ -1,6 +1,32 @@ -"""Tiler Callback.""" +"""Tiler configuration callback. -# Copyright (C) 2022 Intel Corporation +This module provides the :class:`TilerConfigurationCallback` for configuring image tiling operations +in Anomalib models. Tiling allows processing large images by splitting them into smaller tiles, +which is useful when dealing with high-resolution images that don't fit in GPU memory. + +The callback configures tiling parameters such as tile size, stride, and upscaling mode for +models that support tiling operations. + +Example: + Configure tiling with custom parameters: + + >>> from anomalib.callbacks import TilerConfigurationCallback + >>> from anomalib.data.utils.tiler import ImageUpscaleMode + >>> callback = TilerConfigurationCallback( + ... enable=True, + ... tile_size=512, + ... stride=256, + ... mode=ImageUpscaleMode.PADDING + ... ) + >>> from lightning.pytorch import Trainer + >>> trainer = Trainer(callbacks=[callback]) + +Note: + The model must support tiling operations for this callback to work. + It will raise a :exc:`ValueError` if used with a model that doesn't support tiling. +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from collections.abc import Sequence @@ -15,7 +41,53 @@ class TilerConfigurationCallback(Callback): - """Tiler Configuration Callback.""" + """Callback for configuring image tiling operations. + + This callback configures the tiling operation for models that support it. Tiling is useful + when working with high-resolution images that need to be processed in smaller chunks. + + Args: + enable (bool): Whether to enable tiling operation. Defaults to ``False``. + tile_size (int | Sequence): Size of each tile. Can be a single integer for square tiles + or a sequence of two integers for rectangular tiles. Defaults to ``256``. + stride (int | Sequence | None): Stride between tiles. Can be a single integer or a sequence + of two integers. If ``None``, uses ``tile_size``. Defaults to ``None``. + remove_border_count (int): Number of pixels to remove from the image border before + tiling. Useful for removing artifacts at image boundaries. Defaults to ``0``. + mode (ImageUpscaleMode): Method to use when combining overlapping tiles. + Options are defined in :class:`~anomalib.data.utils.tiler.ImageUpscaleMode`. + Defaults to ``ImageUpscaleMode.PADDING``. + + Examples: + Create a basic tiling configuration: + + >>> callback = TilerConfigurationCallback(enable=True) + + Configure tiling with custom tile size and stride: + + >>> callback = TilerConfigurationCallback( + ... enable=True, + ... tile_size=512, + ... stride=256 + ... ) + + Use rectangular tiles with custom upscale mode: + + >>> from anomalib.data.utils.tiler import ImageUpscaleMode + >>> callback = TilerConfigurationCallback( + ... enable=True, + ... tile_size=(512, 256), + ... mode=ImageUpscaleMode.AVERAGE + ... ) + + Raises: + ValueError: If used with a model that doesn't support tiling operations. + + Note: + - The model must have a ``tiler`` attribute to support tiling operations + - Smaller stride values result in more overlap between tiles but increase computation + - The upscale mode affects how overlapping regions are combined + """ def __init__( self, @@ -25,21 +97,7 @@ def __init__( remove_border_count: int = 0, mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, ) -> None: - """Set tiling configuration from the command line. - - Args: - enable (bool): Boolean to enable tiling operation. - Defaults to False. - tile_size ([int | Sequence]): Tile size. - Defaults to 256. - stride ([int | Sequence]): Stride to move tiles on the image. - remove_border_count (int, optional): Number of pixels to remove from the image before - tiling. Defaults to 0. - mode (str, optional): Up-scaling mode when untiling overlapping tiles. - Defaults to "padding". - tile_count (SupportsIndex, optional): Number of random tiles to sample from the image. - Defaults to 4. - """ + """Initialize tiling configuration.""" self.enable = enable self.tile_size = tile_size self.stride = stride @@ -49,14 +107,18 @@ def __init__( def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None: """Set Tiler object within Anomalib Model. + This method is called by PyTorch Lightning during setup. It configures the tiling + parameters if tiling is enabled and the model supports it. + Args: - trainer (pl.Trainer): PyTorch Lightning Trainer - pl_module (pl.LightningModule): Anomalib Model that inherits pl LightningModule. - stage (str | None, optional): fit, validate, test or predict. Defaults to None. + trainer (pl.Trainer): PyTorch Lightning Trainer instance. + pl_module (pl.LightningModule): The Anomalib model being trained/tested. + stage (str | None, optional): Current stage - ``"fit"``, ``"validate"``, + ``"test"`` or ``"predict"``. Defaults to ``None``. Raises: - ValueError: When Anomalib Model doesn't contain ``Tiler`` object, it means the model - doesn not support tiling operation. + ValueError: If tiling is enabled but the model doesn't support tiling operations + (i.e., doesn't have a ``tiler`` attribute). """ del trainer, stage # These variables are not used. diff --git a/src/anomalib/callbacks/timer.py b/src/anomalib/callbacks/timer.py index 3cbf516dc0..04554f1ef7 100644 --- a/src/anomalib/callbacks/timer.py +++ b/src/anomalib/callbacks/timer.py @@ -1,6 +1,27 @@ -"""Callback to measure training and testing time of a PyTorch Lightning module.""" +"""Timer callback. -# Copyright (C) 2022 Intel Corporation +This module provides the :class:`TimerCallback` for measuring training and testing time of +Anomalib models. The callback tracks execution time and calculates throughput metrics. + +Example: + Add timer callback to track performance: + + >>> from anomalib.callbacks import TimerCallback + >>> from lightning.pytorch import Trainer + >>> callback = TimerCallback() + >>> trainer = Trainer(callbacks=[callback]) + + The callback will automatically log: + - Total training time when training completes + - Total testing time and throughput (FPS) when testing completes + +Note: + - The callback handles both single and multiple test dataloaders + - Throughput is calculated as total number of images / total testing time + - Batch size is included in throughput logging for reference +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import logging @@ -13,63 +34,80 @@ class TimerCallback(Callback): - """Callback that measures the training and testing time of a PyTorch Lightning module. + """Callback for measuring model training and testing time. + + This callback tracks execution time metrics: + - Training time: Total time taken for model training + - Testing time: Total time taken for model testing + - Testing throughput: Images processed per second during testing + + Example: + Add timer to track performance: - Examples: >>> from anomalib.callbacks import TimerCallback - >>> from anomalib.engine import Engine - ... - >>> callbacks = [TimerCallback()] - >>> engine = Engine(callbacks=callbacks) + >>> from lightning.pytorch import Trainer + >>> callback = TimerCallback() + >>> trainer = Trainer(callbacks=[callback]) + + Note: + - The callback automatically handles both single and multiple test dataloaders + - Throughput is calculated as: ``num_test_images / testing_time`` + - All metrics are logged using the logger specified in the trainer """ def __init__(self) -> None: + """Initialize timer callback. + + The callback initializes: + - ``start``: Timestamp for tracking execution segments + - ``num_images``: Counter for total test images + """ self.start: float self.num_images: int = 0 def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Call when fit begins. + """Called when fit begins. - Sets the start time to the time training started. + Records the start time of the training process. Args: - trainer (Trainer): PyTorch Lightning trainer. - pl_module (LightningModule): Current training module. + trainer (Trainer): PyTorch Lightning trainer instance + pl_module (LightningModule): The current training module - Returns: - None + Note: + The trainer and module arguments are not used but kept for callback signature compatibility """ del trainer, pl_module # These variables are not used. - self.start = time.time() def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Call when fit ends. + """Called when fit ends. - Prints the time taken for training. + Calculates and logs the total training time. Args: - trainer (Trainer): PyTorch Lightning trainer. - pl_module (LightningModule): Current training module. + trainer (Trainer): PyTorch Lightning trainer instance + pl_module (LightningModule): The current training module - Returns: - None + Note: + The trainer and module arguments are not used but kept for callback signature compatibility """ del trainer, pl_module # Unused arguments. logger.info("Training took %5.2f seconds", (time.time() - self.start)) def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Call when the test begins. + """Called when test begins. - Sets the start time to the time testing started. - Goes over all the test dataloaders and adds the number of images in each. + Records test start time and counts total number of test images. Args: - trainer (Trainer): PyTorch Lightning trainer. - pl_module (LightningModule): Current training module. + trainer (Trainer): PyTorch Lightning trainer instance + pl_module (LightningModule): The current training module - Returns: - None + Note: + - Records start timestamp for testing phase + - Counts total images across all test dataloaders if multiple are present + - The module argument is not used but kept for callback signature compatibility """ del pl_module # Unused argument. @@ -84,16 +122,19 @@ def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self.num_images += len(dataloader.dataset) def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Call when the test ends. + """Called when test ends. - Prints the time taken for testing and the throughput in frames per second. + Calculates and logs testing time and throughput metrics. Args: - trainer (Trainer): PyTorch Lightning trainer. - pl_module (LightningModule): Current training module. - - Returns: - None + trainer (Trainer): PyTorch Lightning trainer instance + pl_module (LightningModule): The current training module + + Note: + - Calculates total testing time + - Computes throughput in frames per second (FPS) + - Logs batch size along with throughput for reference + - The module argument is not used but kept for callback signature compatibility """ del pl_module # Unused argument. diff --git a/src/anomalib/callbacks/visualizer.py b/src/anomalib/callbacks/visualizer.py index 9b0b78dfa0..a118c93ae4 100644 --- a/src/anomalib/callbacks/visualizer.py +++ b/src/anomalib/callbacks/visualizer.py @@ -1,6 +1,28 @@ -"""Visualizer Callback. +"""Visualizer callback. -This is assigned by Anomalib Engine internally. +This module provides the :class:`_VisualizationCallback` for generating and managing visualizations +in Anomalib. This callback is assigned by the Anomalib Engine internally. + +The callback handles: +- Generating visualizations during model testing and prediction +- Saving visualizations to disk +- Showing visualizations interactively +- Logging visualizations to various logging backends + +Example: + Create visualization callback with multiple visualizers:: + + >>> from anomalib.utils.visualization import ImageVisualizer, MetricsVisualizer + >>> visualizers = [ImageVisualizer(), MetricsVisualizer()] + >>> visualization_callback = _VisualizationCallback( + ... visualizers=visualizers, + ... save=True, + ... root="results/images" + ... ) + +Note: + This callback is used internally by the Anomalib Engine and should not be + instantiated directly by users. """ # Copyright (C) 2024 Intel Corporation @@ -17,11 +39,7 @@ from anomalib.loggers import AnomalibWandbLogger from anomalib.loggers.base import ImageLoggerBase from anomalib.models import AnomalibModule -from anomalib.utils.visualization import ( - BaseVisualizer, - GeneratorResult, - VisualizationStep, -) +from anomalib.utils.visualization import BaseVisualizer, GeneratorResult, VisualizationStep logger = logging.getLogger(__name__) @@ -29,32 +47,35 @@ class _VisualizationCallback(Callback): """Callback for visualization that is used internally by the Engine. + This callback handles the generation and management of visualizations during model + testing and prediction. It supports saving, showing, and logging visualizations + to various backends. + Args: - visualizers (BaseVisualizer | list[BaseVisualizer]): - Visualizer objects that are used for computing the visualizations. Defaults to None. - save (bool, optional): Save the image. Defaults to False. - root (Path | None, optional): The path to save the images. Defaults to None. - log (bool, optional): Log the images into the loggers. Defaults to False. - show (bool, optional): Show the images. Defaults to False. - - Example: - >>> visualizers = [ImageVisualizer(), MetricsVisualizer()] - >>> visualization_callback = _VisualizationCallback( - ... visualizers=visualizers, - ... save=True, - ... root="results/images" - ... ) + visualizers (BaseVisualizer | list[BaseVisualizer]): Visualizer objects that + are used for computing the visualizations. + save (bool, optional): Save the visualizations. Defaults to ``False``. + root (Path | None, optional): The path to save the visualizations. Defaults to ``None``. + log (bool, optional): Log the visualizations to the loggers. Defaults to ``False``. + show (bool, optional): Show the visualizations. Defaults to ``False``. + + Examples: + Create visualization callback with multiple visualizers:: - CLI - $ anomalib train --model Padim --data MVTec \ - --visualization.visualizers ImageVisualizer \ - --visualization.visualizers+=MetricsVisualizer - or - $ anomalib train --model Padim --data MVTec \ - --visualization.visualizers '[ImageVisualizer, MetricsVisualizer]' + >>> from anomalib.utils.visualization import ImageVisualizer, MetricsVisualizer + >>> visualizers = [ImageVisualizer(), MetricsVisualizer()] + >>> visualization_callback = _VisualizationCallback( + ... visualizers=visualizers, + ... save=True, + ... root="results/images" + ... ) + + Note: + This callback is used internally by the Anomalib Engine and should not be + instantiated directly by users. Raises: - ValueError: Incase `root` is None and `save` is True. + ValueError: If ``root`` is ``None`` and ``save`` is ``True``. """ def __init__( @@ -83,6 +104,30 @@ def on_test_batch_end( batch_idx: int, dataloader_idx: int = 0, ) -> None: + """Generate visualizations at the end of a test batch. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (AnomalibModule): The current module being tested. + outputs (STEP_OUTPUT | None): Outputs from the test step. + batch (Any): Current batch of data. + batch_idx (int): Index of the current batch. + dataloader_idx (int, optional): Index of the dataloader. Defaults to 0. + + Example: + Generate visualizations for a test batch:: + + >>> from anomalib.utils.visualization import ImageVisualizer + >>> callback = _VisualizationCallback( + ... visualizers=ImageVisualizer(), + ... save=True, + ... root="results/images" + ... ) + >>> callback.on_test_batch_end(trainer, model, outputs, batch, 0) + + Raises: + ValueError: If ``save`` is ``True`` but ``file_name`` is ``None``. + """ for generator in self.generators: if generator.visualize_on == VisualizationStep.BATCH: for result in generator( @@ -115,6 +160,26 @@ def on_test_batch_end( self._add_to_logger(result, pl_module, trainer) def on_test_end(self, trainer: Trainer, pl_module: AnomalibModule) -> None: + """Generate visualizations at the end of testing. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (AnomalibModule): The module that was tested. + + Example: + Generate visualizations at the end of testing:: + + >>> from anomalib.utils.visualization import MetricsVisualizer + >>> callback = _VisualizationCallback( + ... visualizers=MetricsVisualizer(), + ... save=True, + ... root="results/metrics" + ... ) + >>> callback.on_test_end(trainer, model) + + Raises: + ValueError: If ``save`` is ``True`` but ``file_name`` is ``None``. + """ for generator in self.generators: if generator.visualize_on == VisualizationStep.STAGE_END: for result in generator(trainer=trainer, pl_module=pl_module): @@ -141,9 +206,31 @@ def on_predict_batch_end( batch_idx: int, dataloader_idx: int = 0, ) -> None: + """Generate visualizations at the end of a prediction batch. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (AnomalibModule): The module being used for prediction. + outputs (STEP_OUTPUT | None): Outputs from the prediction step. + batch (Any): Current batch of data. + batch_idx (int): Index of the current batch. + dataloader_idx (int, optional): Index of the dataloader. Defaults to 0. + + Note: + This method calls :meth:`on_test_batch_end` internally. + """ return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) def on_predict_end(self, trainer: Trainer, pl_module: AnomalibModule) -> None: + """Generate visualizations at the end of prediction. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (AnomalibModule): The module that was used for prediction. + + Note: + This method calls :meth:`on_test_end` internally. + """ return self.on_test_end(trainer, pl_module) @staticmethod @@ -152,12 +239,21 @@ def _add_to_logger( module: AnomalibModule, trainer: Trainer, ) -> None: - """Add image to logger. + """Add visualization to logger. Args: - result (GeneratorResult): Output from the generators. + result (GeneratorResult): Output from the visualization generators. module (AnomalibModule): LightningModule from which the global step is extracted. - trainer (Trainer): Trainer object. + trainer (Trainer): Trainer object containing the loggers. + + Example: + Add visualization to logger:: + + >>> result = generator.generate(...) # Generate visualization + >>> _VisualizationCallback._add_to_logger(result, model, trainer) + + Raises: + ValueError: If ``file_name`` is ``None`` when attempting to log. """ # Store names of logger and the logger in a dict available_loggers = { diff --git a/src/anomalib/cli/__init__.py b/src/anomalib/cli/__init__.py index 78b54e5988..0c5fd02f80 100644 --- a/src/anomalib/cli/__init__.py +++ b/src/anomalib/cli/__init__.py @@ -1,6 +1,6 @@ """Anomalib CLI.""" -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .cli import AnomalibCLI diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 2bb61d7af5..88a9bf9fc7 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -1,4 +1,8 @@ -"""Anomalib CLI.""" +"""Anomalib Command Line Interface. + +This module provides the `AnomalibCLI` class for configuring and running Anomalib from the command line. +The CLI supports configuration via both command line arguments and configuration files (.yaml or .json). +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -38,16 +42,30 @@ class AnomalibCLI: - """Implementation of a fully configurable CLI tool for anomalib. + """Implementation of a fully configurable CLI tool for Anomalib. + + This class provides a flexible command-line interface that can be configured through + both CLI arguments and configuration files. It supports various subcommands for + training, testing, and exporting models. + + Args: + args (Sequence[str] | None): Command line arguments. Defaults to None. + run (bool): Whether to run the subcommand immediately. Defaults to True. + + Examples: + Run from command line: + + >>> import sys + >>> sys.argv = ["anomalib", "train", "--model", "Padim", "--data", "MVTec"] + + Run programmatically: - The advantage of this tool is its flexibility to configure the pipeline - from both the CLI and a configuration file (.yaml or .json). It is even - possible to use both the CLI and a configuration file simultaneously. - For more details, the reader could refer to PyTorch Lightning CLI - documentation. + >>> from anomalib.cli import AnomalibCLI + >>> cli = AnomalibCLI(["train", "--model", "Padim", "--data", "MVTec"], run=False) - ``save_config_kwargs`` is set to ``overwrite=True`` so that the - ``SaveConfigCallback`` overwrites the config if it already exists. + Note: + The CLI supports both YAML and JSON configuration files. Configuration can be + provided via both files and command line arguments simultaneously. """ def __init__(self, args: Sequence[str] | None = None, run: bool = True) -> None: diff --git a/src/anomalib/cli/install.py b/src/anomalib/cli/install.py index d114c8e168..755b856b50 100644 --- a/src/anomalib/cli/install.py +++ b/src/anomalib/cli/install.py @@ -1,4 +1,8 @@ -"""Anomalib install subcommand code.""" +"""Anomalib installation subcommand. + +This module provides the `anomalib_install` function for installing Anomalib and its dependencies. +It supports installing different dependency sets based on the user's needs. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,11 +13,7 @@ from rich.console import Console from rich.logging import RichHandler -from anomalib.cli.utils.installation import ( - get_requirements, - get_torch_install_args, - parse_requirements, -) +from anomalib.cli.utils.installation import get_requirements, get_torch_install_args, parse_requirements logger = logging.getLogger("pip") logger.setLevel(logging.WARNING) # setLevel: CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET @@ -29,15 +29,27 @@ def anomalib_install(option: str = "full", verbose: bool = False) -> int: """Install Anomalib requirements. + This function handles the installation of Anomalib dependencies based on the + specified option. It can install the full package or specific dependency sets. + Args: - option (str | None): Optional-dependency to install requirements for. - verbose (bool): Set pip logger level to INFO + option (str): Optional-dependency to install requirements for. + Options: "full", "core", "dev", "loggers", "notebooks", "openvino". + Defaults to "full". + verbose (bool): Set pip logger level to INFO. Defaults to False. + + Examples: + Install full package:: + >>> anomalib_install("full") + + Install core dependencies only:: + >>> anomalib_install("core") Raises: - ValueError: When the task is not supported. + ValueError: When the option is not supported. Returns: - int: Status code of the pip install command. + int: Status code of the pip install command (0 for success). """ from pip._internal.commands import create_command diff --git a/src/anomalib/cli/pipelines.py b/src/anomalib/cli/pipelines.py index ba6030491b..5a937f56d9 100644 --- a/src/anomalib/cli/pipelines.py +++ b/src/anomalib/cli/pipelines.py @@ -1,4 +1,8 @@ -"""Subcommand for pipelines.""" +"""Anomalib pipeline subcommands. + +This module provides functionality for managing and running Anomalib pipelines through +the CLI. It includes support for benchmarking and other pipeline operations. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -22,14 +26,39 @@ def pipeline_subcommands() -> dict[str, dict[str, str]]: - """Return subcommands for pipelines.""" + """Get available pipeline subcommands. + + Returns: + dict[str, dict[str, str]]: Dictionary mapping subcommand names to their descriptions. + + Example: + Pipeline subcommands are available only if the pipelines are installed:: + + >>> pipeline_subcommands() + { + 'benchmark': { + 'description': 'Run benchmarking pipeline for model evaluation' + } + } + """ if PIPELINE_REGISTRY is not None: return {name: {"description": get_short_docstring(pipeline)} for name, pipeline in PIPELINE_REGISTRY.items()} return {} def run_pipeline(args: Namespace) -> None: - """Run pipeline.""" + """Run a pipeline with the provided arguments. + + Args: + args (Namespace): Arguments for the pipeline, including the subcommand + and configuration. + + Raises: + ValueError: If pipelines are not available in the current installation. + + Note: + This feature is experimental and may change or be removed in future versions. + """ logger.warning("This feature is experimental. It may change or be removed in the future.") if PIPELINE_REGISTRY is not None: subcommand = args.subcommand diff --git a/src/anomalib/cli/utils/__init__.py b/src/anomalib/cli/utils/__init__.py index 028c972728..fbe47ff661 100644 --- a/src/anomalib/cli/utils/__init__.py +++ b/src/anomalib/cli/utils/__init__.py @@ -1,6 +1,6 @@ """Anomalib CLI Utils.""" -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .help_formatter import CustomHelpFormatter diff --git a/src/anomalib/cli/utils/help_formatter.py b/src/anomalib/cli/utils/help_formatter.py index 4535011b09..f31ea1174b 100644 --- a/src/anomalib/cli/utils/help_formatter.py +++ b/src/anomalib/cli/utils/help_formatter.py @@ -1,6 +1,10 @@ -"""Custom Help Formatters for Anomalib CLI.""" +"""Custom help formatters for Anomalib CLI. -# Copyright (C) 2023 Intel Corporation +This module provides custom help formatting functionality for the Anomalib CLI, +including rich text formatting and customized help output for different verbosity levels. +""" + +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import argparse @@ -38,13 +42,25 @@ def get_short_docstring(component: type) -> str: - """Get the short description from the docstring. + """Get the short description from a component's docstring. Args: - component (type): The component to get the docstring from + component (type): The component to extract the docstring from. Returns: - str: The short description + str: The short description from the docstring, or empty string if no docstring. + + Example: + >>> class MyClass: + ... '''My class description. + ... + ... More details here. + ... ''' + ... pass + + >>> output = get_short_docstring(MyClass) + >>> print(output) + My class description. """ if component.__doc__ is None: return "" @@ -53,12 +69,15 @@ def get_short_docstring(component: type) -> str: def get_verbosity_subcommand() -> dict: - """Parse command line arguments and returns a dictionary of key-value pairs. + """Parse command line arguments for verbosity and subcommand. Returns: - A dictionary containing the parsed command line arguments. + dict: Dictionary containing: + - subcommand: The subcommand being run + - help: Whether help was requested + - verbosity: Verbosity level (0-2) - Examples: + Example: >>> import sys >>> sys.argv = ['anomalib', 'train', '-h', '-v'] >>> get_verbosity_subcommand() @@ -79,12 +98,18 @@ def get_verbosity_subcommand() -> dict: def get_intro() -> Markdown: - """Return a Markdown object containing the introduction text for Anomalib CLI Guide. - - The introduction text includes a brief description of the guide and links to the Github repository and documentation + """Get the introduction text for the Anomalib CLI guide. Returns: - A Markdown object containing the introduction text for Anomalib CLI Guide. + Markdown: A Markdown object containing the introduction text with links + to the Github repository and documentation. + + Example: + >>> intro = get_intro() + >>> print(intro) + # Anomalib CLI Guide + Github Repository: https://github.com/openvinotoolkit/anomalib + Documentation: https://anomalib.readthedocs.io/ """ intro_markdown = ( "# Anomalib CLI Guide\n\n" @@ -96,15 +121,44 @@ def get_intro() -> Markdown: def get_verbose_usage(subcommand: str = "train") -> str: - """Return a string containing verbose usage information for the specified subcommand. + """Get verbose usage information for a subcommand. + + This function generates a formatted string containing usage instructions for running + an Anomalib CLI subcommand with different verbosity levels. The instructions show + how to access more detailed help information using the -v and -vv flags. Args: - ---- - subcommand (str): The name of the subcommand to get verbose usage information for. Defaults to "train". + subcommand (str, optional): The subcommand to get usage information for. + Defaults to "train". Returns: - ------- - str: A string containing verbose usage information for the specified subcommand. + str: A formatted string containing verbose usage information with examples + showing different verbosity levels. + + Example: + Get usage information for the "train" subcommand: + + >>> usage = get_verbose_usage("train") + >>> print(usage) # doctest: +NORMALIZE_WHITESPACE + To get more overridable argument information, run the command below. + ```python + # Verbosity Level 1 + anomalib train [optional_arguments] -h -v + # Verbosity Level 2 + anomalib train [optional_arguments] -h -vv + ``` + + Get usage for a different subcommand: + + >>> usage = get_verbose_usage("export") # doctest: +NORMALIZE_WHITESPACE + >>> print(usage) + To get more overridable argument information, run the command below. + ```python + # Verbosity Level 1 + anomalib export [optional_arguments] -h -v + # Verbosity Level 2 + anomalib export [optional_arguments] -h -vv + ``` """ return ( "To get more overridable argument information, run the command below.\n" @@ -118,29 +172,49 @@ def get_verbose_usage(subcommand: str = "train") -> str: def get_cli_usage_docstring(component: object | None) -> str | None: - r"""Get the cli usage from the docstring. + """Extract CLI usage instructions from a component's docstring. + + This function searches for a "CLI Usage:" section in the component's docstring and + extracts its contents. The section should be delimited by either double newlines + or the end of the docstring. Args: - ---- - component (Optional[object]): The component to get the docstring from + component: The object to extract the CLI usage from. Can be None. Returns: - ------- - Optional[str]: The quick-start guide as Markdown format. + The CLI usage instructions as a string with normalized whitespace, or None if: + - The component is None + - The component has no docstring + - The docstring has no "CLI Usage:" section - Example: - ------- - component.__doc__ = ''' - - - CLI Usage: - 1. First Step. - 2. Second Step. - - - ''' - >>> get_cli_usage_docstring(component) - "1. First Step.\n2. Second Step." + Examples: + A docstring with CLI usage section: + + >>> class MyComponent: + ... '''My component description. + ... + ... CLI Usage: + ... 1. Run this command + ... 2. Then this command + ... + ... Other sections... + ... ''' + >>> component = MyComponent() + >>> print(get_cli_usage_docstring(component)) + 1. Run this command + 2. Then this command + + A docstring without CLI usage returns None: + + >>> class NoUsage: + ... '''Just a description''' + >>> print(get_cli_usage_docstring(NoUsage())) + None + + None input returns None: + + >>> print(get_cli_usage_docstring(None)) + None """ if component is None or component.__doc__ is None or "CLI Usage" not in component.__doc__: return None @@ -154,16 +228,34 @@ def get_cli_usage_docstring(component: object | None) -> str | None: return None -def render_guide(subcommand: str | None = None) -> list: +def render_guide(subcommand: str | None = None) -> list[Panel | Markdown]: """Render a guide for the specified subcommand. + This function generates a formatted guide containing usage instructions and examples + for a given CLI subcommand. + Args: - ---- - subcommand (Optional[str]): The subcommand to render the guide for. + subcommand: The subcommand to render the guide for. If None or not found in + DOCSTRING_USAGE, returns an empty list. Returns: - ------- - list: A list of contents to be displayed in the guide. + A list containing rich formatting elements (Panel, Markdown) to be displayed + in the guide. + + Examples: + >>> # Empty list for invalid subcommand + >>> render_guide("invalid") + [] + + >>> # Guide with intro and usage for valid subcommand + >>> guide = render_guide("train") + >>> len(guide) > 0 + True + + Notes: + - The guide includes an introduction section from `get_intro()` + - For valid subcommands, adds CLI usage from docstrings and verbose usage info + - Usage is formatted in a Panel with "Quick-Start" title """ if subcommand is None or subcommand not in DOCSTRING_USAGE: return [] @@ -183,19 +275,28 @@ class CustomHelpFormatter(RichHelpFormatter, DefaultHelpFormatter): This formatter extends the RichHelpFormatter and DefaultHelpFormatter classes to provide a more detailed and customizable help output for Anomalib CLI. + Args: + *args: Variable length argument list passed to parent classes. + **kwargs: Arbitrary keyword arguments passed to parent classes. + Attributes: - verbosity_level : int - The level of verbosity for the help output. - subcommand : str | None - The subcommand to render the guide for. - - Methods: - add_usage(usage, actions, *args, **kwargs) - Add usage information to the help output. - add_argument(action) - Add an argument to the help output. - format_help() - Format the help output. + verbosity_dict (dict): Dictionary containing verbosity level and subcommand. + verbosity_level (int): The level of verbosity for the help output. + subcommand (str | None): The subcommand to render the guide for. + + Example: + >>> from argparse import ArgumentParser + >>> parser = ArgumentParser(formatter_class=CustomHelpFormatter) + >>> parser.add_argument('--test') + >>> help_text = parser.format_help() + >>> isinstance(help_text, str) + True + + Note: + The formatter supports different verbosity levels: + - Level 0: Shows only quick-start guide + - Level 1: Shows required arguments + - Level 2+: Shows all arguments """ verbosity_dict = get_verbosity_subcommand() @@ -205,16 +306,20 @@ class CustomHelpFormatter(RichHelpFormatter, DefaultHelpFormatter): def add_usage(self, usage: str | None, actions: list, *args, **kwargs) -> None: """Add usage information to the formatter. - Args: - ---- - usage (str | None): A string describing the usage of the program. - actions (list): An list of argparse.Action objects. - *args (Any): Additional positional arguments to pass to the superclass method. - **kwargs (Any): Additional keyword arguments to pass to the superclass method. + Filters the actions shown in the usage section based on verbosity level + and required arguments for the current subcommand. - Returns: - ------- - None + Args: + usage: A string describing the usage of the program. + actions: A list of argparse.Action objects. + *args: Additional positional arguments passed to parent method. + **kwargs: Additional keyword arguments passed to parent method. + + Example: + >>> formatter = CustomHelpFormatter() + >>> formatter.add_usage("usage:", [], groups=[]) + >>> True # Method completes without error + True """ if self.subcommand in REQUIRED_ARGUMENTS: if self.verbosity_level == 0: @@ -227,12 +332,25 @@ def add_usage(self, usage: str | None, actions: list, *args, **kwargs) -> None: def add_argument(self, action: argparse.Action) -> None: """Add an argument to the help formatter. - If the verbose level is set to 0, the argument is not added. - If the verbose level is set to 1 and the argument is not in the non-skip list, the argument is not added. + Controls which arguments are displayed based on verbosity level and + whether they are required for the current subcommand. Args: - ---- - action (argparse.Action): The action to add to the help formatter. + action: The argparse.Action object to potentially add to the help output. + + Example: + >>> from argparse import Action, ArgumentParser + >>> parser = ArgumentParser() + >>> action = parser.add_argument('--test') + >>> formatter = CustomHelpFormatter() + >>> formatter.add_argument(action) + >>> True # Method completes without error + True + + Note: + - At verbosity level 0, no arguments are shown + - At verbosity level 1, only required arguments are shown + - At higher verbosity levels, all arguments are shown """ if self.subcommand in REQUIRED_ARGUMENTS: if self.verbosity_level == 0: @@ -242,13 +360,25 @@ def add_argument(self, action: argparse.Action) -> None: super().add_argument(action) def format_help(self) -> str: - """Format the help message for the current command and returns it as a string. + """Format the complete help message. - The help message includes information about the command's arguments and options, - as well as any additional information provided by the command's help guide. + Generates a formatted help message that includes command arguments, options, + and additional guide information based on the current verbosity level. Returns: - str: A string containing the formatted help message. + str: The formatted help message as a string. + + Example: + >>> formatter = CustomHelpFormatter() + >>> help_text = formatter.format_help() + >>> isinstance(help_text, str) + True + + Note: + The output format depends on verbosity level: + - Level 0-1: Shows quick-start guide for supported subcommands + - Level 1+: Includes argument section in a panel + - All levels: Maintains consistent spacing and formatting """ with self.console.capture() as capture: section = self._root_section diff --git a/src/anomalib/cli/utils/installation.py b/src/anomalib/cli/utils/installation.py index 01c2f9d288..a9df2dff6e 100644 --- a/src/anomalib/cli/utils/installation.py +++ b/src/anomalib/cli/utils/installation.py @@ -1,4 +1,8 @@ -"""Anomalib installation util functions.""" +"""Anomalib installation utilities. + +This module provides utilities for managing Anomalib package installation, +including dependency resolution and hardware-specific package selection. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -25,20 +29,32 @@ def get_requirements(module: str = "anomalib") -> dict[str, list[Requirement]]: - """Get requirements of module from importlib.metadata. + """Get package requirements from importlib.metadata. - This function returns list of required packages from importlib_metadata. + Args: + module (str): Name of the module to get requirements for. Defaults to "anomalib". + + Returns: + dict[str, list[Requirement]]: Dictionary mapping requirement groups to their + package requirements. Example: - >>> get_requirements("anomalib") + ```python + get_requirements("anomalib") + # Returns: { "base": ["jsonargparse==4.27.1", ...], "core": ["torch==2.1.1", ...], ... } - - Returns: - dict[str, list[Requirement]]: List of required packages for each optional-extras. + ``` + + Test: + >>> result = get_requirements("anomalib") + >>> isinstance(result, dict) + True + >>> all(isinstance(v, list) for v in result.values()) + True """ requirement_list: list[str] | None = requires(module) extra_requirement: dict[str, list[Requirement]] = {} @@ -62,26 +78,37 @@ def parse_requirements( requirements: list[Requirement], skip_torch: bool = False, ) -> tuple[str | None, list[str]]: - """Parse requirements and returns torch and other requirements. + """Parse requirements into torch and other requirements. Args: - requirements (list[Requirement]): List of requirements. + requirements (list[Requirement]): List of requirements to parse. skip_torch (bool): Whether to skip torch requirement. Defaults to False. - Raises: - ValueError: If torch requirement is not found. + Returns: + tuple[str | None, list[str]]: Tuple containing: + - Torch requirement string or None if skipped + - List of other requirement strings - Examples: - >>> requirements = [ - ... Requirement.parse("torch==1.13.0"), - ... Requirement.parse("onnx>=1.8.1"), - ... ] - >>> parse_requirements(requirements=requirements) - (Requirement.parse("torch==1.13.0"), - Requirement.parse("onnx>=1.8.1")) + Raises: + ValueError: If torch requirement is not found and skip_torch is False. - Returns: - tuple[str, list[str], list[str]]: Tuple of torch and other requirements. + Example: + ```python + requirements = [ + Requirement.parse("torch==1.13.0"), + Requirement.parse("onnx>=1.8.1"), + ] + parse_requirements(requirements) + # Returns: ('torch==1.13.0', ['onnx>=1.8.1']) + ``` + + Test: + >>> reqs = [Requirement.parse("torch==1.13.0"), Requirement.parse("onnx>=1.8.1")] + >>> torch_req, other_reqs = parse_requirements(reqs) + >>> torch_req == "torch==1.13.0" + True + >>> other_reqs == ["onnx>=1.8.1"] + True """ torch_requirement: str | None = None other_requirements: list[str] = [] @@ -115,17 +142,27 @@ def parse_requirements( def get_cuda_version() -> str | None: """Get CUDA version installed on the system. - Examples: - >>> # Assume that CUDA version is 11.2 - >>> get_cuda_version() - "11.2" - - >>> # Assume that CUDA is not installed on the system - >>> get_cuda_version() - None - Returns: - str | None: CUDA version installed on the system. + str | None: CUDA version string (e.g., "11.8") or None if not found. + + Example: + ```python + # System with CUDA 11.8 installed + get_cuda_version() + # Returns: "11.8" + + # System without CUDA + get_cuda_version() + # Returns: None + ``` + + Test: + >>> version = get_cuda_version() + >>> version is None or isinstance(version, str) + True + >>> if version is not None: + ... version.count('.') == 1 and all(part.isdigit() for part in version.split('.')) + ... True """ # 1. Check CUDA_HOME Environment variable cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda") @@ -157,30 +194,20 @@ def get_cuda_version() -> str | None: def update_cuda_version_with_available_torch_cuda_build(cuda_version: str, torch_version: str) -> str: - """Update the installed CUDA version with the highest supported CUDA version by PyTorch. + """Update CUDA version to match PyTorch's supported versions. Args: cuda_version (str): The installed CUDA version. torch_version (str): The PyTorch version. - Raises: - Warning: If the installed CUDA version is not supported by PyTorch. - - Examples: - >>> update_cuda_version_with_available_torch_cuda_builds("11.1", "1.13.0") - "11.6" - - >>> update_cuda_version_with_available_torch_cuda_builds("11.7", "1.13.0") - "11.7" - - >>> update_cuda_version_with_available_torch_cuda_builds("11.8", "1.13.0") - "11.7" - - >>> update_cuda_version_with_available_torch_cuda_builds("12.1", "2.0.1") - "11.8" - Returns: - str: The updated CUDA version. + str: The updated CUDA version that's compatible with PyTorch. + + Example: + ```python + update_cuda_version_with_available_torch_cuda_build("12.1", "2.0.1") + # Returns: "11.8" # PyTorch 2.0.1 only supports up to CUDA 11.8 + ``` """ max_supported_cuda = max(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"]) min_supported_cuda = min(AVAILABLE_TORCH_VERSIONS[torch_version]["cuda"]) @@ -204,63 +231,58 @@ def get_cuda_suffix(cuda_version: str) -> str: """Get CUDA suffix for PyTorch versions. Args: - cuda_version (str): CUDA version installed on the system. - - Note: - The CUDA version of PyTorch is not always the same as the CUDA version - that is installed on the system. For example, the latest PyTorch - version (1.10.0) supports CUDA 11.3, but the latest CUDA version - that is available for download is 11.2. Therefore, we need to use - the latest available CUDA version for PyTorch instead of the CUDA - version that is installed on the system. Therefore, this function - shoudl be regularly updated to reflect the latest available CUDA. - - Examples: - >>> get_cuda_suffix(cuda_version="11.2") - "cu112" - - >>> get_cuda_suffix(cuda_version="11.8") - "cu118" + cuda_version (str): CUDA version string (e.g., "11.8"). Returns: - str: CUDA suffix for PyTorch or mmX version. + str: CUDA suffix for PyTorch (e.g., "cu118"). + + Example: + ```python + get_cuda_suffix("11.8") + # Returns: "cu118" + ``` + + Test: + >>> get_cuda_suffix("11.8") + 'cu118' + >>> get_cuda_suffix("12.1") + 'cu121' """ return f"cu{cuda_version.replace('.', '')}" def get_hardware_suffix(with_available_torch_build: bool = False, torch_version: str | None = None) -> str: - """Get hardware suffix for PyTorch or mmX versions. + """Get hardware suffix for PyTorch package names. Args: - with_available_torch_build (bool): Whether to use the latest available - PyTorch build or not. If True, the latest available PyTorch build - will be used. If False, the installed PyTorch build will be used. - Defaults to False. - torch_version (str | None): PyTorch version. This is only used when the - ``with_available_torch_build`` is True. - - Examples: - >>> # Assume that CUDA version is 11.2 - >>> get_hardware_suffix() - "cu112" - - >>> # Assume that CUDA is not installed on the system - >>> get_hardware_suffix() - "cpu" - - Assume that that installed CUDA version is 12.1. - However, the latest available CUDA version for PyTorch v2.0 is 11.8. - Therefore, we use 11.8 instead of 12.1. This is because PyTorch does not - support CUDA 12.1 yet. In this case, we could correct the CUDA version - by setting `with_available_torch_build` to True. - - >>> cuda_version = get_cuda_version() - "12.1" - >>> get_hardware_suffix(with_available_torch_build=True, torch_version="2.0.1") - "cu118" + with_available_torch_build (bool): Whether to use available PyTorch builds + to determine the suffix. Defaults to False. + torch_version (str | None): PyTorch version to check against. Required if + with_available_torch_build is True. Returns: - str: Hardware suffix for PyTorch or mmX version. + str: Hardware suffix (e.g., "cu118" or "cpu"). + + Raises: + ValueError: If torch_version is not provided when with_available_torch_build is True. + + Example: + ```python + # System with CUDA 11.8 + get_hardware_suffix() + # Returns: "cu118" + + # System without CUDA + get_hardware_suffix() + # Returns: "cpu" + ``` + + Test: + >>> suffix = get_hardware_suffix() + >>> isinstance(suffix, str) + True + >>> suffix in {'cpu'} or suffix.startswith('cu') + True """ cuda_version = get_cuda_version() if cuda_version: @@ -277,26 +299,38 @@ def get_hardware_suffix(with_available_torch_build: bool = False, torch_version: def get_torch_install_args(requirement: str | Requirement) -> list[str]: - """Get the install arguments for Torch requirement. - - This function will return the install arguments for the Torch requirement - and its corresponding torchvision requirement. + """Get pip install arguments for PyTorch packages. Args: - requirement (str | Requirement): The torch requirement. + requirement (str | Requirement): The torch requirement specification. + + Returns: + list[str]: List of pip install arguments. Raises: RuntimeError: If the OS is not supported. Example: - >>> from pkg_resources import Requirement - >>> requriment = "torch>=1.13.0" - >>> get_torch_install_args(requirement) - ['--extra-index-url', 'https://download.pytorch.org/whl/cpu', - 'torch>=1.13.0', 'torchvision==0.14.0'] - - Returns: - list[str]: The install arguments. + ```python + requirement = "torch>=2.0.0" + get_torch_install_args(requirement) + # Returns: + [ + '--extra-index-url', + 'https://download.pytorch.org/whl/cu118', + 'torch>=2.0.0', + 'torchvision==0.15.1' + ] + ``` + + Test: + >>> args = get_torch_install_args("torch>=2.0.0") + >>> isinstance(args, list) + True + >>> all(isinstance(arg, str) for arg in args) + True + >>> any('torch' in arg for arg in args) + True """ if isinstance(requirement, str): requirement = Requirement.parse(requirement) diff --git a/src/anomalib/cli/utils/openvino.py b/src/anomalib/cli/utils/openvino.py index 50a894c304..00c2fda1ae 100644 --- a/src/anomalib/cli/utils/openvino.py +++ b/src/anomalib/cli/utils/openvino.py @@ -1,6 +1,10 @@ -"""Utils for OpenVINO parser.""" +"""OpenVINO CLI utilities. -# Copyright (C) 2023 Intel Corporation +This module provides utilities for adding OpenVINO-specific arguments to the Anomalib CLI. +It handles the integration of OpenVINO Model Optimizer parameters into the command line interface. +""" + +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import logging @@ -18,7 +22,43 @@ def add_openvino_export_arguments(parser: ArgumentParser) -> None: - """Add OpenVINO arguments to parser under --mo key.""" + """Add OpenVINO Model Optimizer arguments to the parser. + + This function adds OpenVINO-specific export arguments to the parser under the `ov_args` prefix. + If OpenVINO is not installed, it logs an informational message and skips adding the arguments. + + The function adds Model Optimizer arguments like data_type, mean_values, etc. as optional + parameters that can be used during model export to OpenVINO format. + + Args: + parser (ArgumentParser): The argument parser to add OpenVINO arguments to. + This should be an instance of jsonargparse.ArgumentParser. + + Examples: + Add OpenVINO arguments to a parser: + + >>> from jsonargparse import ArgumentParser + >>> parser = ArgumentParser() + >>> add_openvino_export_arguments(parser) + + The parser will now accept OpenVINO arguments like: + + >>> # parser.parse_args(['--ov_args.data_type', 'FP16']) + >>> # parser.parse_args(['--ov_args.mean_values', '[123.675,116.28,103.53]']) + + Notes: + - Requires OpenVINO to be installed to add the arguments + - Automatically skips redundant arguments that are handled elsewhere: + - help + - input_model + - output_dir + - Arguments are added under the 'ov_args' prefix for namespacing + - All OpenVINO arguments are made optional + + See Also: + - OpenVINO Model Optimizer docs: https://docs.openvino.ai/latest/openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html + - OpenVINO Python API: https://docs.openvino.ai/latest/api/python_api.html + """ if get_common_cli_parser is not None: group = parser.add_argument_group("OpenVINO Model Optimizer arguments (optional)") ov_parser = get_common_cli_parser() diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index 9c9be7eb5b..3f7389647f 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -1,4 +1,28 @@ -"""Anomalib Datasets.""" +"""Anomalib Datasets. + +This module provides datasets and data modules for anomaly detection tasks. + +The module contains: + - Data classes for representing different types of data (images, videos, etc.) + - Dataset classes for loading and processing data + - Data modules for use with PyTorch Lightning + - Helper functions for data loading and validation + +Example: + >>> from anomalib.data import get_datamodule + >>> from omegaconf import DictConfig + >>> config = DictConfig({ + ... "data": { + ... "class_path": "MVTec", + ... "init_args": { + ... "root": "./datasets/MVTec", + ... "category": "bottle", + ... "image_size": (256, 256) + ... } + ... } + ... }) + >>> datamodule = get_datamodule(config) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -51,17 +75,34 @@ ) -class UnknownDatamoduleError(ModuleNotFoundError): ... +class UnknownDatamoduleError(ModuleNotFoundError): + """Raised when a datamodule cannot be found.""" def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule: - """Get Anomaly Datamodule. + """Get Anomaly Datamodule from config. Args: - config (DictConfig | ListConfig | dict): Configuration of the anomaly model. + config: Configuration for the anomaly model. Can be either: + - DictConfig from OmegaConf + - ListConfig from OmegaConf + - Python dictionary Returns: - PyTorch Lightning DataModule + PyTorch Lightning DataModule configured according to the input. + + Raises: + UnknownDatamoduleError: If the specified datamodule cannot be found. + + Example: + >>> from omegaconf import DictConfig + >>> config = DictConfig({ + ... "data": { + ... "class_path": "MVTec", + ... "init_args": {"root": "./datasets/MVTec"} + ... } + ... }) + >>> datamodule = get_datamodule(config) """ logger.info("Loading the datamodule") diff --git a/src/anomalib/data/dataclasses/__init__.py b/src/anomalib/data/dataclasses/__init__.py index a7f8516ae5..f0f08e3e54 100644 --- a/src/anomalib/data/dataclasses/__init__.py +++ b/src/anomalib/data/dataclasses/__init__.py @@ -1,45 +1,82 @@ """Anomalib dataclasses. -This module provides a collection of dataclasses used throughout the Anomalib library -for representing and managing various types of data related to anomaly detection tasks. +This module provides a collection of dataclasses used throughout the Anomalib +library for representing and managing various types of data related to anomaly +detection tasks. The dataclasses are organized into two main categories: -1. Numpy-based dataclasses for handling numpy array data. -2. Torch-based dataclasses for handling PyTorch tensor data. +1. Numpy-based dataclasses for handling numpy array data +2. Torch-based dataclasses for handling PyTorch tensor data -Key components: +Key Components +------------- -Numpy Dataclasses: - ``NumpyImageItem``: Represents a single image item as numpy arrays. - ``NumpyImageBatch``: Represents a batch of image data as numpy arrays. - ``NumpyVideoItem``: Represents a single video item as numpy arrays. - ``NumpyVideoBatch``: Represents a batch of video data as numpy arrays. +Numpy Dataclasses +~~~~~~~~~~~~~~~~ -Torch Dataclasses: - ``Batch``: Base class for torch-based batch data. - ``DatasetItem``: Base class for torch-based dataset items. - ``DepthItem``: Represents a single depth data item. - ``DepthBatch``: Represents a batch of depth data. - ``ImageItem``: Represents a single image item as torch tensors. - ``ImageBatch``: Represents a batch of image data as torch tensors. - ``VideoItem``: Represents a single video item as torch tensors. - ``VideoBatch``: Represents a batch of video data as torch tensors. - ``InferenceBatch``: Specialized batch class for inference results. +- :class:`NumpyImageItem`: Single image item as numpy arrays + - Data shape: ``(H, W, C)`` or ``(H, W)`` for grayscale + - Labels: Binary classification (0: normal, 1: anomalous) + - Masks: Binary segmentation masks ``(H, W)`` + +- :class:`NumpyImageBatch`: Batch of image data as numpy arrays + - Data shape: ``(N, H, W, C)`` or ``(N, H, W)`` for grayscale + - Labels: ``(N,)`` binary labels + - Masks: ``(N, H, W)`` binary masks + +- :class:`NumpyVideoItem`: Single video item as numpy arrays + - Data shape: ``(T, H, W, C)`` or ``(T, H, W)`` for grayscale + - Labels: Binary classification per video + - Masks: ``(T, H, W)`` temporal segmentation masks + +- :class:`NumpyVideoBatch`: Batch of video data as numpy arrays + - Data shape: ``(N, T, H, W, C)`` or ``(N, T, H, W)`` for grayscale + - Labels: ``(N,)`` binary labels + - Masks: ``(N, T, H, W)`` batch of temporal masks + +Torch Dataclasses +~~~~~~~~~~~~~~~~ + +- :class:`Batch`: Base class for torch-based batch data +- :class:`DatasetItem`: Base class for torch-based dataset items +- :class:`DepthItem`: Single depth data item + - RGB image: ``(3, H, W)`` + - Depth map: ``(H, W)`` +- :class:`DepthBatch`: Batch of depth data + - RGB images: ``(N, 3, H, W)`` + - Depth maps: ``(N, H, W)`` +- :class:`ImageItem`: Single image as torch tensors + - Data shape: ``(C, H, W)`` +- :class:`ImageBatch`: Batch of images as torch tensors + - Data shape: ``(N, C, H, W)`` +- :class:`VideoItem`: Single video as torch tensors + - Data shape: ``(T, C, H, W)`` +- :class:`VideoBatch`: Batch of videos as torch tensors + - Data shape: ``(N, T, C, H, W)`` +- :class:`InferenceBatch`: Specialized batch for inference results + - Predictions: Scores, labels, anomaly maps and masks These dataclasses provide a structured way to handle various types of data in anomaly detection tasks, ensuring type consistency and easy data manipulation across different components of the Anomalib library. + +Example: +------- +>>> from anomalib.data.dataclasses import ImageItem +>>> import torch +>>> item = ImageItem( +... image=torch.rand(3, 224, 224), +... gt_label=torch.tensor(0), +... image_path="path/to/image.jpg" +... ) +>>> item.image.shape +torch.Size([3, 224, 224]) """ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .numpy import ( - NumpyImageBatch, - NumpyImageItem, - NumpyVideoBatch, - NumpyVideoItem, -) +from .numpy import NumpyImageBatch, NumpyImageItem, NumpyVideoBatch, NumpyVideoItem from .torch import ( Batch, DatasetItem, diff --git a/src/anomalib/data/dataclasses/generic.py b/src/anomalib/data/dataclasses/generic.py index 5f9dca9dc9..3ee82153fd 100644 --- a/src/anomalib/data/dataclasses/generic.py +++ b/src/anomalib/data/dataclasses/generic.py @@ -4,6 +4,30 @@ to define and validate various types of data fields used in Anomalib. The dataclasses are designed to be flexible and extensible, allowing for easy customization and validation of input and output data. + +The module contains several key components: + +- Field descriptors for validation +- Base input field classes for images, videos and depth data +- Output field classes for predictions +- Mixins for updating and batch iteration +- Generic item and batch classes + +Example: + >>> from anomalib.data.dataclasses import _InputFields + >>> from torchvision.tv_tensors import Image, Mask + >>> + >>> class MyInput(_InputFields[int, Image, Mask, str]): + ... def validate_image(self, image): + ... return image + ... # Implement other validation methods + ... + >>> input_data = MyInput( + ... image=torch.rand(3,224,224), + ... gt_label=1, + ... gt_mask=None, + ... mask_path=None + ... ) """ # Copyright (C) 2024 Intel Corporation @@ -37,12 +61,21 @@ class FieldDescriptor(Generic[Value]): validated before being set. This allows validation of the input data not only when it is first set, but also when it is updated. - Attributes: - validator_name (str | None): The name of the validator method to be - called when setting the value. - Defaults to ``None``. - default (Value | None): The default value for the field. + Args: + validator_name: Name of the validator method to call when setting value. Defaults to ``None``. + default: Default value for the field. Defaults to ``None``. + + Example: + >>> class MyClass: + ... field = FieldDescriptor(validator_name="validate_field") + ... def validate_field(self, value): + ... return value + ... + >>> obj = MyClass() + >>> obj.field = 42 + >>> obj.field + 42 """ def __init__(self, validator_name: str | None = None, default: Value | None = None) -> None: @@ -51,15 +84,26 @@ def __init__(self, validator_name: str | None = None, default: Value | None = No self.default = default def __set_name__(self, owner: type[Instance], name: str) -> None: - """Set the name of the descriptor.""" + """Set the name of the descriptor. + + Args: + owner: Class that owns the descriptor + name: Name of the descriptor in the owner class + """ self.name = name def __get__(self, instance: Instance | None, owner: type[Instance]) -> Value | None: """Get the value of the descriptor. + Args: + instance: Instance the descriptor is accessed from + owner: Class that owns the descriptor + Returns: - - The default value if available and if the instance is None (method is called from class). - - The value of the attribute if the instance is not None (method is called from instance). + Default value if instance is None, otherwise the stored value + + Raises: + AttributeError: If no default value and field is not optional """ if instance is None: if self.default is not None or self.is_optional(owner): @@ -71,7 +115,11 @@ def __get__(self, instance: Instance | None, owner: type[Instance]) -> Value | N def __set__(self, instance: object, value: Value) -> None: """Set the value of the descriptor. - First calls the validator method if available, then sets the value of the attribute. + First calls the validator method if available, then sets the value. + + Args: + instance: Instance to set the value on + value: Value to set """ if self.validator_name is not None: validator = getattr(instance, self.validator_name) @@ -79,7 +127,17 @@ def __set__(self, instance: object, value: Value) -> None: instance.__dict__[self.name] = value def get_types(self, owner: type[Instance]) -> tuple[type, ...]: - """Get the types of the descriptor.""" + """Get the types of the descriptor. + + Args: + owner: Class that owns the descriptor + + Returns: + Tuple of valid types for this field + + Raises: + TypeError: If types cannot be determined + """ try: types = get_args(get_type_hints(owner)[self.name]) return get_args(types[0]) if hasattr(types[0], "__args__") else (types[0],) @@ -88,7 +146,14 @@ def get_types(self, owner: type[Instance]) -> tuple[type, ...]: raise TypeError(msg) from e def is_optional(self, owner: type[Instance]) -> bool: - """Check if the descriptor is optional.""" + """Check if the descriptor is optional. + + Args: + owner: Class that owns the descriptor + + Returns: + True if field can be None, False otherwise + """ return NoneType in self.get_types(owner) @@ -96,36 +161,28 @@ def is_optional(self, owner: type[Instance]) -> bool: class _InputFields(Generic[T, ImageT, MaskT, PathT], ABC): """Generic dataclass that defines the standard input fields for Anomalib. - This abstract base class provides a structure for input data used in Anomalib, - a library for anomaly detection in images and videos. It defines common fields - used across various anomaly detection tasks and data types in Anomalib. + This abstract base class provides a structure for input data used in Anomalib. + It defines common fields used across various anomaly detection tasks and data + types. - Subclasses must implement the abstract validation methods to define the - specific validation logic for each field based on the requirements of different - Anomalib models and data processing pipelines. - - Examples: - Assuming a concrete implementation `DummyInput`: - - >>> class DummyInput(_InputFields[int, Image, Mask, str]): - ... # Implement actual validation + Attributes: + image: Input image or video + gt_label: Ground truth label + gt_mask: Ground truth segmentation mask + mask_path: Path to mask file - >>> # Create an input instance - >>> input_item = DummyInput( - ... image=torch.rand(3, 224, 224), + Example: + >>> class MyInput(_InputFields[int, Image, Mask, str]): + ... def validate_image(self, image): + ... return image + ... # Implement other validation methods + ... + >>> input_data = MyInput( + ... image=torch.rand(3,224,224), ... gt_label=1, - ... gt_mask=torch.rand(224, 224) > 0.5, - ... mask_path="path/to/mask.png" + ... gt_mask=None, + ... mask_path=None ... ) - - >>> # Access fields - >>> image = input_item.image - >>> label = input_item.gt_label - - Note: - This is an abstract base class and is not intended to be instantiated - directly. Concrete subclasses should implement all required validation - methods. """ image: FieldDescriptor[ImageT] = FieldDescriptor(validator_name="validate_image") @@ -136,25 +193,65 @@ class _InputFields(Generic[T, ImageT, MaskT, PathT], ABC): @staticmethod @abstractmethod def validate_image(image: ImageT) -> ImageT: - """Validate the image.""" + """Validate the image. + + Args: + image: Input image to validate + + Returns: + Validated image + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_gt_mask(gt_mask: MaskT) -> MaskT | None: - """Validate the ground truth mask.""" + """Validate the ground truth mask. + + Args: + gt_mask: Ground truth mask to validate + + Returns: + Validated mask or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_mask_path(mask_path: PathT) -> PathT | None: - """Validate the mask path.""" + """Validate the mask path. + + Args: + mask_path: Path to mask file to validate + + Returns: + Validated path or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_gt_label(gt_label: T) -> T | None: - """Validate the ground truth label.""" + """Validate the ground truth label. + + Args: + gt_label: Ground truth label to validate + + Returns: + Validated label or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @@ -163,35 +260,17 @@ class _ImageInputFields(Generic[PathT], ABC): """Generic dataclass for image-specific input fields in Anomalib. This class extends standard input fields with an ``image_path`` attribute for - image-based anomaly detection tasks. It allows Anomalib to work efficiently - with disk-stored image datasets, facilitating custom data loading strategies. - - The ``image_path`` field uses a ``FieldDescriptor`` with a validation method. - Subclasses must implement ``validate_image_path`` to ensure path validity - according to specific Anomalib model or dataset requirements. - - This class is designed to complement ``_InputFields`` for comprehensive - image-based anomaly detection input in Anomalib. - - Examples: - Assuming a concrete implementation ``DummyImageInput``: - >>> class DummyImageInput(_ImageInputFields): - ... def validate_image_path(self, image_path): - ... return image_path # Implement actual validation - ... # Implement other required methods - - >>> # Create an image input instance - >>> image_input = DummyImageInput( - ... image_path="path/to/image.jpg" - ... ) + image-based anomaly detection tasks. - >>> # Access image-specific field - >>> path = image_input.image_path - - Note: - This is an abstract base class and is not intended to be instantiated - directly. Concrete subclasses should implement all required validation - methods. + Attributes: + image_path: Path to input image file + + Example: + >>> class MyImageInput(_ImageInputFields[str]): + ... def validate_image_path(self, path): + ... return path + ... + >>> input_data = MyImageInput(image_path="path/to/image.jpg") """ image_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="validate_image_path") @@ -199,7 +278,17 @@ class _ImageInputFields(Generic[PathT], ABC): @staticmethod @abstractmethod def validate_image_path(image_path: PathT) -> PathT | None: - """Validate the image path.""" + """Validate the image path. + + Args: + image_path: Path to validate + + Returns: + Validated path or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @@ -207,45 +296,29 @@ def validate_image_path(image_path: PathT) -> PathT | None: class _VideoInputFields(Generic[T, ImageT, MaskT, PathT], ABC): """Generic dataclass that defines the video input fields for Anomalib. - This class extends standard input fields with attributes specific to video-based - anomaly detection tasks. It includes fields for original images, video paths, - target frames, frame sequences, and last frames. - - Each field uses a ``FieldDescriptor`` with a corresponding validation method. - Subclasses must implement these abstract validation methods to ensure data - consistency with Anomalib's video processing requirements. - - This class is designed to work alongside other input field classes to provide - comprehensive support for video-based anomaly detection in Anomalib. - - Examples: - Assuming a concrete implementation ``DummyVideoInput``: - - >>> class DummyVideoInput(_VideoInputFields): - ... def validate_original_image(self, original_image): - ... return original_image # Implement actual validation - ... # Implement other required methods + This class extends standard input fields with attributes specific to + video-based anomaly detection tasks. - >>> # Create a video input instance - >>> video_input = DummyVideoInput( - ... original_image=torch.rand(3, 224, 224), - ... video_path="path/to/video.mp4", + Attributes: + original_image: Original frame from video + video_path: Path to input video file + target_frame: Frame number to process + frames: Sequence of video frames + last_frame: Last frame in sequence + + Example: + >>> class MyVideoInput(_VideoInputFields[int, Image, Mask, str]): + ... def validate_original_image(self, image): + ... return image + ... # Implement other validation methods + ... + >>> input_data = MyVideoInput( + ... original_image=torch.rand(3,224,224), + ... video_path="video.mp4", ... target_frame=10, - ... frames=torch.rand(3, 224, 224), - ... last_frame=torch.rand(3, 224, 224) + ... frames=None, + ... last_frame=None ... ) - - >>> # Access video-specific fields - >>> original_image = video_input.original_image - >>> path = video_input.video_path - >>> target_frame = video_input.target_frame - >>> frames = video_input.frames - >>> last_frame = video_input.last_frame - - Note: - This is an abstract base class and is not intended to be instantiated - directly. Concrete subclasses should implement all required validation - methods. """ original_image: FieldDescriptor[ImageT | None] = FieldDescriptor(validator_name="validate_original_image") @@ -257,31 +330,81 @@ class _VideoInputFields(Generic[T, ImageT, MaskT, PathT], ABC): @staticmethod @abstractmethod def validate_original_image(original_image: ImageT) -> ImageT | None: - """Validate the original image.""" + """Validate the original image. + + Args: + original_image: Image to validate + + Returns: + Validated image or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_video_path(video_path: PathT) -> PathT | None: - """Validate the video path.""" + """Validate the video path. + + Args: + video_path: Path to validate + + Returns: + Validated path or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_target_frame(target_frame: T) -> T | None: - """Validate the target frame.""" + """Validate the target frame. + + Args: + target_frame: Frame number to validate + + Returns: + Validated frame number or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_frames(frames: T) -> T | None: - """Validate the frames.""" + """Validate the frames. + + Args: + frames: Frame sequence to validate + + Returns: + Validated frames or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_last_frame(last_frame: T) -> T | None: - """Validate the last frame.""" + """Validate the last frame. + + Args: + last_frame: Frame to validate + + Returns: + Validated frame or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @@ -289,41 +412,26 @@ def validate_last_frame(last_frame: T) -> T | None: class _DepthInputFields(Generic[T, PathT], _ImageInputFields[PathT], ABC): """Generic dataclass that defines the depth input fields for Anomalib. - This class extends the standard input fields with a ``depth_map`` and - ``depth_path`` attribute for depth-based anomaly detection tasks. It allows - Anomalib to work efficiently with depth-based anomaly detection tasks, - facilitating custom data loading strategies. - - The ``depth_map`` and ``depth_path`` fields use a ``FieldDescriptor`` with - corresponding validation methods. Subclasses must implement these abstract - validation methods to ensure data consistency with Anomalib's depth processing - requirements. - - Examples: - Assuming a concrete implementation ``DummyDepthInput``: - - >>> class DummyDepthInput(_DepthInputFields): - ... def validate_depth_map(self, depth_map): - ... return depth_map # Implement actual validation - ... def validate_depth_path(self, depth_path): - ... return depth_path # Implement actual validation - ... # Implement other required methods - - >>> # Create a depth input instance - >>> depth_input = DummyDepthInput( - ... image_path="path/to/image.jpg", - ... depth_map=torch.rand(224, 224), - ... depth_path="path/to/depth.png" - ... ) + This class extends standard input fields with depth-specific attributes for + depth-based anomaly detection tasks. - >>> # Access depth-specific fields - >>> depth_map = depth_input.depth_map - >>> depth_path = depth_input.depth_path - - Note: - This is an abstract base class and is not intended to be instantiated - directly. Concrete subclasses should implement all required validation - methods. + Attributes: + depth_map: Depth map image + depth_path: Path to depth map file + + Example: + >>> class MyDepthInput(_DepthInputFields[torch.Tensor, str]): + ... def validate_depth_map(self, depth): + ... return depth + ... def validate_depth_path(self, path): + ... return path + ... # Implement other validation methods + ... + >>> input_data = MyDepthInput( + ... image_path="rgb.jpg", + ... depth_map=torch.rand(224,224), + ... depth_path="depth.png" + ... ) """ depth_map: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_depth_map") @@ -332,13 +440,33 @@ class _DepthInputFields(Generic[T, PathT], _ImageInputFields[PathT], ABC): @staticmethod @abstractmethod def validate_depth_map(depth_map: ImageT) -> ImageT | None: - """Validate the depth map.""" + """Validate the depth map. + + Args: + depth_map: Depth map to validate + + Returns: + Validated depth map or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_depth_path(depth_path: PathT) -> PathT | None: - """Validate the depth path.""" + """Validate the depth path. + + Args: + depth_path: Path to validate + + Returns: + Validated path or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @@ -347,43 +475,28 @@ class _OutputFields(Generic[T, MaskT, PathT], ABC): """Generic dataclass that defines the standard output fields for Anomalib. This class defines the standard output fields used in Anomalib, including - anomaly maps, predicted scores, predicted masks, and predicted labels. - - Each field uses a ``FieldDescriptor`` with a corresponding validation method. - Subclasses must implement these abstract validation methods to ensure data - consistency with Anomalib's anomaly detection tasks. - - Examples: - Assuming a concrete implementation ``DummyOutput``: - - >>> class DummyOutput(_OutputFields): - ... def validate_anomaly_map(self, anomaly_map): - ... return anomaly_map # Implement actual validation - ... def validate_pred_score(self, pred_score): - ... return pred_score # Implement actual validation - ... def validate_pred_mask(self, pred_mask): - ... return pred_mask # Implement actual validation - ... def validate_pred_label(self, pred_label): - ... return pred_label # Implement actual validation - - >>> # Create an output instance with predictions - >>> output = DummyOutput( - ... anomaly_map=torch.rand(224, 224), + anomaly maps, predicted scores, masks, and labels. + + Attributes: + anomaly_map: Predicted anomaly heatmap + pred_score: Predicted anomaly score + pred_mask: Predicted segmentation mask + pred_label: Predicted label + explanation: Path to explanation visualization + + Example: + >>> class MyOutput(_OutputFields[float, Mask, str]): + ... def validate_anomaly_map(self, amap): + ... return amap + ... # Implement other validation methods + ... + >>> output = MyOutput( + ... anomaly_map=torch.rand(224,224), ... pred_score=0.7, - ... pred_mask=torch.rand(224, 224) > 0.5, - ... pred_label=1 + ... pred_mask=None, + ... pred_label=1, + ... explanation=None ... ) - - >>> # Access individual fields - >>> anomaly_map = output.anomaly_map - >>> score = output.pred_score - >>> mask = output.pred_mask - >>> label = output.pred_label - - Note: - This is an abstract base class and is not intended to be instantiated - directly. Concrete subclasses should implement all required validation - methods. """ anomaly_map: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="validate_anomaly_map") @@ -395,70 +508,118 @@ class _OutputFields(Generic[T, MaskT, PathT], ABC): @staticmethod @abstractmethod def validate_anomaly_map(anomaly_map: MaskT) -> MaskT | None: - """Validate the anomaly map.""" + """Validate the anomaly map. + + Args: + anomaly_map: Anomaly map to validate + + Returns: + Validated anomaly map or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_pred_score(pred_score: T) -> T | None: - """Validate the predicted score.""" + """Validate the predicted score. + + Args: + pred_score: Score to validate + + Returns: + Validated score or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_pred_mask(pred_mask: MaskT) -> MaskT | None: - """Validate the predicted mask.""" + """Validate the predicted mask. + + Args: + pred_mask: Mask to validate + + Returns: + Validated mask or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_pred_label(pred_label: T) -> T | None: - """Validate the predicted label.""" + """Validate the predicted label. + + Args: + pred_label: Label to validate + + Returns: + Validated label or None + + Raises: + NotImplementedError: Must be implemented by subclass + """ raise NotImplementedError @staticmethod @abstractmethod def validate_explanation(explanation: PathT) -> PathT | None: - """Validate the explanation.""" - raise NotImplementedError + """Validate the explanation. + Args: + explanation: Explanation to validate -@dataclass -class UpdateMixin: - """Mixin class for dataclasses that allows for in-place replacement of attributes. - - This mixin class provides a method for updating dataclass instances in place or - by creating a new instance. It ensures that the updated instance is reinitialized - by calling the ``__post_init__`` method if it exists. - - Examples: - Assuming a dataclass `DummyItem` that uses UpdateMixin: - - >>> item = DummyItem(image=torch.rand(3, 224, 224), label=0) - - >>> # In-place update - >>> item.update(label=1, pred_score=0.9) - >>> print(item.label, item.pred_score) - 1 0.9 + Returns: + Validated explanation or None - >>> # Create a new instance with updates - >>> new_item = item.update(in_place=False, image=torch.rand(3, 224, 224)) - >>> print(id(item) != id(new_item)) - True + Raises: + NotImplementedError: Must be implemented by subclass + """ + raise NotImplementedError - >>> # Update with multiple fields - >>> item.update(label=2, pred_score=0.8, anomaly_map=torch.rand(224, 224)) - The `update` method can be used to modify single or multiple fields, either - in-place or by creating a new instance. This flexibility is particularly useful - in data processing pipelines and when working with model predictions in Anomalib. +@dataclass +class UpdateMixin: + """Mixin class for dataclasses that allows for in-place replacement of attrs. + + This mixin provides methods for updating dataclass instances in place or by + creating a new instance. + + Example: + >>> @dataclass + ... class MyItem(UpdateMixin): + ... field1: int + ... field2: str + ... + >>> item = MyItem(field1=1, field2="a") + >>> item.update(field1=2) # In-place update + >>> item.field1 + 2 + >>> new_item = item.update(in_place=False, field2="b") + >>> new_item.field2 + 'b' """ def update(self, in_place: bool = True, **changes) -> Any: # noqa: ANN401 - """Replace fields in place and call __post_init__ to reinitialize the instance. + """Replace fields in place and call __post_init__ to reinitialize. - Parameters: - changes (dict): A dictionary of field names and their new values. + Args: + in_place: Whether to modify in place or return new instance + **changes: Field names and new values to update + + Returns: + Updated instance (self if in_place=True, new instance otherwise) + + Raises: + TypeError: If instance is not a dataclass """ if not is_dataclass(self): msg = "replace can only be used with dataclass instances" @@ -483,43 +644,23 @@ class _GenericItem( ): """Generic dataclass for a single item in Anomalib datasets. - This class combines input and output fields for anomaly detection tasks, - providing a comprehensive representation of a single data item. It inherits - from ``_InputFields`` for standard input data and ``_OutputFields`` for - prediction results. - - The class also includes the ``UpdateMixin``, allowing for easy updates of - field values. This is particularly useful during data processing pipelines - and when working with model predictions. - - By using generic types, this class can accommodate various data types used - in different Anomalib models and datasets, ensuring flexibility and - reusability across the library. - - Examples: - Assuming a concrete implementation ``DummyItem``: + This class combines input and output fields for anomaly detection tasks. + It inherits from ``_InputFields`` for standard input data and + ``_OutputFields`` for prediction results. - >>> class DummyItem(_GenericItem): + Example: + >>> class MyItem(_GenericItem[int, Image, Mask, str]): ... def validate_image(self, image): - ... return image # Implement actual validation - ... # Implement other required methods - - >>> # Create a generic item instance - >>> item = DummyItem( - ... image=torch.rand(3, 224, 224), + ... return image + ... # Implement other validation methods + ... + >>> item = MyItem( + ... image=torch.rand(3,224,224), ... gt_label=0, ... pred_score=0.3, - ... anomaly_map=torch.rand(224, 224) + ... anomaly_map=torch.rand(224,224) ... ) - - >>> # Access and update fields - >>> image = item.image - >>> item.update(pred_score=0.8, pred_label=1) - - Note: - This is an abstract base class and is not intended to be instantiated - directly. Concrete subclasses should implement all required validation - methods. + >>> item.update(pred_score=0.8) """ @@ -533,43 +674,19 @@ class _GenericBatch( """Generic dataclass for a batch of items in Anomalib datasets. This class represents a batch of data items, combining both input and output - fields for anomaly detection tasks. It inherits from ``_InputFields`` for - input data and ``_OutputFields`` for prediction results, allowing it to - handle both training data and model outputs. - - The class includes the ``UpdateMixin``, enabling easy updates of field values - across the entire batch. This is particularly useful for in-place modifications - during data processing or when updating predictions. - - Examples: - Assuming a concrete implementation ``DummyBatch``: + fields for anomaly detection tasks. - >>> class DummyBatch(_GenericBatch): + Example: + >>> class MyBatch(_GenericBatch[int, Image, Mask, str]): ... def validate_image(self, image): - ... return image # Implement actual validation - ... # Implement other required methods - - >>> # Create a batch with input data - >>> batch = DummyBatch( - ... image=torch.rand(32, 3, 224, 224), - ... gt_label=torch.randint(0, 2, (32,)) - ... ) - - >>> # Update the entire batch with new predictions - >>> batch.update( - ... pred_score=torch.rand(32), - ... anomaly_map=torch.rand(32, 224, 224) + ... return image + ... # Implement other validation methods + ... + >>> batch = MyBatch( + ... image=torch.rand(32,3,224,224), + ... gt_label=torch.zeros(32), + ... pred_score=torch.rand(32) ... ) - - >>> # Access individual fields - >>> images = batch.image - >>> labels = batch.gt_label - >>> predictions = batch.pred_score - - Note: - This is an abstract base class and is not intended to be instantiated - directly. Concrete subclasses should implement all required validation - methods. """ @@ -581,55 +698,54 @@ class BatchIterateMixin(Generic[ItemT]): """Mixin class for iterating over batches of items in Anomalib datasets. This class provides functionality to iterate over individual items within a - batch, convert batches to lists of items, and determine batch sizes. It's - designed to work with Anomalib's batch processing pipelines. - - The mixin requires subclasses to define an ``item_class`` attribute, which - specifies the class used for individual items in the batch. This ensures - type consistency when iterating or converting batches. + batch and convert batches to lists of items. - Key features include: - - Iteration over batch items - - Conversion of batches to lists of individual items - - Batch size determination - - A class method for collating individual items into a batch - - Examples: - Assuming a subclass `DummyBatch` with `DummyItem` as its item_class: - - >>> batch = DummyBatch(images=[...], labels=[...]) + Attributes: + item_class: Class to use for individual items in the batch + + Example: + >>> @dataclass + ... class MyBatch(BatchIterateMixin): + ... item_class = MyItem + ... data: torch.Tensor + ... + >>> batch = MyBatch(data=torch.rand(32,3,224,224)) >>> for item in batch: - ... process_item(item) # Iterate over items - - >>> item_list = batch.items # Convert batch to list of items - >>> type(item_list[0]) - - - >>> batch_size = len(batch) # Get batch size - - >>> items = [DummyItem(...) for _ in range(5)] - >>> new_batch = DummyBatch.collate(items) # Collate items into a batch - - This mixin enhances batch handling capabilities in Anomalib, facilitating - efficient data processing and model interactions. + ... process_item(item) + >>> items = batch.items # Convert to list of items """ item_class: ClassVar[Callable] def __init_subclass__(cls, **kwargs) -> None: - """Ensure that the subclass has the required attributes.""" + """Ensure that the subclass has the required attributes. + + Args: + **kwargs: Additional keyword arguments + + Raises: + AttributeError: If item_class is not defined + """ super().__init_subclass__(**kwargs) if not (hasattr(cls, "item_class") or issubclass(cls, ABC)): msg = f"{cls.__name__} must have an 'item_class' attribute." raise AttributeError(msg) def __iter__(self) -> Iterator[ItemT]: - """Iterate over the batch.""" + """Iterate over the batch. + + Yields: + Individual items from the batch + """ yield from self.items @property def items(self) -> list[ItemT]: - """Convert the batch to a list of DatasetItem objects.""" + """Convert the batch to a list of DatasetItem objects. + + Returns: + List of individual items from the batch + """ batch_dict = asdict(self) return [ self.item_class( @@ -639,22 +755,40 @@ def items(self) -> list[ItemT]: ] def __len__(self) -> int: - """Get the batch size.""" + """Get the batch size. + + Returns: + Number of items in batch + """ return self.batch_size @property def batch_size(self) -> int: - """Get the batch size.""" + """Get the batch size. + + Returns: + Number of items in batch + + Raises: + AttributeError: If image attribute is not set + """ try: image = getattr(self, "image") # noqa: B009 return len(image) except (KeyError, AttributeError) as e: - msg = "Cannot determine batch size because 'image' attribute has not been set." + msg = "Cannot determine batch size because 'image' attribute not set." raise AttributeError(msg) from e @classmethod def collate(cls: type["BatchIterateMixin"], items: list[ItemT]) -> "BatchIterateMixin": - """Convert a list of DatasetItem objects to a Batch object.""" + """Convert a list of DatasetItem objects to a Batch object. + + Args: + items: List of items to collate into a batch + + Returns: + New batch containing the items + """ keys = [key for key, value in asdict(items[0]).items() if value is not None] out_dict = {key: default_collate([getattr(item, key) for item in items]) for key in keys} return cls(**out_dict) diff --git a/src/anomalib/data/dataclasses/numpy/__init__.py b/src/anomalib/data/dataclasses/numpy/__init__.py index 717e3d6c6e..7b1520d424 100644 --- a/src/anomalib/data/dataclasses/numpy/__init__.py +++ b/src/anomalib/data/dataclasses/numpy/__init__.py @@ -1,17 +1,53 @@ """Numpy-based dataclasses for Anomalib. -This module provides numpy-based implementations of the generic dataclasses -used in Anomalib. These classes are designed to work with numpy arrays for -efficient data handling and processing in anomaly detection tasks. +This module provides numpy-based implementations of the generic dataclasses used in +Anomalib. These classes are designed to work with numpy arrays for efficient data +handling and processing in anomaly detection tasks. The module includes the following main classes: -- NumpyItem: Represents a single item in Anomalib datasets using numpy arrays. -- NumpyBatch: Represents a batch of items in Anomalib datasets using numpy arrays. -- NumpyImageItem: Represents a single image item with additional image-specific fields. -- NumpyImageBatch: Represents a batch of image items with batch operations. -- NumpyVideoItem: Represents a single video item with video-specific fields. -- NumpyVideoBatch: Represents a batch of video items with video-specific operations. +- :class:`NumpyItem`: Base class representing a single item in Anomalib datasets + using numpy arrays. Contains common fields like ``data``, ``label``, + ``label_index``, ``split``, and ``metadata``. + +- :class:`NumpyBatch`: Base class representing a batch of items in Anomalib + datasets using numpy arrays. Provides batch operations and collation + functionality. + +- :class:`NumpyImageItem`: Specialized class for image data that extends + :class:`NumpyItem` with image-specific fields like ``image_path``, ``mask``, + ``mask_path``, ``anomaly_maps``, and ``boxes``. + +- :class:`NumpyImageBatch`: Specialized batch class for image data that extends + :class:`NumpyBatch` with image-specific batch operations and collation. + +- :class:`NumpyVideoItem`: Specialized class for video data that extends + :class:`NumpyItem` with video-specific fields like ``video_path``, ``frames``, + ``frame_masks``, and ``frame_boxes``. + +- :class:`NumpyVideoBatch`: Specialized batch class for video data that extends + :class:`NumpyBatch` with video-specific batch operations and collation. + +Example: + Create and use a numpy image item: + + >>> from anomalib.data.dataclasses.numpy import NumpyImageItem + >>> import numpy as np + >>> item = NumpyImageItem( + ... data=np.random.rand(224, 224, 3), + ... label=0, + ... image_path="path/to/image.jpg" + ... ) + >>> item.data.shape + (224, 224, 3) + +Note: + - All classes in this module use numpy arrays internally for efficient data + handling + - The batch classes provide automatic collation of items into batches suitable + for model input + - The classes are designed to be compatible with Anomalib's data pipeline and + model interfaces """ # Copyright (C) 2024 Intel Corporation @@ -21,4 +57,11 @@ from .image import NumpyImageBatch, NumpyImageItem from .video import NumpyVideoBatch, NumpyVideoItem -__all__ = ["NumpyBatch", "NumpyItem", "NumpyImageBatch", "NumpyImageItem", "NumpyVideoBatch", "NumpyVideoItem"] +__all__ = [ + "NumpyBatch", + "NumpyItem", + "NumpyImageBatch", + "NumpyImageItem", + "NumpyVideoBatch", + "NumpyVideoItem", +] diff --git a/src/anomalib/data/dataclasses/numpy/base.py b/src/anomalib/data/dataclasses/numpy/base.py index a27496f697..3a3fb6ef86 100644 --- a/src/anomalib/data/dataclasses/numpy/base.py +++ b/src/anomalib/data/dataclasses/numpy/base.py @@ -1,8 +1,12 @@ """Numpy-based dataclasses for Anomalib. -This module provides numpy-based implementations of the generic dataclasses -used in Anomalib. These classes are designed to work with numpy arrays for -efficient data handling and processing in anomaly detection tasks. +This module provides numpy-based implementations of the generic dataclasses used in +Anomalib. These classes are designed to work with :class:`numpy.ndarray` objects +for efficient data handling and processing in anomaly detection tasks. + +The module contains two main classes: + - :class:`NumpyItem`: For single data items + - :class:`NumpyBatch`: For batched data items """ # Copyright (C) 2024 Intel Corporation @@ -19,10 +23,18 @@ class NumpyItem(_GenericItem[np.ndarray, np.ndarray, np.ndarray, str]): """Dataclass for a single item in Anomalib datasets using numpy arrays. - This class extends _GenericItem for numpy-based data representation. It includes - both input data (e.g., images, labels) and output data (e.g., predictions, - anomaly maps) as numpy arrays. It is suitable for numpy-based processing - pipelines in Anomalib. + This class extends :class:`_GenericItem` for numpy-based data representation. + It includes both input data (e.g., images, labels) and output data (e.g., + predictions, anomaly maps) as numpy arrays. + + The class uses the following type parameters: + - Image: :class:`numpy.ndarray` + - Label: :class:`numpy.ndarray` + - Mask: :class:`numpy.ndarray` + - Path: :class:`str` + + This implementation is suitable for numpy-based processing pipelines in + Anomalib where GPU acceleration is not required. """ @@ -30,7 +42,16 @@ class NumpyItem(_GenericItem[np.ndarray, np.ndarray, np.ndarray, str]): class NumpyBatch(_GenericBatch[np.ndarray, np.ndarray, np.ndarray, list[str]]): """Dataclass for a batch of items in Anomalib datasets using numpy arrays. - This class extends _GenericBatch for batches of numpy-based data. It represents - multiple data points for batch processing in anomaly detection tasks. It includes - an additional dimension for batch size in all tensor-like fields. + This class extends :class:`_GenericBatch` for batches of numpy-based data. + It represents multiple data points for batch processing in anomaly detection + tasks. + + The class uses the following type parameters: + - Image: :class:`numpy.ndarray` with shape ``(B, C, H, W)`` + - Label: :class:`numpy.ndarray` with shape ``(B,)`` + - Mask: :class:`numpy.ndarray` with shape ``(B, H, W)`` + - Path: :class:`list` of :class:`str` + + Where ``B`` represents the batch dimension that is prepended to all + tensor-like fields. """ diff --git a/src/anomalib/data/dataclasses/numpy/depth.py b/src/anomalib/data/dataclasses/numpy/depth.py index f8bd924c84..2cdf77d1e8 100644 --- a/src/anomalib/data/dataclasses/numpy/depth.py +++ b/src/anomalib/data/dataclasses/numpy/depth.py @@ -1,4 +1,27 @@ -"""Numpy-based depth dataclasses for Anomalib.""" +"""Numpy-based depth dataclasses for Anomalib. + +This module provides numpy-based implementations of depth-specific dataclasses used in +Anomalib. These classes are designed to work with depth data represented as numpy arrays +for anomaly detection tasks. + +The module contains two main classes: + - :class:`NumpyDepthItem`: For single depth data items + - :class:`NumpyDepthBatch`: For batched depth data items + +Example: + Create and use a numpy depth item: + + >>> from anomalib.data.dataclasses.numpy import NumpyDepthItem + >>> import numpy as np + >>> item = NumpyDepthItem( + ... data=np.random.rand(224, 224, 1), + ... depth=np.random.rand(224, 224), + ... label=0, + ... depth_path="path/to/depth.png" + ... ) + >>> item.depth.shape + (224, 224) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -20,9 +43,25 @@ class NumpyDepthItem( ): """Dataclass for a single depth item in Anomalib datasets using numpy arrays. - This class combines _DepthInputFields and NumpyItem for depth-based anomaly detection. - It includes depth-specific fields and validation methods to ensure proper formatting - for Anomalib's depth-based models. + This class combines :class:`_DepthInputFields` and :class:`NumpyItem` for + depth-based anomaly detection. It includes depth-specific fields and validation + methods to ensure proper formatting for Anomalib's depth-based models. + + The class uses the following type parameters: + - Image: :class:`numpy.ndarray` with shape ``(H, W, C)`` + - Depth: :class:`numpy.ndarray` with shape ``(H, W)`` + - Label: :class:`numpy.ndarray` + - Path: :class:`str` + + Example: + >>> import numpy as np + >>> from anomalib.data.dataclasses.numpy import NumpyDepthItem + >>> item = NumpyDepthItem( + ... data=np.random.rand(224, 224, 3), + ... depth=np.random.rand(224, 224), + ... label=0, + ... depth_path="path/to/depth.png" + ... ) """ @@ -32,6 +71,20 @@ class NumpyDepthBatch( _DepthInputFields[np.ndarray, list[str]], NumpyBatch, ): - """Dataclass for a batch of depth items in Anomalib datasets using numpy arrays.""" + """Dataclass for a batch of depth items in Anomalib datasets using numpy arrays. + + This class extends :class:`NumpyBatch` for batches of depth-based data. It + represents multiple depth data points for batch processing in anomaly detection + tasks. + + The class uses the following type parameters: + - Image: :class:`numpy.ndarray` with shape ``(B, C, H, W)`` + - Depth: :class:`numpy.ndarray` with shape ``(B, H, W)`` + - Label: :class:`numpy.ndarray` with shape ``(B,)`` + - Path: :class:`list` of :class:`str` + + Where ``B`` represents the batch dimension that is prepended to all + tensor-like fields. + """ item_class = NumpyDepthItem diff --git a/src/anomalib/data/dataclasses/numpy/image.py b/src/anomalib/data/dataclasses/numpy/image.py index ad71bb4bd8..774e44ff70 100644 --- a/src/anomalib/data/dataclasses/numpy/image.py +++ b/src/anomalib/data/dataclasses/numpy/image.py @@ -1,4 +1,26 @@ -"""Numpy-based image dataclasses for Anomalib.""" +"""Numpy-based image dataclasses for Anomalib. + +This module provides numpy-based implementations of image-specific dataclasses used in +Anomalib. These classes are designed to work with image data represented as numpy arrays +for anomaly detection tasks. + +The module contains two main classes: + - :class:`NumpyImageItem`: For single image data items + - :class:`NumpyImageBatch`: For batched image data items + +Example: + Create and use a numpy image item:: + + >>> from anomalib.data.dataclasses.numpy import NumpyImageItem + >>> import numpy as np + >>> item = NumpyImageItem( + ... data=np.random.rand(224, 224, 3), + ... label=0, + ... image_path="path/to/image.jpg" + ... ) + >>> item.data.shape + (224, 224, 3) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -18,25 +40,26 @@ class NumpyImageItem( ): """Dataclass for a single image item in Anomalib datasets using numpy arrays. - This class combines _ImageInputFields and NumpyItem for image-based anomaly detection. - It includes image-specific fields and validation methods to ensure proper formatting - for Anomalib's image-based models. + This class combines :class:`_ImageInputFields` and :class:`NumpyItem` for + image-based anomaly detection. It includes image-specific fields and validation + methods to ensure proper formatting for Anomalib's image-based models. - Examples: + The class uses the following type parameters: + - Image: :class:`numpy.ndarray` with shape ``(H, W, C)`` + - Label: :class:`numpy.ndarray` + - Mask: :class:`numpy.ndarray` with shape ``(H, W)`` + - Path: :class:`str` + + Example: + >>> import numpy as np + >>> from anomalib.data.dataclasses.numpy import NumpyImageItem >>> item = NumpyImageItem( - ... image=np.random.rand(224, 224, 3), - ... gt_label=np.array(1), - ... gt_mask=np.random.rand(224, 224) > 0.5, - ... anomaly_map=np.random.rand(224, 224), - ... pred_score=np.array(0.7), - ... pred_label=np.array(1), + ... data=np.random.rand(224, 224, 3), + ... label=0, ... image_path="path/to/image.jpg" ... ) - - >>> # Access fields - >>> image = item.image - >>> label = item.gt_label - >>> path = item.image_path + >>> item.data.shape + (224, 224, 3) """ @@ -49,29 +72,29 @@ class NumpyImageBatch( ): """Dataclass for a batch of image items in Anomalib datasets using numpy arrays. - This class combines BatchIterateMixin, _ImageInputFields, and NumpyBatch for batches - of image data. It supports batch operations and iteration over individual NumpyImageItems. - It ensures proper formatting for Anomalib's image-based models. + This class combines :class:`BatchIterateMixin`, :class:`_ImageInputFields`, and + :class:`NumpyBatch` for batches of image data. It supports batch operations and + iteration over individual :class:`NumpyImageItem` instances. - Examples: - >>> batch = NumpyImageBatch( - ... image=np.random.rand(32, 224, 224, 3), - ... gt_label=np.random.randint(0, 2, (32,)), - ... gt_mask=np.random.rand(32, 224, 224) > 0.5, - ... anomaly_map=np.random.rand(32, 224, 224), - ... pred_score=np.random.rand(32), - ... pred_label=np.random.randint(0, 2, (32,)), - ... image_path=["path/to/image_{}.jpg".format(i) for i in range(32)] - ... ) + The class uses the following type parameters: + - Image: :class:`numpy.ndarray` with shape ``(B, H, W, C)`` + - Label: :class:`numpy.ndarray` with shape ``(B,)`` + - Mask: :class:`numpy.ndarray` with shape ``(B, H, W)`` + - Path: :class:`list` of :class:`str` - >>> # Access batch fields - >>> images = batch.image - >>> labels = batch.gt_label - >>> paths = batch.image_path + Where ``B`` represents the batch dimension that is prepended to all tensor-like + fields. - >>> # Iterate over items in the batch - >>> for item in batch: - ... process_item(item) + Example: + >>> import numpy as np + >>> from anomalib.data.dataclasses.numpy import NumpyImageBatch + >>> batch = NumpyImageBatch( + ... data=np.random.rand(32, 224, 224, 3), + ... label=np.zeros(32), + ... image_path=[f"path/to/image_{i}.jpg" for i in range(32)] + ... ) + >>> batch.data.shape + (32, 224, 224, 3) """ item_class = NumpyImageItem diff --git a/src/anomalib/data/dataclasses/numpy/video.py b/src/anomalib/data/dataclasses/numpy/video.py index 34998c00d1..ec5820ef0e 100644 --- a/src/anomalib/data/dataclasses/numpy/video.py +++ b/src/anomalib/data/dataclasses/numpy/video.py @@ -1,4 +1,27 @@ -"""Numpy-based video dataclasses for Anomalib.""" +"""Numpy-based video dataclasses for Anomalib. + +This module provides numpy-based implementations of video-specific dataclasses used in +Anomalib. These classes are designed to work with video data represented as numpy +arrays for anomaly detection tasks. + +The module contains two main classes: + - :class:`NumpyVideoItem`: For single video data items + - :class:`NumpyVideoBatch`: For batched video data items + +Example: + Create and use a numpy video item: + + >>> from anomalib.data.dataclasses.numpy import NumpyVideoItem + >>> import numpy as np + >>> item = NumpyVideoItem( + ... data=np.random.rand(16, 224, 224, 3), # (T, H, W, C) + ... frames=np.random.rand(16, 224, 224, 3), + ... label=0, + ... video_path="path/to/video.mp4" + ... ) + >>> item.frames.shape + (16, 224, 224, 3) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -20,9 +43,17 @@ class NumpyVideoItem( ): """Dataclass for a single video item in Anomalib datasets using numpy arrays. - This class combines _VideoInputFields and NumpyItem for video-based anomaly detection. - It includes video-specific fields and validation methods to ensure proper formatting - for Anomalib's video-based models. + This class combines :class:`_VideoInputFields` and :class:`NumpyItem` for + video-based anomaly detection. It includes video-specific fields and validation + methods to ensure proper formatting for Anomalib's video-based models. + + The class uses the following type parameters: + - Video: :class:`numpy.ndarray` with shape ``(T, H, W, C)`` + - Label: :class:`numpy.ndarray` + - Mask: :class:`numpy.ndarray` with shape ``(T, H, W)`` + - Path: :class:`str` + + Where ``T`` represents the temporal dimension (number of frames). """ @@ -35,9 +66,17 @@ class NumpyVideoBatch( ): """Dataclass for a batch of video items in Anomalib datasets using numpy arrays. - This class combines BatchIterateMixin, _VideoInputFields, and NumpyBatch for batches - of video data. It supports batch operations and iteration over individual NumpyVideoItems. - It ensures proper formatting for Anomalib's video-based models. + This class combines :class:`BatchIterateMixin`, :class:`_VideoInputFields`, and + :class:`NumpyBatch` for batches of video data. It supports batch operations + and iteration over individual :class:`NumpyVideoItem` instances. + + The class uses the following type parameters: + - Video: :class:`numpy.ndarray` with shape ``(B, T, H, W, C)`` + - Label: :class:`numpy.ndarray` with shape ``(B,)`` + - Mask: :class:`numpy.ndarray` with shape ``(B, T, H, W)`` + - Path: :class:`list` of :class:`str` + + Where ``B`` represents the batch dimension and ``T`` the temporal dimension. """ item_class = NumpyVideoItem diff --git a/src/anomalib/data/dataclasses/torch/base.py b/src/anomalib/data/dataclasses/torch/base.py index 77b0cc5022..9139b42f86 100644 --- a/src/anomalib/data/dataclasses/torch/base.py +++ b/src/anomalib/data/dataclasses/torch/base.py @@ -1,8 +1,8 @@ """Torch-based dataclasses for Anomalib. -This module provides PyTorch-based implementations of the generic dataclasses -used in Anomalib. These classes are designed to work with PyTorch tensors for -efficient data handling and processing in anomaly detection tasks. +This module provides PyTorch-based implementations of the generic dataclasses used +in Anomalib. These classes are designed to work with PyTorch tensors for efficient +data handling and processing in anomaly detection tasks. These classes extend the generic dataclasses defined in the Anomalib framework, providing concrete implementations that use PyTorch tensors for tensor-like data. @@ -24,7 +24,18 @@ class InferenceBatch(NamedTuple): - """Batch for use in torch and inference models.""" + """Batch for use in torch and inference models. + + Args: + pred_score (torch.Tensor | None): Predicted anomaly scores. + Defaults to ``None``. + pred_label (torch.Tensor | None): Predicted anomaly labels. + Defaults to ``None``. + anomaly_map (torch.Tensor | None): Generated anomaly maps. + Defaults to ``None``. + pred_mask (torch.Tensor | None): Predicted anomaly masks. + Defaults to ``None``. + """ pred_score: torch.Tensor | None = None pred_label: torch.Tensor | None = None @@ -36,9 +47,9 @@ class InferenceBatch(NamedTuple): class ToNumpyMixin(Generic[NumpyT]): """Mixin for converting torch-based dataclasses to numpy. - This mixin provides functionality to convert PyTorch tensor data to numpy arrays. - It requires the subclass to define a 'numpy_class' attribute specifying the - corresponding numpy-based class. + This mixin provides functionality to convert PyTorch tensor data to numpy + arrays. It requires the subclass to define a ``numpy_class`` attribute + specifying the corresponding numpy-based class. Examples: >>> from anomalib.dataclasses.numpy import NumpyImageItem @@ -47,8 +58,11 @@ class ToNumpyMixin(Generic[NumpyT]): ... numpy_class = NumpyImageItem ... image: torch.Tensor ... gt_label: torch.Tensor - - >>> torch_item = TorchImageItem(image=torch.rand(3, 224, 224), gt_label=torch.tensor(1)) + ... + >>> torch_item = TorchImageItem( + ... image=torch.rand(3, 224, 224), + ... gt_label=torch.tensor(1) + ... ) >>> numpy_item = torch_item.to_numpy() >>> isinstance(numpy_item, NumpyImageItem) True @@ -57,14 +71,25 @@ class ToNumpyMixin(Generic[NumpyT]): numpy_class: ClassVar[Callable] def __init_subclass__(cls, **kwargs) -> None: - """Ensure that the subclass has the required attributes.""" + """Ensure that the subclass has the required attributes. + + Args: + **kwargs: Additional keyword arguments passed to the parent class. + + Raises: + AttributeError: If the subclass does not define ``numpy_class``. + """ super().__init_subclass__(**kwargs) if not hasattr(cls, "numpy_class"): msg = f"{cls.__name__} must have a 'numpy_class' attribute." raise AttributeError(msg) def to_numpy(self) -> NumpyT: - """Convert the batch to a NumpyBatch object.""" + """Convert the batch to a NumpyBatch object. + + Returns: + NumpyT: The converted numpy batch object. + """ batch_dict = asdict(self) for key, value in batch_dict.items(): if isinstance(value, torch.Tensor): @@ -76,41 +101,37 @@ def to_numpy(self) -> NumpyT: @dataclass class DatasetItem(Generic[ImageT], _GenericItem[torch.Tensor, ImageT, Mask, str]): - """Base dataclass for individual items in Anomalib datasets using PyTorch tensors. + """Base dataclass for individual items in Anomalib datasets using PyTorch. - This class extends the generic _GenericItem class to provide a PyTorch-specific - implementation for single data items in Anomalib datasets. It is designed to - handle various types of data (e.g., images, labels, masks) represented as + This class extends the generic ``_GenericItem`` class to provide a + PyTorch-specific implementation for single data items in Anomalib datasets. + It handles various types of data (e.g., images, labels, masks) represented as PyTorch tensors. The class uses generic types to allow flexibility in the image representation, - which can vary depending on the specific use case (e.g., standard images, video clips). - - Attributes: - Inherited from _GenericItem, with PyTorch tensor and Mask types. + which can vary depending on the specific use case (e.g., standard images, + video clips). Note: This class is typically subclassed to create more specific item types - (e.g., ImageItem, VideoItem) with additional fields and methods. + (e.g., ``ImageItem``, ``VideoItem``) with additional fields and methods. """ @dataclass class Batch(Generic[ImageT], _GenericBatch[torch.Tensor, ImageT, Mask, list[str]]): - """Base dataclass for batches of items in Anomalib datasets using PyTorch tensors. + """Base dataclass for batches of items in Anomalib datasets using PyTorch. - This class extends the generic _GenericBatch class to provide a PyTorch-specific - implementation for batches of data in Anomalib datasets. It is designed to - handle collections of data items (e.g., multiple images, labels, masks) + This class extends the generic ``_GenericBatch`` class to provide a + PyTorch-specific implementation for batches of data in Anomalib datasets. + It handles collections of data items (e.g., multiple images, labels, masks) represented as PyTorch tensors. The class uses generic types to allow flexibility in the image representation, - which can vary depending on the specific use case (e.g., standard images, video clips). - - Attributes: - Inherited from _GenericBatch, with PyTorch tensor and Mask types. + which can vary depending on the specific use case (e.g., standard images, + video clips). Note: This class is typically subclassed to create more specific batch types - (e.g., ImageBatch, VideoBatch) with additional fields and methods. + (e.g., ``ImageBatch``, ``VideoBatch``) with additional fields and methods. """ diff --git a/src/anomalib/data/dataclasses/torch/depth.py b/src/anomalib/data/dataclasses/torch/depth.py index 209d5eaf9d..ff9aac8f61 100644 --- a/src/anomalib/data/dataclasses/torch/depth.py +++ b/src/anomalib/data/dataclasses/torch/depth.py @@ -26,13 +26,22 @@ class DepthItem( _DepthInputFields[torch.Tensor, str], DatasetItem[Image], ): - """Dataclass for individual depth items in Anomalib datasets using PyTorch tensors. + """Dataclass for individual depth items in Anomalib datasets using PyTorch. - This class represents a single depth item in Anomalib datasets using PyTorch tensors. - It combines the functionality of ToNumpyMixin, _DepthInputFields, and DatasetItem - to handle depth data, including depth maps, labels, and metadata. + This class represents a single depth item in Anomalib datasets using PyTorch + tensors. It combines the functionality of ``ToNumpyMixin``, + ``_DepthInputFields``, and ``DatasetItem`` to handle depth data, including + depth maps, labels, and metadata. + + Args: + image (torch.Tensor): Image tensor of shape ``(C, H, W)``. + gt_label (torch.Tensor): Ground truth label tensor. + depth_map (torch.Tensor): Depth map tensor of shape ``(H, W)``. + image_path (str): Path to the source image file. + depth_path (str): Path to the depth map file. Examples: + >>> import torch >>> item = DepthItem( ... image=torch.rand(3, 224, 224), ... gt_label=torch.tensor(1), @@ -40,7 +49,6 @@ class DepthItem( ... image_path="path/to/image.jpg", ... depth_path="path/to/depth.png" ... ) - >>> print(item.image.shape, item.depth_map.shape) torch.Size([3, 224, 224]) torch.Size([224, 224]) """ @@ -55,13 +63,22 @@ class DepthBatch( _DepthInputFields[torch.Tensor, list[str]], Batch[Image], ): - """Dataclass for batches of depth items in Anomalib datasets using PyTorch tensors. + """Dataclass for batches of depth items in Anomalib datasets using PyTorch. + + This class represents a batch of depth items in Anomalib datasets using + PyTorch tensors. It combines the functionality of ``BatchIterateMixin``, + ``_DepthInputFields``, and ``Batch`` to handle batches of depth data, + including depth maps, labels, and metadata. - This class represents a batch of depth items in Anomalib datasets using PyTorch tensors. - It combines the functionality of BatchIterateMixin, _DepthInputFields, and Batch - to handle batches of depth data, including depth maps, labels, and metadata. + Args: + image (torch.Tensor): Batch of images of shape ``(B, C, H, W)``. + gt_label (torch.Tensor): Batch of ground truth labels of shape ``(B,)``. + depth_map (torch.Tensor): Batch of depth maps of shape ``(B, H, W)``. + image_path (list[str]): List of paths to the source image files. + depth_path (list[str]): List of paths to the depth map files. Examples: + >>> import torch >>> batch = DepthBatch( ... image=torch.rand(32, 3, 224, 224), ... gt_label=torch.randint(0, 2, (32,)), @@ -69,10 +86,8 @@ class DepthBatch( ... image_path=["path/to/image_{}.jpg".format(i) for i in range(32)], ... depth_path=["path/to/depth_{}.png".format(i) for i in range(32)] ... ) - >>> print(batch.image.shape, batch.depth_map.shape) torch.Size([32, 3, 224, 224]) torch.Size([32, 224, 224]) - >>> for item in batch: ... print(item.image.shape, item.depth_map.shape) torch.Size([3, 224, 224]) torch.Size([224, 224]) diff --git a/src/anomalib/data/dataclasses/torch/image.py b/src/anomalib/data/dataclasses/torch/image.py index 3f9cdcc9f0..162fb70042 100644 --- a/src/anomalib/data/dataclasses/torch/image.py +++ b/src/anomalib/data/dataclasses/torch/image.py @@ -3,6 +3,23 @@ This module provides PyTorch-based implementations of the generic dataclasses used in Anomalib for image data. These classes are designed to work with PyTorch tensors for efficient data handling and processing in anomaly detection tasks. + +The module contains two main classes: + - :class:`ImageItem`: For single image data items + - :class:`ImageBatch`: For batched image data items + +Example: + Create and use a torch image item:: + + >>> from anomalib.data.dataclasses.torch import ImageItem + >>> import torch + >>> item = ImageItem( + ... image=torch.rand(3, 224, 224), + ... gt_label=torch.tensor(0), + ... image_path="path/to/image.jpg" + ... ) + >>> item.image.shape + torch.Size([3, 224, 224]) """ # Copyright (C) 2024 Intel Corporation @@ -25,36 +42,33 @@ class ImageItem( _ImageInputFields[str], DatasetItem[Image], ): - """Dataclass for individual image items in Anomalib datasets using PyTorch tensors. - - This class combines the functionality of ToNumpyMixin, _ImageInputFields, and - DatasetItem to represent single image data points in Anomalib. It includes - image-specific fields and provides methods for data validation and conversion - to numpy format. + """Dataclass for individual image items in Anomalib datasets using PyTorch. - The class is designed to work with PyTorch tensors and includes fields for - the image data, ground truth labels and masks, anomaly maps, and related metadata. + This class combines :class:`_ImageInputFields` and :class:`DatasetItem` for + image-based anomaly detection. It includes image-specific fields and validation + methods to ensure proper formatting for Anomalib's image-based models. - Attributes: - Inherited from _ImageInputFields and DatasetItem. + The class uses the following type parameters: + - Image: :class:`torch.Tensor` with shape ``(C, H, W)`` + - Label: :class:`torch.Tensor` + - Mask: :class:`torch.Tensor` with shape ``(H, W)`` + - Path: :class:`str` - Methods: - Inherited from ToNumpyMixin, including to_numpy() for conversion to numpy format. - - Examples: + Example: + >>> import torch + >>> from anomalib.data.dataclasses.torch import ImageItem >>> item = ImageItem( ... image=torch.rand(3, 224, 224), - ... gt_label=torch.tensor(1), - ... gt_mask=torch.rand(224, 224) > 0.5, + ... gt_label=torch.tensor(0), ... image_path="path/to/image.jpg" ... ) - - >>> print(item.image.shape) + >>> item.image.shape torch.Size([3, 224, 224]) + Convert to numpy format: >>> numpy_item = item.to_numpy() - >>> print(type(numpy_item)) - + >>> type(numpy_item).__name__ + 'NumpyImageItem' """ numpy_class = NumpyImageItem @@ -68,34 +82,39 @@ class ImageBatch( _ImageInputFields[list[str]], Batch[Image], ): - """Dataclass for batches of image items in Anomalib datasets using PyTorch tensors. + """Dataclass for batches of image items in Anomalib datasets using PyTorch. - This class combines the functionality of ``ToNumpyMixin``, ``BatchIterateMixin``, - ``_ImageInputFields``, and ``Batch`` to represent collections of image data points in Anomalib. - It includes image-specific fields and provides methods for batch operations, - iteration over individual items, and conversion to numpy format. + This class combines :class:`_ImageInputFields` and :class:`Batch` for batches + of image data. It includes image-specific fields and methods for batch + operations and iteration. - The class is designed to work with PyTorch tensors and includes fields for - batches of image data, ground truth labels and masks, anomaly maps, and related metadata. + The class uses the following type parameters: + - Image: :class:`torch.Tensor` with shape ``(B, C, H, W)`` + - Label: :class:`torch.Tensor` with shape ``(B,)`` + - Mask: :class:`torch.Tensor` with shape ``(B, H, W)`` + - Path: :class:`list` of :class:`str` - Examples: + Where ``B`` represents the batch dimension. + + Example: + >>> import torch + >>> from anomalib.data.dataclasses.torch import ImageBatch >>> batch = ImageBatch( ... image=torch.rand(32, 3, 224, 224), ... gt_label=torch.randint(0, 2, (32,)), - ... gt_mask=torch.rand(32, 224, 224) > 0.5, - ... image_path=["path/to/image_{}.jpg".format(i) for i in range(32)] + ... image_path=[f"path/to/image_{i}.jpg" for i in range(32)] ... ) - - >>> print(batch.image.shape) + >>> batch.image.shape torch.Size([32, 3, 224, 224]) + Iterate over batch: >>> for item in batch: - ... print(item.image.shape) - torch.Size([3, 224, 224]) + ... assert item.image.shape == torch.Size([3, 224, 224]) + Convert to numpy format: >>> numpy_batch = batch.to_numpy() - >>> print(type(numpy_batch)) - + >>> type(numpy_batch).__name__ + 'NumpyImageBatch' """ item_class = ImageItem diff --git a/src/anomalib/data/dataclasses/torch/video.py b/src/anomalib/data/dataclasses/torch/video.py index 324fb45ca1..baad5ee118 100644 --- a/src/anomalib/data/dataclasses/torch/video.py +++ b/src/anomalib/data/dataclasses/torch/video.py @@ -3,6 +3,23 @@ This module provides PyTorch-based implementations of the generic dataclasses used in Anomalib for video data. These classes are designed to work with PyTorch tensors for efficient data handling and processing in anomaly detection tasks. + +The module contains two main classes: + - :class:`VideoItem`: For single video data items + - :class:`VideoBatch`: For batched video data items + +Example: + Create and use a torch video item:: + + >>> from anomalib.data.dataclasses.torch import VideoItem + >>> import torch + >>> item = VideoItem( + ... image=torch.rand(10, 3, 224, 224), # 10 frames + ... gt_label=torch.tensor(0), + ... video_path="path/to/video.mp4" + ... ) + >>> item.image.shape + torch.Size([10, 3, 224, 224]) """ # Copyright (C) 2024 Intel Corporation @@ -27,26 +44,36 @@ class VideoItem( _VideoInputFields[torch.Tensor, Video, Mask, str], DatasetItem[Video], ): - """Dataclass for individual video items in Anomalib datasets using PyTorch tensors. + """Dataclass for individual video items in Anomalib datasets using PyTorch. + + This class combines :class:`_VideoInputFields` and :class:`DatasetItem` for + video-based anomaly detection. It includes video-specific fields and + validation methods to ensure proper formatting for Anomalib's video-based + models. - This class represents a single video item in Anomalib datasets using PyTorch tensors. - It combines the functionality of ToNumpyMixin, _VideoInputFields, and DatasetItem - to handle video data, including frames, labels, masks, and metadata. + The class uses the following type parameters: + - Video: :class:`torch.Tensor` with shape ``(T, C, H, W)`` + - Label: :class:`torch.Tensor` + - Mask: :class:`torch.Tensor` with shape ``(T, H, W)`` + - Path: :class:`str` - Examples: + Where ``T`` represents the temporal dimension (number of frames). + + Example: + >>> import torch + >>> from anomalib.data.dataclasses.torch import VideoItem >>> item = VideoItem( ... image=torch.rand(10, 3, 224, 224), # 10 frames - ... gt_label=torch.tensor(1), - ... gt_mask=torch.rand(10, 224, 224) > 0.5, + ... gt_label=torch.tensor(0), ... video_path="path/to/video.mp4" ... ) - - >>> print(item.image.shape) + >>> item.image.shape torch.Size([10, 3, 224, 224]) + Convert to numpy format: >>> numpy_item = item.to_numpy() - >>> print(type(numpy_item)) - + >>> type(numpy_item).__name__ + 'NumpyVideoItem' """ numpy_class = NumpyVideoItem @@ -65,30 +92,39 @@ class VideoBatch( _VideoInputFields[torch.Tensor, Video, Mask, list[str]], Batch[Video], ): - """Dataclass for batches of video items in Anomalib datasets using PyTorch tensors. + """Dataclass for batches of video items in Anomalib datasets using PyTorch. - This class represents a batch of video items in Anomalib datasets using PyTorch tensors. - It combines the functionality of ToNumpyMixin, BatchIterateMixin, _VideoInputFields, - and Batch to handle batches of video data, including frames, labels, masks, and metadata. + This class represents batches of video data for batch processing in anomaly + detection tasks. It combines functionality from multiple mixins to handle + batched video data efficiently. - Examples: + The class uses the following type parameters: + - Video: :class:`torch.Tensor` with shape ``(B, T, C, H, W)`` + - Label: :class:`torch.Tensor` with shape ``(B,)`` + - Mask: :class:`torch.Tensor` with shape ``(B, T, H, W)`` + - Path: :class:`list` of :class:`str` + + Where ``B`` represents the batch dimension and ``T`` the temporal dimension. + + Example: + >>> import torch + >>> from anomalib.data.dataclasses.torch import VideoBatch >>> batch = VideoBatch( - ... image=torch.rand(32, 10, 3, 224, 224), # 32 videos, 10 frames each + ... image=torch.rand(32, 10, 3, 224, 224), # 32 videos, 10 frames ... gt_label=torch.randint(0, 2, (32,)), - ... gt_mask=torch.rand(32, 10, 224, 224) > 0.5, - ... video_path=["path/to/video_{}.mp4".format(i) for i in range(32)] + ... video_path=["video_{}.mp4".format(i) for i in range(32)] ... ) - - >>> print(batch.image.shape) + >>> batch.image.shape torch.Size([32, 10, 3, 224, 224]) - >>> for item in batch: - ... print(item.image.shape) + Iterate over items in batch: + >>> next(iter(batch)).image.shape torch.Size([10, 3, 224, 224]) + Convert to numpy format: >>> numpy_batch = batch.to_numpy() - >>> print(type(numpy_batch)) - + >>> type(numpy_batch).__name__ + 'NumpyVideoBatch' """ item_class = VideoItem diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 5c28cd4557..87bbfc17c6 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -1,4 +1,26 @@ -"""Anomalib datamodule base class.""" +"""Base Anomalib data module. + +This module provides the base data module class used across Anomalib. It handles +dataset splitting, validation set creation, and dataloader configuration. + +The module contains: + - :class:`AnomalibDataModule`: Base class for all Anomalib data modules + +Example: + Create a datamodule from a config file:: + + >>> from anomalib.data import AnomalibDataModule + >>> data_config = "configs/data/mvtec.yaml" + >>> datamodule = AnomalibDataModule.from_config(config_path=data_config) + + Override config with additional arguments:: + + >>> override_kwargs = {"data.train_batch_size": 8} + >>> datamodule = AnomalibDataModule.from_config( + ... config_path=data_config, + ... **override_kwargs + ... ) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -28,19 +50,29 @@ class AnomalibDataModule(LightningDataModule, ABC): """Base Anomalib data module. + This class extends PyTorch Lightning's ``LightningDataModule`` to provide + common functionality for anomaly detection datasets. + Args: - train_batch_size (int): Batch size used by the train dataloader. - eval_batch_size (int): Batch size used by the val and test dataloaders. - num_workers (int): Number of workers used by the train, val and test dataloaders. - val_split_mode (ValSplitMode): Determines how the validation split is obtained. - Options: [none, same_as_test, from_test, synthetic] - val_split_ratio (float): Fraction of the train or test images held our for validation. - test_split_mode (Optional[TestSplitMode], optional): Determines how the test split is obtained. - Options: [none, from_dir, synthetic]. + train_batch_size (int): Batch size for training dataloader + eval_batch_size (int): Batch size for validation/test dataloaders + num_workers (int): Number of workers for all dataloaders + val_split_mode (ValSplitMode | str): Method to obtain validation set. + Options: + - ``none``: No validation set + - ``same_as_test``: Use test set as validation + - ``from_test``: Sample from test set + - ``synthetic``: Generate synthetic anomalies + val_split_ratio (float): Fraction of data to use for validation + test_split_mode (TestSplitMode | str | None): Method to obtain test set. + Options: + - ``none``: No test split + - ``from_dir``: Use separate test directory + - ``synthetic``: Generate synthetic anomalies Defaults to ``None``. - test_split_ratio (float): Fraction of the train images held out for testing. + test_split_ratio (float | None): Fraction of data to use for testing. Defaults to ``None``. - seed (int | None, optional): Seed used during random subset splitting. + seed (int | None): Random seed for reproducible splitting. Defaults to ``None``. """ @@ -72,18 +104,25 @@ def __init__( self._samples: DataFrame | None = None self._category: str = "" - self._is_setup = False # flag to track if setup has been called from the trainer + self._is_setup = False # flag to track if setup has been called @property def name(self) -> str: - """Name of the datamodule.""" + """Name of the datamodule. + + Returns: + str: Class name of the datamodule + """ return self.__class__.__name__ def setup(self, stage: str | None = None) -> None: """Set up train, validation and test data. + This method handles the data splitting logic based on the configured + modes. + Args: - stage: str | None: Train/Val/Test stages. + stage (str | None): Current stage (fit/validate/test/predict). Defaults to ``None``. """ has_subset = any(hasattr(self, subset) for subset in ["train_data", "val_data", "test_data"]) @@ -92,37 +131,57 @@ def setup(self, stage: str | None = None) -> None: self._create_test_split() self._create_val_split() if isinstance(stage, TrainerFn): - # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer + # only set flag if called from trainer self._is_setup = True @abstractmethod def _setup(self, _stage: str | None = None) -> None: """Set up the datasets and perform dynamic subset splitting. - This method may be overridden in subclass for custom splitting behaviour. + This method should be implemented by subclasses to define dataset-specific + setup logic. Note: - The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule - subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset - splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and - the test set must therefore be created as early as the `fit` stage. + The ``stage`` argument is not used since all subsets are created on + first call to accommodate validation set extraction from test set. + Args: + _stage (str | None): Current stage (unused). + Defaults to ``None``. + + Raises: + NotImplementedError: When not implemented by subclass """ raise NotImplementedError @property def category(self) -> str: - """Get the category of the datamodule.""" + """Get dataset category name. + + Returns: + str: Name of the current category + """ return self._category @category.setter def category(self, category: str) -> None: - """Set the category of the datamodule.""" + """Set dataset category name. + + Args: + category (str): Category name to set + """ self._category = category @property def task(self) -> TaskType: - """Get the task type of the datamodule.""" + """Get the task type. + + Returns: + TaskType: Type of anomaly task (classification/segmentation) + + Raises: + AttributeError: If no datasets have been set up yet + """ if hasattr(self, "train_data"): return self.train_data.task if hasattr(self, "val_data"): @@ -133,19 +192,26 @@ def task(self) -> TaskType: raise AttributeError(msg) def _create_test_split(self) -> None: - """Obtain the test set based on the settings in the config.""" + """Create the test split based on configured mode. + + This handles splitting normal/anomalous samples and optionally creating + synthetic anomalies. + """ if self.test_data.has_normal: - # split the test data into normal and anomalous so these can be processed separately + # split test data into normal and anomalous normal_test_data, self.test_data = split_by_label(self.test_data) elif self.test_split_mode != TestSplitMode.NONE: - # when the user did not provide any normal images for testing, we sample some from the training set, - # except when the user explicitly requested no test splitting. + # sample normal images from training set if none provided logger.info( - "No normal test images found. Sampling from training set using a split ratio of %0.2f", + "No normal test images found. Sampling from training set using ratio of %0.2f", self.test_split_ratio, ) if self.test_split_ratio is not None: - self.train_data, normal_test_data = random_split(self.train_data, self.test_split_ratio, seed=self.seed) + self.train_data, normal_test_data = random_split( + self.train_data, + self.test_split_ratio, + seed=self.seed, + ) if self.test_split_mode == TestSplitMode.FROM_DIR: self.test_data += normal_test_data @@ -156,9 +222,13 @@ def _create_test_split(self) -> None: raise ValueError(msg) def _create_val_split(self) -> None: - """Obtain the validation set based on the settings in the config.""" + """Create validation split based on configured mode. + + This handles sampling from train/test sets and optionally creating + synthetic anomalies. + """ if self.val_split_mode == ValSplitMode.FROM_TRAIN: - # randomly sampled from train set + # randomly sample from train set self.train_data, self.val_data = random_split( self.train_data, self.val_split_ratio, @@ -166,7 +236,7 @@ def _create_val_split(self) -> None: seed=self.seed, ) elif self.val_split_mode == ValSplitMode.FROM_TEST: - # randomly sampled from test set + # randomly sample from test set self.test_data, self.val_data = random_split( self.test_data, self.val_split_ratio, @@ -174,18 +244,26 @@ def _create_val_split(self) -> None: seed=self.seed, ) elif self.val_split_mode == ValSplitMode.SAME_AS_TEST: - # equal to test set + # use test set as validation self.val_data = self.test_data elif self.val_split_mode == ValSplitMode.SYNTHETIC: - # converted from random training sample - self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) + # create synthetic anomalies from training samples + self.train_data, normal_val_data = random_split( + self.train_data, + self.val_split_ratio, + seed=self.seed, + ) self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) elif self.val_split_mode != ValSplitMode.NONE: msg = f"Unknown validation split mode: {self.val_split_mode}" raise ValueError(msg) def train_dataloader(self) -> TRAIN_DATALOADERS: - """Get train dataloader.""" + """Get training dataloader. + + Returns: + DataLoader: Training dataloader + """ return DataLoader( dataset=self.train_data, shuffle=True, @@ -195,7 +273,11 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: ) def val_dataloader(self) -> EVAL_DATALOADERS: - """Get validation dataloader.""" + """Get validation dataloader. + + Returns: + DataLoader: Validation dataloader + """ return DataLoader( dataset=self.val_data, shuffle=False, @@ -205,7 +287,11 @@ def val_dataloader(self) -> EVAL_DATALOADERS: ) def test_dataloader(self) -> EVAL_DATALOADERS: - """Get test dataloader.""" + """Get test dataloader. + + Returns: + DataLoader: Test dataloader + """ return DataLoader( dataset=self.test_data, shuffle=False, @@ -215,7 +301,13 @@ def test_dataloader(self) -> EVAL_DATALOADERS: ) def predict_dataloader(self) -> EVAL_DATALOADERS: - """Use the test dataloader for inference unless overridden.""" + """Get prediction dataloader. + + By default uses the test dataloader. + + Returns: + DataLoader: Prediction dataloader + """ return self.test_dataloader() @classmethod @@ -224,27 +316,31 @@ def from_config( config_path: str | Path, **kwargs, ) -> "AnomalibDataModule": - """Create a datamodule instance from the configuration. + """Create datamodule instance from config file. Args: - config_path (str | Path): Path to the data configuration file. - **kwargs (dict): Additional keyword arguments. + config_path (str | Path): Path to config file + **kwargs: Additional args to override config Returns: - AnomalibDataModule: Datamodule instance. + AnomalibDataModule: Instantiated datamodule + + Raises: + FileNotFoundError: If config file not found + ValueError: If instantiated object is not AnomalibDataModule Example: - The following example shows how to get datamodule from mvtec.yaml: + Load from config file:: - .. code-block:: python - >>> data_config = "configs/data/mvtec.yaml" - >>> datamodule = AnomalibDataModule.from_config(config_path=data_config) + >>> config_path = "configs/data/mvtec.yaml" + >>> datamodule = AnomalibDataModule.from_config(config_path) - The following example shows overriding the configuration file with additional keyword arguments: + Override config values:: - .. code-block:: python - >>> override_kwargs = {"data.train_batch_size": 8} - >>> datamodule = AnomalibDataModule.from_config(config_path=data_config, **override_kwargs) + >>> datamodule = AnomalibDataModule.from_config( + ... config_path, + ... data_train_batch_size=8 + ... ) """ from jsonargparse import ArgumentParser diff --git a/src/anomalib/data/datamodules/base/video.py b/src/anomalib/data/datamodules/base/video.py index 3bc7af6772..3e86d4f09b 100644 --- a/src/anomalib/data/datamodules/base/video.py +++ b/src/anomalib/data/datamodules/base/video.py @@ -1,4 +1,18 @@ -"""Base Video Data Module.""" +"""Base Video Data Module. + +This module provides the base data module class for video datasets in Anomalib. +It extends :class:`AnomalibDataModule` with video-specific functionality. + +The module contains: + - :class:`AnomalibVideoDataModule`: Base class for all video data modules + +Example: + Create a video datamodule from a config file:: + + >>> from anomalib.data import AnomalibVideoDataModule + >>> data_config = "configs/data/ucsd_ped.yaml" + >>> datamodule = AnomalibVideoDataModule.from_config(config_path=data_config) +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,17 +23,33 @@ class AnomalibVideoDataModule(AnomalibDataModule): - """Base class for video data modules.""" + """Base class for video data modules. + + This class extends :class:`AnomalibDataModule` to handle video datasets. + Unlike image datasets, video datasets do not support dynamic test split + assignment or synthetic anomaly generation. + """ def _create_test_split(self) -> None: - """Video datamodules do not support dynamic assignment of the test split.""" + """Video datamodules do not support dynamic assignment of test split. + + Video datasets typically come with predefined train/test splits due to + temporal dependencies between frames. + """ def _setup(self, _stage: str | None = None) -> None: - """Set up the datasets and perform dynamic subset splitting. + """Set up video datasets and perform validation split. + + This method initializes the train and test datasets and creates the + validation split if specified. It ensures that both train and test + datasets are properly defined and configured. - This method may be overridden in subclass for custom splitting behaviour. + Args: + _stage: Current stage of training. Defaults to ``None``. - Video datamodules are not compatible with synthetic anomaly generation. + Raises: + ValueError: If ``train_data`` or ``test_data`` is ``None``. + ValueError: If ``val_split_mode`` is set to ``SYNTHETIC``. """ if self.train_data is None: msg = "self.train_data cannot be None." diff --git a/src/anomalib/data/datamodules/depth/__init__.py b/src/anomalib/data/datamodules/depth/__init__.py index b7f24ab8d1..0f4c5199a7 100644 --- a/src/anomalib/data/datamodules/depth/__init__.py +++ b/src/anomalib/data/datamodules/depth/__init__.py @@ -1,6 +1,6 @@ """Anomalib Depth Data Modules.""" -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from enum import Enum diff --git a/src/anomalib/data/datamodules/depth/folder_3d.py b/src/anomalib/data/datamodules/depth/folder_3d.py index f475c26bd8..fd1fd7afff 100644 --- a/src/anomalib/data/datamodules/depth/folder_3d.py +++ b/src/anomalib/data/datamodules/depth/folder_3d.py @@ -1,6 +1,18 @@ -"""Custom Folder Datamodule. +"""Custom Folder Datamodule for 3D data. -This script creates a custom datamodule from a folder. +This module provides a custom datamodule for handling 3D data organized in folders. +The datamodule supports RGB and depth image pairs for anomaly detection tasks. + +Example: + Create a folder 3D datamodule:: + + >>> from anomalib.data import Folder3D + >>> datamodule = Folder3D( + ... name="my_dataset", + ... root="path/to/dataset", + ... normal_dir="normal", + ... abnormal_dir="abnormal" + ... ) """ # Copyright (C) 2022-2024 Intel Corporation @@ -14,47 +26,45 @@ class Folder3D(AnomalibDataModule): - """Folder DataModule. + """Folder DataModule for 3D data. + + This class extends :class:`AnomalibDataModule` to handle datasets containing + RGB and depth image pairs organized in folders. Args: - name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. - normal_dir (str | Path): Name of the directory containing normal images. - root (str | Path | None): Path to the root folder containing normal and abnormal dirs. - Defaults to ``None``. - abnormal_dir (str | Path | None): Name of the directory containing abnormal images. - Defaults to ``abnormal``. - normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test - dataset. - Defaults to ``None``. - mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations. - Defaults to ``None``. - normal_depth_dir (str | Path | None, optional): Path to the directory containing - normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir` - abnormal_depth_dir (str | Path | None, optional): Path to the directory containing - abnormal depth images for the test dataset. - normal_test_depth_dir (str | Path | None, optional): Path to the directory containing - normal depth images for the test dataset. Normal test images will be a split of `normal_dir` - if `None`. Defaults to None. - normal_split_ratio (float, optional): Ratio to split normal training images and add to the - test set in case test set doesn't contain any normal images. - Defaults to 0.2. - extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the - directory. Defaults to None. + name (str): Name of the dataset used for logging and saving. + normal_dir (str | Path): Directory containing normal RGB images. + root (str | Path): Root folder containing normal and abnormal dirs. + abnormal_dir (str | Path | None, optional): Directory containing abnormal + RGB images. Defaults to ``None``. + normal_test_dir (str | Path | None, optional): Directory containing normal + RGB images for testing. Defaults to ``None``. + mask_dir (str | Path | None, optional): Directory containing mask + annotations. Defaults to ``None``. + normal_depth_dir (str | Path | None, optional): Directory containing + normal depth images. Defaults to ``None``. + abnormal_depth_dir (str | Path | None, optional): Directory containing + abnormal depth images. Defaults to ``None``. + normal_test_depth_dir (str | Path | None, optional): Directory containing + normal depth images for testing. If ``None``, uses split from + ``normal_dir``. Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Image file extensions to + read. Defaults to ``None``. train_batch_size (int, optional): Training batch size. Defaults to ``32``. - eval_batch_size (int, optional): Test batch size. + eval_batch_size (int, optional): Evaluation batch size. Defaults to ``32``. - num_workers (int, optional): Number of workers. + num_workers (int, optional): Number of workers for data loading. Defaults to ``8``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. - Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + test_split_mode (TestSplitMode | str, optional): Method to create test + set. Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float, optional): Fraction of data for testing. Defaults to ``0.2``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. - Defaults to ``ValSplitMode.FROM_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + val_split_mode (ValSplitMode | str, optional): Method to create validation + set. Defaults to ``ValSplitMode.FROM_TEST``. + val_split_ratio (float, optional): Fraction of data for validation. Defaults to ``0.5``. - seed (int | None, optional): Seed used during random subset splitting. + seed (int | None, optional): Random seed for splitting. Defaults to ``None``. """ @@ -101,6 +111,12 @@ def __init__( self.extensions = extensions def _setup(self, _stage: str | None = None) -> None: + """Set up train and test datasets. + + Args: + _stage (str | None, optional): Stage of setup. Not used. + Defaults to ``None``. + """ self.train_data = Folder3DDataset( name=self.name, split=Split.TRAIN, @@ -131,8 +147,9 @@ def _setup(self, _stage: str | None = None) -> None: @property def name(self) -> str: - """Name of the datamodule. + """Get name of the datamodule. - Folder3D datamodule overrides the name property to provide a custom name. + Returns: + str: Name of the datamodule. """ return self._name diff --git a/src/anomalib/data/datamodules/depth/mvtec_3d.py b/src/anomalib/data/datamodules/depth/mvtec_3d.py index 400b1d3139..afdf981d96 100644 --- a/src/anomalib/data/datamodules/depth/mvtec_3d.py +++ b/src/anomalib/data/datamodules/depth/mvtec_3d.py @@ -1,19 +1,30 @@ -"""MVTec 3D-AD Datamodule (CC BY-NC-SA 4.0). +"""MVTec 3D-AD Datamodule. -Description: - This script contains PyTorch Dataset, Dataloader and PyTorch Lightning DataModule for the MVTec 3D-AD dataset. - If the dataset is not on the file system, the script downloads and extracts the dataset and create PyTorch data - objects. +This module provides a PyTorch Lightning DataModule for the MVTec 3D-AD dataset. +The dataset contains RGB and depth image pairs for anomaly detection tasks. + +Example: + Create a MVTec3D datamodule:: + + >>> from anomalib.data import MVTec3D + >>> datamodule = MVTec3D( + ... root="./datasets/MVTec3D", + ... category="bagel" + ... ) License: - MVTec 3D-AD dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International - License (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + MVTec 3D-AD dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0). + https://creativecommons.org/licenses/by-nc-sa/4.0/ Reference: - - Paul Bergmann, Xin Jin, David Sattlegger, Carsten Steger: The MVTec 3D-AD Dataset for Unsupervised 3D Anomaly - Detection and Localization in: Proceedings of the 17th International Joint Conference on Computer Vision, - Imaging and Computer Graphics Theory and Applications - Volume 5: VISAPP, 202-213, 2022, DOI: 10.5220/ - 0010865000003124. + Paul Bergmann, Xin Jin, David Sattlegger, Carsten Steger: + The MVTec 3D-AD Dataset for Unsupervised 3D Anomaly Detection and + Localization. In: Proceedings of the 17th International Joint Conference + on Computer Vision, Imaging and Computer Graphics Theory and Applications + - Volume 5: VISAPP, 202-213, 2022. + DOI: 10.5220/0010865000003124 """ # Copyright (C) 2022-2024 Intel Corporation @@ -38,28 +49,28 @@ class MVTec3D(AnomalibDataModule): - """MVTec Datamodule. + """MVTec 3D-AD Datamodule. Args: - root (Path | str): Path to the root of the dataset + root (Path | str): Path to the root of the dataset. Defaults to ``"./datasets/MVTec3D"``. - category (str): Category of the MVTec dataset (e.g. "bottle" or "cable"). - Defaults to ``bagel``. + category (str): Category of the MVTec3D dataset (e.g. ``"bottle"`` or + ``"cable"``). Defaults to ``"bagel"``. train_batch_size (int, optional): Training batch size. Defaults to ``32``. eval_batch_size (int, optional): Test batch size. Defaults to ``32``. - num_workers (int, optional): Number of workers. + num_workers (int, optional): Number of workers for data loading. Defaults to ``8``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + test_split_mode (TestSplitMode | str): Method to create test set. Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + test_split_ratio (float): Fraction of data to use for testing. Defaults to ``0.2``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + val_split_mode (ValSplitMode | str): Method to create validation set. Defaults to ``ValSplitMode.SAME_AS_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + val_split_ratio (float): Fraction of data to use for validation. Defaults to ``0.5``. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + seed (int | None, optional): Random seed for reproducibility. Defaults to ``None``. """ @@ -91,6 +102,12 @@ def __init__( self.category = category def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets. + + Args: + _stage (str | None, optional): Stage of setup. Not used. + Defaults to ``None``. + """ self.train_data = MVTec3DDataset( split=Split.TRAIN, root=self.root, diff --git a/src/anomalib/data/datamodules/image/__init__.py b/src/anomalib/data/datamodules/image/__init__.py index 69221f863c..deb98863ed 100644 --- a/src/anomalib/data/datamodules/image/__init__.py +++ b/src/anomalib/data/datamodules/image/__init__.py @@ -1,4 +1,24 @@ -"""Anomalib Image Data Modules.""" +"""Anomalib Image Data Modules. + +This module contains data modules for loading and processing image datasets for +anomaly detection. The following data modules are available: + +- ``BTech``: BTech Surface Defect Dataset +- ``Datumaro``: Dataset in Datumaro format (Intel Geti™ export) +- ``Folder``: Custom folder structure with normal/abnormal images +- ``Kolektor``: Kolektor Surface-Defect Dataset +- ``MVTec``: MVTec Anomaly Detection Dataset +- ``Visa``: Visual Inspection for Steel Anomaly Dataset + +Example: + Load the MVTec dataset:: + + >>> from anomalib.data import MVTec + >>> datamodule = MVTec( + ... root="./datasets/MVTec", + ... category="bottle" + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -14,7 +34,19 @@ class ImageDataFormat(str, Enum): - """Supported Image Dataset Types.""" + """Supported Image Dataset Types. + + The following dataset formats are supported: + + - ``BTECH``: BTech Surface Defect Dataset + - ``DATUMARO``: Dataset in Datumaro format + - ``FOLDER``: Custom folder structure + - ``FOLDER_3D``: Custom folder structure for 3D images + - ``KOLEKTOR``: Kolektor Surface-Defect Dataset + - ``MVTEC``: MVTec AD Dataset + - ``MVTEC_3D``: MVTec 3D AD Dataset + - ``VISA``: Visual Inspection for Steel Anomaly Dataset + """ BTECH = "btech" DATUMARO = "datumaro" diff --git a/src/anomalib/data/datamodules/image/btech.py b/src/anomalib/data/datamodules/image/btech.py index 4ec0527f16..367b6a1489 100644 --- a/src/anomalib/data/datamodules/image/btech.py +++ b/src/anomalib/data/datamodules/image/btech.py @@ -1,9 +1,38 @@ """BTech Data Module. -This script contains PyTorch Lightning DataModule for the BTech dataset. +This module provides a PyTorch Lightning DataModule for the BTech dataset. If the +dataset is not available locally, it will be downloaded and extracted +automatically. -If the dataset is not on the file system, the script downloads and -extracts the dataset and create PyTorch data objects. +Example: + Create a BTech datamodule:: + + >>> from anomalib.data import BTech + >>> datamodule = BTech( + ... root="./datasets/BTech", + ... category="01" + ... ) + +Notes: + The dataset will be automatically downloaded and converted to the required + format when first used. The directory structure after preparation will be:: + + datasets/ + └── BTech/ + ├── 01/ + ├── 02/ + └── 03/ + +License: + BTech dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0). + https://creativecommons.org/licenses/by-nc-sa/4.0/ + +Reference: + Mishra, Pankaj, et al. "BTAD—A Large Scale Dataset and Benchmark for + Real-World Industrial Anomaly Detection." Pattern Recognition 136 (2024): + 109542. """ # Copyright (C) 2022-2024 Intel Corporation @@ -33,63 +62,70 @@ class BTech(AnomalibDataModule): """BTech Lightning Data Module. Args: - root (Path | str): Path to the BTech dataset. + root (Path | str): Path to the root of the dataset. Defaults to ``"./datasets/BTech"``. - category (str): Name of the BTech category. + category (str): Category of the BTech dataset (e.g. ``"01"``, ``"02"``, + or ``"03"``). Defaults to ``"01"``. train_batch_size (int, optional): Training batch size. Defaults to ``32``. - eval_batch_size (int, optional): Eval batch size. + eval_batch_size (int, optional): Test batch size. Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. - test_split_mode (TestSplitMode, optional): Setting that determines how the testing subset is obtained. + test_split_mode (TestSplitMode): Setting that determines how the testing + subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float, optional): Fraction of images from the train set that will be reserved for testing. + test_split_ratio (float): Fraction of images from the train set that will + be reserved for testing. Defaults to ``0.2``. - val_split_mode (ValSplitMode, optional): Setting that determines how the validation subset is obtained. + val_split_mode (ValSplitMode): Setting that determines how the validation + subset is obtained. Defaults to ``ValSplitMode.SAME_AS_TEST``. - val_split_ratio (float, optional): Fraction of train or test images that will be reserved for validation. + val_split_ratio (float): Fraction of train or test images that will be + reserved for validation. Defaults to ``0.5``. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + seed (int | None, optional): Seed which may be set to a fixed value for + reproducibility. Defaults to ``None``. - Examples: - To create the BTech datamodule, we need to instantiate the class, and call the ``setup`` method. - - >>> from anomalib.data import BTech - >>> datamodule = BTech( - ... root="./datasets/BTech", - ... category="01", - ... train_batch_size=32, - ... eval_batch_size=32, - ... num_workers=8, - ... ) - >>> datamodule.setup() - - To get the train dataloader and the first batch of data: - - >>> i, data = next(enumerate(datamodule.train_dataloader())) - >>> data.keys() - dict_keys(['image']) - >>> data["image"].shape - torch.Size([32, 3, 256, 256]) - - To access the validation dataloader and the first batch of data: - - >>> i, data = next(enumerate(datamodule.val_dataloader())) - >>> data.keys() - dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) - >>> data["image"].shape, data["mask"].shape - (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) - - Similarly, to access the test dataloader and the first batch of data: - - >>> i, data = next(enumerate(datamodule.test_dataloader())) - >>> data.keys() - dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) - >>> data["image"].shape, data["mask"].shape - (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) + Example: + To create the BTech datamodule, instantiate the class and call + ``setup``:: + + >>> from anomalib.data import BTech + >>> datamodule = BTech( + ... root="./datasets/BTech", + ... category="01", + ... train_batch_size=32, + ... eval_batch_size=32, + ... num_workers=8, + ... ) + >>> datamodule.setup() + + Get the train dataloader and first batch:: + + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image']) + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + + Access the validation dataloader and first batch:: + + >>> i, data = next(enumerate(datamodule.val_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) + >>> data["image"].shape, data["mask"].shape + (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) + + Access the test dataloader and first batch:: + + >>> i, data = next(enumerate(datamodule.test_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) + >>> data["image"].shape, data["mask"].shape + (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) """ def __init__( @@ -134,34 +170,26 @@ def _setup(self, _stage: str | None = None) -> None: def prepare_data(self) -> None: """Download the dataset if not available. - This method checks if the specified dataset is available in the file system. - If not, it downloads and extracts the dataset into the appropriate directory. + This method checks if the specified dataset is available in the file + system. If not, it downloads and extracts the dataset into the + appropriate directory. Example: Assume the dataset is not available on the file system. - Here's how the directory structure looks before and after calling the - `prepare_data` method: - - Before: - - .. code-block:: bash + Here's how the directory structure looks before and after calling + ``prepare_data``:: + # Before $ tree datasets datasets ├── dataset1 └── dataset2 - Calling the method: - - .. code-block:: python - - >> datamodule = BTech(root="./datasets/BTech", category="01") - >> datamodule.prepare_data() - - After: - - .. code-block:: bash + # Calling prepare_data + >>> datamodule = BTech(root="./datasets/BTech", category="01") + >>> datamodule.prepare_data() + # After $ tree datasets datasets ├── dataset1 @@ -178,9 +206,12 @@ def prepare_data(self) -> None: # rename folder and convert images logger.info("Renaming the dataset directory") - shutil.move(src=str(self.root.parent / "BTech_Dataset_transformed"), dst=str(self.root)) - logger.info("Convert the bmp formats to png to have consistent image extensions") - for filename in tqdm(self.root.glob("**/*.bmp"), desc="Converting bmp to png"): + shutil.move( + src=str(self.root.parent / "BTech_Dataset_transformed"), + dst=str(self.root), + ) + logger.info("Convert the bmp formats to png for consistent extensions") + for filename in tqdm(self.root.glob("**/*.bmp"), desc="Converting"): image = cv2.imread(str(filename)) cv2.imwrite(str(filename.with_suffix(".png")), image) filename.unlink() diff --git a/src/anomalib/data/datamodules/image/datumaro.py b/src/anomalib/data/datamodules/image/datumaro.py index fb37bc7ee7..8865ad7c91 100644 --- a/src/anomalib/data/datamodules/image/datumaro.py +++ b/src/anomalib/data/datamodules/image/datumaro.py @@ -1,6 +1,38 @@ """DataModule for Datumaro format. -Note: This currently only works for annotations exported from Intel Geti™. +This module provides a PyTorch Lightning DataModule for datasets in Datumaro +format. Currently only supports annotations exported from Intel Geti™. + +Example: + Create a Datumaro datamodule:: + + >>> from pathlib import Path + >>> from anomalib.data import Datumaro + >>> datamodule = Datumaro( + ... root="./datasets/datumaro", + ... train_batch_size=32, + ... eval_batch_size=32, + ... num_workers=8, + ... ) + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image']) + +Notes: + The directory structure should be organized as follows:: + + root/ + ├── annotations/ + │ ├── train.json + │ └── test.json + └── images/ + ├── train/ + │ ├── image1.jpg + │ └── image2.jpg + └── test/ + ├── image3.jpg + └── image4.jpg """ # Copyright (C) 2024 Intel Corporation @@ -17,46 +49,41 @@ class Datumaro(AnomalibDataModule): """Datumaro datamodule. Args: - root (str | Path): Path to the dataset root directory. - train_batch_size (int): Batch size for training dataloader. + root (Path | str): Path to the dataset root directory. + train_batch_size (int, optional): Training batch size. Defaults to ``32``. - eval_batch_size (int): Batch size for evaluation dataloader. + eval_batch_size (int, optional): Test batch size. Defaults to ``32``. - num_workers (int): Number of workers for dataloaders. + num_workers (int, optional): Number of workers. Defaults to ``8``. - image_size (tuple[int, int], optional): Size to which input images should be resized. - Defaults to ``None``. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - train_transform (Transform, optional): Transforms that should be applied to the input images during training. - Defaults to ``None``. - eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. - Defaults to ``None``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + test_split_mode (TestSplitMode): Setting that determines how the testing + subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + test_split_ratio (float): Fraction of images from the train set that will + be reserved for testing. Defaults to ``0.2``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + val_split_mode (ValSplitMode): Setting that determines how the validation + subset is obtained. Defaults to ``ValSplitMode.SAME_AS_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + val_split_ratio (float): Fraction of train or test images that will be + reserved for validation. Defaults to ``0.5``. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. - Defualts to ``None``. - - Examples: - To create a Datumaro datamodule + seed (int | None, optional): Seed which may be set to a fixed value for + reproducibility. + Defaults to ``None``. - >>> from pathlib import Path - >>> from torchvision.transforms.v2 import Resize - >>> root = Path("path/to/dataset") - >>> datamodule = Datumaro(root, transform=Resize((256, 256))) + Example: + >>> from anomalib.data import Datumaro + >>> datamodule = Datumaro( + ... root="./datasets/datumaro", + ... train_batch_size=32, + ... eval_batch_size=32, + ... num_workers=8, + ... ) >>> datamodule.setup() >>> i, data = next(enumerate(datamodule.train_dataloader())) >>> data.keys() dict_keys(['image_path', 'label', 'image']) - - >>> data["image"].shape - torch.Size([32, 3, 256, 256]) """ def __init__( diff --git a/src/anomalib/data/datamodules/image/folder.py b/src/anomalib/data/datamodules/image/folder.py index bd3c3fedd0..9cb2d0e430 100644 --- a/src/anomalib/data/datamodules/image/folder.py +++ b/src/anomalib/data/datamodules/image/folder.py @@ -1,6 +1,32 @@ """Custom Folder Data Module. -This script creates a custom Lightning DataModule from a folder. +This script creates a custom Lightning DataModule from a folder containing normal +and abnormal images. + +Example: + Create a folder datamodule:: + + >>> from anomalib.data import Folder + >>> datamodule = Folder( + ... name="custom_folder", + ... root="./datasets/custom", + ... normal_dir="good", + ... abnormal_dir="defect" + ... ) + +Notes: + The directory structure should be organized as follows:: + + root/ + ├── normal_dir/ + │ ├── image1.png + │ └── image2.png + ├── abnormal_dir/ + │ ├── image3.png + │ └── image4.png + └── mask_dir/ + ├── mask3.png + └── mask4.png """ # Copyright (C) 2022-2024 Intel Corporation @@ -18,92 +44,62 @@ class Folder(AnomalibDataModule): """Folder DataModule. Args: - name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. - normal_dir (str | Path | Sequence): Name of the directory containing normal images. - root (str | Path | None): Path to the root folder containing normal and abnormal dirs. - Defaults to ``None``. - abnormal_dir (str | Path | None | Sequence): Name of the directory containing abnormal images. - Defaults to ``None``. - normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing - normal images for the test dataset. - Defaults to ``None``. - mask_dir (str | Path | Sequence | None, optional): Path to the directory containing - the mask annotations. - Defaults to ``None``. - normal_split_ratio (float, optional): Ratio to split normal training images and add to the - test set in case test set doesn't contain any normal images. - Defaults to 0.2. - extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the - directory. + name (str): Name of the dataset. Used for logging/saving. + normal_dir (str | Path | Sequence): Directory containing normal images. + root (str | Path | None): Root folder containing normal and abnormal + directories. Defaults to ``None``. + abnormal_dir (str | Path | None | Sequence): Directory containing + abnormal images. Defaults to ``None``. + normal_test_dir (str | Path | Sequence | None): Directory containing + normal test images. Defaults to ``None``. + mask_dir (str | Path | Sequence | None): Directory containing mask + annotations. Defaults to ``None``. + normal_split_ratio (float): Ratio to split normal training images for + test set when no normal test images exist. + Defaults to ``0.2``. + extensions (tuple[str, ...] | None): Image extensions to include. Defaults to ``None``. - train_batch_size (int, optional): Training batch size. + train_batch_size (int): Training batch size. Defaults to ``32``. - eval_batch_size (int, optional): Validation, test and predict batch size. + eval_batch_size (int): Validation/test batch size. Defaults to ``32``. - num_workers (int, optional): Number of workers. + num_workers (int): Number of workers for data loading. Defaults to ``8``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + test_split_mode (TestSplitMode): Method to obtain test subset. Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + test_split_ratio (float): Fraction of train images for testing. Defaults to ``0.2``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + val_split_mode (ValSplitMode): Method to obtain validation subset. Defaults to ``ValSplitMode.FROM_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + val_split_ratio (float): Fraction of images for validation. Defaults to ``0.5``. - seed (int | None, optional): Seed used during random subset splitting. + seed (int | None): Random seed for splitting. Defaults to ``None``. - Examples: - The following code demonstrates how to use the ``Folder`` datamodule. Assume that the dataset is structured - as follows: - - .. code-block:: bash - - $ tree sample_dataset - sample_dataset - ├── colour - │ ├── 00.jpg - │ ├── ... - │ └── x.jpg - ├── crack - │ ├── 00.jpg - │ ├── ... - │ └── y.jpg - ├── good - │ ├── ... - │ └── z.jpg - ├── LICENSE - └── mask - ├── colour - │ ├── ... - │ └── x.jpg - └── crack - ├── ... - └── y.jpg - - .. code-block:: python - - folder_datamodule = Folder( - root=dataset_root, - normal_dir="good", - abnormal_dir="crack", - mask_dir=dataset_root / "mask" / "crack", - ) - folder_datamodule.setup() - - To access the training images, - - .. code-block:: python - - >> i, data = next(enumerate(folder_datamodule.train_dataloader())) - >> print(data.keys(), data["image"].shape) - - To access the test images, - - .. code-block:: python - - >> i, data = next(enumerate(folder_datamodule.test_dataloader())) - >> print(data.keys(), data["image"].shape) + Example: + Create and setup a folder datamodule:: + + >>> from anomalib.data import Folder + >>> datamodule = Folder( + ... name="custom", + ... root="./datasets/custom", + ... normal_dir="good", + ... abnormal_dir="defect", + ... mask_dir="mask" + ... ) + >>> datamodule.setup() + + Get a batch from train dataloader:: + + >>> batch = next(iter(datamodule.train_dataloader())) + >>> batch.keys() + dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path']) + + Get a batch from test dataloader:: + + >>> batch = next(iter(datamodule.test_dataloader())) + >>> batch.keys() + dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path']) """ def __init__( @@ -172,8 +168,9 @@ def _setup(self, _stage: str | None = None) -> None: @property def name(self) -> str: - """Name of the datamodule. + """Get name of the datamodule. - Folder datamodule overrides the name property to provide a custom name. + Returns: + Name of the datamodule. """ return self._name diff --git a/src/anomalib/data/datamodules/image/kolektor.py b/src/anomalib/data/datamodules/image/kolektor.py index fe767c3a94..980e0ac4b4 100644 --- a/src/anomalib/data/datamodules/image/kolektor.py +++ b/src/anomalib/data/datamodules/image/kolektor.py @@ -1,17 +1,20 @@ """Kolektor Surface-Defect Data Module. Description: - This script provides a PyTorch DataModule for the Kolektor - Surface-Defect dataset. The dataset can be accessed at `Kolektor Surface-Defect Dataset `_. + This script provides a PyTorch DataModule for the Kolektor Surface-Defect + dataset. The dataset can be accessed at `Kolektor Surface-Defect Dataset + `_. License: - The Kolektor Surface-Defect dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike - 4.0 International License (CC BY-NC-SA 4.0). For more details, visit - `Creative Commons License `_. + The Kolektor Surface-Defect dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0). For more details, visit `Creative Commons License + `_. Reference: - Tabernik, Domen, Samo Šela, Jure Skvarč, and Danijel Skočaj. "Segmentation-based deep-learning approach - for surface-defect detection." Journal of Intelligent Manufacturing 31, no. 3 (2020): 759-776. + Tabernik, Domen, Samo Šela, Jure Skvarč, and Danijel Skočaj. + "Segmentation-based deep-learning approach for surface-defect detection." + Journal of Intelligent Manufacturing 31, no. 3 (2020): 759-776. """ # Copyright (C) 2023-2024 Intel Corporation @@ -35,26 +38,45 @@ class Kolektor(AnomalibDataModule): - """Kolektor Datamodule. + """Kolektor Surface-Defect DataModule. Args: - root (Path | str): Path to the root of the dataset + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/kolektor"``. train_batch_size (int, optional): Training batch size. Defaults to ``32``. eval_batch_size (int, optional): Test batch size. Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. - Defaults to ``TestSplitMode.FROM_DIR`` - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. - Defaults to ``0.2`` - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. - Defaults to ``ValSplitMode.SAME_AS_TEST`` - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. - Defaults to ``0.5`` - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + test_split_mode (TestSplitMode): Setting that determines how the testing + subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will + be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation + subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of train or test images that will be + reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for + reproducibility. Defaults to ``None``. + + Example: + >>> from anomalib.data import Kolektor + >>> datamodule = Kolektor( + ... root="./datasets/kolektor", + ... train_batch_size=32, + ... eval_batch_size=32, + ... num_workers=8, + ... ) + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path']) """ def __init__( @@ -95,17 +117,16 @@ def _setup(self, _stage: str | None = None) -> None: def prepare_data(self) -> None: """Download the dataset if not available. - This method checks if the specified dataset is available in the file system. - If not, it downloads and extracts the dataset into the appropriate directory. + This method checks if the specified dataset is available in the file + system. If not, it downloads and extracts the dataset into the + appropriate directory. Example: Assume the dataset is not available on the file system. - Here's how the directory structure looks before and after calling the - `prepare_data` method: - - Before: + Here's how the directory structure looks before and after calling + the ``prepare_data`` method: - .. code-block:: bash + Before:: $ tree datasets datasets @@ -114,14 +135,10 @@ def prepare_data(self) -> None: Calling the method: - .. code-block:: python - - >> datamodule = Kolektor(root="./datasets/kolektor") - >> datamodule.prepare_data() - - After: + >>> datamodule = Kolektor(root="./datasets/kolektor") + >>> datamodule.prepare_data() - .. code-block:: bash + After:: $ tree datasets datasets diff --git a/src/anomalib/data/datamodules/image/mvtec.py b/src/anomalib/data/datamodules/image/mvtec.py index 9e7b2fce89..b412e38c04 100644 --- a/src/anomalib/data/datamodules/image/mvtec.py +++ b/src/anomalib/data/datamodules/image/mvtec.py @@ -1,25 +1,45 @@ """MVTec AD Data Module. -Description: - This script contains PyTorch Lightning DataModule for the MVTec AD dataset. - If the dataset is not on the file system, the script downloads and extracts - the dataset and create PyTorch data objects. +This module provides a PyTorch Lightning DataModule for the MVTec AD dataset. If +the dataset is not available locally, it will be downloaded and extracted +automatically. + +Example: + Create a MVTec datamodule:: + + >>> from anomalib.data import MVTec + >>> datamodule = MVTec( + ... root="./datasets/mvtec", + ... category="bottle" + ... ) + +Notes: + The dataset will be automatically downloaded and converted to the required + format when first used. The directory structure after preparation will be:: + + datasets/ + └── mvtec/ + ├── bottle/ + ├── cable/ + └── ... License: MVTec AD dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License - (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). - -References: - - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger: - The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for - Unsupervised Anomaly Detection; in: International Journal of Computer Vision - 129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4. - - - Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD — - A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; - in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), - 9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982. + (CC BY-NC-SA 4.0). + https://creativecommons.org/licenses/by-nc-sa/4.0/ + +Reference: + Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, + Carsten Steger: The MVTec Anomaly Detection Dataset: A Comprehensive + Real-World Dataset for Unsupervised Anomaly Detection; in: International + Journal of Computer Vision 129(4):1038-1059, 2021, + DOI: 10.1007/s11263-020-01400-4. + + Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD — + A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; + in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), + 9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982. """ # Copyright (C) 2022-2024 Intel Corporation @@ -37,8 +57,8 @@ DOWNLOAD_INFO = DownloadInfo( name="mvtec", - url="https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094" - "/mvtec_anomaly_detection.tar.xz", + url="https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/" + "download/420938113-1629952094/mvtec_anomaly_detection.tar.xz", hashsum="cf4313b13603bec67abb49ca959488f7eedce2a9f7795ec54446c649ac98cd3d", ) @@ -49,53 +69,54 @@ class MVTec(AnomalibDataModule): Args: root (Path | str): Path to the root of the dataset. Defaults to ``"./datasets/MVTec"``. - category (str): Category of the MVTec dataset (e.g. "bottle" or "cable"). - Defaults to ``"bottle"``. + category (str): Category of the MVTec dataset (e.g. ``"bottle"`` or + ``"cable"``). Defaults to ``"bottle"``. train_batch_size (int, optional): Training batch size. Defaults to ``32``. eval_batch_size (int, optional): Test batch size. Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + test_split_mode (TestSplitMode): Method to create test set. Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + test_split_ratio (float): Fraction of data to use for testing. Defaults to ``0.2``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + val_split_mode (ValSplitMode): Method to create validation set. Defaults to ``ValSplitMode.SAME_AS_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + val_split_ratio (float): Fraction of data to use for validation. Defaults to ``0.5``. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. - Defualts to ``None``. - - Examples: - To create an MVTec AD datamodule with default settings: + seed (int | None, optional): Seed for reproducibility. + Defaults to ``None``. - >>> datamodule = MVTec() - >>> datamodule.setup() - >>> i, data = next(enumerate(datamodule.train_dataloader())) - >>> data.keys() - dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + Example: + Create MVTec datamodule with default settings:: - >>> data["image"].shape - torch.Size([32, 3, 256, 256]) + >>> datamodule = MVTec() + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) - To change the category of the dataset: + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) - >>> datamodule = MVTec(category="cable") + Change the category:: - MVTec AD dataset does not provide a validation set. If you would like - to use a separate validation set, you can use the ``val_split_mode`` and - ``val_split_ratio`` arguments to create a validation set. + >>> datamodule = MVTec(category="cable") - >>> datamodule = MVTec(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1) + Create validation set from test data:: - This will subsample the test set by 10% and use it as the validation set. - If you would like to create a validation set synthetically that would - not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option. + >>> datamodule = MVTec( + ... val_split_mode=ValSplitMode.FROM_TEST, + ... val_split_ratio=0.1 + ... ) - >>> datamodule = MVTec(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2) + Create synthetic validation set:: + >>> datamodule = MVTec( + ... val_split_mode=ValSplitMode.SYNTHETIC, + ... val_split_ratio=0.2 + ... ) """ def __init__( @@ -131,11 +152,12 @@ def _setup(self, _stage: str | None = None) -> None: This method may be overridden in subclass for custom splitting behaviour. Note: - The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule - subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset - splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and - the test set must therefore be created as early as the `fit` stage. - + The stage argument is not used here. This is because, for a given + instance of an AnomalibDataModule subclass, all three subsets are + created at the first call of setup(). This is to accommodate the + subset splitting behaviour of anomaly tasks, where the validation set + is usually extracted from the test set, and the test set must + therefore be created as early as the `fit` stage. """ self.train_data = MVTecDataset( split=Split.TRAIN, @@ -151,42 +173,26 @@ def _setup(self, _stage: str | None = None) -> None: def prepare_data(self) -> None: """Download the dataset if not available. - This method checks if the specified dataset is available in the file system. - If not, it downloads and extracts the dataset into the appropriate directory. + This method checks if the specified dataset is available in the file + system. If not, it downloads and extracts the dataset into the + appropriate directory. Example: - Assume the dataset is not available on the file system. - Here's how the directory structure looks before and after calling the - `prepare_data` method: - - Before: - - .. code-block:: bash - - $ tree datasets - datasets - ├── dataset1 - └── dataset2 - - Calling the method: - - .. code-block:: python - - >> datamodule = MVTec(root="./datasets/MVTec", category="bottle") - >> datamodule.prepare_data() + Assume the dataset is not available on the file system:: - After: + >>> datamodule = MVTec( + ... root="./datasets/MVTec", + ... category="bottle" + ... ) + >>> datamodule.prepare_data() - .. code-block:: bash + Directory structure after download:: - $ tree datasets - datasets - ├── dataset1 - ├── dataset2 - └── MVTec - ├── bottle - ├── ... - └── zipper + datasets/ + └── MVTec/ + ├── bottle/ + ├── cable/ + └── ... """ if (self.root / self.category).is_dir(): logger.info("Found the dataset.") diff --git a/src/anomalib/data/datamodules/image/visa.py b/src/anomalib/data/datamodules/image/visa.py index 553d0dcc03..c359eb7600 100644 --- a/src/anomalib/data/datamodules/image/visa.py +++ b/src/anomalib/data/datamodules/image/visa.py @@ -1,19 +1,41 @@ """Visual Anomaly (VisA) Data Module. -Description: - This script contains PyTorch Lightning DataModule for the Visual Anomal - (VisA) dataset. If the dataset is not on the file system, the script - downloads and extracts the dataset and create PyTorch data objects. +This module provides a PyTorch Lightning DataModule for the Visual Anomaly (VisA) +dataset. If the dataset is not available locally, it will be downloaded and +extracted automatically. + +Example: + Create a VisA datamodule:: + + >>> from anomalib.data import Visa + >>> datamodule = Visa( + ... root="./datasets/visa", + ... category="capsules" + ... ) + +Notes: + The dataset will be automatically downloaded and converted to the required + format when first used. The directory structure after preparation will be:: + + datasets/ + └── visa/ + ├── visa_pytorch/ + │ ├── candle/ + │ ├── capsules/ + │ └── ... + └── VisA_20220922.tar License: The VisA dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License - (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + (CC BY-NC-SA 4.0). + https://creativecommons.org/licenses/by-nc-sa/4.0/ Reference: - - Zou, Y., Jeong, J., Pemula, L., Zhang, D., & Dabeer, O. (2022). SPot-the-Difference - Self-supervised Pre-training for Anomaly Detection and Segmentation. In European - Conference on Computer Vision (pp. 392-408). Springer, Cham. + Zou, Y., Jeong, J., Pemula, L., Zhang, D., & Dabeer, O. (2022). + SPot-the-Difference Self-supervised Pre-training for Anomaly Detection + and Segmentation. In European Conference on Computer Vision (pp. 392-408). + Springer, Cham. """ # Copyright (C) 2022-2024 Intel Corporation @@ -46,25 +68,25 @@ class Visa(AnomalibDataModule): """VisA Datamodule. Args: - root (Path | str): Path to the root of the dataset + root (Path | str): Path to the root of the dataset. Defaults to ``"./datasets/visa"``. - category (str): Category of the Visa dataset such as ``candle``. - Defaults to ``"candle"``. + category (str): Category of the VisA dataset (e.g. ``"candle"``). + Defaults to ``"capsules"``. train_batch_size (int, optional): Training batch size. Defaults to ``32``. eval_batch_size (int, optional): Test batch size. Defaults to ``32``. - num_workers (int, optional): Number of workers. + num_workers (int, optional): Number of workers for data loading. Defaults to ``8``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + test_split_mode (TestSplitMode | str): Method to create test set. Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + test_split_ratio (float): Fraction of data to use for testing. Defaults to ``0.2``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + val_split_mode (ValSplitMode | str): Method to create validation set. Defaults to ``ValSplitMode.SAME_AS_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. - Defatuls to ``0.5``. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + val_split_ratio (float): Fraction of data to use for validation. + Defaults to ``0.5``. + seed (int | None, optional): Random seed for reproducibility. Defaults to ``None``. """ @@ -109,61 +131,32 @@ def _setup(self, _stage: str | None = None) -> None: ) def prepare_data(self) -> None: - """Download the dataset if not available. - - This method checks if the specified dataset is available in the file system. - If not, it downloads and extracts the dataset into the appropriate directory. - - Example: - Assume the dataset is not available on the file system. - Here's how the directory structure looks before and after calling the - `prepare_data` method: - - Before: - - .. code-block:: bash - - $ tree datasets - datasets - ├── dataset1 - └── dataset2 - - Calling the method: - - .. code-block:: python - - >> datamodule = Visa() - >> datamodule.prepare_data() - - After: - - .. code-block:: bash - - $ tree datasets - datasets - ├── dataset1 - ├── dataset2 - └── visa - ├── candle - ├── ... - ├── pipe_fryum - │ ├── Data - │ └── image_anno.csv - ├── split_csv - │ ├── 1cls.csv - │ ├── 2cls_fewshot.csv - │ └── 2cls_highshot.csv - ├── VisA_20220922.tar - └── visa_pytorch - ├── candle - ├── ... - ├── pcb4 - └── pipe_fryum - - ``prepare_data`` ensures that the dataset is converted to MVTec - format. ``visa_pytorch`` is the directory that contains the dataset - in the MVTec format. ``visa`` is the directory that contains the - original dataset. + """Download and prepare the dataset if not available. + + This method checks if the dataset exists and is properly formatted. + If not, it downloads and prepares the data in the following steps: + + 1. If the processed dataset exists (``visa_pytorch/{category}``), do + nothing + 2. If the raw dataset exists but isn't processed, apply the train/test + split + 3. If the dataset doesn't exist, download, extract, and process it + + The final directory structure will be:: + + datasets/ + └── visa/ + ├── visa_pytorch/ + │ ├── candle/ + │ │ ├── train/ + │ │ │ └── good/ + │ │ ├── test/ + │ │ │ ├── good/ + │ │ │ └── bad/ + │ │ └── ground_truth/ + │ │ └── bad/ + │ └── ... + └── VisA_20220922.tar """ if (self.split_root / self.category).is_dir(): # dataset is available, and split has been applied @@ -181,7 +174,7 @@ def prepare_data(self) -> None: def apply_cls1_split(self) -> None: """Apply the 1-class subset splitting using the fixed split in the csv file. - adapted from https://github.com/amazon-science/spot-diff + Adapted from https://github.com/amazon-science/spot-diff. """ logger.info("preparing data") categories = [ diff --git a/src/anomalib/data/datamodules/video/__init__.py b/src/anomalib/data/datamodules/video/__init__.py index f9b3763525..efdffd73a9 100644 --- a/src/anomalib/data/datamodules/video/__init__.py +++ b/src/anomalib/data/datamodules/video/__init__.py @@ -1,4 +1,22 @@ -"""Anomalib Video Data Modules.""" +"""Anomalib Video Data Modules. + +This module contains data modules for loading and processing video datasets for +anomaly detection. The following data modules are available: + +- ``Avenue``: CUHK Avenue Dataset for abnormal event detection +- ``ShanghaiTech``: ShanghaiTech Campus Dataset for anomaly detection +- ``UCSDped``: UCSD Pedestrian Dataset for anomaly detection + +Example: + Load the Avenue dataset:: + + >>> from anomalib.data import Avenue + >>> datamodule = Avenue( + ... root="./datasets/avenue", + ... clip_length_in_frames=2, + ... frames_between_clips=1 + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,7 +29,14 @@ class VideoDataFormat(str, Enum): - """Supported Video Dataset Types.""" + """Supported Video Dataset Types. + + The following dataset formats are supported: + + - ``UCSDPED``: UCSD Pedestrian Dataset + - ``AVENUE``: CUHK Avenue Dataset + - ``SHANGHAITECH``: ShanghaiTech Campus Dataset + """ UCSDPED = "ucsdped" AVENUE = "avenue" diff --git a/src/anomalib/data/datamodules/video/avenue.py b/src/anomalib/data/datamodules/video/avenue.py index 446b4b6c37..f91f5dd384 100644 --- a/src/anomalib/data/datamodules/video/avenue.py +++ b/src/anomalib/data/datamodules/video/avenue.py @@ -1,16 +1,56 @@ """CUHK Avenue Data Module. -Description: - This module provides a PyTorch Lightning DataModule for the CUHK Avenue dataset. - If the dataset is not already present on the file system, the DataModule class will download and - extract the dataset, converting the .mat mask files to .png format. +This module provides a PyTorch Lightning DataModule for the CUHK Avenue dataset. If +the dataset is not already present on the file system, the DataModule class will +download and extract the dataset, converting the ``.mat`` mask files to ``.png`` +format. + +Example: + Create an Avenue datamodule:: + + >>> from anomalib.data import Avenue + >>> datamodule = Avenue( + ... root="./datasets/avenue", + ... clip_length_in_frames=2, + ... frames_between_clips=1, + ... ) + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) + +Notes: + The directory structure after preparation will be:: + + root/ + ├── ground_truth_demo/ + │ ├── ground_truth_show.m + │ ├── Readme.txt + │ ├── testing_label_mask/ + │ └── testing_videos/ + ├── testing_videos/ + │ ├── ... + │ └── 21.avi + ├── testing_vol/ + │ ├── ... + │ └── vol21.mat + ├── training_videos/ + │ ├── ... + │ └── 16.avi + └── training_vol/ + ├── ... + └── vol16.mat + +License: + The CUHK Avenue dataset is released for academic research only. For licensing + details, see the original dataset website. Reference: - - Lu, Cewu, Jianping Shi, and Jiaya Jia. "Abnormal event detection at 150 fps in Matlab." - In Proceedings of the IEEE International Conference on Computer Vision, 2013. + Lu, Cewu, Jianping Shi, and Jiaya Jia. "Abnormal event detection at 150 fps + in Matlab." In Proceedings of the IEEE International Conference on Computer + Vision, 2013. """ - # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -45,73 +85,47 @@ class Avenue(AnomalibVideoDataModule): """Avenue DataModule class. Args: - root (Path | str): Path to the root of the dataset - Defaults to ``./datasets/avenue``. - gt_dir (Path | str): Path to the ground truth files - Defaults to ``./datasets/avenue/ground_truth_demo``. - clip_length_in_frames (int, optional): Number of video frames in each clip. + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/avenue"``. + gt_dir (Path | str): Path to the ground truth files. + Defaults to ``"./datasets/avenue/ground_truth_demo"``. + clip_length_in_frames (int): Number of video frames in each clip. Defaults to ``2``. - frames_between_clips (int, optional): Number of frames between each consecutive video clip. + frames_between_clips (int): Number of frames between consecutive clips. Defaults to ``1``. - target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval - Defaults to ``VideoTargetFrame.LAST``. - train_batch_size (int, optional): Training batch size. + target_frame (VideoTargetFrame | str): Target frame in clip for ground + truth. Defaults to ``VideoTargetFrame.LAST``. + train_batch_size (int): Training batch size. Defaults to ``32``. - eval_batch_size (int, optional): Test batch size. + eval_batch_size (int): Test batch size. Defaults to ``32``. - num_workers (int, optional): Number of workers. + num_workers (int): Number of workers. Defaults to ``8``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. - Defaults to ``ValSplitMode.FROM_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + val_split_mode (ValSplitMode | str): How validation subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of data reserved for validation. Defaults to ``0.5``. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + seed (int | None): Seed for reproducibility. Defaults to ``None``. - Examples: - To create a DataModule for Avenue dataset with default parameters: - - .. code-block:: python - - datamodule = Avenue() - datamodule.setup() - - i, data = next(enumerate(datamodule.train_dataloader())) - data.keys() - # Output: dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) - - i, data = next(enumerate(datamodule.test_dataloader())) - data.keys() - # Output: dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', 'original_image', 'label']) - - data["image"].shape - # Output: torch.Size([32, 2, 3, 256, 256]) - - Note that it is important to note that the dataloader returns a batch of clips, where each clip is a sequence of - frames. The number of frames in each clip is determined by the ``clip_length_in_frames`` parameter. The - ``frames_between_clips`` parameter determines the number of frames between each consecutive clip. The - ``target_frame`` parameter determines which frame in the clip is used for ground truth retrieval. For example, - if ``clip_length_in_frames=2``, ``frames_between_clips=1`` and ``target_frame=VideoTargetFrame.LAST``, then the - dataloader will return a batch of clips where each clip contains two consecutive frames from the video. The - second frame in each clip will be used as the ground truth for the first frame in the clip. The following code - shows how to create a dataloader for classification: - - .. code-block:: python - - datamodule = Avenue( - clip_length_in_frames=2, - frames_between_clips=1, - target_frame=VideoTargetFrame.LAST - ) - datamodule.setup() - - i, data = next(enumerate(datamodule.train_dataloader())) - data.keys() - # Output: dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) - - data["image"].shape - # Output: torch.Size([32, 2, 3, 256, 256]) - + Example: + Create a dataloader for classification:: + + >>> datamodule = Avenue( + ... clip_length_in_frames=2, + ... frames_between_clips=1, + ... target_frame=VideoTargetFrame.LAST + ... ) + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data["image"].shape + torch.Size([32, 2, 3, 256, 256]) + + Notes: + The dataloader returns batches of clips, where each clip contains + ``clip_length_in_frames`` consecutive frames. ``frames_between_clips`` + determines frame spacing between clips. ``target_frame`` specifies which + frame provides ground truth. """ def __init__( @@ -165,54 +179,35 @@ def _setup(self, _stage: str | None = None) -> None: def prepare_data(self) -> None: """Download the dataset if not available. - This method checks if the specified dataset is available in the file system. - If not, it downloads and extracts the dataset into the appropriate directory. + This method checks if the specified dataset is available in the file + system. If not, it downloads and extracts the dataset into the appropriate + directory. Example: - Assume the dataset is not available on the file system. - Here's how the directory structure looks before and after calling the - `prepare_data` method: - - Before: - - .. code-block:: bash - - $ tree datasets - datasets - ├── dataset1 - └── dataset2 - - Calling the method: - - .. code-block:: python - - >> datamodule = Avenue() - >> datamodule.prepare_data() + Assume the dataset is not available on the file system:: - After: + >>> datamodule = Avenue() + >>> datamodule.prepare_data() - .. code-block:: bash + The directory structure after preparation will be:: - $ tree datasets - datasets - ├── dataset1 - ├── dataset2 - └── avenue - ├── ground_truth_demo + datasets/ + └── avenue/ + ├── ground_truth_demo/ │ ├── ground_truth_show.m │ ├── Readme.txt - │ ├── testing_label_mask - │ └── testing_videos - ├── testing_videos + │ ├── testing_label_mask/ + │ └── testing_videos/ + ├── testing_videos/ │ ├── ... │ └── 21.avi - ├── testing_vol + ├── testing_vol/ │ ├── ... │ └── vol21.mat - ├── training_videos + ├── training_videos/ │ ├── ... │ └── 16.avi - └── training_vol + └── training_vol/ ├── ... └── vol16.mat """ @@ -235,10 +230,11 @@ def prepare_data(self) -> None: @staticmethod def _convert_masks(gt_dir: Path) -> None: - """Convert mask files to .png. + """Convert mask files from ``.mat`` to ``.png`` format. - The masks in the Avenue datasets are provided as matlab (.mat) files. To speed up data loading, we convert the - masks into a sepaarte .png file for every video frame in the dataset. + The masks in the Avenue datasets are provided as matlab (``.mat``) files. + To speed up data loading, we convert the masks into a separate ``.png`` + file for every video frame in the dataset. Args: gt_dir (Path): Ground truth folder of the dataset. diff --git a/src/anomalib/data/datamodules/video/shanghaitech.py b/src/anomalib/data/datamodules/video/shanghaitech.py index f5e5cd0036..babd338fc0 100644 --- a/src/anomalib/data/datamodules/video/shanghaitech.py +++ b/src/anomalib/data/datamodules/video/shanghaitech.py @@ -1,16 +1,45 @@ """ShanghaiTech Campus Data Module. -Description: - This module contains PyTorch Lightning DataModule for the ShanghaiTech Campus dataset. - If the dataset is not on the file system, the DataModule class downloads and - extracts the dataset and converts video files to a format that is readable by pyav. +This module provides a PyTorch Lightning DataModule for the ShanghaiTech Campus +dataset. If the dataset is not available locally, it will be downloaded and +extracted automatically. The video files are also converted to a format readable +by pyav. + +Example: + Create a ShanghaiTech datamodule:: + + >>> from anomalib.data import ShanghaiTech + >>> datamodule = ShanghaiTech( + ... root="./datasets/shanghaitech", + ... scene=1, + ... clip_length_in_frames=2, + ... frames_between_clips=1, + ... ) + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image', 'video_path', 'frames', 'label']) + +Notes: + The directory structure after preparation will be:: + + root/ + ├── testing/ + │ ├── frames/ + │ ├── test_frame_mask/ + │ └── test_pixel_mask/ + └── training/ + ├── frames/ + ├── converted_videos/ + └── videos/ License: ShanghaiTech Campus Dataset is released under the BSD 2-Clause License. Reference: - - W. Liu and W. Luo, D. Lian and S. Gao. "Future Frame Prediction for Anomaly Detection -- A New Baseline." - IEEE Conference on Computer Vision and Pattern Recognition (CVPR). 2018. + Liu, W., Luo, W., Lian, D., & Gao, S. (2018). Future frame prediction for + anomaly detection--a new baseline. In Proceedings of the IEEE conference on + computer vision and pattern recognition (pp. 6536-6545). """ # Copyright (C) 2023-2024 Intel Corporation @@ -39,17 +68,31 @@ class ShanghaiTech(AnomalibVideoDataModule): """ShanghaiTech DataModule class. Args: - root (Path | str): Path to the root of the dataset - scene (int): Index of the dataset scene (category) in range [1, 13] - clip_length_in_frames (int, optional): Number of video frames in each clip. - frames_between_clips (int, optional): Number of frames between each consecutive video clip. - target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval - train_batch_size (int, optional): Training batch size. Defaults to 32. - eval_batch_size (int, optional): Test batch size. Defaults to 32. - num_workers (int, optional): Number of workers. Defaults to 8. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + root (Path | str): Path to the root directory of the dataset. + Defaults to ``"./datasets/shanghaitech"``. + scene (int): Scene index in range [1, 13]. + Defaults to ``1``. + clip_length_in_frames (int): Number of frames in each video clip. + Defaults to ``2``. + frames_between_clips (int): Number of frames between consecutive clips. + Defaults to ``1``. + target_frame (VideoTargetFrame): Specifies which frame in the clip should + be used for ground truth. + Defaults to ``VideoTargetFrame.LAST``. + train_batch_size (int): Training batch size. + Defaults to ``32``. + eval_batch_size (int): Test batch size. + Defaults to ``32``. + num_workers (int): Number of workers for data loading. + Defaults to ``8``. + val_split_mode (ValSplitMode): Setting that determines how validation + subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of train or test images that will be + reserved for validation. + Defaults to ``0.5``. + seed (int | None): Random seed for reproducibility. + Defaults to ``None``. """ def __init__( @@ -125,19 +168,25 @@ def prepare_data(self) -> None: @staticmethod def _convert_training_videos(video_folder: Path, target_folder: Path) -> None: - """Re-code the training videos to ensure correct reading of frames by torchvision. + """Re-code training videos for correct frame reading by torchvision. - The encoding of the raw video files in the ShanghaiTech dataset causes some problems when - reading the frames using pyav. To prevent this, we read the frames from the video files using opencv, - and write them to a new video file that can be parsed correctly with pyav. + The encoding of the raw video files in the ShanghaiTech dataset causes + issues when reading frames using pyav. To prevent this, frames are read + using opencv and written to new video files that can be parsed correctly + with pyav. Args: - video_folder (Path): Path to the folder of training videos. - target_folder (Path): File system location where the converted videos will be stored. + video_folder (Path): Path to the folder containing training videos. + target_folder (Path): Path where converted videos will be stored. """ training_videos = sorted(video_folder.glob("*")) for video_idx, video_path in enumerate(training_videos): - logger.info("Converting training video %s (%i/%i)...", video_path.name, video_idx + 1, len(training_videos)) + logger.info( + "Converting training video %s (%i/%i)...", + video_path.name, + video_idx + 1, + len(training_videos), + ) file_name = video_path.name target_path = target_folder / file_name convert_video(video_path, target_path, codec="XVID") diff --git a/src/anomalib/data/datamodules/video/ucsd_ped.py b/src/anomalib/data/datamodules/video/ucsd_ped.py index e08bfd1ca6..e4bd9cf15e 100644 --- a/src/anomalib/data/datamodules/video/ucsd_ped.py +++ b/src/anomalib/data/datamodules/video/ucsd_ped.py @@ -1,4 +1,9 @@ -"""UCSD Pedestrian Data Module.""" +"""UCSD Pedestrian Data Module. + +This module provides a PyTorch Lightning data module for the UCSD Pedestrian dataset. +The dataset consists of surveillance videos of pedestrians, with anomalies defined as +non-pedestrian entities like cars, bikes, etc. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -22,20 +27,35 @@ class UCSDped(AnomalibVideoDataModule): - """UCSDped DataModule class. + """UCSD Pedestrian DataModule Class. Args: - root (Path | str): Path to the root of the dataset - category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2" - clip_length_in_frames (int, optional): Number of video frames in each clip. - frames_between_clips (int, optional): Number of frames between each consecutive video clip. - target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval - train_batch_size (int, optional): Training batch size. Defaults to 32. - eval_batch_size (int, optional): Test batch size. Defaults to 32. - num_workers (int, optional): Number of workers. Defaults to 8. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + root (Path | str): Path to the root directory where the dataset will be + downloaded and extracted. Defaults to ``"./datasets/ucsd"``. + category (str): Dataset subcategory. Must be either ``"UCSDped1"`` or + ``"UCSDped2"``. Defaults to ``"UCSDped2"``. + clip_length_in_frames (int): Number of frames in each video clip. + Defaults to ``2``. + frames_between_clips (int): Number of frames between consecutive video + clips. Defaults to ``10``. + target_frame (VideoTargetFrame): Specifies which frame in the clip should + be used for ground truth. Defaults to ``VideoTargetFrame.LAST``. + train_batch_size (int): Batch size for training. Defaults to ``8``. + eval_batch_size (int): Batch size for validation and testing. + Defaults to ``8``. + num_workers (int): Number of workers for data loading. Defaults to ``8``. + val_split_mode (ValSplitMode): Determines how validation set is created. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of data to use for validation. + Must be between 0 and 1. Defaults to ``0.5``. + seed (int | None): Random seed for reproducibility. Defaults to ``None``. + + Example: + >>> datamodule = UCSDped(root="./datasets/ucsd") + >>> datamodule.setup() # Downloads and prepares the dataset + >>> train_loader = datamodule.train_dataloader() + >>> val_loader = datamodule.val_dataloader() + >>> test_loader = datamodule.test_dataloader() """ def __init__( @@ -69,6 +89,11 @@ def __init__( self.target_frame = VideoTargetFrame(target_frame) def _setup(self, _stage: str | None = None) -> None: + """Set up train and test datasets. + + Args: + _stage (str | None): Stage for Lightning. Can be "fit" or "test". + """ self.train_data = UCSDpedDataset( clip_length_in_frames=self.clip_length_in_frames, frames_between_clips=self.frames_between_clips, @@ -88,7 +113,11 @@ def _setup(self, _stage: str | None = None) -> None: ) def prepare_data(self) -> None: - """Download the dataset if not available.""" + """Download and extract the dataset if not already available. + + The method checks if the dataset directory exists. If not, it downloads + and extracts the dataset to the specified root directory. + """ if (self.root / self.category).is_dir(): logger.info("Found the dataset.") else: diff --git a/src/anomalib/data/datasets/__init__.py b/src/anomalib/data/datasets/__init__.py index 32e3995ea5..7011b7373a 100644 --- a/src/anomalib/data/datasets/__init__.py +++ b/src/anomalib/data/datasets/__init__.py @@ -1,4 +1,37 @@ -"""Torch Dataset Implementations of Anomalib Datasets.""" +"""PyTorch Dataset implementations for anomaly detection. + +This module provides dataset implementations for various anomaly detection tasks: + +Base Classes: + - ``AnomalibDataset``: Base class for all Anomalib datasets + - ``AnomalibDepthDataset``: Base class for 3D/depth datasets + - ``AnomalibVideoDataset``: Base class for video datasets + +Depth Datasets: + - ``Folder3DDataset``: Custom RGB-D dataset from folder structure + - ``MVTec3DDataset``: MVTec 3D AD dataset with industrial objects + +Image Datasets: + - ``BTechDataset``: BTech dataset containing industrial objects + - ``DatumaroDataset``: Dataset in Datumaro format (Intel Geti™ export) + - ``FolderDataset``: Custom dataset from folder structure + - ``KolektorDataset``: Kolektor surface defect dataset + - ``MVTecDataset``: MVTec AD dataset with industrial objects + - ``VisaDataset``: Visual Inspection of Surface Anomalies dataset + +Video Datasets: + - ``AvenueDataset``: CUHK Avenue dataset for abnormal event detection + - ``ShanghaiTechDataset``: ShanghaiTech Campus surveillance dataset + - ``UCSDpedDataset``: UCSD Pedestrian dataset for anomaly detection + +Example: + >>> from anomalib.data.datasets import MVTecDataset + >>> dataset = MVTecDataset( + ... root="./datasets/MVTec", + ... category="bottle", + ... split="train" + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/datasets/base/__init__.py b/src/anomalib/data/datasets/base/__init__.py index b39af32f4c..5a72b52378 100644 --- a/src/anomalib/data/datasets/base/__init__.py +++ b/src/anomalib/data/datasets/base/__init__.py @@ -1,4 +1,15 @@ -"""Base Classes for Torch Datasets.""" +"""Base Classes for Torch Datasets. + +This module contains the base dataset classes used in anomalib for different data +modalities: + +- ``AnomalibDataset``: Base class for image datasets +- ``AnomalibVideoDataset``: Base class for video datasets +- ``AnomalibDepthDataset``: Base class for depth/3D datasets + +These classes extend PyTorch's Dataset class with additional functionality specific +to anomaly detection tasks. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/datasets/base/depth.py b/src/anomalib/data/datasets/base/depth.py index 5dd4683b6c..d15bcdac1b 100644 --- a/src/anomalib/data/datasets/base/depth.py +++ b/src/anomalib/data/datasets/base/depth.py @@ -1,4 +1,8 @@ -"""Base Depth Dataset.""" +"""Base Depth Dataset. + +This module implements the base depth dataset class for anomaly detection tasks that +use RGB-D (RGB + Depth) data. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -22,9 +26,22 @@ class AnomalibDepthDataset(AnomalibDataset, ABC): """Base depth anomalib dataset class. + This class extends ``AnomalibDataset`` to handle RGB-D data for anomaly + detection tasks. It supports both classification and segmentation tasks. + Args: - transform (Transform, optional): Transforms that should be applied to the input images. + transform (Transform | None, optional): Transforms to be applied to the + input images and depth maps. If ``None``, no transforms are applied. Defaults to ``None``. + + Example: + >>> from anomalib.data.datasets import AnomalibDepthDataset + >>> dataset = AnomalibDepthDataset(transform=None) + >>> item = dataset[0] + >>> item.image.shape + torch.Size([3, H, W]) + >>> item.depth_map.shape + torch.Size([1, H, W]) """ def __init__(self, transform: Transform | None = None) -> None: @@ -33,13 +50,24 @@ def __init__(self, transform: Transform | None = None) -> None: self.transform = transform def __getitem__(self, index: int) -> DepthItem: - """Return rgb image, depth image and mask. + """Get dataset item for the given index. Args: - index (int): Index of the item to be returned. + index (int): Index of the item to retrieve. Returns: - dict[str, str | torch.Tensor]: Dictionary containing the image, depth image and mask. + DepthItem: Dataset item containing the following fields: + - image (Tensor): RGB image + - depth_map (Tensor): Depth map + - gt_mask (Tensor | None): Ground truth mask for segmentation + - gt_label (int): Ground truth label (0: normal, 1: anomalous) + - image_path (str): Path to the RGB image + - depth_path (str): Path to the depth map + - mask_path (str | None): Path to the ground truth mask + + Raises: + ValueError: If the task type is neither classification nor + segmentation. """ image_path = self.samples.iloc[index].image_path mask_path = self.samples.iloc[index].mask_path @@ -83,5 +111,9 @@ def __getitem__(self, index: int) -> DepthItem: @property def collate_fn(self) -> Callable: - """Return the collate function for depth batches.""" + """Get the collate function for creating depth batches. + + Returns: + Callable: Collate function that creates ``DepthBatch`` objects. + """ return DepthBatch.collate diff --git a/src/anomalib/data/datasets/base/image.py b/src/anomalib/data/datasets/base/image.py index 9bc8c45e74..4c5267ea2c 100644 --- a/src/anomalib/data/datasets/base/image.py +++ b/src/anomalib/data/datasets/base/image.py @@ -1,4 +1,30 @@ -"""Anomalib dataset base class.""" +"""Anomalib dataset base class. + +This module provides the base dataset class for Anomalib datasets. The dataset is based on a +dataframe that contains the information needed by the dataloader to load each dataset item +into memory. + +The samples dataframe must be set from the subclass using the setter of the ``samples`` +property. + +The DataFrame must include at least the following columns: + - ``split`` (str): The subset to which the dataset item is assigned (e.g., 'train', + 'test'). + - ``image_path`` (str): Path to the file system location where the image is stored. + - ``label_index`` (int): Index of the anomaly label, typically 0 for 'normal' and 1 for + 'anomalous'. + - ``mask_path`` (str, optional): Path to the ground truth masks (for anomalous images + only). Required if task is 'segmentation'. + +Example DataFrame: + >>> df = pd.DataFrame({ + ... 'image_path': ['path/to/image.png'], + ... 'label': ['anomalous'], + ... 'label_index': [1], + ... 'mask_path': ['path/to/mask.png'], + ... 'split': ['train'] + ... }) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -26,34 +52,27 @@ class AnomalibDataset(Dataset, ABC): - """Anomalib dataset. + """Base class for Anomalib datasets. - The dataset is based on a dataframe that contains the information needed by the dataloader to load each of - the dataset items into memory. + The dataset is designed to work with image-based anomaly detection tasks. It supports + both classification and segmentation tasks. - The samples dataframe must be set from the subclass using the setter of the `samples` property. - - The DataFrame must, at least, include the following columns: - - `split` (str): The subset to which the dataset item is assigned (e.g., 'train', 'test'). - - `image_path` (str): Path to the file system location where the image is stored. - - `label_index` (int): Index of the anomaly label, typically 0 for 'normal' and 1 for 'anomalous'. - - `mask_path` (str, optional): Path to the ground truth masks (for the anomalous images only). - Required if task is 'segmentation'. + Args: + transform (Transform | None, optional): Transforms to be applied to the input images. + Defaults to ``None``. - Example DataFrame: - +---+-------------------+-----------+-------------+------------------+-------+ - | | image_path | label | label_index | mask_path | split | - +---+-------------------+-----------+-------------+------------------+-------+ - | 0 | path/to/image.png | anomalous | 1 | path/to/mask.png | train | - +---+-------------------+-----------+-------------+------------------+-------+ + Example: + >>> from torchvision.transforms.v2 import Resize + >>> dataset = AnomalibDataset(transform=Resize((256, 256))) + >>> len(dataset) # Get dataset length + 100 + >>> item = dataset[0] # Get first item + >>> item.image.shape + torch.Size([3, 256, 256]) Note: - The example above is illustrative and may need to be adjusted based on the specific dataset structure. - - Args: - task (str): Task type, either 'classification' or 'segmentation' - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. + This is an abstract base class. Subclasses must implement the required methods and + set the samples DataFrame. """ def __init__(self, transform: Transform | None = None) -> None: @@ -64,7 +83,17 @@ def __init__(self, transform: Transform | None = None) -> None: @property def name(self) -> str: - """Name of the dataset.""" + """Get the name of the dataset. + + Returns: + str: Name of the dataset derived from the class name, with 'Dataset' suffix + removed if present. + + Example: + >>> dataset = AnomalibDataset() + >>> dataset.name + 'Anomalib' + """ class_name = self.__class__.__name__ # Remove the `_dataset` suffix from the class name @@ -74,16 +103,35 @@ def name(self) -> str: return class_name def __len__(self) -> int: - """Get length of the dataset.""" + """Get length of the dataset. + + Returns: + int: Number of samples in the dataset. + + Raises: + RuntimeError: If samples DataFrame is not set. + """ return len(self.samples) def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibDataset": - """Subsamples the dataset at the provided indices. + """Create a subset of the dataset using the provided indices. Args: indices (Sequence[int]): Indices at which the dataset is to be subsampled. - inplace (bool): When true, the subsampling will be performed on the instance itself. - Defaults to ``False``. + inplace (bool, optional): When true, modify the instance itself. Defaults to + ``False``. + + Returns: + AnomalibDataset: Subsampled dataset. + + Raises: + ValueError: If duplicate indices are provided. + + Example: + >>> dataset = AnomalibDataset() + >>> subset = dataset.subsample([0, 1, 2]) + >>> len(subset) + 3 """ if len(set(indices)) != len(indices): msg = "No duplicates allowed in indices." @@ -94,21 +142,41 @@ def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibD @property def samples(self) -> DataFrame: - """Get the samples dataframe.""" + """Get the samples DataFrame. + + Returns: + DataFrame: DataFrame containing dataset samples. + + Raises: + RuntimeError: If samples DataFrame has not been set. + """ if self._samples is None: msg = ( - "Dataset does not have a samples dataframe. Ensure that a dataframe has been assigned to " - "`dataset.samples`." + "Dataset does not have a samples dataframe. Ensure that a dataframe has " + "been assigned to `dataset.samples`." ) raise RuntimeError(msg) return self._samples @samples.setter def samples(self, samples: DataFrame) -> None: - """Overwrite the samples with a new dataframe. + """Set the samples DataFrame. Args: - samples (DataFrame): DataFrame with new samples. + samples (DataFrame): DataFrame containing dataset samples. + + Raises: + TypeError: If samples is not a pandas DataFrame. + ValueError: If required columns are missing. + FileNotFoundError: If any image paths do not exist. + + Example: + >>> df = pd.DataFrame({ + ... 'image_path': ['image.png'], + ... 'split': ['train'] + ... }) + >>> dataset = AnomalibDataset() + >>> dataset.samples = df """ # validate the passed samples by checking the if not isinstance(samples, DataFrame): @@ -127,37 +195,69 @@ def samples(self, samples: DataFrame) -> None: @property def category(self) -> str | None: - """Get the category of the dataset.""" + """Get the category of the dataset. + + Returns: + str | None: Dataset category if set, else None. + """ return self._category @category.setter def category(self, category: str) -> None: - """Set the category of the dataset.""" + """Set the category of the dataset. + + Args: + category (str): Category to assign to the dataset. + """ self._category = category @property def has_normal(self) -> bool: - """Check if the dataset contains any normal samples.""" + """Check if the dataset contains normal samples. + + Returns: + bool: True if dataset contains normal samples, False otherwise. + """ return LabelName.NORMAL in list(self.samples.label_index) @property def has_anomalous(self) -> bool: - """Check if the dataset contains any anomalous samples.""" + """Check if the dataset contains anomalous samples. + + Returns: + bool: True if dataset contains anomalous samples, False otherwise. + """ return LabelName.ABNORMAL in list(self.samples.label_index) @property def task(self) -> TaskType: - """Infer the task type from the dataset.""" + """Get the task type from the dataset. + + Returns: + TaskType: Type of task (classification or segmentation). + + Raises: + ValueError: If task type is unknown. + """ return TaskType(self.samples.attrs["task"]) def __getitem__(self, index: int) -> DatasetItem: - """Get dataset item for the index ``index``. + """Get dataset item for the given index. Args: index (int): Index to get the item. Returns: - DatasetItem: DatasetItem instance containing image and ground truth (if available). + DatasetItem: Dataset item containing image and ground truth (if available). + + Raises: + ValueError: If task type is unknown. + + Example: + >>> dataset = AnomalibDataset() + >>> item = dataset[0] + >>> isinstance(item.image, torch.Tensor) + True """ image_path = self.samples.iloc[index].image_path mask_path = self.samples.iloc[index].mask_path @@ -198,6 +298,14 @@ def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset": Returns: AnomalibDataset: Concatenated dataset. + + Raises: + TypeError: If datasets are not of the same type. + + Example: + >>> dataset1 = AnomalibDataset() + >>> dataset2 = AnomalibDataset() + >>> combined = dataset1 + dataset2 """ if not isinstance(other_dataset, self.__class__): msg = "Cannot concatenate datasets that are not of the same type." @@ -208,9 +316,13 @@ def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset": @property def collate_fn(self) -> Callable: - """Get the collate function for the items returned by this dataset. + """Get the collate function for batching dataset items. + + Returns: + Callable: Collate function from ImageBatch. - By default, the dataset is an image dataset, so we will return the ImageBatch's collate function. - Other dataset types should override this property. + Note: + By default, this returns ImageBatch's collate function. Override this property + for other dataset types. """ return ImageBatch.collate diff --git a/src/anomalib/data/datasets/base/video.py b/src/anomalib/data/datasets/base/video.py index 4b8366aae4..2e675aa717 100644 --- a/src/anomalib/data/datasets/base/video.py +++ b/src/anomalib/data/datasets/base/video.py @@ -1,4 +1,21 @@ -"""Base Torch Video Dataset.""" +"""Base Torch Video Dataset. + +This module implements the base video dataset class for anomaly detection tasks that +use video data. The dataset is designed to work with video clips and supports both +classification and segmentation tasks. + +Example: + >>> from anomalib.data.datasets import AnomalibVideoDataset + >>> dataset = AnomalibVideoDataset( + ... clip_length_in_frames=8, + ... frames_between_clips=1, + ... transform=None, + ... target_frame="last" + ... ) + >>> item = dataset[0] + >>> item.image.shape + torch.Size([C, H, W]) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -22,7 +39,14 @@ class VideoTargetFrame(str, Enum): """Target frame for a video-clip. - Used in multi-frame models to determine which frame's ground truth information will be used. + Used in multi-frame models to determine which frame's ground truth information + will be used. + + Args: + FIRST: Use the first frame in the clip as target + LAST: Use the last frame in the clip as target + MID: Use the middle frame in the clip as target + ALL: Use all frames in the clip as target """ FIRST = "first" @@ -34,13 +58,30 @@ class VideoTargetFrame(str, Enum): class AnomalibVideoDataset(AnomalibDataset, ABC): """Base video anomalib dataset class. + This class extends ``AnomalibDataset`` to handle video data for anomaly + detection tasks. It supports both classification and segmentation tasks. + Args: clip_length_in_frames (int): Number of video frames in each clip. - frames_between_clips (int): Number of frames between each consecutive video clip. - transform (Transform, optional): Transforms that should be applied to the input clips. - Defaults to ``None``. - target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. + frames_between_clips (int): Number of frames between each consecutive + video clip. + transform (Transform | None, optional): Transforms to be applied to the + input clips. Defaults to ``None``. + target_frame (VideoTargetFrame, optional): Specifies the target frame in + the video clip, used for ground truth retrieval. Defaults to ``VideoTargetFrame.LAST``. + + Example: + >>> from torchvision.transforms.v2 import Resize + >>> dataset = AnomalibVideoDataset( + ... clip_length_in_frames=8, + ... frames_between_clips=1, + ... transform=Resize((256, 256)), + ... target_frame="last" + ... ) + >>> item = dataset[0] + >>> item.image.shape + torch.Size([C, H, W]) """ def __init__( @@ -62,7 +103,14 @@ def __init__( self.target_frame = target_frame def __len__(self) -> int: - """Get length of the dataset.""" + """Get length of the dataset. + + Returns: + int: Number of clips in the dataset. + + Raises: + TypeError: If ``self.indexer`` is not an instance of ``ClipsIndexer``. + """ if not isinstance(self.indexer, ClipsIndexer): msg = "self.indexer must be an instance of ClipsIndexer." raise TypeError(msg) @@ -70,7 +118,11 @@ def __len__(self) -> int: @property def samples(self) -> DataFrame: - """Get the samples dataframe.""" + """Get the samples dataframe. + + Returns: + DataFrame: DataFrame containing dataset samples. + """ return super().samples @samples.setter @@ -89,7 +141,10 @@ def samples(self, samples: DataFrame) -> None: def _setup_clips(self) -> None: """Compute the video and frame indices of the subvideos. - Should be called after each change to self._samples + Should be called after each change to ``self._samples``. + + Raises: + TypeError: If ``self.indexer_cls`` is not callable. """ if not callable(self.indexer_cls): msg = "self.indexer_cls must be callable." @@ -105,13 +160,13 @@ def _select_targets(self, item: VideoItem) -> VideoItem: """Select the target frame from the clip. Args: - item (DatasetItem): Item containing the clip information. + item (VideoItem): Item containing the clip information. + + Returns: + VideoItem: Selected item from the clip. Raises: ValueError: If the target frame is not one of the supported options. - - Returns: - DatasetItem: Selected item from the clip. """ if self.target_frame == VideoTargetFrame.FIRST: idx = 0 @@ -134,13 +189,17 @@ def _select_targets(self, item: VideoItem) -> VideoItem: return item def __getitem__(self, index: int) -> VideoItem: - """Get the dataset item for the index ``index``. + """Get the dataset item for the index. Args: index (int): Index of the item to be returned. Returns: - DatasetItem: Dictionary containing the mask, clip and file system information. + VideoItem: Dataset item containing the mask, clip and file system + information. + + Raises: + TypeError: If ``self.indexer`` is not an instance of ``ClipsIndexer``. """ if not isinstance(self.indexer, ClipsIndexer): msg = "self.indexer must be an instance of ClipsIndexer." @@ -169,5 +228,9 @@ def __getitem__(self, index: int) -> VideoItem: @property def collate_fn(self) -> Callable: - """Return the collate function for video batches.""" + """Return the collate function for video batches. + + Returns: + Callable: Collate function for creating video batches. + """ return VideoBatch.collate diff --git a/src/anomalib/data/datasets/depth/__init__.py b/src/anomalib/data/datasets/depth/__init__.py index 7d7c5361ee..f77d0ead0d 100644 --- a/src/anomalib/data/datasets/depth/__init__.py +++ b/src/anomalib/data/datasets/depth/__init__.py @@ -1,4 +1,26 @@ -"""Torch Dataset Implementations of Anomalib Depth Datasets.""" +"""Torch Dataset Implementations of Anomalib Depth Datasets. + +This module provides dataset implementations for working with RGB-D (depth) data in +anomaly detection tasks. The following datasets are available: + +- ``Folder3DDataset``: Custom dataset for loading RGB-D data from a folder structure +- ``MVTec3DDataset``: Implementation of the MVTec 3D-AD dataset + +Example: + >>> from anomalib.data.datasets import Folder3DDataset + >>> dataset = Folder3DDataset( + ... name="custom", + ... root="datasets/custom", + ... normal_dir="normal", + ... normal_depth_dir="normal_depth" + ... ) + + >>> from anomalib.data.datasets import MVTec3DDataset + >>> dataset = MVTec3DDataset( + ... root="datasets/MVTec3D", + ... category="bagel" + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/datasets/depth/folder_3d.py b/src/anomalib/data/datasets/depth/folder_3d.py index 0e5247c7bc..5e5d15b3b8 100644 --- a/src/anomalib/data/datasets/depth/folder_3d.py +++ b/src/anomalib/data/datasets/depth/folder_3d.py @@ -1,6 +1,29 @@ -"""Custom Folder Dataset. - -This script creates a custom dataset from a folder. +"""Custom Folder Dataset for 3D anomaly detection. + +This module provides a custom dataset class that loads RGB-D data from a folder +structure. The dataset supports both classification and segmentation tasks. + +The folder structure should contain RGB images and their corresponding depth maps. +The dataset can be configured with separate directories for: + +- Normal training samples +- Normal test samples (optional) +- Abnormal test samples (optional) +- Mask annotations (optional, for segmentation) +- Depth maps for each image type + +Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import Folder3DDataset + >>> dataset = Folder3DDataset( + ... name="custom", + ... root="datasets/custom", + ... normal_dir="normal", + ... abnormal_dir="abnormal", + ... normal_depth_dir="normal_depth", + ... abnormal_depth_dir="abnormal_depth", + ... mask_dir="ground_truth" + ... ) """ # Copyright (C) 2024 Intel Corporation @@ -18,38 +41,43 @@ class Folder3DDataset(AnomalibDepthDataset): - """Folder dataset. + """Dataset class for loading RGB-D data from a custom folder structure. Args: - name (str): Name of the dataset. - transform (Transform): Transforms that should be applied to the input images. - normal_dir (str | Path): Path to the directory containing normal images. - root (str | Path | None): Root folder of the dataset. - Defaults to ``None``. - abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images. - Defaults to ``None``. - normal_test_dir (str | Path | None, optional): Path to the directory containing - normal images for the test dataset. - Defaults to ``None``. - mask_dir (str | Path | None, optional): Path to the directory containing - the mask annotations. - Defaults to ``None``. - normal_depth_dir (str | Path | None, optional): Path to the directory containing - normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir` + name (str): Name of the dataset + normal_dir (str | Path): Path to directory containing normal images + root (str | Path | None, optional): Root directory of the dataset. Defaults to ``None``. - abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for - the test dataset. - Defaults to ``None``. - normal_test_depth_dir (str | Path | None, optional): Path to the directory containing - normal depth images for the test dataset. Normal test images will be a split of `normal_dir` if `None`. - Defaults to ``None``. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - split (str | Split | None): Fixed subset split that follows from folder structure on file system. - Choose from [Split.FULL, Split.TRAIN, Split.TEST] - Defaults to ``None``. - extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + abnormal_dir (str | Path | None, optional): Path to directory containing + abnormal images. Defaults to ``None``. + normal_test_dir (str | Path | None, optional): Path to directory + containing normal test images. If not provided, normal test images + will be split from ``normal_dir``. Defaults to ``None``. + mask_dir (str | Path | None, optional): Path to directory containing + ground truth masks. Required for segmentation. Defaults to ``None``. + normal_depth_dir (str | Path | None, optional): Path to directory + containing depth maps for normal images. Defaults to ``None``. + abnormal_depth_dir (str | Path | None, optional): Path to directory + containing depth maps for abnormal images. Defaults to ``None``. + normal_test_depth_dir (str | Path | None, optional): Path to directory + containing depth maps for normal test images. Defaults to ``None``. + transform (Transform | None, optional): Transforms to apply to the images. Defaults to ``None``. + split (str | Split | None, optional): Dataset split to load. + One of ``["train", "test", "full"]``. Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Image file extensions to + include. Defaults to ``None``. + + Example: + >>> dataset = Folder3DDataset( + ... name="custom", + ... root="./datasets/custom", + ... normal_dir="train/good", + ... abnormal_dir="test/defect", + ... mask_dir="test/ground_truth", + ... normal_depth_dir="train/good_depth", + ... abnormal_depth_dir="test/defect_depth" + ... ) """ def __init__( @@ -96,9 +124,10 @@ def __init__( @property def name(self) -> str: - """Name of the dataset. + """Get dataset name. - Folder3D dataset overrides the name property to provide a custom name. + Returns: + str: Name of the dataset """ return self._name @@ -115,35 +144,38 @@ def make_folder3d_dataset( split: str | Split | None = None, extensions: tuple[str, ...] | None = None, ) -> DataFrame: - """Make Folder Dataset. + """Create a dataset by collecting files from a folder structure. + + The function creates a DataFrame containing paths to RGB images, depth maps, + and masks (if available) along with their corresponding labels. Args: - normal_dir (str | Path): Path to the directory containing normal images. - root (str | Path | None): Path to the root directory of the dataset. - Defaults to ``None``. - abnormal_dir (str | Path | None, optional): Path to the directory containing abnormal images. - Defaults to ``None``. - normal_test_dir (str | Path | None, optional): Path to the directory containing normal images for the test - dataset. Normal test images will be a split of `normal_dir` if `None`. - Defaults to ``None``. - mask_dir (str | Path | None, optional): Path to the directory containing the mask annotations. - Defaults to ``None``. - normal_depth_dir (str | Path | None, optional): Path to the directory containing - normal depth images for the test dataset. Normal test depth images will be a split of `normal_dir` - Defaults to ``None``. - abnormal_depth_dir (str | Path | None, optional): Path to the directory containing abnormal depth images for - the test dataset. - Defaults to ``None``. - normal_test_depth_dir (str | Path | None, optional): Path to the directory containing normal depth images for - the test dataset. Normal test images will be a split of `normal_dir` if `None`. - Defaults to ``None``. - split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST). - Defaults to ``None``. - extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + normal_dir (str | Path): Directory containing normal images + root (str | Path | None, optional): Root directory. Defaults to ``None``. + abnormal_dir (str | Path | None, optional): Directory containing abnormal + images. Defaults to ``None``. + normal_test_dir (str | Path | None, optional): Directory containing + normal test images. Defaults to ``None``. + mask_dir (str | Path | None, optional): Directory containing ground truth + masks. Defaults to ``None``. + normal_depth_dir (str | Path | None, optional): Directory containing + depth maps for normal images. Defaults to ``None``. + abnormal_depth_dir (str | Path | None, optional): Directory containing + depth maps for abnormal images. Defaults to ``None``. + normal_test_depth_dir (str | Path | None, optional): Directory containing + depth maps for normal test images. Defaults to ``None``. + split (str | Split | None, optional): Dataset split to return. Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Image file extensions to + include. Defaults to ``None``. Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + DataFrame: Dataset samples with columns for paths and labels + + Raises: + ValueError: If ``normal_dir`` is not a directory + FileNotFoundError: If depth maps or mask files are missing + MisMatchError: If depth maps don't match their RGB images """ normal_dir = validate_and_resolve_path(normal_dir, root) abnormal_dir = validate_and_resolve_path(abnormal_dir, root) if abnormal_dir else None diff --git a/src/anomalib/data/datasets/depth/mvtec_3d.py b/src/anomalib/data/datasets/depth/mvtec_3d.py index 6dd8ed3752..52873a0e8d 100644 --- a/src/anomalib/data/datasets/depth/mvtec_3d.py +++ b/src/anomalib/data/datasets/depth/mvtec_3d.py @@ -1,19 +1,20 @@ -"""MVTec 3D-AD Datamodule (CC BY-NC-SA 4.0). +"""MVTec 3D-AD Datamodule. -Description: - This script contains PyTorch Dataset, Dataloader and PyTorch Lightning DataModule for the MVTec 3D-AD dataset. - If the dataset is not on the file system, the script downloads and extracts the dataset and create PyTorch data - objects. +This module provides PyTorch Dataset, Dataloader and PyTorch Lightning DataModule for +the MVTec 3D-AD dataset. If the dataset is not available locally, it will be +downloaded and extracted automatically. License: - MVTec 3D-AD dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International - License (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + MVTec 3D-AD dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0) + https://creativecommons.org/licenses/by-nc-sa/4.0/ Reference: - - Paul Bergmann, Xin Jin, David Sattlegger, Carsten Steger: The MVTec 3D-AD Dataset for Unsupervised 3D Anomaly - Detection and Localization in: Proceedings of the 17th International Joint Conference on Computer Vision, - Imaging and Computer Graphics Theory and Applications - Volume 5: VISAPP, 202-213, 2022, DOI: 10.5220/ - 0010865000003124. + Paul Bergmann, Xin Jin, David Sattlegger, Carsten Steger: The MVTec 3D-AD + Dataset for Unsupervised 3D Anomaly Detection and Localization. In: Proceedings + of the 17th International Joint Conference on Computer Vision, Imaging and + Computer Graphics Theory and Applications - Volume 5: VISAPP, 202-213, 2022 + DOI: 10.5220/0010865000003124 """ # Copyright (C) 2024 Intel Corporation @@ -30,21 +31,40 @@ from anomalib.data.utils import LabelName, Split, validate_path IMG_EXTENSIONS = [".png", ".PNG", ".tiff"] -CATEGORIES = ("bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire") +CATEGORIES = ( + "bagel", + "cable_gland", + "carrot", + "cookie", + "dowel", + "foam", + "peach", + "potato", + "rope", + "tire", +) class MVTec3DDataset(AnomalibDepthDataset): """MVTec 3D dataset class. Args: - root (Path | str): Path to the root of the dataset + root (Path | str): Path to the root of the dataset. Defaults to ``"./datasets/MVTec3D"``. - category (str): Sub-category of the dataset, e.g. 'bagel' + category (str): Category name, e.g. ``"bagel"``. Defaults to ``"bagel"``. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + transform (Transform, optional): Transforms applied to input images. Defaults to ``None``. + split (str | Split | None): Dataset split - usually ``Split.TRAIN`` or + ``Split.TEST``. Defaults to ``None``. + + Example: + >>> from pathlib import Path + >>> dataset = MVTec3DDataset( + ... root=Path("./datasets/MVTec3D"), + ... category="bagel", + ... split="train" + ... ) """ def __init__( @@ -58,7 +78,11 @@ def __init__( self.root_category = Path(root) / Path(category) self.split = split - self.samples = make_mvtec_3d_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS) + self.samples = make_mvtec_3d_dataset( + self.root_category, + split=self.split, + extensions=IMG_EXTENSIONS, + ) def make_mvtec_3d_dataset( @@ -66,45 +90,44 @@ def make_mvtec_3d_dataset( split: str | Split | None = None, extensions: Sequence[str] | None = None, ) -> DataFrame: - """Create MVTec 3D-AD samples by parsing the MVTec AD data file structure. + """Create MVTec 3D-AD samples by parsing the data directory structure. - The files are expected to follow this structure: - - `path/to/dataset/split/category/image_filename.png` - - `path/to/dataset/ground_truth/category/mask_filename.png` + The files are expected to follow this structure:: - This function creates a DataFrame to store the parsed information. The DataFrame follows this format: + path/to/dataset/split/category/image_filename.png + path/to/dataset/ground_truth/category/mask_filename.png - +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ - | | path | split | label | image_path | mask_path | label_index | - +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ - | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 | - +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + The function creates a DataFrame with the following format:: + + +---+---------------+-------+---------+---------------+--------------------+ + | | path | split | label | image_path | mask_path | + +---+---------------+-------+---------+---------------+--------------------+ + | 0 | datasets/name | test | defect | filename.png | defect/mask.png | + +---+---------------+-------+---------+---------------+--------------------+ Args: - root (Path): Path to the dataset. - split (str | Split | None, optional): Dataset split (e.g., 'train' or 'test'). - Defaults to ``None``. - extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset. + root (Path | str): Path to the dataset root directory. + split (str | Split | None, optional): Dataset split (e.g., ``"train"`` or + ``"test"``). Defaults to ``None``. + extensions (Sequence[str] | None, optional): List of valid file extensions. Defaults to ``None``. - Examples: - The following example shows how to get training samples from the MVTec 3D-AD 'bagel' category: + Returns: + DataFrame: DataFrame containing the dataset samples. + Example: >>> from pathlib import Path - >>> root = Path('./MVTec3D') - >>> category = 'bagel' - >>> path = root / category - >>> print(path) - PosixPath('MVTec3D/bagel') - - >>> samples = create_mvtec_3d_ad_samples(path, split='train') - >>> print(samples.head()) - path split label image_path mask_path label_index - MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/105.png MVTec3D/bagel/ground_truth/good/gt/105.png 0 - MVTec3D/bagel train good MVTec3D/bagel/train/good/rgb/017.png MVTec3D/bagel/ground_truth/good/gt/017.png 0 - - Returns: - DataFrame: An output DataFrame containing the samples of the dataset. + >>> root = Path("./datasets/MVTec3D/bagel") + >>> samples = make_mvtec_3d_dataset(root, split="train") + >>> samples.head() + path split label image_path mask_path + 0 MVTec3D train good train/good/rgb/105.png gt/105.png + 1 MVTec3D train good train/good/rgb/017.png gt/017.png + + Raises: + RuntimeError: If no images are found in the root directory. + MisMatchError: If there is a mismatch between images and their + corresponding mask/depth files. """ if extensions is None: extensions = IMG_EXTENSIONS @@ -115,7 +138,10 @@ def make_mvtec_3d_dataset( msg = f"Found 0 images in {root}" raise RuntimeError(msg) - samples = DataFrame(samples_list, columns=["path", "split", "label", "type", "file_name"]) + samples = DataFrame( + samples_list, + columns=["path", "split", "label", "type", "file_name"], + ) # Modify image_path column by converting to absolute path samples.loc[(samples.type == "rgb"), "image_path"] = ( @@ -159,9 +185,11 @@ def make_mvtec_3d_dataset( .all() ) if not mismatch_masks: - msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files - in 'ground_truth' folder follow the same naming convention as the anomalous images in - the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png').""" + msg = ( + "Mismatch between anomalous images and ground truth masks. Ensure mask " + "files in 'ground_truth' folder follow the same naming convention as " + "the anomalous images (e.g. image: '000.png', mask: '000.png')." + ) raise MisMatchError(msg) mismatch_depth = ( @@ -170,9 +198,11 @@ def make_mvtec_3d_dataset( .all() ) if not mismatch_depth: - msg = """Mismatch between anomalous images and depth images. Make sure the mask files in - 'xyz' folder follow the same naming convention as the anomalous images in the dataset - (e.g. image: '000.png', depth: '000.tiff').""" + msg = ( + "Mismatch between anomalous images and depth images. Ensure depth " + "files in 'xyz' folder follow the same naming convention as the " + "anomalous images (e.g. image: '000.png', depth: '000.tiff')." + ) raise MisMatchError(msg) # infer the task type diff --git a/src/anomalib/data/datasets/image/__init__.py b/src/anomalib/data/datasets/image/__init__.py index b7749dad18..e319b8a36f 100644 --- a/src/anomalib/data/datasets/image/__init__.py +++ b/src/anomalib/data/datasets/image/__init__.py @@ -1,4 +1,23 @@ -"""Torch Dataset Implementations of Anomalib Image Datasets.""" +"""PyTorch Dataset implementations for anomaly detection in images. + +This module provides dataset implementations for various image anomaly detection +datasets: + +- ``BTechDataset``: BTech dataset containing industrial objects +- ``DatumaroDataset``: Dataset in Datumaro format (Intel Geti™ export) +- ``FolderDataset``: Custom dataset from folder structure +- ``KolektorDataset``: Kolektor surface defect dataset +- ``MVTecDataset``: MVTec AD dataset with industrial objects +- ``VisaDataset``: Visual Inspection of Surface Anomalies dataset + +Example: + >>> from anomalib.data.datasets import MVTecDataset + >>> dataset = MVTecDataset( + ... root="./datasets/MVTec", + ... category="bottle", + ... split="train" + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/datasets/image/btech.py b/src/anomalib/data/datasets/image/btech.py index 3078c99e12..04e4278491 100644 --- a/src/anomalib/data/datasets/image/btech.py +++ b/src/anomalib/data/datasets/image/btech.py @@ -1,9 +1,21 @@ """BTech Dataset. -This script contains PyTorch Dataset for the BTech dataset. - -If the dataset is not on the file system, the script downloads and -extracts the dataset and create PyTorch data objects. +This module provides PyTorch Dataset implementation for the BTech dataset. The +dataset will be downloaded and extracted automatically if not found locally. + +The dataset contains 3 categories of industrial objects with both normal and +anomalous samples. Each category includes RGB images and pixel-level ground truth +masks for anomaly segmentation. + +License: + BTech dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0) https://creativecommons.org/licenses/by-nc-sa/4.0/ + +Reference: + Mishra, P., Verk, C., Fornasier, D., & Piciarelli, C. (2021). VT-ADL: A + Vision Transformer Network for Image Anomaly Detection and Localization. In + IEEE International Conference on Image Processing (ICIP), 2021. """ # Copyright (C) 2024 Intel Corporation @@ -22,41 +34,38 @@ class BTechDataset(AnomalibDataset): - """Btech Dataset class. + """BTech dataset class. + + Dataset class for loading and processing BTech dataset images. Supports both + classification and segmentation tasks. Args: - root: Path to the BTech dataset - category: Name of the BTech category. - transform (Transform, optional): Transforms that should be applied to the input images. + root (Path | str): Path to root directory containing the dataset. + category (str): Category name, must be one of ``CATEGORIES``. + transform (Transform | None, optional): Transforms to apply to the images. Defaults to ``None``. - split: 'train', 'val' or 'test' - create_validation_set: Create a validation subset in addition to the train and test subsets + split (str | Split | None, optional): Dataset split - usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. - Examples: - >>> from anomalib.data.image.btech import BTechDataset - >>> from anomalib.data.utils.transforms import get_transforms - >>> transform = get_transforms(image_size=256) + Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import BTechDataset >>> dataset = BTechDataset( - ... transform=transform, - ... root='./datasets/BTech', - ... category='01', + ... root=Path("./datasets/btech"), + ... category="01", + ... split="train" ... ) >>> dataset[0].keys() - >>> dataset.setup() dict_keys(['image']) >>> dataset.split = "test" >>> dataset[0].keys() dict_keys(['image', 'image_path', 'label']) - >>> dataset.split = "train" - >>> dataset[0].keys() - dict_keys(['image']) - + >>> # For segmentation task >>> dataset.split = "test" >>> dataset[0].keys() dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) - >>> dataset[0]["image"].shape, dataset[0]["mask"].shape (torch.Size([3, 256, 256]), torch.Size([256, 256])) """ @@ -80,36 +89,35 @@ def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFram The files are expected to follow the structure: - .. code-block:: bash + .. code-block:: bash - path/to/dataset/split/category/image_filename.png - path/to/dataset/ground_truth/category/mask_filename.png + path/to/dataset/ + ├── split/ + │ └── category/ + │ └── image_filename.png + └── ground_truth/ + └── category/ + └── mask_filename.png Args: - path (Path): Path to dataset - split (str | Split | None, optional): Dataset split (ie., either train or test). - Defaults to ``None``. + path (Path): Path to dataset directory. + split (str | Split | None, optional): Dataset split - usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. Example: - The following example shows how to get training samples from BTech 01 category: - - .. code-block:: python - - >>> root = Path('./BTech') - >>> category = '01' - >>> path = root / category - >>> path - PosixPath('BTech/01') - - >>> samples = make_btech_dataset(path, split='train') - >>> samples.head() - path split label image_path mask_path label_index - 0 BTech/01 train 01 BTech/01/train/ok/105.bmp BTech/01/ground_truth/ok/105.png 0 - 1 BTech/01 train 01 BTech/01/train/ok/017.bmp BTech/01/ground_truth/ok/017.png 0 - ... + >>> from pathlib import Path + >>> path = Path("./datasets/btech/01") + >>> samples = make_btech_dataset(path, split="train") + >>> samples.head() + path split label image_path mask_path label_index + 0 BTech/01 train ok BTech/01/train/ok/105.bmp BTech/01/gt/ok/105.png 0 + 1 BTech/01 train ok BTech/01/train/ok/017.bmp BTech/01/gt/ok/017.png 0 Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + DataFrame: DataFrame containing samples for the requested split. + + Raises: + RuntimeError: If no images are found in the dataset directory. """ path = validate_path(path) diff --git a/src/anomalib/data/datasets/image/datumaro.py b/src/anomalib/data/datasets/image/datumaro.py index 9335f0a4b4..e6a65c0c54 100644 --- a/src/anomalib/data/datasets/image/datumaro.py +++ b/src/anomalib/data/datasets/image/datumaro.py @@ -1,6 +1,30 @@ """Dataloader for Datumaro format. -Note: This currently only works for annotations exported from Intel Geti™. +This module provides PyTorch Dataset implementation for loading images and +annotations in Datumaro format. Currently only supports annotations exported from +Intel Geti™. + +The dataset expects the following directory structure:: + + dataset/ + ├── annotations/ + │ └── default.json + └── images/ + └── default/ + ├── image1.jpg + ├── image2.jpg + └── ... + +The ``default.json`` file contains image paths and label annotations in Datumaro +format. + +Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import DatumaroDataset + >>> dataset = DatumaroDataset( + ... root=Path("./datasets/datumaro"), + ... split="train" + ... ) """ # Copyright (C) 2024 Intel Corporation @@ -16,39 +40,33 @@ from anomalib.data.utils import LabelName, Split -def make_datumaro_dataset(root: str | Path, split: str | Split | None = None) -> pd.DataFrame: - """Make Datumaro Dataset. - - Assumes the following directory structure: - - dataset - ├── annotations - │ └── default.json - └── images - └── default - ├── image1.jpg - ├── image2.jpg - └── ... +def make_datumaro_dataset( + root: str | Path, + split: str | Split | None = None, +) -> pd.DataFrame: + """Create a DataFrame of image samples from a Datumaro dataset. Args: root (str | Path): Path to the dataset root directory. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST. - Defaults to ``None``. - - Examples: - >>> root = Path("path/to/dataset") - >>> samples = make_datumaro_dataset(root) - >>> samples.head() - image_path label label_index split mask_path - 0 path/to/dataset... Normal 0 Split.TRAIN - 1 path/to/dataset... Normal 0 Split.TRAIN - 2 path/to/dataset... Normal 0 Split.TRAIN - 3 path/to/dataset... Normal 0 Split.TRAIN - 4 path/to/dataset... Normal 0 Split.TRAIN - + split (str | Split | None, optional): Dataset split to load. Usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test). + pd.DataFrame: DataFrame containing samples with columns: + - ``image_path``: Path to the image file + - ``label``: Class label name + - ``label_index``: Numeric label index + - ``split``: Dataset split + - ``mask_path``: Path to mask file (empty for classification) + + Example: + >>> root = Path("./datasets/datumaro") + >>> samples = make_datumaro_dataset(root) + >>> samples.head() # doctest: +NORMALIZE_WHITESPACE + image_path label label_index split mask_path + 0 path/... Normal 0 Split.TRAIN + 1 path/... Normal 0 Split.TRAIN + 2 path/... Normal 0 Split.TRAIN """ annotation_file = Path(root) / "annotations" / "default.json" with annotation_file.open() as f: @@ -67,7 +85,7 @@ def make_datumaro_dataset(root: str | Path, split: str | Split | None = None) -> "label": label, "label_index": label_index, "split": None, - "mask_path": "", # mask is provided in the annotation file and is not on disk. + "mask_path": "", # mask is provided in annotation file }) samples_df = pd.DataFrame( samples, @@ -75,7 +93,7 @@ def make_datumaro_dataset(root: str | Path, split: str | Split | None = None) -> index=range(len(samples)), ) # Create test/train split - # By default assign all "Normal" samples to train and all "Anomalous" samples to test + # By default assign all "Normal" samples to train and all "Anomalous" to test samples_df.loc[samples_df["label_index"] == LabelName.NORMAL, "split"] = Split.TRAIN samples_df.loc[samples_df["label_index"] == LabelName.ABNORMAL, "split"] = Split.TEST @@ -90,30 +108,24 @@ def make_datumaro_dataset(root: str | Path, split: str | Split | None = None) -> class DatumaroDataset(AnomalibDataset): - """Datumaro dataset class. + """Dataset class for loading Datumaro format datasets. Args: - task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. root (str | Path): Path to the dataset root directory. - transform (Transform, optional): Transforms that should be applied to the input images. + transform (Transform | None, optional): Transforms to apply to the images. Defaults to ``None``. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST - Defaults to ``None``. - - - Examples: - .. code-block:: python - - from anomalib.data.image.datumaro import DatumaroDataset - from torchvision.transforms.v2 import Resize - - dataset = DatumaroDataset(root=root, - task="classification", - transform=Resize((256, 256)), - ) - print(dataset[0].keys()) - # Output: dict_keys(['dm_format_version', 'infos', 'categories', 'items']) - + split (str | Split | None, optional): Dataset split to load. Usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. + + Example: + >>> from pathlib import Path + >>> from torchvision.transforms.v2 import Resize + >>> from anomalib.data.datasets import DatumaroDataset + >>> dataset = DatumaroDataset( + ... root=Path("./datasets/datumaro"), + ... transform=Resize((256, 256)), + ... split="train" + ... ) """ def __init__( diff --git a/src/anomalib/data/datasets/image/folder.py b/src/anomalib/data/datasets/image/folder.py index 08e01d85c2..dc64e06af8 100644 --- a/src/anomalib/data/datasets/image/folder.py +++ b/src/anomalib/data/datasets/image/folder.py @@ -1,6 +1,22 @@ """Custom Folder Dataset. -This script creates a custom PyTorch Dataset from a folder. +This module provides a custom PyTorch Dataset implementation for loading images +from a folder structure. The dataset supports both classification and +segmentation tasks. + +The folder structure should contain normal images and optionally abnormal images, +test images, and mask annotations. + +Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import FolderDataset + >>> dataset = FolderDataset( + ... name="custom", + ... root="datasets/custom", + ... normal_dir="normal", + ... abnormal_dir="abnormal", + ... mask_dir="ground_truth" + ... ) """ # Copyright (C) 2024 Intel Corporation @@ -19,49 +35,55 @@ class FolderDataset(AnomalibDataset): - """Folder dataset. - - This class is used to create a dataset from a folder. The class utilizes the Torch Dataset class. + """Dataset class for loading images from a custom folder structure. Args: - name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - normal_dir (str | Path | Sequence): Path to the directory containing normal images. - root (str | Path | None): Root folder of the dataset. - Defaults to ``None``. - abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images. + name (str): Name of the dataset. Used for logging/saving. + normal_dir (str | Path | Sequence): Path to directory containing normal + images. + transform (Transform | None, optional): Transforms to apply to the images. Defaults to ``None``. - normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing - normal images for the test dataset. + root (str | Path | None, optional): Root directory of the dataset. Defaults to ``None``. - mask_dir (str | Path | Sequence | None, optional): Path to the directory containing - the mask annotations. + abnormal_dir (str | Path | Sequence | None, optional): Path to directory + containing abnormal images. Defaults to ``None``. + normal_test_dir (str | Path | Sequence | None, optional): Path to + directory containing normal test images. If not provided, normal test + images will be split from ``normal_dir``. Defaults to ``None``. + mask_dir (str | Path | Sequence | None, optional): Path to directory + containing ground truth masks. Required for segmentation. Defaults to ``None``. - split (str | Split | None): Fixed subset split that follows from folder structure on file system. - Choose from [Split.FULL, Split.TRAIN, Split.TEST] - Defaults to ``None``. - extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + split (str | Split | None, optional): Dataset split to load. + Choose from ``Split.FULL``, ``Split.TRAIN``, ``Split.TEST``. Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Image file extensions to + include. Defaults to ``None``. Examples: - Assume that we would like to use this ``FolderDataset`` to create a dataset from a folder for a classification - task. We could first create the transforms, + Create a classification dataset: >>> from anomalib.data.utils import InputNormalizationMethod, get_transforms - >>> transform = get_transforms(image_size=256, normalization=InputNormalizationMethod.NONE) - - We could then create the dataset as follows, - - .. code-block:: python - - folder_dataset_classification_train = FolderDataset( - normal_dir=dataset_root / "good", - abnormal_dir=dataset_root / "crack", - split="train", - transform=transform, - ) - + >>> transform = get_transforms( + ... image_size=256, + ... normalization=InputNormalizationMethod.NONE + ... ) + >>> dataset = FolderDataset( + ... name="custom", + ... normal_dir="datasets/custom/good", + ... abnormal_dir="datasets/custom/defect", + ... split="train", + ... transform=transform + ... ) + + Create a segmentation dataset: + + >>> dataset = FolderDataset( + ... name="custom", + ... normal_dir="datasets/custom/good", + ... abnormal_dir="datasets/custom/defect", + ... mask_dir="datasets/custom/ground_truth", + ... split="test" + ... ) """ def __init__( @@ -99,9 +121,10 @@ def __init__( @property def name(self) -> str: - """Name of the dataset. + """Get dataset name. - Folder dataset overrides the name property to provide a custom name. + Returns: + str: Name of the dataset """ return self._name @@ -115,64 +138,62 @@ def make_folder_dataset( split: str | Split | None = None, extensions: tuple[str, ...] | None = None, ) -> DataFrame: - """Make Folder Dataset. + """Create a dataset from a folder structure. Args: - normal_dir (str | Path | Sequence): Path to the directory containing normal images. - root (str | Path | None): Path to the root directory of the dataset. + normal_dir (str | Path | Sequence): Path to directory containing normal + images. + root (str | Path | None, optional): Root directory of the dataset. Defaults to ``None``. - abnormal_dir (str | Path | Sequence | None, optional): Path to the directory containing abnormal images. + abnormal_dir (str | Path | Sequence | None, optional): Path to directory + containing abnormal images. Defaults to ``None``. + normal_test_dir (str | Path | Sequence | None, optional): Path to + directory containing normal test images. If not provided, normal test + images will be split from ``normal_dir``. Defaults to ``None``. + mask_dir (str | Path | Sequence | None, optional): Path to directory + containing ground truth masks. Required for segmentation. Defaults to ``None``. - normal_test_dir (str | Path | Sequence | None, optional): Path to the directory containing normal images for - the test dataset. Normal test images will be a split of `normal_dir` if `None`. - Defaults to ``None``. - mask_dir (str | Path | Sequence | None, optional): Path to the directory containing the mask annotations. - Defaults to ``None``. - split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST). - Defaults to ``None``. - extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the directory. + split (str | Split | None, optional): Dataset split to load. + Choose from ``Split.FULL``, ``Split.TRAIN``, ``Split.TEST``. Defaults to ``None``. + extensions (tuple[str, ...] | None, optional): Image file extensions to + include. Defaults to ``None``. Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test). + DataFrame: Dataset samples with columns for image paths, labels, splits + and mask paths (for segmentation). Examples: - Assume that we would like to use this ``make_folder_dataset`` to create a dataset from a folder. - We could then create the dataset as follows, - - .. code-block:: python - - folder_df = make_folder_dataset( - normal_dir=dataset_root / "good", - abnormal_dir=dataset_root / "crack", - split="train", - ) - folder_df.head() - - .. code-block:: bash - - image_path label label_index mask_path split - 0 ./toy/good/00.jpg DirType.NORMAL 0 Split.TRAIN - 1 ./toy/good/01.jpg DirType.NORMAL 0 Split.TRAIN - 2 ./toy/good/02.jpg DirType.NORMAL 0 Split.TRAIN - 3 ./toy/good/03.jpg DirType.NORMAL 0 Split.TRAIN - 4 ./toy/good/04.jpg DirType.NORMAL 0 Split.TRAIN + Create a classification dataset: + + >>> folder_df = make_folder_dataset( + ... normal_dir="datasets/custom/good", + ... abnormal_dir="datasets/custom/defect", + ... split="train" + ... ) + >>> folder_df.head() + image_path label label_index mask_path split + 0 ./good/00.png DirType.NORMAL 0 Split.TRAIN + 1 ./good/01.png DirType.NORMAL 0 Split.TRAIN + 2 ./good/02.png DirType.NORMAL 0 Split.TRAIN + 3 ./good/03.png DirType.NORMAL 0 Split.TRAIN + 4 ./good/04.png DirType.NORMAL 0 Split.TRAIN """ def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] | None) -> list[Path]: """Convert path to list of paths. Args: - path (str | Path | Sequence | None): Path to replace with Sequence[str | Path]. + path (str | Path | Sequence | None): Path to convert. + + Returns: + list[Path]: List of resolved paths. Examples: >>> _resolve_path_and_convert_to_list("dir") [Path("path/to/dir")] >>> _resolve_path_and_convert_to_list(["dir1", "dir2"]) [Path("path/to/dir1"), Path("path/to/dir2")] - - Returns: - list[Path]: The result of path replaced by Sequence[str | Path]. """ if isinstance(path, Sequence) and not isinstance(path, str): return [validate_and_resolve_path(dir_path, root) for dir_path in path] @@ -232,15 +253,17 @@ def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] | .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) .all() ): - msg = """Mismatch between anomalous images and mask images. Make sure the mask files " - "folder follow the same naming convention as the anomalous images in the dataset " - "(e.g. image: '000.png', mask: '000.png').""" + msg = """Mismatch between anomalous images and mask images. Make sure + the mask files folder follow the same naming convention as the + anomalous images in the dataset (e.g. image: '000.png', + mask: '000.png').""" raise MisMatchError(msg) else: samples["mask_path"] = "" - # remove all the rows with temporal image samples that have already been assigned + # remove all the rows with temporal image samples that have already been + # assigned samples = samples.loc[ (samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST) ] @@ -253,7 +276,10 @@ def _resolve_path_and_convert_to_list(path: str | Path | Sequence[str | Path] | # By default, all the normal samples are assigned as train. # and all the abnormal samples are test. samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN - samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST + samples.loc[ + (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), + "split", + ] = Split.TEST # infer the task type samples.attrs["task"] = "classification" if (samples["mask_path"] == "").all() else "segmentation" diff --git a/src/anomalib/data/datasets/image/kolektor.py b/src/anomalib/data/datasets/image/kolektor.py index 410d2191cf..a5ddfe6d97 100644 --- a/src/anomalib/data/datasets/image/kolektor.py +++ b/src/anomalib/data/datasets/image/kolektor.py @@ -1,17 +1,20 @@ """Kolektor Surface-Defect Dataset. Description: - This script provides a PyTorch Dataset for the Kolektor - Surface-Defect dataset. The dataset can be accessed at `Kolektor Surface-Defect Dataset `_. + This module provides a PyTorch Dataset implementation for the Kolektor + Surface-Defect dataset. The dataset can be accessed at `Kolektor + Surface-Defect Dataset `_. License: - The Kolektor Surface-Defect dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike - 4.0 International License (CC BY-NC-SA 4.0). For more details, visit - `Creative Commons License `_. + The Kolektor Surface-Defect dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0). For more details, visit `Creative Commons License + `_. Reference: - Tabernik, Domen, Samo Šela, Jure Skvarč, and Danijel Skočaj. "Segmentation-based deep-learning approach - for surface-defect detection." Journal of Intelligent Manufacturing 31, no. 3 (2020): 759-776. + Tabernik, Domen, Samo Šela, Jure Skvarč, and Danijel Skočaj. + "Segmentation-based deep-learning approach for surface-defect detection." + Journal of Intelligent Manufacturing 31, no. 3 (2020): 759-776. """ # Copyright (C) 2024 Intel Corporation @@ -34,13 +37,20 @@ class KolektorDataset(AnomalibDataset): """Kolektor dataset class. Args: - task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation`` - root (Path | str): Path to the root of the dataset - Defaults to ``./datasets/kolektor``. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST - Defaults to ``None``. + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/kolektor"``. + transform (Transform | None, optional): Transforms that should be applied + to the input images. Defaults to ``None``. + split (str | Split | None, optional): Split of the dataset, usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. + + Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import KolektorDataset + >>> dataset = KolektorDataset( + ... root=Path("./datasets/kolektor"), + ... split="train" + ... ) """ def __init__( @@ -53,7 +63,11 @@ def __init__( self.root = root self.split = split - self.samples = make_kolektor_dataset(self.root, train_split_ratio=0.8, split=self.split) + self.samples = make_kolektor_dataset( + self.root, + train_split_ratio=0.8, + split=self.split, + ) def make_kolektor_dataset( @@ -64,40 +78,40 @@ def make_kolektor_dataset( """Create Kolektor samples by parsing the Kolektor data file structure. The files are expected to follow this structure: - - Image files: `path/to/dataset/item/image_filename.jpg`, `path/to/dataset/kos01/Part0.jpg` - - Mask files: `path/to/dataset/item/mask_filename.bmp`, `path/to/dataset/kos01/Part0_label.bmp` - - This function creates a DataFrame to store the parsed information in the following format: - - +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+ - | | path | item | split | label | image_path | mask_path | label_index | - +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+ - | 0 | KolektorSDD | kos01 | test | Bad | /path/to/image_file | /path/to/mask_file | 1 | - +---+-------------------+--------+-------+---------+-----------------------+------------------------+-------------+ + - Image files: ``path/to/dataset/item/image_filename.jpg`` + - Mask files: ``path/to/dataset/item/mask_filename.bmp`` + + Example file paths: + - ``path/to/dataset/kos01/Part0.jpg`` + - ``path/to/dataset/kos01/Part0_label.bmp`` + + This function creates a DataFrame with the following columns: + - ``path``: Base path to dataset + - ``item``: Item/component name + - ``split``: Dataset split (train/test) + - ``label``: Class label (Good/Bad) + - ``image_path``: Path to image file + - ``mask_path``: Path to mask file + - ``label_index``: Numeric label (0=good, 1=bad) Args: - root (Path): Path to the dataset. - train_split_ratio (float, optional): Ratio for splitting good images into train/test sets. - Defaults to ``0.8``. - split (str | Split | None, optional): Dataset split (either 'train' or 'test'). + root (str | Path): Path to the dataset root directory. + train_split_ratio (float, optional): Ratio for splitting good images into + train/test sets. Defaults to ``0.8``. + split (str | Split | None, optional): Dataset split (train/test). Defaults to ``None``. Returns: - pandas.DataFrame: An output DataFrame containing the samples of the dataset. + DataFrame: DataFrame containing the dataset samples. Example: - The following example shows how to get training samples from the Kolektor Dataset: - >>> from pathlib import Path - >>> root = Path('./KolektorSDD/') - >>> samples = create_kolektor_samples(root, train_split_ratio=0.8) + >>> root = Path('./datasets/kolektor') + >>> samples = make_kolektor_dataset(root, train_split_ratio=0.8) >>> samples.head() - path item split label image_path mask_path label_index - 0 KolektorSDD kos01 train Good KolektorSDD/kos01/Part0.jpg KolektorSDD/kos01/Part0_label.bmp 0 - 1 KolektorSDD kos01 train Good KolektorSDD/kos01/Part1.jpg KolektorSDD/kos01/Part1_label.bmp 0 - 2 KolektorSDD kos01 train Good KolektorSDD/kos01/Part2.jpg KolektorSDD/kos01/Part2_label.bmp 0 - 3 KolektorSDD kos01 test Good KolektorSDD/kos01/Part3.jpg KolektorSDD/kos01/Part3_label.bmp 0 - 4 KolektorSDD kos01 train Good KolektorSDD/kos01/Part4.jpg KolektorSDD/kos01/Part4_label.bmp 0 + path item split label image_path mask_path label_index + 0 kolektor kos01 train Good kos01/Part0.jpg Part0.bmp 0 + 1 kolektor kos01 train Good kos01/Part1.jpg Part1.bmp 0 """ root = validate_path(root) @@ -145,7 +159,17 @@ def make_kolektor_dataset( samples.loc[test_samples.index, "split"] = "test" # Reorder columns - samples = samples[["path", "item", "split", "label", "image_path", "mask_path", "label_index"]] + samples = samples[ + [ + "path", + "item", + "split", + "label", + "image_path", + "mask_path", + "label_index", + ] + ] # assert that the right mask files are associated with the right test images if not ( @@ -153,9 +177,10 @@ def make_kolektor_dataset( .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) .all() ): - msg = """Mismatch between anomalous images and ground truth masks. Make sure the mask files - follow the same naming convention as the anomalous images in the dataset - (e.g. image: 'Part0.jpg', mask: 'Part0_label.bmp').""" + msg = """Mismatch between anomalous images and ground truth masks. Make + sure the mask files follow the same naming convention as the anomalous + images in the dataset (e.g. image: 'Part0.jpg', mask: + 'Part0_label.bmp').""" raise MisMatchError(msg) # infer the task type @@ -175,14 +200,11 @@ def is_mask_anomalous(path: str) -> int: path (str): Path to the mask file. Returns: - int: 1 if the mask shows defects, 0 otherwise. + int: ``1`` if the mask shows defects, ``0`` otherwise. Example: - Assume that the following image is a mask for a defective image. - Then the function will return 1. - - >>> from anomalib.data.image.kolektor import is_mask_anomalous - >>> path = './KolektorSDD/kos01/Part0_label.bmp' + >>> from anomalib.data.datasets.image.kolektor import is_mask_anomalous + >>> path = './datasets/kolektor/kos01/Part0_label.bmp' >>> is_mask_anomalous(path) 1 """ diff --git a/src/anomalib/data/datasets/image/mvtec.py b/src/anomalib/data/datasets/image/mvtec.py index c07cdf34e4..63b95bee61 100644 --- a/src/anomalib/data/datasets/image/mvtec.py +++ b/src/anomalib/data/datasets/image/mvtec.py @@ -1,25 +1,27 @@ """MVTec AD Dataset. -Description: - This script contains PyTorch Dataset for the MVTec AD dataset. - If the dataset is not on the file system, the script downloads and extracts - the dataset and create PyTorch data objects. +This module provides PyTorch Dataset implementation for the MVTec AD dataset. The +dataset will be downloaded and extracted automatically if not found locally. + +The dataset contains 15 categories of industrial objects with both normal and +anomalous samples. Each category includes RGB images and pixel-level ground truth +masks for anomaly segmentation. License: MVTec AD dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License - (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). - -References: - - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger: - The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for - Unsupervised Anomaly Detection; in: International Journal of Computer Vision - 129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4. - - - Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD — - A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; - in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), - 9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982. + (CC BY-NC-SA 4.0) https://creativecommons.org/licenses/by-nc-sa/4.0/ + +Reference: + Bergmann, P., Batzner, K., Fauser, M., Sattlegger, D., & Steger, C. (2021). + The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for + Unsupervised Anomaly Detection. International Journal of Computer Vision, + 129(4), 1038-1059. + + Bergmann, P., Fauser, M., Sattlegger, D., & Steger, C. (2019). MVTec AD — + A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection. In + IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), + 9584-9592. """ # Copyright (C) 2024 Intel Corporation @@ -58,49 +60,46 @@ class MVTecDataset(AnomalibDataset): """MVTec dataset class. + Dataset class for loading and processing MVTec AD dataset images. Supports + both classification and segmentation tasks. + Args: - root (Path | str): Path to the root of the dataset. - Defaults to ``./datasets/MVTec``. - category (str): Sub-category of the dataset, e.g. 'bottle' - Defaults to ``bottle``. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + root (Path | str): Path to root directory containing the dataset. + Defaults to ``"./datasets/MVTec"``. + category (str): Category name, must be one of ``CATEGORIES``. + Defaults to ``"bottle"``. + transform (Transform | None, optional): Transforms to apply to the images. Defaults to ``None``. + split (str | Split | None, optional): Dataset split - usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. - Examples: - .. code-block:: python - - from anomalib.data.image.mvtec import MVTecDataset - from anomalib.data.utils.transforms import get_transforms - - transform = get_transforms(image_size=256) - dataset = MVTecDataset( - task="classification", - transform=transform, - root='./datasets/MVTec', - category='zipper', - ) - dataset.setup() - print(dataset[0].keys()) - # Output: dict_keys(['image_path', 'label', 'image']) - - When the task is segmentation, the dataset will also contain the mask: + Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import MVTecDataset + >>> dataset = MVTecDataset( + ... root=Path("./datasets/MVTec"), + ... category="bottle", + ... split="train" + ... ) - .. code-block:: python + For classification tasks, each sample contains: - dataset.task = "segmentation" - dataset.setup() - print(dataset[0].keys()) - # Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + >>> sample = dataset[0] + >>> list(sample.keys()) + ['image_path', 'label', 'image'] - The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W). + For segmentation tasks, samples also include mask paths and masks: - .. code-block:: python + >>> dataset.task = "segmentation" + >>> sample = dataset[0] + >>> list(sample.keys()) + ['image_path', 'label', 'image', 'mask_path', 'mask'] - print(dataset[0]["image"].shape, dataset[0]["mask"].shape) - # Output: (torch.Size([3, 256, 256]), torch.Size([256, 256])) + Images are PyTorch tensors with shape ``(C, H, W)``, masks have shape + ``(H, W)``: + >>> sample["image"].shape, sample["mask"].shape + (torch.Size([3, 256, 256]), torch.Size([256, 256])) """ def __init__( @@ -115,7 +114,11 @@ def __init__( self.root_category = Path(root) / Path(category) self.category = category self.split = split - self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS) + self.samples = make_mvtec_dataset( + self.root_category, + split=self.split, + extensions=IMG_EXTENSIONS, + ) def make_mvtec_dataset( @@ -123,47 +126,39 @@ def make_mvtec_dataset( split: str | Split | None = None, extensions: Sequence[str] | None = None, ) -> DataFrame: - """Create MVTec AD samples by parsing the MVTec AD data file structure. + """Create MVTec AD samples by parsing the data directory structure. The files are expected to follow the structure: - path/to/dataset/split/category/image_filename.png - path/to/dataset/ground_truth/category/mask_filename.png - - This function creates a dataframe to store the parsed information based on the following format: - - +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ - | | path | split | label | image_path | mask_path | label_index | - +===+===============+=======+=========+===============+=======================================+=============+ - | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 | - +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + ``path/to/dataset/split/category/image_filename.png`` + ``path/to/dataset/ground_truth/category/mask_filename.png`` Args: - root (Path): Path to dataset - split (str | Split | None, optional): Dataset split (ie., either train or test). + root (Path | str): Path to dataset root directory + split (str | Split | None, optional): Dataset split (train or test) Defaults to ``None``. - extensions (Sequence[str] | None, optional): List of file extensions to be included in the dataset. + extensions (Sequence[str] | None, optional): Valid file extensions Defaults to ``None``. - Examples: - The following example shows how to get training samples from MVTec AD bottle category: - - >>> root = Path('./MVTec') - >>> category = 'bottle' - >>> path = root / category - >>> path - PosixPath('MVTec/bottle') - - >>> samples = make_mvtec_dataset(path, split='train', split_ratio=0.1, seed=0) + Returns: + DataFrame: Dataset samples with columns: + - path: Base path to dataset + - split: Dataset split (train/test) + - label: Class label + - image_path: Path to image file + - mask_path: Path to mask file (if available) + - label_index: Numeric label (0=normal, 1=abnormal) + + Example: + >>> root = Path("./datasets/MVTec/bottle") + >>> samples = make_mvtec_dataset(root, split="train") >>> samples.head() - path split label image_path mask_path label_index - 0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0 - 1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0 - 2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0 - 3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0 - 4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0 + path split label image_path mask_path label_index + 0 datasets/MVTec/bottle train good [...]/good/105.png 0 + 1 datasets/MVTec/bottle train good [...]/good/017.png 0 - Returns: - DataFrame: an output dataframe containing the samples of the dataset. + Raises: + RuntimeError: If no valid images are found + MisMatchError: If anomalous images and masks don't match """ if extensions is None: extensions = IMG_EXTENSIONS @@ -185,8 +180,14 @@ def make_mvtec_dataset( samples.label_index = samples.label_index.astype(int) # separate masks from samples - mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True) - samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True) + mask_samples = samples.loc[samples.split == "ground_truth"].sort_values( + by="image_path", + ignore_index=True, + ) + samples = samples[samples.split != "ground_truth"].sort_values( + by="image_path", + ignore_index=True, + ) # assign mask paths to anomalous test images samples["mask_path"] = "" @@ -199,11 +200,17 @@ def make_mvtec_dataset( abnormal_samples = samples.loc[samples.label_index == LabelName.ABNORMAL] if ( len(abnormal_samples) - and not abnormal_samples.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1).all() + and not abnormal_samples.apply( + lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, + axis=1, + ).all() ): - msg = """Mismatch between anomalous images and ground truth masks. Make sure t - he mask files in 'ground_truth' folder follow the same naming convention as the - anomalous images in the dataset (e.g. image: '000.png', mask: '000.png' or '000_mask.png').""" + msg = ( + "Mismatch between anomalous images and ground truth masks. Make sure " + "mask files in 'ground_truth' folder follow the same naming " + "convention as the anomalous images (e.g. image: '000.png', " + "mask: '000.png' or '000_mask.png')." + ) raise MisMatchError(msg) # infer the task type diff --git a/src/anomalib/data/datasets/image/visa.py b/src/anomalib/data/datasets/image/visa.py index 70ee5352aa..fa182bfc19 100644 --- a/src/anomalib/data/datasets/image/visa.py +++ b/src/anomalib/data/datasets/image/visa.py @@ -1,19 +1,23 @@ """Visual Anomaly (VisA) Dataset. -Description: - This script contains PyTorch Dataset for the Visual Anomal - (VisA) dataset. If the dataset is not on the file system, the script - downloads and extracts the dataset and create PyTorch data objects. +This module provides PyTorch Dataset implementation for the Visual Anomaly (VisA) +dataset. The dataset will be downloaded and extracted automatically if not found +locally. + +The dataset contains 12 categories of industrial objects with both normal and +anomalous samples. Each category includes RGB images and pixel-level ground truth +masks for anomaly segmentation. License: The VisA dataset is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License - (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + (CC BY-NC-SA 4.0) https://creativecommons.org/licenses/by-nc-sa/4.0/ Reference: - - Zou, Y., Jeong, J., Pemula, L., Zhang, D., & Dabeer, O. (2022). SPot-the-Difference - Self-supervised Pre-training for Anomaly Detection and Segmentation. In European - Conference on Computer Vision (pp. 392-408). Springer, Cham. + Zou, Y., Jeong, J., Pemula, L., Zhang, D., & Dabeer, O. (2022). + SPot-the-Difference Self-supervised Pre-training for Anomaly Detection and + Segmentation. In European Conference on Computer Vision (pp. 392-408). + Springer, Cham. """ # Copyright (C) 2024 Intel Corporation @@ -47,35 +51,28 @@ class VisaDataset(AnomalibDataset): """VisA dataset class. + Dataset class for loading and processing Visual Anomaly (VisA) dataset images. + Supports both classification and segmentation tasks. + Args: - root (str | Path): Path to the root of the dataset - category (str): Sub-category of the dataset, e.g. 'candle' - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + root (str | Path): Path to root directory containing the dataset. + category (str): Category name, must be one of ``CATEGORIES``. + transform (Transform | None, optional): Transforms to apply to the images. Defaults to ``None``. - - Examples: - To create a Visa dataset for classification: - - .. code-block:: python - - from anomalib.data.image.visa import VisaDataset - from anomalib.data.utils.transforms import get_transforms - - transform = get_transforms(image_size=256) - dataset = VisaDataset( - transform=transform, - split="train", - root="./datasets/visa/visa_pytorch/", - category="candle", - ) - dataset.setup() - dataset[0].keys() - - # Output - dict_keys(['image_path', 'label', 'image', 'mask']) - + split (str | Split | None, optional): Dataset split - usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. + + Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import VisaDataset + >>> dataset = VisaDataset( + ... root=Path("./datasets/visa"), + ... category="candle", + ... split="train" + ... ) + >>> item = dataset[0] + >>> item.keys() + dict_keys(['image_path', 'label', 'image', 'mask']) """ def __init__( @@ -89,4 +86,8 @@ def __init__( self.root_category = Path(root) / category self.split = split - self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=EXTENSIONS) + self.samples = make_mvtec_dataset( + self.root_category, + split=self.split, + extensions=EXTENSIONS, + ) diff --git a/src/anomalib/data/datasets/video/__init__.py b/src/anomalib/data/datasets/video/__init__.py index 189841257a..94b08fd445 100644 --- a/src/anomalib/data/datasets/video/__init__.py +++ b/src/anomalib/data/datasets/video/__init__.py @@ -1,4 +1,19 @@ -"""Torch Dataset Implementations of Anomalib Video Datasets.""" +"""PyTorch Dataset implementations for anomaly detection in videos. + +This module provides dataset implementations for various video anomaly detection +datasets: + +- ``AvenueDataset``: CUHK Avenue dataset for abnormal event detection +- ``ShanghaiTechDataset``: ShanghaiTech Campus surveillance dataset +- ``UCSDpedDataset``: UCSD Pedestrian dataset for anomaly detection + +Example: + >>> from anomalib.data.datasets import AvenueDataset + >>> dataset = AvenueDataset( + ... root="./datasets/avenue", + ... split="train" + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/datasets/video/avenue.py b/src/anomalib/data/datasets/video/avenue.py index 03c07404a5..67c0b51efd 100644 --- a/src/anomalib/data/datasets/video/avenue.py +++ b/src/anomalib/data/datasets/video/avenue.py @@ -1,13 +1,41 @@ """CUHK Avenue Dataset. -Description: - This script contains PyTorch Dataset for the CUHK Avenue dataset. - If the dataset is not already present on the file system, the DataModule class will download and - extract the dataset, converting the .mat mask files to .png format. +This module provides PyTorch Dataset implementation for the CUHK Avenue dataset +for abnormal event detection. The dataset contains surveillance videos with both +normal and abnormal events. + +If the dataset is not already present on the file system, the DataModule class +will download and extract the dataset, converting the .mat mask files to .png +format. + +Example: + Create a dataset for training: + + >>> from anomalib.data.datasets import AvenueDataset + >>> dataset = AvenueDataset( + ... root="./datasets/avenue", + ... split="train" + ... ) + >>> dataset.setup() + >>> dataset[0].keys() + dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', + 'original_image', 'label']) + + Create an image dataset by setting ``clip_length_in_frames=1``: + + >>> dataset = AvenueDataset( + ... root="./datasets/avenue", + ... split="test", + ... clip_length_in_frames=1 + ... ) + >>> dataset.setup() + >>> dataset[0]["image"].shape + torch.Size([3, 256, 256]) Reference: - - Lu, Cewu, Jianping Shi, and Jiaya Jia. "Abnormal event detection at 150 fps in Matlab." - In Proceedings of the IEEE International Conference on Computer Vision, 2013. + Lu, Cewu, Jianping Shi, and Jiaya Jia. "Abnormal event detection at 150 fps + in Matlab." In Proceedings of the IEEE International Conference on Computer + Vision, 2013. """ # Copyright (C) 2024 Intel Corporation @@ -31,58 +59,36 @@ class AvenueDataset(AnomalibVideoDataset): - """Avenue Dataset class. + """CUHK Avenue dataset class. Args: - split (Split): Split of the dataset, usually Split.TRAIN or Split.TEST - root (Path | str): Path to the root of the dataset - Defaults to ``./datasets/avenue``. - gt_dir (Path | str): Path to the ground truth files - Defaults to ``./datasets/avenue/ground_truth_demo``. - clip_length_in_frames (int, optional): Number of video frames in each clip. - Defaults to ``2``. - frames_between_clips (int, optional): Number of frames between each consecutive video clip. - Defaults to ``1``. - target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. - Defaults to ``VideoTargetFrame.LAST``. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - - Examples: - To create an Avenue dataset to train a model: - - .. code-block:: python - - dataset = AvenueDataset( - transform=transform, - split="test", - root="./datasets/avenue/", - ) - - dataset.setup() - dataset[0].keys() - - # Output: dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', 'original_image', 'label']) - - Avenue video dataset can also be used as an image dataset if you set the clip length to 1. This means that each - video frame will be treated as a separate sample. This is useful for training an image model on the - Avenue dataset. The following code shows how to create an image dataset: - - .. code-block:: python + split (Split): Dataset split - usually ``Split.TRAIN`` or ``Split.TEST`` + root (Path | str, optional): Path to the root directory containing the + dataset. Defaults to ``"./datasets/avenue"``. + gt_dir (Path | str, optional): Path to the ground truth directory. + Defaults to ``"./datasets/avenue/ground_truth_demo"``. + clip_length_in_frames (int, optional): Number of frames in each video + clip. Defaults to ``2``. + frames_between_clips (int, optional): Number of frames between + consecutive video clips. Defaults to ``1``. + target_frame (VideoTargetFrame, optional): Target frame in the video + clip for ground truth retrieval. Defaults to + ``VideoTargetFrame.LAST``. + transform (Transform | None, optional): Transforms to apply to the input + images. Defaults to ``None``. - dataset = AvenueDataset( - transform=transform, - split="test", - root="./datasets/avenue/", - clip_length_in_frames=1, - ) - - dataset.setup() - dataset[0].keys() - # Output: dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image', 'label']) - - dataset[0]["image"].shape - # Output: torch.Size([3, 256, 256]) + Example: + Create a dataset for testing: + + >>> dataset = AvenueDataset( + ... root="./datasets/avenue", + ... split="test", + ... transform=transform + ... ) + >>> dataset.setup() + >>> dataset[0].keys() + dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', + 'original_image', 'label']) """ def __init__( @@ -109,33 +115,36 @@ def __init__( self.samples = make_avenue_dataset(self.root, self.gt_dir, self.split) -def make_avenue_dataset(root: Path, gt_dir: Path, split: Split | str | None = None) -> DataFrame: +def make_avenue_dataset( + root: Path, + gt_dir: Path, + split: Split | str | None = None, +) -> DataFrame: """Create CUHK Avenue dataset by parsing the file structure. The files are expected to follow the structure: - - path/to/dataset/[training_videos|testing_videos]/video_filename.avi - - path/to/ground_truth/mask_filename.mat + path/to/dataset/[training_videos|testing_videos]/video_filename.avi + path/to/ground_truth/mask_filename.mat Args: - root (Path): Path to dataset - gt_dir (Path): Path to the ground truth - split (Split | str | None = None, optional): Dataset split (ie., either train or test). + root (Path): Path to dataset root directory + gt_dir (Path): Path to ground truth directory + split (Split | str | None, optional): Dataset split (train/test). Defaults to ``None``. Example: - The following example shows how to get testing samples from Avenue dataset: + Get testing samples from Avenue dataset: - >>> root = Path('./avenue') - >>> gt_dir = Path('./avenue/masks') - >>> samples = make_avenue_dataset(path, gt_dir, split='test') + >>> root = Path("./avenue") + >>> gt_dir = Path("./avenue/masks") + >>> samples = make_avenue_dataset(root, gt_dir, split="test") >>> samples.head() - root folder image_path mask_path split - 0 ./avenue testing_videos ./avenue/training_videos/01.avi ./avenue/masks/01_label.mat test - 1 ./avenue testing_videos ./avenue/training_videos/02.avi ./avenue/masks/01_label.mat test - ... + root folder image_path mask_path split + 0 ./avenue testing 01.avi 01_label.mat test + 1 ./avenue testing 02.avi 02_label.mat test Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + DataFrame: Dataframe containing samples for the requested split """ root = validate_path(root) @@ -166,17 +175,28 @@ def make_avenue_dataset(root: Path, gt_dir: Path, split: Split | str | None = No class AvenueClipsIndexer(ClipsIndexer): - """Clips class for Avenue dataset.""" + """Clips indexer class for Avenue dataset. + + This class handles retrieving video clips and corresponding masks from the + Avenue dataset. + """ def get_mask(self, idx: int) -> np.ndarray | None: - """Retrieve the masks from the file system.""" + """Retrieve masks from the file system. + + Args: + idx (int): Index of the clip + + Returns: + np.ndarray | None: Array of masks if available, else None + """ video_idx, frames_idx = self.get_clip_location(idx) matfile = self.mask_paths[video_idx] if matfile == "": # no gt masks available for this clip return None frames = self.clips[video_idx][frames_idx] - # read masks from .png files if available, othwerise from mat files. + # read masks from .png files if available, otherwise from mat files mask_folder = Path(matfile).with_suffix("") if mask_folder.exists(): mask_frames = sorted(mask_folder.glob("*")) diff --git a/src/anomalib/data/datasets/video/shanghaitech.py b/src/anomalib/data/datasets/video/shanghaitech.py index 424a13e9e6..c1fad64c20 100644 --- a/src/anomalib/data/datasets/video/shanghaitech.py +++ b/src/anomalib/data/datasets/video/shanghaitech.py @@ -1,16 +1,62 @@ """ShanghaiTech Campus Dataset. -Description: - This script contains PyTorch Dataset for the ShanghaiTech Campus dataset. - If the dataset is not on the file system, the DataModule class downloads and - extracts the dataset and converts video files to a format that is readable by pyav. +This module provides PyTorch Dataset implementation for the ShanghaiTech Campus +dataset for abnormal event detection. The dataset contains surveillance videos +with both normal and abnormal events. + +If the dataset is not already present on the file system, the DataModule class +will download and extract the dataset, converting the video files to a format +readable by pyav. + +The dataset expects the following directory structure:: + + root/ + ├── training/ + │ └── converted_videos/ + │ ├── 01_001.avi + │ ├── 01_002.avi + │ └── ... + └── testing/ + ├── frames/ + │ ├── 01_0014/ + │ │ ├── 000001.jpg + │ │ └── ... + │ └── ... + └── test_pixel_mask/ + ├── 01_0014.npy + └── ... + +Example: + Create a dataset for training: + + >>> from anomalib.data.datasets import ShanghaiTechDataset + >>> from anomalib.data.utils import Split + >>> dataset = ShanghaiTechDataset( + ... root="./datasets/shanghaitech", + ... scene=1, + ... split=Split.TRAIN + ... ) + >>> dataset[0].keys() + dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) + + Create a test dataset: + + >>> dataset = ShanghaiTechDataset( + ... root="./datasets/shanghaitech", + ... scene=1, + ... split=Split.TEST + ... ) + >>> dataset[0].keys() + dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', + 'original_image', 'label']) License: ShanghaiTech Campus Dataset is released under the BSD 2-Clause License. Reference: - - W. Liu and W. Luo, D. Lian and S. Gao. "Future Frame Prediction for Anomaly Detection -- A New Baseline." - IEEE Conference on Computer Vision and Pattern Recognition (CVPR). 2018. + Liu, W., Luo, W., Lian, D., & Gao, S. (2018). Future frame prediction for + anomaly detection--a new baseline. In Proceedings of the IEEE conference on + computer vision and pattern recognition (pp. 6536-6545). """ # Copyright (C) 2024 Intel Corporation @@ -34,14 +80,28 @@ class ShanghaiTechDataset(AnomalibVideoDataset): """ShanghaiTech Dataset class. Args: - split (Split): Split of the dataset, usually Split.TRAIN or Split.TEST - root (Path | str): Path to the root of the dataset - scene (int): Index of the dataset scene (category) in range [1, 13] - clip_length_in_frames (int, optional): Number of video frames in each clip. - frames_between_clips (int, optional): Number of frames between each consecutive video clip. - target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. + split (Split): Dataset split - either ``Split.TRAIN`` or ``Split.TEST`` + root (Path | str): Path to the root directory containing the dataset. + Defaults to ``"./datasets/shanghaitech"``. + scene (int): Index of the dataset scene (category) in range [1, 13]. + Defaults to ``1``. + clip_length_in_frames (int, optional): Number of frames in each video + clip. Defaults to ``2``. + frames_between_clips (int, optional): Number of frames between each + consecutive video clip. Defaults to ``1``. + target_frame (VideoTargetFrame): Specifies which frame in the clip to use + for ground truth retrieval. Defaults to ``VideoTargetFrame.LAST``. + transform (Transform | None, optional): Transforms to apply to the input + images. Defaults to ``None``. + + Example: + >>> from anomalib.data.datasets import ShanghaiTechDataset + >>> from anomalib.data.utils import Split + >>> dataset = ShanghaiTechDataset( + ... root="./datasets/shanghaitech", + ... scene=1, + ... split=Split.TRAIN + ... ) """ def __init__( @@ -69,28 +129,42 @@ def __init__( class ShanghaiTechTrainClipsIndexer(ClipsIndexer): - """Clips indexer for ShanghaiTech dataset. + """Clips indexer for ShanghaiTech training dataset. - The train and test subsets of the ShanghaiTech dataset use different file formats, so separate - clips indexer implementations are needed. + The train and test subsets use different file formats, so separate clips + indexer implementations are needed. """ @staticmethod def get_mask(idx: int) -> torch.Tensor | None: - """No masks available for training set.""" + """No masks available for training set. + + Args: + idx (int): Index of the clip. + + Returns: + None: Training set has no masks. + """ del idx # Unused argument return None class ShanghaiTechTestClipsIndexer(ClipsIndexer): - """Clips indexer for the test set of the ShanghaiTech Campus dataset. + """Clips indexer for ShanghaiTech test dataset. - The train and test subsets of the ShanghaiTech dataset use different file formats, so separate - clips indexer implementations are needed. + The train and test subsets use different file formats, so separate clips + indexer implementations are needed. """ def get_mask(self, idx: int) -> torch.Tensor | None: - """Retrieve the masks from the file system.""" + """Retrieve the masks from the file system. + + Args: + idx (int): Index of the clip. + + Returns: + torch.Tensor | None: Ground truth mask if available, else None. + """ video_idx, frames_idx = self.get_clip_location(idx) mask_file = self.mask_paths[video_idx] if mask_file == "": # no gt masks available for this clip @@ -107,19 +181,24 @@ def _compute_frame_pts(self) -> None: n_frames = len(list(Path(video_path).glob("*.jpg"))) self.video_pts.append(torch.Tensor(range(n_frames))) - self.video_fps = [None] * len(self.video_paths) # fps information cannot be inferred from folder structure + # fps information cannot be inferred from folder structure + self.video_fps = [None] * len(self.video_paths) def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], int]: """Get a subclip from a list of videos. Args: - idx (int): index of the subclip. Must be between 0 and num_clips(). + idx (int): Index of the subclip. Must be between 0 and num_clips(). Returns: - video (torch.Tensor) - audio (torch.Tensor) - info (Dict) - video_idx (int): index of the video in `video_paths` + tuple containing: + - video (torch.Tensor): Video clip tensor + - audio (torch.Tensor): Empty audio tensor + - info (dict): Empty info dictionary + - video_idx (int): Index of the video in video_paths + + Raises: + IndexError: If idx is out of range. """ if idx >= self.num_clips(): msg = f"Index {idx} out of range ({self.num_clips()} number of clips)" @@ -139,29 +218,41 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any] def make_shanghaitech_dataset(root: Path, scene: int, split: Split | str | None = None) -> DataFrame: """Create ShanghaiTech dataset by parsing the file structure. - The files are expected to follow the structure: - path/to/dataset/[training_videos|testing_videos]/video_filename.avi - path/to/ground_truth/mask_filename.mat + The files are expected to follow the structure:: + + root/ + ├── training/ + │ └── converted_videos/ + │ ├── 01_001.avi + │ └── ... + └── testing/ + ├── frames/ + │ ├── 01_0014/ + │ │ ├── 000001.jpg + │ │ └── ... + │ └── ... + └── test_pixel_mask/ + ├── 01_0014.npy + └── ... Args: - root (Path): Path to dataset - scene (int): Index of the dataset scene (category) in range [1, 13] - split (Split | str | None, optional): Dataset split (ie., either train or test). Defaults to None. + root (Path): Path to dataset root directory. + scene (int): Index of the dataset scene (category) in range [1, 13]. + split (Split | str | None, optional): Dataset split (train or test). + Defaults to ``None``. - Example: - The following example shows how to get testing samples from ShanghaiTech dataset: + Returns: + DataFrame: DataFrame containing samples for the requested split. - >>> root = Path('./shanghaiTech') + Example: + >>> from pathlib import Path + >>> root = Path('./shanghaitech') >>> scene = 1 - >>> samples = make_avenue_dataset(path, scene, split='test') + >>> samples = make_shanghaitech_dataset(root, scene, split='test') >>> samples.head() - root image_path split mask_path - 0 shanghaitech shanghaitech/testing/frames/01_0014 test shanghaitech/testing/test_pixel_mask/01_0014.npy - 1 shanghaitech shanghaitech/testing/frames/01_0015 test shanghaitech/testing/test_pixel_mask/01_0015.npy - ... - - Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + root image_path split mask_path + 0 shanghaitech shanghaitech/testing/frames/01_0014 test ...01_0014.npy + 1 shanghaitech shanghaitech/testing/frames/01_0015 test ...01_0015.npy """ scene_prefix = str(scene).zfill(2) diff --git a/src/anomalib/data/datasets/video/ucsd_ped.py b/src/anomalib/data/datasets/video/ucsd_ped.py index 5a619be3f1..ffc2ab8c18 100644 --- a/src/anomalib/data/datasets/video/ucsd_ped.py +++ b/src/anomalib/data/datasets/video/ucsd_ped.py @@ -1,4 +1,62 @@ -"""UCSD Pedestrian Dataset.""" +"""UCSD Pedestrian Dataset. + +This module provides PyTorch Dataset implementation for the UCSD Pedestrian +dataset for abnormal event detection. The dataset contains surveillance videos +with both normal and abnormal events. + +The dataset expects the following directory structure:: + + root/ + ├── UCSDped1/ + │ ├── Train/ + │ │ ├── Train001/ + │ │ │ ├── 001.tif + │ │ │ └── ... + │ │ └── ... + │ └── Test/ + │ ├── Test001/ + │ │ ├── 001.tif + │ │ └── ... + │ ├── Test001_gt/ + │ │ ├── 001.bmp + │ │ └── ... + │ └── ... + └── UCSDped2/ + ├── Train/ + └── Test/ + +Example: + Create a dataset for training: + + >>> from anomalib.data.datasets import UCSDpedDataset + >>> from anomalib.data.utils import Split + >>> dataset = UCSDpedDataset( + ... root="./datasets/ucsdped", + ... category="UCSDped1", + ... split=Split.TRAIN + ... ) + >>> dataset[0].keys() + dict_keys(['image', 'video_path', 'frames', 'last_frame', 'original_image']) + + Create a test dataset: + + >>> dataset = UCSDpedDataset( + ... root="./datasets/ucsdped", + ... category="UCSDped1", + ... split=Split.TEST + ... ) + >>> dataset[0].keys() + dict_keys(['image', 'mask', 'video_path', 'frames', 'last_frame', + 'original_image', 'label']) + +License: + UCSD Pedestrian Dataset is released under the BSD 2-Clause License. + +Reference: + Mahadevan, V., Li, W., Bhalodia, V., & Vasconcelos, N. (2010). Anomaly + detection in crowded scenes. In IEEE Conference on Computer Vision and + Pattern Recognition (CVPR), 2010. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -25,14 +83,31 @@ class UCSDpedDataset(AnomalibVideoDataset): """UCSDped Dataset class. Args: - root (Path | str): Path to the root of the dataset - category (str): Sub-category of the dataset, e.g. "UCSDped1" or "UCSDped2" - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + root (Path | str): Path to the root of the dataset. + category (str): Sub-category of the dataset, must be one of ``CATEGORIES``. + split (str | Split | None): Dataset split - usually ``Split.TRAIN`` or + ``Split.TEST``. clip_length_in_frames (int, optional): Number of video frames in each clip. - frames_between_clips (int, optional): Number of frames between each consecutive video clip. - target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. - transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``2``. + frames_between_clips (int, optional): Number of frames between each + consecutive video clip. Defaults to ``10``. + target_frame (VideoTargetFrame): Specifies the target frame in the video + clip, used for ground truth retrieval. Defaults to + ``VideoTargetFrame.LAST``. + transform (Transform | None, optional): Transforms to apply to the images. Defaults to ``None``. + + Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import UCSDpedDataset + >>> dataset = UCSDpedDataset( + ... root=Path("./datasets/ucsdped"), + ... category="UCSDped1", + ... split="train" + ... ) + >>> dataset[0].keys() + dict_keys(['image', 'video_path', 'frames', 'last_frame', + 'original_image']) """ def __init__( @@ -62,7 +137,14 @@ class UCSDpedClipsIndexer(ClipsIndexer): """Clips class for UCSDped dataset.""" def get_mask(self, idx: int) -> np.ndarray | None: - """Retrieve the masks from the file system.""" + """Retrieve the masks from the file system. + + Args: + idx (int): Index of the clip. + + Returns: + np.ndarray | None: Stack of mask frames if available, None otherwise. + """ video_idx, frames_idx = self.get_clip_location(idx) mask_folder = self.mask_paths[video_idx] if mask_folder == "": # no gt masks available for this clip @@ -87,13 +169,18 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any] """Get a subclip from a list of videos. Args: - idx (int): index of the subclip. Must be between 0 and num_clips(). + idx (int): Index of the subclip. Must be between 0 and num_clips(). Returns: - video (torch.Tensor) - audio (torch.Tensor) - info (dict) - video_idx (int): index of the video in `video_paths` + tuple[torch.Tensor, torch.Tensor, dict[str, Any], int]: Tuple + containing: + - video frames tensor + - empty audio tensor + - empty info dict + - video index + + Raises: + IndexError: If ``idx`` is out of range. """ if idx >= self.num_clips(): msg = f"Index {idx} out of range ({self.num_clips()} number of clips)" @@ -113,16 +200,19 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any] def make_ucsd_dataset(path: Path, split: str | Split | None = None) -> DataFrame: """Create UCSD Pedestrian dataset by parsing the file structure. - The files are expected to follow the structure: + The files are expected to follow the structure:: + path/to/dataset/category/split/video_id/image_filename.tif path/to/dataset/category/split/video_id_gt/mask_filename.bmp Args: - path (Path): Path to dataset - split (str | Split | None, optional): Dataset split (ie., either train or test). Defaults to None. + path (Path): Path to dataset. + split (str | Split | None, optional): Dataset split (ie., either train or + test). Defaults to ``None``. Example: - The following example shows how to get testing samples from UCSDped2 category: + The following example shows how to get testing samples from UCSDped2 + category: >>> root = Path('./UCSDped') >>> category = 'UCSDped2' @@ -132,13 +222,11 @@ def make_ucsd_dataset(path: Path, split: str | Split | None = None) -> DataFrame >>> samples = make_ucsd_dataset(path, split='test') >>> samples.head() - root folder image_path mask_path split - 0 UCSDped/UCSDped2 Test UCSDped/UCSDped2/Test/Test001 UCSDped/UCSDped2/Test/Test001_gt test - 1 UCSDped/UCSDped2 Test UCSDped/UCSDped2/Test/Test002 UCSDped/UCSDped2/Test/Test002_gt test - ... + root folder image_path mask_path + 0 UCSDped/UCSDped2 Test UCSDped/UCSDped2/Test/Test001 UCSDped/... Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) + DataFrame: Output dataframe containing samples for the requested split. """ path = validate_path(path) folders = [filename for filename in sorted(path.glob("*/*")) if filename.is_dir()] diff --git a/src/anomalib/data/errors.py b/src/anomalib/data/errors.py index 97c956663c..7909bc9659 100644 --- a/src/anomalib/data/errors.py +++ b/src/anomalib/data/errors.py @@ -1,14 +1,35 @@ -"""Custom Exception Class for Mismatch Detection (MisMatchError).""" +"""Custom exceptions for anomalib data validation. + +This module provides custom exception classes for handling data validation errors +in anomalib. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 class MisMatchError(Exception): - """Exception raised when a mismatch is detected. + """Exception raised when a data mismatch is detected. + + This exception is raised when there is a mismatch between expected and actual + data formats or values during validation. + + Args: + message (str): Custom error message. Defaults to "Mismatch detected." Attributes: message (str): Explanation of the error. + + Examples: + >>> raise MisMatchError() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + MisMatchError: Mismatch detected. + >>> raise MisMatchError("Image dimensions do not match") + ... # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + MisMatchError: Image dimensions do not match """ def __init__(self, message: str = "") -> None: diff --git a/src/anomalib/data/image/datumaro.py b/src/anomalib/data/image/datumaro.py deleted file mode 100644 index b4836990ec..0000000000 --- a/src/anomalib/data/image/datumaro.py +++ /dev/null @@ -1,226 +0,0 @@ -"""Dataloader for Datumaro format. - -Note: This currently only works for annotations exported from Intel Geti™. -""" - -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import json -from pathlib import Path - -import pandas as pd -from torchvision.transforms.v2 import Transform - -from anomalib import TaskType -from anomalib.data.base import AnomalibDataModule, AnomalibDataset -from anomalib.data.utils import LabelName, Split, TestSplitMode, ValSplitMode - - -def make_datumaro_dataset(root: str | Path, split: str | Split | None = None) -> pd.DataFrame: - """Make Datumaro Dataset. - - Assumes the following directory structure: - - dataset - ├── annotations - │ └── default.json - └── images - └── default - ├── image1.jpg - ├── image2.jpg - └── ... - - Args: - root (str | Path): Path to the dataset root directory. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST. - Defaults to ``None``. - - Examples: - >>> root = Path("path/to/dataset") - >>> samples = make_datumaro_dataset(root) - >>> samples.head() - image_path label label_index split mask_path - 0 path/to/dataset... Normal 0 Split.TRAIN - 1 path/to/dataset... Normal 0 Split.TRAIN - 2 path/to/dataset... Normal 0 Split.TRAIN - 3 path/to/dataset... Normal 0 Split.TRAIN - 4 path/to/dataset... Normal 0 Split.TRAIN - - - Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test). - """ - annotation_file = Path(root) / "annotations" / "default.json" - with annotation_file.open() as f: - annotations = json.load(f) - - categories = annotations["categories"] - categories = {idx: label["name"] for idx, label in enumerate(categories["label"]["labels"])} - - samples = [] - for item in annotations["items"]: - image_path = Path(root) / "images" / "default" / item["image"]["path"] - label_index = item["annotations"][0]["label_id"] - label = categories[label_index] - samples.append({ - "image_path": str(image_path), - "label": label, - "label_index": label_index, - "split": None, - "mask_path": "", # mask is provided in the annotation file and is not on disk. - }) - samples_df = pd.DataFrame( - samples, - columns=["image_path", "label", "label_index", "split", "mask_path"], - index=range(len(samples)), - ) - # Create test/train split - # By default assign all "Normal" samples to train and all "Anomalous" samples to test - samples_df.loc[samples_df["label_index"] == LabelName.NORMAL, "split"] = Split.TRAIN - samples_df.loc[samples_df["label_index"] == LabelName.ABNORMAL, "split"] = Split.TEST - - # Get the data frame for the split. - if split: - samples_df = samples_df[samples_df.split == split].reset_index(drop=True) - - return samples_df - - -class DatumaroDataset(AnomalibDataset): - """Datumaro dataset class. - - Args: - task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. - root (str | Path): Path to the dataset root directory. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST - Defaults to ``None``. - - - Examples: - .. code-block:: python - - from anomalib.data.image.datumaro import DatumaroDataset - from torchvision.transforms.v2 import Resize - - dataset = DatumaroDataset(root=root, - task="classification", - transform=Resize((256, 256)), - ) - print(dataset[0].keys()) - # Output: dict_keys(['dm_format_version', 'infos', 'categories', 'items']) - - """ - - def __init__( - self, - task: TaskType, - root: str | Path, - transform: Transform | None = None, - split: str | Split | None = None, - ) -> None: - super().__init__(task, transform) - self.split = split - self.samples = make_datumaro_dataset(root, split) - - -class Datumaro(AnomalibDataModule): - """Datumaro datamodule. - - Args: - root (str | Path): Path to the dataset root directory. - train_batch_size (int): Batch size for training dataloader. - Defaults to ``32``. - eval_batch_size (int): Batch size for evaluation dataloader. - Defaults to ``32``. - num_workers (int): Number of workers for dataloaders. - Defaults to ``8``. - task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. - Defaults to ``TaskType.CLASSIFICATION``. Currently only supports classification. - image_size (tuple[int, int], optional): Size to which input images should be resized. - Defaults to ``None``. - transform (Transform, optional): Transforms that should be applied to the input images. - Defaults to ``None``. - train_transform (Transform, optional): Transforms that should be applied to the input images during training. - Defaults to ``None``. - eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. - Defaults to ``None``. - test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. - Defaults to ``TestSplitMode.FROM_DIR``. - test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. - Defaults to ``0.2``. - val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. - Defaults to ``ValSplitMode.SAME_AS_TEST``. - val_split_ratio (float): Fraction of train or test images that will be reserved for validation. - Defaults to ``0.5``. - seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. - Defualts to ``None``. - - Examples: - To create a Datumaro datamodule - - >>> from pathlib import Path - >>> from torchvision.transforms.v2 import Resize - >>> root = Path("path/to/dataset") - >>> datamodule = Datumaro(root, transform=Resize((256, 256))) - >>> datamodule.setup() - >>> i, data = next(enumerate(datamodule.train_dataloader())) - >>> data.keys() - dict_keys(['image_path', 'label', 'image']) - - >>> data["image"].shape - torch.Size([32, 3, 256, 256]) - """ - - def __init__( - self, - root: str | Path, - train_batch_size: int = 32, - eval_batch_size: int = 32, - num_workers: int = 8, - task: TaskType = TaskType.CLASSIFICATION, - image_size: tuple[int, int] | None = None, - transform: Transform | None = None, - train_transform: Transform | None = None, - eval_transform: Transform | None = None, - test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, - test_split_ratio: float = 0.5, - val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, - val_split_ratio: float = 0.5, - seed: int | None = None, - ) -> None: - if task != TaskType.CLASSIFICATION: - msg = "Datumaro dataloader currently only supports classification task." - raise ValueError(msg) - super().__init__( - train_batch_size=train_batch_size, - eval_batch_size=eval_batch_size, - num_workers=num_workers, - val_split_mode=val_split_mode, - val_split_ratio=val_split_ratio, - test_split_mode=test_split_mode, - test_split_ratio=test_split_ratio, - image_size=image_size, - transform=transform, - train_transform=train_transform, - eval_transform=eval_transform, - seed=seed, - ) - self.root = root - self.task = task - - def _setup(self, _stage: str | None = None) -> None: - self.train_data = DatumaroDataset( - task=self.task, - root=self.root, - transform=self.train_transform, - split=Split.TRAIN, - ) - self.test_data = DatumaroDataset( - task=self.task, - root=self.root, - transform=self.eval_transform, - split=Split.TEST, - ) diff --git a/src/anomalib/data/predict.py b/src/anomalib/data/predict.py index 06c743b88f..e53ef2b52f 100644 --- a/src/anomalib/data/predict.py +++ b/src/anomalib/data/predict.py @@ -1,4 +1,16 @@ -"""Inference Dataset.""" +"""Dataset for performing inference on images. + +This module provides a dataset class for loading and preprocessing images for +inference in anomaly detection tasks. + +Example: + >>> from pathlib import Path + >>> from anomalib.data import PredictDataset + >>> dataset = PredictDataset(path="path/to/images") + >>> item = dataset[0] + >>> item.image.shape # doctest: +SKIP + torch.Size([3, 256, 256]) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -14,14 +26,27 @@ class PredictDataset(Dataset): - """Inference Dataset to perform prediction. + """Dataset for performing inference on images. Args: - path (str | Path): Path to an image or image-folder. - transform (A.Compose | None, optional): Transform object describing the transforms that are - applied to the inputs. - image_size (int | tuple[int, int] | None, optional): Target image size - to resize the original image. Defaults to None. + path (str | Path): Path to an image or directory containing images. + transform (Transform | None, optional): Transform object describing the + transforms to be applied to the inputs. Defaults to ``None``. + image_size (int | tuple[int, int], optional): Target size to which input + images will be resized. If int, a square image of that size will be + created. Defaults to ``(256, 256)``. + + Examples: + >>> from pathlib import Path + >>> dataset = PredictDataset( + ... path=Path("path/to/images"), + ... image_size=(224, 224), + ... ) + >>> len(dataset) # doctest: +SKIP + 10 + >>> item = dataset[0] # doctest: +SKIP + >>> item.image.shape # doctest: +SKIP + torch.Size([3, 224, 224]) """ def __init__( @@ -37,11 +62,22 @@ def __init__( self.image_size = image_size def __len__(self) -> int: - """Get the number of images in the given path.""" + """Get number of images in dataset. + + Returns: + int: Number of images in the dataset. + """ return len(self.image_filenames) def __getitem__(self, index: int) -> ImageItem: - """Get the image based on the `index`.""" + """Get image item at specified index. + + Args: + index (int): Index of the image to retrieve. + + Returns: + ImageItem: Object containing the loaded image and its metadata. + """ image_filename = self.image_filenames[index] image = read_image(image_filename, as_tensor=True) if self.transform: @@ -54,5 +90,10 @@ def __getitem__(self, index: int) -> ImageItem: @property def collate_fn(self) -> Callable: - """Get the collate function.""" + """Get collate function for creating batches. + + Returns: + Callable: Function that collates multiple ``ImageItem`` instances into + a batch. + """ return ImageBatch.collate diff --git a/src/anomalib/data/transforms/center_crop.py b/src/anomalib/data/transforms/center_crop.py index 88b8655aae..880acd5484 100644 --- a/src/anomalib/data/transforms/center_crop.py +++ b/src/anomalib/data/transforms/center_crop.py @@ -1,4 +1,17 @@ -"""Custom Torchvision transforms for Anomalib.""" +"""Custom Torchvision transforms for Anomalib. + +This module provides custom center crop transforms that are compatible with ONNX +export. + +Example: + >>> import torch + >>> from anomalib.data.transforms.center_crop import ExportableCenterCrop + >>> transform = ExportableCenterCrop(size=(224, 224)) + >>> image = torch.randn(3, 256, 256) + >>> output = transform(image) + >>> output.shape + torch.Size([3, 224, 224]) +""" # Original Code # Copyright (c) Soumith Chintala 2016 @@ -29,14 +42,17 @@ def _center_crop_compute_crop_anchor( ) -> tuple[int, int]: """Compute the anchor point for center-cropping. - This function is a modified version of the torchvision.transforms.functional._center_crop_compute_crop_anchor - function. The original function uses `round` to compute the anchor point, which is not compatible with ONNX. + This function is a modified version of the torchvision center crop anchor + computation that is compatible with ONNX export. Args: - crop_height (int): Desired height of the crop. - crop_width (int): Desired width of the crop. - image_height (int): Height of the input image. - image_width (int): Width of the input image. + crop_height (int): Desired height of the crop + crop_width (int): Desired width of the crop + image_height (int): Height of the input image + image_width (int): Width of the input image + + Returns: + tuple[int, int]: Tuple containing the top and left crop anchor points """ crop_top = torch.tensor((image_height - crop_height) / 2.0).round().int().item() crop_left = torch.tensor((image_width - crop_width) / 2.0).round().int().item() @@ -46,11 +62,21 @@ def _center_crop_compute_crop_anchor( def center_crop_image(image: torch.Tensor, output_size: list[int]) -> torch.Tensor: """Apply center-cropping to an input image. - Uses the modified anchor point computation function to compute the anchor point for center-cropping. + Uses the modified anchor point computation function to ensure ONNX + compatibility. Args: - image (torch.Tensor): Input image to be center-cropped. - output_size (list[int]): Desired output size of the crop. + image (torch.Tensor): Input image tensor to be center-cropped + output_size (list[int]): Desired output size ``[height, width]`` + + Returns: + torch.Tensor: Center-cropped image tensor + + Example: + >>> image = torch.randn(3, 256, 256) + >>> output = center_crop_image(image, [224, 224]) + >>> output.shape + torch.Size([3, 224, 224]) """ crop_height, crop_width = _center_crop_parse_output_size(output_size) shape = image.shape @@ -59,22 +85,45 @@ def center_crop_image(image: torch.Tensor, output_size: list[int]) -> torch.Tens image_height, image_width = shape[-2:] if crop_height > image_height or crop_width > image_width: - padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + padding_ltrb = _center_crop_compute_padding( + crop_height, + crop_width, + image_height, + image_width, + ) image = pad(image, _parse_pad_padding(padding_ltrb), value=0.0) image_height, image_width = image.shape[-2:] if crop_width == image_width and crop_height == image_height: return image - crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) - return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] + crop_top, crop_left = _center_crop_compute_crop_anchor( + crop_height, + crop_width, + image_height, + image_width, + ) + return image[ + ..., + crop_top : (crop_top + crop_height), + crop_left : (crop_left + crop_width), + ] class ExportableCenterCrop(Transform): - """Transform that applies center-cropping to an input image and allows to be exported to ONNX. + """Transform that applies center-cropping with ONNX export support. Args: - size (int | tuple[int, int]): Desired output size of the crop. + size (int | tuple[int, int]): Desired output size. If int, creates a + square crop of size ``(size, size)``. If tuple, creates a + rectangular crop of size ``(height, width)``. + + Example: + >>> transform = ExportableCenterCrop(224) + >>> image = torch.randn(3, 256, 256) + >>> output = transform(image) + >>> output.shape + torch.Size([3, 224, 224]) """ def __init__(self, size: int | tuple[int, int]) -> None: @@ -82,6 +131,14 @@ def __init__(self, size: int | tuple[int, int]) -> None: self.size = list(size) if isinstance(size, tuple) else [size, size] def _transform(self, inpt: torch.Tensor, params: dict[str, Any]) -> torch.Tensor: - """Apply the transform.""" + """Apply the center crop transform. + + Args: + inpt (torch.Tensor): Input tensor to transform + params (dict[str, Any]): Transform parameters (unused) + + Returns: + torch.Tensor: Center-cropped output tensor + """ del params return center_crop_image(inpt, output_size=self.size) diff --git a/src/anomalib/data/transforms/multi_random_choice.py b/src/anomalib/data/transforms/multi_random_choice.py index 1d507c17a2..f19c5fa483 100644 --- a/src/anomalib/data/transforms/multi_random_choice.py +++ b/src/anomalib/data/transforms/multi_random_choice.py @@ -1,4 +1,24 @@ -"""Multi random choice transform.""" +"""Multi random choice transform. + +This transform randomly applies multiple transforms from a list of transforms. + +Example: + >>> import torchvision.transforms.v2 as v2 + >>> transforms = [ + ... v2.RandomHorizontalFlip(p=1.0), + ... v2.ColorJitter(brightness=0.5), + ... v2.RandomRotation(10), + ... ] + >>> # Apply 1-2 random transforms with equal probability + >>> transform = MultiRandomChoice(transforms, num_transforms=2) + >>> # Always apply exactly 2 transforms with custom probabilities + >>> transform = MultiRandomChoice( + ... transforms, + ... probabilities=[0.5, 0.3, 0.2], + ... num_transforms=2, + ... fixed_num_transforms=True + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -15,17 +35,22 @@ class MultiRandomChoice(v2.Transform): This transform does not support torchscript. Args: - transforms (sequence or torch.nn.Module): List of transformations to choose from. - probabilities (list[float] | None, optional): Probability of each transform being picked. - If None (default), all transforms have equal probability. If provided, probabilities - will be normalized to sum to 1. - num_transforms (int): Maximum number of transforms to apply at once. + transforms: List of transformations to choose from. + probabilities: Probability of each transform being picked. If ``None`` + (default), all transforms have equal probability. If provided, + probabilities will be normalized to sum to 1. + num_transforms: Maximum number of transforms to apply at once. Defaults to ``1``. - fixed_num_transforms (bool): If ``True``, always applies exactly ``num_transforms`` transforms. - If ``False``, randomly picks between 1 and ``num_transforms``. - Defaults to ``False``. + fixed_num_transforms: If ``True``, always applies exactly + ``num_transforms`` transforms. If ``False``, randomly picks between + 1 and ``num_transforms``. Defaults to ``False``. + + Raises: + TypeError: If ``transforms`` is not a sequence of callables. + ValueError: If length of ``probabilities`` does not match length of + ``transforms``. - Examples: + Example: >>> import torchvision.transforms.v2 as v2 >>> transforms = [ ... v2.RandomHorizontalFlip(p=1.0), @@ -34,7 +59,6 @@ class MultiRandomChoice(v2.Transform): ... ] >>> # Apply 1-2 random transforms with equal probability >>> transform = MultiRandomChoice(transforms, num_transforms=2) - >>> # Always apply exactly 2 transforms with custom probabilities >>> transform = MultiRandomChoice( ... transforms, @@ -71,7 +95,14 @@ def __init__( self.fixed_num_transforms = fixed_num_transforms def forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: - """Apply randomly selected transforms to the input.""" + """Apply randomly selected transforms to the input. + + Args: + *inputs: Input tensors to transform. + + Returns: + Transformed tensor(s). + """ # First determine number of transforms to apply num_transforms = ( self.num_transforms if self.fixed_num_transforms else int(torch.randint(self.num_transforms, (1,)) + 1) diff --git a/src/anomalib/data/utils/__init__.py b/src/anomalib/data/utils/__init__.py index 570c45af4a..d70b9721f7 100644 --- a/src/anomalib/data/utils/__init__.py +++ b/src/anomalib/data/utils/__init__.py @@ -1,4 +1,23 @@ -"""Helper utilities for data.""" +"""Helper utilities for data. + +This module provides various utility functions for data handling in Anomalib. + +The utilities are organized into several categories: + +- Image handling: Functions for reading, writing and processing images +- Box handling: Functions for converting between masks and bounding boxes +- Path handling: Functions for validating and resolving file paths +- Dataset splitting: Functions for splitting datasets into train/val/test +- Data generation: Functions for generating synthetic data like Perlin noise +- Download utilities: Functions for downloading and extracting datasets + +Example: + >>> from anomalib.data.utils import read_image, generate_perlin_noise + >>> # Read an image + >>> image = read_image("path/to/image.jpg") + >>> # Generate Perlin noise + >>> noise = generate_perlin_noise(256, 256) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/utils/boxes.py b/src/anomalib/data/utils/boxes.py index ade9563e55..a2bed89342 100644 --- a/src/anomalib/data/utils/boxes.py +++ b/src/anomalib/data/utils/boxes.py @@ -1,4 +1,8 @@ -"""Helper functions for processing bounding box detections and annotations.""" +"""Helper functions for processing bounding box detections and annotations. + +This module provides utility functions for converting between different bounding box +formats and handling bounding box operations. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,21 +16,37 @@ def masks_to_boxes( masks: torch.Tensor, anomaly_maps: torch.Tensor | None = None, ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """Convert a batch of segmentation masks to bounding box coordinates. + """Convert batch of segmentation masks to bounding box coordinates. Args: - masks (torch.Tensor): Input tensor of shape (B, 1, H, W), (B, H, W) or (H, W) - anomaly_maps (Tensor | None, optional): Anomaly maps of shape (B, 1, H, W), (B, H, W) or (H, W) which are - used to determine an anomaly score for the converted bounding boxes. + masks: Input tensor of masks. Can be one of: + - shape ``(B, 1, H, W)`` + - shape ``(B, H, W)`` + - shape ``(H, W)`` + anomaly_maps: Optional anomaly maps. Can be one of: + - shape ``(B, 1, H, W)`` + - shape ``(B, H, W)`` + - shape ``(H, W)`` + Used to determine anomaly scores for converted bounding boxes. Returns: - list[torch.Tensor]: A list of length B where each element is a tensor of shape (N, 4) - containing the bounding box coordinates of the objects in the masks in xyxy format. - list[torch.Tensor]: A list of length B where each element is a tensor of length (N) - containing an anomaly score for each of the converted boxes. + Tuple containing: + - List of length ``B`` where each element is tensor of shape ``(N, 4)`` + containing bounding box coordinates in ``xyxy`` format + - List of length ``B`` where each element is tensor of length ``N`` + containing anomaly scores for each converted box + + Examples: + >>> import torch + >>> masks = torch.zeros((2, 1, 32, 32)) + >>> masks[0, 0, 10:20, 15:25] = 1 # Add box in first image + >>> boxes, scores = masks_to_boxes(masks) + >>> boxes[0] # Coordinates for first image + tensor([[15., 10., 24., 19.]]) """ height, width = masks.shape[-2:] - masks = masks.view((-1, 1, height, width)).float() # reshape to (B, 1, H, W) and cast to float + # reshape to (B, 1, H, W) and cast to float + masks = masks.view((-1, 1, height, width)).float() if anomaly_maps is not None: anomaly_maps = anomaly_maps.view((-1,) + masks.shape[-2:]) @@ -57,16 +77,22 @@ def masks_to_boxes( def boxes_to_masks(boxes: list[torch.Tensor], image_size: tuple[int, int]) -> torch.Tensor: - """Convert bounding boxes to segmentations masks. + """Convert bounding boxes to segmentation masks. Args: - boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4) - containing the bounding box coordinates of the regions of interest in xyxy format. - image_size (tuple[int, int]): Image size of the output masks in (H, W) format. + boxes: List of length ``B`` where each element is tensor of shape ``(N, 4)`` + containing bounding box coordinates in ``xyxy`` format + image_size: Output mask size as ``(H, W)`` Returns: - Tensor: torch.Tensor of shape (B, H, W) in which each slice is a binary mask showing the pixels contained by a - bounding box. + Binary masks of shape ``(B, H, W)`` where pixels contained within boxes + are set to 1 + + Examples: + >>> boxes = [torch.tensor([[10, 15, 20, 25]])] # One box in first image + >>> masks = boxes_to_masks(boxes, (32, 32)) + >>> masks.shape + torch.Size([1, 32, 32]) """ masks = torch.zeros((len(boxes), *image_size)).to(boxes[0].device) for im_idx, im_boxes in enumerate(boxes): @@ -77,19 +103,25 @@ def boxes_to_masks(boxes: list[torch.Tensor], image_size: tuple[int, int]) -> to def boxes_to_anomaly_maps(boxes: torch.Tensor, scores: torch.Tensor, image_size: tuple[int, int]) -> torch.Tensor: - """Convert bounding box coordinates to anomaly heatmaps. + """Convert bounding boxes and scores to anomaly heatmaps. Args: - boxes (list[torch.Tensor]): A list of length B where each element is a tensor of shape (N, 4) - containing the bounding box coordinates of the regions of interest in xyxy format. - scores (list[torch.Tensor]): A list of length B where each element is a 1D tensor of length N - containing the anomaly scores for each region of interest. - image_size (tuple[int, int]): Image size of the output masks in (H, W) format. + boxes: List of length ``B`` where each element is tensor of shape ``(N, 4)`` + containing bounding box coordinates in ``xyxy`` format + scores: List of length ``B`` where each element is 1D tensor of length ``N`` + containing anomaly scores for each box + image_size: Output heatmap size as ``(H, W)`` Returns: - Tensor: torch.Tensor of shape (B, H, W). The pixel locations within each bounding box are collectively - assigned the anomaly score of the bounding box. In the case of overlapping bounding boxes, - the highest score is used. + Anomaly heatmaps of shape ``(B, H, W)``. Pixels within each box are set to + that box's anomaly score. For overlapping boxes, the highest score is used. + + Examples: + >>> boxes = [torch.tensor([[10, 15, 20, 25]])] # One box + >>> scores = [torch.tensor([0.9])] # Score for the box + >>> maps = boxes_to_anomaly_maps(boxes, scores, (32, 32)) + >>> maps[0, 20, 15] # Point inside box + tensor(0.9000) """ anomaly_maps = torch.zeros((len(boxes), *image_size)).to(boxes[0].device) for im_idx, (im_boxes, im_scores) in enumerate(zip(boxes, scores, strict=False)): @@ -102,15 +134,21 @@ def boxes_to_anomaly_maps(boxes: torch.Tensor, scores: torch.Tensor, image_size: def scale_boxes(boxes: torch.Tensor, image_size: torch.Size, new_size: torch.Size) -> torch.Tensor: - """Scale bbox coordinates to a new image size. + """Scale bounding box coordinates to a new image size. Args: - boxes (torch.Tensor): Boxes of shape (N, 4) - (x1, y1, x2, y2). - image_size (Size): Size of the original image in which the bbox coordinates were retrieved. - new_size (Size): New image size to which the bbox coordinates will be scaled. + boxes: Boxes of shape ``(N, 4)`` in ``(x1, y1, x2, y2)`` format + image_size: Original image size the boxes were computed for + new_size: Target image size to scale boxes to Returns: - Tensor: Updated boxes of shape (N, 4) - (x1, y1, x2, y2). + Scaled boxes of shape ``(N, 4)`` in ``(x1, y1, x2, y2)`` format + + Examples: + >>> boxes = torch.tensor([[10, 15, 20, 25]]) + >>> scaled = scale_boxes(boxes, (32, 32), (64, 64)) + >>> scaled + tensor([[20., 30., 40., 50.]]) """ scale = torch.Tensor([*new_size]) / torch.Tensor([*image_size]) return boxes * scale.repeat(2).to(boxes.device) diff --git a/src/anomalib/data/utils/download.py b/src/anomalib/data/utils/download.py index 7df5da1403..698b2d11bf 100644 --- a/src/anomalib/data/utils/download.py +++ b/src/anomalib/data/utils/download.py @@ -1,4 +1,10 @@ -"""Helper to show progress bars with `urlretrieve`, check hash of file.""" +"""Helper functions for downloading datasets with progress bars and hash verification. + +This module provides utilities for: +- Showing progress bars during downloads with ``urlretrieve`` +- Verifying file hashes +- Safely extracting compressed files +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -23,7 +29,14 @@ @dataclass class DownloadInfo: - """Info needed to download a dataset from a url.""" + """Information needed to download a dataset from a URL. + + Args: + name: Name of the dataset + url: URL to download the dataset from + hashsum: Expected hash value of the downloaded file + filename: Optional filename to save as. If not provided, extracts from URL + """ name: str url: str @@ -32,98 +45,45 @@ class DownloadInfo: class DownloadProgressBar(tqdm): - """Create progress bar for urlretrieve. Subclasses `tqdm`. - - For information about the parameters in constructor, refer to `tqdm`'s documentation. - - Args: - iterable (Iterable | None): Iterable to decorate with a progressbar. - Leave blank to manually manage the updates. - desc (str | None): Prefix for the progressbar. - total (int | float | None): The number of expected iterations. If unspecified, - len(iterable) is used if possible. If float("inf") or as a last - resort, only basic progress statistics are displayed - (no ETA, no progressbar). - If `gui` is True and this parameter needs subsequent updating, - specify an initial arbitrary large positive number, - e.g. 9e9. - leave (bool | None): upon termination of iteration. If `None`, will leave only if `position` is `0`. - file (io.TextIOWrapper | io.StringIO | None): Specifies where to output the progress messages - (default: sys.stderr). Uses `file.write(str)` and - `file.flush()` methods. For encoding, see - `write_bytes`. - ncols (int | None): The width of the entire output message. If specified, - dynamically resizes the progressbar to stay within this bound. - If unspecified, attempts to use environment width. The - fallback is a meter width of 10 and no limit for the counter and - statistics. If 0, will not print any meter (only stats). - mininterval (float | None): Minimum progress display update interval [default: 0.1] seconds. - maxinterval (float | None): Maximum progress display update interval [default: 10] seconds. - Automatically adjusts `miniters` to correspond to `mininterval` - after long display update lag. Only works if `dynamic_miniters` - or monitor thread is enabled. - miniters (int | float | None): Minimum progress display update interval, in iterations. - If 0 and `dynamic_miniters`, will automatically adjust to equal - `mininterval` (more CPU efficient, good for tight loops). - If > 0, will skip display of specified number of iterations. - Tweak this and `mininterval` to get very efficient loops. - If your progress is erratic with both fast and slow iterations - (network, skipping items, etc) you should set miniters=1. - use_ascii (str | bool | None): If unspecified or False, use unicode (smooth blocks) to fill - the meter. The fallback is to use ASCII characters " 123456789#". - disable (bool | None): Whether to disable the entire progressbar wrapper - [default: False]. If set to None, disable on non-TTY. - unit (str | None): String that will be used to define the unit of each iteration - [default: it]. - unit_scale (int | float | bool): If 1 or True, the number of iterations will be reduced/scaled - automatically and a metric prefix following the - International System of Units standard will be added - (kilo, mega, etc.) [default: False]. If any other non-zero - number, will scale `total` and `n`. - dynamic_ncols (bool | None): If set, constantly alters `ncols` and `nrows` to the - environment (allowing for window resizes) [default: False]. - smoothing (float | None): Exponential moving average smoothing factor for speed estimates - (ignored in GUI mode). Ranges from 0 (average speed) to 1 - (current/instantaneous speed) [default: 0.3]. - bar_format (str | None): Specify a custom bar string formatting. May impact performance. - [default: '{l_bar}{bar}{r_bar}'], where - l_bar='{desc}: {percentage:3.0f}%|' and - r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' - '{rate_fmt}{postfix}]' - Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt, - percentage, elapsed, elapsed_s, ncols, nrows, desc, unit, - rate, rate_fmt, rate_noinv, rate_noinv_fmt, - rate_inv, rate_inv_fmt, postfix, unit_divisor, - remaining, remaining_s, eta. - Note that a trailing ": " is automatically removed after {desc} - if the latter is empty. - initial (int | float | None): The initial counter value. Useful when restarting a progress - bar [default: 0]. If using float, consider specifying `{n:.3f}` - or similar in `bar_format`, or specifying `unit_scale`. - position (int | None): Specify the line offset to print this bar (starting from 0) - Automatic if unspecified. - Useful to manage multiple bars at once (eg, from threads). - postfix (dict | None): Specify additional stats to display at the end of the bar. - Calls `set_postfix(**postfix)` if possible (dict). - unit_divisor (float | None): [default: 1000], ignored unless `unit_scale` is True. - write_bytes (bool | None): If (default: None) and `file` is unspecified, - bytes will be written in Python 2. If `True` will also write - bytes. In all other cases will default to unicode. - lock_args (tuple | None): Passed to `refresh` for intermediate output - (initialisation, iterating, and updating). - nrows (int | None): The screen height. If specified, hides nested bars - outside this bound. If unspecified, attempts to use environment height. - The fallback is 20. - colour (str | None): Bar colour (e.g. 'green', '#00ff00'). - delay (float | None): Don't display until [default: 0] seconds have elapsed. - gui (bool | None): WARNING: internal parameter - do not use. - Use tqdm.gui.tqdm(...) instead. If set, will attempt to use - matplotlib animations for a graphical output [default: False]. + """Progress bar for ``urlretrieve`` downloads. + Subclasses ``tqdm`` to provide a progress bar during file downloads. Example: - >>> with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as p_bar: - >>> urllib.request.urlretrieve(url, filename=output_path, reporthook=p_bar.update_to) + >>> url = "https://example.com/file.zip" + >>> output_path = "file.zip" + >>> with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, + ... desc=url.split('/')[-1]) as p_bar: + ... urlretrieve(url, filename=output_path, + ... reporthook=p_bar.update_to) + + Args: + iterable: Iterable to decorate with a progressbar + desc: Prefix for the progressbar + total: Expected number of iterations + leave: Whether to leave the progress bar after completion + file: Output stream for progress messages + ncols: Width of the progress bar + mininterval: Minimum update interval in seconds + maxinterval: Maximum update interval in seconds + miniters: Minimum progress display update interval in iterations + use_ascii: Whether to use ASCII characters for the progress bar + disable: Whether to disable the progress bar + unit: Unit of measurement + unit_scale: Whether to scale units automatically + dynamic_ncols: Whether to adapt to terminal resizes + smoothing: Exponential moving average smoothing factor + bar_format: Custom progress bar format string + initial: Initial counter value + position: Line offset for printing + postfix: Additional stats to display + unit_divisor: Unit divisor for scaling + write_bytes: Whether to write bytes + lock_args: Arguments passed to refresh + nrows: Screen height + colour: Bar color + delay: Display delay in seconds + gui: Whether to use matplotlib animations """ def __init__( @@ -187,17 +147,21 @@ def __init__( ) self.total: int | float | None - def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size: int | None = None) -> None: - """Progress bar hook for tqdm. + def update_to( + self, + chunk_number: int = 1, + max_chunk_size: int = 1, + total_size: int | None = None, + ) -> None: + """Update progress bar based on download progress. - Based on https://stackoverflow.com/a/53877507 - The implementor does not have to bother about passing parameters to this as it gets them from urlretrieve. - However the context needs a few parameters. Refer to the example. + This method is used as a callback for ``urlretrieve`` to update the + progress bar during downloads. Args: - chunk_number (int, optional): The current chunk being processed. Defaults to 1. - max_chunk_size (int, optional): Maximum size of each chunk. Defaults to 1. - total_size (int, optional): Total download size. Defaults to None. + chunk_number: Current chunk being processed + max_chunk_size: Maximum size of each chunk + total_size: Total download size """ if total_size is not None: self.total = total_size @@ -205,14 +169,13 @@ def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size: def is_file_potentially_dangerous(file_name: str) -> bool: - """Check if a file is potentially dangerous. + """Check if a file path contains potentially dangerous patterns. Args: - file_name (str): Filename. + file_name: Path to check Returns: - bool: True if the member is potentially dangerous, False otherwise. - + ``True`` if the path matches unsafe patterns, ``False`` otherwise """ # Some example criteria. We could expand this. unsafe_patterns = ["/etc/", "/root/"] @@ -220,13 +183,12 @@ def is_file_potentially_dangerous(file_name: str) -> bool: def safe_extract(tar_file: TarFile, root: Path, members: list[TarInfo]) -> None: - """Extract safe members from a tar archive. + """Safely extract members from a tar archive. Args: - tar_file (TarFile): TarFile object. - root (Path): Root directory where the dataset will be stored. - members (List[TarInfo]): List of safe members to be extracted. - + tar_file: TarFile object to extract from + root: Root directory for extraction + members: List of safe members to extract """ for member in members: # check if the file already exists @@ -238,14 +200,14 @@ def generate_hash(file_path: str | Path, algorithm: str = "sha256") -> str: """Generate a hash of a file using the specified algorithm. Args: - file_path (str | Path): Path to the file to hash. - algorithm (str): The hashing algorithm to use (e.g., 'sha256', 'sha3_512'). + file_path: Path to the file to hash + algorithm: Hashing algorithm to use (e.g. 'sha256', 'sha3_512') Returns: - str: The hexadecimal hash string of the file. + Hexadecimal hash string of the file Raises: - ValueError: If the specified hashing algorithm is not supported. + ValueError: If the specified hashing algorithm is not supported """ # Get the hashing algorithm. try: @@ -264,30 +226,37 @@ def generate_hash(file_path: str | Path, algorithm: str = "sha256") -> str: def check_hash(file_path: Path, expected_hash: str, algorithm: str = "sha256") -> None: - """Raise value error if hash does not match the calculated hash of the file. + """Verify that a file's hash matches the expected value. Args: - file_path (Path): Path to file. - expected_hash (str): Expected hash of the file. - algorithm (str): Hashing algorithm to use ('sha256', 'sha3_512', etc.). + file_path: Path to file to check + expected_hash: Expected hash value + algorithm: Hashing algorithm to use + + Raises: + ValueError: If the calculated hash does not match the expected hash """ # Compare the calculated hash with the expected hash calculated_hash = generate_hash(file_path, algorithm) if calculated_hash != expected_hash: msg = ( - f"Calculated hash {calculated_hash} of downloaded file {file_path} does not match the required hash " - f"{expected_hash}." + f"Calculated hash {calculated_hash} of downloaded file {file_path} " + f"does not match the required hash {expected_hash}." ) raise ValueError(msg) def extract(file_name: Path, root: Path) -> None: - """Extract a dataset. + """Extract a compressed dataset file. + + Supports .zip, .tar, .gz, .xz and .tgz formats. Args: - file_name (Path): Path of the file to be extracted. - root (Path): Root directory where the dataset will be stored. + file_name: Path of the file to extract + root: Root directory for extraction + Raises: + ValueError: If the file format is not recognized """ logger.info(f"Extracting dataset into {root} folder.") @@ -317,12 +286,15 @@ def download_and_extract(root: Path, info: DownloadInfo) -> None: """Download and extract a dataset. Args: - root (Path): Root directory where the dataset will be stored. - info (DownloadInfo): Info needed to download the dataset. + root: Root directory where the dataset will be stored + info: Download information for the dataset + + Raises: + RuntimeError: If the URL scheme is not http(s) """ root.mkdir(parents=True, exist_ok=True) - # save the compressed file in the specified root directory, using the same file name as on the server + # save the compressed file in the specified root directory downloaded_file_path = root / info.filename if info.filename else root / info.url.split("/")[-1] if downloaded_file_path.exists(): @@ -350,16 +322,17 @@ def is_within_directory(directory: Path, target: Path) -> bool: """Check if a target path is located within a given directory. Args: - directory (Path): path of the parent directory - target (Path): path of the target + directory: Path of the parent directory + target: Path to check Returns: - (bool): True if the target is within the directory, False otherwise + ``True`` if target is within directory, ``False`` otherwise """ abs_directory = directory.resolve() abs_target = target.resolve() - # TODO(djdameln): Replace with pathlib is_relative_to after switching to Python 3.10 + # TODO(djdameln): Replace with pathlib is_relative_to after switching to + # Python 3.10 # CVS-122655 prefix = os.path.commonprefix([abs_directory, abs_target]) return prefix == str(abs_directory) diff --git a/src/anomalib/data/utils/generators/__init__.py b/src/anomalib/data/utils/generators/__init__.py index c46f30d08e..c9c0410c03 100644 --- a/src/anomalib/data/utils/generators/__init__.py +++ b/src/anomalib/data/utils/generators/__init__.py @@ -1,6 +1,26 @@ -"""Utilities to generate synthetic data.""" +"""Utilities to generate synthetic data. -# Copyright (C) 2022 Intel Corporation +This module provides utilities for generating synthetic data for anomaly detection. +The utilities include: + +- Perlin noise generation: Functions for creating Perlin noise patterns +- Anomaly generation: Classes for generating synthetic anomalies + +Example: + >>> from anomalib.data.utils.generators import generate_perlin_noise + >>> # Generate 256x256 Perlin noise + >>> noise = generate_perlin_noise(256, 256) + >>> print(noise.shape) + torch.Size([256, 256]) + + >>> from anomalib.data.utils.generators import PerlinAnomalyGenerator + >>> # Create anomaly generator + >>> generator = PerlinAnomalyGenerator() + >>> # Generate anomaly mask + >>> mask = generator.generate(256, 256) +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .perlin import PerlinAnomalyGenerator, generate_perlin_noise diff --git a/src/anomalib/data/utils/generators/perlin.py b/src/anomalib/data/utils/generators/perlin.py index 052d565121..acdbcb56ef 100644 --- a/src/anomalib/data/utils/generators/perlin.py +++ b/src/anomalib/data/utils/generators/perlin.py @@ -1,4 +1,26 @@ -"""Perlin noise-based synthetic anomaly generator.""" +"""Perlin noise-based synthetic anomaly generator. + +This module provides functionality to generate synthetic anomalies using Perlin noise +patterns. The generator can create both noise-based and image-based anomalies with +configurable parameters. + +Example: + >>> from anomalib.data.utils.generators.perlin import generate_perlin_noise + >>> import torch + >>> # Generate 256x256 noise with default random scale + >>> noise = generate_perlin_noise(256, 256) + >>> print(noise.shape) + torch.Size([256, 256]) + + >>> # Generate 512x512 noise with fixed scale + >>> noise = generate_perlin_noise(512, 512, scale=(8, 8)) + >>> print(noise.shape) + torch.Size([512, 512]) + + >>> # Generate noise on GPU if available + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + >>> noise = generate_perlin_noise(128, 128, device=device) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -21,21 +43,25 @@ def generate_perlin_noise( ) -> torch.Tensor: """Generate a Perlin noise pattern. - This function generates a Perlin noise pattern using a grid-based gradient noise approach. - The noise is generated by interpolating between randomly generated gradient vectors at grid vertices. - The interpolation uses a quintic curve for smooth transitions. + This function generates a Perlin noise pattern using a grid-based gradient noise + approach. The noise is generated by interpolating between randomly generated + gradient vectors at grid vertices. The interpolation uses a quintic curve for + smooth transitions. Args: - height: Desired height of the noise pattern - width: Desired width of the noise pattern - scale: Tuple of (scale_x, scale_y) for noise granularity. If None, random scales will be used. - Larger scales produce coarser noise patterns, while smaller scales produce finer patterns. - device: Device to generate the noise on. If None, uses current default device + height: Desired height of the noise pattern. + width: Desired width of the noise pattern. + scale: Tuple of ``(scale_x, scale_y)`` for noise granularity. If ``None``, + random scales will be used. Larger scales produce coarser noise patterns, + while smaller scales produce finer patterns. + device: Device to generate the noise on. If ``None``, uses current default + device. Returns: - Tensor of shape [height, width] containing the noise pattern, with values roughly in [-1, 1] range + torch.Tensor: Tensor of shape ``[height, width]`` containing the noise + pattern, with values roughly in ``[-1, 1]`` range. - Examples: + Example: >>> # Generate 256x256 noise with default random scale >>> noise = generate_perlin_noise(256, 256) >>> print(noise.shape) @@ -128,7 +154,23 @@ def fade(t: torch.Tensor) -> torch.Tensor: class PerlinAnomalyGenerator(v2.Transform): """Perlin noise-based synthetic anomaly generator. - Examples: + This class provides functionality to generate synthetic anomalies using Perlin + noise patterns. It can also use real anomaly source images for more realistic + anomaly generation. + + Args: + anomaly_source_path: Optional path to directory containing anomaly source + images. If provided, these images will be used instead of Perlin noise + patterns. + probability: Probability of applying the anomaly transformation to an image. + Default: ``0.5``. + blend_factor: Factor determining how much of the anomaly to blend with the + original image. Can be a float or a tuple of ``(min, max)``. Default: + ``(0.2, 1.0)``. + rotation_range: Range of rotation angles in degrees for the Perlin noise + pattern. Default: ``(-90, 90)``. + + Example: >>> # Single image usage with default parameters >>> transform = PerlinAnomalyGenerator() >>> image = torch.randn(3, 256, 256) # [C, H, W] @@ -215,13 +257,15 @@ def generate_perturbation( """Generate perturbed image and mask. Args: - height: Height of the output image - width: Width of the output image - device: Device to generate the perturbation on - anomaly_source_path: Optional path to source image for anomaly + height: Height of the output image. + width: Width of the output image. + device: Device to generate the perturbation on. + anomaly_source_path: Optional path to source image for anomaly. Returns: - tuple[torch.Tensor, torch.Tensor]: Perturbation and mask tensors + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - Perturbation tensor of shape ``[H, W, C]`` + - Mask tensor of shape ``[H, W, 1]`` """ # Generate perlin noise perlin_noise = generate_perlin_noise(height, width, device=device) @@ -265,7 +309,19 @@ def _transform_image( w: int, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - """Transform a single image.""" + """Transform a single image. + + Args: + img: Input image tensor of shape ``[C, H, W]``. + h: Height of the image. + w: Width of the image. + device: Device to perform the transformation on. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - Augmented image tensor of shape ``[C, H, W]`` + - Mask tensor of shape ``[1, H, W]`` + """ if torch.rand(1, device=device) > self.probability: return img, torch.zeros((1, h, w), device=device) @@ -295,7 +351,17 @@ def _transform_image( return augmented_img, mask def forward(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Apply augmentation using the mask for single image or batch.""" + """Apply augmentation using the mask for single image or batch. + + Args: + img: Input image tensor of shape ``[C, H, W]`` or batch tensor of shape + ``[B, C, H, W]``. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - Augmented image tensor of same shape as input + - Mask tensor of shape ``[1, H, W]`` or ``[B, 1, H, W]`` + """ device = img.device is_batch = len(img.shape) == 4 diff --git a/src/anomalib/data/utils/image.py b/src/anomalib/data/utils/image.py index 64a27724cc..0f0ed0c255 100644 --- a/src/anomalib/data/utils/image.py +++ b/src/anomalib/data/utils/image.py @@ -1,4 +1,25 @@ -"""Image Utils.""" +"""Image utilities for reading, writing and processing images. + +This module provides various utility functions for handling images in Anomalib: + +- Reading images in various formats (RGB, grayscale, depth) +- Writing images to disk +- Converting between different image formats +- Processing images (padding, resizing etc.) +- Handling image filenames and paths + +Example: + >>> from anomalib.data.utils import read_image + >>> # Read image as numpy array + >>> image = read_image("image.jpg") + >>> print(type(image)) + + + >>> # Read image as tensor + >>> image = read_image("image.jpg", as_tensor=True) + >>> print(type(image)) + +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -25,25 +46,22 @@ def is_image_file(filename: str | Path) -> bool: - """Check if the filename is an image file. + """Check if the filename has a valid image extension. Args: - filename (str | Path): Filename to check. + filename (str | Path): Path to file to check Returns: - bool: True if the filename is an image file. + bool: ``True`` if filename has valid image extension Examples: - >>> is_image_file("000.png") + >>> is_image_file("image.jpg") True - >>> is_image_file("002.JPEG") + >>> is_image_file("image.png") True - >>> is_image_file("009.tiff") - True - - >>> is_image_file("002.avi") + >>> is_image_file("image.txt") False """ filename = Path(filename) @@ -51,39 +69,31 @@ def is_image_file(filename: str | Path) -> bool: def get_image_filename(filename: str | Path) -> Path: - """Get image filename. + """Get validated image filename. Args: - filename (str | Path): Filename to check. + filename (str | Path): Path to image file Returns: - Path: Image filename. - - Examples: - Assume that we have the following files in the directory: - - .. code-block:: bash + Path: Validated path to image file - $ ls - 000.png 001.jpg 002.JPEG 003.tiff 004.png 005.txt - - >>> get_image_filename("000.png") - PosixPath('000.png') + Raises: + FileNotFoundError: If file does not exist + ValueError: If file is not an image - >>> get_image_filename("001.jpg") - PosixPath('001.jpg') + Examples: + >>> get_image_filename("image.jpg") + PosixPath('image.jpg') - >>> get_image_filename("009.tiff") + >>> get_image_filename("missing.jpg") Traceback (most recent call last): - File "", line 1, in - File "", line 18, in get_image_filename - FileNotFoundError: File not found: 009.tiff + ... + FileNotFoundError: File not found: missing.jpg - >>> get_image_filename("005.txt") + >>> get_image_filename("text.txt") Traceback (most recent call last): - File "", line 1, in - File "", line 18, in get_image_filename - ValueError: ``filename`` is not an image file. 005.txt + ... + ValueError: ``filename`` is not an image file: text.txt """ filename = Path(filename) @@ -98,31 +108,25 @@ def get_image_filename(filename: str | Path) -> Path: def get_image_filenames_from_dir(path: str | Path) -> list[Path]: - """Get image filenames from directory. + """Get list of image filenames from directory. Args: - path (str | Path): Path to image directory. - - Raises: - ValueError: When ``path`` is not a directory. + path (str | Path): Path to directory containing images Returns: - list[Path]: Image filenames. + list[Path]: List of paths to image files - Examples: - Assume that we have the following files in the directory: - $ ls - 000.png 001.jpg 002.JPEG 003.tiff 004.png 005.png + Raises: + ValueError: If path is not a directory or no images found - >>> get_image_filenames_from_dir(".") - [PosixPath('000.png'), PosixPath('001.jpg'), PosixPath('002.JPEG'), - PosixPath('003.tiff'), PosixPath('004.png'), PosixPath('005.png')] + Examples: + >>> get_image_filenames_from_dir("images/") + [PosixPath('images/001.jpg'), PosixPath('images/002.png')] - >>> get_image_filenames_from_dir("009.tiff") + >>> get_image_filenames_from_dir("empty/") Traceback (most recent call last): - File "", line 1, in - File "", line 18, in get_image_filenames_from_dir - ValueError: ``path`` is not a directory: 009.tiff + ... + ValueError: Found 0 images in empty/ """ path = Path(path) if not path.is_dir(): @@ -139,50 +143,26 @@ def get_image_filenames_from_dir(path: str | Path) -> list[Path]: def get_image_filenames(path: str | Path, base_dir: str | Path | None = None) -> list[Path]: - """Get image filenames. + """Get list of image filenames from path. Args: - path (str | Path): Path to image or image-folder. - base_dir (Path): Base directory to restrict file access. + path (str | Path): Path to image file or directory + base_dir (str | Path | None): Base directory to restrict file access Returns: - list[Path]: List of image filenames. + list[Path]: List of paths to image files Examples: - Assume that we have the following files in the directory: + >>> get_image_filenames("image.jpg") + [PosixPath('image.jpg')] - .. code-block:: bash + >>> get_image_filenames("images/") + [PosixPath('images/001.jpg'), PosixPath('images/002.png')] - $ tree images - images - ├── bad - │ ├── 003.png - │ └── 004.jpg - └── good - ├── 000.png - └── 001.tiff - - We can get the image filenames with various ways: - - >>> get_image_filenames("images/bad/003.png") - PosixPath('/home/sakcay/Projects/anomalib/images/bad/003.png')] - - It is possible to recursively get the image filenames from a directory: - - >>> get_image_filenames("images") - [PosixPath('/home/sakcay/Projects/anomalib/images/bad/003.png'), - PosixPath('/home/sakcay/Projects/anomalib/images/bad/004.jpg'), - PosixPath('/home/sakcay/Projects/anomalib/images/good/001.tiff'), - PosixPath('/home/sakcay/Projects/anomalib/images/good/000.png')] - - If we want to restrict the file access to a specific directory, - we can use ``base_dir`` argument. - - >>> get_image_filenames("images", base_dir="images/bad") + >>> get_image_filenames("images/", base_dir="allowed/") Traceback (most recent call last): - File "", line 1, in - File "", line 18, in get_image_filenames - ValueError: Access denied: Path is outside the allowed directory. + ... + ValueError: Access denied: Path is outside the allowed directory """ path = validate_path(path, base_dir) image_filenames: list[Path] = [] @@ -199,24 +179,23 @@ def get_image_filenames(path: str | Path, base_dir: str | Path | None = None) -> def duplicate_filename(path: str | Path) -> Path: - """Check and duplicate filename. - - This function checks the path and adds a suffix if it already exists on the file system. + """Add numeric suffix to filename if it already exists. Args: - path (str | Path): Input Path + path (str | Path): Path to file + + Returns: + Path: Path with numeric suffix if original exists Examples: - >>> path = Path("datasets/MVTec/bottle/test/broken_large/000.png") - >>> path.exists() - True + >>> duplicate_filename("image.jpg") # File doesn't exist + PosixPath('image.jpg') - If we pass this to ``duplicate_filename`` function we would get the following: - >>> duplicate_filename(path) - PosixPath('datasets/MVTec/bottle/test/broken_large/000_1.png') + >>> duplicate_filename("exists.jpg") # File exists + PosixPath('exists_1.jpg') - Returns: - Path: Duplicated output path. + >>> duplicate_filename("exists.jpg") # Both exist + PosixPath('exists_2.jpg') """ path = Path(path) @@ -234,54 +213,36 @@ def duplicate_filename(path: str | Path) -> Path: def generate_output_image_filename(input_path: str | Path, output_path: str | Path) -> Path: - """Generate an output filename to save the inference image. - - This function generates an output filaname by checking the input and output filenames. Input path is - the input to infer, and output path is the path to save the output predictions specified by the user. - - The function expects ``input_path`` to always be a file, not a directory. ``output_path`` could be a - filename or directory. If it is a filename, the function checks if the specified filename exists on - the file system. If yes, the function calls ``duplicate_filename`` to duplicate the filename to avoid - overwriting the existing file. If ``output_path`` is a directory, this function adds the parent and - filenames of ``input_path`` to ``output_path``. + """Generate output filename for inference image. Args: - input_path (str | Path): Path to the input image to infer. - output_path (str | Path): Path to output to save the predictions. - Could be a filename or a directory. + input_path (str | Path): Path to input image + output_path (str | Path): Path to save output (file or directory) - Examples: - >>> input_path = Path("datasets/MVTec/bottle/test/broken_large/000.png") - >>> output_path = Path("datasets/MVTec/bottle/test/broken_large/000.png") - >>> generate_output_image_filename(input_path, output_path) - PosixPath('datasets/MVTec/bottle/test/broken_large/000_1.png') - - >>> input_path = Path("datasets/MVTec/bottle/test/broken_large/000.png") - >>> output_path = Path("results/images") - >>> generate_output_image_filename(input_path, output_path) - PosixPath('results/images/broken_large/000.png') + Returns: + Path: Generated output filename Raises: - ValueError: When the ``input_path`` is not a file. + ValueError: If input_path is not a file - Returns: - Path: The output filename to save the output predictions from the inferencer. + Examples: + >>> generate_output_image_filename("input.jpg", "output.jpg") + PosixPath('output.jpg') # or output_1.jpg if exists + + >>> generate_output_image_filename("dir/input.jpg", "outdir") + PosixPath('outdir/dir/input.jpg') """ input_path = validate_path(input_path) output_path = validate_path(output_path, should_exist=False) - # Input validation: Check if input_path is a valid directory or file - if input_path.is_file() is False: - msg = "input_path is expected to be a file to generate a proper output filename." + if not input_path.is_file(): + msg = "input_path is expected to be a file" raise ValueError(msg) - # If the output is a directory, then add parent directory name - # and filename to the path. This is to ensure we do not overwrite - # images and organize based on the categories. if output_path.is_dir(): output_image_filename = output_path / input_path.parent.name / input_path.name elif output_path.is_file() and output_path.exists(): - msg = f"{output_path} already exists. Renaming the file to avoid overwriting." + msg = f"{output_path} already exists. Renaming to avoid overwriting." logger.warning(msg) output_image_filename = duplicate_filename(output_path) else: @@ -293,32 +254,28 @@ def generate_output_image_filename(input_path: str | Path, output_path: str | Pa def get_image_height_and_width(image_size: int | Sequence[int]) -> tuple[int, int]: - """Get image height and width from ``image_size`` variable. + """Get height and width from image size parameter. Args: - image_size (int | Sequence[int] | None, optional): Input image size. + image_size (int | Sequence[int]): Single int for square, or (H,W) sequence + + Returns: + tuple[int, int]: Image height and width Raises: - ValueError: Image size not None, int or Sequence of values. + TypeError: If image_size is not int or sequence of ints Examples: - >>> get_image_height_and_width(image_size=256) + >>> get_image_height_and_width(256) (256, 256) - >>> get_image_height_and_width(image_size=(256, 256)) - (256, 256) + >>> get_image_height_and_width((480, 640)) + (480, 640) - >>> get_image_height_and_width(image_size=(256, 256, 3)) - (256, 256) - - >>> get_image_height_and_width(image_size=256.) + >>> get_image_height_and_width(256.0) Traceback (most recent call last): - File "", line 1, in - File "", line 18, in get_image_height_and_width - ValueError: ``image_size`` could be either int or tuple[int, int] - - Returns: - tuple[int | None, int | None]: A tuple containing image height and width values. + ... + TypeError: ``image_size`` could be either int or tuple[int, int] """ if isinstance(image_size, int): height_and_width = (image_size, image_size) @@ -332,41 +289,44 @@ def get_image_height_and_width(image_size: int | Sequence[int]) -> tuple[int, in def read_image(path: str | Path, as_tensor: bool = False) -> torch.Tensor | np.ndarray: - """Read image from disk in RGB format. + """Read RGB image from disk. Args: - path (str, Path): path to the image file - as_tensor (bool, optional): If True, returns the image as a tensor. Defaults to False. + path (str | Path): Path to image file + as_tensor (bool): If ``True``, return torch.Tensor. Defaults to ``False`` - Example: - >>> image = read_image("test_image.jpg") + Returns: + torch.Tensor | np.ndarray: Image as tensor or array, normalized to [0,1] + + Examples: + >>> image = read_image("image.jpg") >>> type(image) - >>> - >>> image = read_image("test_image.jpg", as_tensor=True) + + >>> image = read_image("image.jpg", as_tensor=True) >>> type(image) - - Returns: - image as numpy array """ image = Image.open(path).convert("RGB") return to_dtype(to_image(image), torch.float32, scale=True) if as_tensor else np.array(image) / 255.0 def read_mask(path: str | Path, as_tensor: bool = False) -> torch.Tensor | np.ndarray: - """Read mask from disk. + """Read grayscale mask from disk. Args: - path (str, Path): path to the mask file - as_tensor (bool, optional): If True, returns the mask as a tensor. Defaults to False. + path (str | Path): Path to mask file + as_tensor (bool): If ``True``, return torch.Tensor. Defaults to ``False`` + + Returns: + torch.Tensor | np.ndarray: Mask as tensor or array - Example: - >>> mask = read_mask("test_mask.png") + Examples: + >>> mask = read_mask("mask.png") >>> type(mask) - >>> - >>> mask = read_mask("test_mask.png", as_tensor=True) + + >>> mask = read_mask("mask.png", as_tensor=True) >>> type(mask) """ @@ -375,34 +335,40 @@ def read_mask(path: str | Path, as_tensor: bool = False) -> torch.Tensor | np.nd def read_depth_image(path: str | Path) -> np.ndarray: - """Read tiff depth image from disk. + """Read depth image from TIFF file. Args: - path (str, Path): path to the image file - - Example: - >>> image = read_depth_image("test_image.tiff") + path (str | Path): Path to TIFF depth image Returns: - image as numpy array + np.ndarray: Depth image array + + Examples: + >>> depth = read_depth_image("depth.tiff") + >>> type(depth) + """ path = path if isinstance(path, str) else str(path) return tiff.imread(path) def pad_nextpow2(batch: torch.Tensor) -> torch.Tensor: - """Compute required padding from input size and return padded images. + """Pad images to next power of 2 size. - Finds the largest dimension and computes a square image of dimensions that are of the power of 2. - In case the image dimension is odd, it returns the image with an extra padding on one side. + Finds largest dimension and pads to square power-of-2 size. Handles odd sizes. Args: - batch (torch.Tensor): Input images + batch (torch.Tensor): Batch of images to pad Returns: - batch: Padded batch + torch.Tensor: Padded image batch + + Examples: + >>> x = torch.randn(1, 3, 127, 128) + >>> padded = pad_nextpow2(x) + >>> padded.shape + torch.Size([1, 3, 128, 128]) """ - # find the largest dimension l_dim = 2 ** math.ceil(math.log(max(*batch.shape[-2:]), 2)) padding_w = [math.ceil((l_dim - batch.shape[-2]) / 2), math.floor((l_dim - batch.shape[-2]) / 2)] padding_h = [math.ceil((l_dim - batch.shape[-1]) / 2), math.floor((l_dim - batch.shape[-1]) / 2)] @@ -410,11 +376,15 @@ def pad_nextpow2(batch: torch.Tensor) -> torch.Tensor: def show_image(image: np.ndarray | Figure, title: str = "Image") -> None: - """Show an image on the screen. + """Display image in window. Args: - image (np.ndarray | Figure): Image that will be shown in the window. - title (str, optional): Title that will be given to that window. Defaults to "Image". + image (np.ndarray | Figure): Image or matplotlib figure to display + title (str): Window title. Defaults to "Image" + + Examples: + >>> img = read_image("image.jpg") + >>> show_image(img, title="My Image") """ if isinstance(image, Figure): image = figure_to_array(image) @@ -425,13 +395,18 @@ def show_image(image: np.ndarray | Figure, title: str = "Image") -> None: def save_image(filename: Path | str, image: np.ndarray | Figure, root: Path | None = None) -> None: - """Save an image to the file system. + """Save image to disk. Args: - filename (Path | str): Path or filename to which the image will be saved. - image (np.ndarray | Figure): Image that will be saved to the file system. - root (Path, optional): Root directory to save the image. If provided, the top level directory of an absolute - filename will be overwritten. Defaults to None. + filename (Path | str): Output filename + image (np.ndarray | Figure): Image or matplotlib figure to save + root (Path | None): Optional root dir to save under. Defaults to None + + Examples: + >>> img = read_image("input.jpg") + >>> save_image("output.jpg", img) + + >>> save_image("subdir/output.jpg", img, root=Path("results")) """ if isinstance(image, Figure): image = figure_to_array(image) @@ -453,13 +428,21 @@ def save_image(filename: Path | str, image: np.ndarray | Figure, root: Path | No def figure_to_array(fig: Figure) -> np.ndarray: - """Convert a matplotlib figure to a numpy array. + """Convert matplotlib figure to numpy array. Args: - fig (Figure): Matplotlib figure. + fig (Figure): Matplotlib figure to convert Returns: - np.ndarray: Numpy array containing the image. + np.ndarray: RGB image array + + Examples: + >>> import matplotlib.pyplot as plt + >>> fig = plt.figure() + >>> plt.plot([1, 2, 3]) + >>> img = figure_to_array(fig) + >>> type(img) + """ fig.canvas.draw() # convert figure to np.ndarray for saving via visualizer diff --git a/src/anomalib/data/utils/label.py b/src/anomalib/data/utils/label.py index 28908c8169..ce12b8bfb2 100644 --- a/src/anomalib/data/utils/label.py +++ b/src/anomalib/data/utils/label.py @@ -1,4 +1,20 @@ -"""Label name enum class.""" +"""Label name enumeration class. + +This module defines an enumeration class for labeling data in anomaly detection tasks. +The labels are represented as integers, where: + +- ``NORMAL`` (0): Represents normal/good samples +- ``ABNORMAL`` (1): Represents anomalous/defective samples + +Example: + >>> from anomalib.data.utils.label import LabelName + >>> label = LabelName.NORMAL + >>> label.value + 0 + >>> label = LabelName.ABNORMAL + >>> label.value + 1 +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -7,7 +23,16 @@ class LabelName(int, Enum): - """Name of label.""" + """Enumeration class for labeling data in anomaly detection. + + This class inherits from both ``int`` and ``Enum`` to create an integer-based + enumeration. This allows for easy comparison and conversion between label + names and their corresponding integer values. + + Attributes: + NORMAL (int): Label value 0, representing normal/good samples + ABNORMAL (int): Label value 1, representing anomalous/defective samples + """ NORMAL = 0 ABNORMAL = 1 diff --git a/src/anomalib/data/utils/path.py b/src/anomalib/data/utils/path.py index 7bc61b27fe..4889cab0ec 100644 --- a/src/anomalib/data/utils/path.py +++ b/src/anomalib/data/utils/path.py @@ -1,4 +1,23 @@ -"""Path Utils.""" +"""Path utilities for handling file paths in anomalib. + +This module provides utilities for: + +- Validating and resolving file paths +- Checking path length and character restrictions +- Converting between path types +- Handling file extensions +- Managing directory types for anomaly detection + +Example: + >>> from anomalib.data.utils.path import validate_path + >>> path = validate_path("./datasets/MVTec/bottle/train/good/000.png") + >>> print(path) + PosixPath('/abs/path/to/anomalib/datasets/MVTec/bottle/train/good/000.png') + + >>> from anomalib.data.utils.path import DirType + >>> print(DirType.NORMAL) + normal +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,7 +31,17 @@ class DirType(str, Enum): - """Dir type names.""" + """Directory type names for organizing anomaly detection datasets. + + Attributes: + NORMAL: Directory containing normal/good samples for training + ABNORMAL: Directory containing anomalous/defective samples + NORMAL_TEST: Directory containing normal test samples + NORMAL_DEPTH: Directory containing depth maps for normal samples + ABNORMAL_DEPTH: Directory containing depth maps for abnormal samples + NORMAL_TEST_DEPTH: Directory containing depth maps for normal test samples + MASK: Directory containing ground truth segmentation masks + """ NORMAL = "normal" ABNORMAL = "abnormal" @@ -24,13 +53,18 @@ class DirType(str, Enum): def _check_and_convert_path(path: str | Path) -> Path: - """Check an input path, and convert to Pathlib object. + """Check and convert input path to pathlib object. Args: - path (str | Path): Input path. + path: Input path as string or Path object Returns: - Path: Output path converted to pathlib object. + Path object of the input path + + Example: + >>> path = _check_and_convert_path("./datasets/example.png") + >>> isinstance(path, Path) + True """ if not isinstance(path, Path): path = Path(path) @@ -42,16 +76,25 @@ def _prepare_files_labels( path_type: str, extensions: tuple[str, ...] | None = None, ) -> tuple[list, list]: - """Return a list of filenames and list corresponding labels. + """Get lists of filenames and corresponding labels from a directory. Args: - path (str | Path): Path to the directory containing images. - path_type (str): Type of images in the provided path ("normal", "abnormal", "normal_test") - extensions (tuple[str, ...] | None, optional): Type of the image extensions to read from the - directory. + path: Path to directory containing images + path_type: Type of images ("normal", "abnormal", "normal_test") + extensions: Allowed file extensions. Defaults to ``IMG_EXTENSIONS`` Returns: - List, List: Filenames of the images provided in the paths, labels of the images provided in the paths + Tuple containing: + - List of image filenames + - List of corresponding labels + + Raises: + RuntimeError: If no valid images found or extensions don't start with dot + + Example: + >>> files, labels = _prepare_files_labels("./normal", "normal", (".png",)) + >>> len(files) == len(labels) + True """ path = _check_and_convert_path(path) if extensions is None: @@ -79,14 +122,19 @@ def _prepare_files_labels( def resolve_path(folder: str | Path, root: str | Path | None = None) -> Path: - """Combine root and folder and returns the absolute path. - - This allows users to pass either a root directory and relative paths, or absolute paths to each of the - image sources. This function makes sure that the samples dataframe always contains absolute paths. + """Combine root and folder paths into absolute path. Args: - folder (str | Path | None): Folder location containing image or mask data. - root (str | Path | None): Root directory for the dataset. + folder: Folder location containing image or mask data + root: Optional root directory for the dataset + + Returns: + Absolute path combining root and folder + + Example: + >>> path = resolve_path("subdir", "/root") + >>> path.is_absolute() + True """ folder = Path(folder) if folder.is_absolute(): @@ -102,40 +150,37 @@ def resolve_path(folder: str | Path, root: str | Path | None = None) -> Path: def is_path_too_long(path: str | Path, max_length: int = 512) -> bool: - r"""Check if the path contains too long input. + """Check if path exceeds maximum allowed length. Args: - path (str | Path): Path to check. - max_length (int): Maximum length a path can be before it is considered too long. - Defaults to ``512``. + path: Path to check + max_length: Maximum allowed path length. Defaults to ``512`` Returns: - bool: True if the path contains too long input, False otherwise. + ``True`` if path is too long, ``False`` otherwise - Examples: - >>> contains_too_long_input("./datasets/MVTec/bottle/train/good/000.png") + Example: + >>> is_path_too_long("short_path.txt") False - - >>> contains_too_long_input("./datasets/MVTec/bottle/train/good/000.png" + "a" * 4096) + >>> is_path_too_long("a" * 1000) True """ return len(str(path)) > max_length def contains_non_printable_characters(path: str | Path) -> bool: - r"""Check if the path contains non-printable characters. + r"""Check if path contains non-printable characters. Args: - path (str | Path): Path to check. + path: Path to check Returns: - bool: True if the path contains non-printable characters, False otherwise. + ``True`` if path contains non-printable chars, ``False`` otherwise - Examples: - >>> contains_non_printable_characters("./datasets/MVTec/bottle/train/good/000.png") + Example: + >>> contains_non_printable_characters("normal.txt") False - - >>> contains_non_printable_characters("./datasets/MVTec/bottle/train/good/000.png\0") + >>> contains_non_printable_characters("test\x00.txt") True """ printable_pattern = re.compile(r"^[\x20-\x7E]+$") @@ -148,46 +193,27 @@ def validate_path( should_exist: bool = True, extensions: tuple[str, ...] | None = None, ) -> Path: - """Validate the path. + """Validate path for existence, permissions and extension. Args: - path (str | Path): Path to validate. - base_dir (str | Path): Base directory to restrict file access. - should_exist (bool): If True, do not raise an exception if the path does not exist. - extensions (tuple[str, ...] | None): Accepted extensions for the path. An exception is raised if the - path does not have one of the accepted extensions. If None, no check is performed. Defaults to None. + path: Path to validate + base_dir: Base directory to restrict file access + should_exist: If ``True``, verify path exists + extensions: Allowed file extensions Returns: - Path: Validated path. - - Examples: - >>> validate_path("./datasets/MVTec/bottle/train/good/000.png") - PosixPath('/abs/path/to/anomalib/datasets/MVTec/bottle/train/good/000.png') - - >>> validate_path("./datasets/MVTec/bottle/train/good/000.png", base_dir="./datasets/MVTec") - PosixPath('/abs/path/to/anomalib/datasets/MVTec/bottle/train/good/000.png') - - >>> validate_path("/path/to/unexisting/file") - Traceback (most recent call last): - File "", line 1, in - File "", line 18, in validate_path - FileNotFoundError: Path does not exist: /path/to/unexisting/file - - Accessing a file without read permission should raise PermissionError: - - .. note:: - - Note that, we are using ``/usr/local/bin`` directory as an example here. - If this directory does not exist on your system, this will raise - ``FileNotFoundError`` instead of ``PermissionError``. You could change - the directory to any directory that you do not have read permission. - - >>> validate_path("/bin/bash", base_dir="/bin/") - Traceback (most recent call last): - File "", line 1, in - File "", line 18, in validate_path - PermissionError: Read permission denied for the file: /usr/local/bin - + Validated Path object + + Raises: + TypeError: If path is invalid type + ValueError: If path is too long or has invalid characters/extension + FileNotFoundError: If path doesn't exist when required + PermissionError: If path lacks required permissions + + Example: + >>> path = validate_path("./datasets/image.png", extensions=(".png",)) + >>> path.suffix + '.png' """ # Check if the path is of an appropriate type if not isinstance(path, str | Path): @@ -222,7 +248,7 @@ def validate_path( # Check if the path has one of the accepted extensions if extensions is not None and path.suffix not in extensions: - msg = f"Path extension is not accepted. Accepted extensions: {extensions}. Path: {path}" + msg = f"Path extension is not accepted. Accepted: {extensions}. Path: {path}" raise ValueError(msg) return path @@ -233,14 +259,19 @@ def validate_and_resolve_path( root: str | Path | None = None, base_dir: str | Path | None = None, ) -> Path: - """Validate and resolve the path. + """Validate and resolve path by combining validation and resolution. Args: - folder (str | Path): Folder location containing image or mask data. - root (str | Path | None): Root directory for the dataset. - base_dir (str | Path | None): Base directory to restrict file access. + folder: Folder location containing image or mask data + root: Root directory for the dataset + base_dir: Base directory to restrict file access Returns: - Path: Validated and resolved path. + Validated and resolved absolute Path + + Example: + >>> path = validate_and_resolve_path("subdir", "/root") + >>> path.is_absolute() + True """ return validate_path(resolve_path(folder, root), base_dir) diff --git a/src/anomalib/data/utils/split.py b/src/anomalib/data/utils/split.py index fe085ea1cf..db872a19b7 100644 --- a/src/anomalib/data/utils/split.py +++ b/src/anomalib/data/utils/split.py @@ -1,11 +1,29 @@ -"""Dataset Split Utils. +"""Dataset splitting utilities. -This module contains function in regards to splitting normal images in training set, -and creating validation sets from test sets. +This module provides functions for splitting datasets in anomaly detection tasks: -These function are useful - - when the test set does not contain any normal images. - - when the dataset doesn't have a validation set. +- Splitting normal images into training and validation sets +- Creating validation sets from test sets +- Label-aware splitting to maintain class distributions +- Random splitting with optional seed for reproducibility + +These utilities are particularly useful when: + +- The test set lacks normal images +- The dataset needs a validation set +- Class balance needs to be maintained during splits + +Example: + >>> from anomalib.data.utils.split import random_split + >>> # Split dataset with 80/20 ratio + >>> train_set, val_set = random_split(dataset, split_ratio=0.2) + >>> len(train_set), len(val_set) + (800, 200) + + >>> # Label-aware split preserving class distributions + >>> splits = random_split(dataset, [0.7, 0.2, 0.1], label_aware=True) + >>> len(splits) + 3 """ # Copyright (C) 2022-2024 Intel Corporation @@ -26,7 +44,13 @@ class Split(str, Enum): - """Split of a subset.""" + """Dataset split type. + + Attributes: + TRAIN: Training split + VAL: Validation split + TEST: Test split + """ TRAIN = "train" VAL = "val" @@ -34,7 +58,13 @@ class Split(str, Enum): class TestSplitMode(str, Enum): - """Splitting mode used to obtain subset.""" + """Mode used to obtain test split. + + Attributes: + NONE: No test split + FROM_DIR: Test split from directory + SYNTHETIC: Synthetic test split + """ NONE = "none" FROM_DIR = "from_dir" @@ -42,7 +72,15 @@ class TestSplitMode(str, Enum): class ValSplitMode(str, Enum): - """Splitting mode used to obtain validation subset.""" + """Mode used to obtain validation split. + + Attributes: + NONE: No validation split + SAME_AS_TEST: Use same split as test + FROM_TRAIN: Split from training set + FROM_TEST: Split from test set + SYNTHETIC: Synthetic validation split + """ NONE = "none" SAME_AS_TEST = "same_as_test" @@ -51,14 +89,21 @@ class ValSplitMode(str, Enum): SYNTHETIC = "synthetic" -def concatenate_datasets(datasets: Sequence["data.AnomalibDataset"]) -> "data.AnomalibDataset": - """Concatenate multiple datasets into a single dataset object. +def concatenate_datasets( + datasets: Sequence["data.AnomalibDataset"], +) -> "data.AnomalibDataset": + """Concatenate multiple datasets into a single dataset. Args: - datasets (Sequence[AnomalibDataset]): Sequence of at least two datasets. + datasets: Sequence of at least two datasets to concatenate Returns: - AnomalibDataset: Dataset that contains the combined samples of all input datasets. + Combined dataset containing samples from all input datasets + + Example: + >>> combined = concatenate_datasets([dataset1, dataset2]) + >>> len(combined) == len(dataset1) + len(dataset2) + True """ concat_dataset = datasets[0] for dataset in datasets[1:]: @@ -72,16 +117,26 @@ def random_split( label_aware: bool = False, seed: int | None = None, ) -> list["data.AnomalibDataset"]: - """Perform a random split of a dataset. + """Randomly split a dataset into multiple subsets. Args: - dataset (AnomalibDataset): Source dataset - split_ratio (Union[float, Sequence[float]]): Fractions of the splits that will be produced. The values in the - sequence must sum to 1. If a single value is passed, the ratio will be converted to - [1-split_ratio, split_ratio]. - label_aware (bool): When True, the relative occurrence of the different class labels of the source dataset will - be maintained in each of the subsets. - seed (int | None, optional): Seed that can be passed if results need to be reproducible + dataset: Source dataset to split + split_ratio: Split ratios that must sum to 1. If single float ``x`` is + provided, splits into ``[1-x, x]`` + label_aware: If ``True``, maintains class label distributions in splits + seed: Random seed for reproducibility + + Returns: + List of dataset splits based on provided ratios + + Example: + >>> splits = random_split(dataset, [0.7, 0.3], seed=42) + >>> len(splits) + 2 + >>> # Label-aware splitting + >>> splits = random_split(dataset, 0.2, label_aware=True) + >>> len(splits) + 2 """ if isinstance(split_ratio, float): split_ratio = [1 - split_ratio, split_ratio] @@ -128,8 +183,24 @@ def random_split( return [concatenate_datasets(subset) for subset in subsets] -def split_by_label(dataset: "data.AnomalibDataset") -> tuple["data.AnomalibDataset", "data.AnomalibDataset"]: - """Split the dataset into the normal and anomalous subsets.""" +def split_by_label( + dataset: "data.AnomalibDataset", +) -> tuple["data.AnomalibDataset", "data.AnomalibDataset"]: + """Split dataset into normal and anomalous subsets. + + Args: + dataset: Dataset to split by label + + Returns: + Tuple containing: + - Dataset with only normal samples (label 0) + - Dataset with only anomalous samples (label 1) + + Example: + >>> normal, anomalous = split_by_label(dataset) + >>> len(normal) + len(anomalous) == len(dataset) + True + """ samples = dataset.samples normal_indices = samples[samples.label_index == 0].index anomalous_indices = samples[samples.label_index == 1].index diff --git a/src/anomalib/data/utils/synthetic.py b/src/anomalib/data/utils/synthetic.py index c4b52d5b35..fb347aa157 100644 --- a/src/anomalib/data/utils/synthetic.py +++ b/src/anomalib/data/utils/synthetic.py @@ -1,6 +1,22 @@ """Dataset that generates synthetic anomalies. -This dataset can be used when there is a lack of real anomalous data. +This module provides functionality to generate synthetic anomalies when real +anomalous data is scarce or unavailable. It includes: + +- A dataset class that generates synthetic anomalies from normal images +- Functions to convert normal samples into synthetic anomalous samples +- Perlin noise-based anomaly generation +- Temporary file management for synthetic data + +Example: + >>> from anomalib.data.utils.synthetic import SyntheticAnomalyDataset + >>> # Create synthetic dataset from normal samples + >>> synthetic_dataset = SyntheticAnomalyDataset( + ... transform=transforms, + ... source_samples=normal_samples + ... ) + >>> len(synthetic_dataset) # 50/50 normal/anomalous split + 200 """ # Copyright (C) 2022-2024 Intel Corporation @@ -34,16 +50,36 @@ def make_synthetic_dataset( mask_dir: Path, anomalous_ratio: float = 0.5, ) -> DataFrame: - """Convert a set of normal samples into a mixed set of normal and synthetic anomalous samples. + """Convert normal samples into a mixed set with synthetic anomalies. - The synthetic images will be saved to the file system in the specified root directory under /images. - For the synthetic anomalous images, the masks will be saved under /ground_truth. + The function generates synthetic anomalous images and their corresponding + masks by applying Perlin noise-based perturbations to normal images. Args: - source_samples (DataFrame): Normal images that will be used as source for the synthetic anomalous images. - image_dir (Path): Directory to which the synthetic anomalous image files will be written. - mask_dir (Path): Directory to which the ground truth anomaly masks will be written. - anomalous_ratio (float): Fraction of source samples that will be converted into anomalous samples. + source_samples: DataFrame containing normal images used as source for + synthetic anomalies. Must contain columns: ``image_path``, + ``label``, ``label_index``, ``mask_path``, and ``split``. + image_dir: Directory where synthetic anomalous images will be saved. + mask_dir: Directory where ground truth anomaly masks will be saved. + anomalous_ratio: Fraction of source samples to convert to anomalous + samples. Defaults to ``0.5``. + + Returns: + DataFrame containing both normal and synthetic anomalous samples. + + Raises: + ValueError: If source samples contain any anomalous images. + NotADirectoryError: If ``image_dir`` or ``mask_dir`` is not a directory. + + Example: + >>> df = make_synthetic_dataset( + ... source_samples=normal_df, + ... image_dir=Path("./synthetic/images"), + ... mask_dir=Path("./synthetic/masks"), + ... anomalous_ratio=0.3 + ... ) + >>> len(df[df.label == "abnormal"]) # 30% are anomalous + 30 """ if 1 in source_samples.label_index.to_numpy(): msg = "All source images must be normal." @@ -66,19 +102,20 @@ def make_synthetic_dataset( anomalous_samples = anomalous_samples.reset_index(drop=True) # initialize augmenter - augmenter = PerlinAnomalyGenerator(anomaly_source_path="./datasets/dtd", probability=1.0, blend_factor=(0.01, 0.2)) + augmenter = PerlinAnomalyGenerator( + anomaly_source_path="./datasets/dtd", + probability=1.0, + blend_factor=(0.01, 0.2), + ) def augment(sample: Series) -> Series: - """Apply synthetic anomalous augmentation to a sample from a dataframe. - - Reads an image, applies the augmentations, writes the augmented image and corresponding mask to the file system, - and returns a new Series object with the updates labels and file locations. + """Apply synthetic anomalous augmentation to a sample. Args: - sample (Series): DataFrame row containing info about the image that will be augmented. + sample: DataFrame row containing image information. Returns: - Series: DataFrame row with updated information about the augmented image. + Series containing updated information about the augmented image. """ # read and transform image image = read_image(sample.image_path, as_tensor=True) @@ -110,11 +147,26 @@ def augment(sample: Series) -> Series: class SyntheticAnomalyDataset(AnomalibDataset): - """Dataset which reads synthetically generated anomalous images from a temporary folder. + """Dataset for generating and managing synthetic anomalies. + + The dataset creates synthetic anomalous images by applying Perlin + noise-based perturbations to normal images. The synthetic images are + stored in a temporary directory that is cleaned up when the dataset + object is deleted. Args: - transform (A.Compose): Transform object describing the transforms that are applied to the inputs. - source_samples (DataFrame): Normal samples to which the anomalous augmentations will be applied. + transform: Transform object describing the transforms applied to inputs. + source_samples: DataFrame containing normal samples used as source for + synthetic anomalies. + + Example: + >>> transform = Compose([...]) + >>> dataset = SyntheticAnomalyDataset( + ... transform=transform, + ... source_samples=normal_df + ... ) + >>> len(dataset) # 50/50 normal/anomalous split + 100 """ def __init__(self, transform: Compose, source_samples: DataFrame) -> None: @@ -122,7 +174,7 @@ def __init__(self, transform: Compose, source_samples: DataFrame) -> None: self.source_samples = source_samples - # Files will be written to a temporary directory in the workdir, which is cleaned up after code execution + # Files will be written to a temporary directory in the workdir root = Path(ROOT) root.mkdir(parents=True, exist_ok=True) @@ -134,21 +186,40 @@ def __init__(self, transform: Compose, source_samples: DataFrame) -> None: self.im_dir.mkdir() self.mask_dir.mkdir() - self._cleanup = True # flag that determines if temp dir is cleaned up when instance is deleted - self.samples = make_synthetic_dataset(self.source_samples, self.im_dir, self.mask_dir, 0.5) + self._cleanup = True # flag that determines if temp dir is cleaned up + self.samples = make_synthetic_dataset( + self.source_samples, + self.im_dir, + self.mask_dir, + 0.5, + ) @classmethod - def from_dataset(cls: type["SyntheticAnomalyDataset"], dataset: AnomalibDataset) -> "SyntheticAnomalyDataset": - """Create a synthetic anomaly dataset from an existing dataset of normal images. + def from_dataset( + cls: type["SyntheticAnomalyDataset"], + dataset: AnomalibDataset, + ) -> "SyntheticAnomalyDataset": + """Create synthetic dataset from existing dataset of normal images. Args: - dataset (AnomalibDataset): Dataset consisting of only normal images that will be converrted to a synthetic - anomalous dataset with a 50/50 normal anomalous split. + dataset: Dataset containing only normal images to convert into a + synthetic dataset with 50/50 normal/anomalous split. + + Returns: + New synthetic anomaly dataset. + + Example: + >>> normal_dataset = Dataset(...) + >>> synthetic = SyntheticAnomalyDataset.from_dataset(normal_dataset) """ return cls(transform=dataset.transform, source_samples=dataset.samples) def __copy__(self) -> "SyntheticAnomalyDataset": - """Return a shallow copy of the dataset object and prevents cleanup when original object is deleted.""" + """Return shallow copy and prevent cleanup of original. + + Returns: + Shallow copy of the dataset object. + """ cls = self.__class__ new = cls.__new__(cls) new.__dict__.update(self.__dict__) @@ -156,7 +227,14 @@ def __copy__(self) -> "SyntheticAnomalyDataset": return new def __deepcopy__(self, _memo: dict) -> "SyntheticAnomalyDataset": - """Return a deep copy of the dataset object and prevents cleanup when original object is deleted.""" + """Return deep copy and prevent cleanup of original. + + Args: + _memo: Memo dictionary used by deepcopy. + + Returns: + Deep copy of the dataset object. + """ cls = self.__class__ new = cls.__new__(cls) for key, value in self.__dict__.items(): @@ -165,6 +243,6 @@ def __deepcopy__(self, _memo: dict) -> "SyntheticAnomalyDataset": return new def __del__(self) -> None: - """Make sure the temporary directory is cleaned up when the dataset object is deleted.""" + """Clean up temporary directory when dataset object is deleted.""" if self._cleanup: shutil.rmtree(self.root) diff --git a/src/anomalib/data/utils/tiler.py b/src/anomalib/data/utils/tiler.py index 2c1e949e45..430763da1f 100644 --- a/src/anomalib/data/utils/tiler.py +++ b/src/anomalib/data/utils/tiler.py @@ -1,4 +1,28 @@ -"""Image Tiler.""" +"""Image tiling utilities for processing large images. + +This module provides functionality to: + +- Tile large images into smaller patches for efficient processing +- Support overlapping and non-overlapping tiling strategies +- Reconstruct original images from tiles +- Handle upscaling and downscaling with padding or interpolation + +Example: + >>> from anomalib.data.utils.tiler import Tiler + >>> import torch + >>> # Create tiler with 256x256 tiles and 128 stride + >>> tiler = Tiler(tile_size=256, stride=128) + >>> # Create sample 512x512 image + >>> image = torch.rand(1, 3, 512, 512) + >>> # Generate tiles + >>> tiles = tiler.tile(image) + >>> tiles.shape + torch.Size([9, 3, 256, 256]) + >>> # Reconstruct image from tiles + >>> reconstructed = tiler.untile(tiles) + >>> reconstructed.shape + torch.Size([1, 3, 512, 512]) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -14,39 +38,50 @@ class ImageUpscaleMode(str, Enum): - """Type of mode when upscaling image.""" + """Mode for upscaling images. + + Attributes: + PADDING: Upscale by padding with zeros + INTERPOLATION: Upscale using interpolation + """ PADDING = "padding" INTERPOLATION = "interpolation" class StrideSizeError(Exception): - """StrideSizeError to raise exception when stride size is greater than the tile size.""" + """Error raised when stride size exceeds tile size.""" def compute_new_image_size(image_size: tuple, tile_size: tuple, stride: tuple) -> tuple: - """Check if image size is divisible by tile size and stride. - - If not divisible, it resizes the image size to make it divisible. + """Compute new image size that is divisible by tile size and stride. Args: - image_size (tuple): Original image size - tile_size (tuple): Tile size - stride (tuple): Stride + image_size: Original image size as ``(height, width)`` + tile_size: Tile size as ``(height, width)`` + stride: Stride size as ``(height, width)`` + + Returns: + tuple: New image size divisible by tile size and stride Examples: - >>> compute_new_image_size(image_size=(512, 512), tile_size=(256, 256), stride=(128, 128)) + >>> compute_new_image_size((512, 512), (256, 256), (128, 128)) (512, 512) - - >>> compute_new_image_size(image_size=(512, 512), tile_size=(222, 222), stride=(111, 111)) + >>> compute_new_image_size((512, 512), (222, 222), (111, 111)) (555, 555) - - Returns: - tuple: Updated image size that is divisible by tile size and stride. """ def __compute_new_edge_size(edge_size: int, tile_size: int, stride: int) -> int: - """Resize within the edge level.""" + """Compute new edge size that is divisible by tile size and stride. + + Args: + edge_size: Original edge size + tile_size: Tile size for this edge + stride: Stride size for this edge + + Returns: + int: New edge size + """ if (edge_size - tile_size) % stride != 0: edge_size = (ceil((edge_size - tile_size) / stride) * stride) + tile_size @@ -58,27 +93,30 @@ def __compute_new_edge_size(edge_size: int, tile_size: int, stride: int) -> int: return resized_h, resized_w -def upscale_image(image: torch.Tensor, size: tuple, mode: ImageUpscaleMode = ImageUpscaleMode.PADDING) -> torch.Tensor: - """Upscale image to the desired size via either padding or interpolation. +def upscale_image( + image: torch.Tensor, + size: tuple, + mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, +) -> torch.Tensor: + """Upscale image to desired size using padding or interpolation. Args: - image (torch.Tensor): Image - size (tuple): tuple to which image is upscaled. - mode (str, optional): Upscaling mode. Defaults to "padding". + image: Input image tensor + size: Target size as ``(height, width)`` + mode: Upscaling mode, either ``"padding"`` or ``"interpolation"`` + + Returns: + torch.Tensor: Upscaled image Examples: >>> image = torch.rand(1, 3, 512, 512) - >>> image = upscale_image(image, size=(555, 555), mode="padding") - >>> image.shape + >>> upscaled = upscale_image(image, (555, 555), "padding") + >>> upscaled.shape torch.Size([1, 3, 555, 555]) - >>> image = torch.rand(1, 3, 512, 512) - >>> image = upscale_image(image, size=(555, 555), mode="interpolation") - >>> image.shape + >>> upscaled = upscale_image(image, (555, 555), "interpolation") + >>> upscaled.shape torch.Size([1, 3, 555, 555]) - - Returns: - Tensor: Upscaled image. """ image_h, image_w = image.shape[2:] resize_h, resize_w = size @@ -102,22 +140,22 @@ def downscale_image( size: tuple, mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, ) -> torch.Tensor: - """Opposite of upscaling. This image downscales image to a desired size. + """Downscale image to desired size. Args: - image (torch.Tensor): Input image - size (tuple): Size to which image is down scaled. - mode (str, optional): Downscaling mode. Defaults to "padding". + image: Input image tensor + size: Target size as ``(height, width)`` + mode: Downscaling mode, either ``"padding"`` or ``"interpolation"`` + + Returns: + torch.Tensor: Downscaled image Examples: >>> x = torch.rand(1, 3, 512, 512) - >>> y = upscale_image(image, upscale_size=(555, 555), mode="padding") - >>> y = downscale_image(y, size=(512, 512), mode='padding') - >>> torch.allclose(x, y) + >>> y = upscale_image(x, (555, 555), "padding") + >>> z = downscale_image(y, (512, 512), "padding") + >>> torch.allclose(x, z) True - - Returns: - Tensor: Downscaled image """ input_h, input_w = size if mode == ImageUpscaleMode.PADDING: @@ -129,29 +167,40 @@ def downscale_image( class Tiler: - """Tile Image into (non)overlapping Patches. Images are tiled in order to efficiently process large images. + """Tile images into overlapping or non-overlapping patches. + + This class provides functionality to: + - Split large images into smaller tiles for efficient processing + - Support overlapping tiles with configurable stride + - Remove border pixels from tiles before reconstruction + - Reconstruct original image from processed tiles Args: - tile_size: Tile dimension for each patch - stride: Stride length between patches - remove_border_count: Number of border pixels to be removed from tile before untiling - mode: Upscaling mode for image resize.Supported formats: padding, interpolation + tile_size: Size of tiles as int or ``(height, width)`` + stride: Stride between tiles as int or ``(height, width)``. + If ``None``, uses tile_size (non-overlapping) + remove_border_count: Number of border pixels to remove from tiles + mode: Upscaling mode for resizing, either ``"padding"`` or + ``"interpolation"`` Examples: >>> import torch >>> from torchvision import transforms >>> from skimage.data import camera - >>> tiler = Tiler(tile_size=256,stride=128) + >>> # Create tiler for 256x256 tiles with 128 stride + >>> tiler = Tiler(tile_size=256, stride=128) + >>> # Convert test image to tensor >>> image = transforms.ToTensor()(camera()) + >>> # Generate tiles >>> tiles = tiler.tile(image) >>> image.shape, tiles.shape (torch.Size([3, 512, 512]), torch.Size([9, 3, 256, 256])) - >>> # Perform your operations on the tiles. + >>> # Process tiles here... - >>> # Untile the patches to reconstruct the image - >>> reconstructed_image = tiler.untile(tiles) - >>> reconstructed_image.shape + >>> # Reconstruct image from tiles + >>> reconstructed = tiler.untile(tiles) + >>> reconstructed.shape torch.Size([1, 3, 512, 512]) """ @@ -173,16 +222,11 @@ def __init__( self.mode = mode if self.stride_h > self.tile_size_h or self.stride_w > self.tile_size_w: - msg = ( - "Larger stride size than kernel size produces unreliable tiling results. " - "Please ensure stride size is less than or equal than tiling size." - ) - raise StrideSizeError( - msg, - ) + msg = "Stride size larger than tile size produces unreliable results. Ensure stride size <= tile size." + raise StrideSizeError(msg) if self.mode not in {ImageUpscaleMode.PADDING, ImageUpscaleMode.INTERPOLATION}: - msg = f"Unknown tiling mode {self.mode}. Available modes are padding and interpolation" + msg = f"Unknown mode {self.mode}. Available modes: padding and interpolation" raise ValueError(msg) self.batch_size: int @@ -202,64 +246,70 @@ def __init__( @staticmethod def validate_size_type(parameter: int | Sequence) -> tuple[int, ...]: - """Validate size type and return tuple of form [tile_h, tile_w]. + """Validate and convert size parameter to tuple. Args: - parameter (int | Sequence): input tile size parameter. + parameter: Size as int or sequence of ``(height, width)`` Returns: - tuple[int, ...]: Validated tile size in tuple form. + tuple: Validated size as ``(height, width)`` + + Raises: + TypeError: If parameter type is invalid + ValueError: If parameter length is not 2 """ if isinstance(parameter, int): output = (parameter, parameter) elif isinstance(parameter, Sequence): output = (parameter[0], parameter[1]) else: - msg = f"Unknown type {type(parameter)} for tile or stride size. Could be int or Sequence type." + msg = f"Invalid type {type(parameter)} for tile/stride size. Must be int or Sequence." raise TypeError(msg) if len(output) != 2: - msg = f"Length of the size type must be 2 for height and width. Got {len(output)} instead." + msg = f"Size must have length 2, got {len(output)}" raise ValueError(msg) return output def __random_tile(self, image: torch.Tensor) -> torch.Tensor: - """Randomly crop tiles from the given image. + """Randomly crop tiles from image. Args: - image: input image to be cropped + image: Input image tensor - Returns: Randomly cropped tiles from the image + Returns: + torch.Tensor: Stack of random tiles """ return torch.vstack([T.RandomCrop(self.tile_size_h)(image) for i in range(self.random_tile_count)]) def __unfold(self, tensor: torch.Tensor) -> torch.Tensor: - """Unfolds tensor into tiles. - - This is the core function to perform tiling operation. + """Unfold tensor into tiles. Args: - tensor: Input tensor from which tiles are generated. + tensor: Input tensor to tile - Returns: Generated tiles + Returns: + torch.Tensor: Generated tiles """ - # identify device type based on input tensor device = tensor.device - - # extract and calculate parameters batch, channels, image_h, image_w = tensor.shape self.num_patches_h = int((image_h - self.tile_size_h) / self.stride_h) + 1 self.num_patches_w = int((image_w - self.tile_size_w) / self.stride_w) + 1 - # create an empty torch tensor for output tiles = torch.zeros( - (self.num_patches_h, self.num_patches_w, batch, channels, self.tile_size_h, self.tile_size_w), + ( + self.num_patches_h, + self.num_patches_w, + batch, + channels, + self.tile_size_h, + self.tile_size_w, + ), device=device, ) - # fill-in output tensor with spatial patches extracted from the image for (tile_i, tile_j), (loc_i, loc_j) in zip( product(range(self.num_patches_h), range(self.num_patches_w)), product( @@ -275,33 +325,30 @@ def __unfold(self, tensor: torch.Tensor) -> torch.Tensor: loc_j : (loc_j + self.tile_size_w), ] - # rearrange the tiles in order [tile_count * batch, channels, tile_height, tile_width] tiles = tiles.permute(2, 0, 1, 3, 4, 5) return tiles.contiguous().view(-1, channels, self.tile_size_h, self.tile_size_w) def __fold(self, tiles: torch.Tensor) -> torch.Tensor: - """Fold the tiles back into the original tensor. - - This is the core method to reconstruct the original image from its tiled version. + """Fold tiles back into original tensor. Args: - tiles: Tiles from the input image, generated via __unfold method. + tiles: Tiles generated by ``__unfold()`` Returns: - Output that is the reconstructed version of the input tensor. + torch.Tensor: Reconstructed tensor """ - # number of channels differs between image and anomaly map, so infer from input tiles. _, num_channels, tile_size_h, tile_size_w = tiles.shape scale_h, scale_w = (tile_size_h / self.tile_size_h), (tile_size_w / self.tile_size_w) - # identify device type based on input tensor device = tiles.device - # calculate tile size after borders removed reduced_tile_h = tile_size_h - (2 * self.remove_border_count) reduced_tile_w = tile_size_w - (2 * self.remove_border_count) - # reconstructed image dimension - image_size = (self.batch_size, num_channels, int(self.resized_h * scale_h), int(self.resized_w * scale_w)) + image_size = ( + self.batch_size, + num_channels, + int(self.resized_h * scale_h), + int(self.resized_w * scale_w), + ) - # rearrange input tiles in format [tile_count, batch, channel, tile_h, tile_w] tiles = tiles.contiguous().view( self.batch_size, self.num_patches_h, @@ -314,7 +361,6 @@ def __fold(self, tiles: torch.Tensor) -> torch.Tensor: tiles = tiles.contiguous().view(self.batch_size, num_channels, -1, tile_size_h, tile_size_w) tiles = tiles.permute(2, 0, 1, 3, 4) - # remove tile borders by defined count tiles = tiles[ :, :, @@ -323,13 +369,10 @@ def __fold(self, tiles: torch.Tensor) -> torch.Tensor: self.remove_border_count : reduced_tile_w + self.remove_border_count, ] - # create tensors to store intermediate results and outputs img = torch.zeros(image_size, device=device) lookup = torch.zeros(image_size, device=device) ones = torch.ones(reduced_tile_h, reduced_tile_w, device=device) - # reconstruct image by adding patches to their respective location and - # create a lookup for patch count in every location for patch, (loc_i, loc_j) in zip( tiles, product( @@ -346,36 +389,44 @@ def __fold(self, tiles: torch.Tensor) -> torch.Tensor: ), strict=True, ): - img[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += patch - lookup[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += ones + img[ + :, + :, + loc_i : (loc_i + reduced_tile_h), + loc_j : (loc_j + reduced_tile_w), + ] += patch + lookup[ + :, + :, + loc_i : (loc_i + reduced_tile_h), + loc_j : (loc_j + reduced_tile_w), + ] += ones - # divide the reconstucted image by the lookup to average out the values img = torch.divide(img, lookup) - # alternative way of removing nan values (isnan not supported by openvino) img[img != img] = 0 # noqa: PLR0124 return img def tile(self, image: torch.Tensor, use_random_tiling: bool = False) -> torch.Tensor: - """Tiles an input image to either overlapping, non-overlapping or random patches. + """Tile input image into patches. Args: - image: Input image to tile. - use_random_tiling: If True, randomly crops tiles from the image. - If False, tiles the image in a regular grid. + image: Input image tensor + use_random_tiling: If ``True``, randomly crop tiles. + If ``False``, tile in regular grid. + + Returns: + torch.Tensor: Generated tiles Examples: - >>> from anomalib.data.utils.tiler import Tiler - >>> tiler = Tiler(tile_size=512,stride=256) - >>> image = torch.rand(size=(2, 3, 1024, 1024)) - >>> image.shape - torch.Size([2, 3, 1024, 1024]) + >>> tiler = Tiler(tile_size=512, stride=256) + >>> image = torch.rand(2, 3, 1024, 1024) >>> tiles = tiler.tile(image) >>> tiles.shape torch.Size([18, 3, 512, 512]) - Returns: - Tiles generated from the image. + Raises: + ValueError: If tile size exceeds image size """ if image.dim() == 3: image = image.unsqueeze(0) @@ -383,13 +434,8 @@ def tile(self, image: torch.Tensor, use_random_tiling: bool = False) -> torch.Te self.batch_size, self.num_channels, self.input_h, self.input_w = image.shape if self.input_h < self.tile_size_h or self.input_w < self.tile_size_w: - msg = ( - f"One of the edges of the tile size {self.tile_size_h, self.tile_size_w} is larger than " - f"that of the image {self.input_h, self.input_w}." - ) - raise ValueError( - msg, - ) + msg = f"Tile size {self.tile_size_h, self.tile_size_w} exceeds image size {self.input_h, self.input_w}" + raise ValueError(msg) self.resized_h, self.resized_w = compute_new_image_size( image_size=(self.input_h, self.input_w), @@ -402,31 +448,25 @@ def tile(self, image: torch.Tensor, use_random_tiling: bool = False) -> torch.Te return self.__random_tile(image) if use_random_tiling else self.__unfold(image) def untile(self, tiles: torch.Tensor) -> torch.Tensor: - """Untiles patches to reconstruct the original input image. + """Reconstruct image from tiles. - If patches, are overlapping patches, the function averages the overlapping pixels, - and return the reconstructed image. + For overlapping tiles, averages overlapping regions. Args: - tiles: Tiles from the input image, generated via tile().. + tiles: Tiles generated by ``tile()`` + + Returns: + torch.Tensor: Reconstructed image Examples: - >>> from anomalib.data.utils.tiler import Tiler - >>> tiler = Tiler(tile_size=512,stride=256) - >>> image = torch.rand(size=(2, 3, 1024, 1024)) - >>> image.shape - torch.Size([2, 3, 1024, 1024]) + >>> tiler = Tiler(tile_size=512, stride=256) + >>> image = torch.rand(2, 3, 1024, 1024) >>> tiles = tiler.tile(image) - >>> tiles.shape - torch.Size([18, 3, 512, 512]) - >>> reconstructed_image = tiler.untile(tiles) - >>> reconstructed_image.shape + >>> reconstructed = tiler.untile(tiles) + >>> reconstructed.shape torch.Size([2, 3, 1024, 1024]) - >>> torch.equal(image, reconstructed_image) + >>> torch.equal(image, reconstructed) True - - Returns: - Output that is the reconstructed version of the input tensor. """ image = self.__fold(tiles) return downscale_image(image=image, size=(self.input_h, self.input_w), mode=self.mode) diff --git a/src/anomalib/data/utils/video.py b/src/anomalib/data/utils/video.py index cc3d839dfa..4bd5c360ba 100644 --- a/src/anomalib/data/utils/video.py +++ b/src/anomalib/data/utils/video.py @@ -1,4 +1,24 @@ -"""Video utils.""" +"""Video utilities for processing video data in anomaly detection. + +This module provides utilities for: + +- Indexing video clips and their corresponding masks +- Converting videos between different codecs +- Handling video frames and clips in PyTorch format + +Example: + >>> from anomalib.data.utils.video import ClipsIndexer + >>> # Create indexer for video files and masks + >>> indexer = ClipsIndexer( + ... video_paths=["video1.mp4", "video2.mp4"], + ... mask_paths=["mask1.mp4", "mask2.mp4"], + ... clip_length_in_frames=16 + ... ) + >>> # Get video clip with metadata + >>> video_item = indexer.get_item(0) + >>> video_item.image.shape # (16, 3, H, W) + torch.Size([16, 3, 256, 256]) +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -15,15 +35,25 @@ class ClipsIndexer(VideoClips, ABC): - """Extension of torchvision's VideoClips class that also returns the masks for each clip. + """Extension of torchvision's VideoClips class for video and mask indexing. - Subclasses should implement the get_mask method. By default, the class inherits the functionality of VideoClips, - which assumes that video_paths is a list of video files. If custom behaviour is required (e.g. video_paths is a list - of folders with single-frame images), the subclass should implement at least get_clip and _compute_frame_pts. + This class extends ``VideoClips`` to handle both video frames and their + corresponding mask annotations. It provides functionality to: + + - Index and retrieve video clips + - Access corresponding mask frames + - Track frame indices and video metadata + + Subclasses must implement the ``get_mask`` method. The default implementation + assumes ``video_paths`` contains video files. For custom data formats + (e.g., image sequences), subclasses should override ``get_clip`` and + ``_compute_frame_pts``. Args: - video_paths (list[str]): List of video paths that make up the dataset. - mask_paths (list[str]): List of paths to the masks for each video in the dataset. + video_paths: List of paths to video files in the dataset + mask_paths: List of paths to mask files corresponding to each video + clip_length_in_frames: Number of frames in each clip. Defaults to ``2`` + frames_between_clips: Stride between consecutive clips. Defaults to ``1`` """ def __init__( @@ -42,18 +72,40 @@ def __init__( self.mask_paths = mask_paths def last_frame_idx(self, video_idx: int) -> int: - """Return the index of the last frame for a given video.""" + """Get index of the last frame in a video. + + Args: + video_idx: Index of the video in the dataset + + Returns: + Index of the last frame + """ return self.clips[video_idx][-1][-1].item() @abstractmethod def get_mask(self, idx: int) -> torch.Tensor | None: - """Return the masks for the given index.""" + """Get masks for the clip at the given index. + + Args: + idx: Index of the clip + + Returns: + Tensor containing mask frames, or None if no masks exist + """ raise NotImplementedError def get_item(self, idx: int) -> VideoItem: - """Return a dictionary containing the clip, mask, video path and frame indices.""" + """Get video clip and metadata at the given index. + + Args: + idx: Index of the clip to retrieve + + Returns: + VideoItem containing the clip frames, masks, path and metadata + """ with warnings.catch_warnings(): - # silence warning caused by bug in torchvision, see https://github.com/pytorch/vision/issues/5787 + # silence warning caused by bug in torchvision + # see https://github.com/pytorch/vision/issues/5787 warnings.simplefilter("ignore") clip, _, _, _ = self.get_clip(idx) @@ -71,12 +123,15 @@ def get_item(self, idx: int) -> VideoItem: def convert_video(input_path: Path, output_path: Path, codec: str = "MP4V") -> None: - """Convert video file to a different codec. + """Convert a video file to use a different codec. + + Creates the output directory if it doesn't exist. Reads the input video + frame by frame and writes to a new file using the specified codec. Args: - input_path (Path): Path to the input video. - output_path (Path): Path to the target output video. - codec (str): fourcc code of the codec that will be used for compression of the output file. + input_path: Path to the input video file + output_path: Path where the converted video will be saved + codec: FourCC code for the desired output codec. Defaults to ``"MP4V"`` """ if not output_path.parent.exists(): output_path.parent.mkdir(parents=True) @@ -89,7 +144,12 @@ def convert_video(input_path: Path, output_path: Path, codec: str = "MP4V") -> N frame_width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(video_reader.get(cv2.CAP_PROP_FPS)) - video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height)) + video_writer = cv2.VideoWriter( + str(output_path), + fourcc, + fps, + (frame_width, frame_height), + ) # read frames success, frame = video_reader.read() diff --git a/src/anomalib/data/validators/numpy/__init__.py b/src/anomalib/data/validators/numpy/__init__.py index 759f7322bd..2ac929c9c5 100644 --- a/src/anomalib/data/validators/numpy/__init__.py +++ b/src/anomalib/data/validators/numpy/__init__.py @@ -1,4 +1,30 @@ -"""Anomalib Numpy data validators.""" +"""Anomalib Numpy data validators. + +This module provides validators for numpy array data used in Anomalib. The validators +ensure data consistency and correctness for various data types: + +- Image data: Single images and batches +- Video data: Single videos and batches +- Depth data: Single depth maps and batches + +The validators check: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + +Example: + Validate a numpy image batch:: + + >>> from anomalib.data.validators import NumpyImageBatchValidator + >>> validator = NumpyImageBatchValidator() + >>> validator(images=images, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/validators/numpy/depth.py b/src/anomalib/data/validators/numpy/depth.py index 89d7726182..f0c6eb7724 100644 --- a/src/anomalib/data/validators/numpy/depth.py +++ b/src/anomalib/data/validators/numpy/depth.py @@ -1,4 +1,32 @@ -"""Validate numpy depth data.""" +"""Validate numpy depth data. + +This module provides validators for depth data stored as numpy arrays. The validators +ensure data consistency and correctness for depth maps and batches of depth maps. + +The validators check: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + +Example: + Validate a single depth map:: + + >>> from anomalib.data.validators import NumpyDepthValidator + >>> validator = NumpyDepthValidator() + >>> validator.validate_image(depth_map) + + Validate a batch of depth maps:: + + >>> from anomalib.data.validators import NumpyDepthBatchValidator + >>> validator = NumpyDepthBatchValidator() + >>> validator(depth_maps=depth_maps, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing depth map data. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,31 +39,87 @@ class NumpyDepthValidator: - """Validate numpy.ndarray data for depth images.""" + """Validate numpy depth data. + + This class provides validation methods for depth data stored as numpy arrays. + It ensures data consistency and correctness for depth maps and associated + metadata. + + The validator checks: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a depth map and associated metadata:: + + >>> from anomalib.data.validators import NumpyDepthValidator + >>> validator = NumpyDepthValidator() + >>> depth_map = np.random.rand(256, 256).astype(np.float32) + >>> validated_map = validator.validate_depth_map(depth_map) + """ @staticmethod def validate_image(image: np.ndarray) -> np.ndarray: - """Validate the image array.""" + """Validate image array. + + Args: + image (np.ndarray): Input image to validate. + + Returns: + np.ndarray: Validated image array. + """ return NumpyImageValidator.validate_image(image) @staticmethod def validate_gt_label(label: int | np.ndarray | None) -> np.ndarray | None: - """Validate the ground truth label.""" + """Validate ground truth label. + + Args: + label (int | np.ndarray | None): Input label to validate. + + Returns: + np.ndarray | None: Validated label. + """ return NumpyImageValidator.validate_gt_label(label) @staticmethod def validate_gt_mask(mask: np.ndarray | None) -> np.ndarray | None: - """Validate the ground truth mask.""" + """Validate ground truth mask. + + Args: + mask (np.ndarray | None): Input mask to validate. + + Returns: + np.ndarray | None: Validated mask. + """ return NumpyImageValidator.validate_gt_mask(mask) @staticmethod def validate_mask_path(mask_path: str | None) -> str | None: - """Validate the mask path.""" + """Validate mask path. + + Args: + mask_path (str | None): Path to mask file. + + Returns: + str | None: Validated mask path. + """ return NumpyImageValidator.validate_mask_path(mask_path) @staticmethod def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None: - """Validate the anomaly map.""" + """Validate anomaly map. + + Args: + anomaly_map (np.ndarray | None): Input anomaly map to validate. + + Returns: + np.ndarray | None: Validated anomaly map. + """ return NumpyImageValidator.validate_anomaly_map(anomaly_map) @staticmethod @@ -43,27 +127,76 @@ def validate_pred_score( pred_score: np.ndarray | float | None, anomaly_map: np.ndarray | None = None, ) -> np.ndarray | None: - """Validate the prediction score.""" + """Validate prediction score. + + Args: + pred_score (np.ndarray | float | None): Input prediction score. + anomaly_map (np.ndarray | None, optional): Associated anomaly map. + Defaults to None. + + Returns: + np.ndarray | None: Validated prediction score. + """ return NumpyImageValidator.validate_pred_score(pred_score, anomaly_map) @staticmethod def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: - """Validate the prediction mask.""" + """Validate prediction mask. + + Args: + pred_mask (np.ndarray | None): Input prediction mask to validate. + + Returns: + np.ndarray | None: Validated prediction mask. + """ return NumpyImageValidator.validate_pred_mask(pred_mask) @staticmethod def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: - """Validate the prediction label.""" + """Validate prediction label. + + Args: + pred_label (np.ndarray | None): Input prediction label to validate. + + Returns: + np.ndarray | None: Validated prediction label. + """ return NumpyImageValidator.validate_pred_label(pred_label) @staticmethod def validate_image_path(image_path: str | None) -> str | None: - """Validate the image path.""" + """Validate image path. + + Args: + image_path (str | None): Path to image file. + + Returns: + str | None: Validated image path. + """ return NumpyImageValidator.validate_image_path(image_path) @staticmethod def validate_depth_map(depth_map: np.ndarray | None) -> np.ndarray | None: - """Validate the depth map.""" + """Validate depth map array. + + Ensures the depth map has correct dimensions and data type. + + Args: + depth_map (np.ndarray | None): Input depth map to validate. + + Returns: + np.ndarray | None: Validated depth map as float32. + + Raises: + TypeError: If depth map is not a numpy array. + ValueError: If depth map dimensions are invalid. + + Example: + >>> depth_map = np.random.rand(256, 256).astype(np.float32) + >>> validated = NumpyDepthValidator.validate_depth_map(depth_map) + >>> validated.shape + (256, 256) + """ if depth_map is None: return None if not isinstance(depth_map, np.ndarray): @@ -79,66 +212,185 @@ def validate_depth_map(depth_map: np.ndarray | None) -> np.ndarray | None: @staticmethod def validate_depth_path(depth_path: str | None) -> str | None: - """Validate the depth path.""" + """Validate depth map file path. + + Args: + depth_path (str | None): Path to depth map file. + + Returns: + str | None: Validated depth map path. + """ return validate_path(depth_path) if depth_path else None @staticmethod def validate_explanation(explanation: str | None) -> str | None: - """Validate the explanation.""" + """Validate explanation string. + + Args: + explanation (str | None): Input explanation to validate. + + Returns: + str | None: Validated explanation string. + """ return NumpyImageValidator.validate_explanation(explanation) class NumpyDepthBatchValidator: - """Validate numpy.ndarray data for batches of depth images.""" + """Validate numpy depth data batches. + + This class provides validation methods for batches of depth data stored as numpy arrays. + It ensures data consistency and correctness for batches of depth maps and associated + metadata. + + The validator checks: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a batch of depth maps and associated metadata:: + + >>> from anomalib.data.validators import NumpyDepthBatchValidator + >>> validator = NumpyDepthBatchValidator() + >>> depth_maps = np.random.rand(32, 256, 256).astype(np.float32) + >>> labels = np.zeros(32) + >>> masks = np.zeros((32, 256, 256)) + >>> validator.validate_depth_map(depth_maps) + >>> validator.validate_gt_label(labels) + >>> validator.validate_gt_mask(masks) + """ @staticmethod def validate_image(image: np.ndarray) -> np.ndarray: - """Validate the image batch array.""" + """Validate image batch array. + + Args: + image (np.ndarray): Input image batch to validate. + + Returns: + np.ndarray: Validated image batch array. + """ return NumpyImageBatchValidator.validate_image(image) @staticmethod def validate_gt_label(gt_label: np.ndarray | Sequence[int] | None) -> np.ndarray | None: - """Validate the ground truth label batch.""" + """Validate ground truth label batch. + + Args: + gt_label (np.ndarray | Sequence[int] | None): Input label batch to validate. + + Returns: + np.ndarray | None: Validated label batch. + """ return NumpyImageBatchValidator.validate_gt_label(gt_label) @staticmethod def validate_gt_mask(gt_mask: np.ndarray | None) -> np.ndarray | None: - """Validate the ground truth mask batch.""" + """Validate ground truth mask batch. + + Args: + gt_mask (np.ndarray | None): Input mask batch to validate. + + Returns: + np.ndarray | None: Validated mask batch. + """ return NumpyImageBatchValidator.validate_gt_mask(gt_mask) @staticmethod def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: - """Validate the mask paths for a batch.""" + """Validate mask file paths for a batch. + + Args: + mask_path (Sequence[str] | None): Sequence of mask file paths to validate. + + Returns: + list[str] | None: Validated mask file paths. + """ return NumpyImageBatchValidator.validate_mask_path(mask_path) @staticmethod def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None: - """Validate the anomaly map batch.""" + """Validate anomaly map batch. + + Args: + anomaly_map (np.ndarray | None): Input anomaly map batch to validate. + + Returns: + np.ndarray | None: Validated anomaly map batch. + """ return NumpyImageBatchValidator.validate_anomaly_map(anomaly_map) @staticmethod def validate_pred_score(pred_score: np.ndarray | None) -> np.ndarray | None: - """Validate the prediction scores for a batch.""" + """Validate prediction scores for a batch. + + Args: + pred_score (np.ndarray | None): Input prediction scores to validate. + + Returns: + np.ndarray | None: Validated prediction scores. + """ return NumpyImageBatchValidator.validate_pred_score(pred_score) @staticmethod def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: - """Validate the prediction mask batch.""" + """Validate prediction mask batch. + + Args: + pred_mask (np.ndarray | None): Input prediction mask batch to validate. + + Returns: + np.ndarray | None: Validated prediction mask batch. + """ return NumpyImageBatchValidator.validate_pred_mask(pred_mask) @staticmethod def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: - """Validate the prediction label batch.""" + """Validate prediction label batch. + + Args: + pred_label (np.ndarray | None): Input prediction label batch to validate. + + Returns: + np.ndarray | None: Validated prediction label batch. + """ return NumpyImageBatchValidator.validate_pred_label(pred_label) @staticmethod def validate_image_path(image_path: list[str] | None) -> list[str] | None: - """Validate the image paths for a batch.""" + """Validate image file paths for a batch. + + Args: + image_path (list[str] | None): List of image file paths to validate. + + Returns: + list[str] | None: Validated image file paths. + """ return NumpyImageBatchValidator.validate_image_path(image_path) @staticmethod def validate_depth_map(depth_map: np.ndarray | None) -> np.ndarray | None: - """Validate the depth map batch.""" + """Validate depth map batch. + + Args: + depth_map (np.ndarray | None): Input depth map batch to validate. + + Returns: + np.ndarray | None: Validated depth map batch as float32. + + Raises: + TypeError: If depth map batch is not a numpy array. + ValueError: If depth map batch dimensions are invalid. + + Example: + >>> depth_maps = np.random.rand(32, 256, 256).astype(np.float32) + >>> validated = NumpyDepthBatchValidator.validate_depth_map(depth_maps) + >>> validated.shape + (32, 256, 256) + """ if depth_map is None: return None if not isinstance(depth_map, np.ndarray): @@ -154,7 +406,17 @@ def validate_depth_map(depth_map: np.ndarray | None) -> np.ndarray | None: @staticmethod def validate_depth_path(depth_path: list[str] | None) -> list[str] | None: - """Validate the depth paths for a batch.""" + """Validate depth map file paths for a batch. + + Args: + depth_path (list[str] | None): List of depth map file paths to validate. + + Returns: + list[str] | None: Validated depth map file paths. + + Raises: + TypeError: If depth_path is not a list of strings. + """ if depth_path is None: return None if not isinstance(depth_path, list): @@ -164,5 +426,12 @@ def validate_depth_path(depth_path: list[str] | None) -> list[str] | None: @staticmethod def validate_explanation(explanation: list[str] | None) -> list[str] | None: - """Validate the explanations for a batch.""" + """Validate explanation strings for a batch. + + Args: + explanation (list[str] | None): List of explanation strings to validate. + + Returns: + list[str] | None: Validated explanation strings. + """ return NumpyImageBatchValidator.validate_explanation(explanation) diff --git a/src/anomalib/data/validators/numpy/image.py b/src/anomalib/data/validators/numpy/image.py index 455ecde2b0..579ca2cf01 100644 --- a/src/anomalib/data/validators/numpy/image.py +++ b/src/anomalib/data/validators/numpy/image.py @@ -1,4 +1,32 @@ -"""Validate numpy image data.""" +"""Validate numpy image data. + +This module provides validators for image data stored as numpy arrays. The validators +ensure data consistency and correctness for images and batches of images. + +The validators check: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + +Example: + Validate a single image:: + + >>> from anomalib.data.validators import NumpyImageValidator + >>> validator = NumpyImageValidator() + >>> validator.validate_image(image) + + Validate a batch of images:: + + >>> from anomalib.data.validators import NumpyImageBatchValidator + >>> validator = NumpyImageBatchValidator() + >>> validator(images=images, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing image data. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,33 +38,71 @@ class NumpyImageValidator: - """Validate numpy.ndarray data for images.""" + """Validate numpy array data for images. + + This class provides validation methods for image data stored as numpy arrays. + It ensures data consistency and correctness for images and associated metadata. + + The validator checks: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate an image and associated metadata:: + + >>> from anomalib.data.validators import NumpyImageValidator + >>> validator = NumpyImageValidator() + >>> image = np.random.rand(256, 256, 3) + >>> validated_image = validator.validate_image(image) + >>> label = 1 + >>> validated_label = validator.validate_gt_label(label) + >>> mask = np.random.randint(0, 2, (256, 256)) + >>> validated_mask = validator.validate_gt_mask(mask) + + Note: + The validator is used internally by the data modules to ensure data + consistency before processing. + """ @staticmethod def validate_image(image: np.ndarray) -> np.ndarray: """Validate the image array. + Validates and normalizes input image arrays. Handles both RGB and grayscale + images, and converts between channel-first and channel-last formats. + Args: - image (np.ndarray): Input image array. + image (``np.ndarray``): Input image array to validate. Returns: - np.ndarray: Validated image array. + ``np.ndarray``: Validated image array in channel-last format (H,W,C). Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the image array does not have the correct shape. - - Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> rgb_image = np.random.rand(256, 256, 3) - >>> validated_rgb = NumpyImageValidator.validate_image(rgb_image) - >>> validated_rgb.shape - (256, 256, 3) - >>> gray_image = np.random.rand(256, 256) - >>> validated_gray = NumpyImageValidator.validate_image(gray_image) - >>> validated_gray.shape - (256, 256, 1) + TypeError: If ``image`` is not a numpy array. + ValueError: If ``image`` dimensions or channels are invalid. + + Example: + Validate RGB and grayscale images:: + + >>> import numpy as np + >>> from anomalib.data.validators import NumpyImageValidator + >>> rgb_image = np.random.rand(256, 256, 3) + >>> validated_rgb = NumpyImageValidator.validate_image(rgb_image) + >>> validated_rgb.shape + (256, 256, 3) + >>> gray_image = np.random.rand(256, 256) + >>> validated_gray = NumpyImageValidator.validate_image(gray_image) + >>> validated_gray.shape + (256, 256, 1) + + Note: + - 2D arrays are treated as grayscale and expanded to 3D + - Channel-first arrays (C,H,W) are converted to channel-last (H,W,C) + - Output is always float32 type """ if not isinstance(image, np.ndarray): msg = f"Image must be a numpy.ndarray, got {type(image)}." @@ -64,27 +130,36 @@ def validate_image(image: np.ndarray) -> np.ndarray: def validate_gt_label(label: int | np.ndarray | None) -> np.ndarray | None: """Validate the ground truth label. + Validates and normalizes input labels to boolean numpy arrays. + Args: - label (int | np.ndarray | None): Input ground truth label. + label (``int`` | ``np.ndarray`` | ``None``): Input ground truth label. Returns: - np.ndarray | None: Validated ground truth label as a boolean array, or None. + ``np.ndarray`` | ``None``: Validated label as boolean array, or None. Raises: - TypeError: If the input is neither an integer nor a numpy.ndarray. - ValueError: If the label shape or dtype is invalid. - - Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> label_int = 1 - >>> validated_label = NumpyImageValidator.validate_gt_label(label_int) - >>> validated_label - array(True) - >>> label_array = np.array(0) - >>> validated_label = NumpyImageValidator.validate_gt_label(label_array) - >>> validated_label - array(False) + TypeError: If ``label`` is not an integer or numpy array. + ValueError: If ``label`` shape is not scalar. + + Example: + Validate integer and array labels:: + + >>> import numpy as np + >>> from anomalib.data.validators import NumpyImageValidator + >>> label_int = 1 + >>> validated_label = NumpyImageValidator.validate_gt_label(label_int) + >>> validated_label + array(True) + >>> label_array = np.array(0) + >>> validated_label = NumpyImageValidator.validate_gt_label(label_array) + >>> validated_label + array(False) + + Note: + - Integer inputs are converted to numpy arrays + - Output is always boolean type + - None inputs return None """ if label is None: return None @@ -105,23 +180,32 @@ def validate_gt_label(label: int | np.ndarray | None) -> np.ndarray | None: def validate_gt_mask(mask: np.ndarray | None) -> np.ndarray | None: """Validate the ground truth mask. + Validates and normalizes input mask arrays. + Args: - mask (np.ndarray | None): Input ground truth mask. + mask (``np.ndarray`` | ``None``): Input ground truth mask. Returns: - np.ndarray | None: Validated ground truth mask, or None. + ``np.ndarray`` | ``None``: Validated mask as boolean array, or None. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the mask shape is invalid. - - Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> mask = np.random.randint(0, 2, (224, 224)) - >>> validated_mask = NumpyImageValidator.validate_gt_mask(mask) - >>> validated_mask.shape - (224, 224) + TypeError: If ``mask`` is not a numpy array. + ValueError: If ``mask`` dimensions are invalid. + + Example: + Validate a binary mask:: + + >>> import numpy as np + >>> from anomalib.data.validators import NumpyImageValidator + >>> mask = np.random.randint(0, 2, (224, 224)) + >>> validated_mask = NumpyImageValidator.validate_gt_mask(mask) + >>> validated_mask.shape + (224, 224) + + Note: + - 3D masks with shape (H,W,1) are squeezed to (H,W) + - Output is always boolean type + - None inputs return None """ if mask is None: return None @@ -142,23 +226,32 @@ def validate_gt_mask(mask: np.ndarray | None) -> np.ndarray | None: def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None: """Validate the anomaly map. + Validates and normalizes input anomaly map arrays. + Args: - anomaly_map (np.ndarray | None): Input anomaly map. + anomaly_map (``np.ndarray`` | ``None``): Input anomaly map. Returns: - np.ndarray | None: Validated anomaly map, or None. + ``np.ndarray`` | ``None``: Validated anomaly map as float32 array, or None. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the anomaly map shape is invalid. - - Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> anomaly_map = np.random.rand(224, 224) - >>> validated_map = NumpyImageValidator.validate_anomaly_map(anomaly_map) - >>> validated_map.shape - (224, 224) + TypeError: If ``anomaly_map`` is not a numpy array. + ValueError: If ``anomaly_map`` dimensions are invalid. + + Example: + Validate an anomaly map:: + + >>> import numpy as np + >>> from anomalib.data.validators import NumpyImageValidator + >>> anomaly_map = np.random.rand(224, 224) + >>> validated_map = NumpyImageValidator.validate_anomaly_map(anomaly_map) + >>> validated_map.shape + (224, 224) + + Note: + - 3D maps with shape (1,H,W) are squeezed to (H,W) + - Output is always float32 type + - None inputs return None """ if anomaly_map is None: return None @@ -180,17 +273,22 @@ def validate_image_path(image_path: str | None) -> str | None: """Validate the image path. Args: - image_path (str | None): Input image path. + image_path (``str`` | ``None``): Input image path. Returns: - str | None: Validated image path, or None. + ``str`` | ``None``: Validated image path, or None. - Examples: - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> path = "/path/to/image.jpg" - >>> validated_path = NumpyImageValidator.validate_image_path(path) - >>> validated_path == path - True + Example: + Validate an image path:: + + >>> from anomalib.data.validators import NumpyImageValidator + >>> path = "/path/to/image.jpg" + >>> validated_path = NumpyImageValidator.validate_image_path(path) + >>> validated_path == path + True + + Note: + Returns None if input is None. """ return validate_path(image_path) if image_path else None @@ -199,17 +297,22 @@ def validate_mask_path(mask_path: str | None) -> str | None: """Validate the mask path. Args: - mask_path (str | None): Input mask path. + mask_path (``str`` | ``None``): Input mask path. Returns: - str | None: Validated mask path, or None. + ``str`` | ``None``: Validated mask path, or None. - Examples: - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> path = "/path/to/mask.png" - >>> validated_path = NumpyImageValidator.validate_mask_path(path) - >>> validated_path == path - True + Example: + Validate a mask path:: + + >>> from anomalib.data.validators import NumpyImageValidator + >>> path = "/path/to/mask.png" + >>> validated_path = NumpyImageValidator.validate_mask_path(path) + >>> validated_path == path + True + + Note: + Returns None if input is None. """ return validate_path(mask_path) if mask_path else None @@ -220,28 +323,37 @@ def validate_pred_score( ) -> np.ndarray | None: """Validate the prediction score. + Validates and normalizes prediction scores to float32 numpy arrays. + Args: - pred_score (np.ndarray | float | None): Input prediction score. - anomaly_map (np.ndarray | None): Input anomaly map. + pred_score (``np.ndarray`` | ``float`` | ``None``): Input prediction score. + anomaly_map (``np.ndarray`` | ``None``): Input anomaly map. Returns: - np.ndarray | None: Validated prediction score as a float32 array, or None. + ``np.ndarray`` | ``None``: Validated score as float32 array, or None. Raises: - TypeError: If the input is neither a float, numpy.ndarray, nor None. - ValueError: If the prediction score is not a scalar. - - Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> score = 0.8 - >>> validated_score = NumpyImageValidator.validate_pred_score(score) - >>> validated_score - array(0.8, dtype=float32) - >>> score_array = np.array(0.7) - >>> validated_score = NumpyImageValidator.validate_pred_score(score_array) - >>> validated_score - array(0.7, dtype=float32) + TypeError: If ``pred_score`` cannot be converted to numpy array. + ValueError: If ``pred_score`` is not scalar. + + Example: + Validate prediction scores:: + + >>> import numpy as np + >>> from anomalib.data.validators import NumpyImageValidator + >>> score = 0.8 + >>> validated_score = NumpyImageValidator.validate_pred_score(score) + >>> validated_score + array(0.8, dtype=float32) + >>> score_array = np.array(0.7) + >>> validated_score = NumpyImageValidator.validate_pred_score(score_array) + >>> validated_score + array(0.7, dtype=float32) + + Note: + - If input is None and anomaly_map provided, returns max of anomaly_map + - Output is always float32 type + - None inputs with no anomaly_map return None """ if pred_score is None: return np.amax(anomaly_map) if anomaly_map is not None else None @@ -263,19 +375,26 @@ def validate_pred_score( def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: """Validate the prediction mask. + Validates and normalizes prediction mask arrays. + Args: - pred_mask (np.ndarray | None): Input prediction mask. + pred_mask (``np.ndarray`` | ``None``): Input prediction mask. Returns: - np.ndarray | None: Validated prediction mask, or None. + ``np.ndarray`` | ``None``: Validated mask as boolean array, or None. - Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> mask = np.random.randint(0, 2, (224, 224)) - >>> validated_mask = NumpyImageValidator.validate_pred_mask(mask) - >>> validated_mask.shape - (224, 224) + Example: + Validate a prediction mask:: + + >>> import numpy as np + >>> from anomalib.data.validators import NumpyImageValidator + >>> mask = np.random.randint(0, 2, (224, 224)) + >>> validated_mask = NumpyImageValidator.validate_pred_mask(mask) + >>> validated_mask.shape + (224, 224) + + Note: + Uses same validation as ground truth masks. """ return NumpyImageValidator.validate_gt_mask(pred_mask) # We can reuse the gt_mask validation @@ -283,23 +402,31 @@ def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: """Validate the prediction label. + Validates and normalizes prediction labels to boolean numpy arrays. + Args: - pred_label (np.ndarray | None): Input prediction label. + pred_label (``np.ndarray`` | ``None``): Input prediction label. Returns: - np.ndarray | None: Validated prediction label as a boolean array, or None. + ``np.ndarray`` | ``None``: Validated label as boolean array, or None. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the prediction label is not a scalar. - - Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageValidator - >>> label = np.array(1) - >>> validated_label = NumpyImageValidator.validate_pred_label(label) - >>> validated_label - array(True) + TypeError: If ``pred_label`` cannot be converted to numpy array. + ValueError: If ``pred_label`` is not scalar. + + Example: + Validate a prediction label:: + + >>> import numpy as np + >>> from anomalib.data.validators import NumpyImageValidator + >>> label = np.array(1) + >>> validated_label = NumpyImageValidator.validate_pred_label(label) + >>> validated_label + array(True) + + Note: + - Output is always boolean type + - None inputs return None """ if pred_label is None: return None @@ -317,20 +444,28 @@ def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: @staticmethod def validate_explanation(explanation: str | None) -> str | None: - """Validate the explanation. + """Validate the explanation string. Args: - explanation (str | None): Input explanation. + explanation (``str`` | ``None``): Input explanation string. Returns: - str | None: Validated explanation, or None. + ``str`` | ``None``: Validated explanation string, or None. - Examples: - >>> from anomalib.dataclasses.validators import ImageValidator - >>> explanation = "The image has a crack on the wall." - >>> validated_explanation = ImageValidator.validate_explanation(explanation) - >>> validated_explanation == explanation - True + Raises: + TypeError: If ``explanation`` is not a string. + + Example: + Validate an explanation string:: + + >>> from anomalib.dataclasses.validators import ImageValidator + >>> explanation = "The image has a crack on the wall." + >>> validated = ImageValidator.validate_explanation(explanation) + >>> validated == explanation + True + + Note: + Returns None if input is None. """ if explanation is None: return None @@ -341,41 +476,80 @@ def validate_explanation(explanation: str | None) -> str | None: class NumpyImageBatchValidator: - """Validate numpy.ndarray data for batches of images.""" + """Validate batches of image data stored as numpy arrays. + + This class provides validation methods for batches of image data stored as numpy arrays. + It ensures data consistency and correctness for images and associated metadata. + + The validator checks: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a batch of images and associated metadata:: + + >>> from anomalib.data.validators import NumpyImageBatchValidator + >>> validator = NumpyImageBatchValidator() + >>> images = np.random.rand(32, 256, 256, 3) + >>> labels = np.zeros(32) + >>> masks = np.zeros((32, 256, 256)) + >>> validator.validate_image(images) + >>> validator.validate_gt_label(labels) + >>> validator.validate_gt_mask(masks) + """ @staticmethod def validate_image(image: np.ndarray) -> np.ndarray: """Validate the image batch array. + This method validates batches of images stored as numpy arrays. It handles: + - Single images and batches + - Grayscale and RGB images + - Channel-first and channel-last formats + - Type conversion to float32 + Args: - image (np.ndarray): Input image batch array. + image (``np.ndarray``): Input image batch array. Returns: - np.ndarray: Validated image batch array. + ``np.ndarray``: Validated image batch array in [N,H,W,C] format. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the image batch array does not have the correct shape. + TypeError: If ``image`` is not a numpy array. + ValueError: If ``image`` shape is invalid. Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> batch = np.random.rand(32, 224, 224, 3) - >>> validated_batch = NumpyImageBatchValidator.validate_image(batch) - >>> validated_batch.shape - (32, 224, 224, 3) - >>> grayscale_batch = np.random.rand(32, 224, 224) - >>> validated_grayscale = NumpyImageBatchValidator.validate_image(grayscale_batch) - >>> validated_grayscale.shape - (32, 224, 224, 1) - >>> torch_style_batch = np.random.rand(32, 3, 224, 224) - >>> validated_torch_style = NumpyImageBatchValidator.validate_image(torch_style_batch) - >>> validated_torch_style.shape - (32, 224, 224, 3) - >>> single_image = np.zeros((224, 224, 3)) - >>> validated_single = NumpyImageBatchValidator.validate_image(single_image) - >>> validated_single.shape - (1, 224, 224, 3) + Validate RGB batch:: + + >>> batch = np.random.rand(32, 224, 224, 3) + >>> validated = NumpyImageBatchValidator.validate_image(batch) + >>> validated.shape + (32, 224, 224, 3) + + Validate grayscale batch:: + + >>> gray = np.random.rand(32, 224, 224) + >>> validated = NumpyImageBatchValidator.validate_image(gray) + >>> validated.shape + (32, 224, 224, 1) + + Validate channel-first batch:: + + >>> chf = np.random.rand(32, 3, 224, 224) + >>> validated = NumpyImageBatchValidator.validate_image(chf) + >>> validated.shape + (32, 224, 224, 3) + + Validate single image:: + + >>> img = np.zeros((224, 224, 3)) + >>> validated = NumpyImageBatchValidator.validate_image(img) + >>> validated.shape + (1, 224, 224, 3) """ # Check if the image is a numpy array if not isinstance(image, np.ndarray): @@ -410,27 +584,37 @@ def validate_image(image: np.ndarray) -> np.ndarray: def validate_gt_label(gt_label: np.ndarray | Sequence[int] | None) -> np.ndarray | None: """Validate the ground truth label batch. + This method validates batches of ground truth labels. It handles: + - Numpy arrays and sequences of integers + - Type conversion to boolean + - Shape validation + Args: - gt_label (np.ndarray | Sequence[int] | None): Input ground truth label batch. + gt_label (``np.ndarray`` | ``Sequence[int]`` | ``None``): Input ground truth label + batch. Returns: - np.ndarray | None: Validated ground truth label batch as a boolean array, or None. + ``np.ndarray`` | ``None``: Validated ground truth label batch as boolean array, + or ``None``. Raises: - TypeError: If the input is not a numpy.ndarray or Sequence[int]. - ValueError: If the label batch shape is invalid. + TypeError: If ``gt_label`` is not a numpy array or sequence of integers. + ValueError: If ``gt_label`` shape is invalid. Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> labels = np.array([0, 1, 1, 0]) - >>> validated_labels = NumpyImageBatchValidator.validate_gt_label(labels) - >>> validated_labels - array([False, True, True, False]) - >>> list_labels = [1, 0, 1, 1] - >>> validated_list = NumpyImageBatchValidator.validate_gt_label(list_labels) - >>> validated_list - array([ True, False, True, True]) + Validate numpy array labels:: + + >>> labels = np.array([0, 1, 1, 0]) + >>> validated = NumpyImageBatchValidator.validate_gt_label(labels) + >>> validated + array([False, True, True, False]) + + Validate list labels:: + + >>> labels = [1, 0, 1, 1] + >>> validated = NumpyImageBatchValidator.validate_gt_label(labels) + >>> validated + array([ True, False, True, True]) """ if gt_label is None: return None @@ -448,29 +632,38 @@ def validate_gt_label(gt_label: np.ndarray | Sequence[int] | None) -> np.ndarray def validate_gt_mask(gt_mask: np.ndarray | None) -> np.ndarray | None: """Validate the ground truth mask batch. + This method validates batches of ground truth masks. It handles: + - Channel-first and channel-last formats + - Type conversion to boolean + - Shape validation + Args: - gt_mask (np.ndarray | None): Input ground truth mask batch. + gt_mask (``np.ndarray`` | ``None``): Input ground truth mask batch. Returns: - np.ndarray | None: Validated ground truth mask batch as a boolean array, or None. + ``np.ndarray`` | ``None``: Validated ground truth mask batch as boolean array, + or ``None``. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the mask batch shape is invalid. + TypeError: If ``gt_mask`` is not a numpy array. + ValueError: If ``gt_mask`` shape is invalid. Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> masks = np.random.randint(0, 2, (4, 224, 224)) - >>> validated_masks = NumpyImageBatchValidator.validate_gt_mask(masks) - >>> validated_masks.shape - (4, 224, 224) - >>> validated_masks.dtype - dtype('bool') - >>> torch_style_masks = np.random.randint(0, 2, (4, 1, 224, 224)) - >>> validated_torch_style = NumpyImageBatchValidator.validate_gt_mask(torch_style_masks) - >>> validated_torch_style.shape - (4, 224, 224, 1) + Validate channel-last masks:: + + >>> masks = np.random.randint(0, 2, (4, 224, 224)) + >>> validated = NumpyImageBatchValidator.validate_gt_mask(masks) + >>> validated.shape + (4, 224, 224) + >>> validated.dtype + dtype('bool') + + Validate channel-first masks:: + + >>> masks = np.random.randint(0, 2, (4, 1, 224, 224)) + >>> validated = NumpyImageBatchValidator.validate_gt_mask(masks) + >>> validated.shape + (4, 224, 224, 1) """ if gt_mask is None: return None @@ -495,26 +688,26 @@ def validate_gt_mask(gt_mask: np.ndarray | None) -> np.ndarray | None: def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: """Validate the mask paths for a batch. + This method validates sequences of mask file paths. It handles: + - Type conversion to strings + - Path sequence validation + Args: - mask_path (Sequence[str] | None): Input sequence of mask paths. + mask_path (``Sequence[str]`` | ``None``): Input sequence of mask paths. Returns: - list[str] | None: Validated list of mask paths, or None. + ``list[str]`` | ``None``: Validated list of mask paths, or ``None``. Raises: - TypeError: If the input is not a sequence of strings. - ValueError: If the number of paths doesn't match the batch size. + TypeError: If ``mask_path`` is not a sequence of strings. Examples: - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> paths = ['mask1.png', 'mask2.png', 'mask3.png', 'mask4.png'] - >>> validated_paths = NumpyImageBatchValidator.validate_mask_path(paths) - >>> validated_paths - ['mask1.png', 'mask2.png', 'mask3.png', 'mask4.png'] - >>> NumpyImageBatchValidator.validate_mask_path(['mask1.png', 'mask2.png'], 4) - Traceback (most recent call last): - ... - ValueError: Invalid length for mask_path. Got length 2 for batch size 4. + Validate list of paths:: + + >>> paths = ['mask1.png', 'mask2.png', 'mask3.png'] + >>> validated = NumpyImageBatchValidator.validate_mask_path(paths) + >>> validated + ['mask1.png', 'mask2.png', 'mask3.png'] """ if mask_path is None: return None @@ -527,29 +720,37 @@ def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None: """Validate the anomaly map batch. + This method validates batches of anomaly maps. It handles: + - Channel-first and channel-last formats + - Type conversion to float32 + - Shape validation + Args: - anomaly_map (np.ndarray | None): Input anomaly map batch. + anomaly_map (``np.ndarray`` | ``None``): Input anomaly map batch. Returns: - np.ndarray | None: Validated anomaly map batch, or None. + ``np.ndarray`` | ``None``: Validated anomaly map batch, or ``None``. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the anomaly map batch shape is invalid. + TypeError: If ``anomaly_map`` is not a numpy array. + ValueError: If ``anomaly_map`` shape is invalid. Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> anomaly_maps = np.random.rand(4, 224, 224) - >>> validated_maps = NumpyImageBatchValidator.validate_anomaly_map(anomaly_maps) - >>> validated_maps.shape - (4, 224, 224) - >>> validated_maps.dtype - dtype('float32') - >>> torch_style_maps = np.random.rand(4, 1, 224, 224) - >>> validated_torch_style = NumpyImageBatchValidator.validate_anomaly_map(torch_style_maps) - >>> validated_torch_style.shape - (4, 224, 224, 1) + Validate channel-last maps:: + + >>> maps = np.random.rand(4, 224, 224) + >>> validated = NumpyImageBatchValidator.validate_anomaly_map(maps) + >>> validated.shape + (4, 224, 224) + >>> validated.dtype + dtype('float32') + + Validate channel-first maps:: + + >>> maps = np.random.rand(4, 1, 224, 224) + >>> validated = NumpyImageBatchValidator.validate_anomaly_map(maps) + >>> validated.shape + (4, 224, 224, 1) """ if anomaly_map is None: return None @@ -568,30 +769,38 @@ def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None: def validate_pred_score(pred_score: np.ndarray | None) -> np.ndarray | None: """Validate the prediction scores for a batch. + This method validates batches of prediction scores. It handles: + - 1D and 2D arrays + - Type conversion to float32 + - Shape validation + Args: - pred_score (np.ndarray | None): Input prediction score batch. + pred_score (``np.ndarray`` | ``None``): Input prediction score batch. Returns: - np.ndarray | None: Validated prediction score batch, or None. + ``np.ndarray`` | ``None``: Validated prediction score batch, or ``None``. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the prediction score batch is not 1-dimensional or 2-dimensional. + TypeError: If ``pred_score`` is not a numpy array. + ValueError: If ``pred_score`` shape is invalid. Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> scores = np.array([0.1, 0.8, 0.3, 0.6]) - >>> validated_scores = NumpyImageBatchValidator.validate_pred_score(scores) - >>> validated_scores - array([0.1, 0.8, 0.3, 0.6], dtype=float32) - >>> scores_2d = np.array([[0.1], [0.8], [0.3], [0.6]]) - >>> validated_scores_2d = NumpyImageBatchValidator.validate_pred_score(scores_2d) - >>> validated_scores_2d - array([[0.1], - [0.8], - [0.3], - [0.6]], dtype=float32) + Validate 1D scores:: + + >>> scores = np.array([0.1, 0.8, 0.3, 0.6]) + >>> validated = NumpyImageBatchValidator.validate_pred_score(scores) + >>> validated + array([0.1, 0.8, 0.3, 0.6], dtype=float32) + + Validate 2D scores:: + + >>> scores = np.array([[0.1], [0.8], [0.3], [0.6]]) + >>> validated = NumpyImageBatchValidator.validate_pred_score(scores) + >>> validated + array([[0.1], + [0.8], + [0.3], + [0.6]], dtype=float32) """ if pred_score is None: return None @@ -608,29 +817,37 @@ def validate_pred_score(pred_score: np.ndarray | None) -> np.ndarray | None: def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: """Validate the prediction mask batch. + This method validates batches of prediction masks. It handles: + - Channel-first and channel-last formats + - Type conversion to boolean + - Shape validation + Args: - pred_mask (np.ndarray | None): Input prediction mask batch. + pred_mask (``np.ndarray`` | ``None``): Input prediction mask batch. Returns: - np.ndarray | None: Validated prediction mask batch, or None. + ``np.ndarray`` | ``None``: Validated prediction mask batch, or ``None``. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the prediction mask batch shape is invalid. + TypeError: If ``pred_mask`` is not a numpy array. + ValueError: If ``pred_mask`` shape is invalid. Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> masks = np.random.randint(0, 2, (4, 224, 224)) - >>> validated_masks = NumpyImageBatchValidator.validate_pred_mask(masks) - >>> validated_masks.shape - (4, 224, 224) - >>> validated_masks.dtype - dtype('bool') - >>> torch_style_masks = np.random.randint(0, 2, (4, 1, 224, 224)) - >>> validated_torch_style = NumpyImageBatchValidator.validate_pred_mask(torch_style_masks) - >>> validated_torch_style.shape - (4, 224, 224, 1) + Validate channel-last masks:: + + >>> masks = np.random.randint(0, 2, (4, 224, 224)) + >>> validated = NumpyImageBatchValidator.validate_pred_mask(masks) + >>> validated.shape + (4, 224, 224) + >>> validated.dtype + dtype('bool') + + Validate channel-first masks:: + + >>> masks = np.random.randint(0, 2, (4, 1, 224, 224)) + >>> validated = NumpyImageBatchValidator.validate_pred_mask(masks) + >>> validated.shape + (4, 224, 224, 1) """ return NumpyImageBatchValidator.validate_gt_mask(pred_mask) @@ -638,30 +855,39 @@ def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: """Validate the prediction label batch. + This method validates batches of prediction labels. It handles: + - 1D and 2D arrays + - Type conversion to boolean + - Shape validation + Args: - pred_label (np.ndarray | None): Input prediction label batch. + pred_label (``np.ndarray`` | ``None``): Input prediction label batch. Returns: - np.ndarray | None: Validated prediction label batch as a boolean array, or None. + ``np.ndarray`` | ``None``: Validated prediction label batch as boolean array, + or ``None``. Raises: - TypeError: If the input is not a numpy.ndarray. - ValueError: If the prediction label batch is not 1-dimensional or 2-dimensional. + TypeError: If ``pred_label`` is not a numpy array. + ValueError: If ``pred_label`` shape is invalid. Examples: - >>> import numpy as np - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> labels = np.array([0, 1, 1, 0]) - >>> validated_labels = NumpyImageBatchValidator.validate_pred_label(labels) - >>> validated_labels - array([False, True, True, False]) - >>> labels_2d = np.array([[0], [1], [1], [0]]) - >>> validated_labels_2d = NumpyImageBatchValidator.validate_pred_label(labels_2d) - >>> validated_labels_2d - array([[False], - [ True], - [ True], - [False]]) + Validate 1D labels:: + + >>> labels = np.array([0, 1, 1, 0]) + >>> validated = NumpyImageBatchValidator.validate_pred_label(labels) + >>> validated + array([False, True, True, False]) + + Validate 2D labels:: + + >>> labels = np.array([[0], [1], [1], [0]]) + >>> validated = NumpyImageBatchValidator.validate_pred_label(labels) + >>> validated + array([[False], + [ True], + [ True], + [False]]) """ if pred_label is None: return None @@ -677,23 +903,33 @@ def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: def validate_image_path(image_path: list[str] | None) -> list[str] | None: """Validate the image paths for a batch. + This method validates lists of image file paths. It handles: + - Type conversion to strings + - Path list validation + Args: - image_path (list[str] | None): Input list of image paths. + image_path (``list[str]`` | ``None``): Input list of image paths. Returns: - list[str] | None: Validated list of image paths, or None. + ``list[str]`` | ``None``: Validated list of image paths, or ``None``. Raises: - TypeError: If the input is not a list of strings. + TypeError: If ``image_path`` is not a list. Examples: - >>> from anomalib.data.validators.numpy.image import NumpyImageBatchValidator - >>> paths = ['image1.jpg', 'image2.jpg', 'image3.jpg'] - >>> validated_paths = NumpyImageBatchValidator.validate_image_path(paths) - >>> validated_paths - ['image1.jpg', 'image2.jpg', 'image3.jpg'] - >>> NumpyImageBatchValidator.validate_image_path(['image1.jpg', 2, 'image3.jpg']) - ['image1.jpg', '2', 'image3.jpg'] + Validate list of paths:: + + >>> paths = ['image1.jpg', 'image2.jpg', 'image3.jpg'] + >>> validated = NumpyImageBatchValidator.validate_image_path(paths) + >>> validated + ['image1.jpg', 'image2.jpg', 'image3.jpg'] + + Validate mixed type paths:: + + >>> paths = ['image1.jpg', 2, 'image3.jpg'] + >>> validated = NumpyImageBatchValidator.validate_image_path(paths) + >>> validated + ['image1.jpg', '2', 'image3.jpg'] """ if image_path is None: return None @@ -706,21 +942,26 @@ def validate_image_path(image_path: list[str] | None) -> list[str] | None: def validate_explanation(explanation: list[str] | None) -> list[str] | None: """Validate the explanations for a batch. + This method validates lists of explanation strings. It handles: + - Type conversion to strings + - List validation + Args: - explanation (list[str] | None): Input list of explanations. + explanation (``list[str]`` | ``None``): Input list of explanations. Returns: - list[str] | None: Validated list of explanations, or None. + ``list[str]`` | ``None``: Validated list of explanations, or ``None``. Raises: - TypeError: If the input is not a list of strings. + TypeError: If ``explanation`` is not a list. Examples: - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> explanations = ["The image has a crack on the wall.", "The image has a dent on the car."] - >>> validated_explanations = ImageBatchValidator.validate_explanation(explanations) - >>> print(validated_explanations) - ['The image has a crack on the wall.', 'The image has a dent on the car.'] + Validate list of explanations:: + + >>> explanations = ["The image has a crack.", "The image has a dent."] + >>> validated = NumpyImageBatchValidator.validate_explanation(explanations) + >>> validated + ['The image has a crack.', 'The image has a dent.'] """ if explanation is None: return None diff --git a/src/anomalib/data/validators/numpy/video.py b/src/anomalib/data/validators/numpy/video.py index e12682881b..05eb42e910 100644 --- a/src/anomalib/data/validators/numpy/video.py +++ b/src/anomalib/data/validators/numpy/video.py @@ -1,4 +1,32 @@ -"""Validate numpy video data.""" +"""Validate numpy video data. + +This module provides validators for video data stored as numpy arrays. The validators +ensure data consistency and correctness for videos and batches of videos. + +The validators check: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + +Example: + Validate a single video:: + + >>> from anomalib.data.validators import NumpyVideoValidator + >>> validator = NumpyVideoValidator() + >>> validator.validate_image(video) + + Validate a batch of videos:: + + >>> from anomalib.data.validators import NumpyVideoBatchValidator + >>> validator = NumpyVideoBatchValidator() + >>> validator(videos=videos, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing video data. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,29 +39,62 @@ class NumpyVideoValidator: - """Validate numpy.ndarray data for videos.""" + """Validate numpy array data for videos. + + This class provides validation methods for video data stored as numpy arrays. + It ensures data consistency and correctness for videos and associated metadata. + + The validator checks: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a video and associated metadata:: + + >>> from anomalib.data.validators import NumpyVideoValidator + >>> validator = NumpyVideoValidator() + >>> video = np.random.rand(10, 224, 224, 3) # [T, H, W, C] + >>> validated_video = validator.validate_image(video) + >>> label = 1 + >>> validated_label = validator.validate_gt_label(label) + >>> mask = np.random.randint(0, 2, (10, 224, 224)) # [T, H, W] + >>> validated_mask = validator.validate_gt_mask(mask) + + Note: + The validator is used internally by the data modules to ensure data + consistency before processing. + """ @staticmethod def validate_image(image: np.ndarray) -> np.ndarray: """Validate the video array. + Validates and normalizes input video arrays. Handles both RGB and grayscale + videos, and ensures proper time dimension. + Args: - image (np.ndarray): Input video array to validate. + image (``np.ndarray``): Input video array to validate. Returns: - np.ndarray: Validated video array as float32 with an added time dimension if not present. + ``np.ndarray``: Validated video array in format [T, H, W, C] as float32. Raises: - TypeError: If the input is not a numpy array. - ValueError: If the array dimensions or channel count are invalid. + TypeError: If ``image`` is not a numpy array. + ValueError: If ``image`` dimensions or channels are invalid. Example: - >>> import numpy as np - >>> validator = NumpyVideoValidator() - >>> video = np.random.rand(10, 224, 224, 3) # [T, H, W, C] - >>> validated_video = validator.validate_image(video) - >>> print(validated_video.shape, validated_video.dtype) - (10, 224, 224, 3) float32 + Validate RGB video:: + + >>> import numpy as np + >>> validator = NumpyVideoValidator() + >>> video = np.random.rand(10, 224, 224, 3) # [T, H, W, C] + >>> validated_video = validator.validate_image(video) + >>> print(validated_video.shape, validated_video.dtype) + (10, 224, 224, 3) float32 """ if not isinstance(image, np.ndarray): msg = f"Video must be a numpy.ndarray, got {type(image)}." @@ -58,14 +119,15 @@ def validate_gt_label(label: int | np.ndarray | None) -> np.ndarray | None: """Validate the ground truth label. Args: - label (int | np.ndarray | None): Input label to validate. + label (``int`` | ``np.ndarray`` | ``None``): Input label to validate. Returns: - np.ndarray | None: Validated label as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated label as boolean numpy array, or None if + input is None. Raises: - TypeError: If the input is not an integer or numpy array. - ValueError: If the label is not a scalar. + TypeError: If ``label`` is not an integer or numpy array. + ValueError: If ``label`` is not a scalar. Example: >>> validator = NumpyVideoValidator() @@ -94,14 +156,15 @@ def validate_gt_mask(mask: np.ndarray | None) -> np.ndarray | None: """Validate the ground truth mask. Args: - mask (np.ndarray | None): Input mask to validate. + mask (``np.ndarray`` | ``None``): Input mask to validate. Returns: - np.ndarray | None: Validated mask as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated mask as boolean numpy array, or None if + input is None. Raises: - TypeError: If the input is not a numpy array. - ValueError: If the mask dimensions or channel count are invalid. + TypeError: If ``mask`` is not a numpy array. + ValueError: If ``mask`` dimensions or channel count are invalid. Example: >>> import numpy as np @@ -129,10 +192,10 @@ def validate_mask_path(mask_path: str | None) -> str | None: """Validate the mask path. Args: - mask_path (str | None): Input mask path to validate. + mask_path (``str`` | ``None``): Input mask path to validate. Returns: - str | None: Validated mask path, or None if input is None. + ``str`` | ``None``: Validated mask path, or None if input is None. Example: >>> validator = NumpyVideoValidator() @@ -148,14 +211,15 @@ def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None: """Validate the anomaly map. Args: - anomaly_map (np.ndarray | None): Input anomaly map to validate. + anomaly_map (``np.ndarray`` | ``None``): Input anomaly map to validate. Returns: - np.ndarray | None: Validated anomaly map as float32 numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated anomaly map as float32 numpy array, or + None if input is None. Raises: - TypeError: If the input is not a numpy array. - ValueError: If the anomaly map dimensions or channel count are invalid. + TypeError: If ``anomaly_map`` is not a numpy array. + ValueError: If ``anomaly_map`` dimensions or channel count are invalid. Example: >>> import numpy as np @@ -183,14 +247,16 @@ def validate_pred_score(pred_score: np.ndarray | float | None) -> np.ndarray | N """Validate the prediction score. Args: - pred_score (np.ndarray | float | None): Input prediction score to validate. + pred_score (``np.ndarray`` | ``float`` | ``None``): Input prediction score to + validate. Returns: - np.ndarray | None: Validated prediction score as float32 numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated prediction score as float32 numpy array, + or None if input is None. Raises: - TypeError: If the input is not a float or numpy array. - ValueError: If the prediction score is not a scalar. + TypeError: If ``pred_score`` is not a float or numpy array. + ValueError: If ``pred_score`` is not a scalar. Example: >>> validator = NumpyVideoValidator() @@ -216,10 +282,11 @@ def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: """Validate the prediction mask. Args: - pred_mask (np.ndarray | None): Input prediction mask to validate. + pred_mask (``np.ndarray`` | ``None``): Input prediction mask to validate. Returns: - np.ndarray | None: Validated prediction mask as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated prediction mask as boolean numpy array, + or None if input is None. Example: >>> import numpy as np @@ -236,13 +303,15 @@ def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: """Validate the prediction label. Args: - pred_label (np.ndarray | None): Input prediction label to validate. + pred_label (``np.ndarray`` | ``None``): Input prediction label to validate. Returns: - np.ndarray | None: Validated prediction label as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated prediction label as boolean numpy array, + or None if input is None. Raises: - ValueError: If the input cannot be converted to a numpy array or is not a scalar. + ValueError: If ``pred_label`` cannot be converted to a numpy array or is not + a scalar. Example: >>> import numpy as np @@ -271,10 +340,10 @@ def validate_video_path(video_path: str | None) -> str | None: """Validate the video path. Args: - video_path (str | None): Input video path to validate. + video_path (``str`` | ``None``): Input video path to validate. Returns: - str | None: Validated video path, or None if input is None. + ``str`` | ``None``: Validated video path, or None if input is None. Example: >>> validator = NumpyVideoValidator() @@ -290,14 +359,14 @@ def validate_original_image(original_image: np.ndarray | None) -> np.ndarray | N """Validate the original video. Args: - original_image (np.ndarray | None): Input original video to validate. + original_image (``np.ndarray`` | ``None``): Input original video to validate. Returns: - np.ndarray | None: Validated original video, or None if input is None. + ``np.ndarray`` | ``None``: Validated original video, or None if input is None. Raises: - TypeError: If the input is not a numpy array. - ValueError: If the original video dimensions or channel count are invalid. + TypeError: If ``original_image`` is not a numpy array. + ValueError: If ``original_image`` dimensions or channel count are invalid. Example: >>> import numpy as np @@ -325,14 +394,14 @@ def validate_target_frame(target_frame: int | None) -> int | None: """Validate the target frame index. Args: - target_frame (int | None): Input target frame index to validate. + target_frame (``int`` | ``None``): Input target frame index to validate. Returns: - int | None: Validated target frame index, or None if input is None. + ``int`` | ``None``: Validated target frame index, or None if input is None. Raises: - TypeError: If the input is not an integer. - ValueError: If the target frame index is negative. + TypeError: If ``target_frame`` is not an integer. + ValueError: If ``target_frame`` is negative. Example: >>> validator = NumpyVideoValidator() @@ -358,17 +427,45 @@ def validate_explanation(explanation: str | None) -> str | None: class NumpyVideoBatchValidator: - """Validate numpy.ndarray data for batches of videos.""" + """Validate numpy array data for batches of videos. + + This class provides validation methods for batches of video data stored as numpy arrays. + It ensures data consistency and correctness for video batches and associated metadata. + + The validator checks: + - Array shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a batch of videos and associated metadata:: + + >>> from anomalib.data.validators import NumpyVideoBatchValidator + >>> validator = NumpyVideoBatchValidator() + >>> videos = np.random.rand(32, 10, 224, 224, 3) # [N, T, H, W, C] + >>> labels = np.zeros(32) + >>> masks = np.zeros((32, 10, 224, 224)) + >>> validated_videos = validator.validate_image(videos) + >>> validated_labels = validator.validate_gt_label(labels) + >>> validated_masks = validator.validate_gt_mask(masks) + + Note: + The validator is used internally by the data modules to ensure data + consistency before processing. + """ @staticmethod def validate_image(image: np.ndarray) -> np.ndarray: """Validate the video batch array. Args: - image (np.ndarray): Input video batch array to validate. + image (``np.ndarray``): Input video batch array to validate. Returns: - np.ndarray: Validated video batch array as float32. + ``np.ndarray``: Validated video batch array as float32. Raises: TypeError: If the input is not a numpy array. @@ -402,10 +499,12 @@ def validate_gt_label(gt_label: np.ndarray | Sequence[int] | None) -> np.ndarray """Validate the ground truth label batch. Args: - gt_label (np.ndarray | Sequence[int] | None): Input ground truth label batch to validate. + gt_label (``np.ndarray`` | ``Sequence[int]`` | ``None``): Input ground truth + label batch to validate. Returns: - np.ndarray | None: Validated ground truth label batch as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated ground truth label batch as boolean numpy + array, or None if input is None. Raises: TypeError: If the input is not a numpy array or sequence of integers. @@ -436,10 +535,11 @@ def validate_gt_mask(gt_mask: np.ndarray | None) -> np.ndarray | None: """Validate the ground truth mask batch. Args: - gt_mask (np.ndarray | None): Input ground truth mask batch to validate. + gt_mask (``np.ndarray`` | ``None``): Input ground truth mask batch to validate. Returns: - np.ndarray | None: Validated ground truth mask batch as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated ground truth mask batch as boolean numpy + array, or None if input is None. Raises: TypeError: If the input is not a numpy array. @@ -471,10 +571,10 @@ def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: """Validate the mask paths for a batch. Args: - mask_path (Sequence[str] | None): Input mask paths to validate. + mask_path (``Sequence[str]`` | ``None``): Input mask paths to validate. Returns: - list[str] | None: Validated mask paths, or None if input is None. + ``list[str]`` | ``None``: Validated mask paths, or None if input is None. Example: >>> validator = NumpyVideoBatchValidator() @@ -490,10 +590,11 @@ def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None: """Validate the anomaly map batch. Args: - anomaly_map (np.ndarray | None): Input anomaly map batch to validate. + anomaly_map (``np.ndarray`` | ``None``): Input anomaly map batch to validate. Returns: - np.ndarray | None: Validated anomaly map batch as float32 numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated anomaly map batch as float32 numpy array, + or None if input is None. Raises: TypeError: If the input is not a numpy array. @@ -525,10 +626,11 @@ def validate_pred_score(pred_score: np.ndarray | None) -> np.ndarray | None: """Validate the prediction scores for a batch. Args: - pred_score (np.ndarray | None): Input prediction scores to validate. + pred_score (``np.ndarray`` | ``None``): Input prediction scores to validate. Returns: - np.ndarray | None: Validated prediction scores as float32 numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated prediction scores as float32 numpy array, + or None if input is None. Raises: TypeError: If the input is not a numpy array. @@ -557,10 +659,11 @@ def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None: """Validate the prediction mask batch. Args: - pred_mask (np.ndarray | None): Input prediction mask batch to validate. + pred_mask (``np.ndarray`` | ``None``): Input prediction mask batch to validate. Returns: - np.ndarray | None: Validated prediction mask batch as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated prediction mask batch as boolean numpy + array, or None if input is None. Example: >>> import numpy as np @@ -577,10 +680,12 @@ def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None: """Validate the prediction label batch. Args: - pred_label (np.ndarray | None): Input prediction label batch to validate. + pred_label (``np.ndarray`` | ``None``): Input prediction label batch to + validate. Returns: - np.ndarray | None: Validated prediction label batch as boolean numpy array, or None if input is None. + ``np.ndarray`` | ``None``: Validated prediction label batch as boolean numpy + array, or None if input is None. Raises: TypeError: If the input is not a numpy array. @@ -609,10 +714,10 @@ def validate_video_path(video_path: list[str] | None) -> list[str] | None: """Validate the video paths for a batch. Args: - video_path (list[str] | None): Input video paths to validate. + video_path (``list[str]`` | ``None``): Input video paths to validate. Returns: - list[str] | None: Validated video paths, or None if input is None. + ``list[str]`` | ``None``: Validated video paths, or None if input is None. Example: >>> validator = NumpyVideoBatchValidator() @@ -628,10 +733,12 @@ def validate_original_image(original_image: np.ndarray | None) -> np.ndarray | N """Validate the original video batch. Args: - original_image (np.ndarray | None): Input original video batch to validate. + original_image (``np.ndarray`` | ``None``): Input original video batch to + validate. Returns: - np.ndarray | None: Validated original video batch, or None if input is None. + ``np.ndarray`` | ``None``: Validated original video batch, or None if input is + None. Raises: TypeError: If the input is not a numpy array. @@ -666,10 +773,12 @@ def validate_target_frame(target_frame: np.ndarray | None) -> np.ndarray | None: """Validate the target frame indices for a batch. Args: - target_frame (np.ndarray | None): Input target frame indices to validate. + target_frame (``np.ndarray`` | ``None``): Input target frame indices to + validate. Returns: - np.ndarray | None: Validated target frame indices, or None if input is None. + ``np.ndarray`` | ``None``: Validated target frame indices, or None if input is + None. Raises: TypeError: If the input is not a numpy array of integers. diff --git a/src/anomalib/data/validators/path.py b/src/anomalib/data/validators/path.py index 0ee5080710..36fcac6221 100644 --- a/src/anomalib/data/validators/path.py +++ b/src/anomalib/data/validators/path.py @@ -1,4 +1,35 @@ -"""Validate IO path data.""" +"""Validate IO path data. + +This module provides validators for file system paths. The validators ensure path +consistency and correctness. + +The validators check: + - Path types (str vs Path objects) + - Path string formatting + - Batch size consistency + - None handling + +Example: + Validate a single path:: + + >>> from anomalib.data.validators import validate_path + >>> path = "/path/to/file.jpg" + >>> validated = validate_path(path) + >>> validated == path + True + + Validate a batch of paths:: + + >>> from anomalib.data.validators import validate_batch_path + >>> paths = ["/path/1.jpg", "/path/2.jpg"] + >>> validated = validate_batch_path(paths, batch_size=2) + >>> len(validated) + 2 + +Note: + The validators are used internally by the data modules to ensure path + consistency before processing. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,24 +41,37 @@ def validate_path(path: str | Path) -> str: """Validate a single input path. + This function validates and normalizes file system paths. It accepts string paths or + ``pathlib.Path`` objects and converts them to string format. + Args: - path: The input path to validate. Can be None, a string, or a Path object. + path (``str`` | ``Path``): Input path to validate. Can be a string path or + ``pathlib.Path`` object. Returns: - - None if the input is None - - A string representing the validated path + ``str``: The validated path as a string. Raises: - TypeError: If the input is not None, a string, or a Path object. + TypeError: If ``path`` is not a string or ``Path`` object. Examples: - >>> validate_path(None) - None - >>> validate_path("/path/to/file.png") - '/path/to/file.png' - >>> from pathlib import Path - >>> validate_path(Path("/path/to/file.png")) - '/path/to/file.png' + Validate a string path:: + + >>> validate_path("/path/to/file.png") + '/path/to/file.png' + + Validate a Path object:: + + >>> from pathlib import Path + >>> validate_path(Path("/path/to/file.png")) + '/path/to/file.png' + + Invalid input raises TypeError:: + + >>> validate_path(123) + Traceback (most recent call last): + ... + TypeError: Path must be None, a string, or Path object, got . """ if isinstance(path, str | Path): return str(path) @@ -41,28 +85,51 @@ def validate_batch_path( ) -> list[str] | None: """Validate a batch of input paths. + This function validates and normalizes a sequence of file system paths. It accepts a + sequence of string paths or ``pathlib.Path`` objects and converts them to a list of + string paths. Optionally checks if the number of paths matches an expected batch size. + Args: - paths: A sequence of paths to validate, or None. - batch_size: The expected number of paths. Defaults to None, in which case no batch size check is performed. + paths (``Sequence[str | Path] | None``): A sequence of paths to validate, or + ``None``. Each path can be a string or ``pathlib.Path`` object. + batch_size (``int | None``, optional): The expected number of paths. If specified, + validates that the number of paths matches this value. Defaults to ``None``, + in which case no batch size check is performed. Returns: - - None if the input is None - - A list of strings representing validated paths + ``list[str] | None``: A list of validated paths as strings, or ``None`` if the + input is ``None``. Raises: - TypeError: If the input is not None or a sequence of strings or Path objects. - ValueError: If a batch_size is specified and the number of paths doesn't match it. + TypeError: If ``paths`` is not ``None`` or a sequence of strings/``Path`` objects. + ValueError: If ``batch_size`` is specified and the number of paths doesn't match. Examples: - >>> paths = ["/path/to/file1.png", Path("/path/to/file2.png")] - >>> validate_batch_path(paths, batch_size=2) - ['/path/to/file1.png', '/path/to/file2.png'] - >>> validate_batch_path(paths) # Without specifying batch_size - ['/path/to/file1.png', '/path/to/file2.png'] - >>> validate_batch_path(paths, batch_size=3) - Traceback (most recent call last): - ... - ValueError: Number of paths (2) does not match the specified batch size (3). + Validate a list of paths with batch size check:: + + >>> from pathlib import Path + >>> paths = ["/path/to/file1.png", Path("/path/to/file2.png")] + >>> validate_batch_path(paths, batch_size=2) + ['/path/to/file1.png', '/path/to/file2.png'] + + Validate without batch size check:: + + >>> validate_batch_path(paths) # Without specifying batch_size + ['/path/to/file1.png', '/path/to/file2.png'] + + Batch size mismatch raises ValueError:: + + >>> validate_batch_path(paths, batch_size=3) + Traceback (most recent call last): + ... + ValueError: Number of paths (2) does not match the specified batch size (3). + + Invalid input type raises TypeError:: + + >>> validate_batch_path("not_a_sequence") + Traceback (most recent call last): + ... + TypeError: Paths must be None or a sequence of strings or Path objects... """ if paths is None: return None diff --git a/src/anomalib/data/validators/torch/__init__.py b/src/anomalib/data/validators/torch/__init__.py index 14253a93c7..8a654e282b 100644 --- a/src/anomalib/data/validators/torch/__init__.py +++ b/src/anomalib/data/validators/torch/__init__.py @@ -1,4 +1,33 @@ -"""Anomalib Torch data validators.""" +"""Validate PyTorch tensor data. + +This module provides validators for data stored as PyTorch tensors. The validators +ensure data consistency and correctness for images, videos, depth maps and their +batches. + +The validators check: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + +Example: + Validate a single image:: + + >>> from anomalib.data.validators import ImageValidator + >>> validator = ImageValidator() + >>> validator.validate_image(image) + + Validate a batch of images:: + + >>> from anomalib.data.validators import ImageBatchValidator + >>> validator = ImageBatchValidator() + >>> validator(images=images, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/data/validators/torch/depth.py b/src/anomalib/data/validators/torch/depth.py index 6869769ad6..d20f6ffaa4 100644 --- a/src/anomalib/data/validators/torch/depth.py +++ b/src/anomalib/data/validators/torch/depth.py @@ -1,4 +1,33 @@ -"""Validate torch depth data.""" +"""Validate PyTorch tensor data for depth maps. + +This module provides validators for depth data stored as PyTorch tensors. The validators +ensure data consistency and correctness for depth maps and their batches. + +The validators check: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + +Example: + Validate a single depth map:: + + >>> from anomalib.data.validators import DepthValidator + >>> validator = DepthValidator() + >>> validator.validate_depth_map(depth_map) + + Validate a batch of depth maps:: + + >>> from anomalib.data.validators import DepthBatchValidator + >>> validator = DepthBatchValidator() + >>> validator(depth_maps=depth_maps, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing depth data. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -15,29 +44,65 @@ class DepthValidator: - """Validate torch.Tensor data for depth images.""" + """Validate torch.Tensor data for depth images. + + This class provides validation methods for depth data stored as PyTorch tensors. + It ensures data consistency and correctness for depth maps and associated metadata. + + The validator checks: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a depth map and associated metadata:: + + >>> from anomalib.data.validators import DepthValidator + >>> validator = DepthValidator() + >>> depth_map = torch.rand(224, 224) # [H, W] + >>> validated_map = validator.validate_depth_map(depth_map) + >>> label = 1 + >>> validated_label = validator.validate_gt_label(label) + >>> mask = torch.randint(0, 2, (1, 224, 224)) # [1, H, W] + >>> validated_mask = validator.validate_gt_mask(mask) + + Note: + The validator is used internally by the data modules to ensure data + consistency before processing. + """ @staticmethod def validate_image(image: torch.Tensor) -> Image: """Validate the image tensor. + This method validates and normalizes input image tensors. It handles: + - RGB images only + - Channel-first format [C, H, W] + - Type conversion to float32 + - Value range normalization + Args: - image (torch.Tensor): Input image tensor. + image (``torch.Tensor``): Input image tensor to validate. Returns: - Image: Validated image as a torchvision Image object. + ``Image``: Validated image as a torchvision Image object. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the image tensor does not have the correct shape. - - Examples: - >>> import torch - >>> from anomalib.data.validators import DepthValidator - >>> image = torch.rand(3, 256, 256) - >>> validated_image = DepthValidator.validate_image(image) - >>> validated_image.shape - torch.Size([3, 256, 256]) + TypeError: If ``image`` is not a torch.Tensor. + ValueError: If ``image`` dimensions or channels are invalid. + + Example: + Validate RGB image:: + + >>> import torch + >>> from anomalib.data.validators import DepthValidator + >>> image = torch.rand(3, 256, 256) # [C, H, W] + >>> validated = DepthValidator.validate_image(image) + >>> validated.shape + torch.Size([3, 256, 256]) """ if not isinstance(image, torch.Tensor): msg = f"Image must be a torch.Tensor, got {type(image)}." @@ -54,27 +119,33 @@ def validate_image(image: torch.Tensor) -> Image: def validate_gt_label(label: int | torch.Tensor | None) -> torch.Tensor | None: """Validate the ground truth label. + This method validates and normalizes input labels. It handles: + - Integer and tensor inputs + - Type conversion to boolean + - Scalar values only + Args: - label (int | torch.Tensor | None): Input ground truth label. + label (``int`` | ``torch.Tensor`` | ``None``): Input ground truth label. Returns: - torch.Tensor | None: Validated ground truth label as a boolean tensor, or None. + ``torch.Tensor`` | ``None``: Validated ground truth label as boolean tensor. Raises: - TypeError: If the input is neither an integer nor a torch.Tensor. - ValueError: If the label shape or dtype is invalid. - - Examples: - >>> import torch - >>> from anomalib.data.validators import DepthValidator - >>> label_int = 1 - >>> validated_label = DepthValidator.validate_gt_label(label_int) - >>> validated_label - tensor(True) - >>> label_tensor = torch.tensor(0) - >>> validated_label = DepthValidator.validate_gt_label(label_tensor) - >>> validated_label - tensor(False) + TypeError: If ``label`` is neither an integer nor a torch.Tensor. + ValueError: If ``label`` shape is invalid. + + Example: + Validate integer and tensor labels:: + + >>> from anomalib.data.validators import DepthValidator + >>> label_int = 1 + >>> validated = DepthValidator.validate_gt_label(label_int) + >>> validated + tensor(True) + >>> label_tensor = torch.tensor(0) + >>> validated = DepthValidator.validate_gt_label(label_tensor) + >>> validated + tensor(False) """ if label is None: return None @@ -95,25 +166,33 @@ def validate_gt_label(label: int | torch.Tensor | None) -> torch.Tensor | None: def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: """Validate the ground truth mask. + This method validates and normalizes input masks. It handles: + - 2D and 3D inputs + - Single-channel masks + - Type conversion to boolean + - Channel dimension squeezing + Args: - mask (torch.Tensor | None): Input ground truth mask. + mask (``torch.Tensor`` | ``None``): Input ground truth mask. Returns: - Mask | None: Validated ground truth mask, or None. + ``Mask`` | ``None``: Validated ground truth mask as torchvision Mask. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the mask shape is invalid. - - Examples: - >>> import torch - >>> from anomalib.data.validators import DepthValidator - >>> mask = torch.randint(0, 2, (1, 224, 224)) - >>> validated_mask = DepthValidator.validate_gt_mask(mask) - >>> isinstance(validated_mask, Mask) - True - >>> validated_mask.shape - torch.Size([224, 224]) + TypeError: If ``mask`` is not a torch.Tensor. + ValueError: If ``mask`` dimensions or channels are invalid. + + Example: + Validate binary segmentation mask:: + + >>> import torch + >>> from anomalib.data.validators import DepthValidator + >>> mask = torch.randint(0, 2, (1, 224, 224)) # [1, H, W] + >>> validated = DepthValidator.validate_gt_mask(mask) + >>> isinstance(validated, Mask) + True + >>> validated.shape + torch.Size([224, 224]) """ if mask is None: return None @@ -134,18 +213,22 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: def validate_image_path(image_path: str | None) -> str | None: """Validate the image path. + This method validates input image file paths. + Args: - image_path (str | None): Input image path. + image_path (``str`` | ``None``): Input image path to validate. Returns: - str | None: Validated image path, or None. + ``str`` | ``None``: Validated image path, or None. - Examples: - >>> from anomalib.data.validators import DepthValidator - >>> path = "/path/to/image.jpg" - >>> validated_path = DepthValidator.validate_image_path(path) - >>> validated_path == path - True + Example: + Validate image file path:: + + >>> from anomalib.data.validators import DepthValidator + >>> path = "/path/to/image.jpg" + >>> validated = DepthValidator.validate_image_path(path) + >>> validated == path + True """ return validate_path(image_path) if image_path else None @@ -153,23 +236,30 @@ def validate_image_path(image_path: str | None) -> str | None: def validate_depth_map(depth_map: torch.Tensor | None) -> torch.Tensor | None: """Validate the depth map. + This method validates and normalizes input depth maps. It handles: + - 2D and 3D inputs + - Single and multi-channel depth maps + - Type conversion to float32 + Args: - depth_map (torch.Tensor | None): Input depth map. + depth_map (``torch.Tensor`` | ``None``): Input depth map to validate. Returns: - torch.Tensor | None: Validated depth map, or None. + ``torch.Tensor`` | ``None``: Validated depth map as float32 tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the depth map shape is invalid. - - Examples: - >>> import torch - >>> from anomalib.data.validators import DepthValidator - >>> depth_map = torch.rand(224, 224) - >>> validated_map = DepthValidator.validate_depth_map(depth_map) - >>> validated_map.shape - torch.Size([224, 224]) + TypeError: If ``depth_map`` is not a torch.Tensor. + ValueError: If ``depth_map`` dimensions or channels are invalid. + + Example: + Validate single-channel depth map:: + + >>> import torch + >>> from anomalib.data.validators import DepthValidator + >>> depth_map = torch.rand(224, 224) # [H, W] + >>> validated = DepthValidator.validate_depth_map(depth_map) + >>> validated.shape + torch.Size([224, 224]) """ if depth_map is None: return None @@ -188,18 +278,22 @@ def validate_depth_map(depth_map: torch.Tensor | None) -> torch.Tensor | None: def validate_depth_path(depth_path: str | None) -> str | None: """Validate the depth path. + This method validates input depth map file paths. + Args: - depth_path (str | None): Input depth path. + depth_path (``str`` | ``None``): Input depth path to validate. Returns: - str | None: Validated depth path, or None. + ``str`` | ``None``: Validated depth path, or None. - Examples: - >>> from anomalib.data.validators import DepthValidator - >>> path = "/path/to/depth.png" - >>> validated_path = DepthValidator.validate_depth_path(path) - >>> validated_path == path - True + Example: + Validate depth map file path:: + + >>> from anomalib.data.validators import DepthValidator + >>> path = "/path/to/depth.png" + >>> validated = DepthValidator.validate_depth_path(path) + >>> validated == path + True """ return validate_path(depth_path) if depth_path else None @@ -235,28 +329,62 @@ def validate_explanation(explanation: str | None) -> str | None: class DepthBatchValidator: - """Validate torch.Tensor data for batches of depth images.""" + """Validate torch.Tensor data for batches of depth images. + + This class provides validation methods for batches of depth data stored as PyTorch tensors. + It ensures data consistency and correctness for depth maps and associated metadata. + + The validator checks: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a batch of depth maps and associated metadata:: + + >>> from anomalib.data.validators import DepthBatchValidator + >>> validator = DepthBatchValidator() + >>> depth_maps = torch.rand(32, 224, 224) # [N, H, W] + >>> labels = torch.zeros(32) + >>> masks = torch.zeros((32, 224, 224)) + >>> validated_maps = validator.validate_depth_map(depth_maps) + >>> validated_labels = validator.validate_gt_label(labels) + >>> validated_masks = validator.validate_gt_mask(masks) + + Note: + The validator is used internally by the data modules to ensure data + consistency before processing. + """ @staticmethod def validate_image(image: torch.Tensor) -> Image: """Validate the image tensor for a batch. + This method validates batches of images stored as PyTorch tensors. It handles: + - Channel-first format [N, C, H, W] + - RGB images only + - Type conversion to float32 + - Value range normalization + Args: - image (torch.Tensor): Input image tensor. + image (``torch.Tensor``): Input image tensor to validate. Returns: - Image: Validated image as a torchvision Image object. + ``Image``: Validated image as a torchvision Image object. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the image tensor does not have the correct shape. + TypeError: If ``image`` is not a torch.Tensor. + ValueError: If ``image`` dimensions or channels are invalid. - Examples: + Example: >>> import torch >>> from anomalib.data.validators import DepthBatchValidator - >>> image = torch.rand(32, 3, 256, 256) - >>> validated_image = DepthBatchValidator.validate_image(image) - >>> validated_image.shape + >>> image = torch.rand(32, 3, 256, 256) # [N, C, H, W] + >>> validated = DepthBatchValidator.validate_image(image) + >>> validated.shape torch.Size([32, 3, 256, 256]) """ if not isinstance(image, torch.Tensor): @@ -274,22 +402,29 @@ def validate_image(image: torch.Tensor) -> Image: def validate_gt_label(gt_label: torch.Tensor | Sequence[int] | None) -> torch.Tensor | None: """Validate the ground truth label for a batch. + This method validates ground truth labels for batches. It handles: + - Conversion to boolean tensor + - Batch dimension validation + - None inputs + Args: - gt_label (torch.Tensor | Sequence[int] | None): Input ground truth label. + gt_label (``torch.Tensor`` | ``Sequence[int]`` | ``None``): Input ground truth + label to validate. Returns: - torch.Tensor | None: Validated ground truth label as a boolean tensor, or None. + ``torch.Tensor`` | ``None``: Validated ground truth label as a boolean tensor, + or None. Raises: - TypeError: If the input is not a sequence of integers or a torch.Tensor. - ValueError: If the ground truth label does not match the expected batch size or data type. + TypeError: If ``gt_label`` is not a sequence of integers or torch.Tensor. + ValueError: If ``gt_label`` does not match expected batch size or data type. - Examples: + Example: >>> import torch >>> from anomalib.data.validators import DepthBatchValidator >>> gt_label = torch.tensor([0, 1, 1, 0]) - >>> validated_label = DepthBatchValidator.validate_gt_label(gt_label) - >>> print(validated_label) + >>> validated = DepthBatchValidator.validate_gt_label(gt_label) + >>> print(validated) tensor([False, True, True, False]) """ return ImageBatchValidator.validate_gt_label(gt_label) @@ -298,22 +433,28 @@ def validate_gt_label(gt_label: torch.Tensor | Sequence[int] | None) -> torch.Te def validate_gt_mask(gt_mask: torch.Tensor | None) -> Mask | None: """Validate the ground truth mask for a batch. + This method validates ground truth masks for batches. It handles: + - Batch dimension validation + - Binary mask values + - None inputs + Args: - gt_mask (torch.Tensor | None): Input ground truth mask. + gt_mask (``torch.Tensor`` | ``None``): Input ground truth mask to validate. Returns: - Mask | None: Validated ground truth mask as a torchvision Mask object, or None. + ``Mask`` | ``None``: Validated ground truth mask as a torchvision Mask object, + or None. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the ground truth mask does not have the correct shape or batch size. + TypeError: If ``gt_mask`` is not a torch.Tensor. + ValueError: If ``gt_mask`` shape or batch size is invalid. - Examples: + Example: >>> import torch >>> from anomalib.data.validators import DepthBatchValidator - >>> gt_mask = torch.randint(0, 2, (4, 224, 224)) - >>> validated_mask = DepthBatchValidator.validate_gt_mask(gt_mask) - >>> print(validated_mask.shape) + >>> gt_mask = torch.randint(0, 2, (4, 224, 224)) # [N, H, W] + >>> validated = DepthBatchValidator.validate_gt_mask(gt_mask) + >>> print(validated.shape) torch.Size([4, 224, 224]) """ return ImageBatchValidator.validate_gt_mask(gt_mask) @@ -322,21 +463,26 @@ def validate_gt_mask(gt_mask: torch.Tensor | None) -> Mask | None: def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: """Validate the mask paths for a batch. + This method validates file paths for batches of mask images. It handles: + - Path existence validation + - Batch size consistency + - None inputs + Args: - mask_path (Sequence[str] | None): Input sequence of mask paths. + mask_path (``Sequence[str]`` | ``None``): Input sequence of mask paths. Returns: - list[str] | None: Validated list of mask paths, or None. + ``list[str]`` | ``None``: Validated list of mask paths, or None. Raises: - TypeError: If the input is not a sequence of strings. - ValueError: If the number of mask paths does not match the expected batch size. + TypeError: If ``mask_path`` is not a sequence of strings. + ValueError: If number of paths does not match expected batch size. - Examples: + Example: >>> from anomalib.data.validators import DepthBatchValidator - >>> mask_paths = ["path/to/mask_1.png", "path/to/mask_2.png"] - >>> validated_paths = DepthBatchValidator.validate_mask_path(mask_paths) - >>> print(validated_paths) + >>> paths = ["path/to/mask_1.png", "path/to/mask_2.png"] + >>> validated = DepthBatchValidator.validate_mask_path(paths) + >>> print(validated) ['path/to/mask_1.png', 'path/to/mask_2.png'] """ return ImageBatchValidator.validate_mask_path(mask_path) @@ -345,20 +491,25 @@ def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: def validate_image_path(image_path: list[str] | None) -> list[str] | None: """Validate the image paths for a batch. + This method validates file paths for batches of images. It handles: + - Path existence validation + - Batch size consistency + - None inputs + Args: - image_path (list[str] | None): Input list of image paths. + image_path (``list[str]`` | ``None``): Input list of image paths. Returns: - list[str] | None: Validated list of image paths, or None. + ``list[str]`` | ``None``: Validated list of image paths, or None. Raises: - TypeError: If the input is not a list of strings. + TypeError: If ``image_path`` is not a list of strings. - Examples: + Example: >>> from anomalib.data.validators import DepthBatchValidator - >>> image_paths = ["path/to/image_1.jpg", "path/to/image_2.jpg"] - >>> validated_paths = DepthBatchValidator.validate_image_path(image_paths) - >>> print(validated_paths) + >>> paths = ["path/to/image_1.jpg", "path/to/image_2.jpg"] + >>> validated = DepthBatchValidator.validate_image_path(paths) + >>> print(validated) ['path/to/image_1.jpg', 'path/to/image_2.jpg'] """ return ImageBatchValidator.validate_image_path(image_path) @@ -367,22 +518,28 @@ def validate_image_path(image_path: list[str] | None) -> list[str] | None: def validate_depth_map(depth_map: torch.Tensor | None) -> torch.Tensor | None: """Validate the depth map for a batch. + This method validates batches of depth maps. It handles: + - Single-channel and RGB depth maps + - Batch dimension validation + - Type conversion to float32 + - None inputs + Args: - depth_map (torch.Tensor | None): Input depth map. + depth_map (``torch.Tensor`` | ``None``): Input depth map to validate. Returns: - torch.Tensor | None: Validated depth map, or None. + ``torch.Tensor`` | ``None``: Validated depth map as float32, or None. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the depth map shape is invalid or doesn't match the batch size. + TypeError: If ``depth_map`` is not a torch.Tensor. + ValueError: If ``depth_map`` shape is invalid or batch size mismatch. - Examples: + Example: >>> import torch >>> from anomalib.data.validators import DepthBatchValidator - >>> depth_map = torch.rand(4, 224, 224) - >>> validated_map = DepthBatchValidator.validate_depth_map(depth_map) - >>> print(validated_map.shape) + >>> depth_map = torch.rand(4, 224, 224) # [N, H, W] + >>> validated = DepthBatchValidator.validate_depth_map(depth_map) + >>> print(validated.shape) torch.Size([4, 224, 224]) """ if depth_map is None: @@ -402,20 +559,25 @@ def validate_depth_map(depth_map: torch.Tensor | None) -> torch.Tensor | None: def validate_depth_path(depth_path: list[str] | None) -> list[str] | None: """Validate the depth paths for a batch. + This method validates file paths for batches of depth maps. It handles: + - Path existence validation + - Batch size consistency + - None inputs + Args: - depth_path (list[str] | None): Input list of depth paths. + depth_path (``list[str]`` | ``None``): Input list of depth paths. Returns: - list[str] | None: Validated list of depth paths, or None. + ``list[str]`` | ``None``: Validated list of depth paths, or None. Raises: - TypeError: If the input is not a list of strings. + TypeError: If ``depth_path`` is not a list of strings. - Examples: + Example: >>> from anomalib.data.validators import DepthBatchValidator - >>> depth_paths = ["path/to/depth_1.png", "path/to/depth_2.png"] - >>> validated_paths = DepthBatchValidator.validate_depth_path(depth_paths) - >>> print(validated_paths) + >>> paths = ["path/to/depth_1.png", "path/to/depth_2.png"] + >>> validated = DepthBatchValidator.validate_depth_path(paths) + >>> print(validated) ['path/to/depth_1.png', 'path/to/depth_2.png'] """ if depth_path is None: diff --git a/src/anomalib/data/validators/torch/image.py b/src/anomalib/data/validators/torch/image.py index c9a8ac07cb..06a729ff92 100644 --- a/src/anomalib/data/validators/torch/image.py +++ b/src/anomalib/data/validators/torch/image.py @@ -1,4 +1,33 @@ -"""Validate torch image data.""" +"""Validate PyTorch tensor data for images. + +This module provides validators for image data stored as PyTorch tensors. The validators +ensure data consistency and correctness for images and their batches. + +The validators check: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + +Example: + Validate a single image:: + + >>> from anomalib.data.validators import ImageValidator + >>> validator = ImageValidator() + >>> validator.validate_image(image) + + Validate a batch of images:: + + >>> from anomalib.data.validators import ImageBatchValidator + >>> validator = ImageBatchValidator() + >>> validator(images=images, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -14,26 +43,60 @@ class ImageValidator: - """Validate torch.Tensor data for images.""" + """Validate torch.Tensor data for images. + + This class provides validation methods for image data stored as PyTorch tensors. + It ensures data consistency and correctness for images and associated metadata. + + The validator checks: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate an image and associated metadata:: + + >>> from anomalib.data.validators import ImageValidator + >>> validator = ImageValidator() + >>> image = torch.rand(3, 224, 224) # [C, H, W] + >>> validated_image = validator.validate_image(image) + >>> label = 1 + >>> validated_label = validator.validate_gt_label(label) + >>> mask = torch.randint(0, 2, (1, 224, 224)) # [1, H, W] + >>> validated_mask = validator.validate_gt_mask(mask) + + Note: + The validator is used internally by the data modules to ensure data + consistency before processing. + """ @staticmethod def validate_image(image: torch.Tensor) -> torch.Tensor: """Validate the image tensor. + This method validates and normalizes input image tensors. It handles: + - RGB images only + - Channel-first format [C, H, W] + - Type conversion to float32 + - Value range normalization + Args: - image (torch.Tensor): Input image tensor. + image (``torch.Tensor``): Input image tensor to validate. Returns: - torch.Tensor: Validated image tensor. + ``torch.Tensor``: Validated image tensor in [C, H, W] format. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the image tensor does not have the correct shape. + TypeError: If ``image`` is not a torch.Tensor. + ValueError: If ``image`` dimensions or channels are invalid. - Examples: + Example: >>> import torch >>> from anomalib.data.validators import ImageValidator - >>> image = torch.rand(3, 256, 256) + >>> image = torch.rand(3, 256, 256) # [C, H, W] >>> validated_image = ImageValidator.validate_image(image) >>> validated_image.shape torch.Size([3, 256, 256]) @@ -53,19 +116,26 @@ def validate_image(image: torch.Tensor) -> torch.Tensor: def validate_gt_label(label: int | torch.Tensor | None) -> torch.Tensor | None: """Validate the ground truth label. + This method validates and normalizes input labels. It handles: + - Integer and tensor inputs + - Type conversion to boolean + - Output is always boolean type + - None inputs return None + Args: - label (int | torch.Tensor | None): Input ground truth label. + label (``int`` | ``torch.Tensor`` | ``None``): Input ground truth label. Returns: - torch.Tensor | None: Validated ground truth label as a boolean tensor, or None. + ``torch.Tensor`` | ``None``: Validated ground truth label as a boolean + tensor, or None. Raises: - TypeError: If the input is neither an integer nor a torch.Tensor. - ValueError: If the label shape or dtype is invalid. + TypeError: If ``label`` is neither an integer nor a torch.Tensor. + ValueError: If ``label`` shape or dtype is invalid. - Examples: + Example: >>> import torch - >>> from anomalib.dataclasses.validators import ImageValidator + >>> from anomalib.data.validators import ImageValidator >>> label_int = 1 >>> validated_label = ImageValidator.validate_gt_label(label_int) >>> validated_label @@ -94,20 +164,27 @@ def validate_gt_label(label: int | torch.Tensor | None) -> torch.Tensor | None: def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: """Validate the ground truth mask. + This method validates and normalizes input masks. It handles: + - Single channel masks only + - [H, W] and [1, H, W] formats + - Type conversion to boolean + - None inputs return None + Args: - mask (torch.Tensor | None): Input ground truth mask. + mask (``torch.Tensor`` | ``None``): Input ground truth mask. Returns: - Mask | None: Validated ground truth mask, or None. + ``Mask`` | ``None``: Validated ground truth mask as a torchvision Mask + object, or None. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the mask shape is invalid. + TypeError: If ``mask`` is not a torch.Tensor. + ValueError: If ``mask`` dimensions or channels are invalid. - Examples: + Example: >>> import torch - >>> from anomalib.dataclasses.validators import ImageValidator - >>> mask = torch.randint(0, 2, (1, 224, 224)) + >>> from anomalib.data.validators import ImageValidator + >>> mask = torch.randint(0, 2, (1, 224, 224)) # [1, H, W] >>> validated_mask = ImageValidator.validate_gt_mask(mask) >>> isinstance(validated_mask, Mask) True @@ -133,20 +210,27 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: def validate_anomaly_map(anomaly_map: torch.Tensor | None) -> Mask | None: """Validate the anomaly map. + This method validates and normalizes input anomaly maps. It handles: + - Single channel maps only + - [H, W] and [1, H, W] formats + - Type conversion to float32 + - None inputs return None + Args: - anomaly_map (torch.Tensor | None): Input anomaly map. + anomaly_map (``torch.Tensor`` | ``None``): Input anomaly map. Returns: - Mask | None: Validated anomaly map as a Mask, or None. + ``Mask`` | ``None``: Validated anomaly map as a torchvision Mask object, + or None. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the anomaly map shape is invalid. + TypeError: If ``anomaly_map`` is not a torch.Tensor. + ValueError: If ``anomaly_map`` dimensions or channels are invalid. - Examples: + Example: >>> import torch - >>> from anomalib.dataclasses.validators import ImageValidator - >>> anomaly_map = torch.rand(1, 224, 224) + >>> from anomalib.data.validators import ImageValidator + >>> anomaly_map = torch.rand(1, 224, 224) # [1, H, W] >>> validated_map = ImageValidator.validate_anomaly_map(anomaly_map) >>> isinstance(validated_map, Mask) True @@ -173,14 +257,16 @@ def validate_anomaly_map(anomaly_map: torch.Tensor | None) -> Mask | None: def validate_image_path(image_path: str | None) -> str | None: """Validate the image path. + This method validates input image file paths. + Args: - image_path (str | None): Input image path. + image_path (``str`` | ``None``): Input image path to validate. Returns: - str | None: Validated image path, or None. + ``str`` | ``None``: Validated image path, or None. - Examples: - >>> from anomalib.dataclasses.validators import ImageValidator + Example: + >>> from anomalib.data.validators import ImageValidator >>> path = "/path/to/image.jpg" >>> validated_path = ImageValidator.validate_image_path(path) >>> validated_path == path @@ -192,14 +278,16 @@ def validate_image_path(image_path: str | None) -> str | None: def validate_mask_path(mask_path: str | None) -> str | None: """Validate the mask path. + This method validates input mask file paths. + Args: - mask_path (str | None): Input mask path. + mask_path (``str`` | ``None``): Input mask path to validate. Returns: - str | None: Validated mask path, or None. + ``str`` | ``None``: Validated mask path, or None. - Examples: - >>> from anomalib.dataclasses.validators import ImageValidator + Example: + >>> from anomalib.data.validators import ImageValidator >>> path = "/path/to/mask.png" >>> validated_path = ImageValidator.validate_mask_path(path) >>> validated_path == path @@ -213,17 +301,24 @@ def validate_pred_score( ) -> torch.Tensor | None: """Validate the prediction score. + This method validates and normalizes prediction scores. It handles: + - Float, numpy array and tensor inputs + - Type conversion to float32 + - None inputs return None + Args: - pred_score (torch.Tensor | float | None): Input prediction score. + pred_score (``torch.Tensor`` | ``np.ndarray`` | ``float`` | ``None``): + Input prediction score. Returns: - torch.Tensor | None: Validated prediction score as a float32 tensor, or None. + ``torch.Tensor`` | ``None``: Validated prediction score as a float32 + tensor, or None. Raises: - TypeError: If the input is neither a float, torch.Tensor, nor None. - ValueError: If the prediction score is not a scalar. + TypeError: If ``pred_score`` cannot be converted to a tensor. + ValueError: If ``pred_score`` is not a scalar. - Examples: + Example: >>> import torch >>> from anomalib.data.validators import ImageValidator >>> score = 0.8 @@ -234,9 +329,6 @@ def validate_pred_score( >>> validated_score = ImageValidator.validate_pred_score(score_tensor) >>> validated_score tensor(0.7000) - >>> validated_score = ImageValidator.validate_pred_score(None) - >>> validated_score is None - True """ if pred_score is None: return None @@ -254,17 +346,23 @@ def validate_pred_score( def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: """Validate the prediction mask. + This method validates and normalizes prediction masks. It handles: + - Single channel masks only + - [H, W] and [1, H, W] formats + - Type conversion to boolean + - None inputs return None + Args: - pred_mask (torch.Tensor | None): Input prediction mask. + pred_mask (``torch.Tensor`` | ``None``): Input prediction mask. Returns: - Mask | None: Validated prediction mask, or None. + ``Mask`` | ``None``: Validated prediction mask as a torchvision Mask + object, or None. - - Examples: + Example: >>> import torch - >>> from anomalib.dataclasses.validators import ImageValidator - >>> mask = torch.randint(0, 2, (1, 224, 224)) + >>> from anomalib.data.validators import ImageValidator + >>> mask = torch.randint(0, 2, (1, 224, 224)) # [1, H, W] >>> validated_mask = ImageValidator.validate_pred_mask(mask) >>> isinstance(validated_mask, Mask) True @@ -277,19 +375,26 @@ def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: def validate_pred_label(pred_label: torch.Tensor | np.ndarray | float | None) -> torch.Tensor | None: """Validate the prediction label. + This method validates and normalizes prediction labels. It handles: + - Float, numpy array and tensor inputs + - Type conversion to boolean + - None inputs return None + Args: - pred_label (torch.Tensor | None): Input prediction label. + pred_label (``torch.Tensor`` | ``np.ndarray`` | ``float`` | ``None``): + Input prediction label. Returns: - torch.Tensor | None: Validated prediction label as a boolean tensor, or None. + ``torch.Tensor`` | ``None``: Validated prediction label as a boolean + tensor, or None. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the prediction label is not a scalar. + TypeError: If ``pred_label`` cannot be converted to a tensor. + ValueError: If ``pred_label`` is not a scalar. - Examples: + Example: >>> import torch - >>> from anomalib.dataclasses.validators import ImageValidator + >>> from anomalib.data.validators import ImageValidator >>> label = torch.tensor(1) >>> validated_label = ImageValidator.validate_pred_label(label) >>> validated_label @@ -311,19 +416,24 @@ def validate_pred_label(pred_label: torch.Tensor | np.ndarray | float | None) -> @staticmethod def validate_explanation(explanation: str | None) -> str | None: - """Validate the explanation. + """Validate the explanation string. + + This method validates explanation strings. Args: - explanation (str | None): Input explanation. + explanation (``str`` | ``None``): Input explanation string. Returns: - str | None: Validated explanation, or None. + ``str`` | ``None``: Validated explanation string, or None. + + Raises: + TypeError: If ``explanation`` is not a string. - Examples: - >>> from anomalib.dataclasses.validators import ImageValidator + Example: + >>> from anomalib.data.validators import ImageValidator >>> explanation = "The image has a crack on the wall." - >>> validated_explanation = ImageValidator.validate_explanation(explanation) - >>> validated_explanation == explanation + >>> validated = ImageValidator.validate_explanation(explanation) + >>> validated == explanation True """ if explanation is None: @@ -335,29 +445,65 @@ def validate_explanation(explanation: str | None) -> str | None: class ImageBatchValidator: - """Validate torch.Tensor data for batches of images.""" + """Validate torch.Tensor data for batches of images. + + This class provides validation methods for batches of image data stored as PyTorch tensors. + It ensures data consistency and correctness for images and associated metadata. + + The validator checks: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + + Example: + Validate a batch of images and associated metadata:: + + >>> from anomalib.data.validators import ImageBatchValidator + >>> validator = ImageBatchValidator() + >>> images = torch.rand(32, 3, 256, 256) # [N, C, H, W] + >>> labels = torch.zeros(32) + >>> masks = torch.zeros((32, 256, 256)) + >>> validator.validate_image(images) + >>> validator.validate_gt_label(labels) + >>> validator.validate_gt_mask(masks) + + Note: + The validator is used internally by the data modules to ensure data + consistency before processing. + """ @staticmethod def validate_image(image: torch.Tensor) -> Image: """Validate the image for a batch. + This method validates batches of images stored as PyTorch tensors. It handles: + - Single images and batches + - RGB images only + - Channel-first format [N, C, H, W] + - Type conversion to float32 + Args: - image (torch.Tensor): Input image tensor. + image (``torch.Tensor``): Input image tensor to validate. Returns: - Image: Validated image as a torchvision Image object. + ``Image``: Validated image as a torchvision Image object. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the image tensor does not have the correct shape or number of channels. - - Examples: - >>> import torch - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> image = torch.rand(32, 3, 224, 224) - >>> validated_image = ImageBatchValidator.validate_image(image) - >>> print(validated_image.shape) - torch.Size([32, 3, 224, 224]) + TypeError: If ``image`` is not a torch.Tensor. + ValueError: If ``image`` dimensions or channels are invalid. + + Example: + Validate RGB batch:: + + >>> import torch + >>> from anomalib.data.validators import ImageBatchValidator + >>> image = torch.rand(32, 3, 224, 224) # [N, C, H, W] + >>> validated = ImageBatchValidator.validate_image(image) + >>> validated.shape + torch.Size([32, 3, 224, 224]) """ if not isinstance(image, torch.Tensor): msg = f"Image must be a torch.Tensor, got {type(image)}." @@ -376,23 +522,32 @@ def validate_image(image: torch.Tensor) -> Image: def validate_gt_label(gt_label: torch.Tensor | Sequence[int] | None) -> torch.Tensor | None: """Validate the ground truth label for a batch. + This method validates batches of ground truth labels. It handles: + - Conversion to torch.Tensor if needed + - Type conversion to boolean + - Shape validation + Args: - gt_label (torch.Tensor | Sequence[int] | None): Input ground truth label. + gt_label (``torch.Tensor`` | ``Sequence[int]`` | ``None``): Input ground truth + label. Returns: - torch.Tensor | None: Validated ground truth label as a boolean tensor, or None. + ``torch.Tensor`` | ``None``: Validated ground truth label as a boolean tensor, + or None. Raises: - TypeError: If the input is not a sequence of integers or a torch.Tensor. - ValueError: If the ground truth label does not match the expected batch size or data type. - - Examples: - >>> import torch - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> gt_label = torch.tensor([0, 1, 1, 0]) - >>> validated_label = ImageBatchValidator.validate_gt_label(gt_label) - >>> print(validated_label) - tensor([False, True, True, False]) + TypeError: If ``gt_label`` is not a sequence of integers or torch.Tensor. + ValueError: If ``gt_label`` shape or data type is invalid. + + Example: + Validate ground truth labels:: + + >>> import torch + >>> from anomalib.data.validators import ImageBatchValidator + >>> gt_label = torch.tensor([0, 1, 1, 0]) + >>> validated = ImageBatchValidator.validate_gt_label(gt_label) + >>> validated + tensor([False, True, True, False]) """ if gt_label is None: return None @@ -413,23 +568,31 @@ def validate_gt_label(gt_label: torch.Tensor | Sequence[int] | None) -> torch.Te def validate_gt_mask(gt_mask: torch.Tensor | None) -> Mask | None: """Validate the ground truth mask for a batch. + This method validates batches of ground truth masks. It handles: + - Single masks and batches + - Shape normalization + - Type conversion to boolean + Args: - gt_mask (torch.Tensor | None): Input ground truth mask. + gt_mask (``torch.Tensor`` | ``None``): Input ground truth mask. Returns: - Mask | None: Validated ground truth mask as a torchvision Mask object, or None. + ``Mask`` | ``None``: Validated ground truth mask as a torchvision Mask object, + or None. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the ground truth mask does not have the correct shape or batch size. - - Examples: - >>> import torch - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> gt_mask = torch.randint(0, 2, (4, 224, 224)) - >>> validated_mask = ImageBatchValidator.validate_gt_mask(gt_mask) - >>> print(validated_mask.shape) - torch.Size([4, 224, 224]) + TypeError: If ``gt_mask`` is not a torch.Tensor. + ValueError: If ``gt_mask`` shape is invalid. + + Example: + Validate ground truth masks:: + + >>> import torch + >>> from anomalib.data.validators import ImageBatchValidator + >>> gt_mask = torch.randint(0, 2, (4, 224, 224)) + >>> validated = ImageBatchValidator.validate_gt_mask(gt_mask) + >>> validated.shape + torch.Size([4, 224, 224]) """ if gt_mask is None: return None @@ -452,22 +615,25 @@ def validate_gt_mask(gt_mask: torch.Tensor | None) -> Mask | None: def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: """Validate the mask paths for a batch. + This method validates batches of mask file paths. + Args: - mask_path (Sequence[str] | None): Input sequence of mask paths. + mask_path (``Sequence[str]`` | ``None``): Input sequence of mask paths. Returns: - list[str] | None: Validated list of mask paths, or None. + ``list[str]`` | ``None``: Validated list of mask paths, or None. Raises: - TypeError: If the input is not a sequence of strings. - ValueError: If the number of mask paths does not match the expected batch size. - - Examples: - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> mask_paths = ["path/to/mask_1.png", "path/to/mask_2.png"] - >>> validated_paths = ImageBatchValidator.validate_mask_path(mask_paths) - >>> print(validated_paths) - ['path/to/mask_1.png', 'path/to/mask_2.png'] + TypeError: If ``mask_path`` is not a sequence of strings. + + Example: + Validate mask paths:: + + >>> from anomalib.data.validators import ImageBatchValidator + >>> mask_paths = ["path/to/mask_1.png", "path/to/mask_2.png"] + >>> validated = ImageBatchValidator.validate_mask_path(mask_paths) + >>> validated + ['path/to/mask_1.png', 'path/to/mask_2.png'] """ if mask_path is None: return None @@ -480,22 +646,29 @@ def validate_mask_path(mask_path: Sequence[str] | None) -> list[str] | None: def validate_anomaly_map(anomaly_map: torch.Tensor | np.ndarray | None) -> Mask | None: """Validate the anomaly map for a batch. + This method validates batches of anomaly maps. It handles: + - Conversion from numpy arrays + - Shape normalization + - Type conversion to float32 + Args: - anomaly_map (torch.Tensor | np.ndarray | None): Input anomaly map. + anomaly_map (``torch.Tensor`` | ``np.ndarray`` | ``None``): Input anomaly map. Returns: - Mask | None: Validated anomaly map as a torchvision Mask object, or None. + ``Mask`` | ``None``: Validated anomaly map as a torchvision Mask object, or None. Raises: - ValueError: If the anomaly map cannot be converted to a torch.Tensor or has an invalid shape. + ValueError: If ``anomaly_map`` cannot be converted to tensor or has invalid shape. - Examples: - >>> import torch - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> anomaly_map = torch.rand(4, 224, 224) - >>> validated_map = ImageBatchValidator.validate_anomaly_map(anomaly_map) - >>> print(validated_map.shape) - torch.Size([4, 224, 224]) + Example: + Validate anomaly maps:: + + >>> import torch + >>> from anomalib.data.validators import ImageBatchValidator + >>> anomaly_map = torch.rand(4, 224, 224) + >>> validated = ImageBatchValidator.validate_anomaly_map(anomaly_map) + >>> validated.shape + torch.Size([4, 224, 224]) """ if anomaly_map is None: return None @@ -523,27 +696,31 @@ def validate_pred_score( ) -> torch.Tensor | None: """Validate the prediction scores for a batch. + This method validates batches of prediction scores. It handles: + - Conversion from numpy arrays and sequences + - Type conversion to float32 + Args: - pred_score (torch.Tensor | Sequence[float] | None): Input prediction scores. + pred_score (``torch.Tensor`` | ``Sequence[float]`` | ``None``): Input prediction + scores. Returns: - torch.Tensor | None: Validated prediction scores as a float32 tensor, or None. + ``torch.Tensor`` | ``None``: Validated prediction scores as float32 tensor, + or None. Raises: - TypeError: If the input is neither a sequence of floats, torch.Tensor, nor None. - ValueError: If the prediction scores are not a 1-dimensional tensor or sequence. - - Examples: - >>> import torch - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> scores = [0.8, 0.7, 0.9] - >>> validated_scores = ImageBatchValidator.validate_pred_score(scores) - >>> validated_scores - tensor([0.8000, 0.7000, 0.9000]) - >>> score_tensor = torch.tensor([0.8, 0.7, 0.9]) - >>> validated_scores = ImageBatchValidator.validate_pred_score(score_tensor) - >>> validated_scores - tensor([0.8000, 0.7000, 0.9000]) + TypeError: If ``pred_score`` is not a valid input type. + ValueError: If ``pred_score`` cannot be converted to tensor. + + Example: + Validate prediction scores:: + + >>> import torch + >>> from anomalib.data.validators import ImageBatchValidator + >>> scores = [0.8, 0.7, 0.9] + >>> validated = ImageBatchValidator.validate_pred_score(scores) + >>> validated + tensor([0.8000, 0.7000, 0.9000]) """ if pred_score is None: return None @@ -563,19 +740,25 @@ def validate_pred_score( def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: """Validate the prediction mask for a batch. + This method validates batches of prediction masks using the same logic as ground + truth masks. + Args: - pred_mask (torch.Tensor | None): Input prediction mask. + pred_mask (``torch.Tensor`` | ``None``): Input prediction mask. Returns: - Mask | None: Validated prediction mask as a torchvision Mask object, or None. - - Examples: - >>> import torch - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> pred_mask = torch.randint(0, 2, (4, 224, 224)) - >>> validated_mask = ImageBatchValidator.validate_pred_mask(pred_mask) - >>> print(validated_mask.shape) - torch.Size([4, 224, 224]) + ``Mask`` | ``None``: Validated prediction mask as a torchvision Mask object, + or None. + + Example: + Validate prediction masks:: + + >>> import torch + >>> from anomalib.data.validators import ImageBatchValidator + >>> pred_mask = torch.randint(0, 2, (4, 224, 224)) + >>> validated = ImageBatchValidator.validate_pred_mask(pred_mask) + >>> validated.shape + torch.Size([4, 224, 224]) """ return ImageBatchValidator.validate_gt_mask(pred_mask) # We can reuse the gt_mask validation @@ -583,23 +766,30 @@ def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: def validate_pred_label(pred_label: torch.Tensor | None) -> torch.Tensor | None: """Validate the prediction label for a batch. + This method validates batches of prediction labels. It handles: + - Shape normalization + - Type conversion to boolean + Args: - pred_label (torch.Tensor | None): Input prediction label. + pred_label (``torch.Tensor`` | ``None``): Input prediction label. Returns: - torch.Tensor | None: Validated prediction label as a boolean tensor, or None. + ``torch.Tensor`` | ``None``: Validated prediction label as boolean tensor, + or None. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the prediction label has an invalid shape. - - Examples: - >>> import torch - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> pred_label = torch.tensor([[1], [0], [1], [1]]) - >>> validated_label = ImageBatchValidator.validate_pred_label(pred_label) - >>> print(validated_label) - tensor([ True, False, True, True]) + TypeError: If ``pred_label`` is not a torch.Tensor. + ValueError: If ``pred_label`` has invalid shape. + + Example: + Validate prediction labels:: + + >>> import torch + >>> from anomalib.data.validators import ImageBatchValidator + >>> pred_label = torch.tensor([[1], [0], [1], [1]]) + >>> validated = ImageBatchValidator.validate_pred_label(pred_label) + >>> validated + tensor([ True, False, True, True]) """ if pred_label is None: return None @@ -625,21 +815,25 @@ def validate_pred_label(pred_label: torch.Tensor | None) -> torch.Tensor | None: def validate_image_path(image_path: list[str] | None) -> list[str] | None: """Validate the image paths for a batch. + This method validates batches of image file paths. + Args: - image_path (list[str] | None): Input list of image paths. + image_path (``list[str]`` | ``None``): Input list of image paths. Returns: - list[str] | None: Validated list of image paths, or None. + ``list[str]`` | ``None``: Validated list of image paths, or None. Raises: - TypeError: If the input is not a list of strings. - - Examples: - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> image_paths = ["path/to/image_1.jpg", "path/to/image_2.jpg"] - >>> validated_paths = ImageBatchValidator.validate_image_path(image_paths) - >>> print(validated_paths) - ['path/to/image_1.jpg', 'path/to/image_2.jpg'] + TypeError: If ``image_path`` is not a list of strings. + + Example: + Validate image paths:: + + >>> from anomalib.data.validators import ImageBatchValidator + >>> image_paths = ["path/to/image_1.jpg", "path/to/image_2.jpg"] + >>> validated = ImageBatchValidator.validate_image_path(image_paths) + >>> validated + ['path/to/image_1.jpg', 'path/to/image_2.jpg'] """ if image_path is None: return None @@ -652,21 +846,25 @@ def validate_image_path(image_path: list[str] | None) -> list[str] | None: def validate_explanation(explanation: list[str] | None) -> list[str] | None: """Validate the explanations for a batch. + This method validates batches of explanation strings. + Args: - explanation (list[str] | None): Input list of explanations. + explanation (``list[str]`` | ``None``): Input list of explanations. Returns: - list[str] | None: Validated list of explanations, or None. + ``list[str]`` | ``None``: Validated list of explanations, or None. Raises: - TypeError: If the input is not a list of strings. - - Examples: - >>> from anomalib.data.validators.torch.image import ImageBatchValidator - >>> explanations = ["The image has a crack on the wall.", "The image has a dent on the car."] - >>> validated_explanations = ImageBatchValidator.validate_explanation(explanations) - >>> print(validated_explanations) - ['The image has a crack on the wall.', 'The image has a dent on the car.'] + TypeError: If ``explanation`` is not a list of strings. + + Example: + Validate explanations:: + + >>> from anomalib.data.validators import ImageBatchValidator + >>> explanations = ["Crack on wall", "Dent on car"] + >>> validated = ImageBatchValidator.validate_explanation(explanations) + >>> validated + ['Crack on wall', 'Dent on car'] """ if explanation is None: return None diff --git a/src/anomalib/data/validators/torch/video.py b/src/anomalib/data/validators/torch/video.py index bcfca62451..0f93325f7a 100644 --- a/src/anomalib/data/validators/torch/video.py +++ b/src/anomalib/data/validators/torch/video.py @@ -1,4 +1,33 @@ -"""Validate torch video data.""" +"""Validate PyTorch tensor data for videos. + +This module provides validators for video data stored as PyTorch tensors. The validators +ensure data consistency and correctness for videos and their batches. + +The validators check: + - Tensor shapes and dimensions + - Data types + - Value ranges + - Label formats + - Mask properties + - Path validity + +Example: + Validate a single video:: + + >>> from anomalib.data.validators import VideoValidator + >>> validator = VideoValidator() + >>> validator.validate_image(video) + + Validate a batch of videos:: + + >>> from anomalib.data.validators import VideoBatchValidator + >>> validator = VideoBatchValidator() + >>> validator(videos=videos, labels=labels, masks=masks) + +Note: + The validators are used internally by the data modules to ensure data + consistency before processing video data. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,36 +41,72 @@ class VideoValidator: - """Validate torch.Tensor data for videos.""" + """Validate torch.Tensor data for videos. + + This class provides static methods to validate video data and related metadata stored as + PyTorch tensors. The validators ensure data consistency and correctness by checking + tensor shapes, dimensions, data types, and value ranges. + + The validator methods handle: + - Video tensors + - Ground truth labels and masks + - Prediction scores, labels and masks + - Video paths and metadata + - Frame indices and timing information + + Each validation method performs thorough checks and returns properly formatted data + ready for use in video processing pipelines. + + Example: + >>> import torch + >>> from anomalib.data.validators import VideoValidator + >>> video = torch.rand(10, 3, 256, 256) # 10 frames, RGB + >>> validator = VideoValidator() + >>> validated_video = validator.validate_image(video) + >>> validated_video.shape + torch.Size([10, 3, 256, 256]) + """ @staticmethod def validate_image(image: torch.Tensor) -> torch.Tensor: - """Validate the video tensor. + """Validate a video tensor. + + Validates and normalizes video tensors, handling both single and multi-frame cases. + Checks tensor type, dimensions, and channel count. Args: - image (Image): Input tensor. + image (torch.Tensor): Input video tensor with shape either: + - ``[C, H, W]`` for single frame + - ``[T, C, H, W]`` for multiple frames + where ``C`` is channels (1 or 3), ``H`` height, ``W`` width, + and ``T`` number of frames. Returns: - Image: Validated tensor. + torch.Tensor: Validated and normalized video tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the video tensor does not have the correct shape. + TypeError: If ``image`` is not a ``torch.Tensor``. + ValueError: If tensor dimensions or channel count are invalid. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator - >>> video = torch.rand(10, 3, 256, 256) # 10 frames, RGB - >>> validated_video = VideoValidator.validate_image(video) - >>> validated_video.shape + >>> # Multi-frame RGB video + >>> video = torch.rand(10, 3, 256, 256) + >>> validated = VideoValidator.validate_image(video) + >>> validated.shape torch.Size([10, 3, 256, 256]) - >>> single_frame_rgb = torch.rand(3, 256, 256) # Single RGB frame - >>> validated_single_frame_rgb = VideoValidator.validate_image(single_frame_rgb) - >>> validated_single_frame_rgb.shape + + >>> # Single RGB frame + >>> frame = torch.rand(3, 256, 256) + >>> validated = VideoValidator.validate_image(frame) + >>> validated.shape torch.Size([1, 3, 256, 256]) - >>> single_frame_gray = torch.rand(1, 256, 256) # Single grayscale frame - >>> validated_single_frame_gray = VideoValidator.validate_image(single_frame_gray) - >>> validated_single_frame_gray.shape + + >>> # Single grayscale frame + >>> gray = torch.rand(1, 256, 256) + >>> validated = VideoValidator.validate_image(gray) + >>> validated.shape torch.Size([1, 1, 256, 256]) """ if not isinstance(image, torch.Tensor): @@ -64,28 +129,36 @@ def validate_image(image: torch.Tensor) -> torch.Tensor: @staticmethod def validate_gt_label(label: int | torch.Tensor | None) -> torch.Tensor | None: - """Validate the ground truth label. + """Validate ground truth label. + + Validates and converts ground truth labels to boolean tensors. Args: - label (int | torch.Tensor | None): Input ground truth label. + label (int | torch.Tensor | None): Input label as either: + - Integer (0 or 1) + - Boolean tensor + - Integer tensor + - ``None`` Returns: - torch.Tensor | None: Validated ground truth label as a boolean tensor, or None. + torch.Tensor | None: Validated boolean tensor label or ``None``. Raises: - TypeError: If the input is neither an integer nor a torch.Tensor. - ValueError: If the label shape or dtype is invalid. + TypeError: If ``label`` is not an integer, tensor or ``None``. + ValueError: If label shape or dtype is invalid. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator - >>> label_int = 1 - >>> validated_label = VideoValidator.validate_gt_label(label_int) - >>> validated_label + >>> # Integer label + >>> validated = VideoValidator.validate_gt_label(1) + >>> validated tensor(True) - >>> label_tensor = torch.tensor([0, 0], dtype=torch.int32) - >>> validated_label = VideoValidator.validate_gt_label(label_tensor) - >>> validated_label + + >>> # Tensor label + >>> label = torch.tensor([0, 0], dtype=torch.int32) + >>> validated = VideoValidator.validate_gt_label(label) + >>> validated tensor([False, False]) """ if label is None: @@ -102,26 +175,33 @@ def validate_gt_label(label: int | torch.Tensor | None) -> torch.Tensor | None: @staticmethod def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: - """Validate the ground truth mask. + """Validate ground truth mask. + + Validates and converts ground truth masks to boolean Mask objects. Args: - mask (torch.Tensor | None): Input ground truth mask. + mask (torch.Tensor | None): Input mask tensor with shape either: + - ``[H, W]`` for single frame + - ``[T, H, W]`` for multiple frames + - ``[T, 1, H, W]`` for multiple frames with channel dimension + where ``H`` is height, ``W`` width, and ``T`` number of frames. Returns: - Mask | None: Validated ground truth mask, or None. + Mask | None: Validated boolean mask or ``None``. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the mask shape is invalid. + TypeError: If ``mask`` is not a ``torch.Tensor`` or ``None``. + ValueError: If mask shape is invalid. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator - >>> mask = torch.randint(0, 2, (10, 1, 224, 224)) # 10 frames - >>> validated_mask = VideoValidator.validate_gt_mask(mask) - >>> isinstance(validated_mask, Mask) + >>> # Multi-frame mask + >>> mask = torch.randint(0, 2, (10, 1, 224, 224)) + >>> validated = VideoValidator.validate_gt_mask(mask) + >>> isinstance(validated, Mask) True - >>> validated_mask.shape + >>> validated.shape torch.Size([10, 224, 224]) """ if mask is None: @@ -141,26 +221,32 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: @staticmethod def validate_anomaly_map(anomaly_map: torch.Tensor | None) -> Mask | None: - """Validate the anomaly map. + """Validate anomaly map. + + Validates and converts anomaly maps to float32 Mask objects. Args: - anomaly_map (torch.Tensor | None): Input anomaly map. + anomaly_map (torch.Tensor | None): Input anomaly map tensor with shape either: + - ``[T, H, W]`` for multiple frames + - ``[T, 1, H, W]`` for multiple frames with channel dimension + where ``H`` is height, ``W`` width, and ``T`` number of frames. Returns: - Mask | None: Validated anomaly map as a Mask, or None. + Mask | None: Validated float32 mask or ``None``. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the anomaly map shape is invalid. + TypeError: If ``anomaly_map`` is not a ``torch.Tensor`` or ``None``. + ValueError: If anomaly map shape is invalid. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator - >>> anomaly_map = torch.rand(10, 1, 224, 224) # 10 frames - >>> validated_map = VideoValidator.validate_anomaly_map(anomaly_map) - >>> isinstance(validated_map, Mask) + >>> # Multi-frame anomaly map + >>> amap = torch.rand(10, 1, 224, 224) + >>> validated = VideoValidator.validate_anomaly_map(amap) + >>> isinstance(validated, Mask) True - >>> validated_map.shape + >>> validated.shape torch.Size([10, 224, 224]) """ if anomaly_map is None: @@ -181,38 +267,38 @@ def validate_anomaly_map(anomaly_map: torch.Tensor | None) -> Mask | None: @staticmethod def validate_video_path(video_path: str | None) -> str | None: - """Validate the video path. + """Validate video file path. Args: - video_path (str | None): Input video path. + video_path (str | None): Input video file path or ``None``. Returns: - str | None: Validated video path, or None. + str | None: Validated video path or ``None``. Examples: >>> from anomalib.data.validators import VideoValidator >>> path = "/path/to/video.mp4" - >>> validated_path = VideoValidator.validate_video_path(path) - >>> validated_path == path + >>> validated = VideoValidator.validate_video_path(path) + >>> validated == path True """ return validate_path(video_path) if video_path else None @staticmethod def validate_mask_path(mask_path: str | None) -> str | None: - """Validate the mask path. + """Validate mask file path. Args: - mask_path (str | None): Input mask path. + mask_path (str | None): Input mask file path or ``None``. Returns: - str | None: Validated mask path, or None. + str | None: Validated mask path or ``None``. Examples: >>> from anomalib.data.validators import VideoValidator >>> path = "/path/to/mask.mp4" - >>> validated_path = VideoValidator.validate_mask_path(path) - >>> validated_path == path + >>> validated = VideoValidator.validate_mask_path(path) + >>> validated == path True """ return validate_path(mask_path) if mask_path else None @@ -222,25 +308,27 @@ def validate_pred_score( pred_score: torch.Tensor | float | None, anomaly_map: torch.Tensor | None = None, ) -> torch.Tensor | None: - """Validate the prediction score. + """Validate prediction score. + + Validates prediction scores and optionally computes them from anomaly maps. Args: - pred_score (torch.Tensor | float | None): Input prediction score. - anomaly_map (torch.Tensor | None): Input anomaly map. + pred_score (torch.Tensor | float | None): Input prediction score or ``None``. + anomaly_map (torch.Tensor | None): Optional anomaly map to compute score from. Returns: - torch.Tensor | None: Validated prediction score as a float32 tensor, or None. + torch.Tensor | None: Validated float32 prediction score or ``None``. Raises: - TypeError: If the input is neither a float, torch.Tensor, nor None. - ValueError: If the prediction score is not a scalar. + TypeError: If ``pred_score`` is not a float, tensor or ``None``. + ValueError: If prediction score is not a scalar. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator >>> score = 0.8 - >>> validated_score = VideoValidator.validate_pred_score(score) - >>> validated_score + >>> validated = VideoValidator.validate_pred_score(score) + >>> validated tensor(0.8000) """ if pred_score is None: @@ -261,46 +349,46 @@ def validate_pred_score( @staticmethod def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: - """Validate the prediction mask. + """Validate prediction mask. Args: - pred_mask (torch.Tensor | None): Input prediction mask. + pred_mask (torch.Tensor | None): Input prediction mask tensor or ``None``. Returns: - Mask | None: Validated prediction mask, or None. + Mask | None: Validated prediction mask or ``None``. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator - >>> mask = torch.randint(0, 2, (10, 1, 224, 224)) # 10 frames - >>> validated_mask = VideoValidator.validate_pred_mask(mask) - >>> isinstance(validated_mask, Mask) + >>> mask = torch.randint(0, 2, (10, 1, 224, 224)) + >>> validated = VideoValidator.validate_pred_mask(mask) + >>> isinstance(validated, Mask) True - >>> validated_mask.shape + >>> validated.shape torch.Size([10, 224, 224]) """ return VideoValidator.validate_gt_mask(pred_mask) # We can reuse the gt_mask validation @staticmethod def validate_pred_label(pred_label: torch.Tensor | None) -> torch.Tensor | None: - """Validate the prediction label. + """Validate prediction label. Args: - pred_label (torch.Tensor | None): Input prediction label. + pred_label (torch.Tensor | None): Input prediction label or ``None``. Returns: - torch.Tensor | None: Validated prediction label as a boolean tensor, or None. + torch.Tensor | None: Validated boolean prediction label or ``None``. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the prediction label is not a scalar. + TypeError: If ``pred_label`` is not a ``torch.Tensor``. + ValueError: If prediction label is not a scalar. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator >>> label = torch.tensor(1) - >>> validated_label = VideoValidator.validate_pred_label(label) - >>> validated_label + >>> validated = VideoValidator.validate_pred_label(label) + >>> validated tensor(True) """ if pred_label is None: @@ -319,29 +407,33 @@ def validate_pred_label(pred_label: torch.Tensor | None) -> torch.Tensor | None: @staticmethod def validate_original_image(original_image: torch.Tensor | Video | None) -> torch.Tensor | Video | None: - """Validate the original video or image. + """Validate original video or image. Args: - original_image (torch.Tensor | Video | None): Input original video or image. + original_image (torch.Tensor | Video | None): Input original video/image or + ``None``. Returns: - torch.Tensor | Video | None: Validated original video or image. + torch.Tensor | Video | None: Validated original video/image or ``None``. Raises: - TypeError: If the input is not a torch.Tensor or torchvision Video object. - ValueError: If the tensor does not have the correct shape. + TypeError: If input is not a ``torch.Tensor`` or ``Video``. + ValueError: If tensor shape is invalid. Examples: >>> import torch >>> from torchvision.tv_tensors import Video >>> from anomalib.data.validators import VideoValidator - >>> video = Video(torch.rand(10, 3, 224, 224)) # 10 frames - >>> validated_video = VideoValidator.validate_original_image(video) - >>> validated_video.shape + >>> # Video tensor + >>> video = Video(torch.rand(10, 3, 224, 224)) + >>> validated = VideoValidator.validate_original_image(video) + >>> validated.shape torch.Size([10, 3, 224, 224]) - >>> image = torch.rand(3, 256, 256) # Single image - >>> validated_image = VideoValidator.validate_original_image(image) - >>> validated_image.shape + + >>> # Single image + >>> image = torch.rand(3, 256, 256) + >>> validated = VideoValidator.validate_original_image(image) + >>> validated.shape torch.Size([3, 256, 256]) """ if original_image is None: @@ -369,22 +461,22 @@ def validate_original_image(original_image: torch.Tensor | Video | None) -> torc @staticmethod def validate_target_frame(target_frame: int | None) -> int | None: - """Validate the target frame index. + """Validate target frame index. Args: - target_frame (int | None): Input target frame index. + target_frame (int | None): Input target frame index or ``None``. Returns: - int | None: Validated target frame index, or None. + int | None: Validated target frame index or ``None``. Raises: - TypeError: If the input is not an integer. - ValueError: If the target frame index is negative. + TypeError: If ``target_frame`` is not an integer. + ValueError: If target frame index is negative. Examples: >>> from anomalib.data.validators import VideoValidator - >>> validated_frame = VideoValidator.validate_target_frame(31) - >>> print(validated_frame) + >>> validated = VideoValidator.validate_target_frame(31) + >>> print(validated) 31 """ if target_frame is None: @@ -399,24 +491,24 @@ def validate_target_frame(target_frame: int | None) -> int | None: @staticmethod def validate_frames(frames: torch.Tensor | None) -> torch.Tensor | None: - """Validate the frames tensor. + """Validate frames tensor. Args: - frames (torch.Tensor | None): Input frames tensor or frame indices. + frames (torch.Tensor | None): Input frames tensor or frame indices or ``None``. Returns: - torch.Tensor | None: Validated frames tensor, or None. + torch.Tensor | None: Validated frames tensor or ``None``. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the frames tensor is not a 1D tensor of indices. + TypeError: If ``frames`` is not a ``torch.Tensor``. + ValueError: If frames tensor is not a 1D tensor of indices. Examples: >>> import torch >>> from anomalib.data.validators import VideoValidator - >>> frame_indices = torch.tensor([0, 5, 10]) - >>> validated_indices = VideoValidator.validate_frames(frame_indices) - >>> validated_indices + >>> indices = torch.tensor([0, 5, 10]) + >>> validated = VideoValidator.validate_frames(indices) + >>> validated tensor([0, 5, 10]) """ if frames is None: @@ -442,30 +534,36 @@ def validate_frames(frames: torch.Tensor | None) -> torch.Tensor | None: @staticmethod def validate_last_frame(last_frame: torch.Tensor | int | float | None) -> torch.Tensor | int | None: - """Validate the last frame index. + """Validate last frame index. Args: - last_frame (torch.Tensor | int | float | None): Input last frame index. + last_frame (torch.Tensor | int | float | None): Input last frame index or + ``None``. Returns: - torch.Tensor | int | None: Validated last frame index, or None. + torch.Tensor | int | None: Validated last frame index or ``None``. Raises: - TypeError: If the input is not a torch.Tensor, int, or float. - ValueError: If the last frame index is negative. + TypeError: If ``last_frame`` is not a tensor, int, or float. + ValueError: If last frame index is negative. Examples: >>> from anomalib.data.validators import VideoValidator - >>> validated_frame = VideoValidator.validate_last_frame(5) - >>> print(validated_frame) + >>> # Integer input + >>> validated = VideoValidator.validate_last_frame(5) + >>> print(validated) 5 - >>> validated_float = VideoValidator.validate_last_frame(5.7) - >>> print(validated_float) + + >>> # Float input + >>> validated = VideoValidator.validate_last_frame(5.7) + >>> print(validated) 5 + + >>> # Tensor input >>> import torch >>> tensor_frame = torch.tensor(10.3) - >>> validated_tensor = VideoValidator.validate_last_frame(tensor_frame) - >>> print(validated_tensor) + >>> validated = VideoValidator.validate_last_frame(tensor_frame) + >>> print(validated) tensor(10) """ if last_frame is None: @@ -495,27 +593,45 @@ def validate_explanation(explanation: str | None) -> str | None: class VideoBatchValidator: - """Validate torch.Tensor data for video batches.""" + """Validate ``torch.Tensor`` data for video batches. + + This class provides static methods to validate various video batch data types including + tensors, masks, labels, paths and more. Each method performs thorough validation of + its input and returns the validated data in the correct format. + """ @staticmethod def validate_image(image: Video) -> Video: """Validate the video batch tensor. + Validates that the input video batch tensor has the correct dimensions, number of + channels and data type. Converts the tensor to float32 and scales values to [0,1] + range. + Args: - image (Video): Input video batch tensor. + image (Video): Input video batch tensor. Should be either: + - Shape ``(B,C,H,W)`` for single frame images + - Shape ``(B,T,C,H,W)`` for multi-frame videos + Where: + - ``B`` is batch size + - ``T`` is number of frames + - ``C`` is number of channels (1 or 3) + - ``H`` is height + - ``W`` is width Returns: - Video: Validated video batch tensor. + Video: Validated and normalized video batch tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the tensor does not have the correct dimensions or number of channels. + TypeError: If ``image`` is not a ``torch.Tensor``. + ValueError: If tensor dimensions or channel count are invalid. Examples: >>> import torch >>> from torchvision.tv_tensors import Video >>> from anomalib.data.validators import VideoBatchValidator - >>> video_batch = Video(torch.rand(2, 10, 3, 224, 224)) # 2 videos, 10 frames each + >>> # Create sample video batch with 2 videos, 10 frames each + >>> video_batch = Video(torch.rand(2, 10, 3, 224, 224)) >>> validated_batch = VideoBatchValidator.validate_image(video_batch) >>> print(validated_batch.shape) torch.Size([2, 10, 3, 224, 224]) @@ -544,14 +660,18 @@ def validate_image(image: Video) -> Video: def validate_gt_label(label: torch.Tensor | None) -> torch.Tensor | None: """Validate the ground truth labels for a batch. + Validates that the input ground truth labels have the correct data type and + format. Converts labels to boolean type. + Args: - label (torch.Tensor | None): Input ground truth labels. + label (torch.Tensor | None): Input ground truth labels. Should be a 1D tensor + of boolean or integer values. Returns: - torch.Tensor | None: Validated ground truth labels. + torch.Tensor | None: Validated ground truth labels as boolean tensor. Raises: - TypeError: If the input is not a torch.Tensor or has an invalid dtype. + TypeError: If ``label`` is not a ``torch.Tensor`` or has invalid dtype. Examples: >>> import torch @@ -575,25 +695,35 @@ def validate_gt_label(label: torch.Tensor | None) -> torch.Tensor | None: def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: """Validate the ground truth masks for a batch. + Validates that the input ground truth masks have the correct shape and format. + Converts masks to boolean type. + Args: - mask (torch.Tensor | None): Input ground truth masks. + mask (torch.Tensor | None): Input ground truth masks. Should be one of: + - Shape ``(H,W)`` for single mask + - Shape ``(N,H,W)`` for batch of masks + - Shape ``(N,1,H,W)`` for batch with channel dimension Returns: - Mask | None: Validated ground truth masks. + Mask | None: Validated ground truth masks as boolean tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the mask has an invalid shape. + TypeError: If ``mask`` is not a ``torch.Tensor``. + ValueError: If mask shape is invalid. Examples: >>> import torch >>> from anomalib.data.validators import VideoBatchValidator - >>> gt_masks = torch.rand(10, 224, 224) > 0.5 # 10 frames each + >>> # Create 10 frame masks + >>> gt_masks = torch.rand(10, 224, 224) > 0.5 >>> validated_masks = VideoBatchValidator.validate_gt_mask(gt_masks) >>> print(validated_masks.shape) torch.Size([10, 224, 224]) - >>> single_frame_masks = torch.rand(4, 456, 256) > 0.5 # 4 single-frame images - >>> validated_single_frame = VideoBatchValidator.validate_gt_mask(single_frame_masks) + >>> # Create 4 single-frame masks + >>> single_frame_masks = torch.rand(4, 456, 256) > 0.5 + >>> validated_single_frame = VideoBatchValidator.validate_gt_mask( + ... single_frame_masks + ... ) >>> print(validated_single_frame.shape) torch.Size([4, 456, 256]) """ @@ -618,14 +748,17 @@ def validate_gt_mask(mask: torch.Tensor | None) -> Mask | None: def validate_mask_path(mask_path: list[str] | None) -> list[str] | None: """Validate the mask paths for a batch. + Validates that the input mask paths are in the correct format. + Args: - mask_path (list[str] | None): Input mask paths. + mask_path (list[str] | None): Input mask paths. Should be a list of strings + containing valid file paths. Returns: list[str] | None: Validated mask paths. Raises: - TypeError: If the input is not a list of strings. + TypeError: If ``mask_path`` is not a list of strings. Examples: >>> from anomalib.data.validators import VideoBatchValidator @@ -640,20 +773,31 @@ def validate_mask_path(mask_path: list[str] | None) -> list[str] | None: def validate_anomaly_map(anomaly_map: torch.Tensor | None) -> Mask | None: """Validate the anomaly maps for a batch. + Validates that the input anomaly maps have the correct shape and format. + Converts maps to float32 type. + Args: - anomaly_map (torch.Tensor | None): Input anomaly maps. + anomaly_map (torch.Tensor | None): Input anomaly maps. Should be either: + - Shape ``(B,T,H,W)`` for single channel maps + - Shape ``(B,T,1,H,W)`` for explicit single channel + Where: + - ``B`` is batch size + - ``T`` is number of frames + - ``H`` is height + - ``W`` is width Returns: - Mask | None: Validated anomaly maps. + Mask | None: Validated anomaly maps as float32 tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the anomaly map has an invalid shape. + TypeError: If ``anomaly_map`` is not a ``torch.Tensor``. + ValueError: If anomaly map shape is invalid. Examples: >>> import torch >>> from anomalib.data.validators import VideoBatchValidator - >>> anomaly_maps = torch.rand(2, 10, 224, 224) # 2 videos, 10 frames each + >>> # Create maps for 2 videos with 10 frames each + >>> anomaly_maps = torch.rand(2, 10, 224, 224) >>> validated_maps = VideoBatchValidator.validate_anomaly_map(anomaly_maps) >>> print(validated_maps.shape) torch.Size([2, 10, 224, 224]) @@ -680,15 +824,21 @@ def validate_pred_score( ) -> torch.Tensor | None: """Validate the prediction scores for a batch. + Validates that the input prediction scores have the correct format. If no scores + are provided but an anomaly map is given, computes scores from the map. + Args: - pred_score (torch.Tensor | None): Input prediction scores. - anomaly_map (torch.Tensor | None): Input anomaly map (optional). + pred_score (torch.Tensor | None): Input prediction scores. Should be a 1D + tensor of float values. + anomaly_map (torch.Tensor | None, optional): Input anomaly map used to compute + scores if ``pred_score`` is None. Returns: - torch.Tensor | None: Validated prediction scores. + torch.Tensor | None: Validated prediction scores as float32 tensor. Raises: - ValueError: If the prediction scores have an invalid shape or cannot be converted to a tensor. + ValueError: If prediction scores have invalid shape or cannot be converted to + tensor. Examples: >>> import torch @@ -717,8 +867,11 @@ def validate_pred_score( def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: """Validate the prediction masks for a batch. + Validates prediction masks using the same logic as ground truth masks. + Args: - pred_mask (torch.Tensor | None): Input prediction masks. + pred_mask (torch.Tensor | None): Input prediction masks. Should follow same + format as ground truth masks. Returns: Mask | None: Validated prediction masks. @@ -726,7 +879,8 @@ def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: Examples: >>> import torch >>> from anomalib.data.validators import VideoBatchValidator - >>> pred_masks = torch.rand(2, 10, 224, 224) > 0.5 # 2 videos, 10 frames each + >>> # Create masks for 2 videos with 10 frames each + >>> pred_masks = torch.rand(2, 10, 224, 224) > 0.5 >>> validated_masks = VideoBatchValidator.validate_pred_mask(pred_masks) >>> print(validated_masks.shape) torch.Size([2, 10, 224, 224]) @@ -737,14 +891,19 @@ def validate_pred_mask(pred_mask: torch.Tensor | None) -> Mask | None: def validate_pred_label(pred_label: torch.Tensor | None) -> torch.Tensor | None: """Validate the prediction labels for a batch. + Validates that the input prediction labels have the correct format and converts + them to boolean type. + Args: - pred_label (torch.Tensor | None): Input prediction labels. + pred_label (torch.Tensor | None): Input prediction labels. Should be a 1D + tensor of boolean or numeric values. Returns: - torch.Tensor | None: Validated prediction labels. + torch.Tensor | None: Validated prediction labels as boolean tensor. Raises: - ValueError: If the prediction labels have an invalid shape or cannot be converted to a tensor. + ValueError: If prediction labels have invalid shape or cannot be converted to + tensor. Examples: >>> import torch @@ -771,22 +930,37 @@ def validate_pred_label(pred_label: torch.Tensor | None) -> torch.Tensor | None: def validate_original_image(original_image: torch.Tensor | Video | None) -> torch.Tensor | Video | None: """Validate the original videos for a batch. + Validates that the input videos have the correct dimensions and channel count. + Adds temporal dimension to single frame inputs. + Args: - original_image (torch.Tensor | Video | None): Input original videos. + original_image (torch.Tensor | Video | None): Input original videos. Should be + either: + - Shape ``(B,C,H,W)`` for single frame images + - Shape ``(B,T,C,H,W)`` for multi-frame videos + Where: + - ``B`` is batch size + - ``T`` is number of frames + - ``C`` is number of channels (must be 3) + - ``H`` is height + - ``W`` is width Returns: torch.Tensor | Video | None: Validated original videos. Raises: - TypeError: If the input is not a torch.Tensor or torchvision Video object. - ValueError: If the video has an invalid shape or number of channels. + TypeError: If input is not a ``torch.Tensor`` or ``torchvision.Video``. + ValueError: If video has invalid shape or channel count. Examples: >>> import torch >>> from torchvision.tv_tensors import Video >>> from anomalib.data.validators import VideoBatchValidator - >>> original_videos = Video(torch.rand(2, 10, 3, 224, 224)) # 2 videos, 10 frames each - >>> validated_videos = VideoBatchValidator.validate_original_image(original_videos) + >>> # Create 2 videos with 10 frames each + >>> original_videos = Video(torch.rand(2, 10, 3, 224, 224)) + >>> validated_videos = VideoBatchValidator.validate_original_image( + ... original_videos + ... ) >>> print(validated_videos.shape) torch.Size([2, 10, 3, 224, 224]) """ @@ -817,14 +991,17 @@ def validate_original_image(original_image: torch.Tensor | Video | None) -> torc def validate_video_path(video_path: list[str] | None) -> list[str] | None: """Validate the video paths for a batch. + Validates that the input video paths are in the correct format. + Args: - video_path (list[str] | None): Input video paths. + video_path (list[str] | None): Input video paths. Should be a list of strings + containing valid file paths. Returns: list[str] | None: Validated video paths. Raises: - TypeError: If the input is not a list of strings. + TypeError: If ``video_path`` is not a list of strings. Examples: >>> from anomalib.data.validators import VideoBatchValidator @@ -839,15 +1016,18 @@ def validate_video_path(video_path: list[str] | None) -> list[str] | None: def validate_target_frame(target_frame: torch.Tensor | None) -> torch.Tensor | None: """Validate the target frame indices for a batch. + Validates that the input target frame indices are non-negative integers. + Args: - target_frame (torch.Tensor | None): Input target frame indices. + target_frame (torch.Tensor | None): Input target frame indices. Should be a + 1D tensor of non-negative integers. Returns: - torch.Tensor | None: Validated target frame indices. + torch.Tensor | None: Validated target frame indices as int64 tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the target frame indices are invalid. + TypeError: If ``target_frame`` is not a ``torch.Tensor``. + ValueError: If target frame indices are invalid. Examples: >>> import torch @@ -874,15 +1054,20 @@ def validate_target_frame(target_frame: torch.Tensor | None) -> torch.Tensor | N def validate_frames(frames: torch.Tensor | None) -> torch.Tensor | None: """Validate the frame indices for a batch. + Validates that the input frame indices are non-negative integers and converts + them to the correct shape. + Args: - frames (torch.Tensor | None): Input frame indices. + frames (torch.Tensor | None): Input frame indices. Should be either: + - Shape ``(N,)`` for 1D tensor + - Shape ``(N,1)`` for 2D tensor Returns: - torch.Tensor | None: Validated frame indices. + torch.Tensor | None: Validated frame indices as 1D int64 tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the frame indices are invalid. + TypeError: If ``frames`` is not a ``torch.Tensor``. + ValueError: If frame indices are invalid. Examples: >>> import torch @@ -911,21 +1096,26 @@ def validate_frames(frames: torch.Tensor | None) -> torch.Tensor | None: def validate_last_frame(last_frame: torch.Tensor | None) -> torch.Tensor | None: """Validate the last frame indices for a batch. + Validates that the input last frame indices are non-negative integers. + Args: - last_frame (torch.Tensor | None): Input last frame indices. + last_frame (torch.Tensor | None): Input last frame indices. Should be a 1D + tensor of non-negative numeric values. Returns: - torch.Tensor | None: Validated last frame indices. + torch.Tensor | None: Validated last frame indices as int64 tensor. Raises: - TypeError: If the input is not a torch.Tensor. - ValueError: If the last frame indices are invalid. + TypeError: If ``last_frame`` is not a ``torch.Tensor``. + ValueError: If last frame indices are invalid. Examples: >>> import torch >>> from anomalib.data.validators import VideoBatchValidator >>> last_frames = torch.tensor([9.5, 12.2, 15.8, 10.0]) - >>> validated_last_frames = VideoBatchValidator.validate_last_frame(last_frames) + >>> validated_last_frames = VideoBatchValidator.validate_last_frame( + ... last_frames + ... ) >>> print(validated_last_frames) tensor([ 9, 12, 15, 10]) """ diff --git a/src/anomalib/deploy/__init__.py b/src/anomalib/deploy/__init__.py index e2bec10b1f..358f130b3e 100644 --- a/src/anomalib/deploy/__init__.py +++ b/src/anomalib/deploy/__init__.py @@ -1,4 +1,26 @@ -"""Functions for Inference and model deployment.""" +"""Functions for model inference and deployment. + +This module provides functionality for deploying trained anomaly detection models +and performing inference. It includes: + +- Model export utilities for converting models to different formats +- Inference classes for making predictions: + - :class:`Inferencer`: Base inferencer interface + - :class:`TorchInferencer`: For PyTorch models + - :class:`OpenVINOInferencer`: For OpenVINO IR models + +Example: + >>> from anomalib.deploy import TorchInferencer + >>> model = TorchInferencer(path="path/to/model.pt") + >>> predictions = model.predict(image="path/to/image.jpg") + + The prediction contains anomaly maps and scores: + + >>> predictions.anomaly_map # doctest: +SKIP + tensor([[0.1, 0.2, ...]]) + >>> predictions.pred_score # doctest: +SKIP + tensor(0.86) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/deploy/export.py b/src/anomalib/deploy/export.py index 69e508396f..a65f093869 100644 --- a/src/anomalib/deploy/export.py +++ b/src/anomalib/deploy/export.py @@ -1,4 +1,23 @@ -"""Utilities for optimization and OpenVINO conversion.""" +"""Utilities for optimization and OpenVINO conversion. + +This module provides functionality for exporting and optimizing anomaly detection +models to different formats like ONNX, OpenVINO IR and PyTorch. + +Example: + Export a model to ONNX format: + + >>> from anomalib.deploy import ExportType + >>> export_type = ExportType.ONNX + >>> export_type + 'onnx' + + Export with OpenVINO compression: + + >>> from anomalib.deploy import CompressionType + >>> compression = CompressionType.INT8_PTQ + >>> compression + 'int8_ptq' +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,14 +31,18 @@ class ExportType(str, Enum): """Model export type. - Examples: + Supported export formats for anomaly detection models. + + Attributes: + ONNX: Export model to ONNX format + OPENVINO: Export model to OpenVINO IR format + TORCH: Export model to PyTorch format + + Example: >>> from anomalib.deploy import ExportType - >>> ExportType.ONNX + >>> export_type = ExportType.ONNX + >>> export_type 'onnx' - >>> ExportType.OPENVINO - 'openvino' - >>> ExportType.TORCH - 'torch' """ ONNX = "onnx" @@ -31,20 +54,22 @@ class CompressionType(str, Enum): """Model compression type when exporting to OpenVINO. Attributes: - FP16 (str): Weight compression (FP16). All weights are converted to FP16. - INT8 (str): Weight compression (INT8). All weights are quantized to INT8, - but are dequantized to floating point before inference. - INT8_PTQ (str): Full integer post-training quantization (INT8). - All weights and operations are quantized to INT8. Inference is done - in INT8 precision. - INT8_ACQ (str): Accuracy-control quantization (INT8). Weights and + FP16: Weight compression to FP16 precision. All weights are converted + to FP16. + INT8: Weight compression to INT8 precision. All weights are quantized + to INT8, but are dequantized to floating point before inference. + INT8_PTQ: Full integer post-training quantization to INT8 precision. + All weights and operations are quantized to INT8. Inference is + performed in INT8 precision. + INT8_ACQ: Accuracy-control quantization to INT8 precision. Weights and operations are quantized to INT8, except those that would degrade - quality of the model more than is acceptable. Inference is done in - a mixed precision. + model quality beyond an acceptable threshold. Inference uses mixed + precision. - Examples: + Example: >>> from anomalib.deploy import CompressionType - >>> CompressionType.INT8_PTQ + >>> compression = CompressionType.INT8_PTQ + >>> compression 'int8_ptq' """ diff --git a/src/anomalib/deploy/inferencers/__init__.py b/src/anomalib/deploy/inferencers/__init__.py index f47ece3425..00e73e0003 100644 --- a/src/anomalib/deploy/inferencers/__init__.py +++ b/src/anomalib/deploy/inferencers/__init__.py @@ -1,6 +1,19 @@ -"""Inferencers for Torch and OpenVINO.""" +"""Inferencers for performing inference with anomaly detection models. -# Copyright (C) 2022 Intel Corporation +This module provides inferencer classes for running inference with trained models +using different backends: + +- :class:`Inferencer`: Base class defining the inferencer interface +- :class:`TorchInferencer`: For inference with PyTorch models +- :class:`OpenVINOInferencer`: For optimized inference with OpenVINO + +Example: + >>> from anomalib.deploy import TorchInferencer + >>> model = TorchInferencer(path="path/to/model.pt") + >>> predictions = model.predict(image="path/to/image.jpg") +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .base_inferencer import Inferencer diff --git a/src/anomalib/deploy/inferencers/base_inferencer.py b/src/anomalib/deploy/inferencers/base_inferencer.py index b549b32a19..e53d8a487f 100644 --- a/src/anomalib/deploy/inferencers/base_inferencer.py +++ b/src/anomalib/deploy/inferencers/base_inferencer.py @@ -1,4 +1,11 @@ -"""Base Inferencer for Torch and OpenVINO.""" +"""Base Inferencer for Torch and OpenVINO. + +This module provides the base inferencer class that defines the interface for +performing inference with anomaly detection models. + +The base class is used by both the PyTorch and OpenVINO inferencers to ensure +a consistent API across different backends. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -20,49 +27,118 @@ class Inferencer(ABC): - """Abstract class for the inference. + """Abstract base class for performing inference with anomaly detection models. + + This class defines the interface that must be implemented by concrete + inferencer classes for different backends (PyTorch, OpenVINO). - This is used by both Torch and OpenVINO inference. + Example: + >>> from anomalib.deploy import TorchInferencer + >>> model = TorchInferencer(path="path/to/model.pt") + >>> predictions = model.predict(image="path/to/image.jpg") """ @abstractmethod def load_model(self, path: str | Path) -> Any: # noqa: ANN401 - """Load Model.""" + """Load a model from the specified path. + + Args: + path (str | Path): Path to the model file. + + Returns: + Any: Loaded model instance. + + Raises: + NotImplementedError: This is an abstract method. + """ raise NotImplementedError @abstractmethod def pre_process(self, image: np.ndarray) -> np.ndarray | torch.Tensor: - """Pre-process.""" + """Pre-process an input image. + + Args: + image (np.ndarray): Input image to pre-process. + + Returns: + np.ndarray | torch.Tensor: Pre-processed image. + + Raises: + NotImplementedError: This is an abstract method. + """ raise NotImplementedError @abstractmethod def forward(self, image: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: - """Forward-Pass input to model.""" + """Perform a forward pass on the model. + + Args: + image (np.ndarray | torch.Tensor): Pre-processed input image. + + Returns: + np.ndarray | torch.Tensor: Model predictions. + + Raises: + NotImplementedError: This is an abstract method. + """ raise NotImplementedError @abstractmethod def post_process(self, predictions: np.ndarray | torch.Tensor, metadata: dict[str, Any] | None) -> dict[str, Any]: - """Post-Process.""" + """Post-process model predictions. + + Args: + predictions (np.ndarray | torch.Tensor): Raw model predictions. + metadata (dict[str, Any] | None): Metadata used for post-processing. + + Returns: + dict[str, Any]: Post-processed predictions. + + Raises: + NotImplementedError: This is an abstract method. + """ raise NotImplementedError @abstractmethod def predict(self, image: str | Path | np.ndarray | torch.Tensor) -> ImageResult: - """Predict.""" + """Run inference on an image. + + Args: + image (str | Path | np.ndarray | torch.Tensor): Input image. + + Returns: + ImageResult: Prediction results. + + Raises: + NotImplementedError: This is an abstract method. + """ raise NotImplementedError @staticmethod def _superimpose_segmentation_mask(metadata: dict, anomaly_map: np.ndarray, image: np.ndarray) -> np.ndarray: - """Superimpose segmentation mask on top of image. + """Superimpose segmentation mask on an image. Args: - metadata (dict): Metadata of the image which contains the image size. - anomaly_map (np.ndarray): Anomaly map which is used to extract segmentation mask. - image (np.ndarray): Image on which segmentation mask is to be superimposed. + metadata (dict): Image metadata containing the image dimensions. + anomaly_map (np.ndarray): Anomaly map used to extract segmentation. + image (np.ndarray): Image on which to superimpose the mask. Returns: - np.ndarray: Image with segmentation mask superimposed. + np.ndarray: Image with superimposed segmentation mask. + + Example: + >>> image = np.zeros((100, 100, 3)) + >>> anomaly_map = np.zeros((100, 100)) + >>> metadata = {"image_shape": (100, 100)} + >>> result = Inferencer._superimpose_segmentation_mask( + ... metadata, + ... anomaly_map, + ... image, + ... ) + >>> result.shape + (100, 100, 3) """ - pred_mask = compute_mask(anomaly_map, 0.5) # assumes predictions are normalized. + pred_mask = compute_mask(anomaly_map, 0.5) # assumes normalized preds image_height = metadata["image_shape"][0] image_width = metadata["image_shape"][1] pred_mask = cv2.resize(pred_mask, (image_width, image_height)) @@ -72,13 +148,18 @@ def _superimpose_segmentation_mask(metadata: dict, anomaly_map: np.ndarray, imag return image def __call__(self, image: np.ndarray) -> ImageResult: - """Call predict on the Image. + """Call predict on an image. Args: - image (np.ndarray): Input Image + image (np.ndarray): Input image. Returns: ImageResult: Prediction results to be visualized. + + Example: + >>> model = Inferencer() # doctest: +SKIP + >>> image = np.zeros((100, 100, 3)) + >>> predictions = model(image) # doctest: +SKIP """ return self.predict(image) @@ -88,17 +169,29 @@ def _normalize( metadata: dict | DictConfig, anomaly_maps: torch.Tensor | np.ndarray | None = None, ) -> tuple[np.ndarray | torch.Tensor | None, float]: - """Apply normalization and resizes the image. + """Normalize predictions using min-max normalization. Args: - pred_scores (Tensor | np.float32): Predicted anomaly score - metadata (dict | DictConfig): Meta data. Post-processing step sometimes requires - additional meta data such as image shape. This variable comprises such info. - anomaly_maps (Tensor | np.ndarray | None): Predicted raw anomaly map. + pred_scores (torch.Tensor | np.float32): Predicted anomaly scores. + metadata (dict | DictConfig): Metadata containing normalization + parameters. + anomaly_maps (torch.Tensor | np.ndarray | None): Raw anomaly maps. + Defaults to None. Returns: - tuple[np.ndarray | torch.Tensor | None, float]: Post processed predictions that are ready to be - visualized and predicted scores. + tuple[np.ndarray | torch.Tensor | None, float]: Normalized predictions + and scores. + + Example: + >>> scores = torch.tensor(0.5) + >>> metadata = { + ... "image_threshold": 0.5, + ... "pred_scores.min": 0.0, + ... "pred_scores.max": 1.0 + ... } + >>> maps, norm_scores = Inferencer._normalize(scores, metadata) + >>> norm_scores + 0.5 """ # min max normalization if "pred_scores.min" in metadata and "pred_scores.max" in metadata: @@ -118,15 +211,21 @@ def _normalize( return anomaly_maps, float(pred_scores) - def _load_metadata(self, path: str | Path | dict | None = None) -> dict | DictConfig: # noqa: PLR6301 - """Load the meta data from the given path. + @staticmethod + def _load_metadata(path: str | Path | dict | None = None) -> dict | DictConfig: + """Load metadata from a file. Args: - path (str | Path | dict | None, optional): Path to JSON file containing the metadata. - If no path is provided, it returns an empty dict. Defaults to None. + path (str | Path | dict | None): Path to metadata file. If None, + returns empty dict. Defaults to None. Returns: - dict | DictConfig: Dictionary containing the metadata. + dict | DictConfig: Loaded metadata. + + Example: + >>> model = Inferencer() # doctest: +SKIP + >>> metadata = model._load_metadata("path/to/metadata.json") + ... # doctest: +SKIP """ metadata: dict[str, float | np.ndarray | torch.Tensor] | DictConfig = {} if path is not None: diff --git a/src/anomalib/deploy/inferencers/openvino_inferencer.py b/src/anomalib/deploy/inferencers/openvino_inferencer.py index 61b4a3d0ee..07afc52535 100644 --- a/src/anomalib/deploy/inferencers/openvino_inferencer.py +++ b/src/anomalib/deploy/inferencers/openvino_inferencer.py @@ -1,4 +1,50 @@ -"""OpenVINO Inferencer implementation.""" +"""OpenVINO Inferencer for optimized model inference. + +This module provides the OpenVINO inferencer implementation for running optimized +inference with OpenVINO IR models. + +Example: + Assume we have OpenVINO IR model files in the following structure: + + .. code-block:: bash + + $ tree weights + ./weights + ├── model.bin + ├── model.xml + └── metadata.json + + Create an OpenVINO inferencer: + + >>> from anomalib.deploy import OpenVINOInferencer + >>> inferencer = OpenVINOInferencer( + ... path="weights/model.xml", + ... device="CPU" + ... ) + + Make predictions: + + >>> # From image path + >>> prediction = inferencer.predict("path/to/image.jpg") + + >>> # From PIL Image + >>> from PIL import Image + >>> image = Image.open("path/to/image.jpg") + >>> prediction = inferencer.predict(image) + + >>> # From numpy array + >>> import numpy as np + >>> image = np.random.rand(224, 224, 3) + >>> prediction = inferencer.predict(image) + + The prediction result contains anomaly maps and scores: + + >>> prediction.anomaly_map # doctest: +SKIP + array([[0.1, 0.2, ...]], dtype=float32) + + >>> prediction.pred_score # doctest: +SKIP + 0.86 +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -18,67 +64,25 @@ class OpenVINOInferencer: - """OpenVINO implementation for the inference. + """OpenVINO inferencer for optimized model inference. Args: - path (str | Path): Path to the openvino onnx, xml or bin file. - metadata (str | Path | dict, optional): Path to metadata file or a dict object defining the - metadata. - Defaults to ``None``. - device (str | None, optional): Device to run the inference on (AUTO, CPU, GPU, NPU). - Defaults to ``AUTO``. - task (TaskType | None, optional): Task type. - Defaults to ``None``. - config (dict | None, optional): Configuration parameters for the inference + path (str | Path | tuple[bytes, bytes]): Path to OpenVINO IR files + (``.xml`` and ``.bin``) or ONNX model, or tuple of xml/bin data as + bytes. + device (str | None, optional): Inference device. + Options: ``"AUTO"``, ``"CPU"``, ``"GPU"``, ``"NPU"``. + Defaults to ``"AUTO"``. + config (dict | None, optional): OpenVINO configuration parameters. Defaults to ``None``. - Examples: - Assume that we have an OpenVINO IR model and metadata files in the following structure: - - .. code-block:: bash - - $ tree weights - ./weights - ├── model.bin - ├── model.xml - └── metadata.json - - We could then create ``OpenVINOInferencer`` as follows: - - >>> from anomalib.deploy.inferencers import OpenVINOInferencer - >>> inferencer = OpenVINOInferencer( - ... path="weights/model.xml", - ... metadata="weights/metadata.json", - ... device="CPU", + Example: + >>> from anomalib.deploy import OpenVINOInferencer + >>> model = OpenVINOInferencer( + ... path="model.xml", + ... device="CPU" ... ) - - This will ensure that the model is loaded on the ``CPU`` device and the - metadata is loaded from the ``metadata.json`` file. To make a prediction, - we can simply call the ``predict`` method: - - >>> prediction = inferencer.predict(image="path/to/image.jpg") - - Alternatively we can also pass the image as a PIL image or numpy array: - - >>> from PIL import Image - >>> image = Image.open("path/to/image.jpg") - >>> prediction = inferencer.predict(image=image) - - >>> import numpy as np - >>> image = np.random.rand(224, 224, 3) - >>> prediction = inferencer.predict(image=image) - - ``prediction`` will be an ``ImageResult`` object containing the prediction - results. For example, to visualize the heatmap, we can do the following: - - >>> from matplotlib import pyplot as plt - >>> plt.imshow(result.heatmap) - - It is also possible to visualize the true and predicted masks if the - task is ``TaskType.SEGMENTATION``: - - >>> plt.imshow(result.gt_mask) - >>> plt.imshow(result.pred_mask) + >>> prediction = model.predict("test.jpg") """ def __init__( @@ -92,20 +96,24 @@ def __init__( raise ImportError(msg) self.device = device - self.config = config self.input_blob, self.output_blob, self.model = self.load_model(path) def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any, Any]: - """Load the OpenVINO model. + """Load OpenVINO model from file or bytes. Args: - path (str | Path | tuple[bytes, bytes]): Path to the onnx or xml and bin files - or tuple of .xml and .bin data as bytes. + path (str | Path | tuple[bytes, bytes]): Path to model files or model + data as bytes tuple. Returns: - [tuple[str, str, ExecutableNetwork]]: Input and Output blob names - together with the Executable network. + tuple[Any, Any, Any]: Tuple containing: + - Input blob + - Output blob + - Compiled model + + Raises: + ValueError: If model path has invalid extension. """ import openvino as ov @@ -131,7 +139,11 @@ def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any, cache_folder.mkdir(exist_ok=True) core.set_property({"CACHE_DIR": cache_folder}) - compile_model = core.compile_model(model=model, device_name=self.device, config=self.config) + compile_model = core.compile_model( + model=model, + device_name=self.device, + config=self.config, + ) input_blob = compile_model.input(0) output_blob = compile_model.output(0) @@ -140,13 +152,13 @@ def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any, @staticmethod def pre_process(image: np.ndarray) -> np.ndarray: - """Pre-process the input image by applying transformations. + """Pre-process input image. Args: image (np.ndarray): Input image. Returns: - np.ndarray: pre-processed image. + np.ndarray: Pre-processed image with shape (N,C,H,W). """ # Normalize numpy array to range [0, 1] if image.dtype != np.float32: @@ -164,27 +176,29 @@ def pre_process(image: np.ndarray) -> np.ndarray: @staticmethod def post_process(predictions: OVDict) -> dict: - """Convert OpenVINO output dictionary to NumpyBatch.""" + """Convert OpenVINO predictions to dictionary. + + Args: + predictions (OVDict): Raw predictions from OpenVINO model. + + Returns: + dict: Dictionary of prediction tensors. + """ names = [next(iter(name)) for name in predictions.names()] values = predictions.to_tuple() return dict(zip(names, values, strict=False)) - def predict( - self, - image: str | Path | np.ndarray, - ) -> NumpyImageBatch: - """Perform a prediction for a given input image. - - The main workflow is (i) pre-processing, (ii) forward-pass, (iii) post-process. + def predict(self, image: str | Path | np.ndarray) -> NumpyImageBatch: + """Run inference on an input image. Args: - image (Union[str, np.ndarray]): Input image whose output is to be predicted. - It could be either a path to image or numpy array itself. - - metadata: Metadata information such as shape, threshold. + image (str | Path | np.ndarray): Input image as file path or array. Returns: - ImageResult: Prediction results to be visualized. + NumpyImageBatch: Batch containing the predictions. + + Raises: + TypeError: If image input is invalid type. """ # Convert file path or string to image if necessary if isinstance(image, str | Path): diff --git a/src/anomalib/deploy/inferencers/torch_inferencer.py b/src/anomalib/deploy/inferencers/torch_inferencer.py index ed4283ad82..a8ef5c0c8f 100644 --- a/src/anomalib/deploy/inferencers/torch_inferencer.py +++ b/src/anomalib/deploy/inferencers/torch_inferencer.py @@ -1,4 +1,37 @@ -"""Torch inference implementations.""" +"""PyTorch inferencer for running inference with trained anomaly detection models. + +This module provides the PyTorch inferencer implementation for running inference +with trained PyTorch models. + +Example: + Assume we have a PyTorch model saved as a ``.pt`` file: + + >>> from anomalib.deploy import TorchInferencer + >>> model = TorchInferencer(path="path/to/model.pt", device="cpu") + + Make predictions: + + >>> # From image path + >>> prediction = model.predict("path/to/image.jpg") + + >>> # From PIL Image + >>> from PIL import Image + >>> image = Image.open("path/to/image.jpg") + >>> prediction = model.predict(image) + + >>> # From torch tensor + >>> import torch + >>> image = torch.rand(3, 224, 224) + >>> prediction = model.predict(image) + + The prediction result contains anomaly maps and scores: + + >>> prediction.anomaly_map # doctest: +SKIP + tensor([[0.1, 0.2, ...]]) + + >>> prediction.pred_score # doctest: +SKIP + tensor(0.86) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,39 +46,23 @@ class TorchInferencer: - """PyTorch implementation for the inference. + """PyTorch inferencer for anomaly detection models. Args: - path (str | Path): Path to Torch model weights. - device (str): Device to use for inference. Options are ``auto``, - ``cpu``, ``cuda``. - Defaults to ``auto``. - - Examples: - Assume that we have a Torch ``pt`` model and metadata files in the - following structure: - - >>> from anomalib.deploy.inferencers import TorchInferencer - >>> inferencer = TorchInferencer(path="path/to/torch/model.pt", device="cpu") - - This will ensure that the model is loaded on the ``CPU`` device. To make - a prediction, we can simply call the ``predict`` method: - - >>> from anomalib.data.utils import read_image - >>> image = read_image("path/to/image.jpg") - >>> result = inferencer.predict(image) - - ``result`` will be an ``PredictBatch`` object containing the prediction - results. For example, to visualize the heatmap, we can do the following: - - >>> from matplotlib import pyplot as plt - >>> plt.imshow(result.heatmap) - - It is also possible to visualize the true and predicted masks if the - task is ``TaskType.SEGMENTATION``: - - >>> plt.imshow(result.gt_mask) - >>> plt.imshow(result.pred_mask) + path (str | Path): Path to the PyTorch model weights file. + device (str, optional): Device to use for inference. + Options are ``"auto"``, ``"cpu"``, ``"cuda"``, ``"gpu"``. + Defaults to ``"auto"``. + + Example: + >>> from anomalib.deploy import TorchInferencer + >>> model = TorchInferencer(path="path/to/model.pt") + >>> predictions = model.predict(image="path/to/image.jpg") + + Raises: + ValueError: If an invalid device is specified. + ValueError: If the model file has an unknown extension. + KeyError: If the checkpoint file does not contain a model. """ def __init__( @@ -63,10 +80,19 @@ def _get_device(device: str) -> torch.device: """Get the device to use for inference. Args: - device (str): Device to use for inference. Options are auto, cpu, cuda. + device (str): Device to use for inference. + Options are ``"auto"``, ``"cpu"``, ``"cuda"``, ``"gpu"``. Returns: - torch.device: Device to use for inference. + torch.device: PyTorch device object. + + Raises: + ValueError: If an invalid device is specified. + + Example: + >>> model = TorchInferencer(path="path/to/model.pt", device="cpu") + >>> model.device + device(type='cpu') """ if device not in {"auto", "cpu", "cuda", "gpu"}: msg = f"Unknown device {device}" @@ -79,19 +105,28 @@ def _get_device(device: str) -> torch.device: return torch.device(device) def _load_checkpoint(self, path: str | Path) -> dict: - """Load the checkpoint. + """Load the model checkpoint. Args: - path (str | Path): Path to the torch ckpt file. + path (str | Path): Path to the PyTorch checkpoint file. Returns: dict: Dictionary containing the model and metadata. + + Raises: + ValueError: If the model file has an unknown extension. + + Example: + >>> model = TorchInferencer(path="path/to/model.pt") + >>> checkpoint = model._load_checkpoint("path/to/model.pt") + >>> isinstance(checkpoint, dict) + True """ if isinstance(path, str): path = Path(path) if path.suffix not in {".pt", ".pth"}: - msg = f"Unknown torch checkpoint file format {path.suffix}. Make sure you save the Torch model." + msg = f"Unknown PyTorch checkpoint format {path.suffix}. Make sure you save the PyTorch model." raise ValueError(msg) return torch.load(path, map_location=self.device) @@ -100,32 +135,43 @@ def load_model(self, path: str | Path) -> nn.Module: """Load the PyTorch model. Args: - path (str | Path): Path to the Torch model. + path (str | Path): Path to the PyTorch model file. Returns: - (nn.Module): Torch model. + nn.Module: Loaded PyTorch model in evaluation mode. + + Raises: + KeyError: If the checkpoint file does not contain a model. + + Example: + >>> model = TorchInferencer(path="path/to/model.pt") + >>> isinstance(model.model, nn.Module) + True """ checkpoint = self._load_checkpoint(path) if "model" not in checkpoint: - msg = "``model`` is not found in the checkpoint. Please check the checkpoint file." + msg = "``model`` not found in checkpoint. Please check the checkpoint file." raise KeyError(msg) model = checkpoint["model"] model.eval() return model.to(self.device) - def predict( - self, - image: str | Path | torch.Tensor, - ) -> ImageBatch: - """Perform a prediction for a given input image. + def predict(self, image: str | Path | torch.Tensor) -> ImageBatch: + """Predict anomalies for an input image. Args: - image (Union[str, np.ndarray]): Input image whose output is to be predicted. - It could be either a path to image or the tensor itself. + image (str | Path | torch.Tensor): Input image to predict. + Can be a file path or PyTorch tensor. Returns: - ImageResult: Prediction results to be visualized. + ImageBatch: Prediction results containing anomaly maps and scores. + + Example: + >>> model = TorchInferencer(path="path/to/model.pt") + >>> predictions = model.predict("path/to/image.jpg") + >>> predictions.anomaly_map.shape # doctest: +SKIP + torch.Size([1, 256, 256]) """ if isinstance(image, str | Path): image = read_image(image, as_tensor=True) @@ -139,13 +185,20 @@ def predict( ) def pre_process(self, image: torch.Tensor) -> torch.Tensor: - """Pre process the input image. + """Pre-process the input image. Args: - image (torch.Tensor): Input image + image (torch.Tensor): Input image tensor. Returns: - Tensor: pre-processed image. + torch.Tensor: Pre-processed image tensor. + + Example: + >>> model = TorchInferencer(path="path/to/model.pt") + >>> image = torch.rand(3, 224, 224) + >>> processed = model.pre_process(image) + >>> processed.shape + torch.Size([1, 3, 224, 224]) """ if image.dim() == 3: image = image.unsqueeze(0) # model expects [B, C, H, W] diff --git a/src/anomalib/engine/__init__.py b/src/anomalib/engine/__init__.py index 3bb239f5a5..e887d4f7bb 100644 --- a/src/anomalib/engine/__init__.py +++ b/src/anomalib/engine/__init__.py @@ -1,4 +1,27 @@ -"""Anomalib engine.""" +"""Engine module for training and evaluating anomaly detection models. + +This module provides functionality for training and evaluating anomaly detection +models. The main component is the :class:`Engine` class which handles: + +- Model training and validation +- Metrics computation and logging +- Checkpointing and model export +- Distributed training support + +Example: + Create and use an engine: + + >>> from anomalib.engine import Engine + >>> engine = Engine() + >>> engine.train() # doctest: +SKIP + >>> engine.test() # doctest: +SKIP + + The engine can also be used with a custom configuration: + + >>> from anomalib.config import Config + >>> config = Config(path="config.yaml") + >>> engine = Engine(config=config) # doctest: +SKIP +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index a548dd23e4..ac577ad37a 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -1,4 +1,29 @@ -"""Implements custom trainer for Anomalib.""" +"""Implements custom trainer for Anomalib. + +This module provides the core training engine for Anomalib models. The Engine class +wraps PyTorch Lightning's Trainer with additional functionality specific to anomaly +detection tasks. + +The engine handles: +- Model training and validation +- Metrics computation and logging +- Checkpointing and model export +- Distributed training support + +Example: + Create and use an engine: + + >>> from anomalib.engine import Engine + >>> engine = Engine() + >>> engine.train() # doctest: +SKIP + >>> engine.test() # doctest: +SKIP + + The engine can also be used with a custom configuration: + + >>> from anomalib.config import Config + >>> config = Config(path="config.yaml") + >>> engine = Engine(config=config) # doctest: +SKIP +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -27,28 +52,33 @@ class UnassignedError(Exception): - """Unassigned error.""" + """Raised when a required component is not assigned.""" class _TrainerArgumentsCache: - """Cache arguments. + """Cache arguments for PyTorch Lightning Trainer. - Since the Engine class accepts PyTorch Lightning Trainer arguments, we store these arguments using this class - before the trainer is instantiated. + Since the Engine class accepts PyTorch Lightning Trainer arguments, we store + these arguments using this class before the trainer is instantiated. Args: - (**kwargs): Trainer arguments that are cached + **kwargs: Trainer arguments that are cached. Example: + >>> from omegaconf import OmegaConf >>> conf = OmegaConf.load("config.yaml") - >>> cache = _TrainerArgumentsCache(**conf.trainer) + >>> cache = _TrainerArgumentsCache(**conf.trainer) >>> cache.args { ... 'max_epochs': 100, 'val_check_interval': 0 } - >>> model = Padim(layers=["layer1", "layer2", "layer3"], input_size=(256, 256), backbone="resnet18") + >>> model = Padim( + ... layers=["layer1", "layer2", "layer3"], + ... input_size=(256, 256), + ... backbone="resnet18", + ... ) >>> cache.update(model) Overriding max_epochs from 100 with 1 for Padim Overriding val_check_interval from 0 with 1.0 for Padim @@ -64,10 +94,10 @@ def __init__(self, **kwargs) -> None: self._cached_args = {**kwargs} def update(self, model: AnomalibModule) -> None: - """Replace cached arguments with arguments retrieved from the model. + """Replace cached arguments with arguments from the model. Args: - model (AnomalibModule): The model used for training + model (AnomalibModule): The model used for training. """ for key, value in model.trainer_arguments.items(): if key in self._cached_args and self._cached_args[key] != value: @@ -77,35 +107,52 @@ def update(self, model: AnomalibModule) -> None: self._cached_args[key] = value def requires_update(self, model: AnomalibModule) -> bool: + """Check if the cache needs to be updated. + + Args: + model (AnomalibModule): Model to check against. + + Returns: + bool: True if cache needs update, False otherwise. + """ return any(self._cached_args.get(key, None) != value for key, value in model.trainer_arguments.items()) @property def args(self) -> dict[str, Any]: + """Get the cached arguments. + + Returns: + dict[str, Any]: Dictionary of cached trainer arguments. + """ return self._cached_args class Engine: - """Anomalib Engine. - - .. note:: + """Anomalib Engine for training and evaluating anomaly detection models. - Refer to PyTorch Lightning's Trainer for a list of parameters for - details on other Trainer parameters. + The Engine class wraps PyTorch Lightning's Trainer with additional + functionality specific to anomaly detection tasks. Args: - callbacks (list[Callback]): Add a callback or list of callbacks. - normalization (NORMALIZATION, optional): Normalization method. - Defaults to NormalizationMethod.MIN_MAX. - threshold (THRESHOLD): - Thresholding method. Defaults to "F1AdaptiveThreshold". - image_metrics (list[str] | str | dict[str, dict[str, Any]] | None, optional): Image metrics to be used for - evaluation. Defaults to None. - pixel_metrics (list[str] | str | dict[str, dict[str, Any]] | None, optional): Pixel metrics to be used for - evaluation. Defaults to None. - default_root_dir (str, optional): Default root directory for the trainer. - The results will be saved in this directory. - Defaults to ``results``. - **kwargs: PyTorch Lightning Trainer arguments. + callbacks (list[Callback] | None, optional): Add a callback or list of + callbacks. Defaults to None. + logger (Logger | Iterable[Logger] | bool | None, optional): Logger (or + iterable collection of loggers) to use. Defaults to None. + default_root_dir (str | Path, optional): Default path for saving trainer + outputs. Defaults to "results". + **kwargs: Additional arguments passed to PyTorch Lightning Trainer. + + Example: + >>> from anomalib.engine import Engine + >>> engine = Engine() + >>> engine.train() # doctest: +SKIP + >>> engine.test() # doctest: +SKIP + + With custom configuration: + + >>> from anomalib.config import Config + >>> config = Config(path="config.yaml") + >>> engine = Engine(config=config) # doctest: +SKIP """ def __init__( diff --git a/src/anomalib/loggers/__init__.py b/src/anomalib/loggers/__init__.py index 8c47306ddc..953a5974ca 100644 --- a/src/anomalib/loggers/__init__.py +++ b/src/anomalib/loggers/__init__.py @@ -1,4 +1,26 @@ -"""Load PyTorch Lightning Loggers.""" +"""Logging configuration and PyTorch Lightning logger integrations. + +This module provides logging utilities and integrations with various logging frameworks +for use with anomaly detection models. The main components are: + +- Console logging configuration via ``configure_logger()`` +- Integration with logging frameworks: + - Comet ML via :class:`AnomalibCometLogger` + - MLflow via :class:`AnomalibMLFlowLogger` + - TensorBoard via :class:`AnomalibTensorBoardLogger` + - Weights & Biases via :class:`AnomalibWandbLogger` + +Example: + Configure console logging: + + >>> from anomalib.loggers import configure_logger + >>> configure_logger(level="INFO") + + Use a specific logger: + + >>> from anomalib.loggers import AnomalibTensorBoardLogger + >>> logger = AnomalibTensorBoardLogger(log_dir="logs") +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -7,10 +29,7 @@ from rich.logging import RichHandler -__all__ = [ - "configure_logger", - "get_experiment_logger", -] +__all__ = ["configure_logger"] try: from .comet import AnomalibCometLogger # noqa: F401 @@ -31,13 +50,23 @@ def configure_logger(level: int | str = logging.INFO) -> None: - """Get console logger by name. + """Configure console logging with consistent formatting. + + This function sets up console logging with a standardized format and rich + tracebacks. It configures both the root logger and PyTorch Lightning logger + to use the same formatting. Args: - level (int | str, optional): Logger Level. Defaults to logging.INFO. + level (int | str): Logging level to use. Can be either a string name like + ``"INFO"`` or an integer constant like ``logging.INFO``. Defaults to + ``logging.INFO``. - Returns: - Logger: The expected logger. + Example: + >>> from anomalib.loggers import configure_logger + >>> configure_logger(level="DEBUG") # doctest: +SKIP + >>> logger = logging.getLogger("my_logger") + >>> logger.info("Test message") # doctest: +SKIP + 2024-01-01 12:00:00 - my_logger - INFO - Test message """ if isinstance(level, str): level = logging.getLevelName(level) diff --git a/src/anomalib/loggers/base.py b/src/anomalib/loggers/base.py index 485ae07f49..7ced9afced 100644 --- a/src/anomalib/loggers/base.py +++ b/src/anomalib/loggers/base.py @@ -1,4 +1,21 @@ -"""Base logger for image logging consistency across all loggers used in anomalib.""" +"""Base logger for image logging consistency across all loggers used in anomalib. + +This module provides a base class that defines a common interface for logging images +across different logging backends used in anomalib. + +Example: + Create a custom image logger: + + >>> class CustomImageLogger(ImageLoggerBase): + ... def add_image(self, image, name=None): + ... # Custom implementation + ... pass + + Use the logger: + + >>> logger = CustomImageLogger() + >>> logger.add_image(image_array, name="test_image") # doctest: +SKIP +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,9 +27,28 @@ class ImageLoggerBase: - """Adds a common interface for logging the images.""" + """Base class that provides a common interface for logging images. + + This abstract base class ensures consistent image logging functionality across + different logger implementations in anomalib. + + All custom image loggers should inherit from this class and implement the + ``add_image`` method. + """ @abstractmethod def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: - """Interface to log images in the respective loggers.""" + """Log an image using the respective logger implementation. + + Args: + image: Image to be logged, can be either a numpy array or matplotlib + Figure + name: Name/title of the image. Defaults to ``None`` + **kwargs: Additional keyword arguments passed to the specific logger + implementation + + Raises: + NotImplementedError: This is an abstract method that must be + implemented by subclasses + """ raise NotImplementedError diff --git a/src/anomalib/loggers/comet.py b/src/anomalib/loggers/comet.py index d946d9036f..4ca51ea101 100644 --- a/src/anomalib/loggers/comet.py +++ b/src/anomalib/loggers/comet.py @@ -1,4 +1,24 @@ -"""comet logger with add image interface.""" +"""Comet logger with image logging capabilities. + +This module provides a Comet logger implementation that adds an interface for +logging images. It extends both the base image logger and PyTorch Lightning's +Comet logger. + +Example: + >>> from anomalib.loggers import AnomalibCometLogger + >>> from anomalib.engine import Engine + >>> comet_logger = AnomalibCometLogger() # doctest: +SKIP + >>> engine = Engine(logger=comet_logger) # doctest: +SKIP + + Log an image: + >>> import numpy as np + >>> image = np.random.rand(32, 32, 3) # doctest: +SKIP + >>> comet_logger.add_image( + ... image=image, + ... name="test_image", + ... global_step=0 + ... ) # doctest: +SKIP +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,75 +36,52 @@ class AnomalibCometLogger(ImageLoggerBase, CometLogger): - """Logger for comet. + """Logger for Comet ML with image logging capabilities. - Adds interface for ``add_image`` in the logger rather than calling the - experiment object. - - .. note:: - Same as the CometLogger provided by PyTorch Lightning and the doc string - is reproduced below. - - Track your parameters, metrics, source code and more using - `Comet `_. - - Install it with pip: - - .. code-block:: bash - - pip install comet-ml - - Comet requires either an API Key (online mode) or a local directory path - (offline mode). + This logger extends PyTorch Lightning's CometLogger with an interface for + logging images. It inherits from both :class:`ImageLoggerBase` and + :class:`CometLogger`. Args: - api_key: Required in online mode. API key, found on Comet.ml. If not - given, this will be loaded from the environment variable - COMET_API_KEY or ~/.comet.config if either exists. - Defaults to ``None``. - save_dir: Required in offline mode. The path for the directory to save - local comet logs. If given, this also sets the directory for saving - checkpoints. + api_key: API key found on Comet.ml. If not provided, will be loaded from + ``COMET_API_KEY`` environment variable or ``~/.comet.config``. + Required for online mode. Defaults to ``None``. - project_name: Optional. Send your experiment to a specific project. - Otherwise will be sent to Uncategorized Experiments. - If the project name does not already exist, Comet.ml will create a - new project. + save_dir: Directory path to save local comet logs. Required for offline + mode. Also sets checkpoint directory if provided. Defaults to ``None``. - rest_api_key: Optional. Rest API key found in Comet.ml settings. - This is used to determine version number + project_name: Project name for the experiment. Creates new project if + doesn't exist. Defaults to ``None``. - experiment_name: Optional. String representing the name for this - particular experiment on Comet.ml. + rest_api_key: Rest API key from Comet.ml settings. Used for version + tracking. Defaults to ``None``. - experiment_key: Optional. If set, restores from existing experiment. + experiment_name: Name for this experiment on Comet.ml. Defaults to ``None``. - offline: If api_key and save_dir are both given, this determines whether - the experiment will be in online or offline mode. This is useful if - you use save_dir to control the checkpoints directory and have a - ~/.comet.config file but still want to run offline experiments. + experiment_key: Key to restore existing experiment. Defaults to ``None``. - prefix: A string to put at the beginning of metric keys. + offline: Force offline mode even with API key. Useful when using + ``save_dir`` for checkpoints with ``~/.comet.config``. + Defaults to ``False``. + prefix: String to prepend to metric keys. Defaults to ``""``. - kwargs: Additional arguments like `workspace`, `log_code`, etc. used by - :class:`CometExperiment` can be passed as keyword arguments in this - logger. + **kwargs: Additional arguments passed to :class:`CometExperiment` + (e.g. ``workspace``, ``log_code``). Raises: - ModuleNotFoundError: - If required Comet package is not installed on the device. - MisconfigurationException: - If neither ``api_key`` nor ``save_dir`` are passed as arguments. + ModuleNotFoundError: If ``comet-ml`` package is not installed. + MisconfigurationException: If neither ``api_key`` nor ``save_dir`` + provided. Example: >>> from anomalib.loggers import AnomalibCometLogger - >>> from anomalib.engine import Engine - ... - >>> comet_logger = AnomalibCometLogger() - >>> engine = Engine(logger=comet_logger) + >>> comet_logger = AnomalibCometLogger( + ... project_name="anomaly_detection" + ... ) # doctest: +SKIP - See Also: - - `Comet Documentation `__ + Note: + For more details, see the `Comet Documentation + `_ """ def __init__( @@ -114,13 +111,36 @@ def __init__( @rank_zero_only def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: - """Interface to add image to comet logger. + """Log an image to Comet. Args: - image (np.ndarray | Figure): Image to log. - name (str | None): The tag of the image + image: Image to log, either numpy array or matplotlib figure. + name: Name/tag for the image. Defaults to ``None``. - kwargs: Accepts only `global_step` (int). The step at which to log the image. + **kwargs: Must contain ``global_step`` (int) indicating the step at + which to log the image. + + Raises: + ValueError: If ``global_step`` not provided in kwargs. + + Example: + >>> import numpy as np + >>> from matplotlib.figure import Figure + >>> logger = AnomalibCometLogger() # doctest: +SKIP + >>> # Log numpy array + >>> image_array = np.random.rand(32, 32, 3) # doctest: +SKIP + >>> logger.add_image( + ... image=image_array, + ... name="test_image", + ... global_step=0 + ... ) # doctest: +SKIP + >>> # Log matplotlib figure + >>> fig = Figure() # doctest: +SKIP + >>> logger.add_image( + ... image=fig, + ... name="test_figure", + ... global_step=1 + ... ) # doctest: +SKIP """ if "global_step" not in kwargs: msg = "`global_step` is required for comet logger" diff --git a/src/anomalib/loggers/mlflow.py b/src/anomalib/loggers/mlflow.py index f6ec089586..7b647b3108 100644 --- a/src/anomalib/loggers/mlflow.py +++ b/src/anomalib/loggers/mlflow.py @@ -1,4 +1,23 @@ -"""MLFlow logger with add image interface.""" +"""MLFlow logger with image logging capabilities. + +This module provides an MLFlow logger implementation that adds an interface for +logging images. It extends both the base image logger and PyTorch Lightning's +MLFlow logger. + +Example: + >>> from anomalib.loggers import AnomalibMLFlowLogger + >>> from anomalib.engine import Engine + >>> mlflow_logger = AnomalibMLFlowLogger() + >>> engine = Engine(logger=mlflow_logger) # doctest: +SKIP + + Log an image: + >>> import numpy as np + >>> image = np.random.rand(32, 32, 3) # doctest: +SKIP + >>> mlflow_logger.add_image( + ... image=image, + ... name="test_image" + ... ) # doctest: +SKIP +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -15,54 +34,47 @@ class AnomalibMLFlowLogger(ImageLoggerBase, MLFlowLogger): - """Logger for MLFlow. + """Logger for MLFlow with image logging capabilities. - Adds interface for ``add_image`` in the logger rather than calling the - experiment object. - - .. note:: - Same as the MLFlowLogger provided by PyTorch Lightning and the doc string is reproduced below. - - Track your parameters, metrics, source code and more using - `MLFlow `_. - - Install it with pip: - - .. code-block:: bash - - pip install mlflow + This logger extends PyTorch Lightning's MLFlowLogger with an interface for + logging images. It inherits from both :class:`ImageLoggerBase` and + :class:`MLFlowLogger`. Args: - experiment_name: The name of the experiment. - run_name: Name of the new run. - The `run_name` is internally stored as a ``mlflow.runName`` tag. - If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`. - tracking_uri: Address of local or remote tracking server. - If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls - back to `file:`. - save_dir: A path to a local directory where the MLflow runs get saved. - Defaults to `./mlruns` if `tracking_uri` is not provided. - Has no effect if `tracking_uri` is provided. - log_model: Log checkpoints created by `ModelCheckpoint` as MLFlow artifacts. - - - if ``log_model == 'all'``, checkpoints are logged during training. - - if ``log_model == True``, checkpoints are logged at the end of training, \ - except when `save_top_k == -1` which also logs every checkpoint during training. - - if ``log_model == False`` (default), no checkpoint is logged. - - prefix: A string to put at the beginning of metric keys. Defaults to ``''``. - kwargs: Additional arguments like `tags`, `artifact_location` etc. used by - `MLFlowExperiment` can be passed as keyword arguments in this logger. + experiment_name: Name of the experiment. If not provided, defaults to + ``"anomalib_logs"``. + run_name: Name of the new run. The ``run_name`` is internally stored as + a ``mlflow.runName`` tag. If the ``mlflow.runName`` tag has already + been set in ``tags``, the value is overridden by the ``run_name``. + tracking_uri: Address of local or remote tracking server. If not provided, + defaults to ``MLFLOW_TRACKING_URI`` environment variable if set, + otherwise falls back to ``file:``. + save_dir: Path to local directory where MLflow runs are saved. Defaults + to ``"./mlruns"`` if ``tracking_uri`` is not provided. Has no effect + if ``tracking_uri`` is provided. + log_model: Log checkpoints created by ``ModelCheckpoint`` as MLFlow + artifacts: + + - if ``"all"``: checkpoints are logged during training + - if ``True``: checkpoints are logged at end of training (except when + ``save_top_k == -1`` which logs every checkpoint during training) + - if ``False`` (default): no checkpoints are logged + + prefix: String to prepend to metric keys. Defaults to ``""``. + **kwargs: Additional arguments like ``tags``, ``artifact_location`` etc. + used by ``MLFlowExperiment``. Example: >>> from anomalib.loggers import AnomalibMLFlowLogger >>> from anomalib.engine import Engine - ... - >>> mlflow_logger = AnomalibMLFlowLogger() - >>> engine = Engine(logger=mlflow_logger) + >>> mlflow_logger = AnomalibMLFlowLogger( + ... experiment_name="my_experiment", + ... run_name="my_run" + ... ) # doctest: +SKIP + >>> engine = Engine(logger=mlflow_logger) # doctest: +SKIP See Also: - - `MLFlow Documentation `_. + - `MLFlow Documentation `_ """ def __init__( @@ -87,14 +99,14 @@ def __init__( @rank_zero_only def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: - """Interface to log images in the mlflow loggers. + """Log images to MLflow. Args: - image (np.ndarray | Figure): Image to log. - name (str | None): The tag of the image defaults to ``None``. - kwargs: Additional keyword arguments that are only used if `image` is of type Figure. - These arguments are passed directly to the method that saves the figure. - If `image` is a NumPy array, `kwargs` has no effect. + image: Image to log, can be either a numpy array or matplotlib Figure. + name: Name/title of the image. Defaults to ``None``. + **kwargs: Additional keyword arguments passed to the MLflow logging + method when ``image`` is a Figure. Has no effect when ``image`` + is a numpy array. """ # Need to call different functions of `Experiment` for Figure vs np.ndarray if isinstance(image, Figure): diff --git a/src/anomalib/loggers/tensorboard.py b/src/anomalib/loggers/tensorboard.py index 3d02e457ac..60bd965f15 100644 --- a/src/anomalib/loggers/tensorboard.py +++ b/src/anomalib/loggers/tensorboard.py @@ -1,4 +1,24 @@ -"""Tensorboard logger with add image interface.""" +"""TensorBoard logger with image logging capabilities. + +This module provides a TensorBoard logger implementation that adds an interface for +logging images. It extends both the base image logger and PyTorch Lightning's +TensorBoard logger. + +Example: + >>> from anomalib.loggers import AnomalibTensorBoardLogger + >>> from anomalib.engine import Engine + >>> tensorboard_logger = AnomalibTensorBoardLogger("logs") + >>> engine = Engine(logger=tensorboard_logger) # doctest: +SKIP + + Log an image: + >>> import numpy as np + >>> image = np.random.rand(32, 32, 3) # doctest: +SKIP + >>> tensorboard_logger.add_image( + ... image=image, + ... name="test_image", + ... global_step=0 + ... ) # doctest: +SKIP +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -18,48 +38,43 @@ class AnomalibTensorBoardLogger(ImageLoggerBase, TensorBoardLogger): - """Logger for tensorboard. + """Logger for TensorBoard with image logging capabilities. - Adds interface for `add_image` in the logger rather than calling the experiment object. + This logger extends PyTorch Lightning's TensorBoardLogger with an interface + for logging images. It inherits from both :class:`ImageLoggerBase` and + :class:`TensorBoardLogger`. - .. note:: - Same as the Tensorboard Logger provided by PyTorch Lightning and the doc string is reproduced below. - - Logs are saved to - ``os.path.join(save_dir, name, version)``. This is the default logger in Lightning, it comes - preinstalled. + Args: + save_dir: Directory path where logs will be saved. The final path will be + ``os.path.join(save_dir, name, version)``. + name: Name of the experiment. If it is an empty string, no + per-experiment subdirectory is used. Defaults to ``"default"``. + version: Version of the experiment. If not specified, the logger checks + the save directory for existing versions and assigns the next + available one. If a string is provided, it is used as the + run-specific subdirectory name. Otherwise ``"version_${version}"`` is + used. Defaults to ``None``. + log_graph: If ``True``, adds the computational graph to TensorBoard. This + requires that the model has defined the ``example_input_array`` + attribute. Defaults to ``False``. + default_hp_metric: If ``True``, enables a placeholder metric with key + ``hp_metric`` when ``log_hyperparams`` is called without a metric. + Defaults to ``True``. + prefix: String to prepend to metric keys. Defaults to ``""``. + **kwargs: Additional arguments like ``comment``, ``filename_suffix``, + etc. used by :class:`SummaryWriter`. Example: - >>> from anomalib.engine import Engine >>> from anomalib.loggers import AnomalibTensorBoardLogger - ... - >>> logger = AnomalibTensorBoardLogger("tb_logs", name="my_model") - >>> engine = Engine(logger=logger) - - Args: - save_dir (str): Save directory - name (str | None): Experiment name. Defaults to ``'default'``. - If it is the empty string then no per-experiment subdirectory is used. - Default: ``'default'``. - version (int | str | None): Experiment version. If version is not - specified the logger inspects the save directory for existing - versions, then automatically assigns the next available version. - If it is a string then it is used as the run-specific subdirectory - name, otherwise ``'version_${version}'`` is used. - Defaults to ``None`` - log_graph (bool): Adds the computational graph to tensorboard. This - requires that the user has defined the `self.example_input_array` - attribute in their model. - Defaults to ``False``. - default_hp_metric (bool): Enables a placeholder metric with key - ``hp_metric`` when ``log_hyperparams`` is called without a metric - (otherwise calls to log_hyperparams without a metric are ignored). - Defaults to ``True``. - prefix (str): A string to put at the beginning of metric keys. - Defaults to ``''``. - **kwargs: Additional arguments like `comment`, `filename_suffix`, etc. - used by :class:`SummaryWriter` can be passed as keyword arguments in - this logger. + >>> from anomalib.engine import Engine + >>> logger = AnomalibTensorBoardLogger( + ... save_dir="logs", + ... name="my_experiment" + ... ) # doctest: +SKIP + >>> engine = Engine(logger=logger) # doctest: +SKIP + + See Also: + - `TensorBoard Documentation `_ """ def __init__( @@ -85,13 +100,18 @@ def __init__( @rank_zero_only def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: - """Interface to add image to tensorboard logger. + """Log images to TensorBoard. Args: - image (np.ndarray | Figure): Image to log - name (str | None): The tag of the image - Defaults to ``None``. - kwargs: Accepts only `global_step` (int). The step at which to log the image. + image: Image to log, can be either a numpy array or matplotlib + Figure. + name: Name/title of the image. Defaults to ``None``. + **kwargs: Must contain ``global_step`` (int) indicating the step at + which to log the image. Additional keyword arguments are passed + to the TensorBoard logging method. + + Raises: + ValueError: If ``global_step`` is not provided in ``kwargs``. """ if "global_step" not in kwargs: msg = "`global_step` is required for tensorboard logger" diff --git a/src/anomalib/loggers/wandb.py b/src/anomalib/loggers/wandb.py index ff41a0949e..a53ad2e82d 100644 --- a/src/anomalib/loggers/wandb.py +++ b/src/anomalib/loggers/wandb.py @@ -1,4 +1,23 @@ -"""wandb logger with add image interface.""" +"""Weights & Biases logger with image logging capabilities. + +This module provides a Weights & Biases logger implementation that adds an +interface for logging images. It extends both the base image logger and PyTorch +Lightning's WandbLogger. + +Example: + >>> from anomalib.loggers import AnomalibWandbLogger + >>> from anomalib.engine import Engine + >>> wandb_logger = AnomalibWandbLogger() # doctest: +SKIP + >>> engine = Engine(logger=wandb_logger) # doctest: +SKIP + + Log an image: + >>> import numpy as np + >>> image = np.random.rand(32, 32, 3) # doctest: +SKIP + >>> wandb_logger.add_image( + ... image=image, + ... name="test_image" + ... ) # doctest: +SKIP +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -23,68 +42,59 @@ class AnomalibWandbLogger(ImageLoggerBase, WandbLogger): - """Logger for wandb. + """Logger for Weights & Biases with image logging capabilities. - Adds interface for `add_image` in the logger rather than calling the experiment object. - - .. note:: - Same as the wandb Logger provided by PyTorch Lightning and the doc string is reproduced below. - - Log using `Weights and Biases `_. - - Install it with pip: - - .. code-block:: bash - - $ pip install wandb + This logger extends PyTorch Lightning's WandbLogger with an interface for + logging images. It inherits from both :class:`ImageLoggerBase` and + :class:`WandbLogger`. Args: - name: Display name for the run. - Defaults to ``None``. + name: Display name for the run. Defaults to ``None``. save_dir: Path where data is saved (wandb dir by default). - Defaults to ``None``. + Defaults to ``"."``. version: Sets the version, mainly used to resume a previous run. + Defaults to ``None``. offline: Run offline (data can be streamed later to wandb servers). Defaults to ``False``. - dir: Alias for save_dir. + dir: Alias for ``save_dir``. Defaults to ``None``. id: Sets the version, mainly used to resume a previous run. Defaults to ``None``. anonymous: Enables or explicitly disables anonymous logging. Defaults to ``None``. - version: Same as id. - Defaults to ``None``. project: The name of the project to which this run will belong. Defaults to ``None``. log_model: Save checkpoints in wandb dir to upload on W&B servers. Defaults to ``False``. - experiment: WandB experiment object. Automatically set when creating a run. - Defaults to ``None``. + experiment: WandB experiment object. Automatically set when creating a + run. Defaults to ``None``. prefix: A string to put at the beginning of metric keys. - Defaults to ``''``. - **kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. + Defaults to ``""``. + checkpoint_name: Name of the checkpoint to save. + Defaults to ``None``. + **kwargs: Additional arguments passed to :func:`wandb.init` like + ``entity``, ``group``, ``tags``, etc. Raises: - ImportError: - If required WandB package is not installed on the device. - MisconfigurationException: - If both ``log_model`` and ``offline``is set to ``True``. + ImportError: If required WandB package is not installed. + MisconfigurationException: If both ``log_model`` and ``offline`` are + set to ``True``. Example: >>> from anomalib.loggers import AnomalibWandbLogger >>> from anomalib.engine import Engine - ... - >>> wandb_logger = AnomalibWandbLogger() - >>> engine = Engine(logger=wandb_logger) + >>> wandb_logger = AnomalibWandbLogger( + ... project="my_project", + ... name="my_run" + ... ) # doctest: +SKIP + >>> engine = Engine(logger=wandb_logger) # doctest: +SKIP - .. note:: - When logging manually through `wandb.log` or `trainer.logger.experiment.log`, - make sure to use `commit=False` so the logging step does not increase. + Note: + When logging manually through ``wandb.log`` or + ``trainer.logger.experiment.log``, make sure to use ``commit=False`` + so the logging step does not increase. See Also: - - `Tutorial `__ - on how to use W&B with PyTorch Lightning - - `W&B Documentation `__ - + - `W&B Documentation `_ """ def __init__( @@ -122,13 +132,14 @@ def __init__( @rank_zero_only def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwargs) -> None: - """Interface to add image to wandb logger. + """Log an image to Weights & Biases. Args: - image (np.ndarray | Figure): Image to log - name (str | None): The tag of the image - Defaults to ``None``. - kwargs: Additional arguments to `wandb.Image` + image: Image to log, can be either a numpy array or matplotlib + Figure. + name: Name/title of the image. Defaults to ``None``. + **kwargs: Additional keyword arguments passed to + :class:`wandb.Image`. Currently unused. """ del kwargs # Unused argument. @@ -137,10 +148,11 @@ def add_image(self, image: np.ndarray | Figure, name: str | None = None, **kwarg @rank_zero_only def save(self) -> None: - """Upload images to wandb server. + """Upload images to Weights & Biases server. - .. note:: - There is a limit on the number of images that can be logged together to the `wandb` server. + Note: + There is a limit on the number of images that can be logged together + to the W&B server. """ super().save() if len(self.image_list) > 1: diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 6f606bc826..d04cdcba68 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -1,4 +1,41 @@ -"""Custom anomaly evaluation metrics.""" +"""Custom metrics for evaluating anomaly detection models. + +This module provides various metrics for evaluating anomaly detection performance: + +- Area Under Curve (AUC) metrics: + - ``AUROC``: Area Under Receiver Operating Characteristic curve + - ``AUPR``: Area Under Precision-Recall curve + - ``AUPRO``: Area Under Per-Region Overlap curve + - ``AUPIMO``: Area Under Per-Image Missed Overlap curve + +- F1-score metrics: + - ``F1Score``: Standard F1 score + - ``F1Max``: Maximum F1 score across thresholds + +- Threshold metrics: + - ``F1AdaptiveThreshold``: Finds optimal threshold by maximizing F1 score + - ``ManualThreshold``: Uses manually specified threshold + +- Other metrics: + - ``AnomalibMetric``: Base class for custom metrics + - ``AnomalyScoreDistribution``: Analyzes score distributions + - ``BinaryPrecisionRecallCurve``: Computes precision-recall curves + - ``Evaluator``: Combines multiple metrics for evaluation + - ``MinMax``: Normalizes scores to [0,1] range + - ``PRO``: Per-Region Overlap score + - ``PIMO``: Per-Image Missed Overlap score + +Example: + >>> from anomalib.metrics import AUROC, F1Score + >>> auroc = AUROC() + >>> f1 = F1Score() + >>> labels = torch.tensor([0, 1, 0, 1]) + >>> scores = torch.tensor([0.1, 0.9, 0.2, 0.8]) + >>> auroc(scores, labels) + tensor(1.) + >>> f1(scores, labels, threshold=0.5) + tensor(1.) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/metrics/anomaly_score_distribution.py b/src/anomalib/metrics/anomaly_score_distribution.py index d95e863bbf..bbce60e069 100644 --- a/src/anomalib/metrics/anomaly_score_distribution.py +++ b/src/anomalib/metrics/anomaly_score_distribution.py @@ -1,4 +1,28 @@ -"""Module that computes the parameters of the normal data distribution of the training set.""" +"""Compute statistics of anomaly score distributions. + +This module provides the ``AnomalyScoreDistribution`` class which computes mean +and standard deviation statistics of anomaly scores from normal training data. +Statistics are computed for both image-level and pixel-level scores. + +The class tracks: + - Image-level statistics: Mean and std of image anomaly scores + - Pixel-level statistics: Mean and std of pixel anomaly maps + +Example: + >>> from anomalib.metrics import AnomalyScoreDistribution + >>> import torch + >>> # Create sample data + >>> scores = torch.tensor([0.1, 0.2, 0.15]) # Image anomaly scores + >>> maps = torch.tensor([[0.1, 0.2], [0.15, 0.25]]) # Pixel anomaly maps + >>> # Initialize and compute stats + >>> dist = AnomalyScoreDistribution() + >>> dist.update(anomaly_scores=scores, anomaly_maps=maps) + >>> image_mean, image_std, pixel_mean, pixel_std = dist.compute() + +Note: + The input scores and maps are log-transformed before computing statistics. + Both image-level scores and pixel-level maps are optional inputs. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,9 +32,30 @@ class AnomalyScoreDistribution(Metric): - """Mean and standard deviation of the anomaly scores of normal training data.""" + """Compute distribution statistics of anomaly scores. + + This class tracks and computes the mean and standard deviation of anomaly + scores from the normal samples in the training set. Statistics are computed + for both image-level scores and pixel-level anomaly maps. + + The metric maintains internal state to accumulate scores and maps across + batches before computing final statistics. + + Example: + >>> dist = AnomalyScoreDistribution() + >>> # Update with batch of scores + >>> scores = torch.tensor([0.1, 0.2, 0.3]) + >>> dist.update(anomaly_scores=scores) + >>> # Compute statistics + >>> img_mean, img_std, pix_mean, pix_std = dist.compute() + """ def __init__(self, **kwargs) -> None: + """Initialize the metric states. + + Args: + **kwargs: Additional arguments passed to parent class. + """ super().__init__(**kwargs) self.anomaly_maps: list[torch.Tensor] = [] self.anomaly_scores: list[torch.Tensor] = [] @@ -32,7 +77,14 @@ def update( anomaly_maps: torch.Tensor | None = None, **kwargs, ) -> None: - """Update the precision-recall curve metric.""" + """Update the internal state with new scores and maps. + + Args: + *args: Unused positional arguments. + anomaly_scores: Batch of image-level anomaly scores. + anomaly_maps: Batch of pixel-level anomaly maps. + **kwargs: Unused keyword arguments. + """ del args, kwargs # These variables are not used. if anomaly_maps is not None: @@ -41,7 +93,15 @@ def update( self.anomaly_scores.append(anomaly_scores) def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute stats.""" + """Compute distribution statistics from accumulated scores and maps. + + Returns: + tuple containing: + - image_mean: Mean of log-transformed image anomaly scores + - image_std: Standard deviation of log-transformed image scores + - pixel_mean: Mean of log-transformed pixel anomaly maps + - pixel_std: Standard deviation of log-transformed pixel maps + """ anomaly_scores = torch.hstack(self.anomaly_scores) anomaly_scores = torch.log(anomaly_scores) diff --git a/src/anomalib/metrics/aupr.py b/src/anomalib/metrics/aupr.py index 5856a1ae5f..06cf2d9852 100644 --- a/src/anomalib/metrics/aupr.py +++ b/src/anomalib/metrics/aupr.py @@ -1,4 +1,34 @@ -"""Implementation of AUROC metric based on TorchMetrics.""" +"""Area Under the Precision-Recall Curve (AUPR) metric. + +This module provides the ``AUPR`` class which computes the area under the +precision-recall curve for evaluating anomaly detection performance. + +The AUPR score summarizes the trade-off between precision and recall across +different thresholds. It is particularly useful for imbalanced datasets where +anomalies are rare. + +Example: + >>> from anomalib.metrics import AUPR + >>> import torch + >>> # Create sample data + >>> labels = torch.tensor([0, 0, 1, 1]) # Binary labels + >>> scores = torch.tensor([0.1, 0.2, 0.8, 0.9]) # Anomaly scores + >>> # Initialize and compute AUPR + >>> metric = AUPR() + >>> aupr_score = metric(scores, labels) + >>> aupr_score + tensor(1.0) + +The metric can also be updated incrementally with batches: + + >>> for batch_scores, batch_labels in dataloader: + ... metric.update(batch_scores, batch_labels) + >>> final_score = metric.compute() + +Note: + The AUPR score ranges from 0 to 1, with 1 indicating perfect ranking of + anomalies above normal samples. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/metrics/aupro.py b/src/anomalib/metrics/aupro.py index 0b5ac69d58..9f80c01686 100644 --- a/src/anomalib/metrics/aupro.py +++ b/src/anomalib/metrics/aupro.py @@ -1,4 +1,52 @@ -"""Implementation of AUPRO score based on TorchMetrics.""" +"""Area Under Per-Region Overlap (AUPRO) metric. + +This module provides the ``AUPRO`` class which computes the area under the +per-region overlap curve for evaluating anomaly segmentation performance. + +The AUPRO score measures how well predicted anomaly maps overlap with ground truth +anomaly regions. It is computed by: + +1. Performing connected component analysis on ground truth masks +2. Computing per-region ROC curves for each component +3. Averaging the curves and computing area under the curve up to a FPR limit + +Example: + >>> from anomalib.metrics import AUPRO + >>> import torch + >>> # Create sample data + >>> labels = torch.randint(0, 2, (1, 10, 5)) # Binary segmentation masks + >>> scores = torch.rand_like(labels) # Anomaly segmentation maps + >>> # Initialize and compute AUPRO + >>> metric = AUPRO(fpr_limit=0.3) + >>> aupro_score = metric(scores, labels) + >>> aupro_score + tensor(0.4321) + +The metric can also be updated incrementally with batches: + + >>> for batch_scores, batch_labels in dataloader: + ... metric.update(batch_scores, batch_labels) + >>> final_score = metric.compute() + +Args: + dist_sync_on_step (bool): Synchronize metric state across processes at each + ``forward()`` before returning the value at the step. + Defaults to ``False``. + process_group (Any | None): Specify the process group on which + synchronization is called. Defaults to ``None`` (entire world). + dist_sync_fn (Callable | None): Callback that performs the allgather + operation on the metric state. When ``None``, DDP will be used. + Defaults to ``None``. + fpr_limit (float): Limit for the false positive rate. + Defaults to ``0.3``. + num_thresholds (int | None): Number of thresholds to use for computing the + ROC curve. When ``None``, uses thresholds from torchmetrics. + Defaults to ``None``. + +Note: + The AUPRO score ranges from 0 to 1, with 1 indicating perfect overlap between + predictions and ground truth regions. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -24,30 +72,32 @@ class _AUPRO(Metric): """Area under per region overlap (AUPRO) Metric. Args: - dist_sync_on_step (bool): Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. Default: ``False`` - process_group (Optional[Any]): Specify the process group on which synchronization is called. - Default: ``None`` (which selects the entire world) - dist_sync_fn (Optional[Callable]): Callback that performs the allgather operation on the metric state. - When ``None``, DDP will be used to perform the allgather. - Default: ``None`` - fpr_limit (float): Limit for the false positive rate. Defaults to ``0.3``. - num_thresholds (int): Number of thresholds to use for computing the roc curve. Defaults to ``None``. - If ``None``, the roc curve is computed with the thresholds returned by - ``torchmetrics.functional.classification.thresholds``. + dist_sync_on_step (bool): Synchronize metric state across processes at + each ``forward()`` before returning the value at the step. + Defaults to ``False``. + process_group (Any | None): Specify the process group on which + synchronization is called. Defaults to ``None`` (entire world). + dist_sync_fn (Callable | None): Callback that performs the allgather + operation on the metric state. When ``None``, DDP will be used. + Defaults to ``None``. + fpr_limit (float): Limit for the false positive rate. + Defaults to ``0.3``. + num_thresholds (int | None): Number of thresholds to use for computing + the ROC curve. When ``None``, uses thresholds from torchmetrics. + Defaults to ``None``. Examples: >>> import torch >>> from anomalib.metrics import AUPRO - ... - >>> labels = torch.randint(low=0, high=2, size=(1, 10, 5), dtype=torch.float32) + >>> # Create sample data + >>> labels = torch.randint(0, 2, (1, 10, 5), dtype=torch.float32) >>> preds = torch.rand_like(labels) - ... + >>> # Initialize and compute >>> aupro = AUPRO(fpr_limit=0.3) >>> aupro(preds, labels) tensor(0.4321) - Increasing the fpr_limit will increase the AUPRO value: + Increasing the ``fpr_limit`` will increase the AUPRO value: >>> aupro = AUPRO(fpr_limit=0.7) >>> aupro(preds, labels) @@ -59,11 +109,13 @@ class _AUPRO(Metric): full_state_update: bool = False preds: list[torch.Tensor] target: list[torch.Tensor] - # When not None, the computation is performed in constant-memory by computing the roc curve - # for fixed thresholds buckets/thresholds. - # Warning: The thresholds are evenly distributed between the min and max predictions - # if all predictions are inside [0, 1]. Otherwise, the thresholds are evenly distributed between 0 and 1. - # This warning can be removed when https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed + # When not None, the computation is performed in constant-memory by computing + # the roc curve for fixed thresholds buckets/thresholds. + # Warning: The thresholds are evenly distributed between the min and max + # predictions if all predictions are inside [0, 1]. Otherwise, the thresholds + # are evenly distributed between 0 and 1. + # This warning can be removed when + # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed # and the roc curve is computed with deactivated formatting num_thresholds: int | None @@ -100,8 +152,8 @@ def perform_cca(self) -> torch.Tensor: """Perform the Connected Component Analysis on the self.target tensor. Raises: - ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for - connected component analysis. + ValueError: ValueError is raised if self.target doesn't conform with + requirements imposed by kornia for connected component analysis. Returns: Tensor: Components labeled from 0 to N. @@ -111,8 +163,8 @@ def perform_cca(self) -> torch.Tensor: # check and prepare target for labeling via kornia if target.min() < 0 or target.max() > 1: msg = ( - "kornia.contrib.connected_components expects input to lie in the interval [0, 1], " - f"but found interval was [{target.min()}, {target.max()}]." + "kornia.contrib.connected_components expects input to lie in the " + f"interval [0, 1], but found [{target.min()}, {target.max()}]." ) raise ValueError( msg, @@ -127,20 +179,28 @@ def compute_pro( target: torch.Tensor, preds: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit. + """Compute the pro/fpr value-pairs until the fpr specified by fpr_limit. - It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall - PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction. + It leverages the fact that the overlap corresponds to the tpr, and thus + computes the overall PRO curve by aggregating per-region tpr/fpr values + produced by ROC-construction. + + Args: + cca (torch.Tensor): Connected components tensor + target (torch.Tensor): Ground truth tensor + preds (torch.Tensor): Model predictions tensor Returns: - tuple[torch.Tensor, torch.Tensor]: tuple containing final fpr and tpr values. + tuple[torch.Tensor, torch.Tensor]: Tuple containing final fpr and tpr + values. """ if self.num_thresholds is not None: - # binary_roc is applying a sigmoid on the predictions before computing the roc curve - # when some predictions are out of [0, 1], the binning between min and max predictions - # cannot be applied in that case. This can be removed when - # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed and - # the roc curve is computed with deactivated formatting. + # binary_roc is applying a sigmoid on the predictions before computing + # the roc curve when some predictions are out of [0, 1], the binning + # between min and max predictions cannot be applied in that case. + # This can be removed when + # https://github.com/Lightning-AI/torchmetrics/issues/1526 is fixed + # and the roc curve is computed with deactivated formatting. if torch.all((preds >= 0) * (preds <= 1)): thresholds = thresholds_between_min_and_max(preds, self.num_thresholds, self.device) @@ -163,10 +223,12 @@ def compute_pro( fpr = torch.zeros(output_size, device=preds.device, dtype=torch.float) new_idx = torch.arange(0, output_size, device=preds.device, dtype=torch.float) - # Loop over the labels, computing per-region tpr/fpr curves, and aggregating them. - # Note that, since the groundtruth is different for every all to `roc`, we also get - # different/unique tpr/fpr curves (i.e. len(_fpr_idx) is different for every call). - # We therefore need to resample per-region curves to a fixed sampling ratio (defined above). + # Loop over the labels, computing per-region tpr/fpr curves, and + # aggregating them. Note that, since the groundtruth is different for + # every all to `roc`, we also get different/unique tpr/fpr curves + # (i.e. len(_fpr_idx) is different for every call). + # We therefore need to resample per-region curves to a fixed sampling + # ratio (defined above). labels = cca.unique()[1:] # 0 is background background = cca == 0 _fpr: torch.Tensor @@ -175,8 +237,9 @@ def compute_pro( interp: bool = False new_idx[-1] = output_size - 1 mask = cca == label - # Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other - # label in labels as FPs. We also don't need to return the thresholds + # Need to calculate label-wise roc on union of background & mask, as + # otherwise we wrongly consider other label in labels as FPs. + # We also don't need to return the thresholds _fpr, _tpr = binary_roc( preds=preds[background | mask], target=mask[background | mask], @@ -190,8 +253,9 @@ def compute_pro( _fpr_limit = self.fpr_limit _fpr_idx = torch.where(_fpr <= _fpr_limit)[0] - # if computed roc curve is not specified sufficiently close to self.fpr_limit, - # we include the closest higher tpr/fpr pair and linearly interpolate the tpr/fpr point at self.fpr_limit + # if computed roc curve is not specified sufficiently close to + # self.fpr_limit, we include the closest higher tpr/fpr pair and + # linearly interpolate the tpr/fpr point at self.fpr_limit if not torch.allclose(_fpr[_fpr_idx].max(), self.fpr_limit): _tmp_idx = torch.searchsorted(_fpr, self.fpr_limit) _fpr_idx = torch.cat([_fpr_idx, _tmp_idx.unsqueeze_(0)]) @@ -225,7 +289,8 @@ def _compute(self) -> tuple[torch.Tensor, torch.Tensor]: Perform the Connected Component Analysis first then compute the PRO curve. Returns: - tuple[torch.Tensor, torch.Tensor]: tuple containing final fpr and tpr values. + tuple[torch.Tensor, torch.Tensor]: Tuple containing final fpr and tpr + values. """ cca = self.perform_cca().flatten() target = dim_zero_cat(self.target).flatten() @@ -234,7 +299,7 @@ def _compute(self) -> tuple[torch.Tensor, torch.Tensor]: return self.compute_pro(cca=cca, target=target, preds=preds) def compute(self) -> torch.Tensor: - """Fist compute PRO curve, then compute and scale area under the curve. + """First compute PRO curve, then compute and scale area under the curve. Returns: Tensor: Value of the AUPRO metric @@ -248,7 +313,8 @@ def generate_figure(self) -> tuple[Figure, str]: """Generate a figure containing the PRO curve and the AUPRO. Returns: - tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging + tuple[Figure, str]: Tuple containing both the figure and the figure + title to be used for logging """ fpr, tpr = self._compute() aupro = self.compute() @@ -287,7 +353,7 @@ def interp1d(old_x: torch.Tensor, old_y: torch.Tensor, new_x: torch.Tensor) -> t # to preserve order, but we actually want the preceeding index. idx -= 1 # we clamp the index, because the number of intervals = old_x.size(0) -1, - # and the left neighbour should hence be at most number of intervals -1, i.e. old_x.size(0) - 2 + # and the left neighbour should hence be at most number of intervals -1, idx = torch.clamp(idx, 0, old_x.size(0) - 2) # perform actual linear interpolation diff --git a/src/anomalib/metrics/auroc.py b/src/anomalib/metrics/auroc.py index 183da7a4f0..05fd87889c 100644 --- a/src/anomalib/metrics/auroc.py +++ b/src/anomalib/metrics/auroc.py @@ -1,4 +1,38 @@ -"""Implementation of AUROC metric based on TorchMetrics.""" +"""Area Under the Receiver Operating Characteristic (AUROC) metric. + +This module provides the ``AUROC`` class which computes the area under the ROC +curve for evaluating anomaly detection performance. + +The AUROC score summarizes the trade-off between true positive rate (TPR) and +false positive rate (FPR) across different thresholds. It measures how well the +model can distinguish between normal and anomalous samples. + +Example: + >>> from anomalib.metrics import AUROC + >>> import torch + >>> # Create sample data + >>> labels = torch.tensor([0, 0, 1, 1]) # Binary labels + >>> scores = torch.tensor([0.1, 0.2, 0.8, 0.9]) # Anomaly scores + >>> # Initialize and compute AUROC + >>> metric = AUROC() + >>> auroc_score = metric(scores, labels) + >>> auroc_score + tensor(1.0) + +The metric can also be updated incrementally with batches: + + >>> for batch_scores, batch_labels in dataloader: + ... metric.update(batch_scores, batch_labels) + >>> final_score = metric.compute() + +Once computed, the ROC curve can be visualized: + + >>> figure, title = metric.generate_figure() + +Note: + The AUROC score ranges from 0 to 1, with 1 indicating perfect ranking of + anomalies above normal samples. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,13 +50,17 @@ class _AUROC(BinaryROC): """Area under the ROC curve. + This class computes the area under the receiver operating characteristic + curve, which plots the true positive rate against the false positive rate + at various thresholds. + Examples: + To compute the metric for a set of predictions and ground truth targets: + >>> import torch >>> from anomalib.metrics import AUROC - ... >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.92, 0.03]) >>> target = torch.tensor([0, 0, 1, 1, 0]) - ... >>> auroc = AUROC() >>> auroc(preds, target) tensor(0.6667) @@ -34,16 +72,16 @@ class _AUROC(BinaryROC): >>> auroc.compute() tensor(0.6667) - To plot the ROC curve, use the ``generate_figure`` method: + To plot the ROC curve: - >>> fig, title = auroc.generate_figure() + >>> figure, title = auroc.generate_figure() """ def compute(self) -> torch.Tensor: """First compute ROC curve, then compute area under the curve. Returns: - Tensor: Value of the AUROC metric + torch.Tensor: Value of the AUROC metric """ tpr: torch.Tensor fpr: torch.Tensor @@ -52,21 +90,23 @@ def compute(self) -> torch.Tensor: return auc(fpr, tpr, reorder=True) def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: - """Update state with new values. + """Update state with new predictions and targets. - Need to flatten new values as ROC expects them in this format for binary classification. + Need to flatten new values as ROC expects them in this format for binary + classification. Args: - preds (torch.Tensor): predictions of the model - target (torch.Tensor): ground truth targets + preds (torch.Tensor): Predictions from the model + target (torch.Tensor): Ground truth target labels """ super().update(preds.flatten(), target.flatten()) def _compute(self) -> tuple[torch.Tensor, torch.Tensor]: - """Compute fpr/tpr value pairs. + """Compute false positive rate and true positive rate value pairs. Returns: - Tuple containing Tensors for fpr and tpr + tuple[torch.Tensor, torch.Tensor]: Tuple containing tensors for FPR + and TPR values """ tpr: torch.Tensor fpr: torch.Tensor @@ -74,10 +114,14 @@ def _compute(self) -> tuple[torch.Tensor, torch.Tensor]: return (fpr, tpr) def generate_figure(self) -> tuple[Figure, str]: - """Generate a figure containing the ROC curve, the baseline and the AUROC. + """Generate a figure showing the ROC curve. + + The figure includes the ROC curve, a baseline representing random + performance, and the AUROC score. Returns: - tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging + tuple[Figure, str]: Tuple containing both the figure and the figure + title to be used for logging """ fpr, tpr = self._compute() auroc = self.compute() diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index 041e45a334..040fd00ece 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -1,37 +1,17 @@ -"""Base classes for metrics in Anomalib.""" +"""Base classes for metrics in Anomalib. -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -from collections.abc import Sequence +This module provides base classes for implementing metrics in Anomalib: -from torchmetrics import Metric, MetricCollection - -from anomalib.data import Batch +- ``AnomalibMetric``: Base class that makes torchmetrics compatible with Anomalib +- ``create_anomalib_metric``: Factory function to create Anomalib metrics +The ``AnomalibMetric`` class adds batch processing capabilities to any torchmetrics +metric. It allows metrics to be updated directly with ``Batch`` objects instead of +requiring individual tensors. -class AnomalibMetric: - """Base class for metrics in Anomalib. - - This class is designed to make any torchmetrics metric compatible with the - Anomalib framework. An Anomalib version of any torchmetrics metric can be created - by inheriting from this class and the desired torchmetrics metric. For example, to - create an Anomalib version of the BinaryF1Score metric, the user can create a new - class that inherits from AnomalibMetric and BinaryF1Score. - - The AnomalibMetric class adds the ability to update the metric with a Batch - object instead of individual prediction and target tensors. To use this feature, - the user must provide a list of fields as constructor arguments when instantiating - the metric. When the metric is updated with a Batch object, it will extract the - values of these fields from the Batch object and pass them to the `update` method - of the metric. - - Args: - fields (Sequence[str]): List of field names to extract from the Batch object. - prefix (str): Prefix to add to the metric name. Defaults to an empty string. - **kwargs: Variable keyword arguments that can be passed to the parent class. +Example: + Create a custom F1 score metric:: - Examples: >>> from torchmetrics.classification import BinaryF1Score >>> from anomalib.metrics import AnomalibMetric >>> from anomalib.data import ImageBatch @@ -40,29 +20,85 @@ class that inherits from AnomalibMetric and BinaryF1Score. >>> class F1Score(AnomalibMetric, BinaryF1Score): ... pass ... + >>> # Create metric specifying batch fields to use >>> f1_score = F1Score(fields=["pred_label", "gt_label"]) >>> + >>> # Create sample batch >>> batch = ImageBatch( ... image=torch.rand(4, 3, 256, 256), ... pred_label=torch.tensor([0, 0, 0, 1]), - ... gt_label=torch.tensor([0, 0, 1, 1])), + ... gt_label=torch.tensor([0, 0, 1, 1]) ... ) >>> - >>> # The AnomalibMetric class allows us to update the metric by passing a Batch - >>> # object directly. + >>> # Update metric with batch directly >>> f1_score.update(batch) >>> f1_score.compute() tensor(0.6667) - >>> - >>> # specifying the field names allows us to distinguish between image and - >>> # pixel metrics. - >>> image_f1_score = F1Score(fields=["pred_label", "gt_label"], prefix="image_") - >>> pixel_f1_score = F1Score(fields=[pred_mask", "gt_mask"], prefix="pixel_") + + Use factory function to create metric:: + + >>> from anomalib.metrics import create_anomalib_metric + >>> F1Score = create_anomalib_metric(BinaryF1Score) + >>> f1_score = F1Score(fields=["pred_label", "gt_label"]) +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence + +from torchmetrics import Metric, MetricCollection + +from anomalib.data import Batch + + +class AnomalibMetric: + """Base class for metrics in Anomalib. + + Makes any torchmetrics metric compatible with the Anomalib framework by adding + batch processing capabilities. Subclasses must inherit from both this class + and a torchmetrics metric. + + The class enables updating metrics with ``Batch`` objects instead of + individual tensors. It extracts the specified fields from the batch and + passes them to the underlying metric's update method. + + Args: + fields (Sequence[str] | None): Names of fields to extract from batch. + If None, uses class's ``default_fields``. Required if no defaults. + prefix (str): Prefix added to metric name. Defaults to "". + **kwargs: Additional arguments passed to parent metric class. + + Raises: + ValueError: If no fields are specified and class has no defaults. + + Example: + Create image and pixel-level F1 metrics:: + + >>> from torchmetrics.classification import BinaryF1Score + >>> class F1Score(AnomalibMetric, BinaryF1Score): + ... pass + ... + >>> # Image-level metric using pred_label and gt_label + >>> image_f1 = F1Score( + ... fields=["pred_label", "gt_label"], + ... prefix="image_" + ... ) + >>> # Pixel-level metric using pred_mask and gt_mask + >>> pixel_f1 = F1Score( + ... fields=["pred_mask", "gt_mask"], + ... prefix="pixel_" + ... ) """ default_fields: Sequence[str] - def __init__(self, fields: Sequence[str] | None = None, prefix: str = "", **kwargs) -> None: + def __init__( + self, + fields: Sequence[str] | None = None, + prefix: str = "", + **kwargs, + ) -> None: fields = fields or getattr(self, "default_fields", None) if fields is None: msg = ( @@ -76,7 +112,7 @@ def __init__(self, fields: Sequence[str] | None = None, prefix: str = "", **kwar super().__init__(**kwargs) def __init_subclass__(cls, **kwargs) -> None: - """Check that the subclass implements the torchmetrics.Metric interface.""" + """Check that subclass implements torchmetrics.Metric interface.""" del kwargs assert issubclass( cls, @@ -84,7 +120,16 @@ def __init_subclass__(cls, **kwargs) -> None: ), "AnomalibMetric must be a subclass of torchmetrics.Metric or torchmetrics.MetricCollection" def update(self, batch: Batch, *args, **kwargs) -> None: - """Update the metric with the specified fields from the Batch object.""" + """Update metric with values from batch fields. + + Args: + batch (Batch): Batch object containing required fields. + *args: Additional positional arguments passed to parent update. + **kwargs: Additional keyword arguments passed to parent update. + + Raises: + ValueError: If batch is missing any required fields. + """ for key in self.fields: if getattr(batch, key, None) is None: msg = f"Batch object is missing required field: {key}" @@ -96,32 +141,29 @@ def update(self, batch: Batch, *args, **kwargs) -> None: def create_anomalib_metric(metric_cls: type) -> type: """Create an Anomalib version of a torchmetrics metric. - This function creates an Anomalib version of a torchmetrics metric by inheriting - from the AnomalibMetric class and the specified torchmetrics metric class. The - resulting class will have the same name as the input metric class and will inherit - from both AnomalibMetric and the input metric class. + Factory function that creates a new class inheriting from both + ``AnomalibMetric`` and the input metric class. The resulting class has + batch processing capabilities while maintaining the original metric's + functionality. Args: - metric_cls (Callable): The torchmetrics metric class to wrap. + metric_cls (type): torchmetrics metric class to wrap. Returns: - AnomalibMetric: An Anomalib version of the input metric class. + type: New class inheriting from ``AnomalibMetric`` and input class. - Examples: - >>> from torchmetrics.classification import BinaryF1Score - >>> from anomalib.metrics import create_anomalib_metric - >>> - >>> F1Score = create_anomalib_metric(BinaryF1Score) - >>> # This is equivalent to the following class definition: - >>> # class F1Score(AnomalibMetric, BinaryF1Score): ... - >>> - >>> f1_score = F1Score(fields=["pred_label", "gt_label"]) - >>> - >>> # The AnomalibMetric class allows us to update the metric by passing a Batch - >>> # object directly. - >>> f1_score.update(batch) - >>> f1_score.compute() - tensor(0.6667) + Raises: + AssertionError: If input class is not a torchmetrics.Metric subclass. + + Example: + Create F1 score metric:: + + >>> from torchmetrics.classification import BinaryF1Score + >>> F1Score = create_anomalib_metric(BinaryF1Score) + >>> f1_score = F1Score(fields=["pred_label", "gt_label"]) + >>> f1_score.update(batch) # Can update with batch directly + >>> f1_score.compute() + tensor(0.6667) """ assert issubclass(metric_cls, Metric), "The wrapped metric must be a subclass of torchmetrics.Metric." return type(metric_cls.__name__, (AnomalibMetric, metric_cls), {}) diff --git a/src/anomalib/metrics/binning.py b/src/anomalib/metrics/binning.py index b56c234800..7b36299791 100644 --- a/src/anomalib/metrics/binning.py +++ b/src/anomalib/metrics/binning.py @@ -1,4 +1,22 @@ -"""Binning functions for metrics.""" +"""Binning functions for metrics. + +This module provides utility functions for generating threshold values used in +various metrics calculations. + +Example: + >>> import torch + >>> from anomalib.metrics.binning import thresholds_between_min_and_max + >>> preds = torch.tensor([0.1, 0.5, 0.8]) + >>> thresholds = thresholds_between_min_and_max(preds, num_thresholds=3) + >>> thresholds + tensor([0.1000, 0.4500, 0.8000]) + + Generate thresholds between 0 and 1: + >>> from anomalib.metrics.binning import thresholds_between_0_and_1 + >>> thresholds = thresholds_between_0_and_1(num_thresholds=3) + >>> thresholds + tensor([0.0000, 0.5000, 1.0000]) +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,29 +30,48 @@ def thresholds_between_min_and_max( num_thresholds: int = 100, device: torch.device | None = None, ) -> torch.Tensor: - """Threshold values between min and max of the predictions. + """Generate evenly spaced threshold values between min and max predictions. Args: - preds (torch.Tensor): Predictions. - num_thresholds (int, optional): Number of thresholds to generate. Defaults to 100. - device (torch_device | None, optional): Device to use for computation. Defaults to None. + preds (torch.Tensor): Input tensor containing predictions or scores. + num_thresholds (int, optional): Number of threshold values to generate. + Defaults to ``100``. + device (torch.device | None, optional): Device on which to place the + output tensor. If ``None``, uses the device of input tensor. + Defaults to ``None``. Returns: - Tensor: - Array of size ``num_thresholds`` that contains evenly spaced values - between ``preds.min()`` and ``preds.max()`` on ``device``. + torch.Tensor: A 1D tensor of size ``num_thresholds`` containing evenly + spaced values between ``preds.min()`` and ``preds.max()``. + + Example: + >>> preds = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]) + >>> thresholds = thresholds_between_min_and_max(preds, num_thresholds=3) + >>> thresholds + tensor([0.1000, 0.5000, 0.9000]) """ return linspace(start=preds.min(), end=preds.max(), steps=num_thresholds, device=device) -def thresholds_between_0_and_1(num_thresholds: int = 100, device: torch.device | None = None) -> torch.Tensor: - """Threshold values between 0 and 1. +def thresholds_between_0_and_1( + num_thresholds: int = 100, + device: torch.device | None = None, +) -> torch.Tensor: + """Generate evenly spaced threshold values between 0 and 1. Args: - num_thresholds (int, optional): Number of thresholds to generate. Defaults to 100. - device (torch_device | None, optional): Device to use for computation. Defaults to None. + num_thresholds (int, optional): Number of threshold values to generate. + Defaults to ``100``. + device (torch.device | None, optional): Device on which to place the + output tensor. Defaults to ``None``. Returns: - Tensor: Threshold values between 0 and 1. + torch.Tensor: A 1D tensor of size ``num_thresholds`` containing evenly + spaced values between ``0`` and ``1``. + + Example: + >>> thresholds = thresholds_between_0_and_1(num_thresholds=5) + >>> thresholds + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) """ return linspace(start=0, end=1, steps=num_thresholds, device=device) diff --git a/src/anomalib/metrics/evaluator.py b/src/anomalib/metrics/evaluator.py index 460a2a4b0b..76e1893b27 100644 --- a/src/anomalib/metrics/evaluator.py +++ b/src/anomalib/metrics/evaluator.py @@ -1,4 +1,51 @@ -"""Evaluator module for LightningModule.""" +"""Evaluator module for LightningModule. + +The Evaluator module computes and logs metrics during validation and test steps. +Each ``AnomalibModule`` should have an Evaluator module as a submodule to compute +and log metrics. An Evaluator module can be passed to the ``AnomalibModule`` as a +parameter during initialization. When no Evaluator module is provided, the +``AnomalibModule`` will use a default Evaluator module that logs a default set of +metrics. + +Args: + val_metrics (Sequence[AnomalibMetric] | AnomalibMetric | None, optional): + Validation metrics. Defaults to ``None``. + test_metrics (Sequence[AnomalibMetric] | AnomalibMetric | None, optional): + Test metrics. Defaults to ``None``. + compute_on_cpu (bool, optional): Whether to compute metrics on CPU. + Defaults to ``True``. + +Example: + >>> from anomalib.metrics import F1Score, AUROC + >>> from anomalib.data import ImageBatch + >>> import torch + >>> + >>> # Initialize metrics with fields to use from batch + >>> f1_score = F1Score(fields=["pred_label", "gt_label"]) + >>> auroc = AUROC(fields=["pred_score", "gt_label"]) + >>> + >>> # Create evaluator with test metrics + >>> evaluator = Evaluator(test_metrics=[f1_score, auroc]) + >>> + >>> # Create sample batch + >>> batch = ImageBatch( + ... image=torch.rand(4, 3, 256, 256), + ... pred_label=torch.tensor([0, 0, 1, 1]), + ... gt_label=torch.tensor([0, 0, 1, 1]), + ... pred_score=torch.tensor([0.1, 0.2, 0.8, 0.9]) + ... ) + >>> + >>> # Update metrics with batch + >>> evaluator.on_test_batch_end(None, None, None, batch, 0) + >>> + >>> # Compute and log metrics at end of epoch + >>> evaluator.on_test_epoch_end(None, None) + +Note: + The evaluator will automatically move metrics to CPU for computation if + ``compute_on_cpu=True`` and only one device is used. For multi-GPU training, + ``compute_on_cpu`` is automatically set to ``False``. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/metrics/f1_score.py b/src/anomalib/metrics/f1_score.py index ab85c0cc03..68a753dc7b 100644 --- a/src/anomalib/metrics/f1_score.py +++ b/src/anomalib/metrics/f1_score.py @@ -1,4 +1,31 @@ -"""F1 Score and F1Max Metrics for Binary Classification Tasks.""" +"""F1 Score and F1Max metrics for binary classification tasks. + +This module provides two metrics for evaluating binary classification performance: + +- ``F1Score``: Standard F1 score metric that computes the harmonic mean of + precision and recall at a fixed threshold +- ``F1Max``: Maximum F1 score metric that finds the optimal threshold by + computing F1 scores across different thresholds + +Example: + >>> from anomalib.metrics import F1Score, F1Max + >>> import torch + >>> # Create sample data + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> # Compute standard F1 score + >>> f1 = F1Score() + >>> f1.update(preds > 0.5, target) + >>> f1.compute() + tensor(1.0) + >>> # Compute maximum F1 score + >>> f1_max = F1Max() + >>> f1_max.update(preds, target) + >>> f1_max.compute() + tensor(1.0) + >>> f1_max.threshold + tensor(0.6000) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,56 +40,60 @@ class F1Score(AnomalibMetric, BinaryF1Score): - """Wrapper to add AnomalibMetric functionality to F1Score metric.""" + """Wrapper to add AnomalibMetric functionality to F1Score metric. + + This class wraps the torchmetrics ``BinaryF1Score`` to make it compatible + with Anomalib's batch processing capabilities. + + Example: + >>> from anomalib.metrics import F1Score + >>> import torch + >>> # Create metric + >>> f1 = F1Score() + >>> # Create sample data + >>> preds = torch.tensor([0, 0, 1, 1]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> # Update and compute + >>> f1.update(preds, target) + >>> f1.compute() + tensor(0.8571) + """ class _F1Max(Metric): - """F1Max Metric for Computing the Maximum F1 Score. + """F1Max metric for computing the maximum F1 score. - This class is designed to calculate the maximum F1 score from the precision- - recall curve for binary classification tasks. The F1 score is a harmonic - mean of precision and recall, offering a balance between these two metrics. - The maximum F1 score (F1-Max) is particularly useful in scenarios where an - optimal balance between precision and recall is desired, such as in - imbalanced datasets or when both false positives and false negatives carry - significant costs. + This class calculates the maximum F1 score by varying the classification + threshold. The F1 score is the harmonic mean of precision and recall, + providing a balanced metric for imbalanced datasets. - After computing the F1Max score, the class also identifies and stores the - threshold that yields this maximum F1 score, which providing insight into - the optimal point for the classification decision. + After computing the maximum F1 score, the class stores the threshold that + achieved this score in the ``threshold`` attribute. Args: - **kwargs: Variable keyword arguments that can be passed to the parent class. + **kwargs: Additional arguments passed to the parent ``Metric`` class. Attributes: - full_state_update (bool): Indicates whether the metric requires updating - the entire state. Set to False for this metric as it calculates the - F1 score based on the current state without needing historical data. + full_state_update (bool): Whether to update entire state on each batch. + Set to ``False`` as metric only needs current batch. precision_recall_curve (BinaryPrecisionRecallCurve): Utility to compute - precision and recall values across different thresholds. - threshold (torch.Tensor): Stores the threshold value that results in the - maximum F1 score. + precision-recall values across thresholds. + threshold (torch.Tensor): Threshold value that yields maximum F1 score. - Examples: + Example: >>> from anomalib.metrics import F1Max >>> import torch - + >>> # Create metric + >>> f1_max = F1Max() + >>> # Create sample data >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) >>> target = torch.tensor([0, 0, 1, 1]) - - >>> f1_max = F1Max() + >>> # Update and compute >>> f1_max.update(preds, target) - - >>> optimal_f1_score = f1_max.compute() - >>> print(f"Optimal F1 Score: {f1_max_score}") - >>> print(f"Optimal Threshold: {f1_max.threshold}") - - Note: - - Use `update` method to input predictions and target labels. - - Use `compute` method to calculate the maximum F1 score after all - updates. - - Use `reset` method to clear the current state and prepare for a new - set of calculations. + >>> f1_max.compute() + tensor(1.0) + >>> f1_max.threshold + tensor(0.6000) """ full_state_update: bool = False @@ -75,19 +106,27 @@ def __init__(self, **kwargs) -> None: self.threshold: torch.Tensor def update(self, preds: torch.Tensor, target: torch.Tensor, *args, **kwargs) -> None: - """Update the precision-recall curve metric.""" + """Update the precision-recall curve with new predictions and targets. + + Args: + preds (torch.Tensor): Predicted scores or probabilities. + target (torch.Tensor): Ground truth binary labels. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + """ del args, kwargs # These variables are not used. self.precision_recall_curve.update(preds, target) def compute(self) -> torch.Tensor: - """Compute the value of the optimal F1 score. + """Compute the maximum F1 score across all thresholds. - Compute the F1 scores while varying the threshold. Store the optimal - threshold as attribute and return the maximum value of the F1 score. + Computes F1 scores at different thresholds using the precision-recall + curve. Stores the threshold that achieves maximum F1 score in the + ``threshold`` attribute. Returns: - Value of the F1 score at the optimal threshold. + torch.Tensor: Maximum F1 score value. """ precision: torch.Tensor recall: torch.Tensor @@ -99,9 +138,30 @@ def compute(self) -> torch.Tensor: return torch.max(f1_score) def reset(self) -> None: - """Reset the metric.""" + """Reset the metric state.""" self.precision_recall_curve.reset() class F1Max(AnomalibMetric, _F1Max): # type: ignore[misc] - """Wrapper to add AnomalibMetric functionality to F1Max metric.""" + """Wrapper to add AnomalibMetric functionality to F1Max metric. + + This class wraps the internal ``_F1Max`` metric to make it compatible with + Anomalib's batch processing capabilities. + + Example: + >>> from anomalib.metrics import F1Max + >>> from anomalib.data import ImageBatch + >>> import torch + >>> # Create metric with batch fields + >>> f1_max = F1Max(fields=["pred_score", "gt_label"]) + >>> # Create sample batch + >>> batch = ImageBatch( + ... image=torch.rand(4, 3, 32, 32), + ... pred_score=torch.tensor([0.1, 0.4, 0.35, 0.8]), + ... gt_label=torch.tensor([0, 0, 1, 1]) + ... ) + >>> # Update and compute + >>> f1_max.update(batch) + >>> f1_max.compute() + tensor(1.0) + """ diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index 8456174ec9..4ca3e71245 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -1,4 +1,28 @@ -"""Module that tracks the min and max values of the observations in each batch.""" +"""Module that tracks the min and max values of the observations in each batch. + +This module provides the ``MinMax`` metric class which tracks the minimum and +maximum values seen across batches of data. This is useful for normalizing +predictions or monitoring value ranges during training. + +Example: + >>> from anomalib.metrics import MinMax + >>> import torch + >>> # Create sample predictions + >>> predictions = torch.tensor([0.0807, 0.6329, 0.0559, 0.9860, 0.3595]) + >>> # Initialize and compute min/max + >>> minmax = MinMax() + >>> min_val, max_val = minmax(predictions) + >>> min_val, max_val + (tensor(0.0559), tensor(0.9860)) + + The metric can be updated incrementally with new batches: + + >>> new_predictions = torch.tensor([0.3251, 0.3169, 0.3072, 0.6247, 0.9999]) + >>> minmax.update(new_predictions) + >>> min_val, max_val = minmax.compute() + >>> min_val, max_val + (tensor(0.0559), tensor(0.9999)) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,29 +32,35 @@ class MinMax(Metric): - """Track the min and max values of the observations in each batch. + """Track minimum and maximum values across batches. + + This metric maintains running minimum and maximum values across all batches + it processes. It is useful for tasks like normalization or monitoring the + range of values during training. Args: - full_state_update (bool, optional): Whether to update the state with the - new values. - Defaults to ``True``. - kwargs: Any keyword arguments. + full_state_update (bool, optional): Whether to update the internal state + with each new batch. Defaults to ``True``. + kwargs: Additional keyword arguments passed to the parent class. - Examples: + Attributes: + min (torch.Tensor): Running minimum value seen across all batches + max (torch.Tensor): Running maximum value seen across all batches + + Example: >>> from anomalib.metrics import MinMax >>> import torch - ... - >>> predictions = torch.tensor([0.0807, 0.6329, 0.0559, 0.9860, 0.3595]) + >>> # Create metric >>> minmax = MinMax() - >>> minmax(predictions) - (tensor(0.0559), tensor(0.9860)) - - It is possible to update the minmax values with a new tensor of predictions. - - >>> new_predictions = torch.tensor([0.3251, 0.3169, 0.3072, 0.6247, 0.9999]) - >>> minmax.update(new_predictions) - >>> minmax.compute() - (tensor(0.0559), tensor(0.9999)) + >>> # Update with batches + >>> batch1 = torch.tensor([0.1, 0.2, 0.3]) + >>> batch2 = torch.tensor([0.2, 0.4, 0.5]) + >>> minmax.update(batch1) + >>> minmax.update(batch2) + >>> # Get final min/max values + >>> min_val, max_val = minmax.compute() + >>> min_val, max_val + (tensor(0.1000), tensor(0.5000)) """ full_state_update: bool = True @@ -44,12 +74,24 @@ def __init__(self, **kwargs) -> None: self.max = torch.tensor(float("-inf")) def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: - """Update the min and max values.""" + """Update running min and max values with new predictions. + + Args: + predictions (torch.Tensor): New tensor of values to include in min/max + tracking + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) + """ del args, kwargs # These variables are not used. self.max = torch.max(self.max, torch.max(predictions)) self.min = torch.min(self.min, torch.min(predictions)) def compute(self) -> tuple[torch.Tensor, torch.Tensor]: - """Return min and max values.""" + """Compute final minimum and maximum values. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple containing the (min, max) + values tracked across all batches + """ return self.min, self.max diff --git a/src/anomalib/metrics/pimo/__init__.py b/src/anomalib/metrics/pimo/__init__.py index 174f546e4d..f84a8da5d5 100644 --- a/src/anomalib/metrics/pimo/__init__.py +++ b/src/anomalib/metrics/pimo/__init__.py @@ -1,4 +1,22 @@ -"""Per-Image Metrics.""" +"""Per-Image Metrics for anomaly detection. + +This module provides metrics for evaluating anomaly detection performance on a +per-image basis. The metrics include: + +- ``PIMO``: Per-Image Metric Optimization for anomaly detection +- ``AUPIMO``: Area Under PIMO curve +- ``ThresholdMethod``: Methods for determining optimal thresholds +- ``PIMOResult``: Container for PIMO metric results +- ``AUPIMOResult``: Container for AUPIMO metric results + +The implementation is based on the original work from: +https://github.com/jpcbertoldo/aupimo + +Example: + >>> from anomalib.metrics.pimo import PIMO, AUPIMO + >>> pimo = PIMO() # doctest: +SKIP + >>> aupimo = AUPIMO() # doctest: +SKIP +""" # Original Code # https://github.com/jpcbertoldo/aupimo diff --git a/src/anomalib/metrics/pimo/_validate.py b/src/anomalib/metrics/pimo/_validate.py index f0ba7af4bf..4ca27ede6b 100644 --- a/src/anomalib/metrics/pimo/_validate.py +++ b/src/anomalib/metrics/pimo/_validate.py @@ -1,7 +1,26 @@ -"""Utils for validating arguments and results. +"""Utilities for validating arguments and results. -TODO(jpcbertoldo): Move validations to a common place and reuse them across the codebase. -https://github.com/openvinotoolkit/anomalib/issues/2093 +This module provides validation functions for various inputs and outputs used in +the PIMO metrics. The functions check for correct data types, shapes, ranges and +other constraints. + +The validation functions include: + +- Threshold validation (number, bounds, etc) +- Rate validation (ranges, curves, etc) +- Tensor validation (anomaly maps, masks, etc) +- Binary classification curve validation +- Score validation +- Ground truth validation + +TODO(jpcbertoldo): Move validations to a common place and reuse them across the +codebase. https://github.com/openvinotoolkit/anomalib/issues/2093 + +Example: + >>> from anomalib.metrics.pimo._validate import is_rate + >>> is_rate(0.5, zero_ok=True, one_ok=True) # No error + >>> is_rate(-0.1, zero_ok=True, one_ok=True) # Raises ValueError + ValueError: Expected rate to be in [0, 1], but got -0.1. """ # Original Code @@ -22,7 +41,20 @@ def is_num_thresholds_gte2(num_thresholds: int) -> None: - """Validate the number of thresholds is a positive integer >= 2.""" + """Validate that the number of thresholds is a positive integer >= 2. + + Args: + num_thresholds: Number of thresholds to validate. + + Raises: + TypeError: If ``num_thresholds`` is not an integer. + ValueError: If ``num_thresholds`` is less than 2. + + Example: + >>> is_num_thresholds_gte2(5) # No error + >>> is_num_thresholds_gte2(1) # Raises ValueError + ValueError: Expected the number of thresholds to be larger than 1, but got 1 + """ if not isinstance(num_thresholds, int): msg = f"Expected the number of thresholds to be an integer, but got {type(num_thresholds)}" raise TypeError(msg) @@ -33,7 +65,25 @@ def is_num_thresholds_gte2(num_thresholds: int) -> None: def is_same_shape(*args) -> None: - """Works both for tensors and ndarrays.""" + """Validate that all arguments have the same shape. + + Works for both tensors and ndarrays. + + Args: + *args: Variable number of tensors or ndarrays to compare shapes. + + Raises: + ValueError: If arguments have different shapes. + + Example: + >>> import torch + >>> t1 = torch.zeros(2, 3) + >>> t2 = torch.ones(2, 3) + >>> is_same_shape(t1, t2) # No error + >>> t3 = torch.zeros(3, 2) + >>> is_same_shape(t1, t3) # Raises ValueError + ValueError: Expected arguments to have the same shape, but got [(2, 3), (3, 2)] + """ assert len(args) > 0 shapes = sorted({tuple(arg.shape) for arg in args}) if len(shapes) > 1: @@ -42,12 +92,21 @@ def is_same_shape(*args) -> None: def is_rate(rate: float | int, zero_ok: bool, one_ok: bool) -> None: - """Validates a rate parameter. + """Validate a rate parameter. Args: - rate (float | int): The rate to be validated. - zero_ok (bool): Flag indicating if rate can be 0. - one_ok (bool): Flag indicating if rate can be 1. + rate: The rate value to validate. + zero_ok: Whether 0.0 is an acceptable value. + one_ok: Whether 1.0 is an acceptable value. + + Raises: + TypeError: If ``rate`` is not a float or int. + ValueError: If ``rate`` is outside [0,1] or equals 0/1 when not allowed. + + Example: + >>> is_rate(0.5, zero_ok=True, one_ok=True) # No error + >>> is_rate(0.0, zero_ok=False, one_ok=True) # Raises ValueError + ValueError: Rate cannot be 0. """ if not isinstance(rate, float | int): msg = f"Expected rate to be a float or int, but got {type(rate)}." @@ -67,10 +126,19 @@ def is_rate(rate: float | int, zero_ok: bool, one_ok: bool) -> None: def is_rate_range(bounds: tuple[float, float]) -> None: - """Validates the range of rates within the bounds. + """Validate that rate bounds form a valid range. Args: - bounds (tuple[float, float]): The lower and upper bounds of the rates. + bounds: Tuple of (lower, upper) rate bounds. + + Raises: + TypeError: If ``bounds`` is not a tuple of length 2. + ValueError: If bounds are invalid or lower >= upper. + + Example: + >>> is_rate_range((0.1, 0.9)) # No error + >>> is_rate_range((0.9, 0.1)) # Raises ValueError + ValueError: Expected the upper bound to be larger than the lower bound """ if not isinstance(bounds, tuple): msg = f"Expected the bounds to be a tuple, but got {type(bounds)}" @@ -90,7 +158,22 @@ def is_rate_range(bounds: tuple[float, float]) -> None: def is_valid_threshold(thresholds: Tensor) -> None: - """Validate that the thresholds are valid and monotonically increasing.""" + """Validate that thresholds are valid and monotonically increasing. + + Args: + thresholds: Tensor of threshold values. + + Raises: + TypeError: If ``thresholds`` is not a floating point Tensor. + ValueError: If ``thresholds`` is not 1D or not strictly increasing. + + Example: + >>> thresholds = torch.tensor([0.1, 0.2, 0.3]) + >>> is_valid_threshold(thresholds) # No error + >>> bad_thresholds = torch.tensor([0.3, 0.2, 0.1]) + >>> is_valid_threshold(bad_thresholds) # Raises ValueError + ValueError: Expected thresholds to be strictly increasing + """ if not isinstance(thresholds, Tensor): msg = f"Expected thresholds to be an Tensor, but got {type(thresholds)}" raise TypeError(msg) @@ -110,6 +193,20 @@ def is_valid_threshold(thresholds: Tensor) -> None: def validate_threshold_bounds(threshold_bounds: tuple[float, float]) -> None: + """Validate threshold bounds form a valid range. + + Args: + threshold_bounds: Tuple of (lower, upper) threshold bounds. + + Raises: + TypeError: If bounds are not floats or not a tuple of length 2. + ValueError: If upper <= lower. + + Example: + >>> validate_threshold_bounds((0.1, 0.9)) # No error + >>> validate_threshold_bounds((0.9, 0.1)) # Raises ValueError + ValueError: Expected the upper bound to be greater than the lower bound + """ if not isinstance(threshold_bounds, tuple): msg = f"Expected threshold bounds to be a tuple, but got {type(threshold_bounds)}." raise TypeError(msg) @@ -134,6 +231,21 @@ def validate_threshold_bounds(threshold_bounds: tuple[float, float]) -> None: def is_anomaly_maps(anomaly_maps: Tensor) -> None: + """Validate anomaly maps tensor. + + Args: + anomaly_maps: Tensor of shape (N, H, W) containing anomaly scores. + + Raises: + ValueError: If tensor does not have 3 dimensions. + TypeError: If tensor is not floating point. + + Example: + >>> maps = torch.randn(10, 32, 32) + >>> is_anomaly_maps(maps) # No error + >>> bad_maps = torch.zeros(10, 32, 32, dtype=torch.long) + >>> is_anomaly_maps(bad_maps) # Raises TypeError + """ if anomaly_maps.ndim != 3: msg = f"Expected anomaly maps have 3 dimensions (N, H, W), but got {anomaly_maps.ndim} dimensions" raise ValueError(msg) @@ -147,6 +259,22 @@ def is_anomaly_maps(anomaly_maps: Tensor) -> None: def is_masks(masks: Tensor) -> None: + """Validate ground truth mask tensor. + + Args: + masks: Binary tensor of shape (N, H, W) containing ground truth labels. + + Raises: + ValueError: If tensor does not have 3 dimensions or contains non-binary + values. + TypeError: If tensor has invalid dtype. + + Example: + >>> masks = torch.zeros(10, 32, 32, dtype=torch.bool) + >>> is_masks(masks) # No error + >>> bad_masks = torch.ones(10, 32, 32) * 2 + >>> is_masks(bad_masks) # Raises ValueError + """ if masks.ndim != 3: msg = f"Expected masks have 3 dimensions (N, H, W), but got {masks.ndim} dimensions" raise ValueError(msg) @@ -155,8 +283,8 @@ def is_masks(masks: Tensor) -> None: pass elif masks.dtype.is_floating_point: msg = ( - "Expected masks to be an integer or boolean Tensor with ground truth labels, " - f"but got Tensor with dtype {masks.dtype}" + "Expected masks to be an integer or boolean Tensor with ground truth " + f"labels, but got Tensor with dtype {masks.dtype}" ) raise TypeError(msg) else: @@ -165,13 +293,35 @@ def is_masks(masks: Tensor) -> None: masks_unique_vals = torch.unique(masks) if torch.any((masks_unique_vals != 0) & (masks_unique_vals != 1)): msg = ( - "Expected masks to be a *binary* Tensor with ground truth labels, " - f"but got Tensor with unique values {sorted(masks_unique_vals)}" + "Expected masks to be a *binary* Tensor with ground truth " + f"labels, but got Tensor with unique values " + f"{sorted(masks_unique_vals)}" ) raise ValueError(msg) -def is_binclf_curves(binclf_curves: Tensor, valid_thresholds: Tensor | None) -> None: +def is_binclf_curves( + binclf_curves: Tensor, + valid_thresholds: Tensor | None, +) -> None: + """Validate binary classification curves tensor. + + Args: + binclf_curves: Tensor of shape (N, T, 2, 2) containing confusion matrices + for N images and T thresholds. + valid_thresholds: Optional tensor of T threshold values. + + Raises: + ValueError: If tensor has wrong shape or invalid values. + TypeError: If tensor has wrong dtype. + RuntimeError: If number of thresholds doesn't match. + + Example: + >>> curves = torch.zeros(10, 5, 2, 2, dtype=torch.int64) + >>> is_binclf_curves(curves, None) # No error + >>> bad_curves = torch.zeros(10, 5, 3, 2, dtype=torch.int64) + >>> is_binclf_curves(bad_curves, None) # Raises ValueError + """ if binclf_curves.ndim != 4: msg = f"Expected binclf curves to be 4D, but got {binclf_curves.ndim}D" raise ValueError(msg) @@ -188,13 +338,13 @@ def is_binclf_curves(binclf_curves: Tensor, valid_thresholds: Tensor | None) -> msg = "Expected binclf curves to have non-negative values, but got negative values." raise ValueError(msg) - neg = binclf_curves[:, :, 0, :].sum(axis=-1) # (num_images, num_thresholds) + neg = binclf_curves[:, :, 0, :].sum(dim=-1) # (num_images, num_thresholds) if (neg != neg[:, :1]).any(): msg = "Expected binclf curves to have the same number of negatives per image for every thresh." raise ValueError(msg) - pos = binclf_curves[:, :, 1, :].sum(axis=-1) # (num_images, num_thresholds) + pos = binclf_curves[:, :, 1, :].sum(dim=-1) # (num_images, num_thresholds) if (pos != pos[:, :1]).any(): msg = "Expected binclf curves to have the same number of positives per image for every thresh." @@ -205,13 +355,29 @@ def is_binclf_curves(binclf_curves: Tensor, valid_thresholds: Tensor | None) -> if binclf_curves.shape[1] != valid_thresholds.shape[0]: msg = ( - "Expected the binclf curves to have as many confusion matrices as the thresholds sequence, " - f"but got {binclf_curves.shape[1]} and {valid_thresholds.shape[0]}" + "Expected the binclf curves to have as many confusion matrices as " + f"the thresholds sequence, but got {binclf_curves.shape[1]} and " + f"{valid_thresholds.shape[0]}" ) raise RuntimeError(msg) def is_images_classes(images_classes: Tensor) -> None: + """Validate image-level ground truth labels tensor. + + Args: + images_classes: Binary tensor of shape (N,) containing image labels. + + Raises: + ValueError: If tensor is not 1D or contains non-binary values. + TypeError: If tensor has invalid dtype. + + Example: + >>> classes = torch.zeros(10, dtype=torch.bool) + >>> is_images_classes(classes) # No error + >>> bad_classes = torch.ones(10) * 2 + >>> is_images_classes(bad_classes) # Raises ValueError + """ if images_classes.ndim != 1: msg = f"Expected image classes to be 1D, but got {images_classes.ndim}D." raise ValueError(msg) @@ -220,8 +386,9 @@ def is_images_classes(images_classes: Tensor) -> None: pass elif images_classes.dtype.is_floating_point: msg = ( - "Expected image classes to be an integer or boolean Tensor with ground truth labels, " - f"but got Tensor with dtype {images_classes.dtype}" + "Expected image classes to be an integer or boolean Tensor with " + f"ground truth labels, but got Tensor with dtype " + f"{images_classes.dtype}" ) raise TypeError(msg) else: @@ -230,20 +397,38 @@ def is_images_classes(images_classes: Tensor) -> None: unique_vals = torch.unique(images_classes) if torch.any((unique_vals != 0) & (unique_vals != 1)): msg = ( - "Expected image classes to be a *binary* Tensor with ground truth labels, " - f"but got Tensor with unique values {sorted(unique_vals)}" + "Expected image classes to be a *binary* Tensor with ground " + f"truth labels, but got Tensor with unique values " + f"{sorted(unique_vals)}" ) raise ValueError(msg) def is_rates(rates: Tensor, nan_allowed: bool) -> None: + """Validate rates tensor. + + Args: + rates: Tensor of shape (N,) containing rate values in [0,1]. + nan_allowed: Whether NaN values are allowed. + + Raises: + ValueError: If tensor is not 1D, contains values outside [0,1], or has + NaN when not allowed. + TypeError: If tensor is not floating point. + + Example: + >>> rates = torch.tensor([0.1, 0.5, 0.9]) + >>> is_rates(rates, nan_allowed=False) # No error + >>> bad_rates = torch.tensor([0.1, float('nan'), 0.9]) + >>> is_rates(bad_rates, nan_allowed=False) # Raises ValueError + """ if rates.ndim != 1: msg = f"Expected rates to be 1D, but got {rates.ndim}D." raise ValueError(msg) if not rates.dtype.is_floating_point: msg = f"Expected rates to have dtype of float type, but got {rates.dtype}." - raise ValueError(msg) + raise TypeError(msg) isnan_mask = torch.isnan(rates) if nan_allowed: @@ -266,7 +451,28 @@ def is_rates(rates: Tensor, nan_allowed: bool) -> None: raise ValueError(msg) -def is_rate_curve(rate_curve: Tensor, nan_allowed: bool, decreasing: bool) -> None: +def is_rate_curve( + rate_curve: Tensor, + nan_allowed: bool, + decreasing: bool, +) -> None: + """Validate rate curve tensor. + + Args: + rate_curve: Tensor of shape (N,) containing rate values. + nan_allowed: Whether NaN values are allowed. + decreasing: Whether curve should be monotonically decreasing. + + Raises: + ValueError: If curve is not monotonic in specified direction. + + Example: + >>> curve = torch.tensor([0.9, 0.5, 0.1]) + >>> is_rate_curve(curve, nan_allowed=False, decreasing=True) # No error + >>> bad_curve = torch.tensor([0.1, 0.5, 0.9]) + >>> is_rate_curve(bad_curve, nan_allowed=False, decreasing=True) + ValueError: Expected rate curve to be monotonically decreasing + """ is_rates(rate_curve, nan_allowed=nan_allowed) diffs = torch.diff(rate_curve) @@ -281,14 +487,34 @@ def is_rate_curve(rate_curve: Tensor, nan_allowed: bool, decreasing: bool) -> No raise ValueError(msg) -def is_per_image_rate_curves(rate_curves: Tensor, nan_allowed: bool, decreasing: bool | None) -> None: +def is_per_image_rate_curves( + rate_curves: Tensor, + nan_allowed: bool, + decreasing: bool | None, +) -> None: + """Validate per-image rate curves tensor. + + Args: + rate_curves: Tensor of shape (N, T) containing rate curves for N images. + nan_allowed: Whether NaN values are allowed. + decreasing: Whether curves should be monotonically decreasing. + + Raises: + ValueError: If curves have invalid values or wrong monotonicity. + TypeError: If tensor has wrong dtype. + + Example: + >>> curves = torch.zeros(10, 5) # 10 images, 5 thresholds + >>> is_per_image_rate_curves(curves, nan_allowed=False, decreasing=None) + >>> # No error + """ if rate_curves.ndim != 2: msg = f"Expected per-image rate curves to be 2D, but got {rate_curves.ndim}D." raise ValueError(msg) if not rate_curves.dtype.is_floating_point: msg = f"Expected per-image rate curves to have dtype of float type, but got {rate_curves.dtype}." - raise ValueError(msg) + raise TypeError(msg) isnan_mask = torch.isnan(rate_curves) if nan_allowed: @@ -313,7 +539,7 @@ def is_per_image_rate_curves(rate_curves: Tensor, nan_allowed: bool, decreasing: if decreasing is None: return - diffs = torch.diff(rate_curves, axis=1) + diffs = torch.diff(rate_curves, dim=1) diffs_valid = diffs[~torch.isnan(diffs)] if nan_allowed else diffs if decreasing and (diffs_valid > 0).any(): @@ -332,15 +558,30 @@ def is_per_image_rate_curves(rate_curves: Tensor, nan_allowed: bool, decreasing: def is_scores_batch(scores_batch: torch.Tensor) -> None: - """scores_batch (torch.Tensor): floating (N, D).""" + """Validate batch of anomaly scores. + + Args: + scores_batch: Floating point tensor of shape (N, D). + + Raises: + TypeError: If tensor is not floating point. + ValueError: If tensor is not 2D. + + Example: + >>> scores = torch.randn(10, 5) # 10 samples, 5 features + >>> is_scores_batch(scores) # No error + >>> bad_scores = torch.randn(10) # 1D tensor + >>> is_scores_batch(bad_scores) # Raises ValueError + """ if not isinstance(scores_batch, torch.Tensor): msg = f"Expected `scores_batch` to be an torch.Tensor, but got {type(scores_batch)}" raise TypeError(msg) if not scores_batch.dtype.is_floating_point: msg = ( - "Expected `scores_batch` to be an floating torch.Tensor with anomaly scores_batch," - f" but got torch.Tensor with dtype {scores_batch.dtype}" + "Expected `scores_batch` to be an floating torch.Tensor with " + f"anomaly scores_batch, but got torch.Tensor with dtype " + f"{scores_batch.dtype}" ) raise TypeError(msg) @@ -350,15 +591,29 @@ def is_scores_batch(scores_batch: torch.Tensor) -> None: def is_gts_batch(gts_batch: torch.Tensor) -> None: - """gts_batch (torch.Tensor): boolean (N, D).""" + """Validate batch of ground truth labels. + + Args: + gts_batch: Boolean tensor of shape (N, D). + + Raises: + TypeError: If tensor is not boolean. + ValueError: If tensor is not 2D. + + Example: + >>> gts = torch.zeros(10, 5, dtype=torch.bool) + >>> is_gts_batch(gts) # No error + >>> bad_gts = torch.zeros(10, dtype=torch.bool) + >>> is_gts_batch(bad_gts) # Raises ValueError + """ if not isinstance(gts_batch, torch.Tensor): msg = f"Expected `gts_batch` to be an torch.Tensor, but got {type(gts_batch)}" raise TypeError(msg) if gts_batch.dtype != torch.bool: msg = ( - "Expected `gts_batch` to be an boolean torch.Tensor with anomaly scores_batch," - f" but got torch.Tensor with dtype {gts_batch.dtype}" + "Expected `gts_batch` to be an boolean torch.Tensor with anomaly " + f"scores_batch, but got torch.Tensor with dtype {gts_batch.dtype}" ) raise TypeError(msg) @@ -368,6 +623,20 @@ def is_gts_batch(gts_batch: torch.Tensor) -> None: def has_at_least_one_anomalous_image(masks: torch.Tensor) -> None: + """Validate presence of at least one anomalous image. + + Args: + masks: Binary tensor of shape (N, H, W) containing ground truth masks. + + Raises: + ValueError: If no anomalous images are found. + + Example: + >>> masks = torch.ones(10, 32, 32, dtype=torch.bool) # All anomalous + >>> has_at_least_one_anomalous_image(masks) # No error + >>> normal_masks = torch.zeros(10, 32, 32, dtype=torch.bool) + >>> has_at_least_one_anomalous_image(normal_masks) # Raises ValueError + """ is_masks(masks) image_classes = images_classes_from_masks(masks) if (image_classes == 1).sum() == 0: @@ -376,6 +645,20 @@ def has_at_least_one_anomalous_image(masks: torch.Tensor) -> None: def has_at_least_one_normal_image(masks: torch.Tensor) -> None: + """Validate presence of at least one normal image. + + Args: + masks: Binary tensor of shape (N, H, W) containing ground truth masks. + + Raises: + ValueError: If no normal images are found. + + Example: + >>> masks = torch.zeros(10, 32, 32, dtype=torch.bool) # All normal + >>> has_at_least_one_normal_image(masks) # No error + >>> anomalous_masks = torch.ones(10, 32, 32, dtype=torch.bool) + >>> has_at_least_one_normal_image(anomalous_masks) # Raises ValueError + """ is_masks(masks) image_classes = images_classes_from_masks(masks) if (image_classes == 0).sum() == 0: @@ -383,16 +666,53 @@ def has_at_least_one_normal_image(masks: torch.Tensor) -> None: raise ValueError(msg) -def joint_validate_thresholds_shared_fpr(thresholds: torch.Tensor, shared_fpr: torch.Tensor) -> None: +def joint_validate_thresholds_shared_fpr( + thresholds: torch.Tensor, + shared_fpr: torch.Tensor, +) -> None: + """Validate matching dimensions between thresholds and shared FPR. + + Args: + thresholds: Tensor of threshold values. + shared_fpr: Tensor of shared false positive rates. + + Raises: + ValueError: If tensors have different lengths. + + Example: + >>> t = torch.linspace(0, 1, 5) + >>> fpr = torch.zeros(5) + >>> joint_validate_thresholds_shared_fpr(t, fpr) # No error + >>> bad_fpr = torch.zeros(4) + >>> joint_validate_thresholds_shared_fpr(t, bad_fpr) # Raises ValueError + """ if thresholds.shape[0] != shared_fpr.shape[0]: msg = ( - "Expected `thresholds` and `shared_fpr` to have the same number of elements, " - f"but got {thresholds.shape[0]} != {shared_fpr.shape[0]}" + "Expected `thresholds` and `shared_fpr` to have the same number of " + f"elements, but got {thresholds.shape[0]} != {shared_fpr.shape[0]}" ) raise ValueError(msg) -def is_per_image_tprs(per_image_tprs: torch.Tensor, image_classes: torch.Tensor) -> None: +def is_per_image_tprs( + per_image_tprs: torch.Tensor, + image_classes: torch.Tensor, +) -> None: + """Validate per-image true positive rates. + + Args: + per_image_tprs: Tensor of TPR values for each image. + image_classes: Binary tensor indicating normal (0) or anomalous (1) + images. + + Raises: + ValueError: If TPRs have invalid values or wrong monotonicity. + + Example: + >>> tprs = torch.zeros(10, 5) # 10 images, 5 thresholds + >>> classes = torch.zeros(10, dtype=torch.bool) + >>> is_per_image_tprs(tprs, classes) # No error + """ is_images_classes(image_classes) # general validations is_per_image_rate_curves( diff --git a/src/anomalib/metrics/pimo/binary_classification_curve.py b/src/anomalib/metrics/pimo/binary_classification_curve.py index 1a80944041..063090c1a7 100644 --- a/src/anomalib/metrics/pimo/binary_classification_curve.py +++ b/src/anomalib/metrics/pimo/binary_classification_curve.py @@ -1,8 +1,26 @@ """Binary classification curve (numpy-only implementation). -A binary classification (binclf) matrix (TP, FP, FN, TN) is evaluated at multiple thresholds. - -The thresholds are shared by all instances/images, but their binclf are computed independently for each instance/image. +This module provides functionality to compute binary classification matrices at +multiple thresholds. The thresholds are shared across all instances/images, but +binary classification metrics are computed independently for each instance/image. + +The binary classification matrix contains: +- True Positives (TP) +- False Positives (FP) +- False Negatives (FN) +- True Negatives (TN) + +Example: + >>> import torch + >>> from anomalib.metrics.pimo.binary_classification_curve import ( + ... binary_classification_curve + ... ) + >>> scores = torch.rand(10, 100) # 10 images, 100 pixels each + >>> gts = torch.randint(0, 2, (10, 100)).bool() # Binary ground truth + >>> thresholds = torch.linspace(0, 1, 10) # 10 thresholds + >>> curves = binary_classification_curve(scores, gts, thresholds) + >>> curves.shape + torch.Size([10, 10, 2, 2]) """ # Original Code @@ -26,31 +44,36 @@ class ThresholdMethod(Enum): - """Sequence of thresholds to use.""" + """Methods for selecting threshold sequences. + + Available methods: + - ``GIVEN``: Use provided thresholds + - ``MINMAX_LINSPACE``: Linear spacing between min and max scores + - ``MEAN_FPR_OPTIMIZED``: Optimize based on mean false positive rate + """ - GIVEN: str = "given" - MINMAX_LINSPACE: str = "minmax-linspace" - MEAN_FPR_OPTIMIZED: str = "mean-fpr-optimized" + GIVEN = "given" + MINMAX_LINSPACE = "minmax-linspace" + MEAN_FPR_OPTIMIZED = "mean-fpr-optimized" def _binary_classification_curve(scores: np.ndarray, gts: np.ndarray, thresholds: np.ndarray) -> np.ndarray: - """One binary classification matrix at each threshold. + """Compute binary classification matrices at multiple thresholds. + + This implementation is optimized for CPU performance compared to torchmetrics + alternatives when using pre-defined thresholds. - In the case where the thresholds are given (i.e. not considering all possible thresholds based on the scores), - this weird-looking function is faster than the two options in `torchmetrics` on the CPU: - - `_binary_precision_recall_curve_update_vectorized` - - `_binary_precision_recall_curve_update_loop` - (both in module `torchmetrics.functional.classification.precision_recall_curve` in `torchmetrics==1.1.0`). - Note: VALIDATION IS NOT DONE HERE. Make sure to validate the arguments before calling this function. + Note: + Arguments must be validated before calling this function. Args: - scores (np.ndarray): Anomaly scores (D,). - gts (np.ndarray): Binary (bool) ground truth of shape (D,). - thresholds (np.ndarray): Sequence of thresholds in ascending order (K,). + scores: Anomaly scores of shape ``(D,)`` + gts: Binary ground truth of shape ``(D,)`` + thresholds: Sequence of thresholds in ascending order ``(K,)`` Returns: - np.ndarray: Binary classification matrix curve (K, 2, 2) - Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. + Binary classification matrix curve of shape ``(K, 2, 2)`` + containing TP, FP, FN, TN counts at each threshold """ num_th = len(thresholds) @@ -58,7 +81,8 @@ def _binary_classification_curve(scores: np.ndarray, gts: np.ndarray, thresholds scores_positives = scores[gts] # the sorting is very important for the algorithm to work and the speedup scores_positives = np.sort(scores_positives) - # variable updated in the loop; start counting with lowest thresh ==> everything is predicted as positive + # variable updated in the loop; start counting with lowest thresh ==> + # everything is predicted as positive num_pos = current_count_tp = scores_positives.size tps = np.empty((num_th,), dtype=np.int64) @@ -92,7 +116,7 @@ def score_less_than_thresh(score: float, thresh: float) -> bool: fns = num_pos * np.ones((num_th,), dtype=np.int64) - tps tns = num_neg * np.ones((num_th,), dtype=np.int64) - fps - # sequence of dimensions is (thresholds, true class, predicted class) (see docstring) + # sequence of dimensions is (thresholds, true class, predicted class) return np.stack( [ np.stack([tns, fps], axis=-1), @@ -107,48 +131,45 @@ def binary_classification_curve( gts_batch: torch.Tensor, thresholds: torch.Tensor, ) -> torch.Tensor: - """Returns a binary classification matrix at each threshold for each image in the batch. + """Compute binary classification matrices for a batch of images. - This is a wrapper around `_binary_classification_curve`. - Validation of the arguments is done here (not in the actual implementation functions). + This is a wrapper around :func:`_binary_classification_curve` that handles + input validation and batching. - Note: predicted as positive condition is `score >= thresh`. + Note: + Predicted positives are determined by ``score >= thresh`` Args: - scores_batch (torch.Tensor): Anomaly scores (N, D,). - gts_batch (torch.Tensor): Binary (bool) ground truth of shape (N, D,). - thresholds (torch.Tensor): Sequence of thresholds in ascending order (K,). + scores_batch: Anomaly scores of shape ``(N, D)`` + gts_batch: Binary ground truth of shape ``(N, D)`` + thresholds: Sequence of thresholds in ascending order ``(K,)`` Returns: - torch.Tensor: Binary classification matrix curves (N, K, 2, 2) - - The last two dimensions are the confusion matrix (ground truth, predictions) - So for each thresh it gives: - - `tp`: `[... , 1, 1]` - - `fp`: `[... , 0, 1]` - - `fn`: `[... , 1, 0]` - - `tn`: `[... , 0, 0]` - - `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: - - `tp` stands for `true positive` - - `fp` stands for `false positive` - - `fn` stands for `false negative` - - `tn` stands for `true negative` - - The numbers in each confusion matrix are the counts (not the ratios). - - Counts are relative to each instance (i.e. from 0 to D, e.g. the total is the number of pixels in the image). - - Thresholds are shared across all instances, so all confusion matrices, for instance, - at position [:, 0, :, :] are relative to the 1st threshold in `thresholds`. - - Thresholds are sorted in ascending order. + Binary classification matrix curves of shape ``(N, K, 2, 2)`` + where: + + - ``[..., 1, 1]``: True Positives (TP) + - ``[..., 0, 1]``: False Positives (FP) + - ``[..., 1, 0]``: False Negatives (FN) + - ``[..., 0, 0]``: True Negatives (TN) + + The counts are per-instance (e.g. number of pixels in each image). + Thresholds are shared across instances. + + Example: + >>> scores = torch.rand(10, 100) # 10 images, 100 pixels each + >>> gts = torch.randint(0, 2, (10, 100)).bool() + >>> thresholds = torch.linspace(0, 1, 10) + >>> curves = binary_classification_curve(scores, gts, thresholds) + >>> curves.shape + torch.Size([10, 10, 2, 2]) """ _validate.is_scores_batch(scores_batch) _validate.is_gts_batch(gts_batch) _validate.is_same_shape(scores_batch, gts_batch) _validate.is_valid_threshold(thresholds) - # TODO(ashwinvaidya17): this is kept as numpy for now because it is much faster. + # TODO(ashwinvaidya17): this is kept as numpy for now because it is much + # faster. # TEMP-0 result = np.vectorize(_binary_classification_curve, signature="(n),(n),(k)->(k,2,2)")( scores_batch.detach().cpu().numpy(), @@ -159,7 +180,18 @@ def binary_classification_curve( def _get_linspaced_thresholds(anomaly_maps: torch.Tensor, num_thresholds: int) -> torch.Tensor: - """Get thresholds linearly spaced between the min and max of the anomaly maps.""" + """Get linearly spaced thresholds between min and max anomaly scores. + + Args: + anomaly_maps: Anomaly score maps + num_thresholds: Number of thresholds to generate + + Returns: + Linearly spaced thresholds of shape ``(num_thresholds,)`` + + Raises: + ValueError: If threshold bounds are invalid + """ _validate.is_num_thresholds_gte2(num_thresholds) # this operation can be a bit expensive thresh_low, thresh_high = thresh_bounds = (anomaly_maps.min().item(), anomaly_maps.max().item()) @@ -178,45 +210,42 @@ def threshold_and_binary_classification_curve( thresholds: torch.Tensor | None = None, num_thresholds: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Return thresholds and binary classification matrix at each threshold for each image in the batch. + """Get thresholds and binary classification matrices for a batch of images. Args: - anomaly_maps (torch.Tensor): Anomaly score maps of shape (N, H, W) - masks (torch.Tensor): Binary ground truth masks of shape (N, H, W) - threshold_choice (str, optional): Sequence of thresholds to use. Defaults to THRESH_SEQUENCE_MINMAX_LINSPACE. - thresholds (torch.Tensor, optional): Sequence of thresholds to use. - Only applicable when threshold_choice is THRESH_SEQUENCE_GIVEN. - num_thresholds (int, optional): Number of thresholds between the min and max of the anomaly maps. - Only applicable when threshold_choice is THRESH_SEQUENCE_MINMAX_LINSPACE. + anomaly_maps: Anomaly score maps of shape ``(N, H, W)`` + masks: Binary ground truth masks of shape ``(N, H, W)`` + threshold_choice: Method for selecting thresholds. Defaults to + ``MINMAX_LINSPACE`` + thresholds: Sequence of thresholds to use. Only used when + ``threshold_choice`` is ``GIVEN`` + num_thresholds: Number of thresholds between min and max scores. Only + used when ``threshold_choice`` is ``MINMAX_LINSPACE`` Returns: - tuple[torch.Tensor, torch.Tensor]: - [0] Thresholds of shape (K,) and dtype is the same as `anomaly_maps.dtype`. - - [1] Binary classification matrices of shape (N, K, 2, 2) - - N: number of images/instances - K: number of thresholds - - The last two dimensions are the confusion matrix (ground truth, predictions) - So for each thresh it gives: - - `tp`: `[... , 1, 1]` - - `fp`: `[... , 0, 1]` - - `fn`: `[... , 1, 0]` - - `tn`: `[... , 0, 0]` - - `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: - - `tp` stands for `true positive` - - `fp` stands for `false positive` - - `fn` stands for `false negative` - - `tn` stands for `true negative` - - The numbers in each confusion matrix are the counts of pixels in the image (not the ratios). - - Thresholds are shared across all images, so all confusion matrices, for instance, - at position [:, 0, :, :] are relative to the 1st threshold in `thresholds`. - - Thresholds are sorted in ascending order. + Tuple containing: + + - Thresholds of shape ``(K,)`` with same dtype as ``anomaly_maps`` + - Binary classification matrices of shape ``(N, K, 2, 2)`` where: + + - ``[..., 1, 1]``: True Positives (TP) + - ``[..., 0, 1]``: False Positives (FP) + - ``[..., 1, 0]``: False Negatives (FN) + - ``[..., 0, 0]``: True Negatives (TN) + + The counts are per-instance pixel counts. Thresholds are shared across + instances and sorted in ascending order. + + Example: + >>> maps = torch.rand(10, 32, 32) # 10 images + >>> masks = torch.randint(0, 2, (10, 32, 32)).bool() + >>> thresh, curves = threshold_and_binary_classification_curve( + ... maps, + ... masks, + ... num_thresholds=10, + ... ) + >>> thresh.shape, curves.shape + (torch.Size([10]), torch.Size([10, 10, 2, 2])) """ threshold_choice = ThresholdMethod(threshold_choice) _validate.is_anomaly_maps(anomaly_maps) @@ -255,7 +284,7 @@ def threshold_and_binary_classification_curve( # keep the batch dimension and flatten the rest scores_batch = anomaly_maps.reshape(anomaly_maps.shape[0], -1) - gts_batch = masks.reshape(masks.shape[0], -1).to(bool) # make sure it is boolean + gts_batch = masks.reshape(masks.shape[0], -1).to(dtype=torch.bool) binclf_curves = binary_classification_curve(scores_batch, gts_batch, thresholds) @@ -264,8 +293,9 @@ def threshold_and_binary_classification_curve( try: _validate.is_binclf_curves(binclf_curves, valid_thresholds=thresholds) - # these two validations cannot be done in `_validate.binclf_curves` because it does not have access to the - # original shapes of `anomaly_maps` + # these two validations cannot be done in `_validate.binclf_curves` + # because it does not have access to the original shapes of + # `anomaly_maps` if binclf_curves.shape[0] != num_images: msg = ( "Expected `binclf_curves` to have the same number of images as `anomaly_maps`, " @@ -281,54 +311,72 @@ def threshold_and_binary_classification_curve( def per_image_tpr(binclf_curves: torch.Tensor) -> torch.Tensor: - """True positive rates (TPR) for image for each thresh. + """Compute True Positive Rate (TPR) for each image at each threshold. TPR = TP / P = TP / (TP + FN) - TP: true positives - FM: false negatives - P: positives (TP + FN) + Where: + - TP: True Positives + - FN: False Negatives + - P: Total Positives (TP + FN) Args: - binclf_curves (torch.Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + binclf_curves: Binary classification curves of shape ``(N, K, 2, 2)`` + See :func:`binary_classification_curve` Returns: - torch.Tensor: shape (N, K), dtype float64 - N: number of images - K: number of thresholds - - Thresholds are sorted in ascending order, so TPR is in descending order. + TPR values of shape ``(N, K)`` and dtype ``float64`` where: + - N: number of images + - K: number of thresholds + + TPR is in descending order since thresholds are sorted ascending. + TPR will be NaN for normal images (P = 0). + + Example: + >>> curves = torch.randint(0, 10, (5, 10, 2, 2)) # 5 imgs, 10 thresh + >>> tpr = per_image_tpr(curves) + >>> tpr.shape + torch.Size([5, 10]) """ # shape: (num images, num thresholds) tps = binclf_curves[..., 1, 1] - pos = binclf_curves[..., 1, :].sum(axis=2) # 2 was the 3 originally + pos = binclf_curves[..., 1, :].sum(dim=2) # tprs will be nan if pos == 0 (normal image), which is expected return tps.to(torch.float64) / pos.to(torch.float64) def per_image_fpr(binclf_curves: torch.Tensor) -> torch.Tensor: - """False positive rates (TPR) for image for each thresh. + """Compute False Positive Rate (FPR) for each image at each threshold. FPR = FP / N = FP / (FP + TN) - FP: false positives - TN: true negatives - N: negatives (FP + TN) + Where: + - FP: False Positives + - TN: True Negatives + - N: Total Negatives (FP + TN) Args: - binclf_curves (torch.Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + binclf_curves: Binary classification curves of shape ``(N, K, 2, 2)`` + See :func:`binary_classification_curve` Returns: - torch.Tensor: shape (N, K), dtype float64 - N: number of images - K: number of thresholds - - Thresholds are sorted in ascending order, so FPR is in descending order. + FPR values of shape ``(N, K)`` and dtype ``float64`` where: + - N: number of images + - K: number of thresholds + + FPR is in descending order since thresholds are sorted ascending. + FPR will be NaN for fully anomalous images (N = 0). + + Example: + >>> curves = torch.randint(0, 10, (5, 10, 2, 2)) # 5 imgs, 10 thresh + >>> fpr = per_image_fpr(curves) + >>> fpr.shape + torch.Size([5, 10]) """ # shape: (num images, num thresholds) fps = binclf_curves[..., 0, 1] - neg = binclf_curves[..., 0, :].sum(axis=2) # 2 was the 3 originally + neg = binclf_curves[..., 0, :].sum(dim=2) # it can be `nan` if an anomalous image is fully covered by the mask return fps.to(torch.float64) / neg.to(torch.float64) diff --git a/src/anomalib/metrics/pimo/dataclasses.py b/src/anomalib/metrics/pimo/dataclasses.py index 3eaa04cd12..636261f1be 100644 --- a/src/anomalib/metrics/pimo/dataclasses.py +++ b/src/anomalib/metrics/pimo/dataclasses.py @@ -1,4 +1,23 @@ -"""Dataclasses for PIMO metrics.""" +"""Dataclasses for PIMO metrics. + +This module provides dataclasses for storing and manipulating PIMO (Per-Image +Metric Optimization) and AUPIMO (Area Under PIMO) results. + +The dataclasses include: + +- ``PIMOResult``: Container for PIMO curve data and metadata +- ``AUPIMOResult``: Container for AUPIMO curve data and metadata + +Example: + >>> from anomalib.metrics.pimo.dataclasses import PIMOResult + >>> import torch + >>> thresholds = torch.linspace(0, 1, 10) + >>> shared_fpr = torch.linspace(1, 0, 10) # Decreasing FPR + >>> per_image_tprs = torch.rand(5, 10) # 5 images, 10 thresholds + >>> result = PIMOResult(thresholds, shared_fpr, per_image_tprs) + >>> result.num_images + 5 +""" # Based on the code: https://github.com/jpcbertoldo/aupimo # @@ -16,19 +35,31 @@ class PIMOResult: """Per-Image Overlap (PIMO, pronounced pee-mo) curve. - This interface gathers the PIMO curve data and metadata and provides several utility methods. + This class stores PIMO curve data and metadata and provides utility methods + for analysis. Notation: - - N: number of images - - K: number of thresholds - - FPR: False Positive Rate - - TPR: True Positive Rate - - Attributes: - thresholds (torch.Tensor): sequence of K (monotonically increasing) thresholds used to compute the PIMO curve - shared_fpr (torch.Tensor): K values of the shared FPR metric at the corresponding thresholds - per_image_tprs (torch.Tensor): for each of the N images, the K values of in-image TPR at the corresponding - thresholds + - ``N``: number of images + - ``K``: number of thresholds + - ``FPR``: False Positive Rate + - ``TPR``: True Positive Rate + + Args: + thresholds: Sequence of ``K`` monotonically increasing thresholds used + to compute the PIMO curve. Shape: ``(K,)`` + shared_fpr: ``K`` values of the shared FPR metric at corresponding + thresholds. Shape: ``(K,)`` + per_image_tprs: For each of the ``N`` images, the ``K`` values of + in-image TPR at corresponding thresholds. Shape: ``(N, K)`` + + Example: + >>> import torch + >>> thresholds = torch.linspace(0, 1, 10) + >>> shared_fpr = torch.linspace(1, 0, 10) # Decreasing FPR + >>> per_image_tprs = torch.rand(5, 10) # 5 images, 10 thresholds + >>> result = PIMOResult(thresholds, shared_fpr, per_image_tprs) + >>> result.num_images + 5 """ # data @@ -38,25 +69,40 @@ class PIMOResult: @property def num_threshsholds(self) -> int: - """Number of thresholds.""" + """Get number of thresholds. + + Returns: + Number of thresholds used in the PIMO curve. + """ return self.thresholds.shape[0] @property def num_images(self) -> int: - """Number of images.""" + """Get number of images. + + Returns: + Number of images in the dataset. + """ return self.per_image_tprs.shape[0] @property def image_classes(self) -> torch.Tensor: - """Image classes (0: normal, 1: anomalous). + """Get image classes (0: normal, 1: anomalous). - Deduced from the per-image TPRs. - If any TPR value is not NaN, the image is considered anomalous. + The class is deduced from the per-image TPRs. If any TPR value is not + NaN, the image is considered anomalous. + + Returns: + Tensor of shape ``(N,)`` containing image classes. """ return (~torch.isnan(self.per_image_tprs)).any(dim=1).to(torch.int32) def __post_init__(self) -> None: - """Validate the inputs for the result object are consistent.""" + """Validate inputs for result object consistency. + + Raises: + TypeError: If inputs are invalid or have inconsistent shapes. + """ try: _validate.is_valid_threshold(self.thresholds) _validate.is_rate_curve(self.shared_fpr, nan_allowed=False, decreasing=True) # is_shared_apr @@ -68,53 +114,74 @@ def __post_init__(self) -> None: if self.thresholds.shape != self.shared_fpr.shape: msg = ( - f"Invalid {self.__class__.__name__} object. Attributes have inconsistent shapes: " + f"Invalid {self.__class__.__name__} object. " + f"Attributes have inconsistent shapes: " f"{self.thresholds.shape=} != {self.shared_fpr.shape=}." ) raise TypeError(msg) if self.thresholds.shape[0] != self.per_image_tprs.shape[1]: msg = ( - f"Invalid {self.__class__.__name__} object. Attributes have inconsistent shapes: " + f"Invalid {self.__class__.__name__} object. " + f"Attributes have inconsistent shapes: " f"{self.thresholds.shape[0]=} != {self.per_image_tprs.shape[1]=}." ) raise TypeError(msg) def thresh_at(self, fpr_level: float) -> tuple[int, float, float]: - """Return the threshold at the given shared FPR. + """Get threshold at given shared FPR level. - See `anomalib.metrics.per_image.pimo_numpy.thresh_at_shared_fpr_level` for details. + For details see + :func:`anomalib.metrics.per_image.pimo_numpy.thresh_at_shared_fpr_level`. Args: - fpr_level (float): shared FPR level + fpr_level: Target shared FPR level to find threshold for. Returns: - tuple[int, float, float]: - [0] index of the threshold - [1] threshold - [2] the actual shared FPR value at the returned threshold + Tuple containing: + - Index of the threshold + - Threshold value + - Actual shared FPR value at returned threshold + + Example: + >>> result = PIMOResult(...) # doctest: +SKIP + >>> idx, thresh, fpr = result.thresh_at(0.1) # doctest: +SKIP """ - return functional.thresh_at_shared_fpr_level( + idx, thresh, fpr = functional.thresh_at_shared_fpr_level( self.thresholds, self.shared_fpr, fpr_level, ) + return idx, thresh, float(fpr) @dataclass class AUPIMOResult: - """Area Under the Per-Image Overlap (AUPIMO, pronounced a-u-pee-mo) curve. - - This interface gathers the AUPIMO data and metadata and provides several utility methods. - - Attributes: - fpr_lower_bound (float): [metadata] LOWER bound of the FPR integration range - fpr_upper_bound (float): [metadata] UPPER bound of the FPR integration range - num_thresholds (int): [metadata] number of thresholds used to effectively compute AUPIMO; - should not be confused with the number of thresholds used to compute the PIMO curve - thresh_lower_bound (float): LOWER threshold bound --> corresponds to the UPPER FPR bound - thresh_upper_bound (float): UPPER threshold bound --> corresponds to the LOWER FPR bound - aupimos (torch.Tensor): values of AUPIMO scores (1 per image) + """Area Under Per-Image Overlap (AUPIMO, pronounced a-u-pee-mo) curve. + + This class stores AUPIMO data and metadata and provides utility methods for + analysis. + + Args: + fpr_lower_bound: Lower bound of the FPR integration range. + fpr_upper_bound: Upper bound of the FPR integration range. + num_thresholds: Number of thresholds used to compute AUPIMO. Note this + is different from thresholds used for PIMO curve. + thresh_lower_bound: Lower threshold bound (corresponds to upper FPR). + thresh_upper_bound: Upper threshold bound (corresponds to lower FPR). + aupimos: AUPIMO scores, one per image. Shape: ``(N,)`` + + Example: + >>> import torch + >>> aupimos = torch.rand(5) # 5 images + >>> result = AUPIMOResult( # doctest: +SKIP + ... fpr_lower_bound=0.0, + ... fpr_upper_bound=0.3, + ... num_thresholds=100, + ... thresh_lower_bound=0.5, + ... thresh_upper_bound=0.9, + ... aupimos=aupimos + ... ) """ # metadata @@ -129,51 +196,83 @@ class AUPIMOResult: @property def num_images(self) -> int: - """Number of images.""" + """Get number of images. + + Returns: + Number of images in dataset. + """ return self.aupimos.shape[0] @property def num_normal_images(self) -> int: - """Number of normal images.""" + """Get number of normal images. + + Returns: + Count of images with class 0 (normal). + """ return int((self.image_classes == 0).sum()) @property def num_anomalous_images(self) -> int: - """Number of anomalous images.""" + """Get number of anomalous images. + + Returns: + Count of images with class 1 (anomalous). + """ return int((self.image_classes == 1).sum()) @property def image_classes(self) -> torch.Tensor: - """Image classes (0: normal, 1: anomalous).""" - # if an instance has `nan` aupimo it's because it's a normal image + """Get image classes (0: normal, 1: anomalous). + + An image is considered normal if its AUPIMO score is NaN. + + Returns: + Tensor of shape ``(N,)`` containing image classes. + """ return self.aupimos.isnan().to(torch.int32) @property def fpr_bounds(self) -> tuple[float, float]: - """Lower and upper bounds of the FPR integration range.""" + """Get FPR integration range bounds. + + Returns: + Tuple of (lower bound, upper bound) for FPR range. + """ return self.fpr_lower_bound, self.fpr_upper_bound @property def thresh_bounds(self) -> tuple[float, float]: - """Lower and upper bounds of the threshold integration range. + """Get threshold integration range bounds. + + Note: + Bounds correspond to FPR bounds in reverse order: + - ``fpr_lower_bound`` -> ``thresh_upper_bound`` + - ``fpr_upper_bound`` -> ``thresh_lower_bound`` - Recall: they correspond to the FPR bounds in reverse order. - I.e.: - fpr_lower_bound --> thresh_upper_bound - fpr_upper_bound --> thresh_lower_bound + Returns: + Tuple of (lower bound, upper bound) for threshold range. """ return self.thresh_lower_bound, self.thresh_upper_bound def __post_init__(self) -> None: - """Validate the inputs for the result object are consistent.""" + """Validate inputs for result object consistency. + + Raises: + TypeError: If inputs are invalid. + """ try: - _validate.is_rate_range((self.fpr_lower_bound, self.fpr_upper_bound)) - # TODO(jpcbertoldo): warn when it's too low (use parameters from the numpy code) # noqa: TD003 + _validate.is_rate_range( + (self.fpr_lower_bound, self.fpr_upper_bound), + ) + # TODO(jpcbertoldo): warn when too low (use numpy code params) # noqa: TD003 if self.num_thresholds is not None: _validate.is_num_thresholds_gte2(self.num_thresholds) - _validate.is_rates(self.aupimos, nan_allowed=True) # validate is_aupimos + _validate.is_rates(self.aupimos, nan_allowed=True) - _validate.validate_threshold_bounds((self.thresh_lower_bound, self.thresh_upper_bound)) + _validate.validate_threshold_bounds( + (self.thresh_lower_bound, self.thresh_upper_bound), + ) except (TypeError, ValueError) as ex: msg = f"Invalid inputs for {self.__class__.__name__} object. Cause: {ex}." @@ -187,19 +286,37 @@ def from_pimo_result( num_thresholds_auc: int, aupimos: torch.Tensor, ) -> "AUPIMOResult": - """Return an AUPIMO result object from a PIMO result object. + """Create AUPIMO result from PIMO result. Args: - pimo_result: PIMO result object - fpr_bounds: lower and upper bounds of the FPR integration range - num_thresholds_auc: number of thresholds used to effectively compute AUPIMO; - NOT the number of thresholds used to compute the PIMO curve! - aupimos: AUPIMO scores + pimo_result: Source PIMO result object. + fpr_bounds: Tuple of (lower, upper) bounds for FPR range. + num_thresholds_auc: Number of thresholds for AUPIMO computation. + Note this differs from PIMO curve thresholds. + aupimos: AUPIMO scores, one per image. + + Returns: + New AUPIMO result object. + + Raises: + TypeError: If inputs are invalid or inconsistent. + + Example: + >>> pimo_result = PIMOResult(...) # doctest: +SKIP + >>> aupimos = torch.rand(5) # 5 images + >>> result = AUPIMOResult.from_pimo_result( # doctest: +SKIP + ... pimo_result=pimo_result, + ... fpr_bounds=(0.0, 0.3), + ... num_thresholds_auc=100, + ... aupimos=aupimos + ... ) """ if pimo_result.per_image_tprs.shape[0] != aupimos.shape[0]: msg = ( - f"Invalid {cls.__name__} object. Attributes have inconsistent shapes: " - f"there are {pimo_result.per_image_tprs.shape[0]} PIMO curves but {aupimos.shape[0]} AUPIMO scores." + f"Invalid {cls.__name__} object. " + f"Attributes have inconsistent shapes: " + f"there are {pimo_result.per_image_tprs.shape[0]} PIMO curves " + f"but {aupimos.shape[0]} AUPIMO scores." ) raise TypeError(msg) @@ -212,10 +329,10 @@ def from_pimo_result( raise TypeError(msg) fpr_lower_bound, fpr_upper_bound = fpr_bounds - # recall: fpr upper/lower bounds are the same as the thresh lower/upper bounds + # recall: fpr upper/lower bounds are same as thresh lower/upper bounds _, thresh_lower_bound, __ = pimo_result.thresh_at(fpr_upper_bound) _, thresh_upper_bound, __ = pimo_result.thresh_at(fpr_lower_bound) - # `_` is the threshold's index, `__` is the actual fpr value + # `_` is threshold's index, `__` is actual fpr value return cls( fpr_lower_bound=fpr_lower_bound, fpr_upper_bound=fpr_upper_bound, diff --git a/src/anomalib/metrics/pimo/functional.py b/src/anomalib/metrics/pimo/functional.py index 7eac07b1bd..e3d930de05 100644 --- a/src/anomalib/metrics/pimo/functional.py +++ b/src/anomalib/metrics/pimo/functional.py @@ -1,6 +1,33 @@ """Per-Image Overlap curve (PIMO, pronounced pee-mo) and its area under the curve (AUPIMO). -Details: `anomalib.metrics.per_image.pimo`. +This module provides functions for computing PIMO curves and AUPIMO scores for +anomaly detection evaluation. + +The PIMO curve plots True Positive Rate (TPR) values for each image across +multiple anomaly score thresholds. The thresholds are indexed by a shared False +Positive Rate (FPR) measure computed on normal images. + +The AUPIMO score is the area under a PIMO curve within specified FPR bounds, +normalized to the range [0,1]. + +See Also: + :mod:`anomalib.metrics.per_image.pimo` for detailed documentation. + +Example: + >>> import torch + >>> anomaly_maps = torch.rand(10, 32, 32) # 10 images of 32x32 + >>> masks = torch.randint(0, 2, (10, 32, 32)) # Binary masks + >>> thresholds, shared_fpr, per_image_tprs, classes = pimo_curves( + ... anomaly_maps, + ... masks, + ... num_thresholds=100 + ... ) + >>> aupimo_scores = aupimo_scores( + ... anomaly_maps, + ... masks, + ... num_thresholds=100, + ... fpr_bounds=(1e-5, 1e-4) + ... ) """ # Original Code @@ -33,31 +60,40 @@ def pimo_curves( masks: torch.Tensor, num_thresholds: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute the Per-IMage Overlap (PIMO, pronounced pee-mo) curves. + """Compute the Per-IMage Overlap (PIMO) curves. - PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. - The anomaly score thresholds are indexed by a (cross-image shared) value of False Positive Rate (FPR) measure on - the normal images. - - Details: `anomalib.metrics.per_image.pimo`. - - Args' notation: - N: number of images - H: image height - W: image width - K: number of thresholds + PIMO curves plot True Positive Rate (TPR) values for each image across + multiple anomaly score thresholds. The thresholds are indexed by a shared + False Positive Rate (FPR) measure computed on normal images. Args: - anomaly_maps: floating point anomaly score maps of shape (N, H, W) - masks: binary (bool or int) ground truth masks of shape (N, H, W) - num_thresholds: number of thresholds to compute (K) + anomaly_maps: Anomaly score maps of shape ``(N, H, W)`` where: + - ``N``: number of images + - ``H``: image height + - ``W``: image width + masks: Binary ground truth masks of shape ``(N, H, W)`` + num_thresholds: Number of thresholds ``K`` to compute Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - [0] thresholds of shape (K,) in ascending order - [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) - [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) - [3] image classes of shape (N,) with values 0 (normal) or 1 (anomalous) + tuple containing: + - thresholds: Shape ``(K,)`` in ascending order + - shared_fpr: Shape ``(K,)`` in descending order + - per_image_tprs: Shape ``(N, K)`` in descending order + - image_classes: Shape ``(N,)`` with values 0 (normal) or 1 + (anomalous) + + Raises: + ValueError: If inputs are invalid or have inconsistent shapes + RuntimeError: If per-image FPR curves from normal images are invalid + + Example: + >>> anomaly_maps = torch.rand(10, 32, 32) # 10 images of 32x32 + >>> masks = torch.randint(0, 2, (10, 32, 32)) # Binary masks + >>> thresholds, shared_fpr, per_image_tprs, classes = pimo_curves( + ... anomaly_maps, + ... masks, + ... num_thresholds=100 + ... ) """ # validate the strings are valid _validate.is_num_thresholds_gte2(num_thresholds) @@ -69,10 +105,11 @@ def pimo_curves( image_classes = images_classes_from_masks(masks) - # the thresholds are computed here so that they can be restrained to the normal images - # therefore getting a better resolution in terms of FPR quantization - # otherwise the function `binclf_curve_numpy.per_image_binclf_curve` would have the range of thresholds - # computed from all the images (normal + anomalous) + # the thresholds are computed here so that they can be restrained to the + # normal images therefore getting a better resolution in terms of FPR + # quantization otherwise the function + # `binclf_curve_numpy.per_image_binclf_curve` would have the range of + # thresholds computed from all the images (normal + anomalous) thresholds = _get_linspaced_thresholds( anomaly_maps[image_classes == 0], num_thresholds, @@ -109,7 +146,7 @@ def pimo_curves( return thresholds, shared_fpr, per_image_tprs, image_classes -# =========================================== AUPIMO =========================================== +# =========================================== AUPIMO ===================================== def aupimo_scores( @@ -119,34 +156,47 @@ def aupimo_scores( fpr_bounds: tuple[float, float] = (1e-5, 1e-4), force: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: - """Compute the PIMO curves and their Area Under the Curve (i.e. AUPIMO) scores. - - Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. - It can be thought of as the average TPR of the PIMO curves within the given FPR bounds. + """Compute PIMO curves and their Area Under the Curve (AUPIMO) scores. - Details: `anomalib.metrics.per_image.pimo`. - - Args' notation: - N: number of images - H: image height - W: image width - K: number of thresholds + AUPIMO scores are computed by integrating PIMO curves within specified FPR + bounds and normalizing to [0,1]. The score represents the average TPR within + the FPR bounds. Args: - anomaly_maps: floating point anomaly score maps of shape (N, H, W) - masks: binary (bool or int) ground truth masks of shape (N, H, W) - num_thresholds: number of thresholds to compute (K) - fpr_bounds: lower and upper bounds of the FPR integration range - force: whether to force the computation despite bad conditions + anomaly_maps: Anomaly score maps of shape ``(N, H, W)`` where: + - ``N``: number of images + - ``H``: image height + - ``W``: image width + masks: Binary ground truth masks of shape ``(N, H, W)`` + num_thresholds: Number of thresholds ``K`` to compute + fpr_bounds: Lower and upper bounds of FPR integration range + force: Whether to force computation despite bad conditions Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - [0] thresholds of shape (K,) in ascending order - [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) - [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) - [3] image classes of shape (N,) with values 0 (normal) or 1 (anomalous) - [4] AUPIMO scores of shape (N,) in [0, 1] - [5] number of points used in the AUC integration + tuple containing: + - thresholds: Shape ``(K,)`` in ascending order + - shared_fpr: Shape ``(K,)`` in descending order + - per_image_tprs: Shape ``(N, K)`` in descending order + - image_classes: Shape ``(N,)`` with values 0 (normal) or 1 + (anomalous) + - aupimo_scores: Shape ``(N,)`` in range [0,1] + - num_points: Number of points used in AUC integration + + Raises: + ValueError: If inputs are invalid + RuntimeError: If PIMO curves are invalid or integration range has too few + points + + Example: + >>> anomaly_maps = torch.rand(10, 32, 32) # 10 images of 32x32 + >>> masks = torch.randint(0, 2, (10, 32, 32)) # Binary masks + >>> results = aupimo_scores( + ... anomaly_maps, + ... masks, + ... num_thresholds=100, + ... fpr_bounds=(1e-5, 1e-4) + ... ) + >>> thresholds, shared_fpr, tprs, classes, scores, n_points = results """ _validate.is_rate_range(fpr_bounds) @@ -186,8 +236,9 @@ def aupimo_scores( rtol=(rtol := 1e-2), ): logger.warning( - "The lower bound of the shared FPR integration range is not exactly achieved. " - f"Expected {fpr_lower_bound} but got {fpr_lower_bound_defacto}, which is not within {rtol=}.", + "The lower bound of the shared FPR integration range is not exactly " + f"achieved. Expected {fpr_lower_bound} but got " + f"{fpr_lower_bound_defacto}, which is not within {rtol=}.", ) if not torch.isclose( @@ -196,8 +247,9 @@ def aupimo_scores( rtol=rtol, ): logger.warning( - "The upper bound of the shared FPR integration range is not exactly achieved. " - f"Expected {fpr_upper_bound} but got {fpr_upper_bound_defacto}, which is not within {rtol=}.", + "The upper bound of the shared FPR integration range is not exactly " + f"achieved. Expected {fpr_upper_bound} but got " + f"{fpr_upper_bound_defacto}, which is not within {rtol=}.", ) # reminder: fpr lower/upper bound is threshold upper/lower bound (reversed) @@ -207,9 +259,10 @@ def aupimo_scores( # deal with edge cases if thresh_lower_bound_idx >= thresh_upper_bound_idx: msg = ( - "The thresholds corresponding to the given `fpr_bounds` are not valid because " - "they matched the same threshold or the are in the wrong order. " - f"FPR upper/lower = threshold lower/upper = {thresh_lower_bound_idx} and {thresh_upper_bound_idx}." + "The thresholds corresponding to the given `fpr_bounds` are not " + "valid because they matched the same threshold or the are in the " + "wrong order. FPR upper/lower = threshold lower/upper = " + f"{thresh_lower_bound_idx} and {thresh_upper_bound_idx}." ) raise RuntimeError(msg) @@ -217,11 +270,13 @@ def aupimo_scores( shared_fpr_bounded: torch.Tensor = shared_fpr[thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] per_image_tprs_bounded: torch.Tensor = per_image_tprs[:, thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] - # `shared_fpr` and `tprs` are in descending order; `flip()` reverts to ascending order + # `shared_fpr` and `tprs` are in descending order; `flip()` reverts to + # ascending order shared_fpr_bounded = torch.flip(shared_fpr_bounded, dims=[0]) per_image_tprs_bounded = torch.flip(per_image_tprs_bounded, dims=[1]) - # the log's base does not matter because it's a constant factor canceled by normalization factor + # the log's base does not matter because it's a constant factor canceled by + # normalization factor shared_fpr_bounded_log = torch.log(shared_fpr_bounded) # deal with edge cases @@ -229,8 +284,8 @@ def aupimo_scores( if invalid_shared_fpr.all(): msg = ( - "Cannot compute AUPIMO because the shared fpr integration range is invalid). " - "Try increasing the number of thresholds." + "Cannot compute AUPIMO because the shared fpr integration range is " + "invalid). Try increasing the number of thresholds." ) raise RuntimeError(msg) @@ -248,9 +303,9 @@ def aupimo_scores( if num_points_integral <= 30: msg = ( - "Cannot compute AUPIMO because the shared fpr integration range doesn't have enough points. " - f"Found {num_points_integral} points in the integration range. " - "Try increasing `num_thresholds`." + "Cannot compute AUPIMO because the shared fpr integration range " + f"doesn't have enough points. Found {num_points_integral} points in " + "the integration range. Try increasing `num_thresholds`." ) if not force: raise RuntimeError(msg) @@ -259,21 +314,22 @@ def aupimo_scores( if num_points_integral < 300: logger.warning( - "The AUPIMO may be inaccurate because the shared fpr integration range doesn't have enough points. " - f"Found {num_points_integral} points in the integration range. " - "Try increasing `num_thresholds`.", + "The AUPIMO may be inaccurate because the shared fpr integration " + f"range doesn't have enough points. Found {num_points_integral} " + "points in the integration range. Try increasing `num_thresholds`.", ) aucs: torch.Tensor = torch.trapezoid(per_image_tprs_bounded, x=shared_fpr_bounded_log, axis=1) - # normalize, then clip(0, 1) makes sure that the values are in [0, 1] in case of numerical errors + # normalize, then clip(0, 1) makes sure that the values are in [0, 1] in + # case of numerical errors normalization_factor = aupimo_normalizing_factor(fpr_bounds) aucs = (aucs / normalization_factor).clip(0, 1) - return thresholds, shared_fpr, per_image_tprs, image_classes, aucs, num_points_integral + return (thresholds, shared_fpr, per_image_tprs, image_classes, aucs, num_points_integral) -# =========================================== AUX =========================================== +# =========================================== AUX ===================================== def thresh_at_shared_fpr_level( @@ -284,20 +340,32 @@ def thresh_at_shared_fpr_level( """Return the threshold and its index at the given shared FPR level. Three cases are possible: - - fpr_level == 0: the lowest threshold that achieves 0 FPR is returned - - fpr_level == 1: the highest threshold that achieves 1 FPR is returned - - 0 < fpr_level < 1: the threshold that achieves the closest (higher or lower) FPR to `fpr_level` is returned + - ``fpr_level == 0``: lowest threshold achieving 0 FPR is returned + - ``fpr_level == 1``: highest threshold achieving 1 FPR is returned + - ``0 < fpr_level < 1``: threshold achieving closest FPR is returned Args: - thresholds: thresholds at which the shared FPR was computed. - shared_fpr: shared FPR values. - fpr_level: shared FPR value at which to get the threshold. + thresholds: Thresholds at which shared FPR was computed + shared_fpr: Shared FPR values + fpr_level: Shared FPR value at which to get threshold Returns: - tuple[int, float, float]: - [0] index of the threshold - [1] threshold - [2] the actual shared FPR value at the returned threshold + tuple containing: + - index: Index of the threshold + - threshold: Threshold value + - actual_fpr: Actual shared FPR value at returned threshold + + Raises: + ValueError: If inputs are invalid or FPR level is out of range + + Example: + >>> thresholds = torch.linspace(0, 1, 100) + >>> shared_fpr = torch.linspace(1, 0, 100) # Decreasing FPR + >>> idx, thresh, fpr = thresh_at_shared_fpr_level( + ... thresholds, + ... shared_fpr, + ... fpr_level=0.5 + ... ) """ _validate.is_valid_threshold(thresholds) _validate.is_rate_curve(shared_fpr, nan_allowed=False, decreasing=True) @@ -308,20 +376,21 @@ def thresh_at_shared_fpr_level( if fpr_level < shared_fpr_min: msg = ( - "Invalid `fpr_level` because it's out of the range of `shared_fpr` = " - f"[{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." + "Invalid `fpr_level` because it's out of the range of `shared_fpr` " + f"= [{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." ) raise ValueError(msg) if fpr_level > shared_fpr_max: msg = ( - "Invalid `fpr_level` because it's out of the range of `shared_fpr` = " - f"[{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." + "Invalid `fpr_level` because it's out of the range of `shared_fpr` " + f"= [{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." ) raise ValueError(msg) # fpr_level == 0 or 1 are special case - # because there may be multiple solutions, and the chosen should their MINIMUM/MAXIMUM respectively + # because there may be multiple solutions, and the chosen should their + # MINIMUM/MAXIMUM respectively if fpr_level == 0.0: index = torch.min(torch.where(shared_fpr == fpr_level)[0]) @@ -338,16 +407,21 @@ def thresh_at_shared_fpr_level( def aupimo_normalizing_factor(fpr_bounds: tuple[float, float]) -> float: - """Constant that normalizes the AUPIMO integral to 0-1 range. + """Compute constant that normalizes AUPIMO integral to 0-1 range. - It is the maximum possible value from the integral in AUPIMO's definition. - It corresponds to assuming a constant function T_i: thresh --> 1. + The factor is the maximum possible value from the integral in AUPIMO's + definition. It corresponds to assuming a constant function T_i: thresh --> 1. Args: - fpr_bounds: lower and upper bounds of the FPR integration range. + fpr_bounds: Lower and upper bounds of FPR integration range Returns: - float: the normalization factor (>0). + float: Normalization factor (>0) + + Example: + >>> factor = aupimo_normalizing_factor((1e-5, 1e-4)) + >>> print(f"{factor:.3f}") + 2.303 """ _validate.is_rate_range(fpr_bounds) fpr_lower_bound, fpr_upper_bound = fpr_bounds diff --git a/src/anomalib/metrics/pimo/pimo.py b/src/anomalib/metrics/pimo/pimo.py index ef3e22ed3c..4c367f0a40 100644 --- a/src/anomalib/metrics/pimo/pimo.py +++ b/src/anomalib/metrics/pimo/pimo.py @@ -1,35 +1,50 @@ -"""Per-Image Overlap curve (PIMO, pronounced pee-mo) and its area under the curve (AUPIMO). - -# PIMO - -PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. -The anomaly score thresholds are indexed by a (shared) valued of False Positive Rate (FPR) measure on the normal images. - -Each *anomalous* image has its own curve such that the X-axis is shared by all of them. - -At a given threshold: - X-axis: Shared FPR (may vary) - 1. Log of the Average of per-image FPR on normal images. - SEE NOTE BELOW. - Y-axis: per-image TP Rate (TPR), or "Overlap" between the ground truth and the predicted masks. - -*** Note about other shared FPR alternatives *** -The shared FPR metric can be made harder by using the cross-image max (or high-percentile) FPRs instead of the mean. -Rationale: this will further punish models that have exceptional FPs in normal images. -So far there is only one shared FPR metric implemented but others will be added in the future. - -# AUPIMO - -`AUPIMO` is the area under each `PIMO` curve with bounded integration range in terms of shared FPR. - -# Disclaimer - -This module implements torch interfaces to access the numpy code in `pimo_numpy.py`. -Tensors are converted to numpy arrays and then passed and validated in the numpy code. -The results are converted back to tensors and eventually wrapped in an dataclass object. - -Validations will preferably happen in ndarray so the numpy code can be reused without torch, -so often times the Tensor arguments will be converted to ndarray and then validated. +"""Per-Image Overlap curve (PIMO) and its area under the curve (AUPIMO). + +This module provides metrics for evaluating anomaly detection performance using +Per-Image Overlap (PIMO) curves and their area under the curve (AUPIMO). + +PIMO Curves +---------- +PIMO curves plot True Positive Rate (TPR) values for each image across multiple +anomaly score thresholds. The thresholds are indexed by a shared False Positive +Rate (FPR) measure computed on normal images. + +Each anomalous image has its own curve with: + +- X-axis: Shared FPR (logarithmic average of per-image FPR on normal images) +- Y-axis: Per-image TPR ("Overlap" between ground truth and predicted masks) + +Note on Shared FPR +---------------- +The shared FPR metric can be made stricter by using cross-image max or high +percentile FPRs instead of mean. This further penalizes models with exceptional +false positives in normal images. Currently only mean FPR is implemented. + +AUPIMO Score +----------- +AUPIMO is the area under each PIMO curve within bounded FPR integration range. +The score is normalized to [0,1]. + +Implementation Notes +------------------ +This module implements PyTorch interfaces to the numpy implementation in +``pimo_numpy.py``. Tensors are converted to numpy arrays for computation and +validation, then converted back to tensors and wrapped in dataclass objects. + +Example: + >>> import torch + >>> from anomalib.metrics.pimo import PIMO + >>> metric = PIMO(num_thresholds=10) + >>> anomaly_maps = torch.rand(5, 32, 32) # 5 images + >>> masks = torch.randint(0, 2, (5, 32, 32)) # Binary masks + >>> metric.update(anomaly_maps, masks) + >>> result = metric.compute() + >>> result.num_images + 5 + +See Also: + - :class:`PIMOResult`: Container for PIMO curve data + - :class:`AUPIMOResult`: Container for AUPIMO score data """ # Original Code @@ -53,34 +68,35 @@ class _PIMO(Metric): - """Per-IMage Overlap (PIMO, pronounced pee-mo) curves. - - This torchmetrics interface is a wrapper around the functional interface, which is a wrapper around the numpy code. - The tensors are converted to numpy arrays and then passed and validated in the numpy code. - The results are converted back to tensors and wrapped in an dataclass object. + """Per-Image Overlap (PIMO) curve metric. - PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. - The anomaly score thresholds are indexed by a (cross-image shared) value of False Positive Rate (FPR) measure on - the normal images. - - Details: `anomalib.metrics.per_image.pimo`. - - Notation: - N: number of images - H: image height - W: image width - K: number of thresholds - - Attributes: - anomaly_maps: floating point anomaly score maps of shape (N, H, W) - masks: binary (bool or int) ground truth masks of shape (N, H, W) + This metric computes PIMO curves which plot True Positive Rate (TPR) values + for each image across multiple anomaly score thresholds. The thresholds are + indexed by a shared False Positive Rate (FPR) measure on normal images. Args: - num_thresholds: number of thresholds to compute (K) - binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + num_thresholds: Number of thresholds to compute (K). Must be >= 2. - Returns: - PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. + Attributes: + anomaly_maps: List of anomaly score maps, each of shape ``(N, H, W)`` + masks: List of binary ground truth masks, each of shape ``(N, H, W)`` + is_differentiable: Whether metric is differentiable + higher_is_better: Whether higher values are better + full_state_update: Whether to update full state + + Example: + >>> import torch + >>> metric = _PIMO(num_thresholds=10) + >>> anomaly_maps = torch.rand(5, 32, 32) # 5 images + >>> masks = torch.randint(0, 2, (5, 32, 32)) # Binary masks + >>> metric.update(anomaly_maps, masks) + >>> result = metric.compute() + >>> result.num_images + 5 + + Note: + This metric stores all predictions and targets in memory, which may + require significant memory for large datasets. """ is_differentiable: bool = False @@ -95,35 +111,47 @@ class _PIMO(Metric): @property def _is_empty(self) -> bool: - """Return True if the metric has not been updated yet.""" + """Check if metric has been updated. + + Returns: + bool: True if no updates have been made yet. + """ return len(self.anomaly_maps) == 0 @property def num_images(self) -> int: - """Number of images.""" + """Get total number of images. + + Returns: + int: Total number of images across all batches. + """ return sum(am.shape[0] for am in self.anomaly_maps) @property def image_classes(self) -> torch.Tensor: - """Image classes (0: normal, 1: anomalous).""" + """Get image classes (0: normal, 1: anomalous). + + Returns: + torch.Tensor: Binary tensor of image classes. + """ return functional.images_classes_from_masks(self.masks) def __init__(self, num_thresholds: int) -> None: - """Per-Image Overlap (PIMO) curve. + """Initialize PIMO metric. Args: - num_thresholds: number of thresholds used to compute the PIMO curve (K) + num_thresholds: Number of thresholds for curve computation (K). + Must be >= 2. """ super().__init__() logger.warning( - f"Metric `{self.__class__.__name__}` will save all targets and predictions in buffer." - " For large datasets this may lead to large memory footprint.", + f"Metric `{self.__class__.__name__}` will save all targets and " + "predictions in buffer. For large datasets this may lead to large " + "memory footprint.", ) - # the options below are, redundantly, validated here to avoid reaching - # an error later in the execution - + # Validate options early to avoid later errors _validate.is_num_thresholds_gte2(num_thresholds) self.num_thresholds = num_thresholds @@ -131,11 +159,15 @@ def __init__(self, num_thresholds: int) -> None: self.add_state("masks", default=[], dist_reduce_fx="cat") def update(self, anomaly_maps: torch.Tensor, masks: torch.Tensor) -> None: - """Update lists of anomaly maps and masks. + """Update metric state with new predictions and targets. Args: - anomaly_maps (torch.Tensor): predictions of the model (ndim == 2, float) - masks (torch.Tensor): ground truth masks (ndim == 2, binary) + anomaly_maps: Model predictions as float tensors of shape + ``(N, H, W)`` + masks: Ground truth binary masks of shape ``(N, H, W)`` + + Raises: + ValueError: If inputs have invalid shapes or types """ _validate.is_anomaly_maps(anomaly_maps) _validate.is_masks(masks) @@ -144,12 +176,13 @@ def update(self, anomaly_maps: torch.Tensor, masks: torch.Tensor) -> None: self.masks.append(masks) def compute(self) -> PIMOResult: - """Compute the PIMO curves. - - Call the functional interface `pimo_curves()`, which is a wrapper around the numpy code. + """Compute PIMO curves from accumulated data. Returns: - PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. + PIMOResult: Container with curve data and metadata. + + Raises: + RuntimeError: If no data has been added via update() """ if self._is_empty: msg = "No anomaly maps and masks have been added yet. Please call `update()` first." @@ -170,7 +203,7 @@ def compute(self) -> PIMOResult: class PIMO(AnomalibMetric, _PIMO): # type: ignore[misc] - """Wrapper to add AnomalibMetric functionality to PIMO metric.""" + """Wrapper adding AnomalibMetric functionality to PIMO metric.""" default_fields = ("anomaly_map", "gt_mask") @@ -178,32 +211,29 @@ class PIMO(AnomalibMetric, _PIMO): # type: ignore[misc] class _AUPIMO(_PIMO): """Area Under the Per-Image Overlap (PIMO) curve. - This torchmetrics interface is a wrapper around the functional interface, which is a wrapper around the numpy code. - The tensors are converted to numpy arrays and then passed and validated in the numpy code. - The results are converted back to tensors and wrapped in an dataclass object. - - Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. - It can be thought of as the average TPR of the PIMO curves within the given FPR bounds. - - Details: `anomalib.metrics.per_image.pimo`. - - Notation: - N: number of images - H: image height - W: image width - K: number of thresholds - - Attributes: - anomaly_maps: floating point anomaly score maps of shape (N, H, W) - masks: binary (bool or int) ground truth masks of shape (N, H, W) + This metric computes both PIMO curves and their area under the curve + (AUPIMO). AUPIMO scores are computed by integrating PIMO curves within + specified FPR bounds and normalizing to [0,1]. Args: - num_thresholds: number of thresholds to compute (K) - fpr_bounds: lower and upper bounds of the FPR integration range - force: whether to force the computation despite bad conditions - - Returns: - tuple[PIMOResult, AUPIMOResult]: PIMO and AUPIMO results dataclass objects. See `PIMOResult` and `AUPIMOResult`. + num_thresholds: Number of thresholds for curve computation. Default: + 300,000 + fpr_bounds: Lower and upper FPR integration bounds as ``(min, max)``. + Default: ``(1e-5, 1e-4)`` + return_average: If True, return mean AUPIMO score across anomalous + images. If False, return individual scores. Default: True + force: If True, compute scores even in suboptimal conditions. + Default: False + + Example: + >>> import torch + >>> metric = _AUPIMO(num_thresholds=10) + >>> anomaly_maps = torch.rand(5, 32, 32) # 5 images + >>> masks = torch.randint(0, 2, (5, 32, 32)) # Binary masks + >>> metric.update(anomaly_maps, masks) + >>> pimo_result, aupimo_result = metric.compute() + >>> aupimo_result.num_images + 5 """ fpr_bounds: tuple[float, float] @@ -212,21 +242,25 @@ class _AUPIMO(_PIMO): @staticmethod def normalizing_factor(fpr_bounds: tuple[float, float]) -> float: - """Constant that normalizes the AUPIMO integral to 0-1 range. + """Get normalization factor for AUPIMO integral. - It is the maximum possible value from the integral in AUPIMO's definition. - It corresponds to assuming a constant function T_i: thresh --> 1. + The factor normalizes the integral to [0,1] range. It represents the + maximum possible integral value, assuming a constant TPR of 1. Args: - fpr_bounds: lower and upper bounds of the FPR integration range. + fpr_bounds: FPR integration bounds as ``(min, max)`` Returns: - float: the normalization factor (>0). + float: Normalization factor (>0) """ return functional.aupimo_normalizing_factor(fpr_bounds) def __repr__(self) -> str: - """Show the metric name and its integration bounds.""" + """Get string representation with integration bounds. + + Returns: + str: Metric name and FPR bounds + """ lower, upper = self.fpr_bounds return f"{self.__class__.__name__}([{lower:.2g}, {upper:.2g}])" @@ -237,34 +271,34 @@ def __init__( return_average: bool = True, force: bool = False, ) -> None: - """Area Under the Per-Image Overlap (PIMO) curve. + """Initialize AUPIMO metric. Args: - num_thresholds: [passed to parent `PIMO`] number of thresholds used to compute the PIMO curve - fpr_bounds: lower and upper bounds of the FPR integration range - return_average: if True, return the average AUPIMO score; if False, return all the individual AUPIMO scores - force: if True, force the computation of the AUPIMO scores even in bad conditions (e.g. few points) + num_thresholds: Number of thresholds for curve computation + fpr_bounds: FPR integration bounds as ``(min, max)`` + return_average: If True, return mean score across anomalous images + force: If True, compute scores even in suboptimal conditions """ super().__init__(num_thresholds=num_thresholds) - # other validations are done in PIMO.__init__() - _validate.is_rate_range(fpr_bounds) self.fpr_bounds = fpr_bounds self.return_average = return_average self.force = force def compute(self, force: bool | None = None) -> tuple[PIMOResult, AUPIMOResult]: # type: ignore[override] - """Compute the PIMO curves and their Area Under the curve (AUPIMO) scores. - - Call the functional interface `aupimo_scores()`, which is a wrapper around the numpy code. + """Compute PIMO curves and AUPIMO scores. Args: - force: if given (not None), override the `force` attribute. + force: If provided, override instance ``force`` setting Returns: - tuple[PIMOResult, AUPIMOResult]: PIMO curves and AUPIMO scores dataclass objects. - See `PIMOResult` and `AUPIMOResult` for details. + tuple: Contains: + - PIMOResult: PIMO curve data + - AUPIMOResult: AUPIMO score data + + Raises: + RuntimeError: If no data has been added via update() """ if self._is_empty: msg = "No anomaly maps and masks have been added yet. Please call `update()` first." @@ -305,6 +339,6 @@ def compute(self, force: bool | None = None) -> tuple[PIMOResult, AUPIMOResult]: class AUPIMO(AnomalibMetric, _AUPIMO): # type: ignore[misc] - """Wrapper to add AnomalibMetric functionality to AUPIMO metric.""" + """Wrapper adding AnomalibMetric functionality to AUPIMO metric.""" default_fields = ("anomaly_map", "gt_mask") diff --git a/src/anomalib/metrics/pimo/utils.py b/src/anomalib/metrics/pimo/utils.py index f0cac45657..5f162461d4 100644 --- a/src/anomalib/metrics/pimo/utils.py +++ b/src/anomalib/metrics/pimo/utils.py @@ -1,4 +1,16 @@ -"""Torch-oriented interfaces for `utils.py`.""" +"""Utility functions for PIMO metrics. + +This module provides utility functions for working with PIMO (Per-Image Metric +Optimization) metrics in PyTorch. + +Example: + >>> import torch + >>> masks = torch.zeros(3, 32, 32) # 3 normal images + >>> masks[1, 10:20, 10:20] = 1 # Add anomaly to middle image + >>> classes = images_classes_from_masks(masks) + >>> classes + tensor([0, 1, 0]) +""" # Original Code # https://github.com/jpcbertoldo/aupimo @@ -15,5 +27,28 @@ def images_classes_from_masks(masks: torch.Tensor) -> torch.Tensor: - """Deduce the image classes from the masks.""" + """Deduce binary image classes from ground truth masks. + + Determines if each image contains any anomalous pixels (class 1) or is + completely normal (class 0). + + Args: + masks: Binary ground truth masks of shape ``(N, H, W)`` where: + - ``N``: number of images + - ``H``: image height + - ``W``: image width + Values should be 0 (normal) or 1 (anomalous). + + Returns: + torch.Tensor: Binary tensor of shape ``(N,)`` containing image-level + classes where: + - 0: normal image (no anomalous pixels) + - 1: anomalous image (contains anomalous pixels) + + Example: + >>> masks = torch.zeros(3, 32, 32) # 3 normal images + >>> masks[1, 10:20, 10:20] = 1 # Add anomaly to middle image + >>> images_classes_from_masks(masks) + tensor([0, 1, 0]) + """ return (masks == 1).any(axis=(1, 2)).to(torch.int32) diff --git a/src/anomalib/metrics/plotting_utils.py b/src/anomalib/metrics/plotting_utils.py index 0a32dfea29..8c6c7adf34 100644 --- a/src/anomalib/metrics/plotting_utils.py +++ b/src/anomalib/metrics/plotting_utils.py @@ -1,4 +1,8 @@ -"""Helper functions to generate ROC-style plots of various metrics.""" +"""Helper functions to generate ROC-style plots of various metrics. + +This module provides utility functions for generating ROC-style plots and other +visualization helpers used by metrics in Anomalib. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -21,27 +25,45 @@ def plot_figure( title: str, sample_points: int = 1000, ) -> tuple[Figure, Axis]: - """Generate a simple, ROC-style plot, where x_vals is plotted against y_vals. + """Generate a ROC-style plot with x values plotted against y values. - Note that a subsampling is applied if > sample_points are present in x/y, as matplotlib plotting draws - every single plot which takes very long, especially for high-resolution segmentations. + The function creates a matplotlib figure with a single axis showing the curve + defined by ``x_vals`` and ``y_vals``. If the number of points exceeds + ``sample_points``, the data is subsampled to improve plotting performance. Args: - x_vals (torch.Tensor): x values to plot - y_vals (torch.Tensor): y values to plot - auc (torch.Tensor): normalized area under the curve spanned by x_vals, y_vals - xlim (tuple[float, float]): displayed range for x-axis - ylim (tuple[float, float]): displayed range for y-axis - xlabel (str): label of x axis - ylabel (str): label of y axis - loc (str): string-based legend location, for details see + x_vals (torch.Tensor): Values to plot on x-axis. + y_vals (torch.Tensor): Values to plot on y-axis. + auc (torch.Tensor): Area under curve value to display in legend. + xlim (tuple[float, float]): Display range for x-axis as ``(min, max)``. + ylim (tuple[float, float]): Display range for y-axis as ``(min, max)``. + xlabel (str): Label for x-axis. + ylabel (str): Label for y-axis. + loc (str): Legend location. See matplotlib documentation for valid values: https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html - title (str): title of the plot - sample_points (int): number of sampling points to subsample x_vals/y_vals with - Defaults to ``1000``. + title (str): Title of the plot. + sample_points (int, optional): Maximum number of points to plot. Data will + be subsampled if it exceeds this value. Defaults to ``1000``. Returns: - tuple[Figure, Axis]: Figure and the contained Axis + tuple[Figure, Axis]: Tuple containing the figure and its main axis. + + Example: + >>> import torch + >>> x = torch.linspace(0, 1, 100) + >>> y = x ** 2 + >>> auc = torch.tensor(0.5) + >>> fig, ax = plot_figure( + ... x_vals=x, + ... y_vals=y, + ... auc=auc, + ... xlim=(0, 1), + ... ylim=(0, 1), + ... xlabel="False Positive Rate", + ... ylabel="True Positive Rate", + ... loc="lower right", + ... title="ROC Curve", + ... ) """ fig, axis = plt.subplots() diff --git a/src/anomalib/metrics/precision_recall_curve.py b/src/anomalib/metrics/precision_recall_curve.py index a6a6338410..bedd30dd4c 100644 --- a/src/anomalib/metrics/precision_recall_curve.py +++ b/src/anomalib/metrics/precision_recall_curve.py @@ -1,7 +1,23 @@ -"""Custom PrecisionRecallCurve. +"""Custom implementation of Precision-Recall Curve metric. + +This module provides a custom implementation of the binary precision-recall curve +metric that does not apply sigmoid normalization to prediction thresholds, unlike +the standard torchmetrics implementation. -The one in torchmetrics adds a sigmoid operation on top of the thresholds. See: https://github.com/Lightning-AI/torchmetrics/issues/1526 + +Example: + >>> import torch + >>> from anomalib.metrics import BinaryPrecisionRecallCurve + >>> # Create sample predictions and targets + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> # Initialize metric + >>> pr_curve = BinaryPrecisionRecallCurve() + >>> # Update metric state + >>> pr_curve.update(preds, target) + >>> # Compute precision, recall and thresholds + >>> precision, recall, thresholds = pr_curve.compute() """ # Copyright (C) 2024 Intel Corporation @@ -16,7 +32,20 @@ class BinaryPrecisionRecallCurve(_BinaryPrecisionRecallCurve): - """Binary precision-recall curve with without threshold prediction normalization.""" + """Binary precision-recall curve without threshold prediction normalization. + + This class extends the torchmetrics ``BinaryPrecisionRecallCurve`` class but + removes the sigmoid normalization step applied to prediction thresholds. + + Example: + >>> import torch + >>> from anomalib.metrics import BinaryPrecisionRecallCurve + >>> metric = BinaryPrecisionRecallCurve() + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> metric.update(preds, target) + >>> precision, recall, thresholds = metric.compute() + """ @staticmethod def _binary_precision_recall_curve_format( @@ -25,7 +54,25 @@ def _binary_precision_recall_curve_format( thresholds: int | list[float] | Tensor | None = None, ignore_index: int | None = None, ) -> tuple[Tensor, Tensor, Tensor | None]: - """Similar to torchmetrics' ``_binary_precision_recall_curve_format`` except it does not apply sigmoid.""" + """Format predictions and targets for binary precision-recall curve. + + Similar to torchmetrics' ``_binary_precision_recall_curve_format`` but + without sigmoid normalization of predictions. + + Args: + preds (Tensor): Predicted scores or probabilities + target (Tensor): Ground truth binary labels + thresholds (int | list[float] | Tensor | None, optional): Thresholds + used for computing curve points. Defaults to ``None``. + ignore_index (int | None, optional): Label to ignore in evaluation. + Defaults to ``None``. + + Returns: + tuple[Tensor, Tensor, Tensor | None]: Tuple containing: + - Flattened predictions + - Flattened targets + - Adjusted thresholds + """ preds = preds.flatten() target = target.flatten() if ignore_index is not None: @@ -39,11 +86,12 @@ def _binary_precision_recall_curve_format( def update(self, preds: Tensor, target: Tensor) -> None: """Update metric state with new predictions and targets. - Unlike the base class, this accepts raw predictions and targets. + Unlike the base class, this method accepts raw predictions without + applying sigmoid normalization. Args: - preds (Tensor): Predicted probabilities - target (Tensor): Ground truth labels + preds (Tensor): Raw predicted scores or probabilities + target (Tensor): Ground truth binary labels (0 or 1) """ preds, target, _ = BinaryPrecisionRecallCurve._binary_precision_recall_curve_format( preds, diff --git a/src/anomalib/metrics/pro.py b/src/anomalib/metrics/pro.py index d05d8def0d..2d4ff22d01 100644 --- a/src/anomalib/metrics/pro.py +++ b/src/anomalib/metrics/pro.py @@ -1,4 +1,23 @@ -"""Implementation of PRO metric based on TorchMetrics.""" +"""Implementation of PRO metric based on TorchMetrics. + +This module provides the Per-Region Overlap (PRO) metric for evaluating anomaly +segmentation performance. The PRO metric computes the macro average of the +per-region overlap between predicted anomaly masks and ground truth masks. + +Example: + >>> import torch + >>> from anomalib.metrics import PRO + >>> # Create sample predictions and targets + >>> preds = torch.rand(2, 1, 32, 32) # Batch of 2 images + >>> target = torch.zeros(2, 1, 32, 32) + >>> target[0, 0, 10:20, 10:20] = 1 # Add anomalous region + >>> # Initialize metric + >>> pro = PRO() + >>> # Update metric state + >>> pro.update(preds, target) + >>> # Compute PRO score + >>> score = pro.compute() +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -17,35 +36,31 @@ class _PRO(Metric): """Per-Region Overlap (PRO) Score. This metric computes the macro average of the per-region overlap between the - predicted anomaly masks and the ground truth masks. + predicted anomaly masks and the ground truth masks. It first identifies + connected components in the ground truth mask and then computes the overlap + between each component and the predicted mask. Args: - threshold (float): Threshold used to binarize the predictions. + threshold (float, optional): Threshold used to binarize the predictions. Defaults to ``0.5``. - kwargs: Additional arguments to the TorchMetrics base class. + kwargs: Additional arguments passed to the TorchMetrics base class. - Example: - Import the metric from the package: + Attributes: + target (list[torch.Tensor]): List storing ground truth masks from batches + preds (list[torch.Tensor]): List storing predicted masks from batches + threshold (float): Threshold for binarizing predictions + Example: >>> import torch >>> from anomalib.metrics import PRO - - Create random ``preds`` and ``labels`` tensors: - - >>> labels = torch.randint(low=0, high=2, size=(1, 10, 5), dtype=torch.float32) - >>> preds = torch.rand_like(labels) - - Compute the PRO score for labels and preds: - + >>> # Create random predictions and targets + >>> preds = torch.rand(2, 1, 32, 32) # Batch of 2 images + >>> target = torch.zeros(2, 1, 32, 32) + >>> target[0, 0, 10:20, 10:20] = 1 # Add anomalous region + >>> # Initialize and compute PRO score >>> pro = PRO(threshold=0.5) - >>> pro.update(preds, labels) - >>> pro.compute() - tensor(0.5433) - - .. note:: - Note that the example above shows random predictions and labels. - Therefore, the PRO score above may not be reproducible. - + >>> pro.update(preds, target) + >>> score = pro.compute() """ target: list[torch.Tensor] @@ -59,47 +74,66 @@ def __init__(self, threshold: float = 0.5, **kwargs) -> None: self.add_state("target", default=[], dist_reduce_fx="cat") def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None: - """Compute the PRO score for the current batch. + """Update metric state with new predictions and targets. Args: - predictions (torch.Tensor): Predicted anomaly masks (Bx1xHxW) - targets (torch.Tensor): Ground truth anomaly masks (Bx1xHxW) + predictions (torch.Tensor): Predicted anomaly masks of shape + ``(B, 1, H, W)`` where B is batch size + targets (torch.Tensor): Ground truth anomaly masks of shape + ``(B, 1, H, W)`` Example: - To update the metric state for the current batch, use the ``update`` method: - - >>> pro.update(preds, labels) + >>> pro = PRO() + >>> # Assuming preds and target are properly shaped tensors + >>> pro.update(preds, target) """ self.target.append(targets) self.preds.append(predictions) def compute(self) -> torch.Tensor: - """Compute the macro average of the PRO score across all regions in all batches. + """Compute the macro average PRO score across all regions. - Example: - To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method: + Returns: + torch.Tensor: Scalar tensor containing the PRO score averaged across + all regions in all batches - >>> pro.compute() - tensor(0.5433) + Example: + >>> pro = PRO() + >>> # After updating with several batches + >>> score = pro.compute() + >>> print(f"PRO Score: {score:.4f}") """ target = dim_zero_cat(self.target) preds = dim_zero_cat(self.preds) - target = target.unsqueeze(1).type(torch.float) # kornia expects N1HW and FloatTensor format + # kornia expects N1HW format and float dtype + target = target.unsqueeze(1).type(torch.float) comps = connected_components_gpu(target) if target.is_cuda else connected_components_cpu(target) return pro_score(preds, comps, threshold=self.threshold) -def pro_score(predictions: torch.Tensor, comps: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: +def pro_score( + predictions: torch.Tensor, + comps: torch.Tensor, + threshold: float = 0.5, +) -> torch.Tensor: """Calculate the PRO score for a batch of predictions. Args: - predictions (torch.Tensor): Predicted anomaly masks (Bx1xHxW) - comps: (torch.Tensor): Labeled connected components (BxHxW). The components should be labeled from 0 to N - threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. + predictions (torch.Tensor): Predicted anomaly masks of shape + ``(B, 1, H, W)`` + comps (torch.Tensor): Labeled connected components of shape ``(B, H, W)``. + Components should be labeled from 0 to N + threshold (float, optional): Threshold for binarizing float predictions. + Defaults to ``0.5`` Returns: - torch.Tensor: Scalar value representing the average PRO score for the input batch. + torch.Tensor: Scalar tensor containing the average PRO score + + Example: + >>> # Assuming predictions and components are properly shaped tensors + >>> score = pro_score(predictions, components, threshold=0.5) + >>> print(f"PRO Score: {score:.4f}") """ if predictions.dtype == torch.float: predictions = predictions > threshold @@ -113,9 +147,10 @@ def pro_score(predictions: torch.Tensor, comps: torch.Tensor, threshold: float = if n_comps == 1: # only background return torch.Tensor([1.0]) - # Even though ignore_index is set to 0, the final average computed with "macro" - # takes the entire length of the tensor into account. That's why we need to manually - # subtract 1 from the number of components after taking the sum + # Even though ignore_index is set to 0, the final average computed with + # "macro" takes the entire length of the tensor into account. That's why we + # need to manually subtract 1 from the number of components after taking the + # sum recall_tensor = recall( preds.flatten(), comps.flatten(), @@ -128,4 +163,8 @@ def pro_score(predictions: torch.Tensor, comps: torch.Tensor, threshold: float = class PRO(AnomalibMetric, _PRO): # type: ignore[misc] - """Wrapper to add AnomalibMetric functionality to PRO metric.""" + """Wrapper to add AnomalibMetric functionality to PRO metric. + + This class inherits from both ``AnomalibMetric`` and ``_PRO`` to combine + Anomalib's metric functionality with the PRO score computation. + """ diff --git a/src/anomalib/metrics/threshold/__init__.py b/src/anomalib/metrics/threshold/__init__.py index 13d3bf3288..8534a21930 100644 --- a/src/anomalib/metrics/threshold/__init__.py +++ b/src/anomalib/metrics/threshold/__init__.py @@ -1,4 +1,22 @@ -"""Thresholding metrics.""" +"""Thresholding metrics for anomaly detection. + +This module provides various thresholding techniques to convert anomaly scores into +binary predictions. + +Available Thresholds: + - ``BaseThreshold``: Abstract base class for implementing threshold methods + - ``Threshold``: Generic threshold class that can be initialized with a value + - ``F1AdaptiveThreshold``: Automatically finds optimal threshold by maximizing + F1 score + - ``ManualThreshold``: Allows manual setting of threshold value + +Example: + >>> from anomalib.metrics.threshold import ManualThreshold + >>> threshold = ManualThreshold(threshold=0.5) + >>> predictions = threshold(anomaly_scores=[0.1, 0.6, 0.3, 0.9]) + >>> print(predictions) + [0, 1, 0, 1] +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/metrics/threshold/base.py b/src/anomalib/metrics/threshold/base.py index eef57789cd..3f100e09bf 100644 --- a/src/anomalib/metrics/threshold/base.py +++ b/src/anomalib/metrics/threshold/base.py @@ -1,4 +1,28 @@ -"""Base class for thresholding metrics.""" +"""Base classes for thresholding metrics. + +This module provides base classes for implementing threshold-based metrics for +anomaly detection. The main classes are: + +- ``Threshold``: Abstract base class for all threshold metrics +- ``BaseThreshold``: Deprecated alias for ``Threshold`` class + +Example: + >>> from anomalib.metrics.threshold import Threshold + >>> class MyThreshold(Threshold): + ... def __init__(self): + ... super().__init__() + ... self.add_state("scores", default=[]) + ... + ... def update(self, scores): + ... self.scores.append(scores) + ... + ... def compute(self): + ... return torch.tensor(0.5) + >>> threshold = MyThreshold() + >>> threshold.update(torch.tensor([0.1, 0.9])) + >>> threshold.compute() + tensor(0.5) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,43 +34,76 @@ class Threshold(Metric): - """Base class for thresholding metrics. + """Abstract base class for thresholding metrics. + + This class serves as the foundation for implementing threshold-based metrics + in anomaly detection. It inherits from ``torchmetrics.Metric`` and defines + a common interface for threshold computation and state updates. - This class serves as the foundation for all threshold-based metrics in the system. - It inherits from torchmetrics.Metric and provides a common interface for - threshold computation and updates. + Subclasses must implement: + - ``compute()``: Calculate and return the threshold value + - ``update()``: Update internal state with new data - Subclasses should implement the `compute` and `update` methods to define - specific threshold calculation logic. + Example: + >>> class MyThreshold(Threshold): + ... def __init__(self): + ... super().__init__() + ... self.add_state("scores", default=[]) + ... + ... def update(self, scores): + ... self.scores.append(scores) + ... + ... def compute(self): + ... return torch.tensor(0.5) """ def __init__(self, **kwargs) -> None: + """Initialize the threshold metric. + + Args: + **kwargs: Keyword arguments passed to parent ``Metric`` class. + """ super().__init__(**kwargs) def compute(self) -> torch.Tensor: # noqa: PLR6301 - """Compute the threshold. + """Compute the threshold value. Returns: - Value of the optimal threshold. + torch.Tensor: Optimal threshold value. + + Raises: + NotImplementedError: If not implemented by subclass. """ msg = "Subclass of Threshold must implement the compute method" raise NotImplementedError(msg) def update(self, *args, **kwargs) -> None: # noqa: ARG002, PLR6301 - """Update the metric state. + """Update the metric state with new data. Args: - *args: Any positional arguments. - **kwargs: Any keyword arguments. + *args: Positional arguments specific to subclass implementation. + **kwargs: Keyword arguments specific to subclass implementation. + + Raises: + NotImplementedError: If not implemented by subclass. """ msg = "Subclass of Threshold must implement the update method" raise NotImplementedError(msg) class BaseThreshold(Threshold): - """Alias for Threshold class for backward compatibility.""" + """Deprecated alias for ``Threshold`` class. + + .. deprecated:: 0.4.0 + Use ``Threshold`` instead. This class will be removed in a future version. + """ def __init__(self, **kwargs) -> None: + """Initialize with deprecation warning. + + Args: + **kwargs: Keyword arguments passed to parent ``Threshold`` class. + """ warnings.warn( "BaseThreshold is deprecated and will be removed in a future version. Use Threshold instead.", DeprecationWarning, diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index cb2ba1cd19..1d1461ddaf 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -1,4 +1,30 @@ -"""Implementation of F1AdaptiveThreshold based on TorchMetrics.""" +"""F1 adaptive threshold metric for anomaly detection. + +This module provides the ``F1AdaptiveThreshold`` class which automatically finds +the optimal threshold value by maximizing the F1 score on validation data. + +The threshold is computed by: +1. Computing precision-recall curve across multiple thresholds +2. Calculating F1 score at each threshold point +3. Selecting threshold that yields maximum F1 score + +Example: + >>> from anomalib.metrics import F1AdaptiveThreshold + >>> import torch + >>> # Create sample data + >>> labels = torch.tensor([0, 0, 0, 1, 1]) # Binary labels + >>> scores = torch.tensor([2.3, 1.6, 2.6, 7.9, 3.3]) # Anomaly scores + >>> # Initialize and compute threshold + >>> threshold = F1AdaptiveThreshold(default_value=0.5) + >>> optimal_threshold = threshold(scores, labels) + >>> optimal_threshold + tensor(3.3000) + +Note: + The validation set should contain both normal and anomalous samples for + reliable threshold computation. A warning is logged if no anomalous samples + are found. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -15,30 +41,31 @@ class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): - """Anomaly Score Threshold. + """Adaptive threshold that maximizes F1 score. - This class computes/stores the threshold that determines the anomalous label - given anomaly scores. It initially computes the adaptive threshold to find - the optimal f1_score and stores the computed adaptive threshold value. + This class computes and stores the optimal threshold for converting anomaly + scores to binary predictions by maximizing the F1 score on validation data. Args: - default_value: Default value of the threshold. + default_value: Initial threshold value used before computation. Defaults to ``0.5``. + **kwargs: Additional arguments passed to parent classes. - Examples: - To find the best threshold that maximizes the F1 score, we could run the - following: + Attributes: + value (torch.Tensor): Current threshold value. + Example: >>> from anomalib.metrics import F1AdaptiveThreshold >>> import torch - ... - >>> labels = torch.tensor([0, 0, 0, 1, 1]) - >>> preds = torch.tensor([2.3, 1.6, 2.6, 7.9, 3.3]) - ... - >>> adaptive_threshold = F1AdaptiveThreshold(default_value=0.5) - >>> threshold = adaptive_threshold(preds, labels) - >>> threshold - tensor(3.3000) + >>> # Create validation data + >>> labels = torch.tensor([0, 0, 1, 1]) # 2 normal, 2 anomalous + >>> scores = torch.tensor([0.1, 0.2, 0.8, 0.9]) # Anomaly scores + >>> # Initialize threshold + >>> threshold = F1AdaptiveThreshold() + >>> # Compute optimal threshold + >>> optimal_value = threshold(scores, labels) + >>> print(f"Optimal threshold: {optimal_value:.4f}") + Optimal threshold: 0.5000 """ def __init__(self, default_value: float = 0.5, **kwargs) -> None: @@ -48,13 +75,18 @@ def __init__(self, default_value: float = 0.5, **kwargs) -> None: self.value = torch.tensor(default_value) def compute(self) -> torch.Tensor: - """Compute the threshold that yields the optimal F1 score. + """Compute optimal threshold by maximizing F1 score. - Compute the F1 scores while varying the threshold. Store the optimal - threshold as attribute and return the maximum value of the F1 score. + Calculates precision-recall curve and corresponding thresholds, then + finds the threshold that maximizes the F1 score. Returns: - Value of the F1 score at the optimal threshold. + torch.Tensor: Optimal threshold value. + + Warning: + If validation set contains no anomalous samples, the threshold will + default to the maximum anomaly score, which may lead to poor + performance. """ precision: torch.Tensor recall: torch.Tensor @@ -62,9 +94,11 @@ def compute(self) -> torch.Tensor: if not any(1 in batch for batch in self.target): msg = ( - "The validation set does not contain any anomalous images. As a result, the adaptive threshold will " - "take the value of the highest anomaly score observed in the normal validation images, which may lead " - "to poor predictions. For a more reliable adaptive threshold computation, please add some anomalous " + "The validation set does not contain any anomalous images. As a " + "result, the adaptive threshold will take the value of the " + "highest anomaly score observed in the normal validation images, " + "which may lead to poor predictions. For a more reliable " + "adaptive threshold computation, please add some anomalous " "images to the validation set." ) logging.warning(msg) @@ -80,9 +114,9 @@ def compute(self) -> torch.Tensor: return self.value def __repr__(self) -> str: - """Return threshold value within the string representation. + """Return string representation including current threshold value. Returns: - str: String representation of the class. + str: String in format "ClassName(value=X.XX)" """ return f"{super().__repr__()} (value={self.value:.2f})" diff --git a/src/anomalib/metrics/threshold/manual_threshold.py b/src/anomalib/metrics/threshold/manual_threshold.py index e42860db01..d7179be7df 100644 --- a/src/anomalib/metrics/threshold/manual_threshold.py +++ b/src/anomalib/metrics/threshold/manual_threshold.py @@ -1,4 +1,28 @@ -"""Container to hold manual threshold values for image and pixel metrics.""" +"""Manual threshold metric for anomaly detection. + +This module provides the ``ManualThreshold`` class which allows setting a fixed +threshold value for converting anomaly scores to binary predictions. + +The threshold value is manually specified and remains constant regardless of the +input data. + +Example: + >>> from anomalib.metrics import ManualThreshold + >>> import torch + >>> # Create sample data + >>> labels = torch.tensor([0, 0, 1, 1]) # Binary labels + >>> scores = torch.tensor([0.1, 0.3, 0.7, 0.9]) # Anomaly scores + >>> # Initialize with fixed threshold + >>> threshold = ManualThreshold(default_value=0.5) + >>> # Threshold remains constant + >>> threshold(scores, labels) + tensor(0.5000) + +Note: + Unlike adaptive thresholds, this metric does not optimize the threshold value + based on the data. The threshold remains fixed at the manually specified + value. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 1e383530d0..ed4b78c05e 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -1,4 +1,45 @@ -"""Load Anomaly Model.""" +"""Anomaly detection models. + +This module contains all the anomaly detection models available in anomalib. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim + >>> from anomalib.engine import Engine + + >>> # Initialize model and datamodule + >>> datamodule = MVTec() + >>> model = Padim() + + >>> # Train using the engine + >>> engine = Engine() + >>> engine.fit(model=model, datamodule=datamodule) + +The module provides both image and video anomaly detection models: + +Image Models: + - CFA (:class:`anomalib.models.image.Cfa`) + - Cflow (:class:`anomalib.models.image.Cflow`) + - CSFlow (:class:`anomalib.models.image.Csflow`) + - DFKDE (:class:`anomalib.models.image.Dfkde`) + - DFM (:class:`anomalib.models.image.Dfm`) + - DRAEM (:class:`anomalib.models.image.Draem`) + - DSR (:class:`anomalib.models.image.Dsr`) + - EfficientAd (:class:`anomalib.models.image.EfficientAd`) + - FastFlow (:class:`anomalib.models.image.Fastflow`) + - FRE (:class:`anomalib.models.image.Fre`) + - GANomaly (:class:`anomalib.models.image.Ganomaly`) + - PaDiM (:class:`anomalib.models.image.Padim`) + - PatchCore (:class:`anomalib.models.image.Patchcore`) + - Reverse Distillation (:class:`anomalib.models.image.ReverseDistillation`) + - STFPM (:class:`anomalib.models.image.Stfpm`) + - UFlow (:class:`anomalib.models.image.Uflow`) + - VLM-AD (:class:`anomalib.models.image.VlmAd`) + - WinCLIP (:class:`anomalib.models.image.WinClip`) + +Video Models: + - AI-VAD (:class:`anomalib.models.video.AiVad`) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -65,33 +106,51 @@ class UnknownModelError(ModuleNotFoundError): def convert_snake_to_pascal_case(snake_case: str) -> str: - """Convert snake_case to PascalCase. + """Convert snake_case string to PascalCase. + + This function takes a string in snake_case format (words separated by underscores) + and converts it to PascalCase format (each word capitalized and concatenated). Args: - snake_case (str): Input string in snake_case + snake_case (str): Input string in snake_case format (e.g. ``"efficient_ad"``) Returns: - str: Output string in PascalCase + str: Output string in PascalCase format (e.g. ``"EfficientAd"``) Examples: - >>> _convert_snake_to_pascal_case("efficient_ad") - EfficientAd - - >>> _convert_snake_to_pascal_case("patchcore") - Patchcore + >>> convert_snake_to_pascal_case("efficient_ad") + 'EfficientAd' + >>> convert_snake_to_pascal_case("patchcore") + 'Patchcore' + >>> convert_snake_to_pascal_case("reverse_distillation") + 'ReverseDistillation' """ return "".join(word.capitalize() for word in snake_case.split("_")) def get_available_models() -> set[str]: - """Get set of available models. + """Get set of available anomaly detection models. + + Returns a set of model names in snake_case format that are available in the + anomalib library. This includes both image and video anomaly detection models. Returns: - set[str]: List of available models. + set[str]: Set of available model names in snake_case format (e.g. + ``'efficient_ad'``, ``'padim'``, etc.) Example: - >>> get_available_models() - ['ai_vad', 'cfa', 'cflow', 'csflow', 'dfkde', 'dfm', 'draem', 'efficient_ad', 'fastflow', ...] + Get all available models: + + >>> from anomalib.models import get_available_models + >>> models = get_available_models() + >>> print(sorted(list(models))) # doctest: +NORMALIZE_WHITESPACE + ['ai_vad', 'cfa', 'cflow', 'csflow', 'dfkde', 'dfm', 'draem', + 'efficient_ad', 'fastflow', 'fre', 'ganomaly', 'padim', 'patchcore', + 'reverse_distillation', 'stfpm', 'uflow', 'vlm_ad', 'winclip'] + + Note: + The returned model names can be used with :func:`get_model` to instantiate + the corresponding model class. """ return { convert_to_snake_case(cls.__name__) @@ -101,16 +160,32 @@ def get_available_models() -> set[str]: def _get_model_class_by_name(name: str) -> type[AnomalibModule]: - """Retrieves an anomaly model based on its name. + """Retrieve an anomaly model class based on its name. + + This internal function takes a model name and returns the corresponding model class. + The name matching is case-insensitive and supports both snake_case and PascalCase + formats. Args: - name (str): The name of the model to retrieve. The name is case insensitive. + name (str): Name of the model to retrieve. Can be in snake_case (e.g. + ``"efficient_ad"``) or PascalCase (e.g. ``"EfficientAd"``). The name is + case-insensitive. Raises: - UnknownModelError: If the model is not found. + UnknownModelError: If no model is found matching the provided name. The error + message includes the list of available models. Returns: - type[AnomalibModule]: Anomaly Model + type[AnomalibModule]: Model class that inherits from ``AnomalibModule``. + + Examples: + >>> from anomalib.models import _get_model_class_by_name + >>> model_class = _get_model_class_by_name("padim") + >>> model_class.__name__ + 'Padim' + >>> model_class = _get_model_class_by_name("efficient_ad") + >>> model_class.__name__ + 'EfficientAd' """ logger.info("Loading the model.") model_class: type[AnomalibModule] | None = None @@ -127,27 +202,53 @@ def _get_model_class_by_name(name: str) -> type[AnomalibModule]: def get_model(model: DictConfig | str | dict | Namespace, *args, **kwdargs) -> AnomalibModule: - """Get Anomaly Model. + """Get an anomaly detection model instance. + + This function instantiates an anomaly detection model based on the provided + configuration or model name. It supports multiple ways of model specification + including string names, dictionaries and OmegaConf configurations. Args: - model (DictConfig | str): Can either be a configuration or a string. - *args: Variable length argument list for model init. - **kwdargs: Arbitrary keyword arguments for model init. + model (DictConfig | str | dict | Namespace): Model specification that can be: + - A string with model name (e.g. ``"padim"``, ``"efficient_ad"``) + - A dictionary with ``class_path`` and optional ``init_args`` + - An OmegaConf DictConfig with similar structure as dict + - A Namespace object with similar structure as dict + *args: Variable length argument list passed to model initialization. + **kwdargs: Arbitrary keyword arguments passed to model initialization. - Examples: - >>> get_model("Padim") - >>> get_model("efficient_ad") - >>> get_model("Patchcore", input_size=(100, 100)) - >>> get_model({"class_path": "Padim"}) - >>> get_model({"class_path": "Patchcore"}, input_size=(100, 100)) - >>> get_model({"class_path": "Padim", "init_args": {"input_size": (100, 100)}}) - >>> get_model({"class_path": "anomalib.models.Padim", "init_args": {"input_size": (100, 100)}}}) + Returns: + AnomalibModule: Instantiated anomaly detection model. Raises: - TypeError: If unsupported type is passed. + TypeError: If ``model`` argument is of unsupported type. + UnknownModelError: If specified model class cannot be found. - Returns: - AnomalibModule: Anomaly Model + Examples: + Get model by name: + + >>> model = get_model("padim") + >>> model = get_model("efficient_ad") + >>> model = get_model("patchcore", input_size=(100, 100)) + + Get model using dictionary config: + + >>> model = get_model({"class_path": "Padim"}) + >>> model = get_model( + ... {"class_path": "Patchcore"}, + ... input_size=(100, 100) + ... ) + >>> model = get_model({ + ... "class_path": "Padim", + ... "init_args": {"input_size": (100, 100)} + ... }) + + Get model using fully qualified path: + + >>> model = get_model({ + ... "class_path": "anomalib.models.Padim", + ... "init_args": {"input_size": (100, 100)} + ... }) """ _model: AnomalibModule if isinstance(model, str): diff --git a/src/anomalib/models/components/__init__.py b/src/anomalib/models/components/__init__.py index 762345a93d..2bc64990bd 100644 --- a/src/anomalib/models/components/__init__.py +++ b/src/anomalib/models/components/__init__.py @@ -1,4 +1,38 @@ -"""Components used within the models.""" +"""Components used within the anomaly detection models. + +This module provides various components that are used across different anomaly +detection models in the library. + +Components: + Base Components: + - ``AnomalibModule``: Base module for all anomaly detection models + - ``BufferListMixin``: Mixin for managing lists of buffers + - ``DynamicBufferMixin``: Mixin for dynamic buffer management + - ``MemoryBankMixin``: Mixin for memory bank functionality + + Dimensionality Reduction: + - ``PCA``: Principal Component Analysis + - ``SparseRandomProjection``: Random projection with sparse matrices + + Feature Extraction: + - ``TimmFeatureExtractor``: Feature extractor using timm models + - ``TorchFXFeatureExtractor``: Feature extractor using TorchFX + + Image Processing: + - ``GaussianBlur2d``: 2D Gaussian blur filter + + Sampling: + - ``KCenterGreedy``: K-center greedy sampling algorithm + + Statistical Methods: + - ``GaussianKDE``: Gaussian kernel density estimation + - ``MultiVariateGaussian``: Multivariate Gaussian distribution + +Example: + >>> from anomalib.models.components import GaussianKDE + >>> kde = GaussianKDE() + >>> # Use components in anomaly detection models +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/components/base/__init__.py b/src/anomalib/models/components/base/__init__.py index 250eec5045..a010551fe1 100644 --- a/src/anomalib/models/components/base/__init__.py +++ b/src/anomalib/models/components/base/__init__.py @@ -1,4 +1,21 @@ -"""Base classes for all anomaly components.""" +"""Base classes for all anomaly components. + +This module provides the foundational classes used across anomalib's model +components. These include: + +- ``AnomalibModule``: Base class for all anomaly detection modules +- ``BufferListMixin``: Mixin for managing lists of model buffers +- ``DynamicBufferMixin``: Mixin for handling dynamic model buffers +- ``MemoryBankMixin``: Mixin for models requiring feature memory banks + +Example: + >>> from anomalib.models.components.base import AnomalibModule + >>> class MyAnomalyModel(AnomalibModule): + ... def __init__(self): + ... super().__init__() + ... def forward(self, x): + ... return x +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/components/base/anomalib_module.py b/src/anomalib/models/components/base/anomalib_module.py index 3fd5557032..b5fc6a57cf 100644 --- a/src/anomalib/models/components/base/anomalib_module.py +++ b/src/anomalib/models/components/base/anomalib_module.py @@ -1,4 +1,41 @@ -"""Base Anomaly Module for Training Task.""" +"""Base Anomaly Module for Training Task. + +This module provides the foundational class for all anomaly detection models in +anomalib. The ``AnomalibModule`` class extends PyTorch Lightning's +``LightningModule`` and provides common functionality for training, validation, +testing and inference of anomaly detection models. + +The class handles: +- Model initialization and setup +- Pre-processing of input data +- Post-processing of model outputs +- Evaluation metrics computation +- Visualization of results +- Model export capabilities + +Example: + Create a custom anomaly detection model: + + >>> from anomalib.models.components.base import AnomalibModule + >>> class MyModel(AnomalibModule): + ... def __init__(self): + ... super().__init__() + ... self.model = torch.nn.Linear(10, 1) + ... + ... def training_step(self, batch, batch_idx): + ... return self.model(batch) + + Create model with custom components: + + >>> from anomalib.pre_processing import PreProcessor + >>> from anomalib.post_processing import PostProcessor + >>> model = MyModel( + ... pre_processor=PreProcessor(), + ... post_processor=PostProcessor(), + ... evaluator=True, + ... visualizer=True + ... ) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -33,9 +70,56 @@ class AnomalibModule(ExportMixin, pl.LightningModule, ABC): - """AnomalibModule to train, validate, predict and test images. - - Acts as a base class for all the Anomaly Modules in the library. + """Base class for all anomaly detection modules in anomalib. + + This class provides the core functionality for training, validation, testing + and inference of anomaly detection models. It handles data pre-processing, + post-processing, evaluation and visualization. + + Args: + pre_processor (PreProcessor | bool, optional): Pre-processor instance or + flag to use default. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance + or flag to use default. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or flag to use + default. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or flag to + use default. Defaults to ``True``. + + Attributes: + model (nn.Module): PyTorch model to be trained + loss (nn.Module): Loss function for training + callbacks (list[Callback]): List of callbacks + pre_processor (PreProcessor | None): Component for pre-processing inputs + post_processor (PostProcessor | None): Component for post-processing + outputs + evaluator (Evaluator | None): Component for computing metrics + visualizer (Visualizer | None): Component for visualization + + Example: + Create a model with default components: + + >>> model = AnomalibModule() + + Create a model with custom components: + + >>> from anomalib.pre_processing import PreProcessor + >>> from anomalib.post_processing import PostProcessor + >>> model = AnomalibModule( + ... pre_processor=PreProcessor(), + ... post_processor=PostProcessor(), + ... evaluator=True, + ... visualizer=True + ... ) + + Disable certain components: + + >>> model = AnomalibModule( + ... pre_processor=False, + ... post_processor=False, + ... evaluator=False, + ... visualizer=False + ... ) """ def __init__( @@ -63,31 +147,52 @@ def __init__( @property def name(self) -> str: - """Name of the model.""" + """Get name of the model. + + Returns: + str: Name of the model class + """ return self.__class__.__name__ def setup(self, stage: str | None = None) -> None: - """Calls the _setup method to build the model if the model is not already built.""" + """Set up the model if not already done. + + This method ensures the model is built by calling ``_setup()`` if needed. + + Args: + stage (str | None, optional): Current stage of training. + Defaults to ``None``. + """ if getattr(self, "model", None) is None or not self._is_setup: self._setup() if isinstance(stage, TrainerFn): - # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer + # only set the flag if the stage is a TrainerFn, which means the + # setup has been called from a trainer self._is_setup = True def _setup(self) -> None: - """The _setup method is used to build the torch model dynamically or adjust something about them. + """Set up the model architecture. + + This method should be overridden by subclasses to build their model + architecture. It is called by ``setup()`` when the model needs to be + initialized. - The model implementer may override this method to build the model. This is useful when the model cannot be set - in the `__init__` method because it requires some information or data that is not available at the time of - initialization. + This is useful when the model cannot be fully initialized in ``__init__`` + because it requires data-dependent parameters. """ def configure_callbacks(self) -> Sequence[Callback] | Callback: - """Configure default callbacks for AnomalibModule. + """Configure callbacks for the model. Returns: - List of callbacks that includes the pre-processor, post-processor, evaluator, - and visualizer if they are available and inherit from Callback. + Sequence[Callback] | Callback: List of callbacks including components + that inherit from ``Callback`` + + Example: + >>> model = AnomalibModule() + >>> callbacks = model.configure_callbacks() + >>> isinstance(callbacks, (Sequence, Callback)) + True """ callbacks: list[Callback] = [] callbacks.extend( @@ -98,15 +203,27 @@ def configure_callbacks(self) -> Sequence[Callback] | Callback: return callbacks def forward(self, batch: torch.Tensor, *args, **kwargs) -> InferenceBatch: - """Perform the forward-pass by passing input tensor to the module. + """Perform forward pass through the model pipeline. + + The input batch is passed through: + 1. Pre-processor (if configured) + 2. Model + 3. Post-processor (if configured) Args: - batch (dict[str, str | torch.Tensor]): Input batch. - *args: Arguments. - **kwargs: Keyword arguments. + batch (torch.Tensor): Input batch + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - Tensor: Output tensor from the model. + InferenceBatch: Processed batch with model predictions + + Example: + >>> model = AnomalibModule() + >>> batch = torch.randn(1, 3, 256, 256) + >>> output = model(batch) + >>> isinstance(output, InferenceBatch) + True """ del args, kwargs # These variables are not used. batch = self.pre_processor(batch) if self.pre_processor else batch @@ -119,35 +236,38 @@ def predict_step( batch_idx: int, dataloader_idx: int = 0, ) -> STEP_OUTPUT: - """Step function called during :meth:`~lightning.pytorch.trainer.Trainer.predict`. + """Perform prediction step. - By default, it calls :meth:`~lightning.pytorch.core.lightning.LightningModule.forward`. - Override to add any processing logic. + This method is called during the predict stage of training. By default, + it calls the validation step. Args: - batch (Any): Current batch - batch_idx (int): Index of current batch - dataloader_idx (int): Index of the current dataloader + batch (Batch): Input batch + batch_idx (int): Index of the batch + dataloader_idx (int, optional): Index of the dataloader. + Defaults to ``0``. - Return: - Predicted output + Returns: + STEP_OUTPUT: Model predictions """ del dataloader_idx # These variables are not used. return self.validation_step(batch, batch_idx) def test_step(self, batch: Batch, batch_idx: int, *args, **kwargs) -> STEP_OUTPUT: - """Calls validation_step for anomaly map/score calculation. + """Perform test step. + + This method is called during the test stage of training. By default, + it calls the predict step. Args: - batch (Batch): Input batch - batch_idx (int): Batch index - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch + batch_idx (int): Index of the batch + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - Dictionary containing images, features, true labels and masks. - These are required in `validation_epoch_end` for feature concatenation. + STEP_OUTPUT: Model predictions """ del args, kwargs # These variables are not used. @@ -156,26 +276,43 @@ def test_step(self, batch: Batch, batch_idx: int, *args, **kwargs) -> STEP_OUTPU @property @abstractmethod def trainer_arguments(self) -> dict[str, Any]: - """Arguments used to override the trainer parameters so as to train the model correctly.""" + """Get trainer arguments specific to this model. + + Returns: + dict[str, Any]: Dictionary of trainer arguments + + Raises: + NotImplementedError: If not implemented by subclass + """ raise NotImplementedError @property @abstractmethod def learning_type(self) -> LearningType: - """Learning type of the model.""" + """Get learning type of the model. + + Returns: + LearningType: Type of learning (e.g. one-class, supervised) + + Raises: + NotImplementedError: If not implemented by subclass + """ raise NotImplementedError def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None: - """Resolve and validate which pre-processor to use.. + """Resolve and validate the pre-processor configuration. Args: - pre_processor: Pre-processor configuration - - True -> use default pre-processor - - False -> no pre-processor - - PreProcessor -> use the provided pre-processor + pre_processor (PreProcessor | bool): Pre-processor configuration + - ``True`` -> use default pre-processor + - ``False`` -> no pre-processor + - ``PreProcessor`` -> use provided pre-processor Returns: - Configured pre-processor + PreProcessor | None: Configured pre-processor + + Raises: + TypeError: If pre_processor is invalid type """ if isinstance(pre_processor, PreProcessor): return pre_processor @@ -186,39 +323,22 @@ def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProce @classmethod def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor: - """Configure the pre-processor. + """Configure the default pre-processor. - The default pre-processor resizes images to 256x256 and normalizes using ImageNet statistics. - Individual models can override this method to provide custom transforms and pre-processing pipelines. + The default pre-processor resizes images and normalizes using ImageNet + statistics. Args: - image_size (tuple[int, int] | None, optional): Target size for resizing images. - If None, defaults to (256, 256). Defaults to None. - **kwargs (Any): Additional keyword arguments (unused). + image_size (tuple[int, int] | None, optional): Target size for + resizing. Defaults to ``(256, 256)``. Returns: - PreProcessor: Configured pre-processor instance. - - Examples: - Get default pre-processor with custom image size: - - >>> preprocessor = AnomalibModule.configure_pre_processor(image_size=(512, 512)) - - Create model with custom pre-processor: + PreProcessor: Configured pre-processor - >>> from torchvision.transforms.v2 import RandomHorizontalFlip - >>> custom_transform = Compose([ - ... Resize((256, 256), antialias=True), - ... CenterCrop((224, 224)), - ... RandomHorizontalFlip(p=0.5), - ... Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ... ]) - >>> preprocessor.train_transform = custom_transform - >>> model = PatchCore(pre_processor=preprocessor) - - Disable pre-processing: - - >>> model = PatchCore(pre_processor=False) + Example: + >>> preprocessor = AnomalibModule.configure_pre_processor((512, 512)) + >>> isinstance(preprocessor, PreProcessor) + True """ image_size = image_size or (256, 256) return PreProcessor( @@ -229,16 +349,19 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P ) def _resolve_post_processor(self, post_processor: PostProcessor | bool) -> PostProcessor | None: - """Resolve and validate which post-processor to use. + """Resolve and validate the post-processor configuration. Args: - post_processor: Post-processor configuration - - True -> use default post-processor - - False -> no post-processor - - PostProcessor -> use the provided post-processor + post_processor (PostProcessor | bool): Post-processor configuration + - ``True`` -> use default post-processor + - ``False`` -> no post-processor + - ``PostProcessor`` -> use provided post-processor Returns: - Configured post-processor + PostProcessor | None: Configured post-processor + + Raises: + TypeError: If post_processor is invalid type """ if isinstance(post_processor, PostProcessor): return post_processor @@ -248,41 +371,43 @@ def _resolve_post_processor(self, post_processor: PostProcessor | bool) -> PostP raise TypeError(msg) def configure_post_processor(self) -> PostProcessor | None: - """Configure the default post-processor based on the learning type. + """Configure the default post-processor. Returns: - PostProcessor: Configured post-processor instance. + PostProcessor | None: Configured post-processor based on learning type Raises: - NotImplementedError: If no default post-processor is available for the model's learning type. - - Examples: - Get default post-processor: - - >>> post_processor = AnomalibModule.configure_post_processor() - - Create model with custom post-processor: - - >>> custom_post_processor = CustomPostProcessor() - >>> model = PatchCore(post_processor=custom_post_processor) + NotImplementedError: If no default post-processor exists for the + model's learning type - Disable post-processing: - - >>> model = PatchCore(post_processor=False) + Example: + >>> model = AnomalibModule() + >>> post_processor = model.configure_post_processor() + >>> isinstance(post_processor, PostProcessor) + True """ if self.learning_type == LearningType.ONE_CLASS: return OneClassPostProcessor() msg = ( - f"No default post-processor available for model with learning type {self.learning_type}. " - "Please override the configure_post_processor method in the model implementation." + f"No default post-processor available for model with learning type " + f"{self.learning_type}. Please override configure_post_processor." ) raise NotImplementedError(msg) def _resolve_evaluator(self, evaluator: Evaluator | bool) -> Evaluator | None: - """Resolve the evaluator to be used in the model. + """Resolve and validate the evaluator configuration. + + Args: + evaluator (Evaluator | bool): Evaluator configuration + - ``True`` -> use default evaluator + - ``False`` -> no evaluator + - ``Evaluator`` -> use provided evaluator - If the evaluator is set to True, the default evaluator will be used. If the evaluator is set to False, no - evaluator will be used. If the evaluator is an instance of Evaluator, it will be used as the evaluator. + Returns: + Evaluator | None: Configured evaluator + + Raises: + TypeError: If evaluator is invalid type """ if isinstance(evaluator, Evaluator): return evaluator @@ -293,9 +418,18 @@ def _resolve_evaluator(self, evaluator: Evaluator | bool) -> Evaluator | None: @staticmethod def configure_evaluator() -> Evaluator: - """Default evaluator. + """Configure the default evaluator. - Override in subclass for model-specific evaluator behaviour. + The default evaluator includes metrics for both image-level and + pixel-level evaluation. + + Returns: + Evaluator: Configured evaluator with default metrics + + Example: + >>> evaluator = AnomalibModule.configure_evaluator() + >>> isinstance(evaluator, Evaluator) + True """ image_auroc = AUROC(fields=["pred_score", "gt_label"], prefix="image_") image_f1score = F1Score(fields=["pred_label", "gt_label"], prefix="image_") @@ -305,16 +439,19 @@ def configure_evaluator() -> Evaluator: return Evaluator(test_metrics=test_metrics) def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None: - """Resolve and validate which visualizer to use. + """Resolve and validate the visualizer configuration. Args: - visualizer: Visualizer configuration - - True -> use default visualizer - - False -> no visualizer - - Visualizer -> use the provided visualizer + visualizer (Visualizer | bool): Visualizer configuration + - ``True`` -> use default visualizer + - ``False`` -> no visualizer + - ``Visualizer`` -> use provided visualizer Returns: - Configured visualizer + Visualizer | None: Configured visualizer + + Raises: + TypeError: If visualizer is invalid type """ if isinstance(visualizer, Visualizer): return visualizer @@ -327,46 +464,28 @@ def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | Non def configure_visualizer(cls) -> ImageVisualizer: """Configure the default visualizer. - By default, this method returns an ImageVisualizer instance, which is suitable for - visualizing image-based anomaly detection results. However, the visualizer can be - customized based on your needs - for example, using VideoVisualizer for video data - or implementing a custom visualizer for specific visualization requirements. - Returns: - Visualizer: Configured visualizer instance (ImageVisualizer by default). - - Examples: - Get default ImageVisualizer: + ImageVisualizer: Default image visualizer instance + Example: >>> visualizer = AnomalibModule.configure_visualizer() - - Create model with VideoVisualizer: - - >>> from custom_module import VideoVisualizer - >>> video_visualizer = VideoVisualizer() - >>> model = PatchCore(visualizer=video_visualizer) - - Create model with custom visualizer: - - >>> class CustomVisualizer(Visualizer): - ... def __init__(self): - ... super().__init__() - ... # Custom visualization logic - >>> custom_visualizer = CustomVisualizer() - >>> model = PatchCore(visualizer=custom_visualizer) - - Disable visualization: - - >>> model = PatchCore(visualizer=False) + >>> isinstance(visualizer, ImageVisualizer) + True """ return ImageVisualizer() @property def input_size(self) -> tuple[int, int] | None: - """Return the effective input size of the model. + """Get the effective input size of the model. - The effective input size is the size of the input tensor after the transform has been applied. If the transform - is not set, or if the transform does not change the shape of the input tensor, this method will return None. + Returns: + tuple[int, int] | None: Height and width of model input after + pre-processing, or ``None`` if size cannot be determined + + Example: + >>> model = AnomalibModule() + >>> model.input_size # Returns size after pre-processing + (256, 256) """ transform = self.pre_processor.predict_transform if self.pre_processor else None if transform is None: @@ -381,27 +500,29 @@ def from_config( config_path: str | Path, **kwargs, ) -> "AnomalibModule": - """Create a model instance from the configuration. + """Create a model instance from a configuration file. Args: - config_path (str | Path): Path to the model configuration file. - **kwargs (dict): Additional keyword arguments. + config_path (str | Path): Path to the model configuration file + **kwargs: Additional arguments to override config values Returns: - AnomalibModule: model instance. - - Example: - The following example shows how to get model from patchcore.yaml: - - .. code-block:: python - >>> model_config = "configs/model/patchcore.yaml" - >>> model = AnomalibModule.from_config(config_path=model_config) + AnomalibModule: Instantiated model - The following example shows overriding the configuration file with additional keyword arguments: + Raises: + FileNotFoundError: If config file does not exist + ValueError: If instantiated model is not AnomalibModule - .. code-block:: python - >>> override_kwargs = {"model.pre_trained": False} - >>> model = AnomalibModule.from_config(config_path=model_config, **override_kwargs) + Example: + >>> model = AnomalibModule.from_config("configs/model/patchcore.yaml") + >>> isinstance(model, AnomalibModule) + True + + Override config values: + >>> model = AnomalibModule.from_config( + ... "configs/model/patchcore.yaml", + ... model__backbone="resnet18" + ... ) """ from jsonargparse import ActionConfigFile, ArgumentParser from lightning.pytorch import Trainer diff --git a/src/anomalib/models/components/base/buffer_list.py b/src/anomalib/models/components/base/buffer_list.py index f236c2e361..212880d481 100644 --- a/src/anomalib/models/components/base/buffer_list.py +++ b/src/anomalib/models/components/base/buffer_list.py @@ -1,4 +1,45 @@ -"""Buffer List Mixin.""" +"""Buffer List Mixin. + +This mixin allows registering a list of tensors as buffers in a PyTorch module. + +Example: + >>> # Create a module that uses the buffer list mixin + >>> class MyModule(BufferListMixin, nn.Module): + ... def __init__(self): + ... super().__init__() + ... tensor_list = [torch.ones(3) * i for i in range(3)] + ... self.register_buffer_list("my_buffer_list", tensor_list) + ... + >>> # Initialize the module + >>> module = MyModule() + ... + >>> # The buffer list can be accessed as a regular attribute + >>> module.my_buffer_list + [ + tensor([0., 0., 0.]), + tensor([1., 1., 1.]), + tensor([2., 2., 2.]) + ] + ... + >>> # Update the buffer list with new tensors + >>> new_tensor_list = [torch.ones(3) * i + 10 for i in range(3)] + >>> module.register_buffer_list("my_buffer_list", new_tensor_list) + >>> module.my_buffer_list + [ + tensor([10., 10., 10.]), + tensor([11., 11., 11.]), + tensor([12., 12., 12.]) + ] + ... + >>> # Move to GPU - device placement is handled automatically + >>> module.cuda() + >>> module.my_buffer_list + [ + tensor([10., 10., 10.], device='cuda:0'), + tensor([11., 11., 11.], device='cuda:0'), + tensor([12., 12., 12.], device='cuda:0') + ] +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,55 +49,37 @@ class BufferListMixin(nn.Module): - """Buffer List Mixin. - - This mixin is used to allow registering a list of tensors as buffers in a pytorch module. - - Example: - >>> class MyModule(BufferListMixin, nn.Module): - ... def __init__(self): - ... super().__init__() - ... tensor_list = [torch.ones(3) * i for i in range(3)] - ... self.register_buffer_list("my_buffer_list", tensor_list) - >>> module = MyModule() - >>> # The buffer list can be accessed as a regular attribute - >>> module.my_buffer_list - [ - tensor([0., 0., 0.]), - tensor([1., 1., 1.]), - tensor([2., 2., 2.]) - ] - >>> # We can update the buffer list at any time - >>> new_tensor_list = [torch.ones(3) * i + 10 for i in range(3)] - >>> module.register_buffer_list("my_buffer_list", new_tensor_list) - >>> module.my_buffer_list - [ - tensor([10., 10., 10.]), - tensor([11., 11., 11.]), - tensor([12., 12., 12.]) - ] - >>> # Move to GPU. Since the tensors are registered as buffers, device placement is handled automatically - >>> module.cuda() - >>> module.my_buffer_list - [ - tensor([10., 10., 10.], device='cuda:0'), - tensor([11., 11., 11.], device='cuda:0'), - tensor([12., 12., 12.], device='cuda:0') - ] + """Mixin class that enables registering lists of tensors as module buffers. + + This mixin extends PyTorch modules to support registering lists of tensors as + buffers, which are automatically handled during device placement and state + dict operations. """ - def register_buffer_list(self, name: str, values: list[torch.Tensor], persistent: bool = True, **kwargs) -> None: - """Register a list of tensors as buffers in a pytorch module. + def register_buffer_list( + self, + name: str, + values: list[torch.Tensor], + persistent: bool = True, + **kwargs, + ) -> None: + """Register a list of tensors as buffers in a PyTorch module. - Each tensor is registered as a buffer with the name `_name_i` where `i` is the index of the tensor in the list. - To update and retrieve the list of tensors, we dynamically assign a descriptor attribute to the class. + Each tensor is registered as a buffer with the name ``_name_i`` where ``i`` + is the index of the tensor in the list. The list can be accessed and + updated using the original ``name``. Args: - name (str): Name of the buffer list. - values (list[torch.Tensor]): List of tensors to register as buffers. - persistent (bool, optional): Whether the buffers should be saved as part of the module state_dict. - Defaults to True. - **kwargs: Additional keyword arguments to pass to `torch.nn.Module.register_buffer`. + name (str): + Name of the buffer list. + values (list[torch.Tensor]): + List of tensors to register as buffers. + persistent (bool, optional): + Whether the buffers should be saved as part of the module + state_dict. Defaults to ``True``. + **kwargs: + Additional keyword arguments to pass to + ``torch.nn.Module.register_buffer``. """ for i, value in enumerate(values): self.register_buffer(f"_{name}_{i}", value, persistent=persistent, **kwargs) @@ -65,31 +88,41 @@ def register_buffer_list(self, name: str, values: list[torch.Tensor], persistent class BufferListDescriptor: - """Buffer List Descriptor. + """Descriptor class for managing lists of buffer tensors. - This descriptor is used to allow registering a list of tensors as buffers in a pytorch module. + This descriptor provides the functionality to access and modify lists of + tensors that are registered as buffers in a PyTorch module. Args: - name (str): Name of the buffer list. - length (int): Length of the buffer list. + name (str): + Name of the buffer list. + length (int): + Length of the buffer list. """ def __init__(self, name: str, length: int) -> None: self.name = name self.length = length - def __get__(self, instance: object, object_type: type | None = None) -> list[torch.Tensor]: + def __get__( + self, + instance: object, + object_type: type | None = None, + ) -> list[torch.Tensor]: """Get the list of tensors. - Each element of the buffer list is stored as a buffer with the name `name_i` where `i` is the index of the - element in the list. We use list comprehension to retrieve the list of tensors. + Retrieves the list of tensors stored as individual buffers with names + ``_name_i`` where ``i`` is the index. Args: - instance (object): Instance of the class. - object_type (Any, optional): Type of the class. Defaults to None. + instance (object): + Instance of the class. + object_type (type | None, optional): + Type of the class. Defaults to ``None``. Returns: - list[torch.Tensor]: Contents of the buffer list. + list[torch.Tensor]: + List of tensor buffers. """ del object_type return [getattr(instance, f"_{self.name}_{i}") for i in range(self.length)] @@ -97,11 +130,13 @@ def __get__(self, instance: object, object_type: type | None = None) -> list[tor def __set__(self, instance: object, values: list[torch.Tensor]) -> None: """Set the list of tensors. - Assigns a new list of tensors to the buffer list by updating the individual buffer attributes. + Updates the individual buffer attributes with new tensor values. Args: - instance (object): Instance of the class. - values (list[torch.Tensor]): List of tensors to set. + instance (object): + Instance of the class. + values (list[torch.Tensor]): + List of tensors to set as buffers. """ for i, value in enumerate(values): setattr(instance, f"_{self.name}_{i}", value) diff --git a/src/anomalib/models/components/base/dynamic_buffer.py b/src/anomalib/models/components/base/dynamic_buffer.py index e1c6ad6bd6..d34ac94283 100644 --- a/src/anomalib/models/components/base/dynamic_buffer.py +++ b/src/anomalib/models/components/base/dynamic_buffer.py @@ -1,4 +1,27 @@ -"""Dynamic Buffer Mixin.""" +"""Dynamic Buffer Mixin. + +This mixin class enables loading state dictionaries with mismatched tensor shapes +by dynamically resizing buffers to match the loaded state. + +Example: + >>> import torch + >>> from torch import nn + >>> class MyModule(DynamicBufferMixin, nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.register_buffer("buffer", torch.ones(3)) + ... + >>> module = MyModule() + >>> # Original buffer shape is (3,) + >>> module.buffer + tensor([1., 1., 1.]) + >>> # Load state dict with different buffer shape (5,) + >>> new_state = {"buffer": torch.ones(5)} + >>> module.load_state_dict(new_state) + >>> # Buffer is automatically resized + >>> module.buffer + tensor([1., 1., 1., 1., 1.]) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,19 +33,26 @@ class DynamicBufferMixin(nn.Module, ABC): - """This mixin allows loading variables from the state dict even in the case of shape mismatch.""" + """Mixin that enables loading state dicts with mismatched tensor shapes. + + This mixin class extends ``nn.Module`` to allow loading state dictionaries + even when the shapes of tensors in the state dict do not match the shapes + of the module's buffers. When loading a state dict, the mixin automatically + resizes any mismatched buffers to match the shapes in the state dict. + """ def get_tensor_attribute(self, attribute_name: str) -> torch.Tensor: - """Get attribute of the tensor given the name. + """Get a tensor attribute by name. Args: - attribute_name (str): Name of the tensor + attribute_name (str): Name of the tensor attribute to retrieve Raises: - ValueError: `attribute_name` is not a torch Tensor + ValueError: If the attribute with name ``attribute_name`` is not a + ``torch.Tensor`` Returns: - Tensor: torch.Tensor attribute + torch.Tensor: The tensor attribute """ attribute = getattr(self, attribute_name) if isinstance(attribute, torch.Tensor): @@ -32,14 +62,17 @@ def get_tensor_attribute(self, attribute_name: str) -> torch.Tensor: raise ValueError(msg) def _load_from_state_dict(self, state_dict: dict, prefix: str, *args) -> None: - """Resizes the local buffers to match those stored in the state dict. + """Load a state dictionary, resizing buffers if shapes don't match. - Overrides method from parent class. + This method overrides the parent class implementation to handle tensor + shape mismatches when loading state dictionaries. It resizes any local + buffers whose shapes don't match the corresponding tensors in the state + dict. Args: - state_dict (dict): State dictionary containing weights - prefix (str): Prefix of the weight file. - *args: Variable length argument list. + state_dict (dict): Dictionary containing state to load + prefix (str): Prefix to prepend to parameter/buffer names + *args: Additional arguments passed to parent implementation """ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} local_buffers = {k: v for k, v in persistent_buffers.items() if v is not None} diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index baaf07ec95..96fae750fa 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -1,4 +1,38 @@ -"""Mixin for exporting models to disk.""" +"""Mixin for exporting anomaly detection models to disk. + +This mixin provides functionality to export models to various formats: +- PyTorch (.pt) +- ONNX (.onnx) +- OpenVINO IR (.xml/.bin) + +The mixin supports different compression types for OpenVINO exports: +- FP16 compression +- INT8 quantization +- Post-training quantization (PTQ) +- Accuracy-aware quantization (ACQ) + +Example: + Export a trained model to different formats: + + >>> from anomalib.models import Patchcore + >>> from anomalib.data import Visa + >>> from anomalib.deploy.export import CompressionType + ... + >>> # Initialize and train model + >>> model = Patchcore() + >>> datamodule = Visa() + >>> # Export to PyTorch format + >>> model.to_torch("./exports") + >>> # Export to ONNX + >>> model.to_onnx("./exports", input_size=(224, 224)) + >>> # Export to OpenVINO with INT8 quantization + >>> model.to_openvino( + ... "./exports", + ... input_size=(224, 224), + ... compression_type=CompressionType.INT8_PTQ, + ... datamodule=datamodule + ... ) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -29,7 +63,16 @@ class ExportMixin: - """This mixin allows exporting models to torch and ONNX/OpenVINO.""" + """Mixin class that enables exporting models to various formats. + + This mixin provides methods to export models to PyTorch (.pt), ONNX (.onnx), + and OpenVINO IR (.xml/.bin) formats. For OpenVINO exports, it supports + different compression types including FP16, INT8, PTQ and ACQ. + + The mixin requires the host class to have: + - A ``model`` attribute of type ``nn.Module`` + - A ``device`` attribute of type ``torch.device`` + """ model: nn.Module device: torch.device @@ -38,37 +81,22 @@ def to_torch( self, export_root: Path | str, ) -> Path: - """Export AnomalibModel to torch. + """Export model to PyTorch format. Args: - export_root (Path): Path to the output folder. - transform (Transform, optional): Input transforms used for the model. If not provided, the transform is - taken from the model. - Defaults to ``None``. - post_processor (nn.Module, optional): Post-processing module to apply to the model output. - Defaults to ``None``. + export_root (Path | str): Path to the output folder Returns: - Path: Path to the exported pytorch model. + Path: Path to the exported PyTorch model (.pt file) Examples: - Assume that we have a model to train and we want to export it to torch format. + Export a trained model to PyTorch format: - >>> from anomalib.data import Visa >>> from anomalib.models import Patchcore - >>> from anomalib.engine import Engine - ... - >>> datamodule = Visa() >>> model = Patchcore() - >>> engine = Engine() - ... - >>> engine.fit(model, datamodule) - - Now that we have a model trained, we can export it to torch format. - - >>> model.to_torch( - ... export_root="path/to/export", - ... ) + >>> # Train model... + >>> model.to_torch("./exports") + PosixPath('./exports/weights/torch/model.pt') """ export_root = _create_export_root(export_root, ExportType.TORCH) pt_model_path = export_root / "model.pt" @@ -83,41 +111,29 @@ def to_onnx( export_root: Path | str, input_size: tuple[int, int] | None = None, ) -> Path: - """Export model to onnx. + """Export model to ONNX format. Args: - export_root (Path): Path to the root folder of the exported model. - input_size (tuple[int, int] | None, optional): Image size used as the input for onnx converter. - Defaults to None. - transform (Transform, optional): Input transforms used for the model. If not provided, the transform is - taken from the model. - Defaults to ``None``. - post_processor (nn.Module, optional): Post-processing module to apply to the model output. - Defaults to ``None``. + export_root (Path | str): Path to the output folder + input_size (tuple[int, int] | None): Input image dimensions (height, width). + If ``None``, uses dynamic input shape. Defaults to ``None`` Returns: - Path: Path to the exported onnx model. + Path: Path to the exported ONNX model (.onnx file) Examples: - Export the Lightning Model to ONNX: + Export model with fixed input size: >>> from anomalib.models import Patchcore - >>> from anomalib.data import Visa - ... - >>> datamodule = Visa() >>> model = Patchcore() - ... - >>> model.to_onnx( - ... export_root="path/to/export", - ... transform=datamodule.test_data.transform, - ... ) + >>> # Train model... + >>> model.to_onnx("./exports", input_size=(224, 224)) + PosixPath('./exports/weights/onnx/model.onnx') - Using Custom Transforms: - This example shows how to use a custom ``Compose`` object for the ``transform`` argument. + Export model with dynamic input size: - >>> model.to_onnx( - ... export_root="path/to/export", - ... ) + >>> model.to_onnx("./exports") + PosixPath('./exports/weights/onnx/model.onnx') """ export_root = _create_export_root(export_root, ExportType.ONNX) input_shape = torch.zeros((1, 3, *input_size)) if input_size else torch.zeros((1, 3, 1, 1)) @@ -133,7 +149,7 @@ def to_onnx( output_names = [name for name, value in self.eval()(input_shape)._asdict().items() if value is not None] torch.onnx.export( self, - input_shape.to(self.device), + (input_shape.to(self.device),), str(onnx_path), opset_version=14, dynamic_axes=dynamic_axes, @@ -153,78 +169,46 @@ def to_openvino( ov_args: dict[str, Any] | None = None, task: TaskType | None = None, ) -> Path: - """Convert onnx model to OpenVINO IR. + """Export model to OpenVINO IR format. Args: - export_root (Path): Path to the export folder. - input_size (tuple[int, int] | None, optional): Input size of the model. Used for adding metadata to the IR. - Defaults to None. - transform (Transform, optional): Input transforms used for the model. If not provided, the transform is - taken from the model. - Defaults to ``None``. - compression_type (CompressionType, optional): Compression type for better inference performance. - Defaults to ``None``. - datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. - Defaults to ``None``. - metric (Metric | None, optional): Metric to measure quality loss when quantizing. - Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better - performance of the model. - Defaults to ``None``. - ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion. - Defaults to ``None``. - task (TaskType | None): Task type. - Defaults to ``None``. + export_root (Path | str): Path to the output folder + input_size (tuple[int, int] | None): Input image dimensions (height, width). + If ``None``, uses dynamic input shape. Defaults to ``None`` + compression_type (CompressionType | None): Type of compression to apply. + Options: ``FP16``, ``INT8``, ``INT8_PTQ``, ``INT8_ACQ``. + Defaults to ``None`` + datamodule (AnomalibDataModule | None): DataModule for quantization. + Required for ``INT8_PTQ`` and ``INT8_ACQ``. Defaults to ``None`` + metric (Metric | None): Metric for accuracy-aware quantization. + Required for ``INT8_ACQ``. Defaults to ``None`` + ov_args (dict[str, Any] | None): OpenVINO model optimizer arguments. + Defaults to ``None`` + task (TaskType | None): Task type (classification/segmentation). + Defaults to ``None`` Returns: - Path: Path to the exported onnx model. + Path: Path to the exported OpenVINO model (.xml file) Raises: - ModuleNotFoundError: If OpenVINO is not installed. - - Returns: - Path: Path to the exported OpenVINO IR. + ModuleNotFoundError: If OpenVINO is not installed + ValueError: If required arguments for quantization are missing Examples: - Export the Lightning Model to OpenVINO IR: - This example demonstrates how to export the Lightning Model to OpenVINO IR. + Export model with FP16 compression: - >>> from anomalib.models import Patchcore - >>> from anomalib.data import Visa - ... - >>> datamodule = Visa() - >>> model = Patchcore() - ... >>> model.to_openvino( - ... export_root="path/to/export", - ... transform=datamodule.test_data.transform, - ... task=datamodule.test_data.task + ... "./exports", + ... input_size=(224, 224), + ... compression_type=CompressionType.FP16 ... ) - Export and Quantize the Model (OpenVINO IR): - This example demonstrates how to export and quantize the model to OpenVINO IR. + Export with INT8 post-training quantization: - >>> from anomalib.models import Patchcore - >>> from anomalib.data import Visa - >>> datamodule = Visa() - >>> model = Patchcore() >>> model.to_openvino( - ... export_root="path/to/export", + ... "./exports", ... compression_type=CompressionType.INT8_PTQ, - ... datamodule=datamodule, - ... task=datamodule.test_data.task - ... ) - - Using Custom Transforms: - This example shows how to use a custom ``Transform`` object for the ``transform`` argument. - - >>> from torchvision.transforms.v2 import Resize - >>> transform = Resize(224, 224) - ... - >>> model.to_openvino( - ... export_root="path/to/export", - ... transform=transform, - ... task="segmentation", + ... datamodule=datamodule ... ) """ if not module_available("openvino"): @@ -257,23 +241,25 @@ def _compress_ov_model( metric: Metric | None = None, task: TaskType | None = None, ) -> "CompiledModel": - """Compress OpenVINO model with NNCF. - - model (CompiledModel): Model already exported to OpenVINO format. - compression_type (CompressionType, optional): Compression type for better inference performance. - Defaults to ``None``. - datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. - Defaults to ``None``. - metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. - Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better - performance of the model. - Defaults to ``None``. - task (TaskType | None): Task type. - Defaults to ``None``. + """Compress OpenVINO model using NNCF. + + Args: + model (CompiledModel): OpenVINO model to compress + compression_type (CompressionType | None): Type of compression to apply. + Defaults to ``None`` + datamodule (AnomalibDataModule | None): DataModule for quantization. + Required for ``INT8_PTQ`` and ``INT8_ACQ``. Defaults to ``None`` + metric (Metric | None): Metric for accuracy-aware quantization. + Required for ``INT8_ACQ``. Defaults to ``None`` + task (TaskType | None): Task type (classification/segmentation). + Defaults to ``None`` Returns: - model (CompiledModel): Model in the OpenVINO format compressed with NNCF quantization. + CompiledModel: Compressed OpenVINO model + + Raises: + ModuleNotFoundError: If NNCF is not installed + ValueError: If compression type is not recognized """ if not module_available("nncf"): logger.exception("Could not find NCCF. Please check NNCF installation.") @@ -298,15 +284,18 @@ def _post_training_quantization_ov( model: "CompiledModel", datamodule: AnomalibDataModule | None = None, ) -> "CompiledModel": - """Post-Training Quantization model with NNCF. + """Apply post-training quantization to OpenVINO model. - model (CompiledModel): Model already exported to OpenVINO format. - datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. - Defaults to ``None``. + Args: + model (CompiledModel): OpenVINO model to quantize + datamodule (AnomalibDataModule | None): DataModule for calibration. + Must contain at least 300 images. Defaults to ``None`` Returns: - model (CompiledModel): Quantized model. + CompiledModel: Quantized OpenVINO model + + Raises: + ValueError: If datamodule is not provided """ import nncf @@ -336,21 +325,24 @@ def _accuracy_control_quantization_ov( metric: Metric | None = None, task: TaskType | None = None, ) -> "CompiledModel": - """Accuracy-Control Quantization with NNCF. - - model (CompiledModel): Model already exported to OpenVINO format. - datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. - Defaults to ``None``. - metric (Metric | None, optional): Metric to measure quality loss when quantizing. - Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better - performance of the model. - Defaults to ``None``. - task (TaskType | None): Task type. - Defaults to ``None``. + """Apply accuracy-aware quantization to OpenVINO model. + + Args: + model (CompiledModel): OpenVINO model to quantize + datamodule (AnomalibDataModule | None): DataModule for calibration + and validation. Must contain at least 300 images. + Defaults to ``None`` + metric (Metric | None): Metric to measure accuracy during quantization. + Higher values should indicate better performance. + Defaults to ``None`` + task (TaskType | None): Task type (classification/segmentation). + Defaults to ``None`` Returns: - model (CompiledModel): Quantized model. + CompiledModel: Quantized OpenVINO model + + Raises: + ValueError: If datamodule or metric is not provided """ import nncf @@ -393,14 +385,14 @@ def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float: def _create_export_root(export_root: str | Path, export_type: ExportType) -> Path: - """Create export directory. + """Create directory structure for model export. Args: - export_root (str | Path): Path to the root folder of the exported model. - export_type (ExportType): Mode to export the model. Torch, ONNX or OpenVINO. + export_root (str | Path): Root directory for exports + export_type (ExportType): Type of export (torch/onnx/openvino) Returns: - Path: Path to the export directory. + Path: Created directory path """ export_root = Path(export_root) / "weights" / export_type.value export_root.mkdir(parents=True, exist_ok=True) diff --git a/src/anomalib/models/components/base/memory_bank_module.py b/src/anomalib/models/components/base/memory_bank_module.py index 501e8dc11a..07eae880bc 100644 --- a/src/anomalib/models/components/base/memory_bank_module.py +++ b/src/anomalib/models/components/base/memory_bank_module.py @@ -1,4 +1,25 @@ -"""Memory Bank Module.""" +"""Memory Bank Module. + +This module provides a mixin class for implementing memory bank-based anomaly +detection models. Memory banks store reference features or embeddings that are +used to detect anomalies by comparing test samples against the stored references. + +The mixin ensures proper initialization and fitting of the memory bank before +validation or inference. + +Example: + Create a custom memory bank model: + + >>> from anomalib.models.components.base import MemoryBankMixin + >>> class MyMemoryModel(MemoryBankMixin): + ... def __init__(self): + ... super().__init__() + ... self.memory = [] + ... + ... def fit(self): + ... # Implement memory bank population logic + ... self.memory = [1, 2, 3] +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,8 +33,16 @@ class MemoryBankMixin(nn.Module): """Memory Bank Lightning Module. - This module is used to implement memory bank lightning modules. - It checks if the model is fitted before validation starts. + This mixin class provides functionality for memory bank-based models that need + to store and compare against reference features/embeddings. It ensures the + memory bank is properly fitted before validation or inference begins. + + The mixin tracks the fitting status via a persistent buffer ``_is_fitted`` + and automatically triggers the fitting process when needed. + + Attributes: + device (torch.device): Device where the model/tensors reside + _is_fitted (torch.Tensor): Boolean tensor tracking if model is fitted """ def __init__(self, *args, **kwargs) -> None: @@ -24,21 +53,33 @@ def __init__(self, *args, **kwargs) -> None: @abstractmethod def fit(self) -> None: - """Fit the model to the data.""" + """Fit the memory bank model to the training data. + + This method should be implemented by subclasses to define how the memory + bank is populated with reference features/embeddings. + + Raises: + NotImplementedError: If the subclass does not implement this method + """ msg = ( - f"fit method not implemented for {self.__class__.__name__}. " - "To use a memory-bank module, implement ``fit.``" + f"fit method not implemented for {self.__class__.__name__}. To use a memory-bank module, implement ``fit``." ) raise NotImplementedError(msg) def on_validation_start(self) -> None: - """Ensure that the model is fitted before validation starts.""" + """Ensure memory bank is fitted before validation. + + This hook automatically fits the memory bank if it hasn't been fitted yet. + """ if not self._is_fitted: self.fit() self._is_fitted = torch.tensor([True], device=self.device) def on_train_epoch_end(self) -> None: - """Ensure that the model is fitted before validation starts.""" + """Ensure memory bank is fitted after training. + + This hook automatically fits the memory bank if it hasn't been fitted yet. + """ if not self._is_fitted: self.fit() self._is_fitted = torch.tensor([True], device=self.device) diff --git a/src/anomalib/models/components/classification/__init__.py b/src/anomalib/models/components/classification/__init__.py index 253db6aee6..0f3e735a99 100644 --- a/src/anomalib/models/components/classification/__init__.py +++ b/src/anomalib/models/components/classification/__init__.py @@ -1,4 +1,21 @@ -"""Classification modules.""" +"""Classification modules for anomaly detection. + +This module provides classification components used in anomaly detection models. + +Classes: + KDEClassifier: Kernel Density Estimation based classifier for anomaly + detection. + FeatureScalingMethod: Enum class defining feature scaling methods for + KDE classifier. + +Example: + >>> from anomalib.models.components.classification import KDEClassifier + >>> from anomalib.models.components.classification import FeatureScalingMethod + >>> # Create KDE classifier with min-max scaling + >>> classifier = KDEClassifier( + ... scaling_method=FeatureScalingMethod.MIN_MAX + ... ) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/components/classification/kde_classifier.py b/src/anomalib/models/components/classification/kde_classifier.py index d50e5cca31..0c068c74cb 100644 --- a/src/anomalib/models/components/classification/kde_classifier.py +++ b/src/anomalib/models/components/classification/kde_classifier.py @@ -1,4 +1,27 @@ -"""Kernel Density Estimation Classifier.""" +"""Kernel Density Estimation Classifier. + +This module provides a classifier based on kernel density estimation (KDE) for +anomaly detection. The classifier fits a KDE model to feature embeddings and uses +it to compute anomaly probabilities. + +Example: + >>> from anomalib.models.components.classification import KDEClassifier + >>> from anomalib.models.components.classification import FeatureScalingMethod + >>> # Create classifier with default settings + >>> classifier = KDEClassifier() + >>> # Create classifier with custom settings + >>> classifier = KDEClassifier( + ... n_pca_components=32, + ... feature_scaling_method=FeatureScalingMethod.NORM, + ... max_training_points=50000 + ... ) + >>> # Fit classifier on embeddings + >>> embeddings = torch.randn(1000, 512) # Example embeddings + >>> classifier.fit(embeddings) + >>> # Get anomaly probabilities for new samples + >>> new_embeddings = torch.randn(10, 512) + >>> probabilities = classifier.predict(new_embeddings) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,21 +39,43 @@ class FeatureScalingMethod(str, Enum): - """Determines how the feature embeddings are scaled.""" + """Feature scaling methods for KDE classifier. + + The scaling method determines how feature embeddings are normalized before + being passed to the KDE model. + + Attributes: + NORM: Scale features to unit vector length (L2 normalization) + SCALE: Scale features by maximum length observed during training + (preserves relative magnitudes) + """ NORM = "norm" # scale to unit vector length - SCALE = "scale" # scale to max length observed in training (preserve relative magnitude) + SCALE = "scale" # scale to max length observed in training class KDEClassifier(nn.Module): """Classification module for KDE-based anomaly detection. + This classifier uses kernel density estimation to model the distribution of + normal samples in feature space. It first applies dimensionality reduction + via PCA, then fits a Gaussian KDE model to the reduced features. + Args: - n_pca_components (int, optional): Number of PCA components. Defaults to 16. - feature_scaling_method (FeatureScalingMethod, optional): Scaling method applied to features before passing to - KDE. Options are `norm` (normalize to unit vector length) and `scale` (scale to max length observed in - training). - max_training_points (int, optional): Maximum number of training points to fit the KDE model. Defaults to 40000. + n_pca_components: Number of PCA components to retain. Lower values reduce + computational cost but may lose information. + Defaults to 16. + feature_scaling_method: Method used to scale features before KDE. + Options are ``norm`` (unit vector) or ``scale`` (max length). + Defaults to ``FeatureScalingMethod.SCALE``. + max_training_points: Maximum number of points used to fit the KDE model. + If more points are provided, a random subset is selected. + Defaults to 40000. + + Attributes: + pca_model: PCA model for dimensionality reduction + kde_model: Gaussian KDE model for density estimation + max_length: Maximum feature length observed during training """ def __init__( @@ -56,15 +101,20 @@ def pre_process( feature_stack: torch.Tensor, max_length: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Pre-process the CNN features. + """Pre-process feature embeddings before KDE. + + Scales the features according to the specified scaling method. Args: - feature_stack (torch.Tensor): Features extracted from CNN - max_length (Tensor | None): Used to unit normalize the feature_stack vector. If ``max_len`` is not - provided, the length is calculated from the ``feature_stack``. Defaults to None. + feature_stack: Features extracted from the model, shape (N, D) + max_length: Maximum feature length for scaling. If ``None``, computed + from ``feature_stack``. Defaults to None. Returns: - (Tuple): Stacked features and length + tuple: (scaled_features, max_length) + + Raises: + RuntimeError: If unknown scaling method is specified """ if max_length is None: max_length = torch.max(torch.linalg.norm(feature_stack, ord=2, dim=1)) @@ -79,13 +129,21 @@ def pre_process( return feature_stack, max_length def fit(self, embeddings: torch.Tensor) -> bool: - """Fit a kde model to embeddings. + """Fit the KDE classifier to training embeddings. + + Applies PCA, scales the features, and fits the KDE model. Args: - embeddings (torch.Tensor): Input embeddings to fit the model. + embeddings: Training embeddings of shape (N, D) Returns: - Boolean confirming whether the training is successful. + bool: True if fitting succeeded, False if insufficient samples + + Example: + >>> classifier = KDEClassifier() + >>> embeddings = torch.randn(1000, 512) + >>> success = classifier.fit(embeddings) + >>> assert success """ if embeddings.shape[0] < self.n_pca_components: logger.info("Not enough features to commit. Not making a model.") @@ -109,17 +167,17 @@ def fit(self, embeddings: torch.Tensor) -> bool: return True def compute_kde_scores(self, features: torch.Tensor, as_log_likelihood: bool | None = False) -> torch.Tensor: - """Compute the KDE scores. + """Compute KDE scores for input features. - The scores calculated from the KDE model are converted to densities. If `as_log_likelihood` is set to true then - the log of the scores are calculated. + Transforms features via PCA and scaling, then computes KDE scores. Args: - features (torch.Tensor): Features to which the PCA model is fit. - as_log_likelihood (bool | None, optional): If true, gets log likelihood scores. Defaults to False. + features: Input features of shape (N, D) + as_log_likelihood: If True, returns log of KDE scores. + Defaults to False. Returns: - (torch.Tensor): Score + torch.Tensor: KDE scores of shape (N,) """ features = self.pca_model.transform(features) features, _ = self.pre_process(features, self.max_length) @@ -136,28 +194,50 @@ def compute_kde_scores(self, features: torch.Tensor, as_log_likelihood: bool | N @staticmethod def compute_probabilities(scores: torch.Tensor) -> torch.Tensor: - """Convert density scores to anomaly probabilities (see https://www.desmos.com/calculator/ifju7eesg7). + """Convert density scores to anomaly probabilities. + + Uses sigmoid function to map scores to [0,1] range. + See https://www.desmos.com/calculator/ifju7eesg7 Args: - scores (torch.Tensor): density of an image. + scores: Density scores of shape (N,) Returns: - probability that image with {density} is anomalous + torch.Tensor: Anomaly probabilities of shape (N,) """ return 1 / (1 + torch.exp(0.05 * (scores - 12))) def predict(self, features: torch.Tensor) -> torch.Tensor: - """Predicts the probability that the features belong to the anomalous class. + """Predict anomaly probabilities for input features. + + Computes KDE scores and converts them to probabilities. Args: - features (torch.Tensor): Feature from which the output probabilities are detected. + features: Input features of shape (N, D) Returns: - Detection probabilities + torch.Tensor: Anomaly probabilities of shape (N,) + + Example: + >>> classifier = KDEClassifier() + >>> features = torch.randn(10, 512) + >>> classifier.fit(features) + >>> probs = classifier.predict(features) + >>> assert probs.shape == (10,) + >>> assert (probs >= 0).all() and (probs <= 1).all() """ scores = self.compute_kde_scores(features, as_log_likelihood=True) return self.compute_probabilities(scores) def forward(self, features: torch.Tensor) -> torch.Tensor: - """Make predictions on extracted features.""" + """Forward pass of the classifier. + + Equivalent to calling ``predict()``. + + Args: + features: Input features of shape (N, D) + + Returns: + torch.Tensor: Anomaly probabilities of shape (N,) + """ return self.predict(features) diff --git a/src/anomalib/models/components/cluster/__init__.py b/src/anomalib/models/components/cluster/__init__.py index e3ce0455af..74f7601204 100644 --- a/src/anomalib/models/components/cluster/__init__.py +++ b/src/anomalib/models/components/cluster/__init__.py @@ -1,4 +1,22 @@ -"""Clustering algorithm implementations using PyTorch.""" +"""Clustering algorithm implementations using PyTorch. + +This module provides clustering algorithms implemented in PyTorch for anomaly +detection tasks. + +Classes: + GaussianMixture: Gaussian Mixture Model for density estimation and clustering. + KMeans: K-Means clustering algorithm. + +Example: + >>> from anomalib.models.components.cluster import GaussianMixture, KMeans + >>> # Create and fit a GMM + >>> gmm = GaussianMixture(n_components=3) + >>> features = torch.randn(100, 10) # Example features + >>> gmm.fit(features) + >>> # Create and fit KMeans + >>> kmeans = KMeans(n_clusters=5) + >>> kmeans.fit(features) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/components/cluster/gmm.py b/src/anomalib/models/components/cluster/gmm.py index b7f94693b2..cfbb653991 100644 --- a/src/anomalib/models/components/cluster/gmm.py +++ b/src/anomalib/models/components/cluster/gmm.py @@ -1,4 +1,8 @@ -"""Pytorch implementation of Gaussian Mixture Model.""" +"""PyTorch implementation of Gaussian Mixture Model. + +This module provides a PyTorch-based implementation of Gaussian Mixture Model (GMM) +for clustering data into multiple Gaussian distributions. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,38 +20,43 @@ class GaussianMixture(DynamicBufferMixin): - """Gaussian Mixture Model. + """Gaussian Mixture Model for clustering data into Gaussian distributions. Args: - n_components (int): Number of components. - n_iter (int): Maximum number of iterations to perform. - Defaults to ``100``. - tol (float): Convergence threshold. - Defaults to ``1e-3``. + n_components (int): Number of Gaussian components to fit. + n_iter (int, optional): Maximum number of EM iterations. Defaults to 100. + tol (float, optional): Convergence threshold for log-likelihood. + Defaults to 1e-3. + + Attributes: + means (torch.Tensor): Means of the Gaussian components. + Shape: ``(n_components, n_features)``. + covariances (torch.Tensor): Covariance matrices of components. + Shape: ``(n_components, n_features, n_features)``. + weights (torch.Tensor): Mixing weights of components. + Shape: ``(n_components,)``. Example: - The following examples shows how to fit a Gaussian Mixture Model to some data and get the cluster means and - predicted labels and log-likelihood scores of the data. - - .. code-block:: python - - >>> import torch - >>> from anomalib.models.components.cluster import GaussianMixture - >>> model = GaussianMixture(n_components=2) - >>> data = torch.tensor( - ... [ - ... [2, 1], [2, 2], [2, 3], - ... [7, 5], [8, 5], [9, 5], - ... ] - ... ).float() - >>> model.fit(data) - >>> model.means # get the means of the gaussians - tensor([[8., 5.], - [2., 2.]]) - >>> model.predict(data) # get the predicted cluster label of each sample - tensor([1, 1, 1, 0, 0, 0]) - >>> model.score_samples(data) # get the log-likelihood score of each sample - tensor([3.8295, 4.5795, 3.8295, 3.8295, 4.5795, 3.8295]) + >>> import torch + >>> from anomalib.models.components.cluster import GaussianMixture + >>> # Create synthetic data with two clusters + >>> data = torch.tensor([ + ... [2, 1], [2, 2], [2, 3], # Cluster 1 + ... [7, 5], [8, 5], [9, 5], # Cluster 2 + ... ]).float() + >>> # Initialize and fit GMM + >>> model = GaussianMixture(n_components=2) + >>> model.fit(data) + >>> # Get cluster means + >>> model.means + tensor([[8., 5.], + [2., 2.]]) + >>> # Predict cluster assignments + >>> model.predict(data) + tensor([1, 1, 1, 0, 0, 0]) + >>> # Get log-likelihood scores + >>> model.score_samples(data) + tensor([3.8295, 4.5795, 3.8295, 3.8295, 4.5795, 3.8295]) """ def __init__(self, n_components: int, n_iter: int = 100, tol: float = 1e-3) -> None: @@ -65,10 +74,11 @@ def __init__(self, n_components: int, n_iter: int = 100, tol: float = 1e-3) -> N self.weights: torch.Tensor def fit(self, data: torch.Tensor) -> None: - """Fit the model to the data. + """Fit the GMM to the input data using EM algorithm. Args: - data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + data (torch.Tensor): Input data to fit the model to. + Shape: ``(n_samples, n_features)``. """ self._initialize_parameters_kmeans(data) @@ -88,41 +98,50 @@ def fit(self, data: torch.Tensor) -> None: if not converged: logger.warning( - f"GMM did not converge after {self.n_iter} iterations. \ - Consider increasing the number of iterations.", + f"GMM did not converge after {self.n_iter} iterations. Consider increasing the number of iterations.", ) def _initialize_parameters_kmeans(self, data: torch.Tensor) -> None: - """Initialize parameters with K-means. + """Initialize GMM parameters using K-means clustering. Args: - data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + data (torch.Tensor): Input data for initialization. + Shape: ``(n_samples, n_features)``. """ labels, _ = KMeans(n_clusters=self.n_components).fit(data) resp = one_hot(labels, num_classes=self.n_components).float() self._m_step(data, resp) def _e_step(self, data: torch.Tensor) -> torch.Tensor: - """Perform the E-step to estimate the responsibilities of the gaussians. + """Perform E-step to compute responsibilities and log-likelihood. Args: - data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + data (torch.Tensor): Input data. + Shape: ``(n_samples, n_features)``. Returns: - Tensor: log probability of the data given the gaussians. - Tensor: Tensor of shape (n_samples, n_components) containing the responsibilities. + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - Mean log-likelihood of the data + - Responsibilities for each component. + Shape: ``(n_samples, n_components)`` """ weighted_log_prob = self._estimate_weighted_log_prob(data) log_prob_norm = torch.logsumexp(weighted_log_prob, axis=1) - log_resp = weighted_log_prob - torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) + log_resp = weighted_log_prob - torch.logsumexp( + weighted_log_prob, + dim=1, + keepdim=True, + ) return torch.mean(log_prob_norm), torch.exp(log_resp) def _m_step(self, data: torch.Tensor, resp: torch.Tensor) -> None: - """Perform the M-step to update the parameters of the gaussians. + """Perform M-step to update GMM parameters. Args: - data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). - resp (Tensor): Tensor of shape (n_samples, n_components) containing the responsibilities. + data (torch.Tensor): Input data. + Shape: ``(n_samples, n_features)``. + resp (torch.Tensor): Responsibilities from E-step. + Shape: ``(n_samples, n_components)``. """ cluster_counts = resp.sum(axis=0) # number of points in each cluster self.weights = resp.mean(axis=0) # new weights @@ -130,22 +149,37 @@ def _m_step(self, data: torch.Tensor, resp: torch.Tensor) -> None: diff = data.unsqueeze(0) - self.means.unsqueeze(1) weighted_diff = diff * resp.T.unsqueeze(-1) - covariances = torch.bmm(weighted_diff.transpose(-2, -1), diff) / cluster_counts.view(-1, 1, 1) + covariances = torch.bmm( + weighted_diff.transpose(-2, -1), + diff, + ) / cluster_counts.view(-1, 1, 1) # Add a small constant for numerical stability - self.covariances = covariances + torch.eye(data.shape[1], device=data.device) * 1e-6 # new covariances + self.covariances = ( + covariances + + torch.eye( + data.shape[1], + device=data.device, + ) + * 1e-6 + ) def _estimate_weighted_log_prob(self, data: torch.Tensor) -> torch.Tensor: - """Estimate the log probability of the data given the gaussian parameters. + """Estimate weighted log probabilities for each component. Args: - data (Tensor): Data to fit the model to. Tensor of shape (n_samples, n_features). + data (torch.Tensor): Input data. + Shape: ``(n_samples, n_features)``. Returns: - Tensor: Tensor of shape (n_samples, n_components) containing the log-probabilities of each sample. + torch.Tensor: Weighted log probabilities. + Shape: ``(n_samples, n_components)``. """ log_prob = torch.stack( [ - MultivariateNormal(self.means[comp], self.covariances[comp]).log_prob(data) + MultivariateNormal( + self.means[comp], + self.covariances[comp], + ).log_prob(data) for comp in range(self.n_components) ], dim=1, @@ -153,24 +187,28 @@ def _estimate_weighted_log_prob(self, data: torch.Tensor) -> torch.Tensor: return log_prob + torch.log(self.weights) def score_samples(self, data: torch.Tensor) -> torch.Tensor: - """Assign a likelihood score to each sample in the data. + """Compute per-sample likelihood scores. Args: - data (Tensor): Samples to assign scores to. Tensor of shape (n_samples, n_features). + data (torch.Tensor): Input samples to score. + Shape: ``(n_samples, n_features)``. Returns: - Tensor: Tensor of shape (n_samples,) containing the log-likelihood score of each sample. + torch.Tensor: Log-likelihood scores. + Shape: ``(n_samples,)``. """ return torch.logsumexp(self._estimate_weighted_log_prob(data), dim=1) def predict(self, data: torch.Tensor) -> torch.Tensor: - """Predict the cluster labels of the data. + """Predict cluster assignments for the input data. Args: - data (Tensor): Samples to assign to clusters. Tensor of shape (n_samples, n_features). + data (torch.Tensor): Input samples. + Shape: ``(n_samples, n_features)``. Returns: - Tensor: Tensor of shape (n_samples,) containing the predicted cluster label of each sample. + torch.Tensor: Predicted cluster labels. + Shape: ``(n_samples,)``. """ _, resp = self._e_step(data) return torch.argmax(resp, axis=1) diff --git a/src/anomalib/models/components/cluster/kmeans.py b/src/anomalib/models/components/cluster/kmeans.py index 908a3e3fae..b8f5f05c90 100644 --- a/src/anomalib/models/components/cluster/kmeans.py +++ b/src/anomalib/models/components/cluster/kmeans.py @@ -1,4 +1,23 @@ -"""KMeans clustering algorithm implementation using PyTorch.""" +"""PyTorch implementation of K-means clustering algorithm. + +This module provides a PyTorch-based implementation of the K-means clustering +algorithm for partitioning data into ``k`` distinct clusters. + +Example: + >>> import torch + >>> from anomalib.models.components.cluster import KMeans + >>> # Create synthetic data + >>> data = torch.tensor([ + ... [1.0, 2.0], [1.5, 1.8], [1.2, 2.2], # Cluster 1 + ... [4.0, 4.0], [4.2, 4.1], [3.8, 4.2], # Cluster 2 + ... ]) + >>> # Initialize and fit KMeans + >>> kmeans = KMeans(n_clusters=2) + >>> labels, centers = kmeans.fit(data) + >>> # Predict cluster for new points + >>> new_points = torch.tensor([[1.1, 2.1], [4.0, 4.1]]) + >>> predictions = kmeans.predict(new_points) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -7,11 +26,27 @@ class KMeans: - """Initialize the KMeans object. + """K-means clustering algorithm implementation. Args: - n_clusters (int): The number of clusters to create. - max_iter (int, optional)): The maximum number of iterations to run the algorithm. Defaults to 10. + n_clusters (int): Number of clusters to partition the data into. + max_iter (int, optional): Maximum number of iterations for the clustering + algorithm. Defaults to 10. + + Attributes: + cluster_centers_ (torch.Tensor): Coordinates of cluster centers after + fitting. Shape: ``(n_clusters, n_features)``. + labels_ (torch.Tensor): Cluster labels for the training data after + fitting. Shape: ``(n_samples,)``. + + Example: + >>> import torch + >>> from anomalib.models.components.cluster import KMeans + >>> kmeans = KMeans(n_clusters=3) + >>> data = torch.randn(100, 5) # 100 samples, 5 features + >>> labels, centers = kmeans.fit(data) + >>> print(f"Cluster assignments shape: {labels.shape}") + >>> print(f"Cluster centers shape: {centers.shape}") """ def __init__(self, n_clusters: int, max_iter: int = 10) -> None: @@ -22,15 +57,26 @@ def fit(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Fit the K-means algorithm to the input data. Args: - inputs (torch.Tensor): Input data of shape (batch_size, n_features). + inputs (torch.Tensor): Input data to cluster. + Shape: ``(n_samples, n_features)``. Returns: - tuple: A tuple containing the labels of the input data with respect to the identified clusters - and the cluster centers themselves. The labels have a shape of (batch_size,) and the - cluster centers have a shape of (n_clusters, n_features). + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - labels: Cluster assignments for each input point. + Shape: ``(n_samples,)`` + - cluster_centers: Coordinates of the cluster centers. + Shape: ``(n_clusters, n_features)`` Raises: - ValueError: If the number of clusters is less than or equal to 0. + ValueError: If ``n_clusters`` is less than or equal to 0. + + Example: + >>> kmeans = KMeans(n_clusters=2) + >>> data = torch.tensor([[1.0, 2.0], [4.0, 5.0], [1.2, 2.1]]) + >>> labels, centers = kmeans.fit(data) + >>> print(f"Number of points in each cluster: { + ... [(labels == i).sum().item() for i in range(2)] + ... }") """ batch_size, _ = inputs.shape @@ -46,25 +92,36 @@ def fit(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Assign each data point to the closest centroid self.labels_ = torch.argmin(distances, dim=1) - # Update the centroids to be the mean of the data points assigned to them + # Update the centroids to be the mean of the data points assigned for j in range(self.n_clusters): mask = self.labels_ == j if mask.any(): self.cluster_centers_[j] = inputs[mask].mean(dim=0) - # this line returns labels and centoids of the results + return self.labels_, self.cluster_centers_ def predict(self, inputs: torch.Tensor) -> torch.Tensor: - """Predict the labels of input data based on the fitted model. + """Predict cluster labels for input data. Args: - inputs (torch.Tensor): Input data of shape (batch_size, n_features). + inputs (torch.Tensor): Input data to assign to clusters. + Shape: ``(n_samples, n_features)``. Returns: - torch.Tensor: The predicted labels of the input data with respect to the identified clusters. + torch.Tensor: Predicted cluster labels. + Shape: ``(n_samples,)``. Raises: - AttributeError: If the KMeans object has not been fitted to input data. + AttributeError: If called before fitting the model. + + Example: + >>> kmeans = KMeans(n_clusters=2) + >>> # First fit the model + >>> train_data = torch.tensor([[1.0, 2.0], [4.0, 5.0]]) + >>> kmeans.fit(train_data) + >>> # Then predict on new data + >>> new_data = torch.tensor([[1.1, 2.1], [3.9, 4.8]]) + >>> predictions = kmeans.predict(new_data) """ distances = torch.cdist(inputs, self.cluster_centers_) return torch.argmin(distances, dim=1) diff --git a/src/anomalib/models/components/dimensionality_reduction/__init__.py b/src/anomalib/models/components/dimensionality_reduction/__init__.py index d69c691bf0..62260edda8 100644 --- a/src/anomalib/models/components/dimensionality_reduction/__init__.py +++ b/src/anomalib/models/components/dimensionality_reduction/__init__.py @@ -1,6 +1,27 @@ -"""Algorithms for decomposition and dimensionality reduction.""" +"""Dimensionality reduction and decomposition algorithms for feature processing. -# Copyright (C) 2022 Intel Corporation +This module provides implementations of dimensionality reduction techniques used +in anomaly detection models. + +Classes: + PCA: Principal Component Analysis for linear dimensionality reduction. + SparseRandomProjection: Random projection using sparse random matrices. + +Example: + >>> from anomalib.models.components.dimensionality_reduction import PCA + >>> # Create and fit PCA + >>> pca = PCA(n_components=10) + >>> features = torch.randn(100, 50) # 100 samples, 50 features + >>> reduced_features = pca.fit_transform(features) + >>> # Use SparseRandomProjection + >>> from anomalib.models.components.dimensionality_reduction import ( + ... SparseRandomProjection + ... ) + >>> projector = SparseRandomProjection(n_components=20) + >>> projected_features = projector.fit_transform(features) +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .pca import PCA diff --git a/src/anomalib/models/components/dimensionality_reduction/pca.py b/src/anomalib/models/components/dimensionality_reduction/pca.py index 3e9bd4bb65..55fa679243 100644 --- a/src/anomalib/models/components/dimensionality_reduction/pca.py +++ b/src/anomalib/models/components/dimensionality_reduction/pca.py @@ -1,4 +1,20 @@ -"""Principle Component Analysis (PCA) with PyTorch.""" +"""Principal Component Analysis (PCA) implementation using PyTorch. + +This module provides a PyTorch-based implementation of Principal Component Analysis +for dimensionality reduction. + +Example: + >>> import torch + >>> from anomalib.models.components import PCA + >>> # Create sample data + >>> data = torch.randn(100, 10) # 100 samples, 10 features + >>> # Initialize PCA with 3 components + >>> pca = PCA(n_components=3) + >>> # Fit and transform the data + >>> transformed_data = pca.fit_transform(data) + >>> print(transformed_data.shape) + torch.Size([100, 3]) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,31 +25,34 @@ class PCA(DynamicBufferMixin): - """Principle Component Analysis (PCA). + """Principal Component Analysis (PCA) for dimensionality reduction. Args: - n_components (float): Number of components. Can be either integer number of components - or a ratio between 0-1. + n_components (int | float): Number of components to keep. If float between + 0 and 1, represents the variance ratio to preserve. If int, represents + the exact number of components to keep. + + Attributes: + singular_vectors (torch.Tensor): Right singular vectors from SVD. + singular_values (torch.Tensor): Singular values from SVD. + mean (torch.Tensor): Mean of the training data. + num_components (torch.Tensor): Number of components kept. Example: >>> import torch >>> from anomalib.models.components import PCA - - Create a PCA model with 2 components: - - >>> pca = PCA(n_components=2) - - Create a random embedding and fit a PCA model. - - >>> embedding = torch.rand(1000, 5).cuda() - >>> pca = PCA(n_components=2) - >>> pca.fit(embedding) - - Apply transformation: - - >>> transformed = pca.transform(embedding) - >>> transformed.shape - torch.Size([1000, 2]) + >>> # Create sample data + >>> data = torch.randn(100, 10) # 100 samples, 10 features + >>> # Initialize with fixed number of components + >>> pca = PCA(n_components=3) + >>> pca.fit(data) + >>> # Transform new data + >>> transformed = pca.transform(data) + >>> print(transformed.shape) + torch.Size([100, 3]) + >>> # Initialize with variance ratio + >>> pca = PCA(n_components=0.95) # Keep 95% of variance + >>> pca.fit(data) """ def __init__(self, n_components: int | float) -> None: @@ -50,18 +69,21 @@ def __init__(self, n_components: int | float) -> None: self.num_components: torch.Tensor def fit(self, dataset: torch.Tensor) -> None: - """Fits the PCA model to the dataset. + """Fit the PCA model to the dataset. Args: - dataset (torch.Tensor): Input dataset to fit the model. + dataset (torch.Tensor): Input dataset of shape ``(n_samples, + n_features)``. Example: - >>> pca.fit(embedding) - >>> pca.singular_vectors - tensor([9.6053, 9.2763], device='cuda:0') - - >>> pca.mean - tensor([0.4859, 0.4959, 0.4906, 0.5010, 0.5042], device='cuda:0') + >>> data = torch.randn(100, 10) + >>> pca = PCA(n_components=3) + >>> pca.fit(data) + >>> # Access fitted attributes + >>> print(pca.singular_vectors.shape) + torch.Size([10, 3]) + >>> print(pca.mean.shape) + torch.Size([10]) """ mean = dataset.mean(dim=0) dataset -= mean @@ -81,19 +103,22 @@ def fit(self, dataset: torch.Tensor) -> None: self.mean = mean def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor: - """Fit and transform PCA to dataset. + """Fit the model and transform the input dataset. Args: - dataset (torch.Tensor): Dataset to which the PCA if fit and transformed + dataset (torch.Tensor): Input dataset of shape ``(n_samples, + n_features)``. Returns: - Transformed dataset + torch.Tensor: Transformed dataset of shape ``(n_samples, + n_components)``. Example: - >>> pca.fit_transform(embedding) - >>> transformed_embedding = pca.fit_transform(embedding) - >>> transformed_embedding.shape - torch.Size([1000, 2]) + >>> data = torch.randn(100, 10) + >>> pca = PCA(n_components=3) + >>> transformed = pca.fit_transform(data) + >>> print(transformed.shape) + torch.Size([100, 3]) """ mean = dataset.mean(dim=0) dataset -= mean @@ -107,54 +132,66 @@ def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor: return torch.matmul(dataset, self.singular_vectors) def transform(self, features: torch.Tensor) -> torch.Tensor: - """Transform the features based on singular vectors calculated earlier. + """Transform features using the fitted PCA model. Args: - features (torch.Tensor): Input features + features (torch.Tensor): Input features of shape ``(n_samples, + n_features)``. Returns: - Transformed features + torch.Tensor: Transformed features of shape ``(n_samples, + n_components)``. Example: - >>> pca.transform(embedding) - >>> transformed_embedding = pca.transform(embedding) - - >>> embedding.shape - torch.Size([1000, 5]) - # - >>> transformed_embedding.shape - torch.Size([1000, 2]) + >>> data = torch.randn(100, 10) + >>> pca = PCA(n_components=3) + >>> pca.fit(data) + >>> new_data = torch.randn(50, 10) + >>> transformed = pca.transform(new_data) + >>> print(transformed.shape) + torch.Size([50, 3]) """ features -= self.mean return torch.matmul(features, self.singular_vectors) def inverse_transform(self, features: torch.Tensor) -> torch.Tensor: - """Inverses the transformed features. + """Inverse transform features back to original space. Args: - features (torch.Tensor): Transformed features + features (torch.Tensor): Transformed features of shape ``(n_samples, + n_components)``. Returns: - Inverse features + torch.Tensor: Reconstructed features of shape ``(n_samples, + n_features)``. Example: - >>> inverse_embedding = pca.inverse_transform(transformed_embedding) - >>> inverse_embedding.shape - torch.Size([1000, 5]) + >>> data = torch.randn(100, 10) + >>> pca = PCA(n_components=3) + >>> transformed = pca.fit_transform(data) + >>> reconstructed = pca.inverse_transform(transformed) + >>> print(reconstructed.shape) + torch.Size([100, 10]) """ return torch.matmul(features, self.singular_vectors.transpose(-2, -1)) def forward(self, features: torch.Tensor) -> torch.Tensor: - """Transform the features. + """Transform features (alias for transform method). Args: - features (torch.Tensor): Input features + features (torch.Tensor): Input features of shape ``(n_samples, + n_features)``. Returns: - Transformed features + torch.Tensor: Transformed features of shape ``(n_samples, + n_components)``. Example: - >>> pca(embedding).shape - torch.Size([1000, 2]) + >>> data = torch.randn(100, 10) + >>> pca = PCA(n_components=3) + >>> pca.fit(data) + >>> transformed = pca(data) # Using forward + >>> print(transformed.shape) + torch.Size([100, 3]) """ return self.transform(features) diff --git a/src/anomalib/models/components/dimensionality_reduction/random_projection.py b/src/anomalib/models/components/dimensionality_reduction/random_projection.py index cfa6ecad30..083103273a 100644 --- a/src/anomalib/models/components/dimensionality_reduction/random_projection.py +++ b/src/anomalib/models/components/dimensionality_reduction/random_projection.py @@ -1,6 +1,18 @@ """Random Sparse Projector. -Sparse Random Projection using PyTorch Operations +This module provides a PyTorch implementation of Sparse Random Projection for +dimensionality reduction. + +Example: + >>> import torch + >>> from anomalib.models.components import SparseRandomProjection + >>> # Create sample data + >>> data = torch.randn(100, 50) # 100 samples, 50 features + >>> # Initialize projector + >>> projector = SparseRandomProjection(eps=0.1) + >>> # Fit and transform the data + >>> projected_data = projector.fit_transform(data) + >>> print(projected_data.shape) """ # Copyright (C) 2022-2024 Intel Corporation @@ -12,40 +24,43 @@ class NotFittedError(ValueError, AttributeError): - """Raise Exception if estimator is used before fitting.""" + """Exception raised when model is used before fitting.""" class SparseRandomProjection: """Sparse Random Projection using PyTorch operations. + This class implements sparse random projection for dimensionality reduction + using PyTorch. The implementation is based on the paper by Li et al. [1]_. + Args: eps (float, optional): Minimum distortion rate parameter for calculating - Johnson-Lindenstrauss minimum dimensions. - Defaults to ``0.1``. - random_state (int | None, optional): Uses the seed to set the random - state for sample_without_replacement function. - Defaults to ``None``. - - Example: - To fit and transform the embedding tensor, use the following code: - - .. code-block:: python - - import torch - from anomalib.models.components import SparseRandomProjection - - sparse_embedding = torch.rand(1000, 5).cuda() - model = SparseRandomProjection(eps=0.1) - - Fit the model and transform the embedding tensor: - - .. code-block:: python + Johnson-Lindenstrauss minimum dimensions. Defaults to ``0.1``. + random_state (int | None, optional): Seed for random number generation. + Used for reproducible results. Defaults to ``None``. - model.fit(sparse_embedding) - projected_embedding = model.transform(sparse_embedding) + Attributes: + n_components (int): Number of components in the projected space. + sparse_random_matrix (torch.Tensor): Random projection matrix. + eps (float): Minimum distortion rate. + random_state (int | None): Random seed. - print(projected_embedding.shape) - # Output: torch.Size([1000, 5920]) + Example: + >>> import torch + >>> from anomalib.models.components import SparseRandomProjection + >>> # Create sample data + >>> data = torch.randn(100, 50) # 100 samples, 50 features + >>> # Initialize and fit projector + >>> projector = SparseRandomProjection(eps=0.1) + >>> projector.fit(data) + >>> # Transform data + >>> projected = projector.transform(data) + >>> print(projected.shape) + + References: + .. [1] P. Li, T. Hastie and K. Church, "Very Sparse Random Projections," + KDD '06, 2006. + https://web.stanford.edu/~hastie/Papers/Ping/KDD06_rp.pdf """ def __init__(self, eps: float = 0.1, random_state: int | None = None) -> None: @@ -55,15 +70,20 @@ def __init__(self, eps: float = 0.1, random_state: int | None = None) -> None: self.random_state = random_state def _sparse_random_matrix(self, n_features: int) -> torch.Tensor: - """Random sparse matrix. Based on https://web.stanford.edu/~hastie/Papers/Ping/KDD06_rp.pdf. + """Generate a sparse random matrix for projection. + + Implements the sparse random matrix generation described in [1]_. Args: - n_features (int): Dimentionality of the original source space + n_features (int): Dimensionality of the original source space. Returns: - Tensor: Sparse matrix of shape (n_components, n_features). - The generated Gaussian random matrix is in CSR (compressed sparse row) - format. + torch.Tensor: Sparse matrix of shape ``(n_components, n_features)``. + The matrix is stored in dense format for GPU compatibility. + + References: + .. [1] P. Li, T. Hastie and K. Church, "Very Sparse Random + Projections," KDD '06, 2006. """ # Density 'auto'. Factorize density density = 1 / np.sqrt(n_features) @@ -100,28 +120,40 @@ def _sparse_random_matrix(self, n_features: int) -> torch.Tensor: @staticmethod def _johnson_lindenstrauss_min_dim(n_samples: int, eps: float = 0.1) -> int | np.integer: - """Find a 'safe' number of components to randomly project to. + """Find a 'safe' number of components for random projection. - Ref eqn 2.1 https://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf + Implements the Johnson-Lindenstrauss lemma to determine the minimum number + of components needed to approximately preserve distances. Args: - n_samples (int): Number of samples used to compute safe components - eps (float, optional): Minimum distortion rate. Defaults to 0.1. + n_samples (int): Number of samples in the dataset. + eps (float, optional): Minimum distortion rate. Defaults to ``0.1``. + + Returns: + int: Minimum number of components required. + + References: + .. [1] Dasgupta, S. and Gupta, A., "An elementary proof of a theorem + of Johnson and Lindenstrauss," Random Struct. Algor., 22: 60-65, + 2003. """ denominator = (eps**2 / 2) - (eps**3 / 3) return (4 * np.log(n_samples) / denominator).astype(np.int64) def fit(self, embedding: torch.Tensor) -> "SparseRandomProjection": - """Generate sparse matrix from the embedding tensor. + """Fit the random projection matrix to the data. Args: - embedding (torch.Tensor): embedding tensor for generating embedding + embedding (torch.Tensor): Input tensor of shape + ``(n_samples, n_features)``. Returns: - (SparseRandomProjection): Return self to be used as + SparseRandomProjection: The fitted projector. - >>> model = SparseRandomProjection() - >>> model = model.fit() + Example: + >>> projector = SparseRandomProjection() + >>> data = torch.randn(100, 50) + >>> projector = projector.fit(data) """ n_samples, n_features = embedding.shape device = embedding.device @@ -137,20 +169,25 @@ def fit(self, embedding: torch.Tensor) -> "SparseRandomProjection": return self def transform(self, embedding: torch.Tensor) -> torch.Tensor: - """Project the data by using matrix product with the random matrix. + """Project the data using the random projection matrix. Args: - embedding (torch.Tensor): Embedding of shape (n_samples, n_features) - The input data to project into a smaller dimensional space + embedding (torch.Tensor): Input tensor of shape + ``(n_samples, n_features)``. Returns: - projected_embedding (torch.Tensor): Sparse matrix of shape - (n_samples, n_components) Projected array. + torch.Tensor: Projected tensor of shape + ``(n_samples, n_components)``. + + Raises: + NotFittedError: If transform is called before fitting. Example: - >>> projected_embedding = model.transform(embedding) - >>> projected_embedding.shape - torch.Size([1000, 5920]) + >>> projector = SparseRandomProjection() + >>> data = torch.randn(100, 50) + >>> projector.fit(data) + >>> projected = projector.transform(data) + >>> print(projected.shape) """ if self.sparse_random_matrix is None: msg = "`fit()` has not been called on SparseRandomProjection yet." diff --git a/src/anomalib/models/components/feature_extractors/__init__.py b/src/anomalib/models/components/feature_extractors/__init__.py index 5092056967..be57c40936 100644 --- a/src/anomalib/models/components/feature_extractors/__init__.py +++ b/src/anomalib/models/components/feature_extractors/__init__.py @@ -1,4 +1,28 @@ -"""Feature extractors.""" +"""Feature extractors for deep learning models. + +This module provides feature extraction utilities and classes for extracting +features from images using various backbone architectures. + +Classes: + TimmFeatureExtractor: Feature extractor using timm models. + TorchFXFeatureExtractor: Feature extractor using TorchFX for graph capture. + BackboneParams: Configuration parameters for backbone models. + +Functions: + dryrun_find_featuremap_dims: Utility to find feature map dimensions. + +Example: + >>> from anomalib.models.components.feature_extractors import ( + ... TimmFeatureExtractor + ... ) + >>> # Create feature extractor + >>> feature_extractor = TimmFeatureExtractor( + ... backbone="resnet18", + ... layers=['layer1', 'layer2'] + ... ) + >>> # Extract features + >>> features = feature_extractor(images) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/components/feature_extractors/timm.py b/src/anomalib/models/components/feature_extractors/timm.py index ae81dfb2c4..adb1d41153 100644 --- a/src/anomalib/models/components/feature_extractors/timm.py +++ b/src/anomalib/models/components/feature_extractors/timm.py @@ -1,6 +1,24 @@ -"""Feature Extractor. - -This script extracts features from a CNN network +"""Feature extractor using timm models. + +This module provides a feature extractor implementation that leverages the timm +library to extract intermediate features from various CNN architectures. + +Example: + >>> import torch + >>> from anomalib.models.components.feature_extractors import ( + ... TimmFeatureExtractor + ... ) + >>> # Initialize feature extractor + >>> extractor = TimmFeatureExtractor( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"] + ... ) + >>> # Extract features from input + >>> inputs = torch.randn(32, 3, 256, 256) + >>> features = extractor(inputs) + >>> # Access features by layer name + >>> print(features["layer1"].shape) + torch.Size([32, 64, 64, 64]) """ # Copyright (C) 2022-2024 Intel Corporation @@ -17,31 +35,44 @@ class TimmFeatureExtractor(nn.Module): - """Extract features from a CNN. + """Extract intermediate features from timm models. Args: - backbone (nn.Module): The backbone to which the feature extraction hooks are attached. - layers (Iterable[str]): List of layer names of the backbone to which the hooks are attached. - pre_trained (bool): Whether to use a pre-trained backbone. Defaults to True. - requires_grad (bool): Whether to require gradients for the backbone. Defaults to False. - Models like ``stfpm`` use the feature extractor model as a trainable network. In such cases gradient - computation is required. + backbone (str): Name of the timm model architecture to use as backbone. + Can include custom weights URI in format ``name__AT__uri``. + layers (Sequence[str]): Names of layers from which to extract features. + pre_trained (bool, optional): Whether to use pre-trained weights. + Defaults to ``True``. + requires_grad (bool, optional): Whether to compute gradients for the + backbone. Required for training models like STFPM. Defaults to + ``False``. + + Attributes: + backbone (str): Name of the backbone model. + layers (list[str]): Layer names for feature extraction. + idx (list[int]): Indices mapping layer names to model outputs. + requires_grad (bool): Whether gradients are computed. + feature_extractor (nn.Module): The underlying timm model. + out_dims (list[int]): Output dimensions for each extracted layer. Example: - .. code-block:: python - - import torch - from anomalib.models.components.feature_extractors import TimmFeatureExtractor - - model = TimmFeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3']) - input = torch.rand((32, 3, 256, 256)) - features = model(input) - - print([layer for layer in features.keys()]) - # Output: ['layer1', 'layer2', 'layer3'] - - print([feature.shape for feature in features.values()]() - # Output: [torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])] + >>> import torch + >>> from anomalib.models.components.feature_extractors import ( + ... TimmFeatureExtractor + ... ) + >>> # Create extractor + >>> model = TimmFeatureExtractor( + ... backbone="resnet18", + ... layers=["layer1", "layer2"] + ... ) + >>> # Extract features + >>> inputs = torch.randn(1, 3, 224, 224) + >>> features = model(inputs) + >>> # Print shapes + >>> for name, feat in features.items(): + ... print(f"{name}: {feat.shape}") + layer1: torch.Size([1, 64, 56, 56]) + layer2: torch.Size([1, 128, 28, 28]) """ def __init__( @@ -78,10 +109,14 @@ def __init__( self._features = {layer: torch.empty(0) for layer in self.layers} def _map_layer_to_idx(self) -> list[int]: - """Map set of layer names to indices of model. + """Map layer names to their indices in the model's output. Returns: - list[int]: Feature map extracted from the CNN. + list[int]: Indices corresponding to the requested layer names. + + Note: + If a requested layer is not found in the model, it is removed from + ``self.layers`` and a warning is logged. """ idx = [] model = timm.create_model( @@ -90,7 +125,8 @@ def _map_layer_to_idx(self) -> list[int]: features_only=True, exportable=True, ) - # model.feature_info.info returns list of dicts containing info, inside which "module" contains layer name + # model.feature_info.info returns list of dicts containing info, + # inside which "module" contains layer name layer_names = [info["module"] for info in model.feature_info.info] for layer in self.layers: try: @@ -104,21 +140,29 @@ def _map_layer_to_idx(self) -> list[int]: return idx def forward(self, inputs: torch.Tensor) -> dict[str, torch.Tensor]: - """Forward-pass input tensor into the CNN. + """Extract features from the input tensor. Args: - inputs (torch.Tensor): Input tensor + inputs (torch.Tensor): Input tensor of shape + ``(batch_size, channels, height, width)``. Returns: - Feature map extracted from the CNN + dict[str, torch.Tensor]: Dictionary mapping layer names to their + feature tensors. Example: - .. code-block:: python - - model = TimmFeatureExtractor(model="resnet50", layers=['layer3']) - input = torch.rand((32, 3, 256, 256)) - features = model.forward(input) - + >>> import torch + >>> from anomalib.models.components.feature_extractors import ( + ... TimmFeatureExtractor + ... ) + >>> model = TimmFeatureExtractor( + ... backbone="resnet18", + ... layers=["layer1"] + ... ) + >>> inputs = torch.randn(1, 3, 224, 224) + >>> features = model(inputs) + >>> features["layer1"].shape + torch.Size([1, 64, 56, 56]) """ if self.requires_grad: features = dict(zip(self.layers, self.feature_extractor(inputs), strict=True)) diff --git a/src/anomalib/models/components/feature_extractors/torchfx.py b/src/anomalib/models/components/feature_extractors/torchfx.py index 600f2a961d..355d611d10 100644 --- a/src/anomalib/models/components/feature_extractors/torchfx.py +++ b/src/anomalib/models/components/feature_extractors/torchfx.py @@ -1,4 +1,52 @@ -"""Feature Extractor based on TorchFX.""" +"""Feature Extractor based on TorchFX. + +This module provides a feature extractor implementation that leverages TorchFX to +extract intermediate features from CNN architectures. + +Example: + >>> import torch + >>> from anomalib.models.components.feature_extractors import ( + ... TorchFXFeatureExtractor + ... ) + >>> # Initialize with torchvision model + >>> from torchvision.models.efficientnet import EfficientNet_B5_Weights + >>> extractor = TorchFXFeatureExtractor( + ... backbone="efficientnet_b5", + ... return_nodes=["features.6.8"], + ... weights=EfficientNet_B5_Weights.DEFAULT + ... ) + >>> # Extract features + >>> inputs = torch.rand((32, 3, 256, 256)) + >>> features = extractor(inputs) + >>> print([layer for layer in features.keys()]) + ['features.6.8'] + >>> print([feature.shape for feature in features.values()]) + [torch.Size([32, 304, 8, 8])] + + With custom models: + >>> # Initialize with custom model + >>> extractor = TorchFXFeatureExtractor( + ... "path.to.CustomModel", + ... ["linear_relu_stack.3"], + ... weights="path/to/weights.pth" + ... ) + >>> inputs = torch.randn(1, 1, 28, 28) + >>> features = extractor(inputs) + >>> print([layer for layer in features.keys()]) + ['linear_relu_stack.3'] + + With model instances: + >>> # Initialize with model instance + >>> from timm import create_model + >>> model = create_model("resnet18", pretrained=True) + >>> extractor = TorchFXFeatureExtractor(model, ["layer1"]) + >>> inputs = torch.rand((32, 3, 256, 256)) + >>> features = extractor(inputs) + >>> print([layer for layer in features.keys()]) + ['layer1'] + >>> print([feature.shape for feature in features.values()]) + [torch.Size([32, 64, 64, 64])] +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,91 +64,64 @@ @dataclass class BackboneParams: - """Used for serializing the backbone.""" + """Used for serializing the backbone. + + Args: + class_path (str | type[nn.Module]): Path to the backbone class or the + class itself. + init_args (dict): Dictionary of initialization arguments for the backbone. + Defaults to empty dict. + """ class_path: str | type[nn.Module] init_args: dict = field(default_factory=dict) class TorchFXFeatureExtractor(nn.Module): - """Extract features from a CNN. + """Extract features from a CNN using TorchFX. Args: - backbone (str | BackboneParams | dict | nn.Module): The backbone to which the feature extraction hooks are - attached. If the name is provided, the model is loaded from torchvision. Otherwise, the model class can be - provided and it will try to load the weights from the provided weights file. Last, an instance of nn.Module - can also be passed directly. - return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached. - You can find the names of these nodes by using ``get_graph_node_names`` function. - weights (str | WeightsEnum | None): Weights enum to use for the model. Torchvision models require - ``WeightsEnum``. These enums are defined in ``torchvision.models.``. You can pass the weights - path for custom models. - requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should - set ``requires_grad`` to ``True``. Default is ``False``. - tracer_kwargs (dict | None): a dictionary of keyword arguments for NodePathTracer (which passes them onto - it's parent class torch.fx.Tracer). Can be used to allow not tracing through a list of problematic - modules, by passing a list of `leaf_modules` as one of the `tracer_kwargs`. + backbone (str | BackboneParams | dict | nn.Module): The backbone to which + the feature extraction hooks are attached. If a string name is + provided, the model is loaded from torchvision. Otherwise, the model + class can be provided and it will try to load the weights from the + provided weights file. Last, an instance of nn.Module can also be + passed directly. + return_nodes (list[str]): List of layer names of the backbone to which + the hooks are attached. You can find the names of these nodes by + using ``get_graph_node_names`` function. + weights (str | WeightsEnum | None): Weights enum to use for the model. + Torchvision models require ``WeightsEnum``. These enums are defined + in ``torchvision.models.``. You can pass the weights path for + custom models. Defaults to ``None``. + requires_grad (bool): Models like ``stfpm`` use the feature extractor for + training. In such cases we should set ``requires_grad`` to ``True``. + Defaults to ``False``. + tracer_kwargs (dict | None): Dictionary of keyword arguments for + NodePathTracer (which passes them onto it's parent class + torch.fx.Tracer). Can be used to allow not tracing through a list of + problematic modules, by passing a list of ``leaf_modules`` as one of + the ``tracer_kwargs``. Defaults to ``None``. + + Attributes: + feature_extractor (GraphModule): The TorchFX feature extractor module. Example: - With torchvision models: - - .. code-block:: python - - import torch - from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor - from torchvision.models.efficientnet import EfficientNet_B5_Weights - - feature_extractor = TorchFXFeatureExtractor( - backbone="efficientnet_b5", - return_nodes=["features.6.8"], - weights=EfficientNet_B5_Weights.DEFAULT - ) - - input = torch.rand((32, 3, 256, 256)) - features = feature_extractor(input) - - print([layer for layer in features.keys()]) - # Output: ["features.6.8"] - - print([feature.shape for feature in features.values()]) - # Output: [torch.Size([32, 304, 8, 8])] - - With custom models: - - .. code-block:: python - - import torch - from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor - - feature_extractor = TorchFXFeatureExtractor( - "path.to.CustomModel", ["linear_relu_stack.3"], weights="path/to/weights.pth" - ) - - input = torch.randn(1, 1, 28, 28) - features = feature_extractor(input) - - print([layer for layer in features.keys()]) - # Output: ["linear_relu_stack.3"] - - with model instances: - - .. code-block:: python - - import torch - from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor - from timm import create_model - - model = create_model("resnet18", pretrained=True) - feature_extractor = TorchFXFeatureExtractor(model, ["layer1"]) - - input = torch.rand((32, 3, 256, 256)) - features = feature_extractor(input) - - print([layer for layer in features.keys()]) - # Output: ["layer1"] - - print([feature.shape for feature in features.values()]) - # Output: [torch.Size([32, 64, 64, 64])] + >>> import torch + >>> from anomalib.models.components.feature_extractors import ( + ... TorchFXFeatureExtractor + ... ) + >>> # Initialize with torchvision model + >>> extractor = TorchFXFeatureExtractor( + ... backbone="resnet18", + ... return_nodes=["layer1", "layer2"] + ... ) + >>> # Extract features + >>> inputs = torch.randn(1, 3, 224, 224) + >>> features = extractor(inputs) + >>> # Access features by layer name + >>> print(features["layer1"].shape) + torch.Size([1, 64, 56, 56]) """ def __init__( @@ -136,26 +157,25 @@ def initialize_feature_extractor( requires_grad: bool = False, tracer_kwargs: dict | None = None, ) -> GraphModule: - """Extract features from a CNN. + """Initialize the feature extractor. Args: - backbone (BackboneParams | nn.Module): The backbone to which the feature extraction hooks are attached. - If the name is provided for BackboneParams, the model is loaded from torchvision. Otherwise, the model - class can be provided and it will try to load the weights from the provided weights file. Last, an - instance of the model can be provided as well, which will be used as-is. - return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached. - You can find the names of these nodes by using ``get_graph_node_names`` function. - weights (str | WeightsEnum | None): Weights enum to use for the model. Torchvision models require - ``WeightsEnum``. These enums are defined in ``torchvision.models.``. You can pass the weights - path for custom models. - requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should - set ``requires_grad`` to ``True``. Default is ``False``. - tracer_kwargs (dict | None): a dictionary of keyword arguments for NodePathTracer (which passes them onto - it's parent class torch.fx.Tracer). Can be used to allow not tracing through a list of problematic - modules, by passing a list of `leaf_modules` as one of the `tracer_kwargs`. + backbone (BackboneParams | nn.Module): The backbone to which the + feature extraction hooks are attached. + return_nodes (list[str]): List of layer names to extract features + from. + weights (str | WeightsEnum | None): Model weights specification. + Defaults to ``None``. + requires_grad (bool): Whether to compute gradients. Defaults to + ``False``. + tracer_kwargs (dict | None): Additional arguments for the tracer. + Defaults to ``None``. Returns: - Feature Extractor based on TorchFX. + GraphModule: Initialized feature extractor. + + Raises: + TypeError: If weights format is invalid. """ if isinstance(backbone, nn.Module): backbone_model = backbone @@ -167,7 +187,10 @@ class can be provided and it will try to load the weights from the provided weig backbone_model = backbone_class(**backbone.init_args) if isinstance(weights, WeightsEnum): # torchvision models - feature_extractor = create_feature_extractor(model=backbone_model, return_nodes=return_nodes) + feature_extractor = create_feature_extractor( + model=backbone_model, + return_nodes=return_nodes, + ) elif weights is not None: if not isinstance(weights, str): msg = "Weights should point to a path" @@ -178,7 +201,11 @@ class can be provided and it will try to load the weights from the provided weig model_weights = model_weights["state_dict"] backbone_model.load_state_dict(model_weights) - feature_extractor = create_feature_extractor(backbone_model, return_nodes, tracer_kwargs=tracer_kwargs) + feature_extractor = create_feature_extractor( + backbone_model, + return_nodes, + tracer_kwargs=tracer_kwargs, + ) if not requires_grad: feature_extractor.eval() @@ -191,26 +218,30 @@ class can be provided and it will try to load the weights from the provided weig def _get_backbone_class(backbone: str) -> Callable[..., nn.Module]: """Get the backbone class from the provided path. - If only the model name is provided, it will try to load the model from torchvision. - - Example: - >>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor - >>> TorchFXFeatureExtractor._get_backbone_class("efficientnet_b5") - torchvision.models.efficientnet.EfficientNet> - - >>> TorchFXFeatureExtractor._get_backbone_class("path.to.CustomModel") - + If only the model name is provided, it will try to load the model from + torchvision. Args: backbone (str): Path to the backbone class. Returns: - Backbone class. + Callable[..., nn.Module]: Backbone class. + + Raises: + ModuleNotFoundError: If backbone cannot be found. + + Example: + >>> from anomalib.models.components.feature_extractors import ( + ... TorchFXFeatureExtractor + ... ) + >>> # Get torchvision model + >>> cls = TorchFXFeatureExtractor._get_backbone_class( + ... "efficientnet_b5" + ... ) + >>> # Get custom model + >>> cls = TorchFXFeatureExtractor._get_backbone_class( + ... "path.to.CustomModel" + ... ) """ try: if len(backbone.split(".")) > 1: @@ -222,12 +253,18 @@ def _get_backbone_class(backbone: str) -> Callable[..., nn.Module]: backbone_class = getattr(models, backbone) except ModuleNotFoundError as exception: msg = f"Backbone {backbone} not found in torchvision.models nor in {backbone} module." - raise ModuleNotFoundError( - msg, - ) from exception + raise ModuleNotFoundError(msg) from exception return backbone_class def forward(self, inputs: torch.Tensor) -> dict[str, torch.Tensor]: - """Extract features from the input.""" + """Extract features from the input. + + Args: + inputs (torch.Tensor): Input tensor. + + Returns: + dict[str, torch.Tensor]: Dictionary mapping layer names to their + feature tensors. + """ return self.feature_extractor(inputs) diff --git a/src/anomalib/models/components/feature_extractors/utils.py b/src/anomalib/models/components/feature_extractors/utils.py index 71e50f7361..e1d56c3265 100644 --- a/src/anomalib/models/components/feature_extractors/utils.py +++ b/src/anomalib/models/components/feature_extractors/utils.py @@ -1,4 +1,30 @@ -"""Utility functions to manipulate feature extractors.""" +"""Utility functions to manipulate feature extractors. + +This module provides utility functions for working with feature extractors, +including functions to analyze feature map dimensions. + +Example: + >>> import torch + >>> from anomalib.models.components.feature_extractors import ( + ... TimmFeatureExtractor, + ... dryrun_find_featuremap_dims + ... ) + >>> # Create feature extractor + >>> extractor = TimmFeatureExtractor( + ... backbone="resnet18", + ... layers=["layer1", "layer2"] + ... ) + >>> # Get feature dimensions + >>> dims = dryrun_find_featuremap_dims( + ... extractor, + ... input_size=(256, 256), + ... layers=["layer1", "layer2"] + ... ) + >>> print(dims["layer1"]["num_features"]) # Number of channels + 64 + >>> print(dims["layer1"]["resolution"]) # Feature map height, width + (64, 64) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -14,16 +40,42 @@ def dryrun_find_featuremap_dims( input_size: tuple[int, int], layers: list[str], ) -> dict[str, dict[str, int | tuple[int, int]]]: - """Dry run an empty image of `input_size` size to get the featuremap tensors' dimensions (num_features, resolution). + """Get feature map dimensions by running an empty tensor through the model. + + Performs a forward pass with an empty tensor to determine the output + dimensions of specified feature maps. + + Args: + feature_extractor: Feature extraction model, either a ``TimmFeatureExtractor`` + or ``GraphModule``. + input_size: Tuple of ``(height, width)`` specifying input image dimensions. + layers: List of layer names from which to extract features. Returns: - tuple[int, int]: maping of `layer -> dimensions dict` - Each `dimension dict` has two keys: `num_features` (int) and `resolution`(tuple[int, int]). + Dictionary mapping layer names to dimension information. For each layer, + returns a dictionary with: + - ``num_features``: Number of feature channels (int) + - ``resolution``: Spatial dimensions as ``(height, width)`` tuple + + Example: + >>> extractor = TimmFeatureExtractor("resnet18", layers=["layer1"]) + >>> dims = dryrun_find_featuremap_dims( + ... extractor, + ... input_size=(256, 256), + ... layers=["layer1"] + ... ) + >>> print(dims["layer1"]["num_features"]) # channels + 64 + >>> print(dims["layer1"]["resolution"]) # (height, width) + (64, 64) """ device = next(feature_extractor.parameters()).device dryrun_input = torch.empty(1, 3, *input_size).to(device) dryrun_features = feature_extractor(dryrun_input) return { - layer: {"num_features": dryrun_features[layer].shape[1], "resolution": dryrun_features[layer].shape[2:]} + layer: { + "num_features": dryrun_features[layer].shape[1], + "resolution": dryrun_features[layer].shape[2:], + } for layer in layers } diff --git a/src/anomalib/models/components/filters/__init__.py b/src/anomalib/models/components/filters/__init__.py index 340daa47f2..c632383437 100644 --- a/src/anomalib/models/components/filters/__init__.py +++ b/src/anomalib/models/components/filters/__init__.py @@ -1,4 +1,20 @@ -"""Implements filters used by models.""" +"""Filters used by anomaly detection models. + +This module provides filter implementations that can be used for image +preprocessing and feature enhancement in anomaly detection models. + +Classes: + GaussianBlur2d: 2D Gaussian blur filter implementation. + +Example: + >>> import torch + >>> from anomalib.models.components.filters import GaussianBlur2d + >>> # Create a Gaussian blur filter + >>> blur = GaussianBlur2d(kernel_size=3, sigma=1.0) + >>> # Apply blur to input tensor + >>> input_tensor = torch.randn(1, 3, 256, 256) + >>> blurred = blur(input_tensor) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/components/filters/blur.py b/src/anomalib/models/components/filters/blur.py index 986214707d..cfe1640e04 100644 --- a/src/anomalib/models/components/filters/blur.py +++ b/src/anomalib/models/components/filters/blur.py @@ -1,4 +1,17 @@ -"""Gaussian blurring via pytorch.""" +"""Gaussian blurring implementation using PyTorch. + +This module provides a 2D Gaussian blur filter implementation that pre-computes +the Gaussian kernel during initialization for efficiency. + +Example: + >>> import torch + >>> from anomalib.models.components.filters import GaussianBlur2d + >>> # Create a Gaussian blur filter + >>> blur = GaussianBlur2d(sigma=1.0, channels=3) + >>> # Apply blur to input tensor + >>> input_tensor = torch.randn(1, 3, 256, 256) + >>> blurred = blur(input_tensor) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -14,21 +27,53 @@ def compute_kernel_size(sigma_val: float) -> int: """Compute kernel size from sigma value. + The kernel size is calculated as 2 * (4 * sigma + 0.5) + 1 to ensure it + captures the significant part of the Gaussian distribution. + Args: - sigma_val (float): Sigma value. + sigma_val (float): Standard deviation value for the Gaussian kernel. Returns: - int: Kernel size. + int: Computed kernel size (always odd). + + Example: + >>> compute_kernel_size(1.0) + 9 + >>> compute_kernel_size(2.0) + 17 """ return 2 * int(4.0 * sigma_val + 0.5) + 1 class GaussianBlur2d(nn.Module): - """Compute GaussianBlur in 2d. + """2D Gaussian blur filter with pre-computed kernel. + + Unlike some implementations, this class pre-computes the Gaussian kernel + during initialization rather than computing it during the forward pass. + This approach is more efficient but requires specifying the number of + input channels upfront. - Makes use of kornia functions, but most notably the kernel is not computed - during the forward pass, and does not depend on the input size. As a caveat, - the number of channels that are expected have to be provided during initialization. + Args: + sigma (float | tuple[float, float]): Standard deviation(s) for the + Gaussian kernel. If a single float is provided, it's used for both + dimensions. + channels (int): Number of input channels. Defaults to 1. + kernel_size (int | tuple[int, int] | None): Size of the Gaussian + kernel. If ``None``, computed from sigma. Defaults to ``None``. + normalize (bool): Whether to normalize the kernel so its elements sum + to 1. Defaults to ``True``. + border_type (str): Padding mode for border handling. Options are + 'reflect', 'replicate', etc. Defaults to "reflect". + padding (str): Padding strategy. Either 'same' or 'valid'. + Defaults to "same". + + Example: + >>> import torch + >>> blur = GaussianBlur2d(sigma=1.0, channels=3) + >>> x = torch.randn(1, 3, 64, 64) + >>> output = blur(x) + >>> output.shape + torch.Size([1, 3, 64, 64]) """ def __init__( @@ -40,17 +85,6 @@ def __init__( border_type: str = "reflect", padding: str = "same", ) -> None: - """Initialize model, setup kernel etc.. - - Args: - sigma (float | tuple[float, float]): standard deviation to use for constructing the Gaussian kernel. - channels (int): channels of the input. Defaults to 1. - kernel_size (int | tuple[int, int] | None): size of the Gaussian kernel to use. Defaults to None. - normalize (bool, optional): Whether to normalize the kernel or not (i.e. all elements sum to 1). - Defaults to True. - border_type (str, optional): Border type to use for padding of the input. Defaults to "reflect". - padding (str, optional): Type of padding to apply. Defaults to "same". - """ super().__init__() sigma = sigma if isinstance(sigma, tuple) else (sigma, sigma) self.channels = channels @@ -74,13 +108,22 @@ def __init__( self.padding_shape = _compute_padding([self.height, self.width]) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Blur the input with the computed Gaussian. + """Apply Gaussian blur to input tensor. Args: - input_tensor (torch.Tensor): Input tensor to be blurred. + input_tensor (torch.Tensor): Input tensor of shape + ``(B, C, H, W)``. Returns: - Tensor: Blurred output tensor. + torch.Tensor: Blurred output tensor. If padding is 'same', + output shape matches input. If 'valid', output is smaller. + + Example: + >>> blur = GaussianBlur2d(sigma=1.0, channels=1) + >>> x = torch.ones(1, 1, 5, 5) + >>> output = blur(x) + >>> output.shape + torch.Size([1, 1, 5, 5]) """ batch, channel, height, width = input_tensor.size() diff --git a/src/anomalib/models/components/flow/__init__.py b/src/anomalib/models/components/flow/__init__.py index dca2e7b9e6..f343c8dd38 100644 --- a/src/anomalib/models/components/flow/__init__.py +++ b/src/anomalib/models/components/flow/__init__.py @@ -1,6 +1,23 @@ -"""All In One Block Layer.""" +"""Flow components used in anomaly detection models. -# Copyright (C) 2022 Intel Corporation +This module provides flow-based components that can be used in anomaly detection +models. These components help model complex data distributions and transformations. + +Classes: + AllInOneBlock: A block that combines multiple flow operations into a single + transformation. + +Example: + >>> import torch + >>> from anomalib.models.components.flow import AllInOneBlock + >>> # Create flow block + >>> flow = AllInOneBlock(channels=64) + >>> # Apply flow transformation + >>> x = torch.randn(1, 64, 32, 32) + >>> y, logdet = flow(x) +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .all_in_one_block import AllInOneBlock diff --git a/src/anomalib/models/components/flow/all_in_one_block.py b/src/anomalib/models/components/flow/all_in_one_block.py index 6c6713add8..647306a23b 100644 --- a/src/anomalib/models/components/flow/all_in_one_block.py +++ b/src/anomalib/models/components/flow/all_in_one_block.py @@ -1,4 +1,74 @@ -"""All In One Block Layer.""" +r"""All In One Block Layer. + +This module provides an invertible block that combines multiple flow operations: +affine coupling, permutation, and global affine transformation. + +The block performs the following computation: + +.. math:: + + y = V R \; \Psi(s_\mathrm{global}) \odot \mathrm{Coupling} + \Big(R^{-1} V^{-1} x\Big)+ t_\mathrm{global} + +where: + +- :math:`V` is an optional learned householder reflection matrix +- :math:`R` is a permutation matrix +- :math:`\Psi` is an activation function for global scaling +- The coupling operation splits input :math:`x` into :math:`x_1, x_2` and outputs + :math:`u = \mathrm{concat}(u_1, u_2)` where: + + .. math:: + + u_1 &= x_1 \odot \exp \Big( \alpha \; \mathrm{tanh}\big( s(x_2) \big)\Big) + + t(x_2) \\ + u_2 &= x_2 + +Example: + >>> import torch + >>> from anomalib.models.components.flow import AllInOneBlock + >>> # Create flow block + >>> def subnet_fc(c_in, c_out): + ... return torch.nn.Sequential( + ... torch.nn.Linear(c_in, 128), + ... torch.nn.ReLU(), + ... torch.nn.Linear(128, c_out) + ... ) + >>> flow = AllInOneBlock( + ... dims_in=[(64,)], + ... subnet_constructor=subnet_fc + ... ) + >>> # Apply flow transformation + >>> x = torch.randn(10, 64) + >>> y, logdet = flow(x) + >>> print(y[0].shape) + torch.Size([10, 64]) + +Args: + dims_in (list[tuple[int]]): Dimensions of input tensor(s) + dims_c (list[tuple[int]], optional): Dimensions of conditioning tensor(s). + Defaults to None. + subnet_constructor (Callable, optional): Function that constructs the subnet, + called as ``f(channels_in, channels_out)``. Defaults to None. + affine_clamping (float, optional): Clamping value for affine coupling. + Defaults to 2.0. + gin_block (bool, optional): Use GIN coupling from Sorrenson et al, 2019. + Defaults to False. + global_affine_init (float, optional): Initial value for global affine + scaling. Defaults to 1.0. + global_affine_type (str, optional): Type of activation for global affine + scaling. One of ``'SIGMOID'``, ``'SOFTPLUS'``, ``'EXP'``. + Defaults to ``'SOFTPLUS'``. + permute_soft (bool, optional): Use soft permutation matrix from SO(N). + Defaults to False. + learned_householder_permutation (int, optional): Number of learned + householder reflections. Defaults to 0. + reverse_permutation (bool, optional): Apply inverse permutation before block. + Defaults to False. + +Raises: + ValueError: If ``subnet_constructor`` is None or dimensions are invalid. +""" # Copyright (c) https://github.com/vislearn/FrEIA # SPDX-License-Identifier: MIT @@ -20,90 +90,77 @@ def _global_scale_sigmoid_activation(input_tensor: torch.Tensor) -> torch.Tensor: - """Global scale sigmoid activation. + """Apply sigmoid activation for global scaling. Args: input_tensor (torch.Tensor): Input tensor Returns: - Tensor: Sigmoid activation + torch.Tensor: Scaled tensor after sigmoid activation """ return 10 * torch.sigmoid(input_tensor - 2.0) def _global_scale_softplus_activation(input_tensor: torch.Tensor) -> torch.Tensor: - """Global scale softplus activation. + """Apply softplus activation for global scaling. Args: input_tensor (torch.Tensor): Input tensor Returns: - Tensor: Softplus activation + torch.Tensor: Scaled tensor after softplus activation """ softplus = nn.Softplus(beta=0.5) return 0.1 * softplus(input_tensor) def _global_scale_exp_activation(input_tensor: torch.Tensor) -> torch.Tensor: - """Global scale exponential activation. + """Apply exponential activation for global scaling. Args: input_tensor (torch.Tensor): Input tensor Returns: - Tensor: Exponential activation + torch.Tensor: Scaled tensor after exponential activation """ return torch.exp(input_tensor) class AllInOneBlock(InvertibleModule): - r"""Module combining the most common operations in a normalizing flow or similar model. + r"""Module combining common operations in normalizing flows. - It combines affine coupling, permutation, and global affine transformation - ('ActNorm'). It can also be used as GIN coupling block, perform learned - householder permutations, and use an inverted pre-permutation. The affine - transformation includes a soft clamping mechanism, first used in Real-NVP. - The block as a whole performs the following computation: + This block combines affine coupling, permutation, and global affine + transformation ('ActNorm'). It supports: - .. math:: - - y = V R \; \Psi(s_\mathrm{global}) \odot \mathrm{Coupling}\Big(R^{-1} V^{-1} x\Big)+ t_\mathrm{global} - - - The inverse pre-permutation of x (i.e. :math:`R^{-1} V^{-1}`) is optional (see - ``reverse_permutation`` below). - - The learned householder reflection matrix - :math:`V` is also optional all together (see ``learned_householder_permutation`` - below). - - For the coupling, the input is split into :math:`x_1, x_2` along - the channel dimension. Then the output of the coupling operation is the - two halves :math:`u = \mathrm{concat}(u_1, u_2)`. - - .. math:: - - u_1 &= x_1 \odot \exp \Big( \alpha \; \mathrm{tanh}\big( s(x_2) \big)\Big) + t(x_2) \\ - u_2 &= x_2 - - Because :math:`\mathrm{tanh}(s) \in [-1, 1]`, this clamping mechanism prevents - exploding values in the exponential. The hyperparameter :math:`\alpha` can be adjusted. + - GIN coupling blocks + - Learned householder permutations + - Inverted pre-permutation + - Soft clamping mechanism from Real-NVP Args: - subnet_constructor: class or callable ``f``, called as ``f(channels_in, channels_out)`` and - should return a torch.nn.Module. Predicts coupling coefficients :math:`s, t`. - affine_clamping: clamp the output of the multiplicative coefficients before - exponentiation to +/- ``affine_clamping`` (see :math:`\alpha` above). - gin_block: Turn the block into a GIN block from Sorrenson et al, 2019. - Makes it so that the coupling operations as a whole is volume preserving. - global_affine_init: Initial value for the global affine scaling :math:`s_\mathrm{global}`. - global_affine_init: ``'SIGMOID'``, ``'SOFTPLUS'``, or ``'EXP'``. Defines the activation to be used - on the beta for the global affine scaling (:math:`\Psi` above). - permute_soft: bool, whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, - or to use hard permutations instead. Note, ``permute_soft=True`` is very slow - when working with >512 dimensions. - learned_householder_permutation: Int, if >0, turn on the matrix :math:`V` above, that represents - multiple learned householder reflections. Slow if large number. - Dubious whether it actually helps network performance. - reverse_permutation: Reverse the permutation before the block, as introduced by Putzky - et al, 2019. Turns on the :math:`R^{-1} V^{-1}` pre-multiplication above. + dims_in (list[tuple[int]]): Dimensions of input tensor(s) + dims_c (list[tuple[int]], optional): Dimensions of conditioning + tensor(s). Defaults to None. + subnet_constructor (Callable, optional): Function that constructs the + subnet, called as ``f(channels_in, channels_out)``. Defaults to None. + affine_clamping (float, optional): Clamping value for affine coupling. + Defaults to 2.0. + gin_block (bool, optional): Use GIN coupling from Sorrenson et al, 2019. + Defaults to False. + global_affine_init (float, optional): Initial value for global affine + scaling. Defaults to 1.0. + global_affine_type (str, optional): Type of activation for global affine + scaling. One of ``'SIGMOID'``, ``'SOFTPLUS'``, ``'EXP'``. + Defaults to ``'SOFTPLUS'``. + permute_soft (bool, optional): Use soft permutation matrix from SO(N). + Defaults to False. + learned_householder_permutation (int, optional): Number of learned + householder reflections. Defaults to 0. + reverse_permutation (bool, optional): Apply inverse permutation before + block. Defaults to False. + + Raises: + ValueError: If ``subnet_constructor`` is None or dimensions are invalid. """ def __init__( @@ -215,7 +272,11 @@ def __init__( self.last_jac = None def _construct_householder_permutation(self) -> torch.Tensor: - """Compute a permutation matrix from the reflection vectors that are learned internally as nn.Parameters.""" + """Compute permutation matrix from learned reflection vectors. + + Returns: + torch.Tensor: Constructed permutation matrix + """ w = self.w_0 for vk in self.vk_householder: w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk)) @@ -225,16 +286,15 @@ def _construct_householder_permutation(self) -> torch.Tensor: return w def _permute(self, x: torch.Tensor, rev: bool = False) -> tuple[Any, float | torch.Tensor]: - """Perform the permutation and scaling after the coupling operation. - - Returns transformed outputs and the LogJacDet of the scaling operation. + """Perform permutation and scaling after coupling operation. Args: x (torch.Tensor): Input tensor rev (bool, optional): Reverse the permutation. Defaults to False. Returns: - tuple[Any, float | torch.Tensor]: Transformed outputs and the LogJacDet of the scaling operation. + tuple[Any, float | torch.Tensor]: Transformed outputs and LogJacDet + of scaling """ if self.GIN: scale = 1.0 @@ -249,9 +309,16 @@ def _permute(self, x: torch.Tensor, rev: bool = False) -> tuple[Any, float | tor return (self.permute_function(x * scale + self.global_offset, self.w_perm), perm_log_jac) def _pre_permute(self, x: torch.Tensor, rev: bool = False) -> torch.Tensor: - """Permute before the coupling block. + """Permute before coupling block. + + Only used if ``reverse_permutation`` is True. + + Args: + x (torch.Tensor): Input tensor + rev (bool, optional): Reverse the permutation. Defaults to False. - It is only used if reverse_permutation is set. + Returns: + torch.Tensor: Permuted tensor """ if rev: return self.permute_function(x, self.w_perm) @@ -261,9 +328,13 @@ def _pre_permute(self, x: torch.Tensor, rev: bool = False) -> torch.Tensor: def _affine(self, x: torch.Tensor, a: torch.Tensor, rev: bool = False) -> tuple[Any, torch.Tensor]: """Perform affine coupling operation. - Given the passive half, and the pre-activation outputs of the - coupling subnetwork, perform the affine coupling operation. - Returns both the transformed inputs and the LogJacDet. + Args: + x (torch.Tensor): Input tensor (passive half) + a (torch.Tensor): Coupling network outputs + rev (bool, optional): Reverse the operation. Defaults to False. + + Returns: + tuple[Any, torch.Tensor]: Transformed tensor and LogJacDet """ # the entire coupling coefficient tensor is scaled down by a # factor of ten for stability and easier initialization. @@ -286,7 +357,18 @@ def forward( rev: bool = False, jac: bool = True, ) -> tuple[tuple[torch.Tensor], torch.Tensor]: - """See base class docstring.""" + """Forward pass through the invertible block. + + Args: + x (torch.Tensor): Input tensor + c (list, optional): Conditioning tensors. Defaults to None. + rev (bool, optional): Reverse the flow. Defaults to False. + jac (bool, optional): Compute Jacobian determinant. Defaults to True. + + Returns: + tuple[tuple[torch.Tensor], torch.Tensor]: Tuple of (output tensors, + LogJacDet) + """ del jac # Unused argument. if c is None: @@ -332,12 +414,12 @@ def forward( @staticmethod def output_dims(input_dims: list[tuple[int]]) -> list[tuple[int]]: - """Output dimensions of the layer. + """Get output dimensions of the layer. Args: - input_dims (list[tuple[int]]): Input dimensions. + input_dims (list[tuple[int]]): Input dimensions Returns: - list[tuple[int]]: Output dimensions. + list[tuple[int]]: Output dimensions """ return input_dims diff --git a/src/anomalib/models/components/layers/__init__.py b/src/anomalib/models/components/layers/__init__.py index b2937cfe0c..131f2b2258 100644 --- a/src/anomalib/models/components/layers/__init__.py +++ b/src/anomalib/models/components/layers/__init__.py @@ -1,6 +1,23 @@ -"""Neural network layers.""" +"""Neural network layers used in anomaly detection models. -# Copyright (C) 2022 Intel Corporation +This module provides custom neural network layer implementations that can be used +as building blocks in anomaly detection models. + +Classes: + SSPCAB: Spatial-Spectral Pixel-Channel Attention Block layer that combines + spatial and channel attention mechanisms. + +Example: + >>> import torch + >>> from anomalib.models.components.layers import SSPCAB + >>> # Create attention layer + >>> attention = SSPCAB(in_channels=64) + >>> # Apply attention to input tensor + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> output = attention(input_tensor) +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .sspcab import SSPCAB diff --git a/src/anomalib/models/components/layers/sspcab.py b/src/anomalib/models/components/layers/sspcab.py index ee8ce4e8b5..95d9acfa68 100644 --- a/src/anomalib/models/components/layers/sspcab.py +++ b/src/anomalib/models/components/layers/sspcab.py @@ -1,6 +1,21 @@ -"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block for reconstruction-based models. - -Paper https://arxiv.org/abs/2111.09099 +"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block. + +This module implements the SSPCAB architecture from the paper: +"SSPCAB: Self-Supervised Predictive Convolutional Attention Block for +Reconstruction-Based Anomaly Detection" +(https://arxiv.org/abs/2111.09099) + +The SSPCAB combines masked convolutions with channel attention to learn +spatial-spectral feature representations for anomaly detection. + +Example: + >>> import torch + >>> from anomalib.models.components.layers import SSPCAB + >>> # Create SSPCAB layer + >>> sspcab = SSPCAB(in_channels=64) + >>> # Apply attention to input tensor + >>> x = torch.randn(1, 64, 32, 32) + >>> output = sspcab(x) """ # Copyright (C) 2022-2024 Intel Corporation @@ -14,9 +29,23 @@ class AttentionModule(nn.Module): """Squeeze and excitation block that acts as the attention module in SSPCAB. + This module applies channel attention through global average pooling followed + by two fully connected layers with non-linearities. + Args: - channels (int): Number of input channels. - reduction_ratio (int): Reduction ratio of the attention module. + in_channels (int): Number of input channels. + reduction_ratio (int, optional): Reduction ratio for the intermediate + layer. The intermediate layer will have ``in_channels // + reduction_ratio`` channels. Defaults to 8. + + Example: + >>> import torch + >>> from anomalib.models.components.layers.sspcab import AttentionModule + >>> attention = AttentionModule(in_channels=64) + >>> x = torch.randn(1, 64, 32, 32) + >>> output = attention(x) + >>> output.shape + torch.Size([1, 64, 32, 32]) """ def __init__(self, in_channels: int, reduction_ratio: int = 8) -> None: @@ -27,7 +56,15 @@ def __init__(self, in_channels: int, reduction_ratio: int = 8) -> None: self.fc2 = nn.Linear(out_channels, in_channels) def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """Forward pass through the attention module.""" + """Forward pass through the attention module. + + Args: + inputs (torch.Tensor): Input tensor of shape + ``(batch_size, channels, height, width)``. + + Returns: + torch.Tensor: Attended output tensor of same shape as input. + """ # reduce feature map to 1d vector through global average pooling avg_pooled = inputs.mean(dim=(2, 3)) @@ -42,30 +79,78 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: class SSPCAB(nn.Module): - """SSPCAB block. + """Self-Supervised Predictive Convolutional Attention Block. + + This module combines masked convolutions with channel attention to capture + spatial and channel dependencies in the feature maps. Args: in_channels (int): Number of input channels. - kernel_size (int): Size of the receptive fields of the masked convolution kernel. - dilation (int): Dilation factor of the masked convolution kernel. - reduction_ratio (int): Reduction ratio of the attention module. + kernel_size (int, optional): Size of the receptive fields of the masked + convolution kernel. Defaults to 1. + dilation (int, optional): Dilation factor of the masked convolution + kernel. Defaults to 1. + reduction_ratio (int, optional): Reduction ratio of the attention module. + Defaults to 8. + + Example: + >>> import torch + >>> from anomalib.models.components.layers import SSPCAB + >>> sspcab = SSPCAB(in_channels=64, kernel_size=3) + >>> x = torch.randn(1, 64, 32, 32) + >>> output = sspcab(x) + >>> output.shape + torch.Size([1, 64, 32, 32]) """ - def __init__(self, in_channels: int, kernel_size: int = 1, dilation: int = 1, reduction_ratio: int = 8) -> None: + def __init__( + self, + in_channels: int, + kernel_size: int = 1, + dilation: int = 1, + reduction_ratio: int = 8, + ) -> None: super().__init__() self.pad = kernel_size + dilation self.crop = kernel_size + 2 * dilation + 1 - self.masked_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) - self.masked_conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) - self.masked_conv3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) - self.masked_conv4 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size) - - self.attention_module = AttentionModule(in_channels=in_channels, reduction_ratio=reduction_ratio) + self.masked_conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + ) + self.masked_conv2 = nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + ) + self.masked_conv3 = nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + ) + self.masked_conv4 = nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + ) + + self.attention_module = AttentionModule( + in_channels=in_channels, + reduction_ratio=reduction_ratio, + ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """Forward pass through the SSPCAB block.""" + """Forward pass through the SSPCAB block. + + Args: + inputs (torch.Tensor): Input tensor of shape + ``(batch_size, channels, height, width)``. + + Returns: + torch.Tensor: Output tensor of same shape as input. + """ # compute masked convolution padded = F.pad(inputs, (self.pad,) * 4) masked_out = torch.zeros_like(inputs) diff --git a/src/anomalib/models/components/sampling/__init__.py b/src/anomalib/models/components/sampling/__init__.py index 47c842123f..28c3df81c7 100644 --- a/src/anomalib/models/components/sampling/__init__.py +++ b/src/anomalib/models/components/sampling/__init__.py @@ -1,6 +1,23 @@ -"""Sampling methods.""" +"""Sampling methods for anomaly detection models. -# Copyright (C) 2022 Intel Corporation +This module provides sampling techniques used in anomaly detection models to +select representative samples from datasets. + +Classes: + KCenterGreedy: K-center greedy sampling algorithm that selects diverse and + representative samples. + +Example: + >>> import torch + >>> from anomalib.models.components.sampling import KCenterGreedy + >>> # Create sampler + >>> sampler = KCenterGreedy() + >>> # Sample from feature embeddings + >>> features = torch.randn(100, 512) # 100 samples with 512 dimensions + >>> selected_idx = sampler.select_coreset(features, n=10) +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .k_center_greedy import KCenterGreedy diff --git a/src/anomalib/models/components/sampling/k_center_greedy.py b/src/anomalib/models/components/sampling/k_center_greedy.py index d7ca314f33..51e29239c7 100644 --- a/src/anomalib/models/components/sampling/k_center_greedy.py +++ b/src/anomalib/models/components/sampling/k_center_greedy.py @@ -1,7 +1,9 @@ """k-Center Greedy Method. Returns points that minimizes the maximum distance of any point to a center. -- https://arxiv.org/abs/1708.00489 + +Reference: + - https://arxiv.org/abs/1708.00489 """ # Copyright (C) 2022-2024 Intel Corporation @@ -15,16 +17,28 @@ class KCenterGreedy: - """Implements k-center-greedy method. + """k-center-greedy method for coreset selection. + + This class implements the k-center-greedy method to select a coreset from an + embedding space. The method aims to minimize the maximum distance between any + point and its nearest center. Args: - embedding (torch.Tensor): Embedding vector extracted from a CNN - sampling_ratio (float): Ratio to choose coreset size from the embedding size. + embedding (torch.Tensor): Embedding tensor extracted from a CNN. + sampling_ratio (float): Ratio to determine coreset size from embedding size. + + Attributes: + embedding (torch.Tensor): Input embedding tensor. + coreset_size (int): Size of the coreset to be selected. + model (SparseRandomProjection): Dimensionality reduction model. + features (torch.Tensor): Transformed features after dimensionality reduction. + min_distances (torch.Tensor): Minimum distances to cluster centers. + n_observations (int): Number of observations in the embedding. Example: - >>> embedding.shape - torch.Size([219520, 1536]) - >>> sampler = KCenterGreedy(embedding=embedding) + >>> import torch + >>> embedding = torch.randn(219520, 1536) + >>> sampler = KCenterGreedy(embedding=embedding, sampling_ratio=0.001) >>> sampled_idxs = sampler.select_coreset_idxs() >>> coreset = embedding[sampled_idxs] >>> coreset.shape @@ -41,14 +55,14 @@ def __init__(self, embedding: torch.Tensor, sampling_ratio: float) -> None: self.n_observations = self.embedding.shape[0] def reset_distances(self) -> None: - """Reset minimum distances.""" + """Reset minimum distances to None.""" self.min_distances = None def update_distances(self, cluster_centers: list[int]) -> None: - """Update min distances given cluster centers. + """Update minimum distances given cluster centers. Args: - cluster_centers (list[int]): indices of cluster centers + cluster_centers (list[int]): Indices of cluster centers. """ if cluster_centers: centers = self.features[cluster_centers] @@ -61,12 +75,13 @@ def update_distances(self, cluster_centers: list[int]) -> None: self.min_distances = torch.minimum(self.min_distances, distance) def get_new_idx(self) -> int: - """Get index value of a sample. - - Based on minimum distance of the cluster + """Get index of the next sample based on maximum minimum distance. Returns: - int: Sample index + int: Index of the selected sample. + + Raises: + TypeError: If `self.min_distances` is not a torch.Tensor. """ if isinstance(self.min_distances, torch.Tensor): idx = int(torch.argmax(self.min_distances).item()) @@ -77,13 +92,18 @@ def get_new_idx(self) -> int: return idx def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[int]: - """Greedily form a coreset to minimize the maximum distance of a cluster. + """Greedily form a coreset to minimize maximum distance to cluster centers. Args: - selected_idxs: index of samples already selected. Defaults to an empty set. + selected_idxs (list[int] | None, optional): Indices of pre-selected + samples. Defaults to None. Returns: - indices of samples selected to minimize distance to cluster centers + list[int]: Indices of samples selected to minimize distance to cluster + centers. + + Raises: + ValueError: If a newly selected index is already in `selected_idxs`. """ if selected_idxs is None: selected_idxs = [] @@ -113,15 +133,16 @@ def sample_coreset(self, selected_idxs: list[int] | None = None) -> torch.Tensor """Select coreset from the embedding. Args: - selected_idxs: index of samples already selected. Defaults to an empty set. + selected_idxs (list[int] | None, optional): Indices of pre-selected + samples. Defaults to None. Returns: - Tensor: Output coreset + torch.Tensor: Selected coreset. Example: - >>> embedding.shape - torch.Size([219520, 1536]) - >>> sampler = KCenterGreedy(...) + >>> import torch + >>> embedding = torch.randn(219520, 1536) + >>> sampler = KCenterGreedy(embedding=embedding, sampling_ratio=0.001) >>> coreset = sampler.sample_coreset() >>> coreset.shape torch.Size([219, 1536]) diff --git a/src/anomalib/models/components/stats/__init__.py b/src/anomalib/models/components/stats/__init__.py index c65aef1caf..60f5f340fe 100644 --- a/src/anomalib/models/components/stats/__init__.py +++ b/src/anomalib/models/components/stats/__init__.py @@ -1,6 +1,26 @@ -"""Statistical functions.""" +"""Statistical functions for anomaly detection models. -# Copyright (C) 2022 Intel Corporation +This module provides statistical methods used in anomaly detection models for +density estimation and probability modeling. + +Classes: + GaussianKDE: Gaussian kernel density estimation for non-parametric density + estimation. + MultiVariateGaussian: Multivariate Gaussian distribution for parametric + density modeling. + +Example: + >>> import torch + >>> from anomalib.models.components.stats import GaussianKDE + >>> # Create density estimator + >>> kde = GaussianKDE() + >>> # Fit and evaluate density + >>> features = torch.randn(100, 10) # 100 samples, 10 dimensions + >>> kde.fit(features) + >>> density = kde.predict(features) +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .kde import GaussianKDE diff --git a/src/anomalib/models/components/stats/kde.py b/src/anomalib/models/components/stats/kde.py index d9bae9ec81..9903277a18 100644 --- a/src/anomalib/models/components/stats/kde.py +++ b/src/anomalib/models/components/stats/kde.py @@ -1,4 +1,18 @@ -"""Gaussian Kernel Density Estimation.""" +"""Gaussian Kernel Density Estimation. + +This module implements non-parametric density estimation using Gaussian kernels. +The bandwidth is selected automatically using Scott's rule. + +Example: + >>> import torch + >>> from anomalib.models.components.stats import GaussianKDE + >>> # Create density estimator + >>> kde = GaussianKDE() + >>> # Fit and evaluate density + >>> features = torch.randn(100, 10) # 100 samples, 10 dimensions + >>> kde.fit(features) + >>> density = kde.predict(features) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,8 +27,25 @@ class GaussianKDE(DynamicBufferMixin): """Gaussian Kernel Density Estimation. + Estimates probability density using a Gaussian kernel function. The bandwidth + is selected automatically using Scott's rule. + Args: - dataset (Tensor | None, optional): Dataset on which to fit the KDE model. Defaults to None. + dataset (torch.Tensor | None, optional): Dataset on which to fit the KDE + model. If provided, the model will be fitted immediately. + Defaults to ``None``. + + Example: + >>> import torch + >>> from anomalib.models.components.stats import GaussianKDE + >>> features = torch.randn(100, 10) # 100 samples, 10 dimensions + >>> # Initialize and fit in one step + >>> kde = GaussianKDE(dataset=features) + >>> # Or fit later + >>> kde = GaussianKDE() + >>> kde.fit(features) + >>> # Get density estimates + >>> density = kde(features) """ def __init__(self, dataset: torch.Tensor | None = None) -> None: @@ -32,12 +63,22 @@ def __init__(self, dataset: torch.Tensor | None = None) -> None: self.norm = torch.Tensor() def forward(self, features: torch.Tensor) -> torch.Tensor: - """Get the KDE estimates from the feature map. + """Compute KDE estimates for the input features. Args: - features (torch.Tensor): Feature map extracted from the CNN + features (torch.Tensor): Feature tensor of shape ``(N, D)`` where + ``N`` is the number of samples and ``D`` is the dimension. - Returns: KDE Estimates + Returns: + torch.Tensor: Density estimates for each input sample, shape ``(N,)``. + + Example: + >>> kde = GaussianKDE() + >>> features = torch.randn(100, 10) + >>> kde.fit(features) + >>> estimates = kde(features) + >>> estimates.shape + torch.Size([100]) """ features = torch.matmul(features, self.bw_transform) @@ -50,13 +91,19 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: return estimate def fit(self, dataset: torch.Tensor) -> None: - """Fit a KDE model to the input dataset. + """Fit the KDE model to the input dataset. + + Computes the bandwidth matrix using Scott's rule and transforms the data + accordingly. Args: - dataset (torch.Tensor): Input dataset. + dataset (torch.Tensor): Input dataset of shape ``(N, D)`` where ``N`` + is the number of samples and ``D`` is the dimension. - Returns: - None + Example: + >>> kde = GaussianKDE() + >>> features = torch.randn(100, 10) + >>> kde.fit(features) """ num_samples, dimension = dataset.shape @@ -83,10 +130,17 @@ def cov(tensor: torch.Tensor) -> torch.Tensor: """Calculate the unbiased covariance matrix. Args: - tensor (torch.Tensor): Input tensor from which covariance matrix is computed. + tensor (torch.Tensor): Input tensor of shape ``(D, N)`` where ``D`` + is the dimension and ``N`` is the number of samples. Returns: - Output covariance matrix. + torch.Tensor: Covariance matrix of shape ``(D, D)``. + + Example: + >>> x = torch.randn(5, 100) # 5 dimensions, 100 samples + >>> cov_matrix = GaussianKDE.cov(x) + >>> cov_matrix.shape + torch.Size([5, 5]) """ mean = torch.mean(tensor, dim=1) tensor -= mean[:, None] diff --git a/src/anomalib/models/components/stats/multi_variate_gaussian.py b/src/anomalib/models/components/stats/multi_variate_gaussian.py index b05edfb827..3a3b05faed 100644 --- a/src/anomalib/models/components/stats/multi_variate_gaussian.py +++ b/src/anomalib/models/components/stats/multi_variate_gaussian.py @@ -1,4 +1,20 @@ -"""Multi Variate Gaussian Distribution.""" +"""Multi Variate Gaussian Distribution. + +This module implements parametric density estimation using a multivariate Gaussian +distribution. It estimates the mean and covariance matrix from input features. + +Example: + >>> import torch + >>> from anomalib.models.components.stats import MultiVariateGaussian + >>> # Create distribution estimator + >>> mvg = MultiVariateGaussian() + >>> # Fit distribution to features + >>> features = torch.randn(100, 64, 32, 32) # B x C x H x W + >>> mean, inv_cov = mvg.fit(features) + >>> # Access distribution parameters + >>> print(mean.shape) # [64, 1024] + >>> print(inv_cov.shape) # [1024, 64, 64] +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,9 +28,24 @@ class MultiVariateGaussian(DynamicBufferMixin, nn.Module): - """Multi Variate Gaussian Distribution.""" + """Multi Variate Gaussian Distribution. + + Estimates a multivariate Gaussian distribution by computing the mean and + covariance matrix from input feature embeddings. The distribution parameters + are stored as buffers. + + Example: + >>> import torch + >>> from anomalib.models.components.stats import MultiVariateGaussian + >>> mvg = MultiVariateGaussian() + >>> features = torch.randn(100, 64, 32, 32) # B x C x H x W + >>> mean, inv_cov = mvg.fit(features) + >>> print(mean.shape) # [64, 1024] + >>> print(inv_cov.shape) # [1024, 64, 64] + """ def __init__(self) -> None: + """Initialize empty buffers for mean and inverse covariance.""" super().__init__() self.register_buffer("mean", torch.empty(0)) @@ -31,33 +62,27 @@ def _cov( ddof: int | None = None, aweights: torch.Tensor | None = None, ) -> torch.Tensor: - """Estimates covariance matrix like numpy.cov. + """Estimate covariance matrix similar to numpy.cov. Args: - observations (torch.Tensor): A 1-D or 2-D array containing multiple variables and observations. - Each row of `m` represents a variable, and each column a single - observation of all those variables. Also see `rowvar` below. - rowvar (bool): If `rowvar` is True (default), then each row represents a - variable, with observations in the columns. Otherwise, the relationship - is transposed: each column represents a variable, while the rows - contain observations. Defaults to False. - bias (bool): Default normalization (False) is by ``(N - 1)``, where ``N`` is the - number of observations given (unbiased estimate). If `bias` is True, - then normalization is by ``N``. These values can be overridden by using - the keyword ``ddof`` in numpy versions >= 1.5. Defaults to False - ddof (int | None): If not ``None`` the default value implied by `bias` is overridden. - Note that ``ddof=1`` will return the unbiased estimate, even if both - `fweights` and `aweights` are specified, and ``ddof=0`` will return - the simple average. See the notes for the details. The default value - is ``None``. - aweights (torch.Tensor): 1-D array of observation vector weights. These relative weights are - typically large for observations considered "important" and smaller for - observations considered less "important". If ``ddof=0`` the array of - weights can be used to assign probabilities to observation vectors. (Default value = None) - + observations: A 1-D or 2-D tensor containing multiple variables and + observations. Each row represents a variable, and each column a + single observation of all variables if ``rowvar=True``. The + relationship is transposed if ``rowvar=False``. + rowvar: If ``True``, each row represents a variable. If ``False``, + each column represents a variable. Defaults to ``False``. + bias: If ``False`` (default), normalize by ``(N-1)`` for unbiased + estimate. If ``True``, normalize by ``N``. Can be overridden by + ``ddof``. + ddof: Delta degrees of freedom. If not ``None``, overrides ``bias``. + ``ddof=1`` gives unbiased estimate, ``ddof=0`` gives simple + average. + aweights: Optional 1-D tensor of observation weights. Larger weights + indicate more "important" observations. If ``ddof=0``, weights + are treated as observation probabilities. Returns: - The covariance matrix of the variables. + Covariance matrix of the variables. """ # ensure at least 2D if observations.dim() == 1: @@ -75,7 +100,7 @@ def _cov( if weights is not None: if not torch.is_tensor(weights): - weights = torch.tensor(weights, dtype=torch.float) # pylint: disable=not-callable + weights = torch.tensor(weights, dtype=torch.float) weights_sum = torch.sum(weights) avg = torch.sum(observations * (weights / weights_sum)[:, None], 0) else: @@ -101,13 +126,20 @@ def _cov( return covariance.squeeze() def forward(self, embedding: torch.Tensor) -> list[torch.Tensor]: - """Calculate multivariate Gaussian distribution. + """Calculate multivariate Gaussian distribution parameters. + + Computes the mean and inverse covariance matrix from input feature + embeddings. A small regularization term (0.01) is added to the diagonal + of the covariance matrix for numerical stability. Args: - embedding (torch.Tensor): CNN features whose dimensionality is reduced via either random sampling or PCA. + embedding: Input tensor of shape ``(B, C, H, W)`` containing CNN + feature embeddings. Returns: - mean and inverse covariance of the multi-variate gaussian distribution that fits the features. + List containing: + - Mean tensor of shape ``(C, H*W)`` + - Inverse covariance tensor of shape ``(H*W, C, C)`` """ device = embedding.device @@ -125,12 +157,16 @@ def forward(self, embedding: torch.Tensor) -> list[torch.Tensor]: return [self.mean, self.inv_covariance] def fit(self, embedding: torch.Tensor) -> list[torch.Tensor]: - """Fit multi-variate gaussian distribution to the input embedding. + """Fit multivariate Gaussian distribution to input embeddings. + + Convenience method that calls ``forward()`` to compute distribution + parameters. Args: - embedding (torch.Tensor): Embedding vector extracted from CNN. + embedding: Input tensor of shape ``(B, C, H, W)`` containing CNN + feature embeddings. Returns: - Mean and the covariance of the embedding. + List containing the mean and inverse covariance tensors. """ return self.forward(embedding) diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index c8ce0987b2..388c6002a7 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -1,6 +1,39 @@ -"""Anomalib Image Models.""" +"""Anomalib Image Models. -# Copyright (C) 2023 Intel Corporation +This module contains implementations of various deep learning models for image-based +anomaly detection. + +Example: + >>> from anomalib.models.image import Padim, Patchcore + >>> # Initialize a model + >>> model = Padim() # doctest: +SKIP + >>> # Train on normal images + >>> model.fit(["normal1.jpg", "normal2.jpg"]) # doctest: +SKIP + >>> # Get predictions + >>> predictions = model.predict("test.jpg") # doctest: +SKIP + +Available Models: + - :class:`Cfa`: Contrastive Feature Aggregation + - :class:`Cflow`: Conditional Normalizing Flow + - :class:`Csflow`: Conditional Split Flow + - :class:`Dfkde`: Deep Feature Kernel Density Estimation + - :class:`Dfm`: Deep Feature Modeling + - :class:`Draem`: Dual Reconstruction by Adversarial Masking + - :class:`Dsr`: Deep Spatial Reconstruction + - :class:`EfficientAd`: Efficient Anomaly Detection + - :class:`Fastflow`: Fast Flow + - :class:`Fre`: Feature Reconstruction Error + - :class:`Ganomaly`: Generative Adversarial Networks + - :class:`Padim`: Patch Distribution Modeling + - :class:`Patchcore`: Patch Core + - :class:`ReverseDistillation`: Reverse Knowledge Distillation + - :class:`Stfpm`: Student-Teacher Feature Pyramid Matching + - :class:`Uflow`: Unsupervised Flow + - :class:`VlmAd`: Vision Language Model Anomaly Detection + - :class:`WinClip`: Zero-/Few-Shot CLIP-based Detection +""" + +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .cfa import Cfa diff --git a/src/anomalib/models/image/cfa/__init__.py b/src/anomalib/models/image/cfa/__init__.py index def95441cb..962612f974 100644 --- a/src/anomalib/models/image/cfa/__init__.py +++ b/src/anomalib/models/image/cfa/__init__.py @@ -1,8 +1,23 @@ -"""Implementatation of the CFA Model. +"""Implementation of the CFA (Coupled-hypersphere-based Feature Adaptation) model. -CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization +This module provides the CFA model for target-oriented anomaly localization. CFA +learns discriminative features by adapting them to coupled hyperspheres in the +feature space. -Paper https://arxiv.org/abs/2206.04325 +The model uses a teacher-student architecture where the teacher network extracts +features from normal samples to guide the student network in learning +anomaly-sensitive representations. + +Paper: https://arxiv.org/abs/2206.04325 + +Example: + >>> from anomalib.models.image import Cfa + >>> # Initialize the model + >>> model = Cfa() + >>> # Train on normal samples + >>> model.fit(normal_samples) + >>> # Get anomaly predictions + >>> predictions = model.predict(test_samples) """ # Copyright (C) 2022-2024 Intel Corporation diff --git a/src/anomalib/models/image/cfa/anomaly_map.py b/src/anomalib/models/image/cfa/anomaly_map.py index 5c35881c83..8f65c21f9c 100644 --- a/src/anomalib/models/image/cfa/anomaly_map.py +++ b/src/anomalib/models/image/cfa/anomaly_map.py @@ -1,4 +1,18 @@ -"""Anomaly Map Generator for the CFA model implementation.""" +"""Anomaly Map Generator for the CFA model implementation. + +This module provides functionality to generate anomaly heatmaps from distance +features computed by the CFA model. + +Example: + >>> import torch + >>> from anomalib.models.image.cfa.anomaly_map import AnomalyMapGenerator + >>> # Initialize generator + >>> generator = AnomalyMapGenerator(num_nearest_neighbors=3) + >>> # Generate anomaly map + >>> distance = torch.randn(1, 1024, 1) # batch x pixels x 1 + >>> scale = (32, 32) # height x width + >>> anomaly_map = generator(distance=distance, scale=scale) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,7 +26,24 @@ class AnomalyMapGenerator(nn.Module): - """Generate Anomaly Heatmap.""" + """Generate anomaly heatmaps from distance features. + + The generator computes anomaly scores based on k-nearest neighbor distances + and applies Gaussian smoothing to produce the final heatmap. + + Args: + num_nearest_neighbors (int): Number of nearest neighbors to consider + when computing anomaly scores. + sigma (int, optional): Standard deviation for Gaussian smoothing. + Defaults to ``4``. + + Example: + >>> import torch + >>> generator = AnomalyMapGenerator(num_nearest_neighbors=3) + >>> distance = torch.randn(1, 1024, 1) # batch x pixels x 1 + >>> scale = (32, 32) # height x width + >>> anomaly_map = generator(distance=distance, scale=scale) + """ def __init__( self, @@ -24,16 +55,17 @@ def __init__( self.sigma = sigma def compute_score(self, distance: torch.Tensor, scale: tuple[int, int]) -> torch.Tensor: - """Compute score based on the distance. + """Compute anomaly scores from distance features. Args: - distance (torch.Tensor): Distance tensor computed using target oriented - features. - scale (tuple[int, int]): Height and width of the largest feature - map. + distance (torch.Tensor): Distance tensor of shape + ``(batch_size, num_pixels, 1)``. + scale (tuple[int, int]): Height and width of the feature map used + to reshape the scores. Returns: - Tensor: Score value. + torch.Tensor: Anomaly scores of shape + ``(batch_size, 1, height, width)``. """ distance = torch.sqrt(distance) distance = distance.topk(self.num_nearest_neighbors, largest=False).values # noqa: PD011 @@ -48,14 +80,17 @@ def compute_anomaly_map( score: torch.Tensor, image_size: tuple[int, int] | torch.Size | None = None, ) -> torch.Tensor: - """Compute anomaly map based on the score. + """Generate smoothed anomaly map from scores. Args: - score (torch.Tensor): Score tensor. - image_size (tuple[int, int] | torch.Size | None, optional): Size of the input image. + score (torch.Tensor): Anomaly scores of shape + ``(batch_size, 1, height, width)``. + image_size (tuple[int, int] | torch.Size | None, optional): Target + size for upsampling the anomaly map. Defaults to ``None``. Returns: - Tensor: Anomaly map. + torch.Tensor: Smoothed anomaly map of shape + ``(batch_size, 1, height, width)``. """ anomaly_map = score.mean(dim=1, keepdim=True) if image_size is not None: @@ -65,16 +100,27 @@ def compute_anomaly_map( return gaussian_blur(anomaly_map) # pylint: disable=not-callable def forward(self, **kwargs) -> torch.Tensor: - """Return anomaly map. + """Generate anomaly map from input features. + + The method expects ``distance`` and ``scale`` as required inputs, with + optional ``image_size`` for upsampling. + + Args: + **kwargs: Keyword arguments containing: + - distance (torch.Tensor): Distance features + - scale (tuple[int, int]): Feature map scale + - image_size (tuple[int, int] | torch.Size, optional): + Target size for upsampling Raises: - ``distance`` and ``scale`` keys are not found. + ValueError: If required arguments are missing. Returns: - Tensor: Anomaly heatmap. + torch.Tensor: Anomaly heatmap of shape + ``(batch_size, 1, height, width)``. """ if not ("distance" in kwargs and "scale" in kwargs): - msg = f"Expected keys `distance` and `scale. Found {kwargs.keys()}" + msg = f"Expected keys `distance` and `scale`. Found {kwargs.keys()}" raise ValueError(msg) distance: torch.Tensor = kwargs["distance"] diff --git a/src/anomalib/models/image/cfa/lightning_model.py b/src/anomalib/models/image/cfa/lightning_model.py index 9eed15b6a7..650a3277d4 100644 --- a/src/anomalib/models/image/cfa/lightning_model.py +++ b/src/anomalib/models/image/cfa/lightning_model.py @@ -1,8 +1,11 @@ -"""Lightning Implementatation of the CFA Model. +"""Lightning Implementation of the CFA Model. -CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization +CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly +Localization. -Paper https://arxiv.org/abs/2206.04325 +Paper: https://arxiv.org/abs/2206.04325 + +This implementation uses PyTorch Lightning for training and inference. """ # Copyright (C) 2022-2024 Intel Corporation @@ -31,24 +34,35 @@ class Cfa(AnomalibModule): - """CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization. + """CFA Lightning Module. + + The CFA model performs anomaly detection and localization using coupled + hypersphere-based feature adaptation. Args: - backbone (str): Backbone CNN network + backbone (str): Name of the backbone CNN network. Defaults to ``"wide_resnet50_2"``. - gamma_c (int, optional): gamma_c value from the paper. + gamma_c (int, optional): Centroid loss weight parameter. Defaults to ``1``. - gamma_d (int, optional): gamma_d value from the paper. + gamma_d (int, optional): Distance loss weight parameter. Defaults to ``1``. - num_nearest_neighbors (int): Number of nearest neighbors. + num_nearest_neighbors (int): Number of nearest neighbors to consider. Defaults to ``3``. - num_hard_negative_features (int): Number of hard negative features. + num_hard_negative_features (int): Number of hard negative features to use. Defaults to ``3``. - radius (float): Radius of the hypersphere to search the soft boundary. + radius (float): Radius of the hypersphere for soft boundary search. Defaults to ``1e-5``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Pre-processor instance or + boolean flag. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance or + boolean flag. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or boolean flag. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or boolean + flag. + Defaults to ``True``. """ def __init__( @@ -86,19 +100,23 @@ def __init__( ) def on_train_start(self) -> None: - """Initialize the centroid for the memory bank computation.""" + """Initialize the centroid for memory bank computation. + + This method is called at the start of training to compute the initial + centroid using the training data. + """ self.model.initialize_centroid(data_loader=self.trainer.datamodule.train_dataloader()) def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the training step for the CFA model. + """Perform a training step. Args: - batch (Batch): Batch input. - *args: Arguments. - **kwargs: Keyword arguments. + batch (Batch): Input batch containing images and metadata. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). Returns: - STEP_OUTPUT: Loss value. + STEP_OUTPUT: Dictionary containing the loss value. """ del args, kwargs # These variables are not used. @@ -107,15 +125,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: return {"loss": loss} def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the validation step for the CFA model. + """Perform a validation step. Args: - batch (Batch): Input batch. - *args: Arguments. - **kwargs: Keyword arguments. + batch (Batch): Input batch containing images and metadata. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). Returns: - dict: Anomaly map computed by the model. + STEP_OUTPUT: Batch object updated with model predictions. """ del args, kwargs # These variables are not used. @@ -124,12 +142,16 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @staticmethod def backward(loss: torch.Tensor, *args, **kwargs) -> None: - """Perform backward-pass for the CFA model. + """Perform backward pass. Args: - loss (torch.Tensor): Loss value. - *args: Arguments. - **kwargs: Keyword arguments. + loss (torch.Tensor): Computed loss value. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Note: + Uses ``retain_graph=True`` due to computational graph requirements. + See CVS-122673 for more details. """ del args, kwargs # These variables are not used. @@ -139,14 +161,24 @@ def backward(loss: torch.Tensor, *args, **kwargs) -> None: @property def trainer_arguments(self) -> dict[str, Any]: - """CFA specific trainer arguments.""" + """Get CFA-specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary containing trainer configuration: + - ``gradient_clip_val``: Set to ``0`` to disable gradient clipping + - ``num_sanity_val_steps``: Set to ``0`` to skip validation sanity + checks + """ return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure optimizers for the CFA Model. + """Configure the optimizer. Returns: - Optimizer: Adam optimizer for each decoder + torch.optim.Optimizer: AdamW optimizer configured with: + - Learning rate: ``1e-3`` + - Weight decay: ``5e-4`` + - AMSGrad: ``True`` """ return torch.optim.AdamW( params=self.model.parameters(), @@ -157,9 +189,9 @@ def configure_optimizers(self) -> torch.optim.Optimizer: @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type. Returns: - LearningType: Learning type of the model. + LearningType: Indicates this is a one-class classification model. """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/cfa/loss.py b/src/anomalib/models/image/cfa/loss.py index 91b9d270f6..13e67a66d2 100644 --- a/src/anomalib/models/image/cfa/loss.py +++ b/src/anomalib/models/image/cfa/loss.py @@ -1,4 +1,23 @@ -"""Loss function for the Cfa Model Implementation.""" +"""Loss function for the CFA (Coupled-hypersphere-based Feature Adaptation) model. + +This module implements the loss function used to train the CFA model for anomaly +detection. The loss consists of two components: + 1. Attraction loss that pulls normal samples inside a hypersphere + 2. Repulsion loss that pushes anomalous samples outside the hypersphere + +Example: + >>> import torch + >>> from anomalib.models.image.cfa.loss import CfaLoss + >>> # Initialize loss function + >>> loss_fn = CfaLoss( + ... num_nearest_neighbors=3, + ... num_hard_negative_features=3, + ... radius=0.5 + ... ) + >>> # Compute loss on distance tensor + >>> distance = torch.randn(2, 1024, 1) # batch x pixels x 1 + >>> loss = loss_fn(distance) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,12 +27,28 @@ class CfaLoss(nn.Module): - """Cfa Loss. + """Loss function for the CFA model. + + The loss encourages normal samples to lie within a hypersphere while pushing + anomalous samples outside. It uses k-nearest neighbors to identify the closest + samples and hard negative mining to find challenging anomalous examples. Args: - num_nearest_neighbors (int): Number of nearest neighbors. - num_hard_negative_features (int): Number of hard negative features. - radius (float): Radius of the hypersphere to search the soft boundary. + num_nearest_neighbors (int): Number of nearest neighbors to consider for + the attraction loss component. + num_hard_negative_features (int): Number of hard negative features to use + for the repulsion loss component. + radius (float): Initial radius of the hypersphere that defines the + decision boundary between normal and anomalous samples. + + Example: + >>> loss_fn = CfaLoss( + ... num_nearest_neighbors=3, + ... num_hard_negative_features=3, + ... radius=0.5 + ... ) + >>> distance = torch.randn(2, 1024, 1) # batch x pixels x 1 + >>> loss = loss_fn(distance) """ def __init__(self, num_nearest_neighbors: int, num_hard_negative_features: int, radius: float) -> None: @@ -23,13 +58,22 @@ def __init__(self, num_nearest_neighbors: int, num_hard_negative_features: int, self.radius = torch.ones(1, requires_grad=True) * radius def forward(self, distance: torch.Tensor) -> torch.Tensor: - """Compute the CFA loss. + """Compute the CFA loss given distance features. + + The loss has two components: + 1. Attraction loss (`l_att`): Encourages normal samples to lie within + the hypersphere by penalizing distances greater than `radius`. + 2. Repulsion loss (`l_rep`): Pushes anomalous samples outside the + hypersphere by penalizing distances less than `radius + margin`. Args: - distance (torch.Tensor): Distance computed using target oriented features. + distance (torch.Tensor): Distance tensor of shape + ``(batch_size, num_pixels, 1)`` computed using target-oriented + features. Returns: - Tensor: CFA loss. + torch.Tensor: Scalar loss value combining attraction and repulsion + components. """ num_neighbors = self.num_nearest_neighbors + self.num_hard_negative_features distance = distance.topk(num_neighbors, largest=False).values # noqa: PD011 diff --git a/src/anomalib/models/image/cfa/torch_model.py b/src/anomalib/models/image/cfa/torch_model.py index e36d53050e..799f273a02 100644 --- a/src/anomalib/models/image/cfa/torch_model.py +++ b/src/anomalib/models/image/cfa/torch_model.py @@ -1,8 +1,35 @@ -"""Torch Implementatation of the CFA Model. - -CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly Localization - -Paper https://arxiv.org/abs/2206.04325 +"""Torch Implementation of the CFA Model. + +CFA: Coupled-hypersphere-based Feature Adaptation for Target-Oriented Anomaly +Localization. + +This module provides the PyTorch implementation of the CFA model for anomaly +detection and localization. The model learns discriminative features by adapting +them to coupled hyperspheres in the feature space. + +The model consists of: + - A backbone CNN feature extractor + - A descriptor network that generates target-oriented features + - A memory bank that stores prototypical normal features + - An anomaly map generator for localization + +Paper: https://arxiv.org/abs/2206.04325 + +Example: + >>> import torch + >>> from anomalib.models.image.cfa.torch_model import CfaModel + >>> # Initialize model + >>> model = CfaModel( + ... backbone="resnet18", + ... gamma_c=1, + ... gamma_d=1, + ... num_nearest_neighbors=3, + ... num_hard_negative_features=3, + ... radius=0.5 + ... ) + >>> # Forward pass + >>> x = torch.randn(32, 3, 256, 256) + >>> predictions = model(x) """ # Copyright (C) 2022-2024 Intel Corporation @@ -30,18 +57,23 @@ def get_return_nodes(backbone: str) -> list[str]: - """Get the return nodes for a given backbone. + """Get the return nodes for feature extraction from a backbone network. Args: - backbone (str): The name of the backbone. Must be one of - {"resnet18", "wide_resnet50_2", "vgg19_bn", "efficientnet_b5"}. + backbone (str): Name of the backbone CNN. Must be one of + ``{"resnet18", "wide_resnet50_2", "vgg19_bn", "efficientnet_b5"}``. Raises: - NotImplementedError: If the backbone is "efficientnet_b5". - ValueError: If the backbone is not one of the supported backbones. + NotImplementedError: If ``backbone`` is "efficientnet_b5". + ValueError: If ``backbone`` is not one of the supported backbones. Returns: - list[str]: A list of return nodes for the given backbone. + list[str]: List of layer names to extract features from. + + Example: + >>> nodes = get_return_nodes("resnet18") + >>> print(nodes) + ['layer1', 'layer2', 'layer3'] """ if backbone == "efficientnet_b5": msg = "EfficientNet feature extractor has not implemented yet." @@ -61,18 +93,24 @@ def get_return_nodes(backbone: str) -> list[str]: # TODO(samet-akcay): Replace this with the new torchfx feature extractor. # CVS-122673 def get_feature_extractor(backbone: str, return_nodes: list[str]) -> GraphModule: - """Get the feature extractor from the backbone CNN. + """Create a feature extractor from a backbone CNN. Args: - backbone (str): Backbone CNN network - return_nodes (list[str]): A list of return nodes for the given backbone. + backbone (str): Name of the backbone CNN network. + return_nodes (list[str]): List of layer names to extract features from. Raises: - NotImplementedError: When the backbone is efficientnet_b5 - ValueError: When the backbone is not supported + NotImplementedError: When ``backbone`` is efficientnet_b5. + ValueError: When ``backbone`` is not supported. Returns: - GraphModule: Feature extractor. + GraphModule: Feature extractor module. + + Example: + >>> nodes = ["layer1", "layer2", "layer3"] + >>> extractor = get_feature_extractor("resnet18", nodes) + >>> x = torch.randn(1, 3, 224, 224) + >>> features = extractor(x) """ model = getattr(torchvision.models, backbone)(pretrained=True) feature_extractor = create_feature_extractor(model=model, return_nodes=return_nodes) @@ -84,13 +122,29 @@ def get_feature_extractor(backbone: str, return_nodes: list[str]) -> GraphModule class CfaModel(DynamicBufferMixin): """Torch implementation of the CFA Model. + The model learns discriminative features by adapting them to coupled + hyperspheres in the feature space. It uses a teacher-student architecture + where the teacher network extracts features from normal samples to guide the + student network. + Args: - backbone (str): Backbone CNN network. - gamma_c (int): gamma_c parameter from the paper. - gamma_d (int): gamma_d parameter from the paper. - num_nearest_neighbors (int): Number of nearest neighbors. - num_hard_negative_features (int): Number of hard negative features. - radius (float): Radius of the hypersphere to search the soft boundary. + backbone (str): Name of the backbone CNN network. + gamma_c (int): Weight for centroid loss. + gamma_d (int): Weight for distance loss. + num_nearest_neighbors (int): Number of nearest neighbors for score + computation. + num_hard_negative_features (int): Number of hard negative features to use. + radius (float): Initial radius of the hypersphere decision boundary. + + Example: + >>> model = CfaModel( + ... backbone="resnet18", + ... gamma_c=1, + ... gamma_d=1, + ... num_nearest_neighbors=3, + ... num_hard_negative_features=3, + ... radius=0.5 + ... ) """ def __init__( @@ -124,10 +178,18 @@ def __init__( ) def get_scale(self, input_size: tuple[int, int] | torch.Size) -> torch.Size: - """Get the scale of the feature map. + """Get the scale of the feature maps. Args: - input_size (tuple[int, int]): Input size of the image tensor. + input_size (tuple[int, int] | torch.Size): Input image dimensions + (height, width). + + Returns: + torch.Size: Feature map dimensions. + + Example: + >>> model = CfaModel(...) + >>> scale = model.get_scale((256, 256)) """ feature_map_metadata = dryrun_find_featuremap_dims( feature_extractor=self.feature_extractor, @@ -148,13 +210,20 @@ def get_scale(self, input_size: tuple[int, int] | torch.Size) -> torch.Size: return scale def initialize_centroid(self, data_loader: DataLoader) -> None: - """Initialize the Centroid of the Memory Bank. + """Initialize the centroid of the memory bank. - Args: - data_loader (DataLoader): Train Dataloader. + Computes the average feature representation of normal samples to + initialize the memory bank centroids. - Returns: - Tensor: Memory Bank. + Args: + data_loader (DataLoader): DataLoader containing normal training + samples. + + Example: + >>> from torch.utils.data import DataLoader + >>> model = CfaModel(...) + >>> train_loader = DataLoader(...) + >>> model.initialize_centroid(train_loader) """ device = next(self.feature_extractor.parameters()).device with torch.no_grad(): @@ -179,14 +248,19 @@ def initialize_centroid(self, data_loader: DataLoader) -> None: self.memory_bank = rearrange(self.memory_bank, "h w -> w h") def compute_distance(self, target_oriented_features: torch.Tensor) -> torch.Tensor: - """Compute distance using target oriented features. + """Compute distances between features and memory bank centroids. Args: - target_oriented_features (torch.Tensor): Target oriented features computed - using the descriptor. + target_oriented_features (torch.Tensor): Features from the descriptor + network. Returns: - Tensor: Distance tensor. + torch.Tensor: Distance tensor. + + Example: + >>> model = CfaModel(...) + >>> features = torch.randn(32, 256, 32, 32) # B x C x H x W + >>> distances = model.compute_distance(features) """ if target_oriented_features.ndim == 4: target_oriented_features = rearrange(target_oriented_features, "b c h w -> b (h w) c") @@ -197,16 +271,22 @@ def compute_distance(self, target_oriented_features: torch.Tensor) -> torch.Tens return features + centers - f_c def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: - """Forward pass. + """Forward pass through the model. Args: - input_tensor (torch.Tensor): Input tensor. + input_tensor (torch.Tensor): Input image tensor. Raises: ValueError: When the memory bank is not initialized. Returns: - Tensor: Loss or anomaly map depending on the train/eval mode. + torch.Tensor | InferenceBatch: During training, returns distance + tensor. During inference, returns anomaly predictions. + + Example: + >>> model = CfaModel(...) + >>> x = torch.randn(32, 3, 256, 256) + >>> predictions = model(x) """ if self.memory_bank.ndim == 0: msg = "Memory bank is not initialized. Run `initialize_centroid` method first." @@ -233,7 +313,20 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: class Descriptor(nn.Module): - """Descriptor module.""" + """Descriptor network that generates target-oriented features. + + Args: + gamma_d (int): Weight for distance loss. + backbone (str): Name of the backbone CNN network. + + Raises: + ValueError: If ``backbone`` is not supported. + + Example: + >>> descriptor = Descriptor(gamma_d=1, backbone="resnet18") + >>> features = [torch.randn(32, 64, 64, 64)] + >>> target_features = descriptor(features) + """ def __init__(self, gamma_d: int, backbone: str) -> None: super().__init__() @@ -252,7 +345,20 @@ def __init__(self, gamma_d: int, backbone: str) -> None: self.layer = CoordConv2d(in_channels=dim, out_channels=out_channels, kernel_size=1) def forward(self, features: list[torch.Tensor] | dict[str, torch.Tensor]) -> torch.Tensor: - """Forward pass.""" + """Forward pass through the descriptor network. + + Args: + features (list[torch.Tensor] | dict[str, torch.Tensor]): Features + from the backbone network. + + Returns: + torch.Tensor: Target-oriented features. + + Example: + >>> descriptor = Descriptor(gamma_d=1, backbone="resnet18") + >>> features = [torch.randn(32, 64, 64, 64)] + >>> target_features = descriptor(features) + """ if isinstance(features, dict): features = list(features.values()) @@ -273,13 +379,36 @@ def forward(self, features: list[torch.Tensor] | dict[str, torch.Tensor]) -> tor class CoordConv2d(nn.Conv2d): - """CoordConv layer as in the paper. + """CoordConv layer that adds coordinate channels to input features. + + Implementation based on the paper "An Intriguing Failing of Convolutional + Neural Networks and the CoordConv Solution". MIT License Copyright (c) 2018 Walsvid - Link to the paper: https://arxiv.org/abs/1807.03247 - Link to the PyTorch implementation: https://github.com/walsvid/CoordConv + Paper: https://arxiv.org/abs/1807.03247 + Code: https://github.com/walsvid/CoordConv + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (_size_2_t): Size of the convolution kernel. + stride (_size_2_t, optional): Stride of the convolution. + Defaults to ``1``. + padding (str | _size_2_t, optional): Padding added to input. + Defaults to ``0``. + dilation (_size_2_t, optional): Dilation of the kernel. + Defaults to ``1``. + groups (int, optional): Number of blocked connections. Defaults to ``1``. + bias (bool, optional): If True, adds learnable bias. Defaults to ``True``. + with_r (bool, optional): If True, adds radial coordinate channel. + Defaults to ``False``. + + Example: + >>> conv = CoordConv2d(64, 128, kernel_size=3) + >>> x = torch.randn(32, 64, 32, 32) + >>> out = conv(x) """ def __init__( @@ -309,7 +438,7 @@ def __init__( # Create conv layer on top of add_coords layer. self.conv2d = nn.Conv2d( - in_channels=in_channels + 2 + int(with_r), # 2 for rank-2 tensor, 1 for r if with_r + in_channels=in_channels + 2 + int(with_r), # 2 for rank-2, 1 for r out_channels=out_channels, kernel_size=kernel_size, stride=stride, @@ -319,27 +448,42 @@ def __init__( bias=bias, ) - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-renamed - """Forward pass. + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Forward pass through the CoordConv layer. Args: input_tensor (torch.Tensor): Input tensor. Returns: - Tensor: Output tensor after applying the CoordConv layer. + torch.Tensor: Output tensor after applying coordinates and + convolution. + + Example: + >>> conv = CoordConv2d(64, 128, kernel_size=3) + >>> x = torch.randn(32, 64, 32, 32) + >>> out = conv(x) """ out = self.add_coords(input_tensor) return self.conv2d(out) class AddCoords(nn.Module): - """Add coords to a tensor. + """Module that adds coordinate channels to input tensor. MIT License Copyright (c) 2018 Walsvid - Link to the paper: https://arxiv.org/abs/1807.03247 - Link to the PyTorch implementation: https://github.com/walsvid/CoordConv + Paper: https://arxiv.org/abs/1807.03247 + Code: https://github.com/walsvid/CoordConv + + Args: + with_r (bool, optional): If True, adds radial coordinate channel. + Defaults to ``False``. + + Example: + >>> coord_adder = AddCoords() + >>> x = torch.randn(32, 64, 32, 32) + >>> out = coord_adder(x) # adds x,y coordinate channels """ def __init__(self, with_r: bool = False) -> None: @@ -347,13 +491,18 @@ def __init__(self, with_r: bool = False) -> None: self.with_r = with_r def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Forward pass. + """Add coordinate channels to input tensor. Args: - input_tensor (torch.Tensor): Input tensor + input_tensor (torch.Tensor): Input tensor. Returns: - Tensor: Output tensor with added coordinates. + torch.Tensor: Tensor with added coordinate channels. + + Example: + >>> coord_adder = AddCoords() + >>> x = torch.randn(32, 64, 32, 32) + >>> out = coord_adder(x) # adds x,y coordinate channels """ # NOTE: This is a modified version of the original implementation, # which only supports rank 2 tensors. diff --git a/src/anomalib/models/image/cflow/__init__.py b/src/anomalib/models/image/cflow/__init__.py index d6d4bfde71..61fa9838a5 100644 --- a/src/anomalib/models/image/cflow/__init__.py +++ b/src/anomalib/models/image/cflow/__init__.py @@ -1,4 +1,26 @@ -"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.""" +"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. + +This module provides the implementation of CFLOW model for anomaly detection. +CFLOW uses conditional normalizing flows to model the distribution of normal +samples in the feature space. + +Example: + >>> from anomalib.models.image.cflow import Cflow + >>> # Initialize the model + >>> model = Cflow( + ... backbone="resnet18", + ... flow_steps=8, + ... hidden_ratio=1.0, + ... coupling_blocks=4, + ... clamp_alpha=1.9, + ... permute_soft=False + ... ) + >>> # Forward pass + >>> x = torch.randn(32, 3, 256, 256) + >>> predictions = model(x) + +Paper: https://arxiv.org/abs/2107.12571 +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/cflow/anomaly_map.py b/src/anomalib/models/image/cflow/anomaly_map.py index b212ddcc36..f13b940df8 100644 --- a/src/anomalib/models/image/cflow/anomaly_map.py +++ b/src/anomalib/models/image/cflow/anomaly_map.py @@ -1,4 +1,25 @@ -"""Anomaly Map Generator for CFlow model implementation.""" +"""Anomaly Map Generator for CFlow model implementation. + +This module provides the anomaly map generation functionality for the CFlow model. +The generator takes feature distributions from multiple layers and combines them +into a single anomaly heatmap. + +Example: + >>> from anomalib.models.image.cflow.anomaly_map import AnomalyMapGenerator + >>> import torch + >>> # Initialize generator + >>> pool_layers = ["layer1", "layer2", "layer3"] + >>> generator = AnomalyMapGenerator(pool_layers=pool_layers) + >>> # Generate anomaly map + >>> distribution = [torch.randn(32, 64) for _ in range(3)] + >>> height = [32, 16, 8] + >>> width = [32, 16, 8] + >>> anomaly_map = generator( + ... distribution=distribution, + ... height=height, + ... width=width + ... ) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,7 +33,27 @@ class AnomalyMapGenerator(nn.Module): - """Generate Anomaly Heatmap.""" + """Generate anomaly heatmap from layer-wise feature distributions. + + The generator combines likelihood estimations from multiple feature layers into + a single anomaly heatmap by upsampling and aggregating the scores. + + Args: + pool_layers (Sequence[str]): Names of pooling layers from which to extract + features. + + Example: + >>> pool_layers = ["layer1", "layer2", "layer3"] + >>> generator = AnomalyMapGenerator(pool_layers=pool_layers) + >>> distribution = [torch.randn(32, 64) for _ in range(3)] + >>> height = [32, 16, 8] + >>> width = [32, 16, 8] + >>> anomaly_map = generator( + ... distribution=distribution, + ... height=height, + ... width=width + ... ) + """ def __init__( self, @@ -29,17 +70,22 @@ def compute_anomaly_map( width: list[int], image_size: tuple[int, int] | torch.Size | None, ) -> torch.Tensor: - """Compute the layer map based on likelihood estimation. + """Compute anomaly map from layer-wise likelihood distributions. + + The method normalizes likelihood scores from each layer, upsamples them to + a common size, and combines them into a final anomaly map. Args: - distribution (list[torch.Tensor]): List of likelihoods for each layer. - height (list[int]): List of heights of the feature maps. - width (list[int]): List of widths of the feature maps. - image_size (tuple[int, int] | torch.Size | None): Size of the input image. + distribution (list[torch.Tensor]): List of likelihood distributions for + each layer. + height (list[int]): List of feature map heights for each layer. + width (list[int]): List of feature map widths for each layer. + image_size (tuple[int, int] | torch.Size | None): Target size for the + output anomaly map. If None, keeps the original size. Returns: - Final Anomaly Map - + torch.Tensor: Anomaly map tensor where higher values indicate higher + likelihood of anomaly. """ layer_maps: list[torch.Tensor] = [] for layer_idx in range(len(self.pool_layers)): @@ -65,20 +111,36 @@ def compute_anomaly_map( return score_map.max() - score_map def forward(self, **kwargs: list[torch.Tensor] | list[int] | list[list]) -> torch.Tensor: - """Return anomaly_map. + """Generate anomaly map from input feature distributions. - Expects `distribution`, `height` and 'width' keywords to be passed explicitly + The method expects keyword arguments containing the feature distributions + and corresponding spatial dimensions. + + Args: + **kwargs: Keyword arguments containing: + - distribution (list[torch.Tensor]): Feature distributions + - height (list[int]): Feature map heights + - width (list[int]): Feature map widths + - image_size (tuple[int, int] | torch.Size | None, optional): + Target output size Example: - >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size), - >>> pool_layers=pool_layers) - >>> output = self.anomaly_map_generator(distribution=dist, height=height, width=width) + >>> generator = AnomalyMapGenerator(pool_layers=["layer1", "layer2"]) + >>> distribution = [torch.randn(32, 64) for _ in range(2)] + >>> height = [32, 16] + >>> width = [32, 16] + >>> anomaly_map = generator( + ... distribution=distribution, + ... height=height, + ... width=width + ... ) Raises: - ValueError: `distribution`, `height` and 'width' keys are not found + KeyError: If required arguments `distribution`, `height` or `width` + are missing. Returns: - torch.Tensor: anomaly map + torch.Tensor: Generated anomaly map. """ if not ("distribution" in kwargs and "height" in kwargs and "width" in kwargs): msg = f"Expected keys `distribution`, `height` and `width`. Found {kwargs.keys()}" diff --git a/src/anomalib/models/image/cflow/lightning_model.py b/src/anomalib/models/image/cflow/lightning_model.py index 4dd9c25850..34f7937c61 100644 --- a/src/anomalib/models/image/cflow/lightning_model.py +++ b/src/anomalib/models/image/cflow/lightning_model.py @@ -1,9 +1,16 @@ -"""Cflow. +"""CFLOW - Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. -Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. +This module implements the CFLOW model for anomaly detection. CFLOW uses conditional +normalizing flows to model the distribution of normal data and detect anomalies in +real-time. -For more details, see the paper: `Real-Time Unsupervised Anomaly Detection via -Conditional Normalizing Flows `_. +The model consists of: + - A CNN backbone encoder to extract features + - Multiple decoders using normalizing flows to model feature distributions + - Positional encoding to capture spatial information + +Paper: `Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows +`_ """ # Copyright (C) 2022-2024 Intel Corporation @@ -34,29 +41,40 @@ class Cflow(AnomalibModule): - """PL Lightning Module for the CFLOW algorithm. + """PyTorch Lightning implementation of the CFLOW model. + + The model uses a pre-trained CNN backbone to extract features, followed by + conditional normalizing flow decoders to model the distribution of normal data. Args: - backbone (str, optional): Backbone CNN architecture. + backbone (str, optional): Name of the backbone CNN network. Defaults to ``"wide_resnet50_2"``. - layers (Sequence[str], optional): Layers to extract features from. - Defaults to ``("layer2", "layer3", "layer4")``. - pre_trained (bool, optional): Whether to use pre-trained weights. - Defaults to ``True``. - fiber_batch_size (int, optional): Fiber batch size. - Defaults to ``64``. - decoder (str, optional): Decoder architecture. + layers (Sequence[str], optional): List of layer names to extract features + from. Defaults to ``("layer2", "layer3", "layer4")``. + pre_trained (bool, optional): If True, use pre-trained weights for the + backbone. Defaults to ``True``. + fiber_batch_size (int, optional): Batch size for processing individual + fibers. Defaults to ``64``. + decoder (str, optional): Type of normalizing flow decoder to use. Defaults to ``"freia-cflow"``. - condition_vector (int, optional): Condition vector size. + condition_vector (int, optional): Dimension of the condition vector. Defaults to ``128``. - coupling_blocks (int, optional): Number of coupling blocks. + coupling_blocks (int, optional): Number of coupling blocks in the flow. Defaults to ``8``. - clamp_alpha (float, optional): Clamping value for the alpha parameter. - Defaults to ``1.9``. - permute_soft (bool, optional): Whether to use soft permutation. + clamp_alpha (float, optional): Clamping value for the alpha parameter in + flows. Defaults to ``1.9``. + permute_soft (bool, optional): If True, use soft permutation in flows. Defaults to ``False``. - lr (float, optional): Learning rate. + lr (float, optional): Learning rate for the optimizer. Defaults to ``0.0001``. + pre_processor (PreProcessor | bool, optional): Pre-processing module. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processing module. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluation module. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualization module. + Defaults to ``True``. """ def __init__( @@ -95,15 +113,18 @@ def __init__( permute_soft=permute_soft, ) self.automatic_optimization = False - # TODO(ashwinvaidya17): LR should be part of optimizer in config.yaml since cflow has custom optimizer. - # CVS-122670 + # TODO(ashwinvaidya17): LR should be part of optimizer in config.yaml since # noqa: TD003 + # cflow has custom optimizer. CVS-122670 self.learning_rate = lr def configure_optimizers(self) -> Optimizer: """Configure optimizers for each decoder. + Creates an Adam optimizer for all decoder parameters with the specified + learning rate. + Returns: - Optimizer: Adam optimizer for each decoder + Optimizer: Adam optimizer instance configured for the decoders. """ decoders_parameters = [] for decoder_idx in range(len(self.model.pool_layers)): @@ -115,20 +136,24 @@ def configure_optimizers(self) -> Optimizer: ) def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the training step of CFLOW. + """Perform a training step of the CFLOW model. - For each batch, decoder layers are trained with a dynamic fiber batch size. - Training step is performed manually as multiple training steps are involved - per batch of input images + The training process involves: + 1. Extract features using the encoder + 2. Process features in fiber batches + 3. Apply positional encoding + 4. Train decoders using normalizing flows Args: - batch (Batch): Input batch - *args: Arguments. - **kwargs: Keyword arguments. + batch (Batch): Input batch containing images + *args: Additional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - Loss value for the batch + STEP_OUTPUT: Dictionary containing the average loss for the batch + Raises: + ValueError: If the fiber batch size is too large for the input size """ del args, kwargs # These variables are not used. @@ -190,21 +215,20 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: return {"loss": avg_loss} def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the validation step of CFLOW. + """Perform a validation step of the CFLOW model. - Similar to the training step, encoder features - are extracted from the CNN for each batch, and anomaly - map is computed. + The validation process: + 1. Extracts features using the encoder + 2. Computes anomaly maps using the trained decoders + 3. Updates the batch with predictions Args: - batch (Batch): Input batch - *args: Arguments. - **kwargs: Keyword arguments. + batch (Batch): Input batch containing images + *args: Additional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - Dictionary containing images, anomaly maps, true labels and masks. - These are required in `validation_epoch_end` for feature concatenation. - + STEP_OUTPUT: Batch updated with model predictions """ del args, kwargs # These variables are not used. @@ -213,14 +237,20 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """C-FLOW specific trainer arguments.""" + """Get CFLOW-specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary containing trainer arguments: + - gradient_clip_val: 0 + - num_sanity_val_steps: 0 + """ return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: ONE_CLASS learning type """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/cflow/torch_model.py b/src/anomalib/models/image/cflow/torch_model.py index dcfdcfa7fc..f83639d6b4 100644 --- a/src/anomalib/models/image/cflow/torch_model.py +++ b/src/anomalib/models/image/cflow/torch_model.py @@ -1,4 +1,32 @@ -"""PyTorch model for CFlow model implementation.""" +"""PyTorch model for the CFLOW anomaly detection model. + +This module provides the PyTorch implementation of the CFLOW model for anomaly +detection. The model uses conditional normalizing flows to model the distribution +of normal data in the feature space. + +The model consists of: + - A CNN backbone encoder to extract features + - Multiple decoders using normalizing flows to model feature distributions + - Positional encoding to capture spatial information + +Example: + >>> import torch + >>> from anomalib.models.image.cflow.torch_model import CflowModel + >>> # Initialize the model + >>> model = CflowModel( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"], + ... fiber_batch_size=64, + ... decoder="freia-cflow", + ... condition_vector=128, + ... coupling_blocks=8, + ... clamp_alpha=1.9, + ... permute_soft=False + ... ) + >>> # Forward pass + >>> x = torch.randn(32, 3, 256, 256) + >>> predictions = model(x) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -20,22 +48,31 @@ class CflowModel(nn.Module): """CFLOW: Conditional Normalizing Flows. Args: - backbone (str): Backbone CNN architecture. - layers (Sequence[str]): Layers to extract features from. - pre_trained (bool): Whether to use pre-trained weights. - Defaults to ``True``. - fiber_batch_size (int): Fiber batch size. - Defaults to ``64``. - decoder (str): Decoder architecture. + backbone (str): Name of the backbone CNN network to use as feature + extractor. + layers (Sequence[str]): Names of layers from which to extract features. + pre_trained (bool, optional): Whether to use pre-trained weights for the + backbone. Defaults to ``True``. + fiber_batch_size (int, optional): Batch size for processing feature + fibers. Defaults to ``64``. + decoder (str, optional): Type of decoder architecture to use. Defaults to ``"freia-cflow"``. - condition_vector (int): Condition vector size. - Defaults to ``128``. - coupling_blocks (int): Number of coupling blocks. - Defaults to ``8``. - clamp_alpha (float): Clamping value for the alpha parameter. - Defaults to ``1.9``. - permute_soft (bool): Whether to use soft permutation. - Defaults to ``False``. + condition_vector (int, optional): Size of the condition vector for the + normalizing flows. Defaults to ``128``. + coupling_blocks (int, optional): Number of coupling blocks in the + normalizing flows. Defaults to ``8``. + clamp_alpha (float, optional): Clamping value for the alpha parameter in + the flows. Defaults to ``1.9``. + permute_soft (bool, optional): Whether to use soft permutation in the + flows. Defaults to ``False``. + + Example: + >>> model = CflowModel( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"] + ... ) + >>> x = torch.randn(32, 3, 256, 256) + >>> predictions = model(x) """ def __init__( @@ -84,14 +121,25 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator(pool_layers=self.pool_layers) def forward(self, images: torch.Tensor) -> InferenceBatch: - """Forward-pass images into the network to extract encoder features and compute probability. + """Forward pass through the model. + + The method extracts features using the encoder, processes them through + normalizing flows, and generates anomaly predictions. Args: - images: Batch of images. + images (torch.Tensor): Input images of shape + ``(batch_size, channels, height, width)``. Returns: - Predicted anomaly maps. - + InferenceBatch: Batch containing predicted anomaly scores and maps. + The anomaly maps have shape ``(batch_size, 1, height, width)``. + + Example: + >>> x = torch.randn(32, 3, 256, 256) + >>> model = CflowModel(backbone="resnet18", layers=["layer1"]) + >>> predictions = model(x) + >>> predictions.anomaly_map.shape + torch.Size([32, 1, 256, 256]) """ self.encoder.eval() self.decoders.eval() diff --git a/src/anomalib/models/image/cflow/utils.py b/src/anomalib/models/image/cflow/utils.py index 636bfed1c9..a8e653d1a8 100644 --- a/src/anomalib/models/image/cflow/utils.py +++ b/src/anomalib/models/image/cflow/utils.py @@ -1,4 +1,12 @@ -"""Helper functions for CFlow implementation.""" +"""Helper functions for CFlow implementation. + +This module provides utility functions used by the CFlow model implementation, +including: + +- Log likelihood estimation +- 2D positional encoding generation +- Subnet and decoder network creation +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -17,33 +25,50 @@ def get_logp(dim_feature_vector: int, p_u: torch.Tensor, logdet_j: torch.Tensor) -> torch.Tensor: - """Return the log likelihood estimation. + """Calculate the log likelihood estimation. Args: - dim_feature_vector (int): Dimensions of the condition vector - p_u (torch.Tensor): Random variable u - logdet_j (torch.Tensor): log of determinant of jacobian returned from the invertable decoder + dim_feature_vector (int): Dimension of the feature vector + p_u (torch.Tensor): Random variable ``u`` sampled from the base distribution + logdet_j (torch.Tensor): Log determinant of the Jacobian returned from the + invertible decoder Returns: - Tensor: Log probability + torch.Tensor: Log probability estimation + + Example: + >>> dim = 128 + >>> p_u = torch.randn(32, dim) + >>> logdet_j = torch.zeros(32) + >>> logp = get_logp(dim, p_u, logdet_j) """ ln_sqrt_2pi = -np.log(np.sqrt(2 * np.pi)) # ln(sqrt(2*pi)) return dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u**2, 1) + logdet_j def positional_encoding_2d(condition_vector: int, height: int, width: int) -> torch.Tensor: - """Create embedding to store relative position of the feature vector using sine and cosine functions. + """Create 2D positional encoding using sine and cosine functions. + + Creates an embedding to store relative position of feature vectors using + sinusoidal functions at different frequencies. Args: - condition_vector (int): Length of the condition vector - height (int): H of the positions - width (int): W of the positions + condition_vector (int): Length of the condition vector (must be multiple + of 4) + height (int): Height of the positions grid + width (int): Width of the positions grid Raises: - ValueError: Cannot generate encoding with conditional vector length not as multiple of 4 + ValueError: If ``condition_vector`` is not a multiple of 4 Returns: - Tensor: condition_vector x HEIGHT x WIDTH position matrix + torch.Tensor: Position encoding of shape + ``(condition_vector, height, width)`` + + Example: + >>> encoding = positional_encoding_2d(128, 32, 32) + >>> encoding.shape + torch.Size([128, 32, 32]) """ if condition_vector % 4 != 0: msg = f"Cannot use sin/cos positional encoding with odd dimension (got dim={condition_vector})" @@ -70,14 +95,21 @@ def positional_encoding_2d(condition_vector: int, height: int, width: int) -> to def subnet_fc(dims_in: int, dims_out: int) -> nn.Sequential: - """Subnetwork which predicts the affine coefficients. + """Create a feed-forward subnetwork that predicts affine coefficients. Args: - dims_in (int): input dimensions - dims_out (int): output dimensions + dims_in (int): Input dimensions + dims_out (int): Output dimensions Returns: - nn.Sequential: Feed-forward subnetwork + nn.Sequential: Feed-forward subnetwork with ReLU activation + + Example: + >>> net = subnet_fc(64, 128) + >>> x = torch.randn(32, 64) + >>> out = net(x) + >>> out.shape + torch.Size([32, 128]) """ return nn.Sequential(nn.Linear(dims_in, 2 * dims_in), nn.ReLU(), nn.Linear(2 * dims_in, dims_out)) @@ -89,19 +121,28 @@ def cflow_head( n_features: int, permute_soft: bool = False, ) -> SequenceINN: - """Create invertible decoder network. + """Create an invertible decoder network for CFlow. Args: - condition_vector (int): length of the condition vector - coupling_blocks (int): number of coupling blocks to build the decoder - clamp_alpha (float): clamping value to avoid exploding values - n_features (int): number of decoder features - permute_soft (bool): Whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, - or to use hard permutations instead. Note, ``permute_soft=True`` is very slow - when working with >512 dimensions. + condition_vector (int): Length of the condition vector + coupling_blocks (int): Number of coupling blocks in the decoder + clamp_alpha (float): Clamping value to avoid exploding values + n_features (int): Number of decoder features + permute_soft (bool, optional): Whether to sample the permutation matrix + from SO(N) (True) or use hard permutations (False). Note that + ``permute_soft=True`` is very slow for >512 dimensions. + Defaults to False. Returns: - SequenceINN: decoder network block + SequenceINN: Invertible decoder network + + Example: + >>> decoder = cflow_head( + ... condition_vector=128, + ... coupling_blocks=4, + ... clamp_alpha=1.9, + ... n_features=256 + ... ) """ coder = SequenceINN(n_features) logger.info("CNF coder: %d", n_features) diff --git a/src/anomalib/models/image/csflow/__init__.py b/src/anomalib/models/image/csflow/__init__.py index f53d606823..3f516195c8 100644 --- a/src/anomalib/models/image/csflow/__init__.py +++ b/src/anomalib/models/image/csflow/__init__.py @@ -1,4 +1,25 @@ -"""Fully Convolutional Cross-Scale-Flows for Image-based Defect Detection.""" +"""Implementation of the CS-Flow model for anomaly detection. + +The CS-Flow model, short for Cross-Scale-Flows, is a fully convolutional approach +for image-based defect detection. It leverages normalizing flows across multiple +scales of the input image to model the distribution of normal (non-defective) +samples. + +The model architecture consists of: + - A feature extraction backbone + - Multiple normalizing flow blocks operating at different scales + - Cross-scale connections to capture multi-scale dependencies + +Example: + >>> from anomalib.models.image.csflow import Csflow + >>> model = Csflow() + +Reference: + Gudovskiy, Denis, et al. "Cflow-ad: Real-time unsupervised anomaly detection + with localization via conditional normalizing flows." + Proceedings of the IEEE/CVF Winter Conference on Applications of Computer + Vision. 2022. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/csflow/anomaly_map.py b/src/anomalib/models/image/csflow/anomaly_map.py index 8a80f3cdfb..800ee21bb3 100644 --- a/src/anomalib/models/image/csflow/anomaly_map.py +++ b/src/anomalib/models/image/csflow/anomaly_map.py @@ -1,4 +1,22 @@ -"""Anomaly Map Generator for CS-Flow model.""" +"""Anomaly Map Generator for CS-Flow model. + +This module provides functionality to generate anomaly maps from the CS-Flow model's +outputs. The generator can operate in two modes: + +1. ``ALL`` - Combines anomaly scores from all scales (default) +2. ``MAX`` - Uses only the largest scale as mentioned in the paper + +The anomaly maps are generated by computing the mean of squared z-scores across +channels and upsampling to the input dimensions. + +Example: + >>> import torch + >>> generator = AnomalyMapGenerator(input_dims=(3, 256, 256)) + >>> z_dist = [torch.randn(2, 64, 32, 32) for _ in range(3)] + >>> anomaly_map = generator(z_dist) + >>> anomaly_map.shape + torch.Size([2, 1, 256, 256]) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,19 +29,31 @@ class AnomalyMapMode(str, Enum): - """Generate anomaly map from all the scales or the max.""" + """Mode for generating anomaly maps. + + The mode determines how the anomaly scores from different scales are combined: + + - ``ALL``: Combines scores from all scales by multiplication + - ``MAX``: Uses only the score from the largest scale + """ ALL = "all" MAX = "max" class AnomalyMapGenerator(nn.Module): - """Anomaly Map Generator for CS-Flow model. + """Generate anomaly maps from CS-Flow model outputs. Args: - input_dims (tuple[int, int, int]): Input dimensions. - mode (AnomalyMapMode): Anomaly map mode. + input_dims (tuple[int, int, int]): Input dimensions in the format + ``(channels, height, width)``. + mode (AnomalyMapMode, optional): Mode for generating anomaly maps. Defaults to ``AnomalyMapMode.ALL``. + + Example: + >>> generator = AnomalyMapGenerator((3, 256, 256)) + >>> z_dist = [torch.randn(1, 64, 32, 32) for _ in range(3)] + >>> anomaly_map = generator(z_dist) """ def __init__(self, input_dims: tuple[int, int, int], mode: AnomalyMapMode = AnomalyMapMode.ALL) -> None: @@ -32,17 +62,23 @@ def __init__(self, input_dims: tuple[int, int, int], mode: AnomalyMapMode = Anom self.input_dims = input_dims def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """Get anomaly maps by taking mean of the z-distributions across channels. - - By default it computes anomaly maps for all the scales as it gave better performance on initial tests. - Use ``AnomalyMapMode.MAX`` for the largest scale as mentioned in the paper. + """Generate anomaly maps from z-distributions. Args: - inputs (torch.Tensor): z-distributions for the three scales. - mode (AnomalyMapMode): Anomaly map mode. + inputs (torch.Tensor): List of z-distributions from different scales, + where each element has shape ``(batch_size, channels, height, + width)``. Returns: - Tensor: Anomaly maps. + torch.Tensor: Anomaly maps with shape ``(batch_size, 1, height, + width)``, where height and width match the input dimensions. + + Example: + >>> z_dist = [torch.randn(2, 64, 32, 32) for _ in range(3)] + >>> generator = AnomalyMapGenerator((3, 256, 256)) + >>> maps = generator(z_dist) + >>> maps.shape + torch.Size([2, 1, 256, 256]) """ anomaly_map: torch.Tensor if self.mode == AnomalyMapMode.ALL: diff --git a/src/anomalib/models/image/csflow/lightning_model.py b/src/anomalib/models/image/csflow/lightning_model.py index 8e9994631a..c1e7f47951 100644 --- a/src/anomalib/models/image/csflow/lightning_model.py +++ b/src/anomalib/models/image/csflow/lightning_model.py @@ -1,6 +1,10 @@ """Fully Convolutional Cross-Scale-Flows for Image-based Defect Detection. -https://arxiv.org/pdf/2110.02855.pdf +Paper: https://arxiv.org/pdf/2110.02855.pdf + +This module provides the CS-Flow model implementation for anomaly detection. +CS-Flow uses normalizing flows across multiple scales to model the distribution +of normal images and detect anomalies. """ # Copyright (C) 2022-2024 Intel Corporation @@ -29,17 +33,41 @@ class Csflow(AnomalibModule): - """Fully Convolutional Cross-Scale-Flows for Image-based Defect Detection. + """CS-Flow Lightning Model for anomaly detection. + + CS-Flow uses normalizing flows across multiple scales to model the distribution + of normal images. During inference, it assigns anomaly scores based on the + likelihood of test samples under the learned distribution. Args: - n_coupling_blocks (int): Number of coupling blocks in the model. + n_coupling_blocks (int, optional): Number of coupling blocks in the model. Defaults to ``4``. - cross_conv_hidden_channels (int): Number of hidden channels in the cross convolution. - Defaults to ``1024``. - clamp (int): Clamp value for glow layer. - Defaults to ``3``. - num_channels (int): Number of channels in the model. - Defaults to ``3``. + cross_conv_hidden_channels (int, optional): Number of hidden channels in + the cross convolution layer. Defaults to ``1024``. + clamp (int, optional): Clamping value for the affine coupling layers in + the Glow model. Defaults to ``3``. + num_channels (int, optional): Number of input image channels. + Defaults to ``3`` for RGB images. + pre_processor (PreProcessor | bool, optional): Preprocessing module or + flag to enable default preprocessing. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processing module or + flag to enable default post-processing. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluation module or flag to + enable default evaluation. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualization module or flag to + enable default visualization. Defaults to ``True``. + + Raises: + ValueError: If ``input_size`` is not provided during initialization. + + Example: + >>> from anomalib.models.image.csflow import Csflow + >>> model = Csflow( + ... n_coupling_blocks=4, + ... cross_conv_hidden_channels=1024, + ... clamp=3, + ... num_channels=3 + ... ) """ def __init__( @@ -79,15 +107,22 @@ def __init__( self.loss = CsFlowLoss() def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the training step of CS-Flow. + """Perform a training step of CS-Flow model. Args: - batch (Batch): Input batch - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and targets + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - Loss value + STEP_OUTPUT: Dictionary containing the loss value + + Example: + >>> batch = Batch(image=torch.randn(32, 3, 256, 256)) + >>> model = Csflow() + >>> output = model.training_step(batch) + >>> output["loss"] + tensor(...) """ del args, kwargs # These variables are not used. @@ -97,15 +132,21 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: return {"loss": loss} def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the validation step for CS Flow. + """Perform a validation step of CS-Flow model. Args: - batch (Batch): Input batch - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and targets + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - dict[str, torch.Tensor]: Dictionary containing the anomaly map, scores, etc. + STEP_OUTPUT: Dictionary containing predictions including anomaly maps + and scores + + Example: + >>> batch = Batch(image=torch.randn(32, 3, 256, 256)) + >>> model = Csflow() + >>> predictions = model.validation_step(batch) """ del args, kwargs # These variables are not used. @@ -114,14 +155,26 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """CS-Flow-specific trainer arguments.""" + """Get CS-Flow-specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary containing trainer arguments: + - gradient_clip_val: Maximum gradient norm for clipping + - num_sanity_val_steps: Number of validation steps to run before + training + """ return {"gradient_clip_val": 1, "num_sanity_val_steps": 0} def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure optimizers. + """Configure the Adam optimizer for CS-Flow. Returns: - Optimizer: Adam optimizer + torch.optim.Optimizer: Configured Adam optimizer with specific + hyperparameters + + Example: + >>> model = Csflow() + >>> optimizer = model.configure_optimizers() """ return torch.optim.Adam( self.parameters(), @@ -133,9 +186,9 @@ def configure_optimizers(self) -> torch.optim.Optimizer: @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: The learning type, which is ONE_CLASS for CS-Flow """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/csflow/loss.py b/src/anomalib/models/image/csflow/loss.py index 2e5d1da8ff..a5156567f1 100644 --- a/src/anomalib/models/image/csflow/loss.py +++ b/src/anomalib/models/image/csflow/loss.py @@ -1,4 +1,18 @@ -"""Loss function for the CS-Flow Model Implementation.""" +"""Loss function for the CS-Flow Model Implementation. + +This module implements the loss function used in the CS-Flow model for anomaly +detection. The loss combines the squared L2 norm of the latent space +representations with the log-determinant of the Jacobian from the normalizing +flows. + +Example: + >>> import torch + >>> from anomalib.models.image.csflow.loss import CsFlowLoss + >>> criterion = CsFlowLoss() + >>> z_dist = [torch.randn(2, 64, 32, 32) for _ in range(3)] + >>> jacobians = torch.randn(2) + >>> loss = criterion(z_dist, jacobians) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,18 +22,31 @@ class CsFlowLoss(nn.Module): - """Loss function for the CS-Flow Model Implementation.""" + """Loss function for the CS-Flow model. + + The loss is computed as the mean of the squared L2 norm of the latent space + representations minus the log-determinant of the Jacobian, normalized by the + dimensionality of the latent space. + """ @staticmethod - def forward(z_dist: torch.Tensor, jacobians: torch.Tensor) -> torch.Tensor: - """Compute the loss CS-Flow. + def forward(z_dist: list[torch.Tensor], jacobians: torch.Tensor) -> torch.Tensor: + """Compute the CS-Flow loss. Args: - z_dist (torch.Tensor): Latent space image mappings from NF. - jacobians (torch.Tensor): Jacobians of the distribution + z_dist (list[torch.Tensor]): List of latent space tensors from each + scale of the normalizing flow. Each tensor has shape + ``(batch_size, channels, height, width)``. + jacobians (torch.Tensor): Log-determinant of the Jacobian matrices + from the normalizing flows. Shape: ``(batch_size,)``. Returns: - Loss value + torch.Tensor: Scalar loss value averaged over the batch. + + Example: + >>> z_dist = [torch.randn(2, 64, 32, 32) for _ in range(3)] + >>> jacobians = torch.randn(2) + >>> loss = CsFlowLoss.forward(z_dist, jacobians) """ - z_dist = torch.cat([z_dist[i].reshape(z_dist[i].shape[0], -1) for i in range(len(z_dist))], dim=1) - return torch.mean(0.5 * torch.sum(z_dist**2, dim=(1,)) - jacobians) / z_dist.shape[1] + concatenated = torch.cat([z_dist[i].reshape(z_dist[i].shape[0], -1) for i in range(len(z_dist))], dim=1) + return torch.mean(0.5 * torch.sum(concatenated**2, dim=(1,)) - jacobians) / concatenated.shape[1] diff --git a/src/anomalib/models/image/csflow/torch_model.py b/src/anomalib/models/image/csflow/torch_model.py index a4703d9b4c..fd067450e3 100644 --- a/src/anomalib/models/image/csflow/torch_model.py +++ b/src/anomalib/models/image/csflow/torch_model.py @@ -1,5 +1,14 @@ -"""PyTorch model for CS-Flow implementation.""" +"""PyTorch model for CS-Flow implementation. +This module contains the PyTorch implementation of CS-Flow model for anomaly detection. +The model uses cross-scale coupling layers to learn the distribution of normal images +and detect anomalies based on the likelihood of test images under this distribution. + +The implementation is based on the paper: + CS-Flow: Learning Cross-Scale Semantic Flow for Unsupervised Anomaly Detection + Marco Rudolph, Tom Wehrbein, Bodo Rosenhahn, Bastian Wandt + https://arxiv.org/abs/2110.02855 +""" # Original Code # Copyright (c) 2021 marco-rudolph @@ -27,21 +36,34 @@ class CrossConvolutions(nn.Module): - """Cross convolution for the three scales. + """Cross convolution module for processing features at three scales. + + This module applies convolutions across three different scales of features, + with connections between scales via up/downsampling operations. Args: in_channels (int): Number of input channels. - channels (int): Number of output channels in the hidden convolution and the upscaling layers. - channels_hidden (int, optional): Number of input channels in the hidden convolution layers. + channels (int): Number of output channels in convolution layers. + channels_hidden (int, optional): Number of channels in hidden layers. Defaults to ``512``. - kernel_size (int, optional): Kernel size of the convolution layers. + kernel_size (int, optional): Size of convolution kernels. Defaults to ``3``. - leaky_slope (float, optional): Slope of the leaky ReLU activation. + leaky_slope (float, optional): Negative slope for leaky ReLU. Defaults to ``0.1``. batch_norm (bool, optional): Whether to use batch normalization. Defaults to ``False``. - use_gamma (bool, optional): Whether to use gamma parameters for the cross convolutions. + use_gamma (bool, optional): Whether to use learnable gamma parameters. Defaults to ``True``. + + Example: + >>> cross_conv = CrossConvolutions(64, 128) + >>> scale0 = torch.randn(1, 64, 32, 32) + >>> scale1 = torch.randn(1, 64, 16, 16) + >>> scale2 = torch.randn(1, 64, 8, 8) + >>> out0, out1, out2 = cross_conv(scale0, scale1, scale2) + >>> out0.shape, out1.shape, out2.shape + (torch.Size([1, 128, 32, 32]), torch.Size([1, 128, 16, 16]), + torch.Size([1, 128, 8, 8])) """ def __init__( @@ -161,14 +183,21 @@ def __init__( self.leaky_relu = nn.LeakyReLU(self.leaky_slope) def forward(self, scale0: int, scale1: int, scale2: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Apply the cross convolution to the three scales. + """Apply cross-scale convolutions to input features. - This block is represented in figure 4 of the paper. + Processes features at three scales with cross-connections between scales via + up/downsampling operations. This implements the architecture shown in Figure 4 + of the CS-Flow paper. + + Args: + scale0 (torch.Tensor): Features at original scale. + scale1 (torch.Tensor): Features at 1/2 scale. + scale2 (torch.Tensor): Features at 1/4 scale. Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tensors indicating scale and transform parameters - as a single tensor for each scale. The scale parameters are the first part across channel dimension - and the transform parameters are the second. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Processed features at three + scales. Each tensor contains scale and transform parameters concatenated + along the channel dimension. """ # Increase the number of channels to hidden channel length via convolutions and apply leaky ReLU. out0 = self.conv_scale0_0(scale0) @@ -206,11 +235,23 @@ def forward(self, scale0: int, scale1: int, scale2: int) -> tuple[torch.Tensor, class ParallelPermute(InvertibleModule): - """Permutes input vector in a random but fixed way. + """Permutes input vectors in a random but fixed way. + + This module applies a fixed random permutation to the channels of each input + tensor. The permutation is deterministic for a given seed. Args: - dim (list[tuple[int]]): Dimension of the input vector. - seed (float | None=None): Seed for the random permutation. + dims_in (list[tuple[int]]): List of input tensor dimensions. + seed (int | None, optional): Random seed for permutation. + Defaults to ``None``. + + Example: + >>> permute = ParallelPermute([(3, 32, 32), (3, 16, 16)], seed=42) + >>> x1 = torch.randn(1, 3, 32, 32) + >>> x2 = torch.randn(1, 3, 16, 16) + >>> y1, y2 = permute([x1, x2])[0] + >>> y1.shape, y2.shape + (torch.Size([1, 3, 32, 32]), torch.Size([1, 3, 16, 16])) """ def __init__(self, dims_in: list[tuple[int]], seed: int | None = None) -> None: @@ -229,13 +270,13 @@ def __init__(self, dims_in: list[tuple[int]], seed: int | None = None) -> None: self.perm_inv.append(perm_inv) def get_random_perm(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: - """Return a random permutation of the channels for each input. + """Generate random permutation and its inverse for given input index. Args: - index (int): index of the input + index (int): Index of input tensor. Returns: - tuple[torch.Tensor, torch.Tensor]: permutation and inverse permutation + tuple[torch.Tensor, torch.Tensor]: Permutation and inverse permutation tensors. """ perm = np.random.default_rng(self.seed).permutation(self.in_channels[index]) perm_inv = np.zeros_like(perm) @@ -253,17 +294,17 @@ def forward( rev: bool = False, jac: bool = True, ) -> tuple[list[torch.Tensor], float]: - """Apply the permutation to the input. + """Apply permutation or inverse permutation to inputs. Args: - input_tensor: list of input tensors - rev: if True, applies the inverse permutation + input_tensor (list[torch.Tensor]): List of input tensors. + rev (bool, optional): If ``True``, applies inverse permutation. Defaults to ``False``. - jac: (unused) if True, computes the log determinant of the Jacobian + jac (bool, optional): Unused. Required for interface compatibility. Defaults to ``True``. Returns: - tuple[torch.Tensor, torch.Tensor]: output tensor and log determinant of the Jacobian + tuple[list[torch.Tensor], float]: Permuted tensors and log determinant (0). """ del jac # Unused argument. @@ -274,18 +315,39 @@ def forward( @staticmethod def output_dims(input_dims: list[tuple[int]]) -> list[tuple[int]]: - """Return the output dimensions of the module.""" + """Return output dimensions of the module. + + Args: + input_dims (list[tuple[int]]): List of input dimensions. + + Returns: + list[tuple[int]]: List of output dimensions (same as input). + """ return input_dims class ParallelGlowCouplingLayer(InvertibleModule): - """Coupling block that follows the GLOW design but is applied to all the scales in parallel. + """Coupling block following GLOW design applied to multiple scales in parallel. + + This module implements an invertible coupling layer that processes multiple scales + simultaneously, following the GLOW architecture design. Args: - dims_in (list[tuple[int]]): list of dimensions of the input tensors - subnet_args (dict): arguments of the subnet - clamp (float): clamp value for the output of the subnet + dims_in (list[tuple[int]]): List of input tensor dimensions. + subnet_args (dict): Arguments for subnet construction. + clamp (float, optional): Clamp value for outputs. Defaults to ``5.0``. + + Example: + >>> coupling = ParallelGlowCouplingLayer( + ... [(6, 32, 32), (6, 16, 16)], + ... {"channels_hidden": 64} + ... ) + >>> x1 = torch.randn(1, 6, 32, 32) + >>> x2 = torch.randn(1, 6, 16, 16) + >>> y1, y2 = coupling([x1, x2])[0] + >>> y1.shape, y2.shape + (torch.Size([1, 6, 32, 32]), torch.Size([1, 6, 16, 16])) """ def __init__(self, dims_in: list[tuple[int]], subnet_args: dict, clamp: float = 5.0) -> None: @@ -305,13 +367,27 @@ def __init__(self, dims_in: list[tuple[int]], subnet_args: dict, clamp: float = self.cross_convolution2 = CrossConvolutions(self.split_len2, self.split_len1 * 2, **subnet_args) def exp(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Exponentiates the input and, optionally, clamps it to avoid numerical issues.""" + """Exponentiates input with optional clamping. + + Args: + input_tensor (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Exponentiated tensor, optionally clamped. + """ if self.clamp > 0: return torch.exp(self.log_e(input_tensor)) return torch.exp(input_tensor) def log_e(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Return log of input. And optionally clamped to avoid numerical issues.""" + """Compute log with optional clamping. + + Args: + input_tensor (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Log of input, optionally clamped. + """ if self.clamp > 0: return self.clamp * 0.636 * torch.atan(input_tensor / self.clamp) return input_tensor @@ -322,7 +398,19 @@ def forward( rev: bool = False, jac: bool = True, ) -> tuple[list[torch.Tensor], torch.Tensor]: - """Apply GLOW coupling for the three scales.""" + """Apply GLOW coupling transformation to inputs at multiple scales. + + Args: + input_tensor (list[torch.Tensor]): List of input tensors at different scales. + rev (bool, optional): If ``True``, applies inverse transformation. + Defaults to ``False``. + jac (bool, optional): Unused. Required for interface compatibility. + Defaults to ``True``. + + Returns: + tuple[list[torch.Tensor], torch.Tensor]: Transformed tensors and log + determinant of Jacobian. + """ del jac # Unused argument. # Even channel split. The two splits are used by cross-scale convolution to compute scale and transform @@ -406,18 +494,40 @@ def forward( @staticmethod def output_dims(input_dims: list[tuple[int]]) -> list[tuple[int]]: - """Output dimensions of the module.""" + """Return output dimensions of the module. + + Args: + input_dims (list[tuple[int]]): List of input dimensions. + + Returns: + list[tuple[int]]: List of output dimensions (same as input). + """ return input_dims class CrossScaleFlow(nn.Module): """Cross scale coupling layer. + This module implements the cross-scale flow architecture that couples features + across multiple scales. + Args: - input_dims (tuple[int, int, int]): Input dimensions of the module. - n_coupling_blocks (int): Number of coupling blocks. - clamp (float): Clamp value for the inputs. - corss_conv_hidden_channels (int): Number of hidden channels in the cross convolution. + input_dims (tuple[int, int, int]): Input dimensions (C, H, W). + n_coupling_blocks (int): Number of coupling blocks to use. + clamp (float): Clamping value for coupling layers. + cross_conv_hidden_channels (int): Hidden channels in cross convolutions. + + Example: + >>> flow = CrossScaleFlow((3, 256, 256), 4, 3.0, 64) + >>> x = [ + ... torch.randn(1, 304, 8, 8), + ... torch.randn(1, 304, 4, 4), + ... torch.randn(1, 304, 2, 2) + ... ] + >>> z, jac = flow(x) + >>> [zi.shape for zi in z] + [torch.Size([1, 304, 8, 8]), torch.Size([1, 304, 4, 4]), + torch.Size([1, 304, 2, 2])] """ def __init__( @@ -436,6 +546,11 @@ def __init__( self.graph = self._create_graph() def _create_graph(self) -> GraphINN: + """Create the invertible neural network graph. + + Returns: + GraphINN: Constructed invertible neural network. + """ nodes: list[Node] = [] # 304 is the number of features extracted from EfficientNet-B5 feature extractor input_nodes = [ @@ -481,25 +596,35 @@ def _create_graph(self) -> GraphINN: return GraphINN(nodes) def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Forward pass. + """Forward pass through the flow model. Args: inputs (torch.Tensor): Input tensor. Returns: - tuple[torch.Tensor, torch.Tensor]: Output tensor and log determinant of Jacobian. + tuple[torch.Tensor, torch.Tensor]: Output tensor and log determinant + of Jacobian. """ return self.graph(inputs) class MultiScaleFeatureExtractor(nn.Module): - """Multi-scale feature extractor. + """Multi-scale feature extractor using EfficientNet-B5. - Uses 36th layer of EfficientNet-B5 to extract features. + This module extracts features at multiple scales using the 36th layer of + EfficientNet-B5. Args: - n_scales (int): Number of scales for input image. - input_size (tuple[int, int]): Size of input image. + n_scales (int): Number of scales to extract features at. + input_size (tuple[int, int]): Input image size (H, W). + + Example: + >>> extractor = MultiScaleFeatureExtractor(3, (256, 256)) + >>> x = torch.randn(1, 3, 256, 256) + >>> features = extractor(x) + >>> [f.shape for f in features] + [torch.Size([1, 304, 8, 8]), torch.Size([1, 304, 4, 4]), + torch.Size([1, 304, 2, 2])] """ def __init__(self, n_scales: int, input_size: tuple[int, int]) -> None: @@ -514,13 +639,13 @@ def __init__(self, n_scales: int, input_size: tuple[int, int]) -> None: ) def forward(self, input_tensor: torch.Tensor) -> list[torch.Tensor]: - """Extract features at three scales. + """Extract features at multiple scales. Args: input_tensor (torch.Tensor): Input images. Returns: - list[torch.Tensor]: List of tensors containing features at three scales. + list[torch.Tensor]: List of feature tensors at different scales. """ output = [] for scale in range(self.n_scales): @@ -539,17 +664,27 @@ def forward(self, input_tensor: torch.Tensor) -> list[torch.Tensor]: class CsFlowModel(nn.Module): - """CS Flow Module. + """CS-Flow model for anomaly detection. + + This module implements the complete CS-Flow model that learns the distribution + of normal images using cross-scale coupling layers. Args: - input_size (tuple[int, int]): Input image size. - cross_conv_hidden_channels (int): Number of hidden channels in the cross convolution. - n_coupling_blocks (int): Number of coupling blocks. + input_size (tuple[int, int]): Input image size (H, W). + cross_conv_hidden_channels (int): Hidden channels in cross convolutions. + n_coupling_blocks (int, optional): Number of coupling blocks. Defaults to ``4``. - clamp (float): Clamp value for the coupling blocks. + clamp (int, optional): Clamping value for coupling layers. Defaults to ``3``. - num_channels (int): Number of channels in the input image. + num_channels (int, optional): Number of input image channels. Defaults to ``3``. + + Example: + >>> model = CsFlowModel((256, 256), 64) + >>> x = torch.randn(1, 3, 256, 256) + >>> output = model(x) + >>> isinstance(output, InferenceBatch) + True """ def __init__( diff --git a/src/anomalib/models/image/dfkde/__init__.py b/src/anomalib/models/image/dfkde/__init__.py index 9930fcea71..948b252887 100644 --- a/src/anomalib/models/image/dfkde/__init__.py +++ b/src/anomalib/models/image/dfkde/__init__.py @@ -1,4 +1,24 @@ -"""Deep Feature Kernel Density Estimation model.""" +"""Deep Feature Kernel Density Estimation (DFKDE) model for anomaly detection. + +The DFKDE model extracts deep features from images using a pre-trained CNN backbone +and fits a kernel density estimation on these features to model the distribution +of normal samples. During inference, samples with low likelihood under this +distribution are flagged as anomalous. + +Example: + >>> from anomalib.models.image import Dfkde + >>> model = Dfkde() + +The model can be used with any of the supported datasets and task modes in +anomalib. + +Notes: + The model implementation is available in the ``lightning_model`` module. + +See Also: + :class:`anomalib.models.image.dfkde.lightning_model.Dfkde`: + Lightning implementation of the DFKDE model. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index 666fb5507d..a437d9a244 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -1,4 +1,27 @@ -"""DFKDE: Deep Feature Kernel Density Estimation.""" +"""DFKDE: Deep Feature Kernel Density Estimation. + +This module provides a PyTorch Lightning implementation of the DFKDE model for +anomaly detection. The model extracts deep features from images using a +pre-trained CNN backbone and fits a kernel density estimation on these features +to model the distribution of normal samples. + +Example: + >>> from anomalib.models.image import Dfkde + >>> model = Dfkde( + ... backbone="resnet18", + ... layers=("layer4",), + ... pre_trained=True + ... ) + +Notes: + The model uses a pre-trained backbone to extract features and fits a KDE + classifier on the embeddings during training. No gradient updates are + performed on the backbone. + +See Also: + :class:`anomalib.models.image.dfkde.torch_model.DfkdeModel`: + PyTorch implementation of the DFKDE model. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -25,21 +48,40 @@ class Dfkde(MemoryBankMixin, AnomalibModule): - """DFKDE: Deep Feature Kernel Density Estimation. + """DFKDE Lightning Module. Args: - backbone (str): Pre-trained model backbone. + backbone (str): Name of the backbone CNN to use for feature extraction. Defaults to ``"resnet18"``. - layers (Sequence[str], optional): Layers to extract features from. + layers (Sequence[str]): Layers from which to extract features. Defaults to ``("layer4",)``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + pre_trained (bool): Whether to use pre-trained weights. Defaults to ``True``. - n_pca_components (int, optional): Number of PCA components. - Defaults to ``16``. - feature_scaling_method (FeatureScalingMethod, optional): Feature scaling method. + n_pca_components (int): Number of principal components for dimensionality + reduction. Defaults to ``16``. + feature_scaling_method (FeatureScalingMethod): Method to scale features. Defaults to ``FeatureScalingMethod.SCALE``. - max_training_points (int, optional): Number of training points to fit the KDE model. - Defaults to ``40000``. + max_training_points (int): Maximum number of points to use for KDE + fitting. Defaults to ``40000``. + pre_processor (PreProcessor | bool): Pre-processor object or flag. + Defaults to ``True``. + post_processor (PostProcessor | bool): Post-processor object or flag. + Defaults to ``True``. + evaluator (Evaluator | bool): Evaluator object or flag. + Defaults to ``True``. + visualizer (Visualizer | bool): Visualizer object or flag. + Defaults to ``True``. + + Example: + >>> from anomalib.models.image import Dfkde + >>> from anomalib.models.components.classification import ( + ... FeatureScalingMethod + ... ) + >>> model = Dfkde( + ... backbone="resnet18", + ... layers=("layer4",), + ... feature_scaling_method=FeatureScalingMethod.SCALE + ... ) """ def __init__( @@ -79,15 +121,15 @@ def configure_optimizers() -> None: # pylint: disable=arguments-differ return def training_step(self, batch: Batch, *args, **kwargs) -> None: - """Perform the training step of DFKDE. For each batch, features are extracted from the CNN. + """Extract features from the CNN for each training batch. Args: - batch (batch: Batch): Batch containing image filename, image, label and mask - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and metadata. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. Returns: - Deep CNN features. + torch.Tensor: Dummy tensor for Lightning compatibility. """ del args, kwargs # These variables are not used. @@ -98,24 +140,22 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: - """Fit a KDE Model to the embedding collected from the training set.""" + """Fit KDE model to collected embeddings from the training set.""" embeddings = torch.vstack(self.embeddings) logger.info("Fitting a KDE model to the embedding collected from the training set.") self.model.classifier.fit(embeddings) def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the validation step of DFKDE. - - Similar to the training step, features are extracted from the CNN for each batch. + """Perform validation by computing anomaly scores. Args: - batch (Batch): Input batch - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and metadata. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. Returns: - Dictionary containing probability, prediction and ground truth values. + STEP_OUTPUT: Dictionary containing predictions and batch info. """ del args, kwargs # These variables are not used. @@ -124,21 +164,29 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Return DFKDE-specific trainer arguments.""" + """Get DFKDE-specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary of trainer arguments. + """ return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type. Returns: - LearningType: Learning type of the model. + LearningType: Learning type of the model (ONE_CLASS). """ return LearningType.ONE_CLASS @staticmethod def configure_evaluator() -> Evaluator: - """Default evaluator for DFKE.""" + """Configure the default evaluator for DFKDE. + + Returns: + Evaluator: Evaluator object with image-level AUROC and F1 metrics. + """ image_auroc = AUROC(fields=["pred_score", "gt_label"], prefix="image_") image_f1score = F1Score(fields=["pred_label", "gt_label"], prefix="image_") test_metrics = [image_auroc, image_f1score] diff --git a/src/anomalib/models/image/dfkde/torch_model.py b/src/anomalib/models/image/dfkde/torch_model.py index 4dc5fd58fe..deca22aedd 100644 --- a/src/anomalib/models/image/dfkde/torch_model.py +++ b/src/anomalib/models/image/dfkde/torch_model.py @@ -1,4 +1,27 @@ -"""Normality model of DFKDE.""" +"""PyTorch model for Deep Feature Kernel Density Estimation (DFKDE). + +This module provides a PyTorch implementation of the DFKDE model for anomaly +detection. The model extracts deep features from images using a pre-trained CNN +backbone and fits a kernel density estimation on these features to model the +distribution of normal samples. + +Example: + >>> import torch + >>> from anomalib.models.image.dfkde.torch_model import DfkdeModel + >>> model = DfkdeModel( + ... backbone="resnet18", + ... layers=["layer4"], + ... pre_trained=True + ... ) + >>> batch = torch.randn(32, 3, 224, 224) + >>> features = model(batch) # Returns features during training + >>> predictions = model(batch) # Returns scores during inference + +Notes: + The model uses a pre-trained backbone to extract features and fits a KDE + classifier on the embeddings during training. No gradient updates are + performed on the backbone. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -18,19 +41,34 @@ class DfkdeModel(nn.Module): - """Normality Model for the DFKDE algorithm. + """Deep Feature Kernel Density Estimation model for anomaly detection. + + The model extracts deep features from images using a pre-trained CNN backbone + and fits a kernel density estimation on these features to model the + distribution of normal samples. Args: - backbone (str): Pre-trained model backbone. - layers (Sequence[str]): Layers to extract features from. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + backbone (str): Name of the pre-trained model backbone from timm. + layers (Sequence[str]): Names of layers to extract features from. + pre_trained (bool, optional): Whether to use pre-trained backbone weights. Defaults to ``True``. - n_pca_components (int, optional): Number of PCA components. - Defaults to ``16``. - feature_scaling_method (FeatureScalingMethod, optional): Feature scaling method. - Defaults to ``FeatureScalingMethod.SCALE``. - max_training_points (int, optional): Number of training points to fit the KDE model. - Defaults to ``40000``. + n_pca_components (int, optional): Number of components for PCA dimension + reduction. Defaults to ``16``. + feature_scaling_method (FeatureScalingMethod, optional): Method used to + scale features before KDE. Defaults to + ``FeatureScalingMethod.SCALE``. + max_training_points (int, optional): Maximum number of points used to fit + the KDE model. Defaults to ``40000``. + + Example: + >>> import torch + >>> model = DfkdeModel( + ... backbone="resnet18", + ... layers=["layer4"], + ... pre_trained=True + ... ) + >>> batch = torch.randn(32, 3, 224, 224) + >>> features = model(batch) """ def __init__( @@ -53,13 +91,21 @@ def __init__( ) def get_features(self, batch: torch.Tensor) -> torch.Tensor: - """Extract features from the pretrained network. + """Extract features from the pre-trained backbone network. Args: - batch (torch.Tensor): Image batch. + batch (torch.Tensor): Batch of input images with shape + ``(N, C, H, W)``. Returns: - Tensor: torch.Tensor containing extracted features. + torch.Tensor: Concatenated features from specified layers, flattened + to shape ``(N, D)`` where ``D`` is the total feature dimension. + + Example: + >>> batch = torch.randn(32, 3, 224, 224) + >>> features = model.get_features(batch) + >>> features.shape + torch.Size([32, 512]) # Depends on backbone and layers """ self.feature_extractor.eval() layer_outputs = self.feature_extractor(batch) @@ -70,13 +116,27 @@ def get_features(self, batch: torch.Tensor) -> torch.Tensor: return torch.cat(list(layer_outputs.values())).detach() def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: - """Prediction by normality model. + """Extract features during training or compute anomaly scores during inference. Args: - batch (torch.Tensor): Input images. + batch (torch.Tensor): Batch of input images with shape + ``(N, C, H, W)``. Returns: - Tensor: Predictions + torch.Tensor | InferenceBatch: During training, returns extracted + features as a tensor. During inference, returns an + ``InferenceBatch`` containing anomaly scores. + + Example: + >>> batch = torch.randn(32, 3, 224, 224) + >>> # Training mode + >>> model.train() + >>> features = model(batch) + >>> # Inference mode + >>> model.eval() + >>> predictions = model(batch) + >>> predictions.pred_score.shape + torch.Size([32]) """ # 1. apply feature extraction features = self.get_features(batch) diff --git a/src/anomalib/models/image/dfm/__init__.py b/src/anomalib/models/image/dfm/__init__.py index c003420afc..2aba3c62d4 100644 --- a/src/anomalib/models/image/dfm/__init__.py +++ b/src/anomalib/models/image/dfm/__init__.py @@ -1,4 +1,24 @@ -"""Deep Feature Extraction (DFM) model.""" +"""Deep Feature Matching (DFM) model for anomaly detection. + +The DFM model extracts deep features from images using a pre-trained CNN backbone +and matches these features against a memory bank of normal samples to detect +anomalies. During inference, samples with high feature matching distances are +flagged as anomalous. + +Example: + >>> from anomalib.models.image import Dfm + >>> model = Dfm() + +The model can be used with any of the supported datasets and task modes in +anomalib. + +Notes: + The model implementation is available in the ``lightning_model`` module. + +See Also: + :class:`anomalib.models.image.dfm.lightning_model.Dfm`: + Lightning implementation of the DFM model. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/dfm/lightning_model.py b/src/anomalib/models/image/dfm/lightning_model.py index 1bdad50e1e..f380dab428 100644 --- a/src/anomalib/models/image/dfm/lightning_model.py +++ b/src/anomalib/models/image/dfm/lightning_model.py @@ -1,6 +1,28 @@ -"""DFM: Deep Feature Modeling. - -https://arxiv.org/abs/1909.11786 +"""Deep Feature Modeling (DFM) for anomaly detection. + +This module provides a PyTorch Lightning implementation of the DFM model for +anomaly detection. The model extracts deep features from images using a +pre-trained CNN backbone and fits a Gaussian model on these features to detect +anomalies. + +Paper: https://arxiv.org/abs/1909.11786 + +Example: + >>> from anomalib.models.image import Dfm + >>> model = Dfm( + ... backbone="resnet50", + ... layer="layer3", + ... pre_trained=True + ... ) + +Notes: + The model uses a pre-trained backbone to extract features and fits a PCA + transformation followed by a Gaussian model during training. No gradient + updates are performed on the backbone. + +See Also: + :class:`anomalib.models.image.dfm.torch_model.DFMModel`: + PyTorch implementation of the DFM model. """ # Copyright (C) 2022-2024 Intel Corporation @@ -26,24 +48,40 @@ class Dfm(MemoryBankMixin, AnomalibModule): - """DFM: Deep Featured Kernel Density Estimation. + """DFM Lightning Module. Args: - backbone (str): Backbone CNN network + backbone (str): Name of the backbone CNN network. Defaults to ``"resnet50"``. - layer (str): Layer to extract features from the backbone CNN + layer (str): Name of the layer to extract features from the backbone. Defaults to ``"layer3"``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + pre_trained (bool, optional): Whether to use a pre-trained backbone. Defaults to ``True``. - pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. + pooling_kernel_size (int, optional): Kernel size for pooling features. Defaults to ``4``. - pca_level (float, optional): Ratio from which number of components for PCA are calculated. + pca_level (float, optional): Ratio of variance to preserve in PCA. + Must be between 0 and 1. Defaults to ``0.97``. - score_type (str, optional): Scoring type. Options are `fre` and `nll`. - Defaults to ``fre``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + score_type (str, optional): Type of anomaly score to compute. + Options are ``"fre"`` (feature reconstruction error) or + ``"nll"`` (negative log-likelihood). + Defaults to ``"fre"``. + pre_processor (PreProcessor | bool, optional): Pre-processor to use. + If ``True``, uses the default pre-processor. + If ``False``, no pre-processing is performed. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor to use. + If ``True``, uses the default post-processor. + If ``False``, no post-processing is performed. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator to use. + If ``True``, uses the default evaluator. + If ``False``, no evaluation is performed. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer to use. + If ``True``, uses the default visualizer. + If ``False``, no visualization is performed. + Defaults to ``True``. """ def __init__( @@ -79,21 +117,23 @@ def __init__( @staticmethod def configure_optimizers() -> None: # pylint: disable=arguments-differ - """DFM doesn't require optimization, therefore returns no optimizers.""" + """Configure optimizers for training. + + Returns: + None: DFM doesn't require optimization. + """ return def training_step(self, batch: Batch, *args, **kwargs) -> None: - """Perform the training step of DFM. - - For each batch, features are extracted from the CNN. + """Extract features from the input batch during training. Args: - batch (Batch): Input batch - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). Returns: - Deep CNN features. + torch.Tensor: Dummy loss tensor for compatibility. """ del args, kwargs # These variables are not used. @@ -104,7 +144,11 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: - """Fit a PCA transformation and a Gaussian model to dataset.""" + """Fit the PCA transformation and Gaussian model to the embeddings. + + The method aggregates embeddings collected during training and fits + both the PCA transformation and Gaussian model used for scoring. + """ logger.info("Aggregating the embedding extracted from the training set.") embeddings = torch.vstack(self.embeddings) @@ -112,17 +156,15 @@ def fit(self) -> None: self.model.fit(embeddings) def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the validation step of DFM. - - Similar to the training step, features are extracted from the CNN for each batch. + """Compute predictions for the input batch during validation. Args: - batch (Batch): Input batch - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). Returns: - Dictionary containing FRE anomaly scores and anomaly maps. + STEP_OUTPUT: Dictionary containing anomaly scores and maps. """ del args, kwargs # These variables are not used. @@ -131,14 +173,21 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Return DFM-specific trainer arguments.""" + """Get DFM-specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary of trainer arguments: + - ``gradient_clip_val`` (int): Disable gradient clipping + - ``max_epochs`` (int): Train for one epoch only + - ``num_sanity_val_steps`` (int): Skip validation sanity checks + """ return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: The model uses one-class learning. """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index 520cbf8196..7ad516e35f 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -1,4 +1,26 @@ -"""PyTorch model for DFM model implementation.""" +"""PyTorch model for Deep Feature Modeling (DFM). + +This module provides a PyTorch implementation of the DFM model for anomaly +detection. The model extracts deep features from images using a pre-trained CNN +backbone and fits a Gaussian model on these features to detect anomalies. + +Example: + >>> import torch + >>> from anomalib.models.image.dfm.torch_model import DFMModel + >>> model = DFMModel( + ... backbone="resnet18", + ... layer="layer4", + ... pre_trained=True + ... ) + >>> batch = torch.randn(32, 3, 224, 224) + >>> features = model(batch) # Returns features during training + >>> predictions = model(batch) # Returns scores during inference + +Notes: + The model uses a pre-trained backbone to extract features and fits a PCA + transformation followed by a Gaussian model during training. No gradient + updates are performed on the backbone. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -14,9 +36,20 @@ class SingleClassGaussian(DynamicBufferMixin): - """Model Gaussian distribution over a set of points.""" + """Model Gaussian distribution over a set of points. + + This class fits a single Gaussian distribution to a set of feature vectors + and computes likelihood scores for new samples. + + Example: + >>> gaussian = SingleClassGaussian() + >>> features = torch.randn(128, 100) # 100 samples of 128 dimensions + >>> gaussian.fit(features) + >>> scores = gaussian.score_samples(features) + """ def __init__(self) -> None: + """Initialize Gaussian model with empty buffers.""" super().__init__() self.register_buffer("mean_vec", torch.Tensor()) self.register_buffer("u_mat", torch.Tensor()) @@ -29,16 +62,14 @@ def __init__(self) -> None: def fit(self, dataset: torch.Tensor) -> None: """Fit a Gaussian model to dataset X. - Covariance matrix is not calculated directly using: - ``C = X.X^T`` - Instead, it is represented in terms of the Singular Value Decomposition of X: - ``X = U.S.V^T`` - Hence, - ``C = U.S^2.U^T`` - This simplifies the calculation of the log-likelihood without requiring full matrix inversion. + Covariance matrix is not calculated directly using ``C = X.X^T``. + Instead, it is represented using SVD of X: ``X = U.S.V^T``. + Hence, ``C = U.S^2.U^T``. This simplifies the calculation of the + log-likelihood without requiring full matrix inversion. Args: - dataset (torch.Tensor): Input dataset to fit the model. + dataset (torch.Tensor): Input dataset to fit the model with shape + ``(n_features, n_samples)``. """ num_samples = dataset.shape[1] self.mean_vec = torch.mean(dataset, dim=1, device=dataset.device) @@ -46,43 +77,57 @@ def fit(self, dataset: torch.Tensor) -> None: self.u_mat, self.sigma_mat, _ = torch.linalg.svd(data_centered, full_matrices=False) def score_samples(self, features: torch.Tensor) -> torch.Tensor: - """Compute the NLL (negative log likelihood) scores. + """Compute the negative log likelihood (NLL) scores. Args: - features (torch.Tensor): semantic features on which density modeling is performed. + features (torch.Tensor): Semantic features on which density modeling + is performed with shape ``(n_samples, n_features)``. Returns: - nll (torch.Tensor): Torch tensor of scores + torch.Tensor: NLL scores for each sample. """ features_transformed = torch.matmul(features - self.mean_vec, self.u_mat / self.sigma_mat) return torch.sum(features_transformed * features_transformed, dim=1) + 2 * torch.sum(torch.log(self.sigma_mat)) def forward(self, dataset: torch.Tensor) -> None: - """Provide the same functionality as `fit`. + """Fit the model to the input dataset. Transforms the input dataset based on singular values calculated earlier. Args: - dataset (torch.Tensor): Input dataset + dataset (torch.Tensor): Input dataset with shape + ``(n_features, n_samples)``. """ self.fit(dataset) class DFMModel(nn.Module): - """Model for the DFM algorithm. + """Deep Feature Modeling (DFM) model for anomaly detection. + + The model extracts deep features from images using a pre-trained CNN backbone + and fits a Gaussian model on these features to detect anomalies. Args: - backbone (str): Pre-trained model backbone. + backbone (str): Pre-trained model backbone from timm. layer (str): Layer from which to extract features. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + pre_trained (bool, optional): Whether to use pre-trained backbone. Defaults to ``True``. - pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. + pooling_kernel_size (int, optional): Kernel size to pool features. Defaults to ``4``. - n_comps (float, optional): Ratio from which number of components for PCA are calculated. + n_comps (float, optional): Ratio for PCA components calculation. Defaults to ``0.97``. - score_type (str, optional): Scoring type. Options are `fre` and `nll`. Anomaly - Defaults to ``fre``. Segmentation is supported with `fre` only. - If using `nll`, set `task` in config.yaml to classification Defaults to ``classification``. + score_type (str, optional): Scoring type - ``fre`` or ``nll``. + Defaults to ``fre``. Segmentation supported with ``fre`` only. + For ``nll``, set task to classification. + + Example: + >>> model = DFMModel( + ... backbone="resnet18", + ... layer="layer4", + ... pre_trained=True + ... ) + >>> input_tensor = torch.randn(32, 3, 224, 224) + >>> output = model(input_tensor) """ def __init__( @@ -109,10 +154,11 @@ def __init__( ).eval() def fit(self, dataset: torch.Tensor) -> None: - """Fit a pca transformation and a Gaussian model to dataset. + """Fit PCA and Gaussian model to dataset. Args: - dataset (torch.Tensor): Input dataset to fit the model. + dataset (torch.Tensor): Input dataset with shape + ``(n_samples, n_features)``. """ self.pca_model.fit(dataset) if self.score_type == "nll": @@ -120,17 +166,19 @@ def fit(self, dataset: torch.Tensor) -> None: self.gaussian_model.fit(features_reduced.T) def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: - """Compute scores. + """Compute anomaly scores. Scores are either PCA-based feature reconstruction error (FRE) scores or - the Gaussian density-based NLL scores + Gaussian density-based NLL scores. Args: - features (torch.Tensor): semantic features on which PCA and density modeling is performed. - feature_shapes (tuple): shape of `features` tensor. Used to generate anomaly map of correct shape. + features (torch.Tensor): Features for scoring with shape + ``(n_samples, n_features)``. + feature_shapes (tuple): Shape of features tensor for anomaly map. Returns: - score (torch.Tensor): numpy array of scores + tuple[torch.Tensor, Optional[torch.Tensor]]: Tuple containing + (scores, anomaly_maps). Anomaly maps are None for NLL scoring. """ feats_projected = self.pca_model.transform(features) if self.score_type == "nll": @@ -150,10 +198,12 @@ def get_features(self, batch: torch.Tensor) -> torch.Tensor: """Extract features from the pretrained network. Args: - batch (torch.Tensor): Image batch. + batch (torch.Tensor): Input images with shape + ``(batch_size, channels, height, width)``. Returns: - Tensor: torch.Tensor containing extracted features. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Size]]: Features during + training, or tuple of (features, feature_shapes) during inference. """ self.feature_extractor.eval() features = self.feature_extractor(batch)[self.layer] @@ -165,13 +215,16 @@ def get_features(self, batch: torch.Tensor) -> torch.Tensor: return features if self.training else (features, feature_shapes) def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: - """Compute score from input images. + """Compute anomaly predictions from input images. Args: - batch (torch.Tensor): Input images + batch (torch.Tensor): Input images with shape + ``(batch_size, channels, height, width)``. Returns: - Tensor: Scores + Union[torch.Tensor, InferenceBatch]: Model predictions. During + training returns features tensor. During inference returns + ``InferenceBatch`` with prediction scores and anomaly maps. """ feature_vector, feature_shapes = self.get_features(batch) pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes) diff --git a/src/anomalib/models/image/draem/__init__.py b/src/anomalib/models/image/draem/__init__.py index 4c8b06fa1d..945f5c9016 100644 --- a/src/anomalib/models/image/draem/__init__.py +++ b/src/anomalib/models/image/draem/__init__.py @@ -1,4 +1,23 @@ -"""DRAEM model.""" +"""DRAEM (Data-efficient Anomaly Detection and Localization) model. + +The DRAEM model uses a dual-branch architecture with a reconstruction branch and +a segmentation branch to detect and localize anomalies. It is trained using +synthetic anomalies generated by augmenting normal images. + +Example: + >>> from anomalib.models.image import Draem + >>> model = Draem() + +The model can be used with any of the supported datasets and task modes in +anomalib. + +Notes: + The model implementation is available in the ``lightning_model`` module. + +See Also: + :class:`anomalib.models.image.draem.lightning_model.Draem`: + Lightning implementation of the DRAEM model. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/draem/lightning_model.py b/src/anomalib/models/image/draem/lightning_model.py index 84b143f3f5..98a0568ce2 100644 --- a/src/anomalib/models/image/draem/lightning_model.py +++ b/src/anomalib/models/image/draem/lightning_model.py @@ -1,6 +1,13 @@ -"""DRÆM - A discriminatively trained reconstruction embedding for surface anomaly detection. +"""DRÆM. + +A discriminatively trained reconstruction embedding for surface anomaly +detection. Paper https://arxiv.org/abs/2108.07610 + +This module implements the DRÆM model for surface anomaly detection. DRÆM uses a +discriminatively trained reconstruction embedding approach to detect anomalies by +comparing input images with their reconstructions. """ # Copyright (C) 2022-2024 Intel Corporation @@ -30,19 +37,38 @@ class Draem(AnomalibModule): - """DRÆM: A discriminatively trained reconstruction embedding for surface anomaly detection. + """DRÆM. + + A discriminatively trained reconstruction embedding for + surface anomaly detection. + + The model consists of two main components: + 1. A reconstruction network that learns to reconstruct normal images + 2. A discriminative network that learns to identify anomalous regions Args: - enable_sspcab (bool): Enable SSPCAB training. + enable_sspcab (bool, optional): Enable SSPCAB training. Defaults to ``False``. - sspcab_lambda (float): SSPCAB loss weight. + sspcab_lambda (float, optional): Weight factor for SSPCAB loss. Defaults to ``0.1``. - anomaly_source_path (str | None): Path to folder that contains the anomaly source images. Random noise will - be used if left empty. - Defaults to ``None``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. + anomaly_source_path (str | None, optional): Path to directory containing + anomaly source images. If ``None``, random noise is used. Defaults to ``None``. + beta (float | tuple[float, float], optional): Blend factor for anomaly + generation. If tuple, represents range for random sampling. + Defaults to ``(0.1, 1.0)``. + pre_processor (PreProcessor | bool, optional): Pre-processor instance or + flag to use default. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance + or flag to use default. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or flag to + use default. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or flag to + use default. + Defaults to ``True``. """ def __init__( @@ -75,22 +101,30 @@ def __init__( self.sspcab_lambda = sspcab_lambda def setup_sspcab(self) -> None: - """Prepare the model for the SSPCAB training step by adding forward hooks for the SSPCAB layer activations.""" + """Set up SSPCAB forward hooks. + + Prepares the model for SSPCAB training by adding forward hooks to capture + layer activations from specific points in the network. + """ def get_activation(name: str) -> Callable: - """Retrieve the activations. + """Create a hook function to retrieve layer activations. Args: - name (str): Identifier for the retrieved activations. + name (str): Identifier for storing the activation in the + activation dictionary. + + Returns: + Callable: Hook function that stores layer activations. """ def hook(_, __, output: torch.Tensor) -> None: # noqa: ANN001 - """Create hook for retrieving the activations. + """Store layer activations during forward pass. Args: - _: Placeholder for the module input. - __: Placeholder for the module output. - output (torch.Tensor): The output tensor of the module. + _: Unused module argument. + __: Unused input argument. + output (torch.Tensor): Output tensor from the layer. """ self.sspcab_activations[name] = output @@ -100,18 +134,20 @@ def hook(_, __, output: torch.Tensor) -> None: # noqa: ANN001 self.model.reconstructive_subnetwork.encoder.block5.register_forward_hook(get_activation("output")) def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the training step of DRAEM. + """Perform training step for DRAEM. - Feeds the original image and the simulated anomaly - image through the network and computes the training loss. + The step consists of: + 1. Generating simulated anomalies + 2. Computing reconstructions and predictions + 3. Calculating the loss Args: - batch (Batch): Batch containing image filename, image, label and mask - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and metadata. + args: Additional positional arguments (unused). + kwargs: Additional keyword arguments (unused). Returns: - Loss dictionary + STEP_OUTPUT: Dictionary containing the training loss. """ del args, kwargs # These variables are not used. @@ -133,15 +169,17 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: return {"loss": loss} def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the validation step of DRAEM. The Softmax predictions of the anomalous class are used as anomaly map. + """Perform validation step for DRAEM. + + Uses softmax predictions of the anomalous class as anomaly maps. Args: - batch (Batch): Batch of input images - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and metadata. + args: Additional positional arguments (unused). + kwargs: Additional keyword arguments (unused). Returns: - Dictionary to which predicted anomaly maps have been added. + STEP_OUTPUT: Dictionary containing predictions and metadata. """ del args, kwargs # These variables are not used. @@ -150,27 +188,49 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Return DRÆM-specific trainer arguments.""" + """Get DRÆM-specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary containing trainer arguments: + - gradient_clip_val: ``0`` + - num_sanity_val_steps: ``0`` + """ return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure the Adam optimizer.""" + """Configure optimizer and learning rate scheduler. + + Returns: + tuple[list[Adam], list[MultiStepLR]]: Tuple containing optimizer and + scheduler lists. + """ optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.0001) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400, 600], gamma=0.1) return [optimizer], [scheduler] @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: The learning type (``LearningType.ONE_CLASS``). """ return LearningType.ONE_CLASS @staticmethod def configure_transforms(image_size: tuple[int, int] | None = None) -> Transform: - """Default transform for DRAEM. Normalization is not needed as the images are scaled to [0, 1] in Dataset.""" + """Configure default transforms for DRAEM. + + Note: + Normalization is not needed as images are scaled to [0, 1] in Dataset. + + Args: + image_size (tuple[int, int] | None, optional): Target size for image + resizing. Defaults to ``(256, 256)``. + + Returns: + Transform: Composed transform including resizing. + """ image_size = image_size or (256, 256) return Compose( [ diff --git a/src/anomalib/models/image/draem/loss.py b/src/anomalib/models/image/draem/loss.py index 1cef702e15..2e65d97ecf 100644 --- a/src/anomalib/models/image/draem/loss.py +++ b/src/anomalib/models/image/draem/loss.py @@ -1,4 +1,19 @@ -"""Loss function for the DRAEM model implementation.""" +"""Loss function for the DRAEM model implementation. + +This module implements the loss function used to train the DRAEM model for anomaly +detection. The loss combines L2 reconstruction loss, focal loss for anomaly +segmentation, and structural similarity (SSIM) loss. + +Example: + >>> import torch + >>> from anomalib.models.image.draem.loss import DraemLoss + >>> criterion = DraemLoss() + >>> input_image = torch.randn(8, 3, 256, 256) + >>> reconstruction = torch.randn(8, 3, 256, 256) + >>> anomaly_mask = torch.randint(0, 2, (8, 1, 256, 256)) + >>> prediction = torch.randn(8, 2, 256, 256) + >>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction) +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,11 +26,20 @@ class DraemLoss(nn.Module): """Overall loss function of the DRAEM model. - The total loss consists of the sum of the L2 loss and Focal loss between the reconstructed image and the input - image, and the Structural Similarity loss between the predicted and GT anomaly masks. + The total loss consists of three components: + 1. L2 loss between the reconstructed and input images + 2. Focal loss between predicted and ground truth anomaly masks + 3. Structural Similarity (SSIM) loss between reconstructed and input images + + The final loss is computed as: ``loss = l2_loss + ssim_loss + focal_loss`` + + Example: + >>> criterion = DraemLoss() + >>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction) """ def __init__(self) -> None: + """Initialize loss components with default parameters.""" super().__init__() self.l2_loss = nn.modules.loss.MSELoss() @@ -29,7 +53,21 @@ def forward( anomaly_mask: torch.Tensor, prediction: torch.Tensor, ) -> torch.Tensor: - """Compute the loss over a batch for the DRAEM model.""" + """Compute the combined loss over a batch for the DRAEM model. + + Args: + input_image: Original input images of shape + ``(batch_size, num_channels, height, width)`` + reconstruction: Reconstructed images from the model of shape + ``(batch_size, num_channels, height, width)`` + anomaly_mask: Ground truth anomaly masks of shape + ``(batch_size, 1, height, width)`` + prediction: Model predictions of shape + ``(batch_size, num_classes, height, width)`` + + Returns: + torch.Tensor: Combined loss value + """ l2_loss_val = self.l2_loss(reconstruction, input_image) focal_loss_val = self.focal_loss(prediction, anomaly_mask.squeeze(1).long()) ssim_loss_val = self.ssim_loss(reconstruction, input_image) * 2 diff --git a/src/anomalib/models/image/draem/torch_model.py b/src/anomalib/models/image/draem/torch_model.py index 3ce080aca5..5ef1d7eba6 100644 --- a/src/anomalib/models/image/draem/torch_model.py +++ b/src/anomalib/models/image/draem/torch_model.py @@ -1,4 +1,10 @@ -"""PyTorch model for the DRAEM model implementation.""" +"""PyTorch model for the DRAEM model implementation. + +The DRAEM model consists of two sub-networks: +1. A reconstructive sub-network that learns to reconstruct input images +2. A discriminative sub-network that detects anomalies by comparing original and + reconstructed images +""" # Original Code # Copyright (c) 2021 VitjanZ @@ -17,11 +23,15 @@ class DraemModel(nn.Module): - """DRAEM PyTorch model consisting of the reconstructive and discriminative sub networks. + """DRAEM PyTorch model with reconstructive and discriminative sub-networks. Args: - sspcab (bool): Enable SSPCAB training. - Defaults to ``False``. + sspcab (bool, optional): Enable SSPCAB training. Defaults to ``False``. + + Example: + >>> model = DraemModel(sspcab=True) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> reconstruction, prediction = model(input_tensor) """ def __init__(self, sspcab: bool = False) -> None: @@ -30,14 +40,27 @@ def __init__(self, sspcab: bool = False) -> None: self.discriminative_subnetwork = DiscriminativeSubNetwork(in_channels=6, out_channels=2) def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | InferenceBatch: - """Compute the reconstruction and anomaly mask from an input image. + """Forward pass through both sub-networks. Args: - batch (torch.Tensor): batch of input images + batch (torch.Tensor): Input batch of images of shape + ``(batch_size, channels, height, width)`` Returns: - Predicted confidence values of the anomaly mask. During training the reconstructed input images are - returned as well. + During training: + tuple: Tuple containing: + - Reconstructed images + - Predicted anomaly masks + During inference: + InferenceBatch: Contains anomaly map and prediction score + + Example: + >>> model = DraemModel() + >>> batch = torch.randn(32, 3, 256, 256) + >>> reconstruction, prediction = model(batch) # Training mode + >>> model.eval() + >>> output = model(batch) # Inference mode + >>> assert isinstance(output, InferenceBatch) """ reconstruction = self.reconstructive_subnetwork(batch) concatenated_inputs = torch.cat([batch, reconstruction], axis=1) @@ -51,17 +74,21 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, tor class ReconstructiveSubNetwork(nn.Module): - """Autoencoder model that encodes and reconstructs the input image. + """Autoencoder model for image reconstruction. Args: - in_channels (int): Number of input channels. - Defaults to ``3``. - out_channels (int): Number of output channels. - Defaults to ``3``. - base_width (int): Base dimensionality of the layers of the autoencoder. - Defaults to ``128``. - sspcab (bool): Enable SSPCAB training. - Defaults to ``False``. + in_channels (int, optional): Number of input channels. Defaults to ``3``. + out_channels (int, optional): Number of output channels. Defaults to ``3``. + base_width (int, optional): Base dimensionality of layers. Defaults to + ``128``. + sspcab (bool, optional): Enable SSPCAB training. Defaults to ``False``. + + Example: + >>> subnet = ReconstructiveSubNetwork(in_channels=3, base_width=64) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = subnet(input_tensor) + >>> output.shape + torch.Size([32, 3, 256, 256]) """ def __init__( @@ -76,28 +103,37 @@ def __init__( self.decoder = DecoderReconstructive(base_width, out_channels=out_channels) def forward(self, batch: torch.Tensor) -> torch.Tensor: - """Encode and reconstruct the input images. + """Encode and reconstruct input images. Args: - batch (torch.Tensor): Batch of input images + batch (torch.Tensor): Batch of input images of shape + ``(batch_size, channels, height, width)`` Returns: - Batch of reconstructed images. + torch.Tensor: Batch of reconstructed images of same shape as input """ encoded = self.encoder(batch) return self.decoder(encoded) class DiscriminativeSubNetwork(nn.Module): - """Discriminative model that predicts the anomaly mask from the original image and its reconstruction. + """Discriminative model for anomaly mask prediction. + + Compares original images with their reconstructions to predict anomaly masks. Args: - in_channels (int): Number of input channels. - Defaults to ``3``. - out_channels (int): Number of output channels. - Defaults to ``3``. - base_width (int): Base dimensionality of the layers of the autoencoder. - Defaults to ``64``. + in_channels (int, optional): Number of input channels. Defaults to ``3``. + out_channels (int, optional): Number of output channels. Defaults to ``3``. + base_width (int, optional): Base dimensionality of layers. Defaults to + ``64``. + + Example: + >>> subnet = DiscriminativeSubNetwork(in_channels=6, out_channels=2) + >>> # Concatenated original and reconstructed images + >>> input_tensor = torch.randn(32, 6, 256, 256) + >>> output = subnet(input_tensor) + >>> output.shape + torch.Size([32, 2, 256, 256]) """ def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width: int = 64) -> None: @@ -106,25 +142,32 @@ def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width: int self.decoder_segment = DecoderDiscriminative(base_width, out_channels=out_channels) def forward(self, batch: torch.Tensor) -> torch.Tensor: - """Generate the predicted anomaly masks for a batch of input images. + """Generate predicted anomaly masks. Args: - batch (torch.Tensor): Batch of inputs consisting of the concatenation of the original images - and their reconstructions. + batch (torch.Tensor): Concatenated original and reconstructed images of + shape ``(batch_size, channels*2, height, width)`` Returns: - Activations of the output layer corresponding to the normal and anomalous class scores on the pixel level. + torch.Tensor: Pixel-level class scores for normal and anomalous regions """ act1, act2, act3, act4, act5, act6 = self.encoder_segment(batch) return self.decoder_segment(act1, act2, act3, act4, act5, act6) class EncoderDiscriminative(nn.Module): - """Encoder part of the discriminator network. + """Encoder component of the discriminator network. Args: - in_channels (int): Number of input channels. - base_width (int): Base dimensionality of the layers of the autoencoder. + in_channels (int): Number of input channels + base_width (int): Base dimensionality of the layers + + Example: + >>> encoder = EncoderDiscriminative(in_channels=6, base_width=64) + >>> input_tensor = torch.randn(32, 6, 256, 256) + >>> outputs = encoder(input_tensor) + >>> len(outputs) # Returns 6 activation maps + 6 """ def __init__(self, in_channels: int, base_width: int) -> None: @@ -188,14 +231,14 @@ def forward( self, batch: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Convert the inputs to the salient space by running them through the encoder network. + """Convert inputs to salient space through encoder network. Args: - batch (torch.Tensor): Batch of inputs consisting of the concatenation of the original images - and their reconstructions. + batch (torch.Tensor): Input batch of concatenated original and + reconstructed images Returns: - Computed feature maps for each of the layers in the encoder sub network. + tuple: Contains 6 activation tensors from each encoder block """ act1 = self.block1(batch) mp1 = self.mp1(act1) @@ -212,12 +255,19 @@ def forward( class DecoderDiscriminative(nn.Module): - """Decoder part of the discriminator network. + """Decoder component of the discriminator network. Args: - base_width (int): Base dimensionality of the layers of the autoencoder. - out_channels (int): Number of output channels. - Defaults to ``1``. + base_width (int): Base dimensionality of the layers + out_channels (int, optional): Number of output channels. Defaults to ``1`` + + Example: + >>> decoder = DecoderDiscriminative(base_width=64, out_channels=2) + >>> # Create 6 mock activation tensors + >>> acts = [torch.randn(32, 64, 256>>i, 256>>i) for i in range(6)] + >>> output = decoder(*acts) + >>> output.shape + torch.Size([32, 2, 256, 256]) """ def __init__(self, base_width: int, out_channels: int = 1) -> None: @@ -309,18 +359,18 @@ def forward( act5: torch.Tensor, act6: torch.Tensor, ) -> torch.Tensor: - """Compute predicted anomaly class scores from the intermediate outputs of the encoder sub network. + """Compute predicted anomaly scores from encoder activations. Args: - act1 (torch.Tensor): Encoder activations of the first block of convolutional layers. - act2 (torch.Tensor): Encoder activations of the second block of convolutional layers. - act3 (torch.Tensor): Encoder activations of the third block of convolutional layers. - act4 (torch.Tensor): Encoder activations of the fourth block of convolutional layers. - act5 (torch.Tensor): Encoder activations of the fifth block of convolutional layers. - act6 (torch.Tensor): Encoder activations of the sixth block of convolutional layers. + act1 (torch.Tensor): First block encoder activations + act2 (torch.Tensor): Second block encoder activations + act3 (torch.Tensor): Third block encoder activations + act4 (torch.Tensor): Fourth block encoder activations + act5 (torch.Tensor): Fifth block encoder activations + act6 (torch.Tensor): Sixth block encoder activations Returns: - Predicted anomaly class scores per pixel. + torch.Tensor: Predicted anomaly scores per pixel """ up_b = self.up_b(act6) cat_b = torch.cat((up_b, act5), dim=1) @@ -346,13 +396,19 @@ def forward( class EncoderReconstructive(nn.Module): - """Encoder part of the reconstructive network. + """Encoder component of the reconstructive network. Args: - in_channels (int): Number of input channels. - base_width (int): Base dimensionality of the layers of the autoencoder. - sspcab (bool): Enable SSPCAB training. - Defaults to ``False``. + in_channels (int): Number of input channels + base_width (int): Base dimensionality of the layers + sspcab (bool, optional): Enable SSPCAB training. Defaults to ``False`` + + Example: + >>> encoder = EncoderReconstructive(in_channels=3, base_width=64) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = encoder(input_tensor) + >>> output.shape + torch.Size([32, 512, 16, 16]) """ def __init__(self, in_channels: int, base_width: int, sspcab: bool = False) -> None: @@ -406,13 +462,14 @@ def __init__(self, in_channels: int, base_width: int, sspcab: bool = False) -> N ) def forward(self, batch: torch.Tensor) -> torch.Tensor: - """Encode a batch of input images to the salient space. + """Encode input images to the salient space. Args: - batch (torch.Tensor): Batch of input images. + batch (torch.Tensor): Input batch of images of shape + ``(batch_size, channels, height, width)`` Returns: - Feature maps extracted from the bottleneck layer. + torch.Tensor: Feature maps from the bottleneck layer """ act1 = self.block1(batch) mp1 = self.mp1(act1) @@ -426,12 +483,18 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: class DecoderReconstructive(nn.Module): - """Decoder part of the reconstructive network. + """Decoder component of the reconstructive network. Args: - base_width (int): Base dimensionality of the layers of the autoencoder. - out_channels (int): Number of output channels. - Defaults to ``1``. + base_width (int): Base dimensionality of the layers + out_channels (int, optional): Number of output channels. Defaults to ``1`` + + Example: + >>> decoder = DecoderReconstructive(base_width=64, out_channels=3) + >>> input_tensor = torch.randn(32, 512, 16, 16) + >>> output = decoder(input_tensor) + >>> output.shape + torch.Size([32, 3, 256, 256]) """ def __init__(self, base_width: int, out_channels: int = 1) -> None: @@ -501,13 +564,14 @@ def __init__(self, base_width: int, out_channels: int = 1) -> None: self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) def forward(self, act5: torch.Tensor) -> torch.Tensor: - """Reconstruct the image from the activations of the bottleneck layer. + """Reconstruct image from bottleneck features. Args: - act5 (torch.Tensor): Activations of the bottleneck layer. + act5 (torch.Tensor): Activations from the bottleneck layer of shape + ``(batch_size, channels, height, width)`` Returns: - Batch of reconstructed images. + torch.Tensor: Reconstructed images of same size as original input """ up1 = self.up1(act5) db1 = self.db1(up1) diff --git a/src/anomalib/models/image/dsr/__init__.py b/src/anomalib/models/image/dsr/__init__.py index 54e53d5d6f..e54bfc7c82 100644 --- a/src/anomalib/models/image/dsr/__init__.py +++ b/src/anomalib/models/image/dsr/__init__.py @@ -1,4 +1,23 @@ -"""DSR model.""" +"""Deep Spatial Reconstruction (DSR) model. + +DSR is an anomaly detection model that uses a deep autoencoder architecture to +learn spatial reconstructions of normal images. The model learns to reconstruct +normal patterns and identifies anomalies based on reconstruction errors. + +Example: + >>> from anomalib.models.image import Dsr + >>> model = Dsr() + +The model can be used with any of the supported datasets and task modes in +anomalib. + +Notes: + The model implementation is available in the ``lightning_model`` module. + +See Also: + :class:`anomalib.models.image.dsr.lightning_model.Dsr`: + Lightning implementation of the DSR model. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/dsr/anomaly_generator.py b/src/anomalib/models/image/dsr/anomaly_generator.py index 2d1d5c4a75..b4c884a9db 100644 --- a/src/anomalib/models/image/dsr/anomaly_generator.py +++ b/src/anomalib/models/image/dsr/anomaly_generator.py @@ -1,4 +1,15 @@ -"""Anomaly generator for the DSR model implementation.""" +"""Anomaly generator for the DSR model implementation. + +This module implements an anomaly generator that creates synthetic anomalies +using Perlin noise. The generator is used during the second phase of DSR model +training to create anomalous samples. + +Example: + >>> from anomalib.models.image.dsr.anomaly_generator import DsrAnomalyGenerator + >>> generator = DsrAnomalyGenerator(p_anomalous=0.5) + >>> batch = torch.randn(8, 3, 256, 256) + >>> masks = generator.augment_batch(batch) +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,14 +22,21 @@ class DsrAnomalyGenerator(nn.Module): - """Anomaly generator of the DSR model. + """Anomaly generator for the DSR model. - The anomaly is generated using a Perlin noise generator on the two quantized representations of an image. - This generator is only used during the second phase of training! The third phase requires generating - smudges over the input images. + The generator creates synthetic anomalies by applying Perlin noise to images. + It is used during the second phase of DSR model training. The third phase + uses a different approach with smudge-based anomalies. Args: - p_anomalous (float, optional): Probability to generate an anomalous image. + p_anomalous (float, optional): Probability of generating an anomalous + image. Defaults to ``0.5``. + + Example: + >>> generator = DsrAnomalyGenerator(p_anomalous=0.7) + >>> batch = torch.randn(4, 3, 256, 256) + >>> masks = generator.augment_batch(batch) + >>> assert masks.shape == (4, 1, 256, 256) """ def __init__( @@ -32,14 +50,21 @@ def __init__( self.rot = v2.RandomAffine(degrees=(-90, 90)) def generate_anomaly(self, height: int, width: int) -> Tensor: - """Generate an anomalous mask. + """Generate an anomalous mask using Perlin noise. Args: - height (int): Height of generated mask. - width (int): Width of generated mask. + height (int): Height of the mask to generate. + width (int): Width of the mask to generate. Returns: - Tensor: Generated mask. + Tensor: Binary mask of shape ``(1, height, width)`` where ``1`` + indicates anomalous regions. + + Example: + >>> generator = DsrAnomalyGenerator() + >>> mask = generator.generate_anomaly(256, 256) + >>> assert mask.shape == (1, 256, 256) + >>> assert torch.all((mask >= 0) & (mask <= 1)) """ min_perlin_scale = 0 perlin_scale = 6 @@ -59,13 +84,23 @@ def generate_anomaly(self, height: int, width: int) -> Tensor: return mask.unsqueeze(0) # Add channel dimension [1, H, W] def augment_batch(self, batch: Tensor) -> Tensor: - """Generate anomalous augmentations for a batch of input images. + """Generate anomalous masks for a batch of images. Args: - batch (Tensor): Batch of input images + batch (Tensor): Input batch of images of shape + ``(batch_size, channels, height, width)``. Returns: - Tensor: Ground truth masks corresponding to the anomalous perturbations. + Tensor: Batch of binary masks of shape + ``(batch_size, 1, height, width)`` where ``1`` indicates + anomalous regions. + + Example: + >>> generator = DsrAnomalyGenerator() + >>> batch = torch.randn(8, 3, 256, 256) + >>> masks = generator.augment_batch(batch) + >>> assert masks.shape == (8, 1, 256, 256) + >>> assert torch.all((masks >= 0) & (masks <= 1)) """ batch_size, _, height, width = batch.shape diff --git a/src/anomalib/models/image/dsr/lightning_model.py b/src/anomalib/models/image/dsr/lightning_model.py index dd80e88ba7..86392b76a9 100644 --- a/src/anomalib/models/image/dsr/lightning_model.py +++ b/src/anomalib/models/image/dsr/lightning_model.py @@ -1,6 +1,33 @@ """DSR - A Dual Subspace Re-Projection Network for Surface Anomaly Detection. -Paper https://link.springer.com/chapter/10.1007/978-3-031-19821-2_31 +This module implements the DSR model for surface anomaly detection. DSR uses a dual +subspace re-projection approach to detect anomalies by comparing input images with +their reconstructions in two different subspaces. + +The model consists of three training phases: +1. A discrete model pre-training phase (using pre-trained weights) +2. Training of the main reconstruction and anomaly detection modules +3. Training of the upsampling module + +Paper: https://link.springer.com/chapter/10.1007/978-3-031-19821-2_31 + +Example: + >>> from anomalib.models.image import Dsr + >>> model = Dsr( + ... latent_anomaly_strength=0.2, + ... upsampling_train_ratio=0.7 + ... ) + +The model can be used with any of the supported datasets and task modes in +anomalib. + +Notes: + The model requires pre-trained weights for the discrete model which are + downloaded automatically during training. + +See Also: + :class:`anomalib.models.image.dsr.torch_model.DsrModel`: + PyTorch implementation of the DSR model architecture. """ # Copyright (C) 2023-2024 Intel Corporation @@ -33,7 +60,8 @@ WEIGHTS_DOWNLOAD_INFO = DownloadInfo( name="vq_model_pretrained_128_4096.pckl", - url="https://github.com/openvinotoolkit/anomalib/releases/download/dsr_pretrained_weights/dsr_vq_model_pretrained.zip", + url="https://github.com/openvinotoolkit/anomalib/releases/download/" + "dsr_pretrained_weights/dsr_vq_model_pretrained.zip", hashsum="52fe7504ec8e9df70b4382f287ab26269dcfe000cd7a7e146a52c6f146f34afb", ) @@ -41,12 +69,33 @@ class Dsr(AnomalibModule): """DSR: A Dual Subspace Re-Projection Network for Surface Anomaly Detection. + The model uses a dual subspace approach with three training phases: + 1. Pre-trained discrete model (loaded from weights) + 2. Training of reconstruction and anomaly detection modules + 3. Training of the upsampling module for final anomaly map generation + Args: - latent_anomaly_strength (float): Strength of the generated anomalies in the latent space. Defaults to 0.2 - upsampling_train_ratio (float): Ratio of training steps for the upsampling module. Defaults to 0.7 - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + latent_anomaly_strength (float, optional): Strength of the generated + anomalies in the latent space. Defaults to ``0.2``. + upsampling_train_ratio (float, optional): Ratio of training steps for + the upsampling module. Defaults to ``0.7``. + pre_processor (PreProcessor | bool, optional): Pre-processor instance or + flag to use default. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance + or flag to use default. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or flag to + use default. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or flag to + use default. Defaults to ``True``. + + Example: + >>> from anomalib.models.image import Dsr + >>> model = Dsr( + ... latent_anomaly_strength=0.2, + ... upsampling_train_ratio=0.7 + ... ) + >>> model.trainer_arguments + {'num_sanity_val_steps': 0} """ def __init__( @@ -78,7 +127,17 @@ def __init__( @staticmethod def prepare_pretrained_model() -> Path: - """Download pre-trained models if they don't exist.""" + """Download pre-trained models if they don't exist. + + Returns: + Path: Path to the downloaded pre-trained model weights. + + Example: + >>> model = Dsr() + >>> weights_path = model.prepare_pretrained_model() + >>> weights_path.name + 'vq_model_pretrained_128_4096.pckl' + """ pretrained_models_dir = Path("./pre_trained/") if not (pretrained_models_dir / "vq_model_pretrained_128_4096.pckl").is_file(): download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO) @@ -92,7 +151,16 @@ def configure_optimizers( Does not train the discrete model (phase 1) Returns: - dict[str, torch.optim.Optimizer | torch.optim.lr_scheduler.LRScheduler]: Dictionary of optimizers + dict[str, torch.optim.Optimizer | torch.optim.lr_scheduler.LRScheduler]: + Dictionary containing optimizers and schedulers. + + Example: + >>> model = Dsr() + >>> optimizers = model.configure_optimizers() + >>> isinstance(optimizers, tuple) + True + >>> len(optimizers) + 2 """ num_steps = max( self.trainer.max_steps // len(self.trainer.datamodule.train_dataloader()), @@ -126,19 +194,34 @@ def on_train_epoch_start(self) -> None: def training_step(self, batch: Batch) -> STEP_OUTPUT: """Training Step of DSR. - Feeds the original image and the simulated anomaly mask during first phase. During - second phase, feeds a generated anomalous image to train the upsampling module. + During the first phase, feeds the original image and simulated anomaly + mask. During second phase, feeds a generated anomalous image to train + the upsampling module. Args: - batch (Batch): Batch containing image filename, image, label and mask + batch (Batch): Input batch containing image, label and mask Returns: - STEP_OUTPUT: Loss dictionary + STEP_OUTPUT: Dictionary containing the loss value + + Example: + >>> from anomalib.data import Batch + >>> model = Dsr() + >>> batch = Batch( + ... image=torch.randn(8, 3, 256, 256), + ... label=torch.zeros(8) + ... ) + >>> output = model.training_step(batch) + >>> isinstance(output, dict) + True + >>> "loss" in output + True """ ph1_opt, ph2_opt = self.optimizers() if self.current_epoch < self.second_phase: - # we are not yet training the upsampling module: we are only using the first optimizer + # we are not yet training the upsampling module: we are only using + # the first optimizer input_image = batch.image # Create anomaly masks anomaly_mask = self.quantized_anomaly_generator.augment_batch(input_image) @@ -185,12 +268,23 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: The Softmax predictions of the anomalous class are used as anomaly map. Args: - batch (Batch): Batch of input images - *args: unused - **kwargs: unused + batch (Batch): Input batch containing image, label and mask + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - STEP_OUTPUT: Dictionary to which predicted anomaly maps have been added. + STEP_OUTPUT: Dictionary containing predictions and batch information + + Example: + >>> from anomalib.data import Batch + >>> model = Dsr() + >>> batch = Batch( + ... image=torch.randn(8, 3, 256, 256), + ... label=torch.zeros(8) + ... ) + >>> output = model.validation_step(batch) + >>> isinstance(output, Batch) + True """ del args, kwargs # These variables are not used. @@ -199,7 +293,16 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Required trainer arguments.""" + """Required trainer arguments. + + Returns: + dict[str, Any]: Dictionary of trainer arguments + + Example: + >>> model = Dsr() + >>> model.trainer_arguments + {'num_sanity_val_steps': 0} + """ return {"num_sanity_val_steps": 0} @property @@ -208,12 +311,33 @@ def learning_type(self) -> LearningType: Returns: LearningType: Learning type of the model. + + Example: + >>> model = Dsr() + >>> model.learning_type + """ return LearningType.ONE_CLASS @staticmethod def configure_transforms(image_size: tuple[int, int] | None = None) -> Transform: - """Default transform for DSR. Normalization is not needed as the images are scaled to [0, 1] in Dataset.""" + """Configure default transforms for DSR. + + Normalization is not needed as the images are scaled to [0, 1] in Dataset. + + Args: + image_size (tuple[int, int] | None, optional): Input image size. + Defaults to ``(256, 256)``. + + Returns: + Transform: Composed transforms + + Example: + >>> model = Dsr() + >>> transforms = model.configure_transforms((512, 512)) + >>> isinstance(transforms, Transform) + True + """ image_size = image_size or (256, 256) return Compose( [ diff --git a/src/anomalib/models/image/dsr/loss.py b/src/anomalib/models/image/dsr/loss.py index f1020b9d34..07a9a14578 100644 --- a/src/anomalib/models/image/dsr/loss.py +++ b/src/anomalib/models/image/dsr/loss.py @@ -1,4 +1,22 @@ -"""Loss function for the DSR model implementation.""" +"""Loss functions for the DSR model implementation. + +This module contains the loss functions used in the second and third training +phases of the DSR model. + +Example: + >>> from anomalib.models.image.dsr.loss import DsrSecondStageLoss + >>> loss_fn = DsrSecondStageLoss() + >>> loss = loss_fn( + ... recon_nq_hi=recon_nq_hi, + ... recon_nq_lo=recon_nq_lo, + ... qu_hi=qu_hi, + ... qu_lo=qu_lo, + ... input_image=input_image, + ... gen_img=gen_img, + ... seg=seg, + ... anomaly_mask=anomaly_mask + ... ) +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,13 +26,27 @@ class DsrSecondStageLoss(nn.Module): - """Overall loss function of the second training phase of the DSR model. - - The total loss consists of: - - MSE loss between non-anomalous quantized input image and anomalous subspace-reconstructed - non-quantized input (hi and lo) - - MSE loss between input image and reconstructed image through object-specific decoder, - - Focal loss between computed segmentation mask and ground truth mask. + """Loss function for the second training phase of the DSR model. + + The total loss is a combination of: + - MSE loss between non-anomalous quantized input image and anomalous + subspace-reconstructed non-quantized input (hi and lo features) + - MSE loss between input image and reconstructed image through + object-specific decoder + - Focal loss between computed segmentation mask and ground truth mask + + Example: + >>> loss_fn = DsrSecondStageLoss() + >>> loss = loss_fn( + ... recon_nq_hi=recon_nq_hi, + ... recon_nq_lo=recon_nq_lo, + ... qu_hi=qu_hi, + ... qu_lo=qu_lo, + ... input_image=input_image, + ... gen_img=gen_img, + ... seg=seg, + ... anomaly_mask=anomaly_mask + ... ) """ def __init__(self) -> None: @@ -34,20 +66,33 @@ def forward( seg: Tensor, anomaly_mask: Tensor, ) -> Tensor: - """Compute the loss over a batch for the DSR model. + """Compute the combined loss over a batch. Args: recon_nq_hi (Tensor): Reconstructed non-quantized hi feature recon_nq_lo (Tensor): Reconstructed non-quantized lo feature qu_hi (Tensor): Non-defective quantized hi feature qu_lo (Tensor): Non-defective quantized lo feature - input_image (Tensor): Original image + input_image (Tensor): Original input image gen_img (Tensor): Object-specific decoded image - seg (Tensor): Computed anomaly map - anomaly_mask (Tensor): Ground truth anomaly map + seg (Tensor): Computed anomaly segmentation map + anomaly_mask (Tensor): Ground truth anomaly mask Returns: - Tensor: Total loss + Tensor: Total combined loss value + + Example: + >>> loss_fn = DsrSecondStageLoss() + >>> loss = loss_fn( + ... recon_nq_hi=torch.randn(32, 64, 32, 32), + ... recon_nq_lo=torch.randn(32, 64, 32, 32), + ... qu_hi=torch.randn(32, 64, 32, 32), + ... qu_lo=torch.randn(32, 64, 32, 32), + ... input_image=torch.randn(32, 3, 256, 256), + ... gen_img=torch.randn(32, 3, 256, 256), + ... seg=torch.randn(32, 2, 256, 256), + ... anomaly_mask=torch.randint(0, 2, (32, 1, 256, 256)) + ... ) """ l2_loss_hi_val = self.l2_loss(recon_nq_hi, qu_hi) l2_loss_lo_val = self.l2_loss(recon_nq_lo, qu_lo) @@ -57,9 +102,17 @@ def forward( class DsrThirdStageLoss(nn.Module): - """Overall loss function of the third training phase of the DSR model. + """Loss function for the third training phase of the DSR model. + + The loss consists of a focal loss between the computed segmentation mask + and the ground truth mask. - The loss consists of a focal loss between the computed segmentation mask and the ground truth mask. + Example: + >>> loss_fn = DsrThirdStageLoss() + >>> loss = loss_fn( + ... pred_mask=pred_mask, + ... true_mask=true_mask + ... ) """ def __init__(self) -> None: @@ -68,13 +121,20 @@ def __init__(self) -> None: self.focal_loss = FocalLoss(alpha=1, reduction="mean") def forward(self, pred_mask: Tensor, true_mask: Tensor) -> Tensor: - """Compute the loss over a batch for the DSR model. + """Compute the focal loss between predicted and true masks. Args: - pred_mask (Tensor): Computed anomaly map - true_mask (Tensor): Ground truth anomaly map + pred_mask (Tensor): Computed anomaly segmentation map + true_mask (Tensor): Ground truth anomaly mask Returns: - Tensor: Total loss + Tensor: Focal loss value + + Example: + >>> loss_fn = DsrThirdStageLoss() + >>> loss = loss_fn( + ... pred_mask=torch.randn(32, 2, 256, 256), + ... true_mask=torch.randint(0, 2, (32, 1, 256, 256)) + ... ) """ return self.focal_loss(pred_mask, true_mask.squeeze(1).long()) diff --git a/src/anomalib/models/image/dsr/torch_model.py b/src/anomalib/models/image/dsr/torch_model.py index 4fe036ea5c..a55fb6cd27 100644 --- a/src/anomalib/models/image/dsr/torch_model.py +++ b/src/anomalib/models/image/dsr/torch_model.py @@ -1,4 +1,30 @@ -"""PyTorch model for the DSR model implementation.""" +"""PyTorch model for the DSR model implementation. + +This module implements the PyTorch model for Deep Spatial Reconstruction (DSR). +DSR is an anomaly detection model that uses a discrete latent model, image +reconstruction network, subspace restriction modules, anomaly detection module +and upsampling module to detect anomalies in images. + +The model works by: +1. Encoding input images into quantized feature maps +2. Reconstructing images using a general appearance decoder +3. Detecting anomalies by comparing reconstructed and original images + +Example: + >>> from anomalib.models.image.dsr.torch_model import DsrModel + >>> model = DsrModel() + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output["anomaly_map"].shape + torch.Size([32, 256, 256]) + +Notes: + The model implementation is based on the original DSR paper and code. + Original code: https://github.com/VitjanZ/DSR_anomaly_detection + +References: + - Original paper: https://arxiv.org/abs/2012.12436 +""" # Original Code # Copyright (c) 2022 VitjanZ @@ -26,12 +52,24 @@ class DsrModel(nn.Module): subspace restriction modules, anomaly detection module and upsampling module. Args: - embedding_dim (int): Dimension of codebook embeddings. - num_embeddings (int): Number of embeddings. - latent_anomaly_strength (float): Strength of the generated anomalies in the latent space. + embedding_dim (int): Dimension of codebook embeddings. Defaults to + ``128``. + num_embeddings (int): Number of embeddings in codebook. Defaults to + ``4096``. + latent_anomaly_strength (float): Strength of the generated anomalies in + latent space. Defaults to ``0.2``. num_hiddens (int): Number of output channels in residual layers. - num_residual_layers (int): Number of residual layers. - num_residual_hiddens (int): Number of intermediate channels. + Defaults to ``128``. + num_residual_layers (int): Number of residual layers. Defaults to ``2``. + num_residual_hiddens (int): Number of intermediate channels in residual + layers. Defaults to ``64``. + + Example: + >>> model = DsrModel() + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output["anomaly_map"].shape + torch.Size([32, 256, 256]) """ def __init__( @@ -83,7 +121,13 @@ def __init__( parameters.requires_grad = False def load_pretrained_discrete_model_weights(self, ckpt: Path, device: torch.device | str | None = None) -> None: - """Load pre-trained model weights.""" + """Load pre-trained model weights from checkpoint file. + + Args: + ckpt (Path): Path to checkpoint file containing model weights. + device (torch.device | str | None, optional): Device to load weights + to. Defaults to ``None``. + """ self.discrete_latent_model.load_state_dict(torch.load(ckpt, map_location=device)) def forward( @@ -91,28 +135,47 @@ def forward( batch: torch.Tensor, anomaly_map_to_generate: torch.Tensor | None = None, ) -> dict[str, torch.Tensor] | InferenceBatch: - """Compute the anomaly mask from an input image. + """Forward pass through the model. Args: - batch (torch.Tensor): Batch of input images. - anomaly_map_to_generate (torch.Tensor | None): anomaly map to use to generate quantized defects. - If not training phase 2, should be None. + batch (torch.Tensor): Input batch of images. + anomaly_map_to_generate (torch.Tensor | None, optional): Anomaly map + to use for generating quantized defects. Should be ``None`` if + not in training phase 2. Defaults to ``None``. Returns: - dict[str, torch.Tensor]: + dict[str, torch.Tensor] | InferenceBatch: Output depends on mode: + If testing: - - "anomaly_map": Upsampled anomaly map - - "pred_score": Image score + - ``anomaly_map``: Upsampled anomaly map + - ``pred_score``: Image anomaly score + If training phase 2: - - "recon_feat_hi": Reconstructed non-quantized hi features of defect (F~_hi) - - "recon_feat_lo": Reconstructed non-quantized lo features of defect (F~_lo) - - "embedding_bot": Quantized features of non defective img (Q_hi) - - "embedding_top": Quantized features of non defective img (Q_lo) - - "obj_spec_image": Object-specific-decoded image (I_spc) - - "anomaly_map": Predicted segmentation mask (M) - - "true_mask": Resized ground-truth anomaly map (M_gt) + - ``recon_feat_hi``: Reconstructed non-quantized hi features + (F~_hi) + - ``recon_feat_lo``: Reconstructed non-quantized lo features + (F~_lo) + - ``embedding_bot``: Quantized features of non defective img + (Q_hi) + - ``embedding_top``: Quantized features of non defective img + (Q_lo) + - ``obj_spec_image``: Object-specific-decoded image (I_spc) + - ``anomaly_map``: Predicted segmentation mask (M) + - ``true_mask``: Resized ground-truth anomaly map (M_gt) + If training phase 3: - - "anomaly_map": Reconstructed anomaly map + - ``anomaly_map``: Reconstructed anomaly map + + Raises: + RuntimeError: If ``anomaly_map_to_generate`` is provided when not in + training mode. + + Example: + >>> model = DsrModel() + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output["anomaly_map"].shape + torch.Size([32, 256, 256]) """ # Generate latent embeddings decoded image via general object decoder if anomaly_map_to_generate is None: @@ -127,7 +190,8 @@ def forward( embedder_bot = self.discrete_latent_model.vq_vae_bot embedder_top = self.discrete_latent_model.vq_vae_top - # Copy embeddings in order to input them to the subspace restriction module + # Copy embeddings in order to input them to the subspace + # restriction module anomaly_embedding_bot_copy = embd_bot.clone() anomaly_embedding_top_copy = embd_top.clone() @@ -138,7 +202,8 @@ def forward( # Upscale top (lo) embedding up_quantized_recon_t = self.discrete_latent_model.upsample_t(recon_embd_top) - # Concat embeddings and reconstruct image (object specific decoder) + # Concat embeddings and reconstruct image (object specific + # decoder) quant_join = torch.cat((up_quantized_recon_t, recon_embd_bot), dim=1) obj_spec_image = self.image_reconstruction_network(quant_join) @@ -181,7 +246,8 @@ def forward( torch.rand(batch.shape[0]) * (1.0 - self.latent_anomaly_strength) + self.latent_anomaly_strength ).cuda() - # Generate image through general object decoder, and defective & non defective quantized feature maps. + # Generate image through general object decoder, and defective & non + # defective quantized feature maps. with torch.no_grad(): latent_model_outputs = self.discrete_latent_model( batch, @@ -196,7 +262,8 @@ def forward( embd_top_def = latent_model_outputs["anomaly_embedding_lo"] embd_bot_def = latent_model_outputs["anomaly_embedding_hi"] - # Restore the features to normality with the Subspace restriction modules + # Restore the features to normality with the Subspace restriction + # modules recon_feat_hi, recon_embeddings_hi = self.subspace_restriction_module_hi( embd_bot_def, self.discrete_latent_model.vq_vae_bot, diff --git a/src/anomalib/models/image/efficient_ad/__init__.py b/src/anomalib/models/image/efficient_ad/__init__.py index d8b6f5f2b0..295ec61762 100644 --- a/src/anomalib/models/image/efficient_ad/__init__.py +++ b/src/anomalib/models/image/efficient_ad/__init__.py @@ -1,6 +1,17 @@ """EfficientAd: Accurate Visual Anomaly Detection at Millisecond-Level Latencies. -https://arxiv.org/pdf/2303.14535.pdf. +EfficientAd is a fast and accurate anomaly detection model that achieves +state-of-the-art performance with millisecond-level inference times. The model +utilizes a pre-trained EfficientNet backbone and employs a student-teacher +architecture for anomaly detection. + +The implementation is based on the paper: + "EfficientAd: Accurate Visual Anomaly Detection at Millisecond-Level Latencies" + https://arxiv.org/pdf/2303.14535.pdf + +Example: + >>> from anomalib.models import EfficientAd + >>> model = EfficientAd() """ # Copyright (C) 2023-2024 Intel Corporation diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index aa99d6a439..a75f889ec3 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -1,6 +1,36 @@ """EfficientAd: Accurate Visual Anomaly Detection at Millisecond-Level Latencies. -https://arxiv.org/pdf/2303.14535.pdf. +This module implements the EfficientAd model for fast and accurate anomaly +detection. EfficientAd uses a student-teacher architecture with a pre-trained +EfficientNet backbone to achieve state-of-the-art performance with +millisecond-level inference times. + +The model consists of: + - A pre-trained EfficientNet teacher network + - A lightweight student network + - Knowledge distillation training + - Anomaly detection via feature comparison + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import EfficientAd + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = EfficientAd() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +Paper: + "EfficientAd: Accurate Visual Anomaly Detection at + Millisecond-Level Latencies" + https://arxiv.org/pdf/2303.14535.pdf + +See Also: + :class:`anomalib.models.image.efficient_ad.torch_model.EfficientAdModel`: + PyTorch implementation of the EfficientAd model architecture. """ # Copyright (C) 2023-2024 Intel Corporation @@ -46,25 +76,45 @@ class EfficientAd(AnomalibModule): """PL Lightning Module for the EfficientAd algorithm. + The EfficientAd model uses a student-teacher architecture with a pretrained + EfficientNet backbone for fast and accurate anomaly detection. + Args: - imagenet_dir (Path|str): directory path for the Imagenet dataset - Defaults to ``./datasets/imagenette``. - teacher_out_channels (int): number of convolution output channels + imagenet_dir (Path | str): Directory path for the Imagenet dataset. + Defaults to ``"./datasets/imagenette"``. + teacher_out_channels (int): Number of convolution output channels. Defaults to ``384``. - model_size (str): size of student and teacher model + model_size (EfficientAdModelSize | str): Size of student and teacher model. Defaults to ``EfficientAdModelSize.S``. - lr (float): learning rate + lr (float): Learning rate. Defaults to ``0.0001``. - weight_decay (float): optimizer weight decay + weight_decay (float): Optimizer weight decay. Defaults to ``0.00001``. - padding (bool): use padding in convoluional layers + padding (bool): Use padding in convolutional layers. Defaults to ``False``. - pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the - output anomaly maps so that their size matches the size in the padding = True case. + pad_maps (bool): Relevant if ``padding=False``. If ``True``, pads the output + anomaly maps to match size of ``padding=True`` case. + Defaults to ``True``. + pre_processor (PreProcessor | bool, optional): Pre-processor used to transform + input data before passing to model. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor used to process + model predictions. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator used to compute metrics. Defaults to ``True``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + visualizer (Visualizer | bool, optional): Visualizer used to create + visualizations. + Defaults to ``True``. + + Example: + >>> from anomalib.models import EfficientAd + >>> model = EfficientAd( + ... imagenet_dir="./datasets/imagenette", + ... model_size="s", + ... lr=1e-4 + ... ) + """ def __init__( @@ -103,7 +153,11 @@ def __init__( self.weight_decay: float = weight_decay def prepare_pretrained_model(self) -> None: - """Prepare the pretrained teacher model.""" + """Prepare the pretrained teacher model. + + Downloads and loads pretrained weights for the teacher model if not already + present. + """ pretrained_models_dir = Path("./pre_trained/") if not (pretrained_models_dir / "efficientad_pretrained_weights").is_dir(): download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO) @@ -117,8 +171,11 @@ def prepare_pretrained_model(self) -> None: def prepare_imagenette_data(self, image_size: tuple[int, int] | torch.Size) -> None: """Prepare ImageNette dataset transformations. + Sets up data transforms and downloads ImageNette dataset if not present. + Args: - image_size (tuple[int, int] | torch.Size): Image size. + image_size (tuple[int, int] | torch.Size): Target image size for + transforms. """ self.data_transforms_imagenet = Compose( [ @@ -137,15 +194,22 @@ def prepare_imagenette_data(self, image_size: tuple[int, int] | torch.Size) -> N @torch.no_grad() def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, torch.Tensor]: - """Calculate the mean and std of the teacher models activations. + """Calculate channel-wise mean and std of teacher model activations. - Adapted from https://math.stackexchange.com/a/2148949 + Computes running mean and standard deviation of teacher model feature maps + over the full dataset. Args: - dataloader (DataLoader): Dataloader of the respective dataset. + dataloader (DataLoader): Dataloader for the dataset. Returns: - dict[str, torch.Tensor]: Dictionary of channel-wise mean and std + dict[str, torch.Tensor]: Dictionary containing: + - ``mean``: Channel-wise means of shape ``(1, C, 1, 1)`` + - ``std``: Channel-wise standard deviations of shape + ``(1, C, 1, 1)`` + + Raises: + ValueError: If no data is provided (``n`` remains ``None``). """ arrays_defined = False n: torch.Tensor | None = None @@ -178,14 +242,20 @@ def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, torch.Te @torch.no_grad() def map_norm_quantiles(self, dataloader: DataLoader) -> dict[str, torch.Tensor]: - """Calculate 90% and 99.5% quantiles of the student(st) and autoencoder(ae). + """Calculate quantiles of student and autoencoder feature maps. + + Computes the 90% and 99.5% quantiles of the feature maps from both the + student network and autoencoder on normal (good) validation samples. Args: - dataloader (DataLoader): Dataloader of the respective dataset. + dataloader (DataLoader): Validation dataloader. Returns: - dict[str, torch.Tensor]: Dictionary of both the 90% and 99.5% quantiles - of both the student and autoencoder feature maps. + dict[str, torch.Tensor]: Dictionary containing: + - ``qa_st``: 90% quantile of student maps + - ``qa_ae``: 90% quantile of autoencoder maps + - ``qb_st``: 99.5% quantile of student maps + - ``qb_ae``: 99.5% quantile of autoencoder maps """ maps_st = [] maps_ae = [] @@ -202,17 +272,18 @@ def map_norm_quantiles(self, dataloader: DataLoader) -> dict[str, torch.Tensor]: return {"qa_st": qa_st, "qa_ae": qa_ae, "qb_st": qb_st, "qb_ae": qb_ae} def _get_quantiles_of_maps(self, maps: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - """Calculate 90% and 99.5% quantiles of the given anomaly maps. + """Calculate quantiles of anomaly maps. - If the total number of elements in the given maps is larger than 16777216 - the returned quantiles are computed on a random subset of the given - elements. + Computes the 90% and 99.5% quantiles of the given anomaly maps. If total + number of elements exceeds 16777216, uses a random subset. Args: maps (list[torch.Tensor]): List of anomaly maps. Returns: - tuple[torch.Tensor, torch.Tensor]: Two scalars - the 90% and the 99.5% quantile. + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - 90% quantile scalar + - 99.5% quantile scalar """ maps_flat = reduce_tensor_elems(torch.cat(maps)) qa = torch.quantile(maps_flat, q=0.9).to(self.device) @@ -221,13 +292,35 @@ def _get_quantiles_of_maps(self, maps: list[torch.Tensor]) -> tuple[torch.Tensor @classmethod def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor: - """Default transform for EfficientAd. Imagenet normalization applied in forward.""" + """Configure default pre-processor for EfficientAd. + + Note that ImageNet normalization is applied in the forward pass, not here. + + Args: + image_size (tuple[int, int] | None, optional): Target image size. + Defaults to ``(256, 256)``. + + Returns: + PreProcessor: Configured pre-processor with resize transform. + """ image_size = image_size or (256, 256) transform = Compose([Resize(image_size, antialias=True)]) return PreProcessor(transform=transform) def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure optimizers.""" + """Configure optimizers for training. + + Sets up Adam optimizer with learning rate scheduler that decays LR by 0.1 + at 95% of training. + + Returns: + dict: Dictionary containing: + - ``optimizer``: Adam optimizer + - ``lr_scheduler``: StepLR scheduler + + Raises: + ValueError: If neither ``max_epochs`` nor ``max_steps`` is defined. + """ optimizer = torch.optim.Adam( list(self.model.student.parameters()) + list(self.model.ae.parameters()), lr=self.lr, @@ -256,12 +349,17 @@ def configure_optimizers(self) -> torch.optim.Optimizer: return {"optimizer": optimizer, "lr_scheduler": scheduler} def on_train_start(self) -> None: - """Called before the first training epoch. + """Set up model before training begins. + + Performs the following steps: + 1. Validates training parameters (batch size=1, no normalization) + 2. Sets up pretrained teacher model + 3. Prepares ImageNette dataset + 4. Calculates channel statistics - First check if EfficientAd-specific parameters are set correctly (train_batch_size of 1 - and no Imagenet normalization in transforms), then sets up the pretrained teacher model, - then prepares the imagenette data, and finally calculates or loads - the channel-wise mean and std of the training dataset and push to the model. + Raises: + ValueError: If ``train_batch_size != 1`` or transforms contain + normalization. """ if self.trainer.datamodule.train_batch_size != 1: msg = "train_batch_size for EfficientAd should be 1." @@ -282,15 +380,18 @@ def on_train_start(self) -> None: self.model.mean_std.update(channel_mean_std) def training_step(self, batch: Batch, *args, **kwargs) -> dict[str, torch.Tensor]: - """Perform the training step for EfficientAd returns the student, autoencoder and combined loss. + """Perform training step. + + Computes student, autoencoder and combined losses using both the input + batch and a batch from ImageNette. Args: - batch (Batch): Batch containing image filename, image, label and mask - args: Additional arguments. - kwargs: Additional keyword arguments. + batch (Batch): Input batch containing image and labels + *args: Additional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - Loss. + dict[str, torch.Tensor]: Dictionary containing total loss """ del args, kwargs # These variables are not used. @@ -311,20 +412,25 @@ def training_step(self, batch: Batch, *args, **kwargs) -> dict[str, torch.Tensor return {"loss": loss} def on_validation_start(self) -> None: - """Calculate the feature map quantiles of the validation dataset and push to the model.""" + """Calculate feature map statistics before validation. + + Computes quantiles of feature maps on validation set and updates model. + """ map_norm_quantiles = self.map_norm_quantiles(self.trainer.datamodule.val_dataloader()) self.model.quantiles.update(map_norm_quantiles) def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform the validation step of EfficientAd returns anomaly maps for the input image batch. + """Perform validation step. + + Generates anomaly maps for the input batch. Args: - batch (Batch): Input batch - args: Additional arguments. - kwargs: Additional keyword arguments. + batch (Batch): Input batch + *args: Additional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - Dictionary containing anomaly maps. + STEP_OUTPUT: Batch with added predictions """ del args, kwargs # These variables are not used. @@ -333,14 +439,19 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Return EfficientAD trainer arguments.""" + """Get trainer arguments. + + Returns: + dict[str, Any]: Dictionary with trainer arguments: + - ``num_sanity_val_steps``: 0 + """ return {"num_sanity_val_steps": 0} @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get model's learning type. Returns: - LearningType: Learning type of the model. + LearningType: Always ``LearningType.ONE_CLASS`` """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/efficient_ad/torch_model.py b/src/anomalib/models/image/efficient_ad/torch_model.py index 74f2a507bb..e053736614 100644 --- a/src/anomalib/models/image/efficient_ad/torch_model.py +++ b/src/anomalib/models/image/efficient_ad/torch_model.py @@ -1,4 +1,31 @@ -"""Torch model for student, teacher and autoencoder model in EfficientAd.""" +"""PyTorch implementation of the EfficientAd model architecture. + +This module contains the PyTorch implementation of the student, teacher and +autoencoder networks used in EfficientAd for fast and accurate anomaly detection. + +The model consists of: + - A pre-trained EfficientNet teacher network + - A lightweight student network + - Knowledge distillation training + - Anomaly detection via feature comparison + +Example: + >>> from anomalib.models.image.efficient_ad.torch_model import EfficientAdModel + >>> model = EfficientAdModel() + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output["anomaly_map"].shape + torch.Size([32, 256, 256]) + +Paper: + "EfficientAd: Accurate Visual Anomaly Detection at + Millisecond-Level Latencies" + https://arxiv.org/pdf/2303.14535.pdf + +See Also: + :class:`anomalib.models.image.efficient_ad.lightning_model.EfficientAd`: + Lightning implementation of the EfficientAd model. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -19,13 +46,22 @@ def imagenet_norm_batch(x: torch.Tensor) -> torch.Tensor: - """Normalize batch of images with ImageNet mean and std. + """Normalize batch of images using ImageNet mean and standard deviation. + + This function normalizes a batch of images using the standard ImageNet mean and + standard deviation values. The normalization is done channel-wise. Args: - x (torch.Tensor): Input batch. + x (torch.Tensor): Input batch tensor of shape ``(N, C, H, W)`` where + ``N`` is batch size, ``C`` is number of channels (3 for RGB), + ``H`` is height and ``W`` is width. Returns: - torch.Tensor: Normalized batch using the ImageNet mean and std. + torch.Tensor: Normalized batch tensor with same shape as input, where each + channel is normalized using ImageNet statistics: + - Red channel: mean=0.485, std=0.229 + - Green channel: mean=0.456, std=0.224 + - Blue channel: mean=0.406, std=0.225 """ mean = torch.tensor([0.485, 0.456, 0.406])[None, :, None, None].to(x.device) std = torch.tensor([0.229, 0.224, 0.225])[None, :, None, None].to(x.device) @@ -33,21 +69,32 @@ def imagenet_norm_batch(x: torch.Tensor) -> torch.Tensor: def reduce_tensor_elems(tensor: torch.Tensor, m: int = 2**24) -> torch.Tensor: - """Reduce tensor elements. + """Reduce the number of elements in a tensor by random sampling. + + This function flattens an n-dimensional tensor and randomly samples at most ``m`` + elements from it. This is used to handle the limitation of ``torch.quantile`` + operation which supports a maximum of 2^24 elements. - This function flatten n-dimensional tensors, selects m elements from it - and returns the selected elements as tensor. It is used to select - at most 2**24 for torch.quantile operation, as it is the maximum - supported number of elements. - https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291. + Reference: + https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291 Args: - tensor (torch.Tensor): input tensor from which elements are selected - m (int): number of maximum tensor elements. - Defaults to ``2**24`` + tensor (torch.Tensor): Input tensor of any shape from which elements will be + sampled. + m (int, optional): Maximum number of elements to sample. If the flattened + tensor has more elements than ``m``, random sampling is performed. + Defaults to ``2**24``. Returns: - Tensor: reduced tensor + torch.Tensor: A flattened tensor containing at most ``m`` elements randomly + sampled from the input tensor. + + Example: + >>> import torch + >>> tensor = torch.randn(1000, 1000) # 1M elements + >>> reduced = reduce_tensor_elems(tensor, m=1000) + >>> reduced.shape + torch.Size([1000]) """ tensor = torch.flatten(tensor) if len(tensor) > m: @@ -59,19 +106,53 @@ def reduce_tensor_elems(tensor: torch.Tensor, m: int = 2**24) -> torch.Tensor: class EfficientAdModelSize(str, Enum): - """Supported EfficientAd model sizes.""" + """Supported EfficientAd model sizes. + + The EfficientAd model comes in two sizes: + - ``M`` (medium): Uses a larger architecture with more parameters + - ``S`` (small): Uses a smaller architecture with fewer parameters + + Example: + >>> from anomalib.models.image.efficient_ad.torch_model import ( + ... EfficientAdModelSize + ... ) + >>> model_size = EfficientAdModelSize.S + >>> model_size + 'small' + >>> model_size = EfficientAdModelSize.M + >>> model_size + 'medium' + """ M = "medium" S = "small" class SmallPatchDescriptionNetwork(nn.Module): - """Patch Description Network small. + """Small variant of the Patch Description Network. + + This network processes input images through a series of convolutional and pooling + layers to extract patch-level features. It uses a smaller architecture compared + to the medium variant. Args: - out_channels (int): number of convolution output channels - padding (bool): use padding in convoluional layers + out_channels (int): Number of output channels in the final convolution layer. + padding (bool, optional): Whether to use padding in convolutional layers. Defaults to ``False``. + + Example: + >>> import torch + >>> from anomalib.models.image.efficient_ad.torch_model import ( + ... SmallPatchDescriptionNetwork + ... ) + >>> model = SmallPatchDescriptionNetwork(out_channels=384) + >>> input_tensor = torch.randn(32, 3, 64, 64) + >>> output = model(input_tensor) + >>> output.shape + torch.Size([32, 384, 13, 13]) + + Note: + The network applies ImageNet normalization to the input before processing. """ def __init__(self, out_channels: int, padding: bool = False) -> None: @@ -85,13 +166,15 @@ def __init__(self, out_channels: int, padding: bool = False) -> None: self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Perform a forward pass through the network. + """Forward pass through the network. Args: - x (torch.Tensor): Input batch. + x (torch.Tensor): Input tensor of shape ``(N, 3, H, W)``. Returns: - torch.Tensor: Output from the network. + torch.Tensor: Output feature maps of shape + ``(N, out_channels, H', W')``, where ``H'`` and ``W'`` are + determined by the network architecture and padding settings. """ x = imagenet_norm_batch(x) x = F.relu(self.conv1(x)) @@ -103,12 +186,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MediumPatchDescriptionNetwork(nn.Module): - """Patch Description Network medium. + """Medium-sized patch description network. + + This network processes input images through a series of convolutional and + pooling layers to extract descriptive features from image patches. Args: - out_channels (int): number of convolution output channels - padding (bool): use padding in convoluional layers + out_channels (int): Number of output channels in the final convolution + layer. + padding (bool, optional): Whether to use padding in convolutional layers. Defaults to ``False``. + + Example: + >>> import torch + >>> from anomalib.models.image.efficient_ad.torch_model import ( + ... MediumPatchDescriptionNetwork + ... ) + >>> model = MediumPatchDescriptionNetwork(out_channels=384) + >>> input_tensor = torch.randn(32, 3, 64, 64) + >>> output = model(input_tensor) + >>> output.shape + torch.Size([32, 384, 13, 13]) + + Note: + The network applies ImageNet normalization to the input before + processing. """ def __init__(self, out_channels: int, padding: bool = False) -> None: @@ -124,13 +226,15 @@ def __init__(self, out_channels: int, padding: bool = False) -> None: self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Perform a forward pass through the network. + """Forward pass through the network. Args: - x (torch.Tensor): Input batch. + x (torch.Tensor): Input tensor of shape ``(N, 3, H, W)``. Returns: - torch.Tensor: Output from the network. + torch.Tensor: Output feature maps of shape + ``(N, out_channels, H', W')``, where ``H'`` and ``W'`` are + determined by the network architecture and padding settings. """ x = imagenet_norm_batch(x) x = F.relu(self.conv1(x)) @@ -144,7 +248,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Encoder(nn.Module): - """Autoencoder Encoder model.""" + """Encoder module for the autoencoder architecture. + + The encoder consists of 6 convolutional layers that progressively reduce the + spatial dimensions while increasing the number of channels. + + Example: + >>> import torch + >>> from anomalib.models.image.efficient_ad.torch_model import Encoder + >>> model = Encoder() + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output.shape + torch.Size([32, 64, 1, 1]) + + Note: + The encoder uses ReLU activation after each convolutional layer except + the last one. + """ def __init__(self) -> None: super().__init__() @@ -156,13 +277,14 @@ def __init__(self) -> None: self.enconv6 = nn.Conv2d(64, 64, kernel_size=8, stride=1, padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Perform the forward pass through the network. + """Forward pass through the encoder network. Args: - x (torch.Tensor): Input batch. + x (torch.Tensor): Input tensor of shape ``(N, 3, H, W)``. Returns: - torch.Tensor: Output from the network. + torch.Tensor: Encoded features of shape ``(N, 64, H', W')``, where + ``H'`` and ``W'`` are determined by the network architecture. """ x = F.relu(self.enconv1(x)) x = F.relu(self.enconv2(x)) @@ -173,11 +295,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Decoder(nn.Module): - """Autoencoder Decoder model. + """Decoder module for the autoencoder architecture. + + The decoder consists of 8 convolutional layers with upsampling that + progressively increase spatial dimensions while maintaining or reducing + channel dimensions. Args: - out_channels (int): number of convolution output channels - padding (int): use padding in convoluional layers + out_channels (int): Number of output channels in final conv layer. + padding (int): Whether to use padding in convolutional layers. + + Example: + >>> import torch + >>> from anomalib.models.image.efficient_ad.torch_model import Decoder + >>> model = Decoder(out_channels=384, padding=True) + >>> input_tensor = torch.randn(32, 64, 1, 1) + >>> image_size = (256, 256) + >>> output = model(input_tensor, image_size) + >>> output.shape + torch.Size([32, 384, 64, 64]) + + Note: + - Uses ReLU activation and dropout after most convolutional layers + - Performs bilinear upsampling between conv layers to increase spatial + dimensions + - Final output size depends on ``padding`` parameter and input + ``image_size`` """ def __init__(self, out_channels: int, padding: int, *args, **kwargs) -> None: @@ -203,11 +346,14 @@ def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> """Perform a forward pass through the network. Args: - x (torch.Tensor): Input batch. - image_size (tuple): size of input images. + x (torch.Tensor): Input tensor of shape ``(N, 64, H, W)``. + image_size (tuple[int, int] | torch.Size): Target output size + ``(H, W)``. Returns: - torch.Tensor: Output from the network. + torch.Tensor: Decoded features of shape + ``(N, out_channels, H', W')``, where ``H'`` and ``W'`` are + determined by the network architecture and padding settings. """ last_upsample = ( math.ceil(image_size[0] / 4) if self.padding else math.ceil(image_size[0] / 4) - 8, @@ -239,9 +385,26 @@ def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> class AutoEncoder(nn.Module): """EfficientAd Autoencoder. + The autoencoder consists of an encoder and decoder network. The encoder extracts features + from the input image which are then reconstructed by the decoder. + Args: - out_channels (int): number of convolution output channels - padding (int): use padding in convoluional layers + out_channels (int): Number of convolution output channels in the decoder. + padding (int): Whether to use padding in convolutional layers. + *args: Variable length argument list passed to parent class. + **kwargs: Arbitrary keyword arguments passed to parent class. + + Example: + >>> from torch import randn + >>> autoencoder = AutoEncoder(out_channels=384, padding=True) + >>> input_tensor = randn(32, 3, 256, 256) + >>> output = autoencoder(input_tensor, image_size=(256, 256)) + >>> output.shape + torch.Size([32, 384, 256, 256]) + + Notes: + The input images are normalized using ImageNet statistics before being passed + through the encoder. """ def __init__(self, out_channels: int, padding: int, *args, **kwargs) -> None: @@ -250,14 +413,16 @@ def __init__(self, out_channels: int, padding: int, *args, **kwargs) -> None: self.decoder = Decoder(out_channels, padding) def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> torch.Tensor: - """Perform the forward pass through the network. + """Forward pass through the autoencoder. Args: - x (torch.Tensor): Input batch. - image_size (tuple): size of input images. + x (torch.Tensor): Input tensor of shape ``(N, C, H, W)``. + image_size (tuple[int, int] | torch.Size): Target output size ``(H, W)``. Returns: - torch.Tensor: Output from the network. + torch.Tensor: Reconstructed features of shape ``(N, out_channels, H', W')``, + where ``H'`` and ``W'`` are determined by the decoder architecture and + padding settings. """ x = imagenet_norm_batch(x) x = self.encoder(x) @@ -267,14 +432,41 @@ def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> class EfficientAdModel(nn.Module): """EfficientAd model. + The EfficientAd model consists of a teacher and student network for anomaly + detection. The teacher network is pre-trained and frozen, while the student + network is trained to match the teacher's outputs. + Args: - teacher_out_channels (int): number of convolution output channels of the pre-trained teacher model - model_size (str): size of student and teacher model - padding (bool): use padding in convoluional layers + teacher_out_channels (int): Number of convolution output channels of the + pre-trained teacher model. + model_size (EfficientAdModelSize): Size of student and teacher model. + Defaults to ``EfficientAdModelSize.S``. + padding (bool): Whether to use padding in convolutional layers. Defaults to ``False``. - pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the - output anomaly maps so that their size matches the size in the padding = True case. + pad_maps (bool): Whether to pad output anomaly maps when ``padding=False`` + to match size of padded case. Only relevant if ``padding=False``. Defaults to ``True``. + + Example: + >>> from anomalib.models.image.efficient_ad.torch_model import ( + ... EfficientAdModel, + ... EfficientAdModelSize + ... ) + >>> model = EfficientAdModel( + ... teacher_out_channels=384, + ... model_size=EfficientAdModelSize.S + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output.anomaly_map.shape + torch.Size([32, 1, 256, 256]) + + Notes: + The model uses a student-teacher architecture where: + - Teacher network is pre-trained and frozen + - Student network learns to match teacher outputs + - Autoencoder provides additional feature extraction + - Anomaly scores are computed from student-teacher differences """ def __init__( @@ -323,25 +515,28 @@ def __init__( @staticmethod def is_set(p_dic: nn.ParameterDict) -> bool: - """Check if any of the parameters in the parameter dictionary is set. + """Check if any parameters in the dictionary are non-zero. Args: - p_dic (nn.ParameterDict): Parameter dictionary. + p_dic (nn.ParameterDict): Parameter dictionary to check. Returns: - bool: Boolean indicating whether any of the parameters in the parameter dictionary is set. + bool: ``True`` if any parameter is non-zero, ``False`` otherwise. """ return any(value.sum() != 0 for _, value in p_dic.items()) @staticmethod def choose_random_aug_image(image: torch.Tensor) -> torch.Tensor: - """Choose a random augmentation function and apply it to the input image. + """Apply random augmentation to input image. + + Randomly selects and applies one of: brightness, contrast or saturation + adjustment with coefficient sampled from U(0.8, 1.2). Args: - image (torch.Tensor): Input image. + image (torch.Tensor): Input image tensor. Returns: - Tensor: Augmented image. + torch.Tensor: Augmented image tensor. """ transform_functions = [ transforms.functional.adjust_brightness, @@ -359,15 +554,22 @@ def forward( batch_imagenet: torch.Tensor | None = None, normalize: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | InferenceBatch: - """Perform the forward-pass of the EfficientAd models. + """Forward pass through the model. Args: - batch (torch.Tensor): Input images. - batch_imagenet (torch.Tensor): ImageNet batch. Defaults to None. - normalize (bool): Normalize anomaly maps or not + batch (torch.Tensor): Input batch of images. + batch_imagenet (torch.Tensor | None): Optional batch of ImageNet + images for training. Defaults to ``None``. + normalize (bool): Whether to normalize anomaly maps. + Defaults to ``True``. Returns: - Tensor: Predictions + tuple[torch.Tensor, torch.Tensor, torch.Tensor] | InferenceBatch: + If training: + - Loss components (student-teacher, autoencoder, + student-autoencoder) + If inference: + - Batch containing anomaly maps and scores """ student_output, distance_st = self.compute_student_teacher_distance(batch) if self.training: @@ -382,12 +584,12 @@ def compute_student_teacher_distance(self, batch: torch.Tensor) -> tuple[torch.T """Compute the student-teacher distance vectors. Args: - batch (torch.Tensor): Input images. - batch_imagenet (torch.Tensor): ImageNet batch. Defaults to None. - normalize (bool): Normalize anomaly maps or not + batch (torch.Tensor): Input batch of images. Returns: - Tensor: Predictions + tuple[torch.Tensor, torch.Tensor]: + - Student network output features + - Squared distance between normalized teacher and student features """ with torch.no_grad(): teacher_output = self.teacher(batch) @@ -404,7 +606,24 @@ def compute_losses( batch_imagenet: torch.Tensor, distance_st: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute the student-teacher loss and the autoencoder loss.""" + """Compute training losses. + + Computes three loss components: + - Student-teacher loss (hard examples + ImageNet penalty) + - Autoencoder reconstruction loss + - Student-autoencoder consistency loss + + Args: + batch (torch.Tensor): Input batch of images. + batch_imagenet (torch.Tensor): Batch of ImageNet images. + distance_st (torch.Tensor): Student-teacher distances. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Student-teacher loss + - Autoencoder loss + - Student-autoencoder loss + """ # Student loss distance_st = reduce_tensor_elems(distance_st) d_hard = torch.quantile(distance_st, 0.999) @@ -438,7 +657,20 @@ def compute_maps( distance_st: torch.Tensor, normalize: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute the anomaly maps.""" + """Compute anomaly maps from model outputs. + + Args: + batch (torch.Tensor): Input batch of images. + student_output (torch.Tensor): Student network output features. + distance_st (torch.Tensor): Student-teacher distances. + normalize (bool): Whether to normalize maps with pre-computed + quantiles. Defaults to ``True``. + + Returns: + tuple[torch.Tensor, torch.Tensor]: + - Student-teacher anomaly map + - Student-autoencoder anomaly map + """ image_size = batch.shape[-2:] # Eval mode. with torch.no_grad(): @@ -463,6 +695,18 @@ def compute_maps( return map_st, map_stae def get_maps(self, batch: torch.Tensor, normalize: bool = False) -> tuple[torch.Tensor, torch.Tensor]: - """Standalone function to compute anomaly maps.""" + """Compute anomaly maps for a batch of images. + + Convenience method that combines distance computation and map generation. + + Args: + batch (torch.Tensor): Input batch of images. + normalize (bool): Whether to normalize maps. Defaults to ``False``. + + Returns: + tuple[torch.Tensor, torch.Tensor]: + - Student-teacher anomaly map + - Student-autoencoder anomaly map + """ student_output, distance_st = self.compute_student_teacher_distance(batch) return self.compute_maps(batch, student_output, distance_st, normalize) diff --git a/src/anomalib/models/image/fastflow/__init__.py b/src/anomalib/models/image/fastflow/__init__.py index 7abb420e33..f9221b7ee0 100644 --- a/src/anomalib/models/image/fastflow/__init__.py +++ b/src/anomalib/models/image/fastflow/__init__.py @@ -1,4 +1,30 @@ -"""FastFlow Algorithm Implementation.""" +"""FastFlow Algorithm Implementation. + +FastFlow is a fast flow-based anomaly detection model that uses normalizing flows +to model the distribution of features extracted from a pre-trained CNN backbone. +The model achieves competitive performance while maintaining fast inference times. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Fastflow + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = Fastflow() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +Paper: + Title: FastFlow: Unsupervised Anomaly Detection and Localization via 2D + Normalizing Flows + URL: https://arxiv.org/abs/2111.07677 + +See Also: + :class:`anomalib.models.image.fastflow.torch_model.FastflowModel`: + PyTorch implementation of the FastFlow model architecture. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/fastflow/anomaly_map.py b/src/anomalib/models/image/fastflow/anomaly_map.py index b0bded15b6..4d195eec56 100644 --- a/src/anomalib/models/image/fastflow/anomaly_map.py +++ b/src/anomalib/models/image/fastflow/anomaly_map.py @@ -1,4 +1,15 @@ -"""FastFlow Anomaly Map Generator Implementation.""" +"""FastFlow Anomaly Map Generator Implementation. + +This module implements the anomaly map generation for the FastFlow model. The +generator takes hidden variables from normalizing flow blocks and produces an +anomaly heatmap by computing flow maps. + +Example: + >>> from anomalib.models.image.fastflow.anomaly_map import AnomalyMapGenerator + >>> generator = AnomalyMapGenerator(input_size=(256, 256)) + >>> hidden_vars = [torch.randn(1, 64, 32, 32)] # from NF blocks + >>> anomaly_map = generator(hidden_vars) # returns anomaly heatmap +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,10 +21,26 @@ class AnomalyMapGenerator(nn.Module): - """Generate Anomaly Heatmap. + """Generate anomaly heatmaps from FastFlow hidden variables. + + The generator takes hidden variables from normalizing flow blocks and produces + an anomaly heatmap. For each hidden variable tensor, it: + 1. Computes negative log probability + 2. Converts to probability via exponential + 3. Interpolates to input size + 4. Stacks and averages flow maps to produce final anomaly map Args: - input_size (ListConfig | tuple): Input size. + input_size (ListConfig | tuple): Target size for the anomaly map as + ``(height, width)``. If ``ListConfig`` is provided, it will be + converted to tuple. + + Example: + >>> generator = AnomalyMapGenerator(input_size=(256, 256)) + >>> hidden_vars = [torch.randn(1, 64, 32, 32)] # from NF blocks + >>> anomaly_map = generator(hidden_vars) + >>> anomaly_map.shape + torch.Size([1, 1, 256, 256]) """ def __init__(self, input_size: ListConfig | tuple) -> None: @@ -21,18 +48,26 @@ def __init__(self, input_size: ListConfig | tuple) -> None: self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size) def forward(self, hidden_variables: list[torch.Tensor]) -> torch.Tensor: - """Generate Anomaly Heatmap. + """Generate anomaly heatmap from hidden variables. + + This implementation generates the heatmap based on the flow maps computed + from the normalizing flow (NF) FastFlow blocks. Each block yields a flow + map, which overall is stacked and averaged to produce an anomaly map. - This implementation generates the heatmap based on the flow maps - computed from the normalizing flow (NF) FastFlow blocks. Each block - yields a flow map, which overall is stacked and averaged to an anomaly - map. + The process for each hidden variable is: + 1. Compute negative log probability as mean of squared values + 2. Convert to probability via exponential + 3. Interpolate to input size + 4. Stack all flow maps and average to get final anomaly map Args: - hidden_variables (list[torch.Tensor]): List of hidden variables from each NF FastFlow block. + hidden_variables (list[torch.Tensor]): List of hidden variables from + each NF FastFlow block. Each tensor has shape + ``(N, C, H, W)``. Returns: - Tensor: Anomaly Map. + torch.Tensor: Anomaly heatmap with shape ``(N, 1, H, W)`` where + ``H, W`` match the ``input_size``. """ flow_maps: list[torch.Tensor] = [] for hidden_variable in hidden_variables: diff --git a/src/anomalib/models/image/fastflow/lightning_model.py b/src/anomalib/models/image/fastflow/lightning_model.py index 8a98ea9e7a..eb3e7deb45 100644 --- a/src/anomalib/models/image/fastflow/lightning_model.py +++ b/src/anomalib/models/image/fastflow/lightning_model.py @@ -1,6 +1,35 @@ """FastFlow Lightning Model Implementation. -https://arxiv.org/abs/2111.07677 +This module provides a PyTorch Lightning implementation of the FastFlow model for anomaly +detection. FastFlow is a fast flow-based model that uses normalizing flows to model the +distribution of features extracted from a pre-trained CNN backbone. + +The model achieves competitive performance while maintaining fast inference times by +leveraging normalizing flows to transform feature distributions into a simpler form that +can be efficiently modeled. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Fastflow + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = Fastflow() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +Paper: + Title: FastFlow: Unsupervised Anomaly Detection and Localization via 2D + Normalizing Flows + URL: https://arxiv.org/abs/2111.07677 + +See Also: + :class:`anomalib.models.image.fastflow.torch_model.FastflowModel`: + PyTorch implementation of the FastFlow model architecture. + :class:`anomalib.models.image.fastflow.loss.FastflowLoss`: + Loss function used to train the FastFlow model. """ # Copyright (C) 2022-2024 Intel Corporation @@ -27,20 +56,45 @@ class Fastflow(AnomalibModule): """PL Lightning Module for the FastFlow algorithm. + The FastFlow model uses normalizing flows to transform feature distributions from a + pre-trained CNN backbone into a simpler form that can be efficiently modeled for + anomaly detection. + Args: - backbone (str): Backbone CNN network - Defaults to ``resnet18``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + backbone (str): Backbone CNN network architecture. Available options are + ``"resnet18"``, ``"wide_resnet50_2"``, etc. + Defaults to ``"resnet18"``. + pre_trained (bool, optional): Whether to use pre-trained backbone weights. Defaults to ``True``. - flow_steps (int, optional): Flow steps. + flow_steps (int, optional): Number of steps in the normalizing flow. Defaults to ``8``. - conv3x3_only (bool, optinoal): Use only conv3x3 in fast_flow model. + conv3x3_only (bool, optional): Whether to use only 3x3 convolutions in the + FastFlow model. Defaults to ``False``. - hidden_ratio (float, optional): Ratio to calculate hidden var channels. + hidden_ratio (float, optional): Ratio used to calculate hidden variable + channels. Defaults to ``1.0``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Pre-processor to use for + input data. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor to use for + model outputs. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator to compute metrics. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer for model outputs. + Defaults to ``True``. + + Raises: + ValueError: If ``input_size`` is not provided during initialization. + + Example: + >>> from anomalib.models import Fastflow + >>> model = Fastflow( + ... backbone="resnet18", + ... pre_trained=True, + ... flow_steps=8 + ... ) """ def __init__( diff --git a/src/anomalib/models/image/fastflow/loss.py b/src/anomalib/models/image/fastflow/loss.py index a47f49df88..b2b36b1619 100644 --- a/src/anomalib/models/image/fastflow/loss.py +++ b/src/anomalib/models/image/fastflow/loss.py @@ -1,4 +1,22 @@ -"""Loss function for the FastFlow Model Implementation.""" +"""Loss function for the FastFlow Model Implementation. + +This module implements the loss function used to train the FastFlow model. The loss is +computed based on the hidden variables and Jacobian determinants produced by the +normalizing flow transformations. + +Example: + >>> from anomalib.models.image.fastflow.loss import FastflowLoss + >>> criterion = FastflowLoss() + >>> hidden_vars = [torch.randn(2, 64, 32, 32)] # from NF blocks + >>> jacobians = [torch.randn(2)] # log det jacobians + >>> loss = criterion(hidden_vars, jacobians) + >>> loss.shape + torch.Size([]) + +See Also: + :class:`anomalib.models.image.fastflow.torch_model.FastflowModel`: + PyTorch implementation of the FastFlow model architecture. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,18 +26,38 @@ class FastflowLoss(nn.Module): - """FastFlow Loss.""" + """FastFlow Loss Module. + + Computes the negative log-likelihood loss used to train the FastFlow model. The loss + combines the log-likelihood of the hidden variables with the log determinant of the + Jacobian matrices from the normalizing flow transformations. + """ @staticmethod def forward(hidden_variables: list[torch.Tensor], jacobians: list[torch.Tensor]) -> torch.Tensor: - """Calculate the Fastflow loss. + """Calculate the FastFlow loss. + + The loss is computed as the negative log-likelihood of the hidden variables + transformed by the normalizing flows, taking into account the Jacobian + determinants of the transformations. Args: - hidden_variables (list[torch.Tensor]): Hidden variables from the fastflow model. f: X -> Z - jacobians (list[torch.Tensor]): Log of the jacobian determinants from the fastflow model. + hidden_variables (list[torch.Tensor]): List of hidden variable tensors + produced by the normalizing flow transformations. Each tensor has + shape ``(N, C, H, W)`` where ``N`` is batch size. + jacobians (list[torch.Tensor]): List of log determinants of Jacobian + matrices for each normalizing flow transformation. Each tensor has + shape ``(N,)`` where ``N`` is batch size. Returns: - Tensor: Fastflow loss computed based on the hidden variables and the log of the Jacobians. + torch.Tensor: Scalar loss value combining the negative log-likelihood + of hidden variables and Jacobian determinants. + + Example: + >>> criterion = FastflowLoss() + >>> h_vars = [torch.randn(2, 64, 32, 32)] # hidden variables + >>> jacs = [torch.randn(2)] # log det jacobians + >>> loss = criterion(h_vars, jacs) """ loss = torch.tensor(0.0, device=hidden_variables[0].device) # pylint: disable=not-callable for hidden_variable, jacobian in zip(hidden_variables, jacobians, strict=True): diff --git a/src/anomalib/models/image/fre/__init__.py b/src/anomalib/models/image/fre/__init__.py index 7de3b5b399..91646c778f 100755 --- a/src/anomalib/models/image/fre/__init__.py +++ b/src/anomalib/models/image/fre/__init__.py @@ -1,4 +1,26 @@ -"""Deep Feature Extraction (DFM) model.""" +"""Feature Reconstruction Error (FRE) Algorithm Implementation. + +FRE is an anomaly detection model that uses feature reconstruction error to detect +anomalies. The model extracts features from a pre-trained CNN backbone and learns +to reconstruct them using an autoencoder. Anomalies are detected by measuring the +reconstruction error. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Fre + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = Fre() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +See Also: + :class:`anomalib.models.image.fre.lightning_model.Fre`: + PyTorch Lightning implementation of the FRE model. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/fre/lightning_model.py b/src/anomalib/models/image/fre/lightning_model.py index 953fcd4322..6021f6655a 100755 --- a/src/anomalib/models/image/fre/lightning_model.py +++ b/src/anomalib/models/image/fre/lightning_model.py @@ -1,6 +1,30 @@ -"""FRE: Feature-Reconstruction Error. +"""Feature Reconstruction Error (FRE) Algorithm Implementation. -https://papers.bmvc2023.org/0614.pdf +FRE is an anomaly detection model that uses feature reconstruction error to detect +anomalies. The model extracts features from a pre-trained CNN backbone and learns +to reconstruct them using a tied autoencoder. Anomalies are detected by measuring +the reconstruction error between the original and reconstructed features. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Fre + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = Fre() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +Paper: + Title: FRE: Feature Reconstruction Error for Unsupervised Anomaly Detection + and Segmentation + URL: https://papers.bmvc2023.org/0614.pdf + +See Also: + :class:`anomalib.models.image.fre.torch_model.FREModel`: + PyTorch implementation of the FRE model architecture. """ # Copyright (C) 2024 Intel Corporation @@ -29,23 +53,51 @@ class Fre(AnomalibModule): """FRE: Feature-reconstruction error using Tied AutoEncoder. + The FRE model extracts features from a pre-trained CNN backbone and learns to + reconstruct them using a tied autoencoder. Anomalies are detected by measuring + the reconstruction error between original and reconstructed features. + Args: - backbone (str): Backbone CNN network - Defaults to ``resnet50``. - layer (str): Layer to extract features from the backbone CNN - Defaults to ``layer3``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + backbone (str): Backbone CNN network architecture. + Defaults to ``"resnet50"``. + layer (str): Layer name to extract features from the backbone CNN. + Defaults to ``"layer3"``. + pre_trained (bool, optional): Whether to use pre-trained backbone weights. Defaults to ``True``. - pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. + pooling_kernel_size (int, optional): Kernel size for pooling features + extracted from the CNN. Defaults to ``2``. - input_dim (int, optional): Dimension of feature at output of layer specified in layer. + input_dim (int, optional): Dimension of features at output of specified + layer. Defaults to ``65536``. - latent_dim (int, optional): Reduced size of feature after applying dimensionality reduction - via shallow linear autoencoder. + latent_dim (int, optional): Reduced feature dimension after applying + dimensionality reduction via shallow linear autoencoder. Defaults to ``220``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Pre-processor to transform + inputs before passing to model. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor to generate + predictions from model outputs. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator to compute metrics. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer to display results. + Defaults to ``True``. + + Example: + >>> from anomalib.models import Fre + >>> model = Fre( + ... backbone="resnet50", + ... layer="layer3", + ... pre_trained=True, + ... pooling_kernel_size=2, + ... input_dim=65536, + ... latent_dim=220, + ... ) + + See Also: + :class:`anomalib.models.image.fre.torch_model.FREModel`: + PyTorch implementation of the FRE model architecture. """ def __init__( @@ -82,22 +134,24 @@ def configure_optimizers(self) -> torch.optim.Optimizer: """Configure optimizers. Returns: - Optimizer: Adam optimizer + torch.optim.Optimizer: Adam optimizer for training the model. """ return optim.Adam(params=self.model.fre_model.parameters(), lr=1e-3) def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform the training step of FRE. - For each batch, features are extracted from the CNN. + For each batch, features are extracted from the CNN backbone and + reconstructed using the tied autoencoder. The loss is computed as the MSE + between original and reconstructed features. Args: - batch (Batch): Input batch - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and labels. + args: Additional arguments (unused). + kwargs: Additional keyword arguments (unused). Returns: - Deep CNN features. + STEP_OUTPUT: Dictionary containing the loss value. """ del args, kwargs # These variables are not used. features_in, features_out, _ = self.model.get_features(batch.image) @@ -108,15 +162,16 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform the validation step of FRE. - Similar to the training step, features are extracted from the CNN for each batch. + Similar to training, features are extracted and reconstructed. The + reconstruction error is used to compute anomaly scores and maps. Args: - batch (Batch): Input batch - args: Arguments. - kwargs: Keyword arguments. + batch (Batch): Input batch containing images and labels. + args: Additional arguments (unused). + kwargs: Additional keyword arguments (unused). Returns: - Dictionary containing FRE anomaly scores and anomaly maps. + STEP_OUTPUT: Dictionary containing anomaly scores and maps. """ del args, kwargs # These variables are not used. @@ -125,7 +180,14 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Return FRE-specific trainer arguments.""" + """Return FRE-specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary of trainer arguments: + - ``gradient_clip_val``: ``0`` + - ``max_epochs``: ``220`` + - ``num_sanity_val_steps``: ``0`` + """ return {"gradient_clip_val": 0, "max_epochs": 220, "num_sanity_val_steps": 0} @property @@ -133,6 +195,6 @@ def learning_type(self) -> LearningType: """Return the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: Learning type of the model (``ONE_CLASS``). """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/fre/torch_model.py b/src/anomalib/models/image/fre/torch_model.py index c2eb0c3416..760bae258a 100755 --- a/src/anomalib/models/image/fre/torch_model.py +++ b/src/anomalib/models/image/fre/torch_model.py @@ -1,4 +1,35 @@ -"""PyTorch model for FRE model implementation.""" +"""PyTorch model for the Feature Reconstruction Error (FRE) algorithm implementation. + +The FRE model extracts features from a pre-trained CNN backbone and learns to +reconstruct them using a tied autoencoder. Anomalies are detected by measuring +the reconstruction error between original and reconstructed features. + +Example: + >>> from anomalib.models.image.fre.torch_model import FREModel + >>> model = FREModel( + ... backbone="resnet50", + ... layer="layer3", + ... input_dim=65536, + ... latent_dim=220, + ... pre_trained=True, + ... pooling_kernel_size=4 + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output.pred_score.shape + torch.Size([32]) + >>> output.anomaly_map.shape + torch.Size([32, 1, 256, 256]) + +Paper: + Title: FRE: Feature Reconstruction Error for Unsupervised Anomaly Detection + and Segmentation + URL: https://papers.bmvc2023.org/0614.pdf + +See Also: + :class:`anomalib.models.image.fre.lightning_model.Fre`: + PyTorch Lightning implementation of the FRE model. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,11 +43,21 @@ class TiedAE(nn.Module): - """Model for the Tied AutoEncoder used for FRE calculation. + """Tied Autoencoder used for feature reconstruction error calculation. + + The tied autoencoder uses shared weights between encoder and decoder to reduce + the number of parameters while maintaining reconstruction capability. Args: - input_dim (int): Dimension of input to the tied auto-encoder. - latent_dim (int): Dimension of the reduced-dimension latent space of the tied auto-encoder. + input_dim (int): Dimension of input features to the tied autoencoder. + latent_dim (int): Dimension of the reduced latent space representation. + + Example: + >>> tied_ae = TiedAE(input_dim=1024, latent_dim=128) + >>> features = torch.randn(32, 1024) + >>> reconstructed = tied_ae(features) + >>> reconstructed.shape + torch.Size([32, 1024]) """ def __init__(self, input_dim: int, latent_dim: int) -> None: @@ -31,31 +72,59 @@ def __init__(self, input_dim: int, latent_dim: int) -> None: def forward(self, features: torch.Tensor) -> torch.Tensor: """Run input features through the autoencoder. + The features are first encoded to a lower dimensional latent space and + then decoded back to the original feature space using transposed weights. + Args: - features (torch.Tensor): Feature batch. + features (torch.Tensor): Input feature batch of shape + ``(N, input_dim)``. Returns: - Tensor: torch.Tensor containing reconstructed features. + torch.Tensor: Reconstructed features of shape ``(N, input_dim)``. """ encoded = F.linear(features, self.weight, self.encoder_bias) return F.linear(encoded, self.weight.t(), self.decoder_bias) class FREModel(nn.Module): - """Model for the FRE algorithm. + """Feature Reconstruction Error (FRE) model implementation. + + The model extracts features from a pre-trained CNN backbone and learns to + reconstruct them using a tied autoencoder. Anomalies are detected by + measuring the reconstruction error between original and reconstructed + features. Args: - backbone (str): Pre-trained model backbone. - layer (str): Layer from which to extract features. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. - Defaults to ``True``. - pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. - Defaults to ``4``. - input_dim (int, optional): Dimension of feature at output of layer specified in layer. + backbone (str): Pre-trained CNN backbone architecture (e.g. + ``"resnet18"``, ``"resnet50"``, etc.). + layer (str): Layer name from which to extract features (e.g. + ``"layer2"``, ``"layer3"``, etc.). + input_dim (int, optional): Dimension of features at output of specified + layer. Defaults to ``65536``. - latent_dim (int, optional): Reduced size of feature after applying dimensionality reduction - via shallow linear autoencoder. + latent_dim (int, optional): Reduced feature dimension after applying + dimensionality reduction via shallow linear autoencoder. Defaults to ``220``. + pre_trained (bool, optional): Whether to use pre-trained backbone + weights. + Defaults to ``True``. + pooling_kernel_size (int, optional): Kernel size for pooling features + extracted from the CNN. + Defaults to ``4``. + + Example: + >>> model = FREModel( + ... backbone="resnet50", + ... layer="layer3", + ... input_dim=65536, + ... latent_dim=220 + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + >>> output.pred_score.shape + torch.Size([32]) + >>> output.anomaly_map.shape + torch.Size([32, 1, 256, 256]) """ def __init__( @@ -79,13 +148,18 @@ def __init__( ).eval() def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Extract features from the pretrained network. + """Extract and reconstruct features from the pretrained network. Args: - batch (torch.Tensor): Image batch. + batch (torch.Tensor): Input image batch of shape + ``(N, C, H, W)``. Returns: - Tensor: torch.Tensor containing extracted features. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing: + - Original features of shape ``(N, D)`` + - Reconstructed features of shape ``(N, D)`` + - Original feature tensor shape ``(N, C, H, W)`` + where ``D`` is the flattened feature dimension. """ self.feature_extractor.eval() features_in = self.feature_extractor(batch)[self.layer] @@ -98,13 +172,22 @@ def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, return features_in, features_out, feature_shapes def forward(self, batch: torch.Tensor) -> InferenceBatch: - """Compute score from input images. + """Generate anomaly predictions for input images. + + The method: + 1. Extracts and reconstructs features using the tied autoencoder + 2. Computes reconstruction error as anomaly scores + 3. Generates pixel-wise anomaly maps + 4. Upsamples anomaly maps to input image size Args: - batch (torch.Tensor): Input images + batch (torch.Tensor): Input image batch of shape + ``(N, C, H, W)``. Returns: - tuple[torch.Tensor, torch.Tensor]: Scores, Anomaly Map + InferenceBatch: Batch containing: + - Anomaly scores of shape ``(N,)`` + - Anomaly maps of shape ``(N, 1, H, W)`` """ features_in, features_out, feature_shapes = self.get_features(batch) fre = torch.square(features_in - features_out).reshape(feature_shapes) diff --git a/src/anomalib/models/image/ganomaly/__init__.py b/src/anomalib/models/image/ganomaly/__init__.py index ec872b077d..ea07b478ca 100644 --- a/src/anomalib/models/image/ganomaly/__init__.py +++ b/src/anomalib/models/image/ganomaly/__init__.py @@ -1,4 +1,30 @@ -"""GANomaly Model.""" +"""GANomaly Algorithm Implementation. + +GANomaly is an anomaly detection model that uses a conditional GAN architecture to +learn the normal data distribution. The model consists of a generator network that +learns to reconstruct normal images, and a discriminator that helps ensure the +reconstructions are realistic. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Ganomaly + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = Ganomaly() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +Paper: + Title: GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training + URL: https://arxiv.org/abs/1805.06725 + +See Also: + :class:`anomalib.models.image.ganomaly.lightning_model.Ganomaly`: + PyTorch Lightning implementation of the GANomaly model. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/ganomaly/lightning_model.py b/src/anomalib/models/image/ganomaly/lightning_model.py index 4b48b0b633..4f2421d37b 100644 --- a/src/anomalib/models/image/ganomaly/lightning_model.py +++ b/src/anomalib/models/image/ganomaly/lightning_model.py @@ -1,6 +1,33 @@ """GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training. -https://arxiv.org/abs/1805.06725 +GANomaly is an anomaly detection model that uses a conditional GAN architecture to +learn the normal data distribution. The model consists of a generator network that +learns to reconstruct normal images, and a discriminator that helps ensure the +reconstructions are realistic. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Ganomaly + >>> from anomalib.engine import Engine + + >>> datamodule = MVTec() + >>> model = Ganomaly() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +Paper: + Title: GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training + URL: https://arxiv.org/abs/1805.06725 + +See Also: + :class:`anomalib.models.image.ganomaly.torch_model.GanomalyModel`: + PyTorch implementation of the GANomaly model architecture. + :class:`anomalib.models.image.ganomaly.loss.GeneratorLoss`: + Loss function for the generator network. + :class:`anomalib.models.image.ganomaly.loss.DiscriminatorLoss`: + Loss function for the discriminator network. """ # Copyright (C) 2022-2024 Intel Corporation @@ -30,32 +57,63 @@ class Ganomaly(AnomalibModule): """PL Lightning Module for the GANomaly Algorithm. + The GANomaly model consists of a generator and discriminator network. The + generator learns to reconstruct normal images while the discriminator helps + ensure the reconstructions are realistic. Anomalies are detected by measuring + the reconstruction error and latent space differences. + Args: - batch_size (int): Batch size. + batch_size (int): Number of samples in each batch. Defaults to ``32``. - n_features (int): Number of features layers in the CNNs. + n_features (int): Number of feature channels in CNN layers. Defaults to ``64``. - latent_vec_size (int): Size of autoencoder latent vector. + latent_vec_size (int): Dimension of the latent space vectors. Defaults to ``100``. - extra_layers (int, optional): Number of extra layers for encoder/decoder. + extra_layers (int, optional): Number of extra layers in encoder/decoder. Defaults to ``0``. add_final_conv_layer (bool, optional): Add convolution layer at the end. Defaults to ``True``. - wadv (int, optional): Weight for adversarial loss. + wadv (int, optional): Weight for adversarial loss component. Defaults to ``1``. - wcon (int, optional): Image regeneration weight. + wcon (int, optional): Weight for image reconstruction loss component. Defaults to ``50``. - wenc (int, optional): Latent vector encoder weight. + wenc (int, optional): Weight for latent vector encoding loss component. Defaults to ``1``. - lr (float, optional): Learning rate. + lr (float, optional): Learning rate for optimizers. Defaults to ``0.0002``. - beta1 (float, optional): Adam beta1. + beta1 (float, optional): Beta1 parameter for Adam optimizers. Defaults to ``0.5``. - beta2 (float, optional): Adam beta2. + beta2 (float, optional): Beta2 parameter for Adam optimizers. Defaults to ``0.999``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Pre-processor to transform + inputs before passing to model. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor to generate + predictions from model outputs. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator to compute metrics. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer to display results. + Defaults to ``True``. + + Example: + >>> from anomalib.models import Ganomaly + >>> model = Ganomaly( + ... batch_size=32, + ... n_features=64, + ... latent_vec_size=100, + ... wadv=1, + ... wcon=50, + ... wenc=1, + ... ) + + See Also: + :class:`anomalib.models.image.ganomaly.torch_model.GanomalyModel`: + PyTorch implementation of the GANomaly model architecture. + :class:`anomalib.models.image.ganomaly.loss.GeneratorLoss`: + Loss function for the generator network. + :class:`anomalib.models.image.ganomaly.loss.DiscriminatorLoss`: + Loss function for the discriminator network. """ def __init__( diff --git a/src/anomalib/models/image/ganomaly/loss.py b/src/anomalib/models/image/ganomaly/loss.py index 6262ef1764..fb50ce24b5 100644 --- a/src/anomalib/models/image/ganomaly/loss.py +++ b/src/anomalib/models/image/ganomaly/loss.py @@ -1,4 +1,23 @@ -"""Loss function for the GANomaly Model Implementation.""" +"""Loss functions for the GANomaly model implementation. + +The GANomaly model uses two loss functions: + +1. Generator Loss: Combines adversarial loss, reconstruction loss and encoding loss +2. Discriminator Loss: Binary cross entropy loss for real/fake image discrimination + +Example: + >>> from anomalib.models.image.ganomaly.loss import GeneratorLoss + >>> generator_loss = GeneratorLoss(wadv=1, wcon=50, wenc=1) + >>> loss = generator_loss(latent_i, latent_o, images, fake, pred_real, pred_fake) + + >>> from anomalib.models.image.ganomaly.loss import DiscriminatorLoss + >>> discriminator_loss = DiscriminatorLoss() + >>> loss = discriminator_loss(pred_real, pred_fake) + +See Also: + :class:`anomalib.models.image.ganomaly.torch_model.GanomalyModel`: + PyTorch implementation of the GANomaly model architecture. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,13 +29,27 @@ class GeneratorLoss(nn.Module): """Generator loss for the GANomaly model. + Combines three components: + 1. Adversarial loss: Helps generate realistic images + 2. Contextual loss: Ensures generated images match input + 3. Encoding loss: Enforces consistency in latent space + Args: - wadv (int, optional): Weight for adversarial loss. - Defaults to ``1``. - wcon (int, optional): Image regeneration weight. + wadv (int, optional): Weight for adversarial loss. Defaults to ``1``. + wcon (int, optional): Weight for contextual/reconstruction loss. Defaults to ``50``. - wenc (int, optional): Latent vector encoder weight. - Defaults to ``1``. + wenc (int, optional): Weight for encoding/latent loss. Defaults to ``1``. + + Example: + >>> generator_loss = GeneratorLoss(wadv=1, wcon=50, wenc=1) + >>> loss = generator_loss( + ... latent_i=torch.randn(32, 100), + ... latent_o=torch.randn(32, 100), + ... images=torch.randn(32, 3, 256, 256), + ... fake=torch.randn(32, 3, 256, 256), + ... pred_real=torch.randn(32, 1), + ... pred_fake=torch.randn(32, 1) + ... ) """ def __init__(self, wadv: int = 1, wcon: int = 50, wenc: int = 1) -> None: @@ -39,18 +72,22 @@ def forward( pred_real: torch.Tensor, pred_fake: torch.Tensor, ) -> torch.Tensor: - """Compute the loss for a batch. + """Compute the generator loss for a batch. Args: - latent_i (torch.Tensor): Latent features of the first encoder. - latent_o (torch.Tensor): Latent features of the second encoder. - images (torch.Tensor): Real image that served as input of the generator. - fake (torch.Tensor): Generated image. - pred_real (torch.Tensor): Discriminator predictions for the real image. - pred_fake (torch.Tensor): Discriminator predictions for the fake image. + latent_i (torch.Tensor): Latent features from the first encoder. + latent_o (torch.Tensor): Latent features from the second encoder. + images (torch.Tensor): Real images that served as generator input. + fake (torch.Tensor): Generated/fake images. + pred_real (torch.Tensor): Discriminator predictions for real images. + pred_fake (torch.Tensor): Discriminator predictions for fake images. Returns: - Tensor: The computed generator loss. + torch.Tensor: Combined weighted generator loss. + + Example: + >>> loss = generator_loss(latent_i, latent_o, images, fake, + ... pred_real, pred_fake) """ error_enc = self.loss_enc(latent_i, latent_o) error_con = self.loss_con(images, fake) @@ -60,7 +97,18 @@ def forward( class DiscriminatorLoss(nn.Module): - """Discriminator loss for the GANomaly model.""" + """Discriminator loss for the GANomaly model. + + Uses binary cross entropy to train the discriminator to distinguish between + real and generated images. + + Example: + >>> discriminator_loss = DiscriminatorLoss() + >>> loss = discriminator_loss( + ... pred_real=torch.randn(32, 1), + ... pred_fake=torch.randn(32, 1) + ... ) + """ def __init__(self) -> None: super().__init__() @@ -68,14 +116,17 @@ def __init__(self) -> None: self.loss_bce = nn.BCELoss() def forward(self, pred_real: torch.Tensor, pred_fake: torch.Tensor) -> torch.Tensor: - """Compute the loss for a predicted batch. + """Compute the discriminator loss for predicted batch. Args: - pred_real (torch.Tensor): Discriminator predictions for the real image. - pred_fake (torch.Tensor): Discriminator predictions for the fake image. + pred_real (torch.Tensor): Discriminator predictions for real images. + pred_fake (torch.Tensor): Discriminator predictions for fake images. Returns: - Tensor: The computed discriminator loss. + torch.Tensor: Average discriminator loss. + + Example: + >>> loss = discriminator_loss(pred_real, pred_fake) """ error_discriminator_real = self.loss_bce( pred_real, diff --git a/src/anomalib/models/image/ganomaly/torch_model.py b/src/anomalib/models/image/ganomaly/torch_model.py index 3d791c8501..26d4be55ab 100644 --- a/src/anomalib/models/image/ganomaly/torch_model.py +++ b/src/anomalib/models/image/ganomaly/torch_model.py @@ -1,11 +1,46 @@ -"""Torch models defining encoder, decoder, Generator and Discriminator. - -Code adapted from https://github.com/samet-akcay/ganomaly. +"""Torch models defining encoder, decoder, generator and discriminator networks. + +The GANomaly model consists of several key components: + +1. Encoder: Compresses input images into latent vectors +2. Decoder: Reconstructs images from latent vectors +3. Generator: Combines encoder-decoder-encoder for image generation +4. Discriminator: Distinguishes real from generated images + +The architecture follows an encoder-decoder-encoder pattern where: +- First encoder compresses input image to latent space +- Decoder reconstructs the image from latent vector +- Second encoder re-encodes reconstructed image +- Anomaly score is based on difference between latent vectors + +Example: + >>> from anomalib.models.image.ganomaly.torch_model import GanomalyModel + >>> model = GanomalyModel( + ... input_size=(256, 256), + ... num_input_channels=3, + ... n_features=64, + ... latent_vec_size=100, + ... extra_layers=0, + ... add_final_conv_layer=True + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + +Code adapted from: + Title: GANomaly - PyTorch Implementation + Authors: Samet Akcay + URL: https://github.com/samet-akcay/ganomaly + License: MIT + +See Also: + - :class:`anomalib.models.image.ganomaly.lightning_model.Ganomaly`: + Lightning implementation of the GANomaly model + - :class:`anomalib.models.image.ganomaly.loss.GeneratorLoss`: + Loss function for the generator network + - :class:`anomalib.models.image.ganomaly.loss.DiscriminatorLoss`: + Loss function for the discriminator network """ -# Copyright (c) 2018-2022 Samet Akcay, Durham University, UK -# SPDX-License-Identifier: MIT -# # Copyright (C) 2020-2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -21,15 +56,28 @@ class Encoder(nn.Module): """Encoder Network. + Compresses input images into latent vectors through a series of convolution + layers. + Args: - input_size (tuple[int, int]): Size of input image - latent_vec_size (int): Size of latent vector z - num_input_channels (int): Number of input channels in the image - n_features (int): Number of features per convolution layer - extra_layers (int): Number of extra layers since the network uses only a single encoder layer by default. + input_size (tuple[int, int]): Size of input image (height, width) + latent_vec_size (int): Size of output latent vector + num_input_channels (int): Number of input image channels + n_features (int): Number of feature maps in convolution layers + extra_layers (int, optional): Number of extra intermediate layers. Defaults to ``0``. - add_final_conv_layer (bool): Add a final convolution layer in the encoder. - Defaults to ``True``. + add_final_conv_layer (bool, optional): Whether to add final convolution + layer. Defaults to ``True``. + + Example: + >>> encoder = Encoder( + ... input_size=(256, 256), + ... latent_vec_size=100, + ... num_input_channels=3, + ... n_features=64 + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> latent = encoder(input_tensor) """ def __init__( @@ -88,7 +136,15 @@ def __init__( ) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Return latent vectors.""" + """Forward pass through encoder network. + + Args: + input_tensor (torch.Tensor): Input tensor of shape + ``(batch_size, channels, height, width)`` + + Returns: + torch.Tensor: Latent vector tensor + """ output = self.input_layers(input_tensor) output = self.extra_layers(output) output = self.pyramid_features(output) @@ -101,13 +157,25 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: class Decoder(nn.Module): """Decoder Network. + Reconstructs images from latent vectors through transposed convolutions. + Args: - input_size (tuple[int, int]): Size of input image - latent_vec_size (int): Size of latent vector z - num_input_channels (int): Number of input channels in the image - n_features (int): Number of features per convolution layer - extra_layers (int): Number of extra layers since the network uses only a single encoder layer by default. + input_size (tuple[int, int]): Size of output image (height, width) + latent_vec_size (int): Size of input latent vector + num_input_channels (int): Number of output image channels + n_features (int): Number of feature maps in convolution layers + extra_layers (int, optional): Number of extra intermediate layers. Defaults to ``0``. + + Example: + >>> decoder = Decoder( + ... input_size=(256, 256), + ... latent_vec_size=100, + ... num_input_channels=3, + ... n_features=64 + ... ) + >>> latent = torch.randn(32, 100, 1, 1) + >>> reconstruction = decoder(latent) """ def __init__( @@ -195,7 +263,14 @@ def __init__( self.final_layers.add_module(f"final-{num_input_channels}-tanh", nn.Tanh()) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Return generated image.""" + """Forward pass through decoder network. + + Args: + input_tensor (torch.Tensor): Input latent tensor + + Returns: + torch.Tensor: Reconstructed image tensor + """ output = self.latent_input(input_tensor) output = self.inverse_pyramid(output) output = self.extra_layers(output) @@ -203,16 +278,25 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: class Discriminator(nn.Module): - """Discriminator. + """Discriminator Network. - Made of only one encoder layer which takes x and x_hat to produce a score. + Classifies images as real or generated using a modified encoder architecture. Args: - input_size (tuple[int, int]): Input image size. - num_input_channels (int): Number of image channels. - n_features (int): Number of feature maps in each convolution layer. - extra_layers (int, optional): Add extra intermediate layers. + input_size (tuple[int, int]): Input image size (height, width) + num_input_channels (int): Number of input image channels + n_features (int): Number of feature maps in convolution layers + extra_layers (int, optional): Number of extra intermediate layers. Defaults to ``0``. + + Example: + >>> discriminator = Discriminator( + ... input_size=(256, 256), + ... num_input_channels=3, + ... n_features=64 + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> prediction, features = discriminator(input_tensor) """ def __init__( @@ -236,7 +320,16 @@ def __init__( self.classifier.add_module("Sigmoid", nn.Sigmoid()) def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Return class of object and features.""" + """Forward pass through discriminator network. + + Args: + input_tensor (torch.Tensor): Input image tensor + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - Classification scores (real/fake) + - Intermediate features + """ features = self.features(input_tensor) classifier = self.classifier(features) classifier = classifier.view(-1, 1).squeeze(1) @@ -244,19 +337,30 @@ def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso class Generator(nn.Module): - """Generator model. + """Generator Network. - Made of an encoder-decoder-encoder architecture. + Combines encoder-decoder-encoder architecture for image generation and + reconstruction. Args: - input_size (tuple[int, int]): Size of input data. - latent_vec_size (int): Dimension of latent vector produced between the first encoder-decoder. - num_input_channels (int): Number of channels in input image. - n_features (int): Number of feature maps in each convolution layer. - extra_layers (int, optional): Extra intermediate layers in the encoder/decoder. + input_size (tuple[int, int]): Input/output image size (height, width) + latent_vec_size (int): Size of latent vector between encoder-decoder + num_input_channels (int): Number of input/output image channels + n_features (int): Number of feature maps in convolution layers + extra_layers (int, optional): Number of extra intermediate layers. Defaults to ``0``. - add_final_conv_layer (bool, optional): Add a final convolution layer in the decoder. + add_final_conv_layer (bool, optional): Add final convolution to encoders. Defaults to ``True``. + + Example: + >>> generator = Generator( + ... input_size=(256, 256), + ... latent_vec_size=100, + ... num_input_channels=3, + ... n_features=64 + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> gen_img, latent_i, latent_o = generator(input_tensor) """ def __init__( @@ -288,7 +392,17 @@ def __init__( ) def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Return generated image and the latent vectors.""" + """Forward pass through generator network. + + Args: + input_tensor (torch.Tensor): Input image tensor + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing: + - Generated image + - First encoder's latent vector + - Second encoder's latent vector + """ latent_i = self.encoder1(input_tensor) gen_image = self.decoder(latent_i) latent_o = self.encoder2(gen_image) @@ -296,17 +410,35 @@ def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso class GanomalyModel(nn.Module): - """Ganomaly Model. + """GANomaly model for anomaly detection. + + Complete model combining Generator and Discriminator networks. Args: - input_size (tuple[int, int]): Input dimension. - num_input_channels (int): Number of input channels. - n_features (int): Number of features layers in the CNNs. - latent_vec_size (int): Size of autoencoder latent vector. - extra_layers (int, optional): Number of extra layers for encoder/decoder. + input_size (tuple[int, int]): Input image size (height, width) + num_input_channels (int): Number of input image channels + n_features (int): Number of feature maps in convolution layers + latent_vec_size (int): Size of latent vector between encoder-decoder + extra_layers (int, optional): Number of extra intermediate layers. Defaults to ``0``. - add_final_conv_layer (bool, optional): Add convolution layer at the end. + add_final_conv_layer (bool, optional): Add final convolution to encoders. Defaults to ``True``. + + Example: + >>> model = GanomalyModel( + ... input_size=(256, 256), + ... num_input_channels=3, + ... n_features=64, + ... latent_vec_size=100 + ... ) + >>> input_tensor = torch.randn(32, 3, 256, 256) + >>> output = model(input_tensor) + + References: + - Title: GANomaly: Semi-Supervised Anomaly Detection via Adversarial + Training + - Authors: Samet Akcay, Amir Atapour-Abarghouei, Toby P. Breckon + - URL: https://arxiv.org/abs/1805.06725 """ def __init__( @@ -341,7 +473,7 @@ def weights_init(module: nn.Module) -> None: """Initialize DCGAN weights. Args: - module (nn.Module): [description] + module (nn.Module): Neural network module to initialize """ classname = module.__class__.__name__ if classname.find("Conv") != -1: @@ -354,13 +486,21 @@ def forward( self, batch: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | InferenceBatch: - """Get scores for batch. + """Forward pass through GANomaly model. Args: - batch (torch.Tensor): Images + batch (torch.Tensor): Batch of input images Returns: - Tensor: Regeneration scores. + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | + InferenceBatch: + If training: + - Padded input batch + - Generated images + - First encoder's latent vectors + - Second encoder's latent vectors + If inference: + - Batch containing anomaly scores """ padded_batch = pad_nextpow2(batch) fake, latent_i, latent_o = self.generator(padded_batch) diff --git a/src/anomalib/models/image/padim/__init__.py b/src/anomalib/models/image/padim/__init__.py index 944e8f20c3..3dcbbd1d43 100644 --- a/src/anomalib/models/image/padim/__init__.py +++ b/src/anomalib/models/image/padim/__init__.py @@ -1,4 +1,17 @@ -"""PADIM model.""" +"""PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + +The PaDiM model is an anomaly detection approach that leverages patch-based +distribution modeling using pretrained CNN feature embeddings. It models the +distribution of patch embeddings at each spatial location using multivariate +Gaussian distributions. + +The model uses features extracted from multiple layers of networks like +``ResNet`` to capture both semantic and low-level visual information. During +inference, it computes Mahalanobis distances between test patch embeddings and +their corresponding reference distributions to detect anomalies. + +Paper: https://arxiv.org/abs/2011.08785 +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/padim/anomaly_map.py b/src/anomalib/models/image/padim/anomaly_map.py index 054a930664..4807ccf3dd 100644 --- a/src/anomalib/models/image/padim/anomaly_map.py +++ b/src/anomalib/models/image/padim/anomaly_map.py @@ -1,4 +1,32 @@ -"""Anomaly Map Generator for the PaDiM model implementation.""" +"""Anomaly Map Generator for the PaDiM model implementation. + +This module generates anomaly heatmaps for the PaDiM model by computing Mahalanobis +distances between test patch embeddings and reference distributions. + +The anomaly map generation process involves: +1. Computing Mahalanobis distances between embeddings and reference statistics +2. Upsampling the distance map to match input image size +3. Applying Gaussian smoothing to obtain the final anomaly map + +Example: + >>> from anomalib.models.image.padim.anomaly_map import AnomalyMapGenerator + >>> generator = AnomalyMapGenerator(sigma=4) + >>> embedding = torch.randn(32, 1024, 28, 28) + >>> mean = torch.randn(1024, 784) # 784 = 28*28 + >>> inv_covariance = torch.randn(784, 1024, 1024) + >>> anomaly_map = generator( + ... embedding=embedding, + ... mean=mean, + ... inv_covariance=inv_covariance, + ... image_size=(224, 224) + ... ) + +See Also: + - :class:`anomalib.models.image.padim.lightning_model.Padim`: + Lightning implementation of the PaDiM model + - :class:`anomalib.models.components.GaussianBlur2d`: + Gaussian blur module used for smoothing anomaly maps +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,10 +41,24 @@ class AnomalyMapGenerator(nn.Module): """Generate Anomaly Heatmap. + This class implements anomaly map generation for the PaDiM model by computing + Mahalanobis distances and applying post-processing steps. + Args: - image_size (ListConfig, tuple): Size of the input image. The anomaly map is upsampled to this dimension. - sigma (int, optional): Standard deviation for Gaussian Kernel. - Defaults to ``4``. + sigma (int, optional): Standard deviation for Gaussian smoothing kernel. + Higher values produce smoother anomaly maps. Defaults to ``4``. + + Example: + >>> generator = AnomalyMapGenerator(sigma=4) + >>> embedding = torch.randn(32, 1024, 28, 28) + >>> mean = torch.randn(1024, 784) + >>> inv_covariance = torch.randn(784, 1024, 1024) + >>> anomaly_map = generator.compute_anomaly_map( + ... embedding=embedding, + ... mean=mean, + ... inv_covariance=inv_covariance, + ... image_size=(224, 224) + ... ) """ def __init__(self, sigma: int = 4) -> None: @@ -26,16 +68,20 @@ def __init__(self, sigma: int = 4) -> None: @staticmethod def compute_distance(embedding: torch.Tensor, stats: list[torch.Tensor]) -> torch.Tensor: - """Compute anomaly score to the patch in position(i,j) of a test image. + """Compute anomaly score for each patch position using Mahalanobis distance. - Ref: Equation (2), Section III-C of the paper. + Implements Equation (2) from Section III-C of the PaDiM paper to compute + the distance between patch embeddings and their reference distributions. Args: - embedding (torch.Tensor): Embedding Vector - stats (list[torch.Tensor]): Mean and Covariance Matrix of the multivariate Gaussian distribution + embedding (torch.Tensor): Feature embeddings from the CNN backbone, + shape ``(batch_size, n_features, height, width)`` + stats (list[torch.Tensor]): List containing mean and inverse covariance + tensors for the multivariate Gaussian distributions Returns: - Anomaly score of a test image via mahalanobis distance. + torch.Tensor: Anomaly scores computed via Mahalanobis distance, + shape ``(batch_size, 1, height, width)`` """ batch, channel, height, width = embedding.shape embedding = embedding.reshape(batch, channel, height * width) @@ -53,11 +99,13 @@ def up_sample(distance: torch.Tensor, image_size: tuple[int, int] | torch.Size) """Up sample anomaly score to match the input image size. Args: - distance (torch.Tensor): Anomaly score computed via the mahalanobis distance. - image_size (tuple[int, int] | torch.Size): Size to which the anomaly map should be upsampled. + distance (torch.Tensor): Anomaly scores, shape + ``(batch_size, 1, height, width)`` + image_size (tuple[int, int] | torch.Size): Target size for upsampling, + usually the original input image size Returns: - Resized distance matrix matching the input image size + torch.Tensor: Upsampled anomaly scores matching the input image size """ return F.interpolate( distance, @@ -67,13 +115,14 @@ def up_sample(distance: torch.Tensor, image_size: tuple[int, int] | torch.Size) ) def smooth_anomaly_map(self, anomaly_map: torch.Tensor) -> torch.Tensor: - """Apply gaussian smoothing to the anomaly map. + """Apply Gaussian smoothing to the anomaly map. Args: - anomaly_map (torch.Tensor): Anomaly score for the test image(s). + anomaly_map (torch.Tensor): Raw anomaly scores, + shape ``(batch_size, 1, height, width)`` Returns: - Filtered anomaly scores + torch.Tensor: Smoothed anomaly scores with reduced noise """ return self.blur(anomaly_map) @@ -84,19 +133,20 @@ def compute_anomaly_map( inv_covariance: torch.Tensor, image_size: tuple[int, int] | torch.Size | None = None, ) -> torch.Tensor: - """Compute anomaly score. + """Compute anomaly map from feature embeddings and distribution parameters. - Scores are calculated based on embedding vector, mean and inv_covariance of the multivariate gaussian - distribution. + This method combines distance computation, upsampling, and smoothing to + generate the final anomaly map. Args: - embedding (torch.Tensor): Embedding vector extracted from the test set. - mean (torch.Tensor): Mean of the multivariate gaussian distribution - inv_covariance (torch.Tensor): Inverse Covariance matrix of the multivariate gaussian distribution. - image_size (tuple[int, int] | torch.Size, optional): Size to which the anomaly map should be upsampled. + embedding (torch.Tensor): Feature embeddings from the CNN backbone + mean (torch.Tensor): Mean of the multivariate Gaussian distribution + inv_covariance (torch.Tensor): Inverse covariance matrix + image_size (tuple[int, int] | torch.Size | None, optional): Target + size for upsampling. If ``None``, no upsampling is performed. Returns: - Output anomaly score. + torch.Tensor: Final anomaly map after all processing steps """ score_map = self.compute_distance( embedding=embedding, @@ -107,19 +157,29 @@ def compute_anomaly_map( return self.smooth_anomaly_map(score_map) def forward(self, **kwargs) -> torch.Tensor: - """Return anomaly_map. + """Generate anomaly map from the provided embeddings and statistics. - Expects `embedding`, `mean` and `covariance` keywords to be passed explicitly. + Expects ``embedding``, ``mean`` and ``inv_covariance`` keywords to be + passed explicitly. Example: - >>> anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) - >>> output = anomaly_map_generator(embedding=embedding, mean=mean, covariance=covariance) + >>> generator = AnomalyMapGenerator(sigma=4) + >>> anomaly_map = generator( + ... embedding=embedding, + ... mean=mean, + ... inv_covariance=inv_covariance, + ... image_size=(224, 224) + ... ) + + Args: + **kwargs: Keyword arguments containing ``embedding``, ``mean``, + ``inv_covariance`` and optionally ``image_size`` Raises: - ValueError: `embedding`. `mean` or `covariance` keys are not found + ValueError: If required keys are not found in ``kwargs`` Returns: - torch.Tensor: anomaly map + torch.Tensor: Generated anomaly map """ if not ("embedding" in kwargs and "mean" in kwargs and "inv_covariance" in kwargs): msg = f"Expected keys `embedding`, `mean` and `covariance`. Found {kwargs.keys()}" diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index 78f17861c0..242cd309e7 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -1,6 +1,31 @@ """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. -Paper https://arxiv.org/abs/2011.08785 +This model implements the PaDiM algorithm for anomaly detection and localization. +PaDiM models the distribution of patch embeddings at each spatial location using +multivariate Gaussian distributions. + +The model extracts features from multiple layers of pretrained CNN backbones to +capture both semantic and low-level visual information. During inference, it +computes Mahalanobis distances between test patch embeddings and their +corresponding reference distributions. + +Paper: https://arxiv.org/abs/2011.08785 + +Example: + >>> from anomalib.models.image.padim import Padim + >>> model = Padim( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"], + ... pre_trained=True + ... ) + >>> model.fit() + >>> prediction = model(image) + +See Also: + - :class:`anomalib.models.image.padim.torch_model.PadimModel`: + PyTorch implementation of the PaDiM model architecture + - :class:`anomalib.models.image.padim.anomaly_map.AnomalyMapGenerator`: + Anomaly map generation for PaDiM using Mahalanobis distance """ # Copyright (C) 2022-2024 Intel Corporation @@ -27,21 +52,41 @@ class Padim(MemoryBankMixin, AnomalibModule): - """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection. Args: - backbone (str): Backbone CNN network - Defaults to ``resnet18``. - layers (list[str]): Layers to extract features from the backbone CNN - Defaults to ``["layer1", "layer2", "layer3"]``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + backbone (str): Name of the backbone CNN network. Available options are + ``resnet18``, ``wide_resnet50_2`` etc. Defaults to ``resnet18``. + layers (list[str]): List of layer names to extract features from the + backbone CNN. Defaults to ``["layer1", "layer2", "layer3"]``. + pre_trained (bool, optional): Use pre-trained backbone weights. Defaults to ``True``. - n_features (int, optional): Number of features to retain in the dimension reduction step. - Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550). - Defaults to ``None``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + n_features (int | None, optional): Number of features to retain after + dimension reduction. Default values from paper: ``resnet18=100``, + ``wide_resnet50_2=550``. Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Preprocessor to apply on + input data. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post processor to apply + on model outputs. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator for computing metrics. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer for generating + result images. Defaults to ``True``. + + Example: + >>> from anomalib.models.image.padim import Padim + >>> model = Padim( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"], + ... pre_trained=True + ... ) + >>> model.fit() + >>> prediction = model(image) + + Note: + The model does not require training in the traditional sense. It fits + Gaussian distributions to the extracted features during the training + phase. """ def __init__( @@ -78,15 +123,17 @@ def configure_optimizers() -> None: return def training_step(self, batch: Batch, *args, **kwargs) -> None: - """Perform the training step of PADIM. For each batch, hierarchical features are extracted from the CNN. + """Perform the training step of PADIM. + + For each batch, hierarchical features are extracted from the CNN. Args: - batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask - args: Additional arguments. - kwargs: Additional keyword arguments. + batch (Batch): Input batch containing image and metadata + args: Additional arguments (unused) + kwargs: Additional keyword arguments (unused) Returns: - Hierarchical feature map + torch.Tensor: Dummy loss tensor for Lightning compatibility """ del args, kwargs # These variables are not used. @@ -107,16 +154,17 @@ def fit(self) -> None: def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a validation step of PADIM. - Similar to the training step, hierarchical features are extracted from the CNN for each batch. + Similar to the training step, hierarchical features are extracted from + the CNN for each batch. Args: - batch (dict[str, str | torch.Tensor]): Input batch - args: Additional arguments. - kwargs: Additional keyword arguments. + batch (Batch): Input batch containing image and metadata + args: Additional arguments (unused) + kwargs: Additional keyword arguments (unused) Returns: - Dictionary containing images, features, true labels and masks. - These are required in `validation_epoch_end` for feature concatenation. + STEP_OUTPUT: Dictionary containing images, features, true labels + and masks required for validation """ del args, kwargs # These variables are not used. @@ -128,7 +176,11 @@ def trainer_arguments(self) -> dict[str, int | float]: """Return PADIM trainer arguments. Since the model does not require training, we limit the max_epochs to 1. - Since we need to run training epoch before validation, we also set the sanity steps to 0 + Since we need to run training epoch before validation, we also set the + sanity steps to 0. + + Returns: + dict[str, int | float]: Dictionary of trainer arguments """ return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0} @@ -137,11 +189,15 @@ def learning_type(self) -> LearningType: """Return the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: Learning type (ONE_CLASS for PaDiM) """ return LearningType.ONE_CLASS @staticmethod def configure_post_processor() -> OneClassPostProcessor: - """Return the default post-processor for PADIM.""" + """Return the default post-processor for PADIM. + + Returns: + OneClassPostProcessor: Default post-processor + """ return OneClassPostProcessor() diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index e537d87ca3..5f8e165a04 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -1,4 +1,35 @@ -"""PyTorch model for the PaDiM model implementation.""" +"""PyTorch model for the PaDiM model implementation. + +This module implements the PaDiM model architecture using PyTorch. PaDiM models the +distribution of patch embeddings at each spatial location using multivariate +Gaussian distributions. + +The model extracts features from multiple layers of pretrained CNN backbones to +capture both semantic and low-level visual information. During inference, it +computes Mahalanobis distances between test patch embeddings and their +corresponding reference distributions. + +Example: + >>> from anomalib.models.image.padim.torch_model import PadimModel + >>> model = PadimModel( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"], + ... pre_trained=True, + ... n_features=100 + ... ) + >>> input_tensor = torch.randn(32, 3, 224, 224) + >>> output = model(input_tensor) + +Paper: https://arxiv.org/abs/2011.08785 + +See Also: + - :class:`anomalib.models.image.padim.lightning_model.Padim`: + Lightning implementation of the PaDiM model + - :class:`anomalib.models.image.padim.anomaly_map.AnomalyMapGenerator`: + Anomaly map generation for PaDiM using Mahalanobis distance + - :class:`anomalib.models.components.MultiVariateGaussian`: + Multivariate Gaussian distribution modeling +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -33,11 +64,22 @@ def _deduce_dims( ) -> tuple[int, int]: """Run a dry run to deduce the dimensions of the extracted features. - Important: `layers` is assumed to be ordered and the first (layers[0]) - is assumed to be the layer with largest resolution. + This function performs a forward pass to determine the dimensions of features + extracted from the specified layers of the backbone network. + + Args: + feature_extractor (TimmFeatureExtractor): Feature extraction model + input_size (tuple[int, int]): Input image dimensions (height, width) + layers (list[str]): Names of layers to extract features from + + Important: + ``layers`` is assumed to be ordered and the first (``layers[0]``) + is assumed to be the layer with largest resolution. Returns: - tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) + tuple[int, int]: Dimensions of extracted features: + - n_dims_original: Total number of feature dimensions + - n_patches: Number of spatial patches """ dimensions_mapping = dryrun_find_featuremap_dims(feature_extractor, input_size, layers) @@ -56,13 +98,13 @@ class PadimModel(nn.Module): Args: layers (list[str]): Layers used for feature extraction - backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". - Defaults to ``resnet18``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. - Defaults to ``True``. - n_features (int, optional): Number of features to retain in the dimension reduction step. - Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550). - Defaults to ``None``. + backbone (str, optional): Pre-trained model backbone. Defaults to + ``resnet18``. + pre_trained (bool, optional): Boolean to check whether to use a + pre_trained backbone. Defaults to ``True``. + n_features (int, optional): Number of features to retain in the dimension + reduction step. Default values from the paper are available for: + resnet18 (100), wide_resnet50_2 (550). Defaults to ``None``. """ def __init__( @@ -110,18 +152,19 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. Args: - input_tensor: Image-batch (N, C, H, W) + input_tensor (torch.Tensor): Image batch with shape (N, C, H, W) Returns: - If training, returns the embeddings. - If inference, returns the prediction score and the anomaly map. + torch.Tensor | InferenceBatch: If training, returns the embeddings. + If inference, returns ``InferenceBatch`` containing prediction + scores and anomaly maps. Example: + >>> model = PadimModel() >>> x = torch.randn(32, 3, 224, 224) - >>> features = self.extract_features(input_tensor) + >>> features = model.extract_features(x) >>> features.keys() dict_keys(['layer1', 'layer2', 'layer3']) - >>> [v.shape for v in features.values()] [torch.Size([32, 64, 56, 56]), torch.Size([32, 128, 28, 28]), @@ -153,11 +196,17 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: """Generate embedding from hierarchical feature map. + This method combines features from multiple layers of the backbone network + to create a rich embedding that captures both low-level and high-level + image features. + Args: - features (dict[str, torch.Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) + features (dict[str, torch.Tensor]): Dictionary mapping layer names to + their feature tensors extracted from the backbone CNN. Returns: - Embedding vector + torch.Tensor: Embedding tensor combining features from all specified + layers, with dimensions reduced according to ``n_features``. """ embeddings = features[self.layers[0]] for layer in self.layers[1:]: diff --git a/src/anomalib/models/image/patchcore/__init__.py b/src/anomalib/models/image/patchcore/__init__.py index 1e69fa8571..1d716b53f0 100644 --- a/src/anomalib/models/image/patchcore/__init__.py +++ b/src/anomalib/models/image/patchcore/__init__.py @@ -1,4 +1,26 @@ -"""PatchCore model.""" +"""PatchCore: Towards Total Recall in Industrial Anomaly Detection. + +PatchCore is an anomaly detection model that uses a memory bank of patch features +extracted from a pretrained CNN backbone. It stores representative patch features +from normal training images and detects anomalies by comparing test image patches +against this memory bank. + +The model uses a nearest neighbor search to find the most similar patches in the +memory bank and computes anomaly scores based on these distances. It achieves +high performance while maintaining interpretability through localization maps. + +Example: + >>> from anomalib.models.image.patchcore import Patchcore + >>> model = Patchcore( + ... backbone="wide_resnet50_2", + ... layers=["layer2", "layer3"], + ... coreset_sampling_ratio=0.1 + ... ) + >>> model.fit() + >>> prediction = model(image) + +Paper: https://arxiv.org/abs/2106.08265 +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/patchcore/anomaly_map.py b/src/anomalib/models/image/patchcore/anomaly_map.py index 2c6cf5e69c..c2a9748305 100644 --- a/src/anomalib/models/image/patchcore/anomaly_map.py +++ b/src/anomalib/models/image/patchcore/anomaly_map.py @@ -1,4 +1,28 @@ -"""Anomaly Map Generator for the PatchCore model implementation.""" +"""Anomaly Map Generator for the PatchCore model implementation. + +This module generates anomaly heatmaps for the PatchCore model by upsampling +patch-level anomaly scores and applying Gaussian smoothing. + +The anomaly map generation process involves: +1. Taking patch-level anomaly scores as input +2. Optionally upsampling scores to match input image dimensions +3. Applying Gaussian blur to smooth the final anomaly map + +Example: + >>> from anomalib.models.image.patchcore.anomaly_map import AnomalyMapGenerator + >>> generator = AnomalyMapGenerator(sigma=4) + >>> patch_scores = torch.randn(32, 1, 28, 28) # (B, 1, H, W) + >>> anomaly_map = generator( + ... patch_scores=patch_scores, + ... image_size=(224, 224) + ... ) + +See Also: + - :class:`anomalib.models.image.patchcore.lightning_model.Patchcore`: + Lightning implementation of the PatchCore model + - :class:`anomalib.models.components.GaussianBlur2d`: + Gaussian blur module used for smoothing anomaly maps +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,10 +37,17 @@ class AnomalyMapGenerator(nn.Module): """Generate Anomaly Heatmap. + This class implements anomaly map generation for the PatchCore model by + upsampling patch scores and applying Gaussian smoothing. + Args: - The anomaly map is upsampled to this dimension. - sigma (int, optional): Standard deviation for Gaussian Kernel. - Defaults to ``4``. + sigma (int, optional): Standard deviation for Gaussian smoothing kernel. + Higher values produce smoother anomaly maps. Defaults to ``4``. + + Example: + >>> generator = AnomalyMapGenerator(sigma=4) + >>> patch_scores = torch.randn(32, 1, 28, 28) + >>> anomaly_map = generator(patch_scores) """ def __init__( @@ -32,16 +63,18 @@ def compute_anomaly_map( patch_scores: torch.Tensor, image_size: tuple[int, int] | torch.Size | None = None, ) -> torch.Tensor: - """Pixel Level Anomaly Heatmap. + """Compute pixel-level anomaly heatmap from patch scores. Args: - patch_scores (torch.Tensor): Patch-level anomaly scores - image_size (tuple[int, int] | torch.Size, optional): Size of the input image. - The anomaly map is upsampled to this dimension. - Defaults to None. + patch_scores (torch.Tensor): Patch-level anomaly scores with shape + ``(B, 1, H, W)`` + image_size (tuple[int, int] | torch.Size | None, optional): Target + size ``(H, W)`` to upsample anomaly map. If ``None``, keeps + original size. Defaults to ``None``. Returns: - Tensor: Map of the pixel-level anomaly scores + torch.Tensor: Pixel-level anomaly scores after upsampling and + smoothing, with shape ``(B, 1, H, W)`` """ if image_size is None: anomaly_map = patch_scores @@ -54,19 +87,25 @@ def forward( patch_scores: torch.Tensor, image_size: tuple[int, int] | torch.Size | None = None, ) -> torch.Tensor: - """Return anomaly_map and anomaly_score. + """Generate smoothed anomaly map from patch scores. Args: - patch_scores (torch.Tensor): Patch-level anomaly scores - image_size (tuple[int, int] | torch.Size, optional): Size of the input image. - The anomaly map is upsampled to this dimension. - Defaults to None. + patch_scores (torch.Tensor): Patch-level anomaly scores with shape + ``(B, 1, H, W)`` + image_size (tuple[int, int] | torch.Size | None, optional): Target + size ``(H, W)`` to upsample anomaly map. If ``None``, keeps + original size. Defaults to ``None``. Example: - >>> anomaly_map_generator = AnomalyMapGenerator() - >>> map = anomaly_map_generator(patch_scores=patch_scores) + >>> generator = AnomalyMapGenerator(sigma=4) + >>> patch_scores = torch.randn(32, 1, 28, 28) + >>> anomaly_map = generator( + ... patch_scores=patch_scores, + ... image_size=(224, 224) + ... ) Returns: - Tensor: anomaly_map + torch.Tensor: Anomaly heatmap after upsampling and smoothing, + with shape ``(B, 1, H, W)`` """ return self.compute_anomaly_map(patch_scores, image_size) diff --git a/src/anomalib/models/image/patchcore/lightning_model.py b/src/anomalib/models/image/patchcore/lightning_model.py index e58185e50e..bd8f9da4f7 100644 --- a/src/anomalib/models/image/patchcore/lightning_model.py +++ b/src/anomalib/models/image/patchcore/lightning_model.py @@ -1,6 +1,31 @@ -"""Towards Total Recall in Industrial Anomaly Detection. - -Paper https://arxiv.org/abs/2106.08265. +"""PatchCore: Towards Total Recall in Industrial Anomaly Detection. + +This module implements the PatchCore model for anomaly detection using a memory bank +of patch features extracted from a pretrained CNN backbone. The model stores +representative patch features from normal training images and detects anomalies by +comparing test image patches against this memory bank. + +The model uses a nearest neighbor search to find the most similar patches in the +memory bank and computes anomaly scores based on these distances. It achieves high +performance while maintaining interpretability through localization maps. + +Example: + >>> from anomalib.models.image.patchcore import Patchcore + >>> model = Patchcore( + ... backbone="wide_resnet50_2", + ... layers=["layer2", "layer3"], + ... coreset_sampling_ratio=0.1 + ... ) + >>> model.fit() + >>> prediction = model(image) + +Paper: https://arxiv.org/abs/2106.08265 + +See Also: + - :class:`anomalib.models.image.patchcore.torch_model.PatchcoreModel`: + PyTorch implementation of the PatchCore model architecture + - :class:`anomalib.models.image.patchcore.anomaly_map.AnomalyMapGenerator`: + Anomaly map generation for PatchCore using nearest neighbor search """ # Copyright (C) 2022-2024 Intel Corporation @@ -28,22 +53,57 @@ class Patchcore(MemoryBankMixin, AnomalibModule): - """PatchcoreLightning Module to train PatchCore algorithm. + """PatchCore Lightning Module for anomaly detection. + + This class implements the PatchCore algorithm which uses a memory bank of patch + features for anomaly detection. Features are extracted from a pretrained CNN + backbone and stored in a memory bank. Anomalies are detected by comparing test + image patches with the stored features using nearest neighbor search. + + The model works in two phases: + 1. Training: Extract and store patch features from normal training images + 2. Inference: Compare test image patches against stored features to detect + anomalies Args: - backbone (str): Backbone CNN network - Defaults to ``wide_resnet50_2``. - layers (list[str]): Layers to extract features from the backbone CNN - Defaults to ``["layer2", "layer3"]``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + backbone (str): Name of the backbone CNN network. + Defaults to ``"wide_resnet50_2"``. + layers (Sequence[str]): Names of layers to extract features from. + Defaults to ``("layer2", "layer3")``. + pre_trained (bool, optional): Whether to use pre-trained backbone weights. Defaults to ``True``. - coreset_sampling_ratio (float, optional): Coreset sampling ratio to subsample embedding. - Defaults to ``0.1``. - num_neighbors (int, optional): Number of nearest neighbors. + coreset_sampling_ratio (float, optional): Ratio for coreset sampling to + subsample embeddings. Defaults to ``0.1``. + num_neighbors (int, optional): Number of nearest neighbors to use. Defaults to ``9``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Pre-processor instance or + bool flag. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance or + bool flag. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or bool flag. + Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or bool flag. + Defaults to ``True``. + + Example: + >>> from anomalib.models.image.patchcore import Patchcore + >>> model = Patchcore( + ... backbone="wide_resnet50_2", + ... layers=["layer2", "layer3"], + ... coreset_sampling_ratio=0.1 + ... ) + >>> model.fit() + >>> predictions = model(image) + + Notes: + The model requires no optimization/backpropagation as it uses a pretrained + backbone and nearest neighbor search. + + See Also: + - :class:`anomalib.models.components.AnomalibModule`: + Base class for all anomaly detection models + - :class:`anomalib.models.components.MemoryBankMixin`: + Mixin class for models using feature memory banks """ def __init__( @@ -80,7 +140,29 @@ def configure_pre_processor( image_size: tuple[int, int] | None = None, center_crop_size: tuple[int, int] | None = None, ) -> PreProcessor: - """Default transform for Padim.""" + """Configure the default pre-processor for PatchCore. + + The pre-processor performs the following steps: + 1. Resize image to specified size + 2. Center crop to maintain aspect ratio + 3. Normalize using ImageNet mean and std + + Args: + image_size (tuple[int, int] | None, optional): Target size for + resizing. Defaults to ``(256, 256)``. + center_crop_size (tuple[int, int] | None, optional): Size for center + cropping. If ``None``, scales proportionally to ``image_size``. + Defaults to ``None``. + + Returns: + PreProcessor: Configured pre-processor instance. + + Example: + >>> pre_processor = Patchcore.configure_pre_processor( + ... image_size=(256, 256) + ... ) + >>> transformed_image = pre_processor(image) + """ image_size = image_size or (256, 256) if center_crop_size is None: # scale center crop size proportional to image size @@ -99,7 +181,7 @@ def configure_optimizers() -> None: """Configure optimizers. Returns: - None: Do not set optimizers by returning None. + None: PatchCore requires no optimization. """ return @@ -107,12 +189,16 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: """Generate feature embedding of the batch. Args: - batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask - args: Additional arguments. - kwargs: Additional keyword arguments. + batch (Batch): Input batch containing image and metadata + *args: Additional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - dict[str, np.ndarray]: Embedding Vector + torch.Tensor: Dummy loss tensor for Lightning compatibility + + Note: + The method stores embeddings in ``self.embeddings`` for later use in + ``fit()``. """ del args, kwargs # These variables are not used. @@ -122,7 +208,12 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: - """Apply subsampling to the embedding collected from the training set.""" + """Apply subsampling to the embedding collected from the training set. + + This method: + 1. Aggregates embeddings from all training batches + 2. Applies coreset subsampling to reduce memory requirements + """ logger.info("Aggregating the embedding extracted from the training set.") embeddings = torch.vstack(self.embeddings) @@ -130,15 +221,19 @@ def fit(self) -> None: self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio) def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Get batch of anomaly maps from input image batch. + """Generate predictions for a batch of images. Args: - batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask - args: Additional arguments. - kwargs: Additional keyword arguments. + batch (Batch): Input batch containing images and metadata + *args: Additional arguments (unused) + **kwargs: Additional keyword arguments (unused) Returns: - dict[str, Any]: Image filenames, test images, GT and predicted label/masks + STEP_OUTPUT: Batch with added predictions + + Note: + Predictions include anomaly maps and scores computed using nearest + neighbor search. """ # These variables are not used. del args, kwargs @@ -150,23 +245,32 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Return Patchcore trainer arguments.""" + """Get default trainer arguments for PatchCore. + + Returns: + dict[str, Any]: Trainer arguments + - ``gradient_clip_val``: ``0`` (no gradient clipping needed) + - ``max_epochs``: ``1`` (single pass through training data) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + """ return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type. Returns: - LearningType: Learning type of the model. + LearningType: Always ``LearningType.ONE_CLASS`` as PatchCore only + trains on normal samples """ return LearningType.ONE_CLASS @staticmethod def configure_post_processor() -> OneClassPostProcessor: - """Return the default post-processor for the model. + """Configure the default post-processor. Returns: - OneClassPostProcessor: Post-processor for one-class models. + OneClassPostProcessor: Post-processor for one-class models that + converts raw scores to anomaly predictions """ return OneClassPostProcessor() diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 80133b4bd2..ac74686994 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -1,4 +1,34 @@ -"""PyTorch model for the PatchCore model implementation.""" +"""PyTorch model for the PatchCore model implementation. + +This module implements the PatchCore model architecture using PyTorch. PatchCore +uses a memory bank of patch features extracted from a pretrained CNN backbone to +detect anomalies. + +The model stores representative patch features from normal training images and +detects anomalies by comparing test image patches against this memory bank using +nearest neighbor search. + +Example: + >>> from anomalib.models.image.patchcore.torch_model import PatchcoreModel + >>> model = PatchcoreModel( + ... backbone="wide_resnet50_2", + ... layers=["layer2", "layer3"], + ... pre_trained=True, + ... num_neighbors=9 + ... ) + >>> input_tensor = torch.randn(32, 3, 224, 224) + >>> output = model(input_tensor) + +Paper: https://arxiv.org/abs/2106.08265 + +See Also: + - :class:`anomalib.models.image.patchcore.lightning_model.Patchcore`: + Lightning implementation of the PatchCore model + - :class:`anomalib.models.image.patchcore.anomaly_map.AnomalyMapGenerator`: + Anomaly map generation for PatchCore using nearest neighbor search + - :class:`anomalib.models.components.KCenterGreedy`: + Coreset subsampling using k-center-greedy approach +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -20,16 +50,56 @@ class PatchcoreModel(DynamicBufferMixin, nn.Module): - """Patchcore Module. + """PatchCore PyTorch model for anomaly detection. + + This model implements the PatchCore algorithm which uses a memory bank of patch + features for anomaly detection. Features are extracted from a pretrained CNN + backbone and stored in a memory bank. Anomalies are detected by comparing test + image patches with the stored features using nearest neighbor search. + + The model works in two phases: + 1. Training: Extract and store patch features from normal training images + 2. Inference: Compare test image patches against stored features to detect + anomalies Args: - layers (list[str]): Layers used for feature extraction - backbone (str, optional): Pre-trained model backbone. - Defaults to ``resnet18``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + layers (Sequence[str]): Names of layers to extract features from. + backbone (str, optional): Name of the backbone CNN network. + Defaults to ``"wide_resnet50_2"``. + pre_trained (bool, optional): Whether to use pre-trained backbone weights. Defaults to ``True``. - num_neighbors (int, optional): Number of nearest neighbors. + num_neighbors (int, optional): Number of nearest neighbors to use. Defaults to ``9``. + + Example: + >>> from anomalib.models.image.patchcore.torch_model import PatchcoreModel + >>> model = PatchcoreModel( + ... backbone="wide_resnet50_2", + ... layers=["layer2", "layer3"], + ... pre_trained=True, + ... num_neighbors=9 + ... ) + >>> input_tensor = torch.randn(32, 3, 224, 224) + >>> output = model(input_tensor) + + Attributes: + tiler (Tiler | None): Optional tiler for processing large images. + feature_extractor (TimmFeatureExtractor): CNN feature extractor. + feature_pooler (torch.nn.AvgPool2d): Average pooling layer. + anomaly_map_generator (AnomalyMapGenerator): Generates anomaly heatmaps. + memory_bank (torch.Tensor): Storage for patch features from training. + + Notes: + The model requires no optimization/backpropagation as it uses a pretrained + backbone and nearest neighbor search. + + See Also: + - :class:`anomalib.models.image.patchcore.lightning_model.Patchcore`: + Lightning implementation of the PatchCore model + - :class:`anomalib.models.image.patchcore.anomaly_map.AnomalyMapGenerator`: + Anomaly map generation for PatchCore + - :class:`anomalib.models.components.KCenterGreedy`: + Coreset subsampling using k-center-greedy approach """ def __init__( @@ -58,18 +128,29 @@ def __init__( self.memory_bank: torch.Tensor def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: - """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. + """Process input tensor through the model. - Steps performed: - 1. Get features from a CNN. - 2. Generate embedding based on the features. - 3. Compute anomaly map in test mode. + During training, returns embeddings extracted from the input. During + inference, returns anomaly maps and scores computed by comparing input + embeddings against the memory bank. Args: - input_tensor (torch.Tensor): Input tensor + input_tensor (torch.Tensor): Input images of shape + ``(batch_size, channels, height, width)``. Returns: - Tensor | dict[str, torch.Tensor]: Embedding for training, anomaly map and anomaly score for testing. + torch.Tensor | InferenceBatch: During training, returns embeddings. + During inference, returns ``InferenceBatch`` containing anomaly + maps and scores. + + Example: + >>> model = PatchcoreModel(layers=["layer1"]) + >>> input_tensor = torch.randn(32, 3, 224, 224) + >>> output = model(input_tensor) + >>> if model.training: + ... assert isinstance(output, torch.Tensor) + ... else: + ... assert isinstance(output, InferenceBatch) """ output_size = input_tensor.shape[-2:] if self.tiler: @@ -104,14 +185,27 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: return InferenceBatch(pred_score=pred_score, anomaly_map=anomaly_map) def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: - """Generate embedding from hierarchical feature map. + """Generate embedding by concatenating multi-scale feature maps. + + Combines feature maps from different CNN layers by upsampling them to a + common size and concatenating along the channel dimension. Args: - features: Hierarchical feature map from a CNN (ResNet18 or WideResnet) - features: dict[str:Tensor]: + features (dict[str, torch.Tensor]): Dictionary mapping layer names to + feature tensors extracted from the backbone CNN. Returns: - Embedding vector + torch.Tensor: Concatenated feature embedding of shape + ``(batch_size, num_features, height, width)``. + + Example: + >>> features = { + ... "layer1": torch.randn(32, 64, 56, 56), + ... "layer2": torch.randn(32, 128, 28, 28) + ... } + >>> embedding = model.generate_embedding(features) + >>> embedding.shape + torch.Size([32, 192, 56, 56]) """ embeddings = features[self.layers[0]] for layer in self.layers[1:]: @@ -123,26 +217,43 @@ def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: @staticmethod def reshape_embedding(embedding: torch.Tensor) -> torch.Tensor: - """Reshape Embedding. + """Reshape embedding tensor for patch-wise processing. - Reshapes Embedding to the following format: - - [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding] + Converts a 4D embedding tensor into a 2D matrix where each row represents + a patch embedding vector. Args: - embedding (torch.Tensor): Embedding tensor extracted from CNN features. + embedding (torch.Tensor): Input embedding tensor of shape + ``(batch_size, embedding_dim, height, width)``. Returns: - Tensor: Reshaped embedding tensor. + torch.Tensor: Reshaped embedding tensor of shape + ``(batch_size * height * width, embedding_dim)``. + + Example: + >>> embedding = torch.randn(32, 512, 7, 7) + >>> reshaped = PatchcoreModel.reshape_embedding(embedding) + >>> reshaped.shape + torch.Size([1568, 512]) """ embedding_size = embedding.size(1) return embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: - """Subsample embedding based on coreset sampling and store to memory. + """Subsample embeddings using coreset selection. + + Uses k-center-greedy coreset subsampling to select a representative + subset of patch embeddings to store in the memory bank. Args: - embedding (np.ndarray): Embedding tensor from the CNN - sampling_ratio (float): Coreset sampling ratio + embedding (torch.Tensor): Embedding tensor to subsample from. + sampling_ratio (float): Fraction of embeddings to keep, in range (0,1]. + + Example: + >>> embedding = torch.randn(1000, 512) + >>> model.subsample_embedding(embedding, sampling_ratio=0.1) + >>> model.memory_bank.shape + torch.Size([100, 512]) """ # Coreset Subsampling sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) @@ -151,17 +262,30 @@ def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> @staticmethod def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Calculate pair-wise distance between row vectors in x and those in y. + """Compute pairwise Euclidean distances between two sets of vectors. - Replaces torch cdist with p=2, as cdist is not properly exported to onnx and openvino format. - Resulting matrix is indexed by x vectors in rows and y vectors in columns. + Implements an efficient matrix computation of Euclidean distances between + all pairs of vectors in ``x`` and ``y`` without using ``torch.cdist()``. Args: - x: input tensor 1 - y: input tensor 2 + x (torch.Tensor): First tensor of shape ``(n, d)``. + y (torch.Tensor): Second tensor of shape ``(m, d)``. Returns: - Matrix of distances between row vectors in x and y. + torch.Tensor: Distance matrix of shape ``(n, m)`` where element + ``(i,j)`` is the distance between row ``i`` of ``x`` and row + ``j`` of ``y``. + + Example: + >>> x = torch.randn(100, 512) + >>> y = torch.randn(50, 512) + >>> distances = PatchcoreModel.euclidean_dist(x, y) + >>> distances.shape + torch.Size([100, 50]) + + Note: + This implementation avoids using ``torch.cdist()`` for better + compatibility with ONNX export and OpenVINO conversion. """ x_norm = x.pow(2).sum(dim=-1, keepdim=True) # |x| y_norm = y.pow(2).sum(dim=-1, keepdim=True) # |y| @@ -170,15 +294,28 @@ def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return res.clamp_min_(0).sqrt_() def nearest_neighbors(self, embedding: torch.Tensor, n_neighbors: int) -> tuple[torch.Tensor, torch.Tensor]: - """Nearest Neighbours using brute force method and euclidean norm. + """Find nearest neighbors in memory bank for input embeddings. + + Uses brute force search with Euclidean distance to find the closest + matches in the memory bank for each input embedding. Args: - embedding (torch.Tensor): Features to compare the distance with the memory bank. - n_neighbors (int): Number of neighbors to look at + embedding (torch.Tensor): Query embeddings to find neighbors for. + n_neighbors (int): Number of nearest neighbors to return. Returns: - Tensor: Patch scores. - Tensor: Locations of the nearest neighbor(s). + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - Distances to nearest neighbors (shape: ``(n, k)``) + - Indices of nearest neighbors (shape: ``(n, k)``) + where ``n`` is number of query embeddings and ``k`` is + ``n_neighbors``. + + Example: + >>> embedding = torch.randn(100, 512) + >>> # Assuming memory_bank is already populated + >>> scores, locations = model.nearest_neighbors(embedding, n_neighbors=5) + >>> scores.shape, locations.shape + (torch.Size([100, 5]), torch.Size([100, 5])) """ distances = self.euclidean_dist(embedding, self.memory_bank) if n_neighbors == 1: @@ -194,15 +331,32 @@ def compute_anomaly_score( locations: torch.Tensor, embedding: torch.Tensor, ) -> torch.Tensor: - """Compute Image-Level Anomaly Score. + """Compute image-level anomaly scores. + + Implements the paper's weighted scoring mechanism that considers both + the distance to nearest neighbors and the local neighborhood structure + in the memory bank. Args: - patch_scores (torch.Tensor): Patch-level anomaly scores - locations: Memory bank locations of the nearest neighbor for each patch location - embedding: The feature embeddings that generated the patch scores + patch_scores (torch.Tensor): Patch-level anomaly scores. + locations (torch.Tensor): Memory bank indices of nearest neighbors. + embedding (torch.Tensor): Input embeddings that generated the scores. Returns: - Tensor: Image-level anomaly scores + torch.Tensor: Image-level anomaly scores. + + Example: + >>> patch_scores = torch.randn(32, 49) # 7x7 patches + >>> locations = torch.randint(0, 1000, (32, 49)) + >>> embedding = torch.randn(32 * 49, 512) + >>> scores = model.compute_anomaly_score(patch_scores, locations, + ... embedding) + >>> scores.shape + torch.Size([32]) + + Note: + When ``num_neighbors=1``, returns the maximum patch score directly. + Otherwise, computes weighted scores using neighborhood information. """ # Don't need to compute weights if num_neighbors is 1 if self.num_neighbors == 1: diff --git a/src/anomalib/models/image/reverse_distillation/__init__.py b/src/anomalib/models/image/reverse_distillation/__init__.py index 7dd60dcb25..616c06c4f8 100644 --- a/src/anomalib/models/image/reverse_distillation/__init__.py +++ b/src/anomalib/models/image/reverse_distillation/__init__.py @@ -1,4 +1,27 @@ -"""Reverse Distillation Model.""" +"""Reverse Distillation Model for anomaly detection. + +This module implements the Reverse Distillation model for anomaly detection as described in +the paper "Reverse Distillation: A New Training Strategy for Feature Reconstruction +Networks in Anomaly Detection" (Deng et al., 2022). + +The model consists of: +- A pre-trained encoder (e.g. ResNet) that extracts multi-scale features +- A bottleneck layer that compresses features into a compact representation +- A decoder that reconstructs features back to the original feature space +- A scoring mechanism based on reconstruction error + +Example: + >>> from anomalib.models.image import ReverseDistillation + >>> model = ReverseDistillation() + >>> model.fit(train_dataloader) + >>> predictions = model.predict(test_dataloader) + +See Also: + - :class:`anomalib.models.image.reverse_distillation.lightning_model.ReverseDistillation`: + Lightning implementation of the model + - :class:`anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel`: + PyTorch implementation of the model architecture +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/reverse_distillation/anomaly_map.py b/src/anomalib/models/image/reverse_distillation/anomaly_map.py index 74dc19e1df..8357eb6acd 100644 --- a/src/anomalib/models/image/reverse_distillation/anomaly_map.py +++ b/src/anomalib/models/image/reverse_distillation/anomaly_map.py @@ -1,4 +1,26 @@ -"""Compute Anomaly map.""" +"""Anomaly map computation for Reverse Distillation model. + +This module implements functionality to generate anomaly heatmaps from the feature +reconstruction errors of the Reverse Distillation model. + +The anomaly maps are generated by: +1. Computing reconstruction error between original and reconstructed features +2. Upscaling the error maps to original image size +3. Optional smoothing via Gaussian blur +4. Combining multiple scale errors via addition or multiplication + +Example: + >>> from anomalib.models.image.reverse_distillation.anomaly_map import ( + ... AnomalyMapGenerator + ... ) + >>> generator = AnomalyMapGenerator(image_size=(256, 256)) + >>> features = [torch.randn(1, 64, 32, 32), torch.randn(1, 128, 16, 16)] + >>> anomaly_map = generator(features) + +See Also: + - :class:`AnomalyMapGenerator`: Main class for generating anomaly maps + - :class:`AnomalyMapGenerationMode`: Enum defining map generation modes +""" # Original Code # Copyright (c) 2022 hq-deng diff --git a/src/anomalib/models/image/reverse_distillation/components/__init__.py b/src/anomalib/models/image/reverse_distillation/components/__init__.py index b3f4796605..cb5c14afc1 100644 --- a/src/anomalib/models/image/reverse_distillation/components/__init__.py +++ b/src/anomalib/models/image/reverse_distillation/components/__init__.py @@ -1,4 +1,28 @@ -"""PyTorch modules for Reverse Distillation.""" +"""PyTorch modules for the Reverse Distillation model implementation. + +This module contains the core components used in the Reverse Distillation model +architecture, including the bottleneck layer and decoder network. + +The components work together to learn a compact representation of normal images +through distillation and reconstruction: + +- Bottleneck layer: Compresses features into a lower dimensional space +- Decoder network: Reconstructs features from the bottleneck representation + +Example: + >>> from anomalib.models.image.reverse_distillation.components import ( + ... get_bottleneck_layer, + ... get_decoder + ... ) + >>> bottleneck = get_bottleneck_layer() + >>> decoder = get_decoder() + +See Also: + - :func:`anomalib.models.image.reverse_distillation.components.bottleneck`: + Bottleneck layer implementation + - :func:`anomalib.models.image.reverse_distillation.components.de_resnet`: + Decoder network implementation +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/reverse_distillation/components/bottleneck.py b/src/anomalib/models/image/reverse_distillation/components/bottleneck.py index 220fc1d670..a5a5bde542 100644 --- a/src/anomalib/models/image/reverse_distillation/components/bottleneck.py +++ b/src/anomalib/models/image/reverse_distillation/components/bottleneck.py @@ -1,4 +1,28 @@ -"""Torch model defining the bottleneck layer.""" +"""PyTorch model defining the bottleneck layer for Reverse Distillation. + +This module implements the bottleneck layer used in the Reverse Distillation model +architecture. The bottleneck layer compresses features into a lower dimensional +space while preserving important information for anomaly detection. + +The module contains: +- Bottleneck layer implementation using convolutional blocks +- Helper functions for creating 3x3 and 1x1 convolutions +- One-Class Bottleneck Embedding (OCBE) module for feature compression + +Example: + >>> from anomalib.models.image.reverse_distillation.components.bottleneck import ( + ... get_bottleneck_layer + ... ) + >>> bottleneck = get_bottleneck_layer() + >>> features = torch.randn(32, 512, 28, 28) + >>> compressed = bottleneck(features) + +See Also: + - :class:`anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel`: + Main model implementation using this bottleneck layer + - :class:`anomalib.models.image.reverse_distillation.components.OCBE`: + One-Class Bottleneck Embedding module +""" # Original Code # Copyright (c) 2022 hq-deng @@ -38,13 +62,51 @@ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: class OCBE(nn.Module): """One-Class Bottleneck Embedding module. + This module implements a bottleneck layer that compresses multi-scale features into a + compact representation. It consists of: + + 1. Multiple convolutional layers to process features at different scales + 2. Feature fusion through concatenation + 3. Final bottleneck compression through residual blocks + + The module takes features from multiple scales of an encoder network and outputs a + compressed bottleneck representation. + Args: - block (Bottleneck): Expansion value is extracted from this block. - layers (int): Numbers of OCE layers to create after multiscale feature fusion. - groups (int, optional): Number of blocked connections from input channels to output channels. - Defaults to 1. - width_per_group (int, optional): Number of layers in each intermediate convolution layer. Defaults to 64. - norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use. Defaults to None. + block (Bottleneck | BasicBlock): Block type that determines expansion factor. + Can be either ``Bottleneck`` or ``BasicBlock``. + layers (int): Number of OCE layers to create after multi-scale feature fusion. + groups (int, optional): Number of blocked connections from input channels to + output channels. Defaults to ``1``. + width_per_group (int, optional): Number of channels in each intermediate + convolution layer. Defaults to ``64``. + norm_layer (Callable[..., nn.Module] | None, optional): Normalization layer to + use. If ``None``, uses ``BatchNorm2d``. Defaults to ``None``. + + Example: + >>> import torch + >>> from torchvision.models.resnet import Bottleneck + >>> from anomalib.models.image.reverse_distillation.components import OCBE + >>> model = OCBE(block=Bottleneck, layers=3) + >>> # Create 3 feature maps of different scales + >>> f1 = torch.randn(1, 256, 28, 28) # First scale + >>> f2 = torch.randn(1, 512, 14, 14) # Second scale + >>> f3 = torch.randn(1, 1024, 7, 7) # Third scale + >>> features = [f1, f2, f3] + >>> output = model(features) + >>> output.shape + torch.Size([1, 2048, 4, 4]) + + Notes: + - The module expects exactly 3 input feature maps at different scales + - Features are processed through conv layers before fusion + - Final output dimensions depend on the input feature dimensions and stride + - Initialization uses Kaiming normal for conv layers and constant for norms + + See Also: + - :func:`get_bottleneck_layer`: Factory function to create OCBE instances + - :class:`torchvision.models.resnet.Bottleneck`: ResNet bottleneck block + - :class:`torchvision.models.resnet.BasicBlock`: ResNet basic block """ def __init__( @@ -136,13 +198,24 @@ def _make_layer( return nn.Sequential(*layers) def forward(self, features: list[torch.Tensor]) -> torch.Tensor: - """Forward-pass of Bottleneck layer. + """Forward pass of the bottleneck layer. + + Processes multi-scale features through convolution layers, fuses them via + concatenation, and applies final bottleneck compression. Args: - features (list[torch.Tensor]): List of features extracted from the encoder. + features (list[torch.Tensor]): List of 3 feature tensors from different + scales of the encoder network. Expected shapes: + - features[0]: ``(B, C1, H1, W1)`` + - features[1]: ``(B, C2, H2, W2)`` + - features[2]: ``(B, C3, H3, W3)`` + where B is batch size, Ci are channel dimensions, and Hi, Wi are + spatial dimensions. Returns: - Tensor: Output of the bottleneck layer + torch.Tensor: Compressed bottleneck representation with shape + ``(B, C_out, H_out, W_out)``, where dimensions depend on the input + feature shapes and stride values. """ # Always assumes that features has length of 3 feature0 = self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(features[0])))))) diff --git a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py b/src/anomalib/models/image/reverse_distillation/components/de_resnet.py index 3bb8886e8b..be4389cccf 100644 --- a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py +++ b/src/anomalib/models/image/reverse_distillation/components/de_resnet.py @@ -1,4 +1,28 @@ -"""Torch model defining the decoder.""" +"""PyTorch model defining the decoder network for Reverse Distillation. + +This module implements the decoder network used in the Reverse Distillation model +architecture. The decoder reconstructs features from the bottleneck representation +back to the original feature space. + +The module contains: +- Decoder block implementations using transposed convolutions +- Helper functions for creating decoder layers +- Full decoder network architecture + +Example: + >>> from anomalib.models.image.reverse_distillation.components.de_resnet import ( + ... get_decoder + ... ) + >>> decoder = get_decoder() + >>> features = torch.randn(32, 512, 28, 28) + >>> reconstructed = decoder(features) + +See Also: + - :class:`anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel`: + Main model implementation using this decoder + - :class:`anomalib.models.image.reverse_distillation.components.DecoderBasicBlock`: + Basic building block for the decoder network +""" # Original Code # Copyright (c) 2022 hq-deng @@ -19,20 +43,46 @@ class DecoderBasicBlock(nn.Module): """Basic block for decoder ResNet architecture. + This module implements a basic decoder block used in the decoder network. It performs + upsampling and feature reconstruction through transposed convolutions and skip + connections. + + The block consists of: + 1. Optional upsampling via transposed convolution when ``stride=2`` + 2. Two convolutional layers with batch normalization and ReLU activation + 3. Skip connection that adds input to output features + Args: - inplanes (int): Number of input channels. - planes (int): Number of output channels. - stride (int, optional): Stride for convolution and de-convolution layers. Defaults to 1. - upsample (nn.Module | None, optional): Module used for upsampling output. Defaults to None. - groups (int, optional): Number of blocked connections from input channels to output channels. - Defaults to 1. - base_width (int, optional): Number of layers in each intermediate convolution layer. Defaults to 64. - dilation (int, optional): Spacing between kernel elements. Defaults to 1. - norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use.Defaults to None. + inplanes (int): Number of input channels + planes (int): Number of output channels + stride (int, optional): Stride for convolution and transposed convolution. + When ``stride=2``, upsampling is performed. Defaults to ``1``. + upsample (nn.Module | None, optional): Module used for upsampling the + identity branch. Defaults to ``None``. + groups (int, optional): Number of blocked connections from input to output + channels. Must be ``1``. Defaults to ``1``. + base_width (int, optional): Width of intermediate conv layers. Must be + ``64``. Defaults to ``64``. + dilation (int, optional): Dilation rate for convolutions. Must be ``1``. + Defaults to ``1``. + norm_layer (Callable[..., nn.Module] | None, optional): Normalization layer + to use. Defaults to ``None`` which uses ``BatchNorm2d``. Raises: - ValueError: If groups are not equal to 1 and base width is not 64. - NotImplementedError: If dilation is greater than 1. + ValueError: If ``groups != 1`` or ``base_width != 64`` + NotImplementedError: If ``dilation > 1`` + + Example: + >>> block = DecoderBasicBlock(64, 128, stride=2) + >>> x = torch.randn(1, 64, 32, 32) + >>> output = block(x) # Shape: (1, 128, 64, 64) + + Notes: + - When ``stride=2``, the first conv is replaced with transposed conv for + upsampling + - The block maintains the same architectural pattern as ResNet's BasicBlock + but in reverse + - Skip connections help preserve spatial information during reconstruction """ expansion: int = 1 @@ -78,7 +128,15 @@ def __init__( self.stride = stride def forward(self, batch: torch.Tensor) -> torch.Tensor: - """Forward-pass of de-resnet block.""" + """Forward pass of the decoder basic block. + + Args: + batch (torch.Tensor): Input tensor of shape ``(B, C, H, W)`` + + Returns: + torch.Tensor: Output tensor of shape ``(B, C', H', W')``, where C' is + determined by ``planes`` and H', W' depend on ``stride`` + """ identity = batch out = self.conv1(batch) @@ -96,18 +154,50 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: class DecoderBottleneck(nn.Module): - """Bottleneck for Decoder. + """Bottleneck block for the decoder network. + + This module implements a bottleneck block used in the decoder part of the Reverse + Distillation model. It performs upsampling and feature reconstruction through a series of + convolutional layers. + + The block consists of three convolution layers: + 1. 1x1 conv to adjust channels + 2. 3x3 conv (or transpose conv) for processing + 3. 1x1 conv to expand channels Args: inplanes (int): Number of input channels. - planes (int): Number of output channels. - stride (int, optional): Stride for convolution and de-convolution layers. Defaults to 1. - upsample (nn.Module | None, optional): Module used for upsampling output. Defaults to None. - groups (int, optional): Number of blocked connections from input channels to output channels. - Defaults to 1. - base_width (int, optional): Number of layers in each intermediate convolution layer. Defaults to 64. - dilation (int, optional): Spacing between kernel elements. Defaults to 1. - norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use.Defaults to None. + planes (int): Number of intermediate channels (will be expanded by ``expansion``). + stride (int, optional): Stride for convolution and transpose convolution layers. + Defaults to ``1``. + upsample (nn.Module | None, optional): Module used for upsampling the residual branch. + Defaults to ``None``. + groups (int, optional): Number of blocked connections from input to output channels. + Defaults to ``1``. + base_width (int, optional): Base width for the conv layers. + Defaults to ``64``. + dilation (int, optional): Dilation rate for conv layers. + Defaults to ``1``. + norm_layer (Callable[..., nn.Module] | None, optional): Normalization layer to use. + Defaults to ``None`` which will use ``nn.BatchNorm2d``. + + Attributes: + expansion (int): Channel expansion factor (4 for bottleneck blocks). + + Example: + >>> import torch + >>> from anomalib.models.image.reverse_distillation.components.de_resnet import ( + ... DecoderBottleneck + ... ) + >>> layer = DecoderBottleneck(256, 64) + >>> x = torch.randn(32, 256, 28, 28) + >>> output = layer(x) + >>> output.shape + torch.Size([32, 256, 28, 28]) + + Notes: + - When ``stride=2``, the middle conv layer becomes a transpose conv for upsampling + - The actual output channels will be ``planes * expansion`` """ expansion: int = 4 @@ -150,7 +240,15 @@ def __init__( self.stride = stride def forward(self, batch: torch.Tensor) -> torch.Tensor: - """Forward-pass of de-resnet bottleneck block.""" + """Forward pass of the decoder bottleneck block. + + Args: + batch (torch.Tensor): Input tensor of shape ``(B, C, H, W)`` + + Returns: + torch.Tensor: Output tensor of shape ``(B, C', H', W')``, where ``C'`` is + ``planes * expansion`` and ``H'``, ``W'`` depend on ``stride`` + """ identity = batch out = self.conv1(batch) @@ -172,17 +270,55 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: class ResNet(nn.Module): - """ResNet model for decoder. + """Decoder ResNet model for feature reconstruction. + + This module implements a decoder version of the ResNet architecture, which + reconstructs features from a bottleneck representation back to higher + dimensional feature spaces. + + The decoder consists of multiple layers that progressively upsample and + reconstruct features through transposed convolutions and skip connections. Args: - block (Type[DecoderBasicBlock | DecoderBottleneck]): Type of block to use in a layer. - layers (list[int]): List to specify number for blocks per layer. - zero_init_residual (bool, optional): If true, initializes the last batch norm in each layer to zero. - Defaults to False. - groups (int, optional): Number of blocked connections per layer from input channels to output channels. - Defaults to 1. - width_per_group (int, optional): Number of layers in each intermediate convolution layer.. Defaults to 64. - norm_layer (Callable[..., nn.Module] | None, optional): Batch norm layer to use. Defaults to None. + block (Type[DecoderBasicBlock | DecoderBottleneck]): Type of decoder block + to use in each layer. Can be either ``DecoderBasicBlock`` or + ``DecoderBottleneck``. + layers (list[int]): List specifying number of blocks in each decoder + layer. + zero_init_residual (bool, optional): If ``True``, initializes the last + batch norm in each layer to zero. This improves model performance by + 0.2~0.3% according to https://arxiv.org/abs/1706.02677. + Defaults to ``False``. + groups (int, optional): Number of blocked connections from input channels + to output channels per layer. Defaults to ``1``. + width_per_group (int, optional): Number of channels in each intermediate + convolution layer. Defaults to ``64``. + norm_layer (Callable[..., nn.Module] | None, optional): Normalization + layer to use. If ``None``, uses ``BatchNorm2d``. Defaults to ``None``. + + Example: + >>> from anomalib.models.image.reverse_distillation.components import ( + ... DecoderBasicBlock, + ... ResNet + ... ) + >>> model = ResNet( + ... block=DecoderBasicBlock, + ... layers=[2, 2, 2, 2] + ... ) + >>> x = torch.randn(1, 512, 8, 8) + >>> features = model(x) # Returns list of features at different scales + + Notes: + - The decoder reverses the typical ResNet architecture, starting from a + bottleneck and expanding to larger feature maps + - Features are returned at multiple scales for multi-scale reconstruction + - The implementation follows the original ResNet paper but in reverse + for decoding + + See Also: + - :class:`DecoderBasicBlock`: Basic building block for decoder layers + - :class:`DecoderBottleneck`: Bottleneck building block for deeper + decoder architectures """ def __init__( @@ -270,7 +406,30 @@ def _make_layer( return nn.Sequential(*layers) def forward(self, batch: torch.Tensor) -> list[torch.Tensor]: - """Forward pass for Decoder ResNet. Returns list of features.""" + """Forward pass through the decoder ResNet. + + Progressively reconstructs features through multiple decoder layers, + returning features at different scales. + + Args: + batch (torch.Tensor): Input tensor of shape ``(B, C, H, W)`` where: + - ``B`` is batch size + - ``C`` is number of input channels (512 * block.expansion) + - ``H`` and ``W`` are spatial dimensions + + Returns: + list[torch.Tensor]: List of feature tensors at different scales: + - ``feature_c``: ``(B, 64, H*8, W*8)`` + - ``feature_b``: ``(B, 128, H*4, W*4)`` + - ``feature_a``: ``(B, 256, H*2, W*2)`` + + Example: + >>> model = ResNet(DecoderBasicBlock, [2, 2, 2]) + >>> x = torch.randn(1, 512, 8, 8) + >>> features = model(x) + >>> [f.shape for f in features] + [(1, 64, 64, 64), (1, 128, 32, 32), (1, 256, 16, 16)] + """ feature_a = self.layer1(batch) # 512*8*8->256*16*16 feature_b = self.layer2(feature_a) # 256*16*16->128*32*32 feature_c = self.layer3(feature_b) # 128*32*32->64*64*64 diff --git a/src/anomalib/models/image/reverse_distillation/lightning_model.py b/src/anomalib/models/image/reverse_distillation/lightning_model.py index 3eb3bf903c..9436549568 100644 --- a/src/anomalib/models/image/reverse_distillation/lightning_model.py +++ b/src/anomalib/models/image/reverse_distillation/lightning_model.py @@ -1,6 +1,27 @@ """Anomaly Detection via Reverse Distillation from One-Class Embedding. -https://arxiv.org/abs/2201.10703v2 +This module implements the Reverse Distillation model for anomaly detection as described in +`Deng et al. (2022) `_. + +The model consists of: +- A pre-trained encoder (e.g. ResNet) that extracts multi-scale features +- A bottleneck layer that compresses features into a compact representation +- A decoder that reconstructs features back to the original feature space +- A scoring mechanism based on reconstruction error + +Example: + >>> from anomalib.models.image import ReverseDistillation + >>> model = ReverseDistillation( + ... backbone="wide_resnet50_2", + ... layers=["layer1", "layer2", "layer3"] + ... ) + >>> model.fit(train_dataloader) + >>> predictions = model.predict(test_dataloader) + +See Also: + - :class:`ReverseDistillation`: Lightning implementation of the model + - :class:`ReverseDistillationModel`: PyTorch implementation of the model + - :class:`ReverseDistillationLoss`: Loss function for training """ # Copyright (C) 2022-2024 Intel Corporation diff --git a/src/anomalib/models/image/reverse_distillation/loss.py b/src/anomalib/models/image/reverse_distillation/loss.py index 3d563238ff..7d6f50d569 100644 --- a/src/anomalib/models/image/reverse_distillation/loss.py +++ b/src/anomalib/models/image/reverse_distillation/loss.py @@ -1,4 +1,29 @@ -"""Loss function for Reverse Distillation.""" +"""Loss function for Reverse Distillation model. + +This module implements the loss function used to train the Reverse Distillation model +for anomaly detection. The loss is based on cosine similarity between encoder and +decoder features. + +The loss function: +1. Takes encoder and decoder feature maps as input +2. Flattens the spatial dimensions of each feature map +3. Computes cosine similarity between corresponding encoder-decoder pairs +4. Averages the similarities across spatial dimensions and feature pairs + +Example: + >>> import torch + >>> from anomalib.models.image.reverse_distillation.loss import ( + ... ReverseDistillationLoss + ... ) + >>> criterion = ReverseDistillationLoss() + >>> encoder_features = [torch.randn(2, 64, 32, 32)] + >>> decoder_features = [torch.randn(2, 64, 32, 32)] + >>> loss = criterion(encoder_features, decoder_features) + +See Also: + - :class:`ReverseDistillationLoss`: Main loss class implementation + - :class:`ReverseDistillation`: Lightning implementation of the full model +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,22 +33,49 @@ class ReverseDistillationLoss(nn.Module): - """Loss function for Reverse Distillation.""" + """Loss function for Reverse Distillation model. + + This class implements the cosine similarity loss used to train the Reverse + Distillation model. The loss measures the dissimilarity between encoder and + decoder feature maps. + + The loss computation involves: + 1. Flattening the spatial dimensions of encoder and decoder feature maps + 2. Computing cosine similarity between corresponding encoder-decoder pairs + 3. Subtracting similarities from 1 to get a dissimilarity measure + 4. Taking mean across spatial dimensions and feature pairs + + Example: + >>> import torch + >>> from anomalib.models.image.reverse_distillation.loss import ( + ... ReverseDistillationLoss + ... ) + >>> criterion = ReverseDistillationLoss() + >>> encoder_features = [torch.randn(2, 64, 32, 32)] + >>> decoder_features = [torch.randn(2, 64, 32, 32)] + >>> loss = criterion(encoder_features, decoder_features) + + References: + - Official Implementation: + https://github.com/hq-deng/RD4AD/blob/main/main.py + - Implementation Details: + https://github.com/hq-deng/RD4AD/issues/22 + """ @staticmethod def forward(encoder_features: list[torch.Tensor], decoder_features: list[torch.Tensor]) -> torch.Tensor: - """Compute cosine similarity loss based on features from encoder and decoder. - - Based on the official code: - https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/main.py#L33C25-L33C25 - Calculates loss from flattened arrays of features, see https://github.com/hq-deng/RD4AD/issues/22 + """Compute cosine similarity loss between encoder and decoder features. Args: - encoder_features (list[torch.Tensor]): List of features extracted from encoder - decoder_features (list[torch.Tensor]): List of features extracted from decoder + encoder_features (list[torch.Tensor]): List of feature tensors from the + encoder network. Each tensor has shape ``(B, C, H, W)`` where B is + batch size, C is channels, H and W are spatial dimensions. + decoder_features (list[torch.Tensor]): List of feature tensors from the + decoder network. Must match encoder features in length and shapes. Returns: - Tensor: Cosine similarity loss + torch.Tensor: Scalar loss value computed as mean of (1 - cosine + similarity) across all feature pairs. """ cos_loss = torch.nn.CosineSimilarity() loss_sum = 0 diff --git a/src/anomalib/models/image/reverse_distillation/torch_model.py b/src/anomalib/models/image/reverse_distillation/torch_model.py index b20e19b02f..e6149e8a95 100644 --- a/src/anomalib/models/image/reverse_distillation/torch_model.py +++ b/src/anomalib/models/image/reverse_distillation/torch_model.py @@ -1,4 +1,32 @@ -"""PyTorch model for Reverse Distillation.""" +"""PyTorch model implementation for Reverse Distillation. + +This module implements the core PyTorch model architecture for the Reverse Distillation +anomaly detection method as described in `Deng et al. (2022) +`_. + +The model consists of: +- A pre-trained encoder (e.g. ResNet) that extracts multi-scale features +- A bottleneck layer that compresses features into a compact representation +- A decoder that reconstructs features back to the original feature space +- A scoring mechanism based on reconstruction error + +Example: + >>> from anomalib.models.image.reverse_distillation.torch_model import ( + ... ReverseDistillationModel + ... ) + >>> model = ReverseDistillationModel( + ... backbone="wide_resnet50_2", + ... input_size=(256, 256), + ... layers=["layer1", "layer2", "layer3"], + ... anomaly_map_mode="multiply" + ... ) + >>> features = model(torch.randn(1, 3, 256, 256)) + +See Also: + - :class:`ReverseDistillationModel`: Main PyTorch model implementation + - :class:`ReverseDistillationLoss`: Loss function for training + - :class:`AnomalyMapGenerator`: Anomaly map generation from features +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -20,18 +48,48 @@ class ReverseDistillationModel(nn.Module): - """Reverse Distillation Model. + """PyTorch implementation of the Reverse Distillation model. - To reproduce results in the paper, use torchvision model for the encoder: - self.encoder = torchvision.models.wide_resnet50_2(pretrained=True) + The model consists of an encoder-decoder architecture where the encoder extracts + multi-scale features and the decoder reconstructs them back to the original + feature space. The reconstruction error is used to detect anomalies. Args: - backbone (str): Name of the backbone used for encoder and decoder. - input_size (tuple[int, int]): Size of input image. - layers (list[str]): Name of layers from which the features are extracted. - anomaly_map_mode (str): Mode used to generate anomaly map. Options are between ``multiply`` and ``add``. - pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. - Defaults to ``True``. + backbone (str): Name of the backbone CNN architecture used for encoder and + decoder. Supported backbones can be found in timm library. + input_size (tuple[int, int]): Size of input images in format ``(H, W)``. + layers (Sequence[str]): Names of layers from which to extract features. + For example ``["layer1", "layer2", "layer3"]``. + anomaly_map_mode (AnomalyMapGenerationMode): Mode used to generate anomaly + map. Options are ``"multiply"`` or ``"add"``. + pre_trained (bool, optional): Whether to use pre-trained weights for the + encoder backbone. Defaults to ``True``. + + Example: + >>> import torch + >>> from anomalib.models.image.reverse_distillation.torch_model import ( + ... ReverseDistillationModel + ... ) + >>> model = ReverseDistillationModel( + ... backbone="wide_resnet50_2", + ... input_size=(256, 256), + ... layers=["layer1", "layer2", "layer3"], + ... anomaly_map_mode="multiply" + ... ) + >>> input_tensor = torch.randn(1, 3, 256, 256) + >>> features = model(input_tensor) + + Note: + The original paper uses torchvision's pre-trained wide_resnet50_2 as the + encoder backbone. + + Attributes: + tiler (Tiler | None): Optional tiler for processing large images in patches. + encoder (TimmFeatureExtractor): Feature extraction backbone. + bottleneck (nn.Module): Bottleneck layer to compress features. + decoder (nn.Module): Decoder network to reconstruct features. + anomaly_map_generator (AnomalyMapGenerator): Module to generate anomaly + maps from features. """ def __init__( @@ -53,17 +111,39 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size, mode=anomaly_map_mode) def forward(self, images: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor]] | InferenceBatch: - """Forward-pass images to the network. + """Forward pass through the model. - During the training mode the model extracts features from encoder and decoder networks. - During evaluation mode, it returns the predicted anomaly map. + The behavior differs between training and evaluation modes: + - Training: Returns encoder and decoder features for computing loss + - Evaluation: Returns anomaly maps and scores Args: - images (torch.Tensor): Batch of images + images (torch.Tensor): Input tensor of shape ``(N, C, H, W)`` where + ``N`` is batch size, ``C`` is number of channels, ``H`` and ``W`` + are height and width. Returns: - torch.Tensor | tuple[list[torch.Tensor]] | InferenceBatch: Encoder and decoder features - in training mode, else anomaly maps. + tuple[list[torch.Tensor], list[torch.Tensor]] | InferenceBatch: + - In training mode: Tuple of lists containing encoder and decoder + features + - In evaluation mode: ``InferenceBatch`` containing anomaly maps + and scores + + Example: + >>> import torch + >>> model = ReverseDistillationModel( + ... backbone="wide_resnet50_2", + ... input_size=(256, 256), + ... layers=["layer1", "layer2", "layer3"], + ... anomaly_map_mode="multiply" + ... ) + >>> input_tensor = torch.randn(1, 3, 256, 256) + >>> # Training mode + >>> model.train() + >>> encoder_features, decoder_features = model(input_tensor) + >>> # Evaluation mode + >>> model.eval() + >>> predictions = model(input_tensor) """ self.encoder.eval() diff --git a/src/anomalib/models/image/stfpm/__init__.py b/src/anomalib/models/image/stfpm/__init__.py index 049695a63e..d6c456acb5 100644 --- a/src/anomalib/models/image/stfpm/__init__.py +++ b/src/anomalib/models/image/stfpm/__init__.py @@ -1,4 +1,33 @@ -"""STFPM Model.""" +"""Student-Teacher Feature Pyramid Matching Model for anomaly detection. + +This module implements the STFPM model for anomaly detection as described in +Wang et al., 2021: Student-Teacher Feature Pyramid Matching for Unsupervised +Anomaly Detection. + +The model consists of: +- A pre-trained teacher network that extracts multi-scale features +- A student network that learns to match the teacher's feature representations +- Feature pyramid matching between student and teacher features +- Anomaly detection based on feature discrepancy + +Example: + >>> from anomalib.models.image import Stfpm + >>> from anomalib.engine import Engine + >>> from anomalib.data import MVTec + + >>> datamodule = MVTec() + >>> model = Stfpm() + >>> engine = Engine(model=model, datamodule=datamodule) + + >>> engine.fit() # doctest: +SKIP + >>> predictions = engine.predict() # doctest: +SKIP + +See Also: + - :class:`anomalib.models.image.stfpm.lightning_model.Stfpm`: + Lightning implementation of the model + - :class:`anomalib.models.image.stfpm.torch_model.StfpmModel`: + PyTorch implementation of the model architecture +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/stfpm/anomaly_map.py b/src/anomalib/models/image/stfpm/anomaly_map.py index 9cd7887fea..afb38eafb8 100644 --- a/src/anomalib/models/image/stfpm/anomaly_map.py +++ b/src/anomalib/models/image/stfpm/anomaly_map.py @@ -1,4 +1,30 @@ -"""Anomaly Map Generator for the STFPM model implementation.""" +"""Anomaly map computation for Student-Teacher Feature Pyramid Matching model. + +This module implements functionality to generate anomaly heatmaps by comparing +features between a pre-trained teacher network and a student network that learns +to match the teacher's representations. + +The anomaly maps are generated by: +1. Computing cosine similarity between teacher and student features +2. Converting similarity scores to anomaly scores via L2 norm +3. Upscaling anomaly scores to original image size +4. Combining multiple layer scores via element-wise multiplication + +Example: + >>> from anomalib.models.image.stfpm.anomaly_map import AnomalyMapGenerator + >>> generator = AnomalyMapGenerator() + >>> teacher_features = {"layer1": torch.randn(1, 64, 32, 32)} + >>> student_features = {"layer1": torch.randn(1, 64, 32, 32)} + >>> anomaly_map = generator.compute_anomaly_map( + ... teacher_features, + ... student_features, + ... image_size=(256, 256) + ... ) + +See Also: + - :class:`AnomalyMapGenerator`: Main class for generating anomaly maps + - :func:`compute_layer_map`: Function to compute per-layer anomaly scores +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,9 +35,36 @@ class AnomalyMapGenerator(nn.Module): - """Generate Anomaly Heatmap.""" + """Generate anomaly heatmaps by comparing teacher and student features. + + This class implements functionality to generate anomaly maps by comparing + feature representations between a pre-trained teacher network and a student + network. The comparison is done via cosine similarity and L2 distance. + + The anomaly map generation process involves: + 1. Computing cosine similarity between teacher-student feature pairs + 2. Converting similarity scores to anomaly scores using L2 norm + 3. Upscaling the scores to original image size + 4. Combining multiple layer scores via element-wise multiplication + + Example: + >>> from anomalib.models.image.stfpm.anomaly_map import AnomalyMapGenerator + >>> generator = AnomalyMapGenerator() + >>> teacher_features = {"layer1": torch.randn(1, 64, 32, 32)} + >>> student_features = {"layer1": torch.randn(1, 64, 32, 32)} + >>> anomaly_map = generator.compute_anomaly_map( + ... teacher_features, + ... student_features, + ... image_size=(256, 256) + ... ) + + See Also: + - :func:`compute_layer_map`: Function to compute per-layer anomaly scores + - :func:`compute_anomaly_map`: Function to combine layer scores + """ def __init__(self) -> None: + """Initialize pairwise distance metric.""" super().__init__() self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) @@ -21,15 +74,24 @@ def compute_layer_map( student_features: torch.Tensor, image_size: tuple[int, int] | torch.Size, ) -> torch.Tensor: - """Compute the layer map based on cosine similarity. + """Compute anomaly map for a single feature layer. + + The layer map is computed by: + 1. Normalizing teacher and student features + 2. Computing L2 distance between normalized features + 3. Upscaling the distance map to original image size Args: - teacher_features (torch.Tensor): Teacher features - student_features (torch.Tensor): Student features - image_size (tuple[int, int]): Image size to which the anomaly map should be resized. + teacher_features (torch.Tensor): Features from teacher network with + shape ``(B, C, H, W)`` + student_features (torch.Tensor): Features from student network with + matching shape + image_size (tuple[int, int] | torch.Size): Target size for upscaling + in format ``(H, W)`` Returns: - Anomaly score based on cosine similarity. + torch.Tensor: Anomaly scores for the layer, upscaled to + ``image_size`` """ norm_teacher_features = F.normalize(teacher_features) norm_student_features = F.normalize(student_features) @@ -43,15 +105,23 @@ def compute_anomaly_map( student_features: dict[str, torch.Tensor], image_size: tuple[int, int] | torch.Size, ) -> torch.Tensor: - """Compute the overall anomaly map via element-wise production the interpolated anomaly maps. + """Compute overall anomaly map by combining multiple layer maps. + + The final anomaly map is generated by: + 1. Computing per-layer anomaly maps via :func:`compute_layer_map` + 2. Combining layer maps through element-wise multiplication Args: - teacher_features (dict[str, torch.Tensor]): Teacher features - student_features (dict[str, torch.Tensor]): Student features - image_size (tuple[int, int]): Image size to which the anomaly map should be resized. + teacher_features (dict[str, torch.Tensor]): Dictionary mapping layer + names to teacher feature tensors + student_features (dict[str, torch.Tensor]): Dictionary mapping layer + names to student feature tensors + image_size (tuple[int, int] | torch.Size): Target size for the + anomaly map in format ``(H, W)`` Returns: - Final anomaly map + torch.Tensor: Final anomaly map with shape ``(B, 1, H, W)`` where + ``B`` is batch size and ``(H, W)`` matches ``image_size`` """ batch_size = next(iter(teacher_features.values())).shape[0] anomaly_map = torch.ones(batch_size, 1, image_size[0], image_size[1]) @@ -63,25 +133,30 @@ def compute_anomaly_map( return anomaly_map def forward(self, **kwargs: dict[str, torch.Tensor]) -> torch.Tensor: - """Return anomaly map. + """Generate anomaly map from teacher and student features. - Expects `teach_features` and `student_features` keywords to be passed explicitly. + Expects the following keys in ``kwargs``: + - ``teacher_features``: Dictionary of teacher network features + - ``student_features``: Dictionary of student network features + - ``image_size``: Target size for the anomaly map Args: - kwargs (dict[str, torch.Tensor]): Keyword arguments + kwargs (dict[str, torch.Tensor]): Keyword arguments containing + required inputs Example: - >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size)) - >>> output = self.anomaly_map_generator( - teacher_features=teacher_features, - student_features=student_features - ) + >>> generator = AnomalyMapGenerator() + >>> anomaly_map = generator( + ... teacher_features=teacher_features, + ... student_features=student_features, + ... image_size=(256, 256) + ... ) Raises: - ValueError: `teach_features` and `student_features` keys are not found + ValueError: If required keys are missing from ``kwargs`` Returns: - torch.Tensor: anomaly map + torch.Tensor: Anomaly map with shape ``(B, 1, H, W)`` """ if not ("teacher_features" in kwargs and "student_features" in kwargs): msg = f"Expected keys `teacher_features` and `student_features. Found {kwargs.keys()}" diff --git a/src/anomalib/models/image/stfpm/lightning_model.py b/src/anomalib/models/image/stfpm/lightning_model.py index f3daafe407..dc07f9035e 100644 --- a/src/anomalib/models/image/stfpm/lightning_model.py +++ b/src/anomalib/models/image/stfpm/lightning_model.py @@ -1,6 +1,31 @@ -"""STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. - -https://arxiv.org/abs/2103.04257 +"""Student-Teacher Feature Pyramid Matching for anomaly detection. + +This module implements the STFPM model for anomaly detection as described in +`Wang et al. (2021) `_. + +The model consists of: +- A pre-trained teacher network that extracts multi-scale features +- A student network that learns to match the teacher's feature representations +- Feature pyramid matching between student and teacher features +- Anomaly detection based on feature discrepancy + +Example: + >>> from anomalib.models.image import Stfpm + >>> from anomalib.engine import Engine + >>> from anomalib.data import MVTec + >>> datamodule = MVTec() + >>> model = Stfpm( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"] + ... ) + >>> engine = Engine(model=model, datamodule=datamodule) + >>> engine.fit() # doctest: +SKIP + >>> predictions = engine.predict() # doctest: +SKIP + +See Also: + - :class:`Stfpm`: Lightning implementation of the model + - :class:`STFPMModel`: PyTorch implementation of the model architecture + - :class:`STFPMLoss`: Loss function for training """ # Copyright (C) 2022-2024 Intel Corporation @@ -30,14 +55,45 @@ class Stfpm(AnomalibModule): """PL Lightning Module for the STFPM algorithm. + The Student-Teacher Feature Pyramid Matching (STFPM) model consists of a + pre-trained teacher network and a student network that learns to match the + teacher's feature representations. The model detects anomalies by comparing + feature discrepancies between the teacher and student networks. + Args: - backbone (str): Backbone CNN network - Defaults to ``resnet18``. - layers (list[str]): Layers to extract features from the backbone CNN + backbone (str): Name of the backbone CNN network used for both teacher + and student. Defaults to ``"resnet18"``. + layers (list[str]): Names of layers from which to extract features. Defaults to ``["layer1", "layer2", "layer3"]``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Pre-processor to transform + input data before passing to model. If ``True``, uses default. + Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor to generate + predictions from model outputs. If ``True``, uses default. + Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator to compute metrics. + If ``True``, uses default. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer to display results. + If ``True``, uses default. Defaults to ``True``. + + Example: + >>> from anomalib.models.image import Stfpm + >>> from anomalib.data import MVTec + >>> from anomalib.engine import Engine + >>> datamodule = MVTec() + >>> model = Stfpm( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"] + ... ) + >>> engine = Engine(model=model, datamodule=datamodule) + >>> engine.fit() # doctest: +SKIP + >>> predictions = engine.predict() # doctest: +SKIP + + See Also: + - :class:`anomalib.models.image.stfpm.torch_model.STFPMModel`: + PyTorch implementation of the model architecture + - :class:`anomalib.models.image.stfpm.loss.STFPMLoss`: + Loss function for training """ def __init__( @@ -62,15 +118,15 @@ def __init__( def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a training step of STFPM. - For each batch, teacher and student and teacher features are extracted from the CNN. + For each batch, teacher and student features are extracted from the CNN. Args: - batch (Batch): Input batch. - args: Additional arguments. - kwargs: Additional keyword arguments. + batch (Batch): Input batch containing images and labels. + args: Additional arguments (unused). + kwargs: Additional keyword arguments (unused). Returns: - Loss value + STEP_OUTPUT: Dictionary containing the loss value. """ del args, kwargs # These variables are not used. @@ -80,19 +136,19 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: return {"loss": loss} def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: - """Perform a validation Step of STFPM. + """Perform a validation step of STFPM. - Similar to the training step, student/teacher features are extracted from the CNN for each batch, and - anomaly map is computed. + Similar to training, extracts student/teacher features from CNN and + computes anomaly maps. Args: - batch (Batch): Input batch - args: Additional arguments - kwargs: Additional keyword arguments + batch (Batch): Input batch containing images and labels. + args: Additional arguments (unused). + kwargs: Additional keyword arguments (unused). Returns: - Dictionary containing images, anomaly maps, true labels and masks. - These are required in `validation_epoch_end` for feature concatenation. + STEP_OUTPUT: Dictionary containing images, anomaly maps, labels and + masks for evaluation. """ del args, kwargs # These variables are not used. @@ -101,14 +157,24 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """Required trainer arguments.""" + """Get required trainer arguments for the model. + + Returns: + dict[str, Any]: Dictionary of trainer arguments: + - ``gradient_clip_val``: Set to 0 to disable gradient clipping + - ``num_sanity_val_steps``: Set to 0 to skip validation sanity + checks + """ return {"gradient_clip_val": 0, "num_sanity_val_steps": 0} def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure optimizers. + """Configure optimizers for training. Returns: - Optimizer: SGD optimizer + torch.optim.Optimizer: SGD optimizer with the following parameters: + - Learning rate: 0.4 + - Momentum: 0.9 + - Weight decay: 0.001 """ return optim.SGD( params=self.model.student_model.parameters(), @@ -120,9 +186,9 @@ def configure_optimizers(self) -> torch.optim.Optimizer: @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: The model uses one-class learning. """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/stfpm/loss.py b/src/anomalib/models/image/stfpm/loss.py index 412caf8fdd..3d598def98 100644 --- a/src/anomalib/models/image/stfpm/loss.py +++ b/src/anomalib/models/image/stfpm/loss.py @@ -1,4 +1,37 @@ -"""Loss function for the STFPM Model Implementation.""" +"""Loss function for Student-Teacher Feature Pyramid Matching model. + +This module implements the loss function used to train the STFPM model for anomaly +detection as described in `Wang et al. (2021) `_. + +The loss function: +1. Takes feature maps from teacher and student networks as input +2. Normalizes the features using L2 normalization +3. Computes MSE loss between normalized features +4. Scales the loss by spatial dimensions of feature maps + +Example: + >>> from anomalib.models.components import TimmFeatureExtractor + >>> from anomalib.models.image.stfpm.loss import STFPMLoss + >>> from torchvision.models import resnet18 + >>> layers = ["layer1", "layer2", "layer3"] + >>> teacher_model = TimmFeatureExtractor( + ... model=resnet18(pretrained=True), + ... layers=layers + ... ) + >>> student_model = TimmFeatureExtractor( + ... model=resnet18(pretrained=False), + ... layers=layers + ... ) + >>> criterion = STFPMLoss() + >>> features = torch.randn(4, 3, 256, 256) + >>> teacher_features = teacher_model(features) + >>> student_features = student_model(features) + >>> loss = criterion(student_features, teacher_features) + +See Also: + - :class:`STFPMLoss`: Main loss class implementation + - :class:`Stfpm`: Lightning implementation of the full model +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,23 +42,42 @@ class STFPMLoss(nn.Module): - """Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. + """Loss function for Student-Teacher Feature Pyramid Matching model. - Example: - >>> from anomalib.models.components.feature_extractors import TimmFeatureExtractor - >>> from anomalib.models.stfpm.loss import STFPMLoss - >>> from torchvision.models import resnet18 + This class implements the feature pyramid loss function proposed in the STFPM + paper. The loss measures the discrepancy between feature representations from + a pre-trained teacher network and a student network that learns to match them. - >>> layers = ['layer1', 'layer2', 'layer3'] - >>> teacher_model = TimmFeatureExtractor(model=resnet18(pretrained=True), layers=layers) - >>> student_model = TimmFeatureExtractor(model=resnet18(pretrained=False), layers=layers) - >>> loss = Loss() + The loss computation involves: + 1. Normalizing teacher and student features using L2 normalization + 2. Computing MSE loss between normalized features + 3. Scaling the loss by spatial dimensions of feature maps + 4. Summing losses across all feature layers - >>> inp = torch.rand((4, 3, 256, 256)) - >>> teacher_features = teacher_model(inp) - >>> student_features = student_model(inp) - >>> loss(student_features, teacher_features) - tensor(51.2015, grad_fn=) + Example: + >>> from anomalib.models.components import TimmFeatureExtractor + >>> from anomalib.models.image.stfpm.loss import STFPMLoss + >>> from torchvision.models import resnet18 + >>> layers = ["layer1", "layer2", "layer3"] + >>> teacher_model = TimmFeatureExtractor( + ... model=resnet18(pretrained=True), + ... layers=layers + ... ) + >>> student_model = TimmFeatureExtractor( + ... model=resnet18(pretrained=False), + ... layers=layers + ... ) + >>> criterion = STFPMLoss() + >>> features = torch.randn(4, 3, 256, 256) + >>> teacher_features = teacher_model(features) + >>> student_features = student_model(features) + >>> loss = criterion(student_features, teacher_features) + >>> loss + tensor(51.2015, grad_fn=) + + See Also: + - :class:`Stfpm`: Lightning implementation of the full model + - :class:`STFPMModel`: PyTorch implementation of the model architecture """ def __init__(self) -> None: @@ -33,14 +85,22 @@ def __init__(self) -> None: self.mse_loss = nn.MSELoss(reduction="sum") def compute_layer_loss(self, teacher_feats: torch.Tensor, student_feats: torch.Tensor) -> torch.Tensor: - """Compute layer loss based on Equation (1) in Section 3.2 of the paper. + """Compute loss between teacher and student features for a single layer. + + This implements the loss computation based on Equation (1) in Section 3.2 + of the paper. The loss is computed as: + 1. L2 normalize teacher and student features + 2. Compute MSE loss between normalized features + 3. Scale loss by spatial dimensions (height * width) Args: - teacher_feats (torch.Tensor): Teacher features - student_feats (torch.Tensor): Student features + teacher_feats (torch.Tensor): Features from teacher network with shape + ``(B, C, H, W)`` + student_feats (torch.Tensor): Features from student network with shape + ``(B, C, H, W)`` Returns: - L2 distance between teacher and student features. + torch.Tensor: Scalar loss value for the layer """ height, width = teacher_feats.shape[2:] @@ -53,14 +113,20 @@ def forward( teacher_features: dict[str, torch.Tensor], student_features: dict[str, torch.Tensor], ) -> torch.Tensor: - """Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity. + """Compute total loss across all feature layers. + + The total loss is computed as the sum of individual layer losses. Each + layer loss measures the discrepancy between teacher and student features + at that layer. Args: - teacher_features (dict[str, torch.Tensor]): Teacher features - student_features (dict[str, torch.Tensor]): Student features + teacher_features (dict[str, torch.Tensor]): Dictionary mapping layer + names to teacher feature tensors + student_features (dict[str, torch.Tensor]): Dictionary mapping layer + names to student feature tensors Returns: - Total loss, which is the weighted average of the layer losses. + torch.Tensor: Total loss summed across all layers """ layer_losses: list[torch.Tensor] = [] for layer in teacher_features: diff --git a/src/anomalib/models/image/stfpm/torch_model.py b/src/anomalib/models/image/stfpm/torch_model.py index 72638b1531..a4308ecce9 100644 --- a/src/anomalib/models/image/stfpm/torch_model.py +++ b/src/anomalib/models/image/stfpm/torch_model.py @@ -1,4 +1,28 @@ -"""PyTorch model for the STFPM model implementation.""" +"""PyTorch model implementation for Student-Teacher Feature Pyramid Matching. + +This module implements the core PyTorch model architecture for the STFPM anomaly +detection method as described in `Wang et al. (2021) +`_. + +The model consists of: +- A pre-trained teacher network that extracts multi-scale features +- A student network that learns to match the teacher's feature representations +- Feature pyramid matching between student and teacher features +- Anomaly detection based on feature discrepancy + +Example: + >>> from anomalib.models.image.stfpm.torch_model import STFPMModel + >>> model = STFPMModel( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"] + ... ) + >>> features = model(torch.randn(1, 3, 256, 256)) + +See Also: + - :class:`STFPMModel`: Main PyTorch model implementation + - :class:`STFPMLoss`: Loss function for training + - :class:`AnomalyMapGenerator`: Anomaly map generation from features +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -19,12 +43,43 @@ class STFPMModel(nn.Module): - """STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. + """PyTorch implementation of the STFPM model. + + The Student-Teacher Feature Pyramid Matching model consists of a pre-trained + teacher network and a student network that learns to match the teacher's + feature representations. The model detects anomalies by comparing feature + discrepancies between the teacher and student networks. Args: - layers (list[str]): Layers used for feature extraction. - backbone (str, optional): Pre-trained model backbone. - Defaults to ``resnet18``. + layers (Sequence[str]): Names of layers from which to extract features. + For example ``["layer1", "layer2", "layer3"]``. + backbone (str, optional): Name of the backbone CNN architecture used for + both teacher and student networks. Supported backbones can be found + in timm library. Defaults to ``"resnet18"``. + + Example: + >>> import torch + >>> from anomalib.models.image.stfpm.torch_model import STFPMModel + >>> model = STFPMModel( + ... backbone="resnet18", + ... layers=["layer1", "layer2", "layer3"] + ... ) + >>> input_tensor = torch.randn(1, 3, 256, 256) + >>> features = model(input_tensor) + + Note: + The teacher model is initialized with pre-trained weights and frozen + during training, while the student model is trained from scratch. + + Attributes: + tiler (Tiler | None): Optional tiler for processing large images in + patches. + teacher_model (TimmFeatureExtractor): Pre-trained teacher network for + feature extraction. + student_model (TimmFeatureExtractor): Student network that learns to + match teacher features. + anomaly_map_generator (AnomalyMapGenerator): Module to generate anomaly + maps from features. """ def __init__( @@ -54,16 +109,36 @@ def forward( self, images: torch.Tensor, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] | InferenceBatch: - """Forward-pass images into the network. + """Forward pass through teacher and student networks. - During the training mode the model extracts the features from the teacher and student networks. - During the evaluation mode, it returns the predicted anomaly map. + The forward pass behavior differs between training and evaluation: + - Training: Returns features from both teacher and student networks + - Evaluation: Returns anomaly maps generated from feature differences Args: - images (torch.Tensor): Batch of images. + images (torch.Tensor): Batch of input images with shape + ``(N, C, H, W)``. Returns: - Teacher and student features when in training mode, otherwise the predicted anomaly maps. + Training mode: + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + Features from teacher and student networks respectively. + Each dict maps layer names to feature tensors. + Evaluation mode: + InferenceBatch: + Batch containing anomaly maps and prediction scores. + + Example: + >>> import torch + >>> from anomalib.models.image.stfpm.torch_model import STFPMModel + >>> model = STFPMModel(layers=["layer1", "layer2", "layer3"]) + >>> input_tensor = torch.randn(1, 3, 256, 256) + >>> # Training mode + >>> model.train() + >>> teacher_feats, student_feats = model(input_tensor) + >>> # Evaluation mode + >>> model.eval() + >>> predictions = model(input_tensor) """ output_size = images.shape[-2:] if self.tiler: diff --git a/src/anomalib/models/image/uflow/__init__.py b/src/anomalib/models/image/uflow/__init__.py index 653f7835fa..71693e3b69 100644 --- a/src/anomalib/models/image/uflow/__init__.py +++ b/src/anomalib/models/image/uflow/__init__.py @@ -1,4 +1,32 @@ -"""U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold.""" +"""U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold. + +This module implements the U-Flow model for anomaly detection as described in +Rudolph et al., 2022: U-Flow: A U-shaped Normalizing Flow for Anomaly Detection +with Unsupervised Threshold. + +The model consists of: +- A U-shaped normalizing flow architecture for density estimation +- Unsupervised threshold estimation based on the learned density +- Anomaly detection by comparing likelihoods to the threshold + +Example: + >>> from anomalib.models.image import Uflow + >>> from anomalib.engine import Engine + >>> from anomalib.data import MVTec + + >>> datamodule = MVTec() + >>> model = Uflow() + >>> engine = Engine(model=model, datamodule=datamodule) + + >>> engine.fit() # doctest: +SKIP + >>> predictions = engine.predict() # doctest: +SKIP + +See Also: + - :class:`anomalib.models.image.uflow.lightning_model.Uflow`: + Lightning implementation of the model + - :class:`anomalib.models.image.uflow.torch_model.UflowModel`: + PyTorch implementation of the model architecture +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/uflow/anomaly_map.py b/src/anomalib/models/image/uflow/anomaly_map.py index 697f03321d..4457bd17e5 100644 --- a/src/anomalib/models/image/uflow/anomaly_map.py +++ b/src/anomalib/models/image/uflow/anomaly_map.py @@ -1,4 +1,22 @@ -"""UFlow Anomaly Map Generator Implementation.""" +"""Anomaly map computation for U-Flow model. + +This module implements functionality to generate anomaly heatmaps from the latent +variables produced by a U-Flow model. The anomaly maps are generated by: + +1. Computing per-scale likelihoods from latent variables +2. Upscaling likelihoods to original image size +3. Combining multiple scale likelihoods + +Example: + >>> from anomalib.models.image.uflow.anomaly_map import AnomalyMapGenerator + >>> generator = AnomalyMapGenerator(input_size=(256, 256)) + >>> latent_vars = [torch.randn(1, 64, 32, 32), torch.randn(1, 128, 16, 16)] + >>> anomaly_map = generator(latent_vars) + +See Also: + - :class:`AnomalyMapGenerator`: Main class for generating anomaly maps + - :func:`compute_anomaly_map`: Function to generate anomaly maps from latents +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,27 +34,66 @@ class AnomalyMapGenerator(nn.Module): - """Generate Anomaly Heatmap and segmentation.""" + """Generate anomaly heatmaps and segmentation masks from U-Flow latent variables. + + This class implements functionality to generate anomaly maps by analyzing the latent + variables produced by a U-Flow model. The anomaly maps can be generated in two ways: + + 1. Using likelihood-based scoring (default method): + - Computes per-scale likelihoods from latent variables + - Upscales likelihoods to original image size + - Combines multiple scale likelihoods via averaging + + 2. Using NFA-based segmentation (optional method): + - Applies binomial testing on local windows + - Computes Number of False Alarms (NFA) statistics + - Generates binary segmentation masks + + Args: + input_size (ListConfig | tuple): Size of input images as ``(height, width)`` + + Example: + >>> from anomalib.models.image.uflow.anomaly_map import AnomalyMapGenerator + >>> generator = AnomalyMapGenerator(input_size=(256, 256)) + >>> latents = [torch.randn(1, 64, 32, 32), torch.randn(1, 128, 16, 16)] + >>> anomaly_map = generator(latents) + >>> anomaly_map.shape + torch.Size([1, 1, 256, 256]) + + See Also: + - :func:`compute_anomaly_map`: Main method for likelihood-based maps + - :func:`compute_anomaly_mask`: Optional method for NFA-based segmentation + """ def __init__(self, input_size: ListConfig | tuple) -> None: super().__init__() self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size) def forward(self, latent_variables: list[Tensor]) -> Tensor: - """Return anomaly map.""" + """Generate anomaly map from latent variables. + + Args: + latent_variables (list[Tensor]): List of latent tensors from U-Flow model + + Returns: + Tensor: Anomaly heatmap of shape ``(batch_size, 1, height, width)`` + """ return self.compute_anomaly_map(latent_variables) def compute_anomaly_map(self, latent_variables: list[Tensor]) -> Tensor: - """Generate a likelihood-based anomaly map, from latent variables. + """Generate likelihood-based anomaly map from latent variables. + + The method: + 1. Computes per-scale likelihoods from latent variables + 2. Upscales each likelihood map to input image size + 3. Combines scale likelihoods via averaging Args: - latent_variables: List of latent variables from the UFlow model. Each element is a tensor of shape - (N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height and - width of the latent variables, respectively, for each scale l. + latent_variables (list[Tensor]): List of latent tensors from U-Flow model, + each with shape ``(batch_size, channels, height, width)`` Returns: - Final Anomaly Map. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and - width of the input image, respectively. + Tensor: Anomaly heatmap of shape ``(batch_size, 1, height, width)`` """ likelihoods = [] for z in latent_variables: @@ -61,31 +118,29 @@ def compute_anomaly_mask( binomial_probability_thr: float = 0.5, high_precision: bool = False, ) -> torch.Tensor: - """This method is not used in the basic functionality of training and testing. + """Generate NFA-based anomaly segmentation mask from latent variables. - It is a bit slow, so we decided to - leave it as an option for the user. It is included as it is part of the U-Flow paper, and can be called - separately if an unsupervised anomaly segmentation is needed. + This optional method implements the Number of False Alarms (NFA) approach from + the U-Flow paper. It is slower than the default likelihood method but provides + unsupervised binary segmentation. - Generate an anomaly mask, from latent variables. It is based on the NFA (Number of False Alarms) method, which - is a statistical method to detect anomalies. The NFA is computed as the log of the probability of the null - hypothesis, which is that all pixels are normal. First, we compute a list of candidate pixels, with - suspiciously high values of z^2, by applying a binomial test to each pixel, looking at a window around it. - Then, to compute the NFA values (actually the log-NFA), we evaluate how probable is that a pixel belongs to the - normal distribution. The null-hypothesis is that under normality assumptions, all candidate pixels are uniformly - distributed. Then, the detection is based on the concentration of candidate pixels. + The method: + 1. Applies binomial testing on local windows around each pixel + 2. Computes NFA statistics based on concentration of candidate pixels + 3. Generates binary segmentation mask Args: - z (list[torch.Tensor]): List of latent variables from the UFlow model. Each element is a tensor of shape - (N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height - and width of the latent variables, respectively, for each scale l. - window_size (int): Window size for the binomial test. Defaults to 7. - binomial_probability_thr (float): Probability threshold for the binomial test. Defaults to 0.5 - high_precision (bool): Whether to use high precision for the binomial test. Defaults to False. + z (list[torch.Tensor]): List of latent tensors from U-Flow model + window_size (int, optional): Size of local window for binomial test. + Defaults to ``7``. + binomial_probability_thr (float, optional): Probability threshold for + binomial test. Defaults to ``0.5``. + high_precision (bool, optional): Whether to use high precision NFA + computation. Slower but more accurate. Defaults to ``False``. Returns: - Anomaly mask. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and - width of the input image, respectively. + torch.Tensor: Binary anomaly mask of shape ``(batch_size, 1, height, + width)`` """ log_prob_l = [ self.binomial_test(zi, window_size / (2**scale), binomial_probability_thr, high_precision) @@ -113,22 +168,27 @@ def binomial_test( probability_thr: float, high_precision: bool = False, ) -> torch.Tensor: - """The binomial test applied to validate or reject the null hypothesis that the pixel is normal. + """Apply binomial test to validate/reject normality hypothesis. - The null hypothesis is that the pixel is normal, and the alternative hypothesis is that the pixel is anomalous. - The binomial test is applied to a window around the pixel, and the number of pixels in the window that ares - anomalous is compared to the number of pixels that are expected to be anomalous under the null hypothesis. + For each pixel, tests the null hypothesis that the pixel and its local + neighborhood are normal against the alternative that they are anomalous. + + The test: + 1. Counts anomalous pixels in local window using chi-square threshold + 2. Compares observed count to expected count under null hypothesis + 3. Returns log probability of observing such extreme counts Args: - z: Latent variable from the UFlow model. Tensor of shape (N, Cl, Hl, Wl), where N is the batch size, Cl is - the number of channels, and Hl and Wl are the height and width of the latent variables, respectively. - window_size (int): Window size for the binomial test. - probability_thr: Probability threshold for the binomial test. - high_precision: Whether to use high precision for the binomial test. + z (torch.Tensor): Latent tensor of shape ``(batch_size, channels, + height, width)`` + window_size (int): Size of local window for counting + probability_thr (float): Probability threshold for chi-square test + high_precision (bool, optional): Whether to use high precision + computation. Defaults to ``False``. Returns: - Log of the probability of the null hypothesis. - + torch.Tensor: Log probability tensor of shape ``(batch_size, 1, + height, width)`` """ tau = st.chi2.ppf(probability_thr, 1) half_win = np.max([int(window_size // 2), 1]) diff --git a/src/anomalib/models/image/uflow/feature_extraction.py b/src/anomalib/models/image/uflow/feature_extraction.py index 50cd2ba5e3..7597411a5b 100644 --- a/src/anomalib/models/image/uflow/feature_extraction.py +++ b/src/anomalib/models/image/uflow/feature_extraction.py @@ -1,4 +1,22 @@ -"""Feature Extractor for U-Flow model.""" +"""Feature extraction module for U-Flow model. + +This module implements feature extraction functionality for the U-Flow model for +anomaly detection. It provides: + +1. Feature extractors based on different backbone architectures +2. Utility function to get appropriate feature extractor +3. Support for multiple scales of feature extraction + +Example: + >>> from anomalib.models.image.uflow.feature_extraction import get_feature_extractor + >>> extractor = get_feature_extractor(backbone="resnet18") + >>> features = extractor(torch.randn(1, 3, 256, 256)) + +See Also: + - :func:`get_feature_extractor`: Factory function to get feature extractors + - :class:`FeatureExtractor`: Main feature extractor implementation + - :class:`MCaitFeatureExtractor`: Alternative feature extractor +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,17 +34,32 @@ def get_feature_extractor(backbone: str, input_size: tuple[int, int] = (256, 256)) -> nn.Module: - """Get feature extractor. Currently, is restricted to AVAILABLE_EXTRACTORS. + """Get feature extractor based on specified backbone architecture. + + This function returns a feature extractor model based on the specified backbone + architecture. Currently supported backbones are defined in ``AVAILABLE_EXTRACTORS``. Args: - backbone (str): Backbone name. - input_size (tuple[int, int]): Input size. + backbone (str): Name of the backbone architecture to use. Must be one of + ``["mcait", "resnet18", "wide_resnet50_2"]``. + input_size (tuple[int, int], optional): Input image dimensions as + ``(height, width)``. Defaults to ``(256, 256)``. + + Returns: + nn.Module: Feature extractor model instance. Raises: - ValueError if unknown backbone is provided. + ValueError: If ``backbone`` is not one of the supported architectures in + ``AVAILABLE_EXTRACTORS``. - Returns: - FeatureExtractorInterface: Feature extractor. + Example: + >>> from anomalib.models.image.uflow.feature_extraction import get_feature_extractor + >>> extractor = get_feature_extractor(backbone="resnet18") + >>> features = extractor(torch.randn(1, 3, 256, 256)) + + See Also: + - :class:`FeatureExtractor`: Main feature extractor implementation + - :class:`MCaitFeatureExtractor`: Alternative feature extractor """ if backbone not in AVAILABLE_EXTRACTORS: msg = f"Feature extractor must be one of {AVAILABLE_EXTRACTORS}." @@ -44,11 +77,33 @@ def get_feature_extractor(backbone: str, input_size: tuple[int, int] = (256, 256 class FeatureExtractor(TimmFeatureExtractor): """Feature extractor based on ResNet (or others) backbones. + This class extends TimmFeatureExtractor to extract and normalize features from + common CNN backbones like ResNet. It adds layer normalization to the extracted + features. + Args: - backbone (str): Backbone of the feature extractor. - input_size (tuple[int, int]): Input image size used for computing normalization layers. - layers (tuple[str], optional): Layers from which to extract features. - Defaults to ("layer1", "layer2", "layer3"). + backbone (str): Name of the backbone CNN architecture to use for feature + extraction (e.g. ``"resnet18"``, ``"wide_resnet50_2"``). + input_size (tuple[int, int]): Input image dimensions as ``(height, width)`` + used for computing normalization layers. + layers (tuple[str, ...], optional): Names of layers from which to extract + features. Defaults to ``("layer1", "layer2", "layer3")``. + **kwargs: Additional keyword arguments (unused). + + Example: + >>> import torch + >>> extractor = FeatureExtractor( + ... backbone="resnet18", + ... input_size=(256, 256) + ... ) + >>> features = extractor(torch.randn(1, 3, 256, 256)) + + Attributes: + channels (list[int]): Number of channels in each extracted feature layer. + scale_factors (list[int]): Downsampling factor for each feature layer. + scales (range): Range object for iterating over feature scales. + feature_normalizations (nn.ModuleList): Layer normalization modules for + each feature scale. """ def __init__( @@ -76,26 +131,69 @@ def __init__( param.requires_grad = False def forward(self, img: torch.Tensor) -> torch.Tensor: - """Normalized features.""" + """Extract and normalize features from input image. + + Args: + img (torch.Tensor): Input image tensor of shape + ``(batch_size, channels, height, width)``. + + Returns: + torch.Tensor: Normalized features from multiple network layers. + """ features = self.extract_features(img) return self.normalize_features(features) def extract_features(self, img: torch.Tensor) -> torch.Tensor: - """Extract features.""" + """Extract features from input image using backbone network. + + Args: + img (torch.Tensor): Input image tensor of shape + ``(batch_size, channels, height, width)``. + + Returns: + torch.Tensor: Features extracted from multiple network layers. + """ self.feature_extractor.eval() return self.feature_extractor(img) def normalize_features(self, features: Iterable[torch.Tensor]) -> list[torch.Tensor]: - """Normalize features.""" + """Apply layer normalization to extracted features. + + Args: + features (Iterable[torch.Tensor]): Features extracted from multiple + network layers. + + Returns: + list[torch.Tensor]: Normalized features from each layer. + """ return [self.feature_normalizations[i](feature) for i, feature in enumerate(features)] class MCaitFeatureExtractor(nn.Module): """Feature extractor based on MCait backbone. - This is the proposed feature extractor in the paper. It uses two - independently trained Cait models, at different scales, with input sizes 448 and 224, respectively. - It also includes a normalization layer for each scale. + This class implements the feature extractor proposed in the U-Flow paper. It uses two + independently trained CaiT models at different scales: + - A CaiT-M48 model with input size 448x448 + - A CaiT-S24 model with input size 224x224 + + Each model extracts features at a different scale, and includes normalization layers. + + Example: + >>> from anomalib.models.image.uflow.feature_extraction import MCaitFeatureExtractor + >>> extractor = MCaitFeatureExtractor() + >>> image = torch.randn(1, 3, 448, 448) + >>> features = extractor(image) + >>> [f.shape for f in features] + [torch.Size([1, 768, 28, 28]), torch.Size([1, 384, 14, 14])] + + Attributes: + input_size (int): Size of input images (448) + extractor1 (nn.Module): CaiT-M48 model for scale 1 (448x448) + extractor2 (nn.Module): CaiT-S24 model for scale 2 (224x224) + channels (list[int]): Number of channels for each scale [768, 384] + scale_factors (list[int]): Downsampling factors for each scale [16, 32] + scales (range): Range object for iterating over scales """ def __init__(self) -> None: @@ -112,20 +210,33 @@ def __init__(self) -> None: for param in self.extractor2.parameters(): param.requires_grad = False - def forward(self, img: torch.Tensor, training: bool = True) -> torch.Tensor: - """Return normalized features.""" + def forward(self, img: torch.Tensor) -> torch.Tensor: + """Extract and normalize features from input image. + + Args: + img (torch.Tensor): Input image tensor of shape + ``(batch_size, channels, height, width)`` + + Returns: + torch.Tensor: List of normalized feature tensors from each scale + """ features = self.extract_features(img) - return self.normalize_features(features, training=training) + return self.normalize_features(features) - def extract_features(self, img: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: # noqa: ARG002 | unused argument - """Extract features from ``img`` from the two extractors. + def extract_features(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Extract features from input image using both CaiT models. + + The features are extracted at two scales: + - Scale 1: Using CaiT-M48 up to block index 40 (448x448 input) + - Scale 2: Using CaiT-S24 up to block index 20 (224x224 input) Args: - img (torch.Tensor): Input image - kwargs: unused + img (torch.Tensor): Input image tensor of shape + ``(batch_size, channels, height, width)`` Returns: - tuple[torch.Tensor, torch.Tensor]: Features from the two extractors. + tuple[torch.Tensor, torch.Tensor]: Features from both extractors with shapes: + ``[(B, 768, H/16, W/16), (B, 384, H/32, W/32)]`` """ self.extractor1.eval() self.extractor2.eval() @@ -147,15 +258,20 @@ def extract_features(self, img: torch.Tensor, **kwargs) -> tuple[torch.Tensor, t return (x1, x2) - def normalize_features(self, features: torch.Tensor, **kwargs) -> torch.Tensor: # noqa: ARG002 | unused argument - """Normalize features. + def normalize_features(self, features: torch.Tensor) -> torch.Tensor: + """Normalize extracted features from both scales. + + For each scale: + 1. Apply layer normalization + 2. Reshape features to spatial format + 3. Append to list of normalized features Args: - features (torch.Tensor): Features to normalize. - **kwargs: unused + features (torch.Tensor): Tuple of features from both extractors Returns: - torch.Tensor: Normalized features. + torch.Tensor: List of normalized feature tensors with shapes: + ``[(B, 768, H/16, W/16), (B, 384, H/32, W/32)]`` """ normalized_features = [] for i, extractor in enumerate([self.extractor1, self.extractor2]): diff --git a/src/anomalib/models/image/uflow/lightning_model.py b/src/anomalib/models/image/uflow/lightning_model.py index bfd51195ca..02715837e9 100644 --- a/src/anomalib/models/image/uflow/lightning_model.py +++ b/src/anomalib/models/image/uflow/lightning_model.py @@ -1,6 +1,28 @@ """U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold. -https://arxiv.org/pdf/2211.12353.pdf +This module implements the U-Flow model for anomaly detection as described in + `_. The model consists +of: + +- A U-shaped normalizing flow architecture for density estimation +- Multi-scale feature extraction using pre-trained backbones +- Unsupervised threshold estimation based on the learned density +- Anomaly detection by comparing likelihoods to the threshold + +Example: + >>> from anomalib.models.image import Uflow + >>> from anomalib.engine import Engine + >>> from anomalib.data import MVTec + >>> datamodule = MVTec() + >>> model = Uflow() + >>> engine = Engine(model=model, datamodule=datamodule) + >>> engine.fit() # doctest: +SKIP + >>> predictions = engine.predict() # doctest: +SKIP + +See Also: + - :class:`UflowModel`: PyTorch implementation of the model architecture + - :class:`UFlowLoss`: Loss function for training + - :class:`AnomalyMapGenerator`: Anomaly map generation from features """ # Copyright (C) 2023-2024 Intel Corporation @@ -32,14 +54,55 @@ class Uflow(AnomalibModule): - """Uflow model. + """Lightning implementation of the U-Flow model. + + This class implements the U-Flow model for anomaly detection as described in + Rudolph et al., 2022. The model consists of: + + - A U-shaped normalizing flow architecture for density estimation + - Multi-scale feature extraction using pre-trained backbones + - Unsupervised threshold estimation based on the learned density + - Anomaly detection by comparing likelihoods to the threshold Args: - backbone (str): Backbone name. - flow_steps (int): Number of flow steps. - affine_clamp (float): Affine clamp. - affine_subnet_channels_ratio (float): Affine subnet channels ratio. - permute_soft (bool): Whether to use soft permutation. + backbone (str, optional): Name of the backbone feature extractor. Must be + one of ``["mcait", "resnet18", "wide_resnet50_2"]``. Defaults to + ``"mcait"``. + flow_steps (int, optional): Number of normalizing flow steps. Defaults to + ``4``. + affine_clamp (float, optional): Clamping value for affine coupling + layers. Defaults to ``2.0``. + affine_subnet_channels_ratio (float, optional): Channel ratio for affine + coupling subnet. Defaults to ``1.0``. + permute_soft (bool, optional): Whether to use soft permutation. Defaults + to ``False``. + pre_processor (PreProcessor | bool, optional): Pre-processor for input + data. If ``True``, uses default pre-processor. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor for model + outputs. If ``True``, uses default post-processor. Defaults to + ``True``. + evaluator (Evaluator | bool, optional): Evaluator for model performance. + If ``True``, uses default evaluator. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer for model outputs. + If ``True``, uses default visualizer. Defaults to ``True``. + + Example: + >>> from anomalib.models.image import Uflow + >>> from anomalib.engine import Engine + >>> from anomalib.data import MVTec + >>> datamodule = MVTec() + >>> model = Uflow(backbone="resnet18") + >>> engine = Engine(model=model, datamodule=datamodule) + >>> engine.fit() # doctest: +SKIP + >>> predictions = engine.predict() # doctest: +SKIP + + Raises: + ValueError: If ``input_size`` is not provided during initialization. + + See Also: + - :class:`UflowModel`: PyTorch implementation of the model architecture + - :class:`UFlowLoss`: Loss function for training + - :class:`AnomalyMapGenerator`: Anomaly map generation from features """ def __init__( @@ -54,26 +117,35 @@ def __init__( evaluator: Evaluator | bool = True, visualizer: Visualizer | bool = True, ) -> None: - """Uflow model. + """Initialize U-Flow model. Args: - backbone (str): Backbone name. - flow_steps (int): Number of flow steps. - affine_clamp (float): Affine clamp. - affine_subnet_channels_ratio (float): Affine subnet channels ratio. - permute_soft (bool): Whether to use soft permutation. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. - post_processor (PostProcessor, optional): Post-processor for the model. - This is used to post-process the output data after it is passed to the model. - Defaults to ``None``. - evaluator (Evaluator, optional): Evaluator for the model. - This is used to evaluate the model. - Defaults to ``True``. - visualizer (Visualizer, optional): Visualizer for the model. - This is used to visualize the model. + backbone (str, optional): Name of the backbone feature extractor. + Must be one of ``["mcait", "resnet18", "wide_resnet50_2"]``. + Defaults to ``"mcait"``. + flow_steps (int, optional): Number of normalizing flow steps. + Defaults to ``4``. + affine_clamp (float, optional): Clamping value for affine coupling + layers. Defaults to ``2.0``. + affine_subnet_channels_ratio (float, optional): Channel ratio for + affine coupling subnet. Defaults to ``1.0``. + permute_soft (bool, optional): Whether to use soft permutation. + Defaults to ``False``. + pre_processor (PreProcessor | bool, optional): Pre-processor for + input data. If ``True``, uses default pre-processor. Defaults to + ``True``. + post_processor (PostProcessor | bool, optional): Post-processor for + model outputs. If ``True``, uses default post-processor. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator for model + performance. If ``True``, uses default evaluator. Defaults to + ``True``. + visualizer (Visualizer | bool, optional): Visualizer for model + outputs. If ``True``, uses default visualizer. Defaults to + ``True``. + + Raises: + ValueError: If ``input_size`` is not provided during initialization. """ super().__init__( pre_processor=pre_processor, @@ -103,7 +175,19 @@ def __init__( @classmethod def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor: - """Default pre-processor for UFlow.""" + """Configure default pre-processor for U-Flow model. + + Args: + image_size (tuple[int, int] | None, optional): Input image size. + Not used as U-Flow has fixed input size. Defaults to ``None``. + + Returns: + PreProcessor: Default pre-processor with resizing and normalization. + + Note: + The input image size is fixed to 448x448 for U-Flow regardless of + the provided ``image_size``. + """ if image_size is not None: logger.warning("Image size is not used in UFlow. The input image size is determined by the model.") transform = Compose([ @@ -113,7 +197,13 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P return PreProcessor(transform=transform) def configure_optimizers(self) -> tuple[list[LightningOptimizer], list[LRScheduler]]: - """Return optimizer and scheduler.""" + """Configure optimizers and learning rate schedulers. + + Returns: + tuple[list[LightningOptimizer], list[LRScheduler]]: Tuple containing: + - List of optimizers (Adam with initial lr=1e-3) + - List of schedulers (LinearLR reducing to 0.4 over 25000 steps) + """ # Optimizer # values used in paper: bottle: 0.0001128999, cable: 0.0016160391, capsule: 0.0012118892, carpet: 0.0012118892, # grid: 0.0000362248, hazelnut: 0.0013268899, leather: 0.0006124724, metal_nut: 0.0008148858, @@ -131,27 +221,49 @@ def configure_optimizers(self) -> tuple[list[LightningOptimizer], list[LRSchedul return [optimizer], [scheduler] def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: # noqa: ARG002 | unused arguments - """Training step.""" + """Perform a training step. + + Args: + batch (Batch): Input batch containing images + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments + + Returns: + STEP_OUTPUT: Dictionary containing the loss value + """ z, ljd = self.model(batch.image) loss = self.loss(z, ljd) self.log_dict({"loss": loss}, on_step=True, on_epoch=False, prog_bar=False, logger=True) return {"loss": loss} def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: # noqa: ARG002 | unused arguments - """Validation step.""" + """Perform a validation step. + + Args: + batch (Batch): Input batch containing images + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments + + Returns: + STEP_OUTPUT: Batch updated with model predictions + """ predictions = self.model(batch.image) return batch.update(**predictions._asdict()) @property def trainer_arguments(self) -> dict[str, Any]: - """Return EfficientAD trainer arguments.""" + """Get trainer arguments for U-Flow. + + Returns: + dict[str, Any]: Dictionary containing trainer arguments + """ return {"num_sanity_val_steps": 0} @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: Learning type (ONE_CLASS for U-Flow) """ return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/uflow/loss.py b/src/anomalib/models/image/uflow/loss.py index 08f2dfbe31..c9d09547f5 100644 --- a/src/anomalib/models/image/uflow/loss.py +++ b/src/anomalib/models/image/uflow/loss.py @@ -1,4 +1,23 @@ -"""Loss function for the UFlow Model Implementation.""" +"""Loss function implementation for the U-Flow model. + +This module implements the loss function used to train the U-Flow model for anomaly +detection as described in `_. +The loss combines: + +- A likelihood term based on the hidden variables +- A Jacobian determinant term from the normalizing flow + +Example: + >>> from anomalib.models.image.uflow.loss import UFlowLoss + >>> loss_fn = UFlowLoss() + >>> hidden_vars = [torch.randn(2, 64, 32, 32)] + >>> jacobians = [torch.randn(2)] + >>> loss = loss_fn(hidden_vars, jacobians) + +See Also: + - :class:`UFlowLoss`: Main loss function implementation + - :class:`UflowModel`: PyTorch model using this loss +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -8,18 +27,45 @@ class UFlowLoss(nn.Module): - """UFlow Loss.""" + """Loss function for training the U-Flow model. + + This class implements the loss function used to train the U-Flow model. + The loss combines: + + 1. A likelihood term based on the hidden variables (``lpz``) + 2. A Jacobian determinant term from the normalizing flow + + The total loss is computed as: + ``loss = mean(lpz - jacobians)`` + + Example: + >>> from anomalib.models.image.uflow.loss import UFlowLoss + >>> loss_fn = UFlowLoss() + >>> hidden_vars = [torch.randn(2, 64, 32, 32)] # List of hidden variables + >>> jacobians = [torch.randn(2)] # List of log Jacobian determinants + >>> loss = loss_fn(hidden_vars, jacobians) + >>> loss.shape + torch.Size([]) + + See Also: + - :class:`UflowModel`: PyTorch model using this loss function + - :class:`Uflow`: Lightning implementation using this loss + """ @staticmethod def forward(hidden_variables: list[Tensor], jacobians: list[Tensor]) -> Tensor: """Calculate the UFlow loss. Args: - hidden_variables (list[Tensor]): Hidden variables from the fastflow model. f: X -> Z - jacobians (list[Tensor]): Log of the jacobian determinants from the fastflow model. + hidden_variables (list[Tensor]): List of hidden variable tensors from the + normalizing flow transformation f: X -> Z. Each tensor has shape + ``(batch_size, channels, height, width)``. + jacobians (list[Tensor]): List of log Jacobian determinant tensors from the + flow transformation. Each tensor has shape ``(batch_size,)``. Returns: - Tensor: UFlow loss computed based on the hidden variables and the log of the Jacobians. + Tensor: Scalar loss value combining the likelihood of hidden variables and + the log Jacobian determinants. """ lpz = torch.sum(torch.stack([0.5 * torch.sum(z_i**2, dim=(1, 2, 3)) for z_i in hidden_variables], dim=0)) return torch.mean(lpz - jacobians) diff --git a/src/anomalib/models/image/uflow/torch_model.py b/src/anomalib/models/image/uflow/torch_model.py index 7c376328b9..2612b16356 100644 --- a/src/anomalib/models/image/uflow/torch_model.py +++ b/src/anomalib/models/image/uflow/torch_model.py @@ -1,4 +1,19 @@ -"""U-Flow torch model.""" +"""U-Flow PyTorch Implementation. + +This module provides the PyTorch implementation of the U-Flow model for anomaly detection. +U-Flow combines normalizing flows with a U-Net style architecture to learn the distribution +of normal images and detect anomalies. + +The model consists of several key components: + - Feature extraction using a pre-trained backbone + - Normalizing flow blocks arranged in a U-Net structure + - Anomaly map generation for localization + +The implementation includes classes for: + - Affine coupling subnet construction + - Main U-Flow model architecture + - Anomaly map generation +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -18,11 +33,28 @@ class AffineCouplingSubnet: """Class for building the Affine Coupling subnet. - It is passed as an argument to the `AllInOneBlock` module. + This class creates a subnet used within the affine coupling layers of the normalizing + flow. The subnet is passed as an argument to the ``AllInOneBlock`` module and + determines how features are transformed within the coupling layer. Args: - kernel_size (int): Kernel size. - subnet_channels_ratio (float): Subnet channels ratio. + kernel_size (int): Size of convolutional kernels used in subnet layers. + subnet_channels_ratio (float): Ratio determining the number of intermediate + channels in the subnet relative to input channels. + + Example: + >>> subnet = AffineCouplingSubnet(kernel_size=3, subnet_channels_ratio=1.0) + >>> layer = subnet(in_channels=64, out_channels=128) + >>> layer + Sequential( + (0): Conv2d(64, 64, kernel_size=(3, 3), padding=same) + (1): ReLU() + (2): Conv2d(64, 128, kernel_size=(3, 3), padding=same) + ) + + See Also: + - :class:`AllInOneBlock`: Flow block using this subnet + - :class:`UflowModel`: Main model incorporating these subnets """ def __init__(self, kernel_size: int, subnet_channels_ratio: float) -> None: @@ -30,14 +62,21 @@ def __init__(self, kernel_size: int, subnet_channels_ratio: float) -> None: self.subnet_channels_ratio = subnet_channels_ratio def __call__(self, in_channels: int, out_channels: int) -> nn.Sequential: - """Return AffineCouplingSubnet network. + """Create and return the affine coupling subnet. + + The subnet consists of two convolutional layers with a ReLU activation in + between. The intermediate channel dimension is determined by + ``subnet_channels_ratio``. Args: - in_channels (int): Input channels. - out_channels (int): Output channels. + in_channels (int): Number of input channels to the subnet. + out_channels (int): Number of output channels from the subnet. Returns: - nn.Sequential: Affine Coupling subnet. + nn.Sequential: Sequential container of the subnet layers including: + - Conv2d layer mapping input to intermediate channels + - ReLU activation + - Conv2d layer mapping intermediate to output channels """ mid_channels = int(in_channels * self.subnet_channels_ratio) return nn.Sequential( @@ -48,15 +87,44 @@ def __call__(self, in_channels: int, out_channels: int) -> nn.Sequential: class UflowModel(nn.Module): - """U-Flow model. + """PyTorch implementation of the U-Flow model architecture. + + This class implements the U-Flow model for anomaly detection. + The model consists of: + + - A U-shaped normalizing flow architecture for density estimation + - Multi-scale feature extraction using pre-trained backbones + - Unsupervised threshold estimation based on the learned density + - Anomaly detection by comparing likelihoods to the threshold Args: - input_size (tuple[int, int]): Input image size. - flow_steps (int): Number of flow steps. - backbone (str): Backbone name. - affine_clamp (float): Affine clamp. - affine_subnet_channels_ratio (float): Affine subnet channels ratio. - permute_soft (bool): Whether to use soft permutation. + input_size (tuple[int, int]): Input image dimensions as ``(height, width)``. + Defaults to ``(448, 448)``. + flow_steps (int): Number of normalizing flow steps in each flow stage. + Defaults to ``4``. + backbone (str): Name of the backbone feature extractor. Must be one of + ``["mcait", "resnet18", "wide_resnet50_2"]``. Defaults to ``"mcait"``. + affine_clamp (float): Clamping value for affine coupling layers. Defaults + to ``2.0``. + affine_subnet_channels_ratio (float): Channel ratio for affine coupling + subnet. Defaults to ``1.0``. + permute_soft (bool): Whether to use soft permutation. Defaults to + ``False``. + + Example: + >>> import torch + >>> from anomalib.models.image.uflow.torch_model import UflowModel + >>> model = UflowModel( + ... input_size=(256, 256), + ... backbone="resnet18" + ... ) + >>> image = torch.randn(1, 3, 256, 256) + >>> output = model(image) # Returns anomaly map during inference + + See Also: + - :class:`Uflow`: Lightning implementation using this model + - :class:`UFlowLoss`: Loss function for training + - :class:`AnomalyMapGenerator`: Anomaly map generation from features """ def __init__( @@ -80,21 +148,28 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator(input_size) def build_flow(self, flow_steps: int) -> ff.GraphINN: - """Build the flow model. + """Build the U-shaped normalizing flow architecture. + + The flow is built in a U-shaped structure, processing features from coarse + to fine scales: - First we start with the input nodes, which have to match the feature extractor output. - Then, we build the U-Shaped flow. Starting from the bottom (the coarsest scale), the flow is built as follows: - 1. Pass the input through a Flow Stage (`build_flow_stage`). - 2. Split the output of the flow stage into two parts, one that goes directly to the output, - 3. and the other is up-sampled, and will be concatenated with the output of the next flow stage (next scale) - 4. Repeat steps 1-3 for the next scale. - Finally, we build the Flow graph using the input nodes, the flow stages, and the output nodes. + 1. Start with input nodes matching feature extractor outputs + 2. For each scale (coarse to fine): + - Pass through flow stage (sequence of coupling layers) + - Split output into two parts + - Send one part to output + - Upsample other part and concatenate with next scale + 3. Build final flow graph combining all nodes Args: - flow_steps (int): Number of flow steps. + flow_steps (int): Number of coupling layers in each flow stage. Returns: - ff.GraphINN: Flow model. + ff.GraphINN: Constructed normalizing flow graph. + + See Also: + - :meth:`build_flow_stage`: Builds individual flow stages + - :class:`AllInOneBlock`: Individual coupling layer blocks """ input_nodes = [] for channel, s_factor in zip( @@ -138,17 +213,24 @@ def build_flow(self, flow_steps: int) -> ff.GraphINN: return ff.GraphINN(input_nodes + nodes + output_nodes[::-1]) def build_flow_stage(self, in_node: ff.Node, flow_steps: int, condition_node: ff.Node = None) -> list[ff.Node]: - """Build a flow stage, which is a sequence of flow steps. + """Build a single flow stage consisting of multiple coupling layers. - Each flow stage is essentially a sequence of `flow_steps` Glow blocks (`AllInOneBlock`). + Each flow stage is a sequence of ``flow_steps`` Glow-style coupling blocks + (``AllInOneBlock``). The blocks alternate between 3x3 and 1x1 convolutions + in their coupling subnets. Args: - in_node (ff.Node): Input node. - flow_steps (int): Number of flow steps. - condition_node (ff.Node): Condition node. + in_node (ff.Node): Input node to the flow stage. + flow_steps (int): Number of coupling layers to use. + condition_node (ff.Node, optional): Optional conditioning node. + Defaults to ``None``. Returns: - List[ff.Node]: List of flow steps. + list[ff.Node]: List of constructed coupling layer nodes. + + See Also: + - :class:`AllInOneBlock`: Individual coupling layer implementation + - :class:`AffineCouplingSubnet`: Subnet used in coupling layers """ flow_size = in_node.output_dims[0][-1] nodes = [] @@ -173,7 +255,20 @@ def build_flow_stage(self, in_node: ff.Node, flow_steps: int, condition_node: ff return nodes def forward(self, image: torch.Tensor) -> torch.Tensor | InferenceBatch: - """Return anomaly map.""" + """Process input image through the model. + + During training, returns latent variables and log-Jacobian determinant. + During inference, returns anomaly scores and anomaly map. + + Args: + image (torch.Tensor): Input image tensor of shape + ``(batch_size, channels, height, width)``. + + Returns: + torch.Tensor | InferenceBatch: During training, returns tuple of + ``(latent_vars, log_jacobian)``. During inference, returns + ``InferenceBatch`` with anomaly scores and map. + """ features = self.feature_extractor(image) z, ljd = self.encode(features) @@ -185,7 +280,16 @@ def forward(self, image: torch.Tensor) -> torch.Tensor | InferenceBatch: return InferenceBatch(pred_score=pred_score, anomaly_map=anomaly_map) def encode(self, features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Return""" + """Encode input features to latent space using normalizing flow. + + Args: + features (torch.Tensor): Input features from feature extractor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple containing: + - Latent variables from flow transformation + - Log-Jacobian determinant of the transformation + """ z, ljd = self.flow(features, rev=False) if len(self.feature_extractor.scales) == 1: z = [z] diff --git a/src/anomalib/models/image/vlm_ad/__init__.py b/src/anomalib/models/image/vlm_ad/__init__.py index 46ab8e0fee..f13d6c46d9 100644 --- a/src/anomalib/models/image/vlm_ad/__init__.py +++ b/src/anomalib/models/image/vlm_ad/__init__.py @@ -1,4 +1,23 @@ -"""Visual Anomaly Model.""" +"""Vision Language Model (VLM) based Anomaly Detection. + +This module implements anomaly detection using Vision Language Models (VLMs) like +GPT-4V, LLaVA, etc. The models use natural language prompting to detect anomalies +in images by comparing them with reference normal images. + +Example: + >>> from anomalib.models.image import VlmAd + >>> model = VlmAd( # doctest: +SKIP + ... backend="chatgpt", + ... model_name="gpt-4-vision-preview" + ... ) + >>> model.fit(["normal1.jpg", "normal2.jpg"]) # doctest: +SKIP + >>> prediction = model.predict("test.jpg") # doctest: +SKIP + +See Also: + - :class:`VlmAd`: Main model class for VLM-based anomaly detection + - :mod:`.backends`: Different VLM backend implementations + - :mod:`.utils`: Utility functions for prompting and responses +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/vlm_ad/backends/__init__.py b/src/anomalib/models/image/vlm_ad/backends/__init__.py index 44009f8f83..6de26ffa24 100644 --- a/src/anomalib/models/image/vlm_ad/backends/__init__.py +++ b/src/anomalib/models/image/vlm_ad/backends/__init__.py @@ -1,4 +1,23 @@ -"""VLM backends.""" +"""Vision Language Model (VLM) backends for anomaly detection. + +This module provides backend implementations for different Vision Language Models +(VLMs) that can be used for anomaly detection. The backends include: + +- :class:`ChatGPT`: OpenAI's ChatGPT model +- :class:`Huggingface`: Models from Hugging Face Hub +- :class:`Ollama`: Open source LLM models via Ollama + +Example: + >>> from anomalib.models.image.vlm_ad.backends import ChatGPT + >>> backend = ChatGPT() # doctest: +SKIP + >>> response = backend.generate(prompt="Describe this image") # doctest: +SKIP + +See Also: + - :class:`Backend`: Base class for VLM backends + - :class:`ChatGPT`: ChatGPT backend implementation + - :class:`Huggingface`: Hugging Face backend implementation + - :class:`Ollama`: Ollama backend implementation +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/vlm_ad/backends/base.py b/src/anomalib/models/image/vlm_ad/backends/base.py index b4aadf9a22..37fb20d0df 100644 --- a/src/anomalib/models/image/vlm_ad/backends/base.py +++ b/src/anomalib/models/image/vlm_ad/backends/base.py @@ -1,4 +1,26 @@ -"""Base backend.""" +"""Base backend for Vision Language Models (VLMs). + +This module provides the abstract base class for VLM backends used in anomaly detection. +The backends handle communication with different VLM services and models. + +Example: + >>> from anomalib.models.image.vlm_ad.backends import Backend + >>> class CustomBackend(Backend): + ... def __init__(self, model_name: str) -> None: + ... super().__init__(model_name) + ... def add_reference_images(self, image: str) -> None: + ... pass + ... def predict(self, image: str, prompt: Prompt) -> str: + ... return "normal" + ... @property + ... def num_reference_images(self) -> int: + ... return 0 + +See Also: + - :class:`ChatGPT`: OpenAI's ChatGPT backend implementation + - :class:`Huggingface`: Hugging Face models backend implementation + - :class:`Ollama`: Ollama models backend implementation +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,21 +32,65 @@ class Backend(ABC): - """Base backend.""" + """Abstract base class for Vision Language Model (VLM) backends. + + This class defines the interface that all VLM backends must implement. Backends + handle communication with different VLM services and models for anomaly detection. + + Example: + >>> from anomalib.models.image.vlm_ad.backends import Backend + >>> class CustomBackend(Backend): + ... def __init__(self, model_name: str) -> None: + ... super().__init__(model_name) + ... def add_reference_images(self, image: str) -> None: + ... pass + ... def predict(self, image: str, prompt: Prompt) -> str: + ... return "normal" + ... @property + ... def num_reference_images(self) -> int: + ... return 0 + + See Also: + - :class:`ChatGPT`: OpenAI's ChatGPT backend implementation + - :class:`Huggingface`: Hugging Face models backend implementation + - :class:`Ollama`: Ollama models backend implementation + """ @abstractmethod def __init__(self, model_name: str) -> None: - """Initialize the backend.""" + """Initialize the VLM backend. + + Args: + model_name (str): Name or identifier of the VLM model to use + """ @abstractmethod def add_reference_images(self, image: str | Path) -> None: - """Add reference images for k-shot.""" + """Add reference images for few-shot learning. + + The backend stores these images to use as examples when making predictions. + + Args: + image (str | Path): Path to the reference image file + """ @abstractmethod def predict(self, image: str | Path, prompt: Prompt) -> str: - """Predict the anomaly label.""" + """Predict whether an image contains anomalies. + + Args: + image (str | Path): Path to the image file to analyze + prompt (Prompt): Prompt template to use for querying the VLM + + Returns: + str: Prediction result from the VLM + """ @property @abstractmethod def num_reference_images(self) -> int: - """Get the number of reference images.""" + """Get the number of stored reference images. + + Returns: + int: Count of reference images currently stored in the backend + """ diff --git a/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py b/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py index 53648e688a..e81c1a2d63 100644 --- a/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py +++ b/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py @@ -1,4 +1,29 @@ -"""ChatGPT backend.""" +"""ChatGPT backend for Vision Language Models (VLMs). + +This module implements a backend for using OpenAI's ChatGPT model for vision-language +tasks in anomaly detection. The backend handles: + +- Authentication with OpenAI API +- Encoding and sending images +- Prompting the model +- Processing responses + +Example: + >>> from anomalib.models.image.vlm_ad.backends import ChatGPT + >>> backend = ChatGPT(model_name="gpt-4-vision-preview") # doctest: +SKIP + >>> backend.add_reference_images("normal_image.jpg") # doctest: +SKIP + >>> response = backend.predict("test.jpg", prompt) # doctest: +SKIP + +Args: + model_name (str): Name of the ChatGPT model to use (e.g. ``"gpt-4-vision-preview"``) + api_key (str | None, optional): OpenAI API key. If not provided, will attempt to + load from environment. Defaults to ``None``. + +See Also: + - :class:`Backend`: Base class for VLM backends + - :class:`Huggingface`: Alternative backend using Hugging Face models + - :class:`Ollama`: Alternative backend using Ollama models +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -28,7 +53,37 @@ class ChatGPT(Backend): - """ChatGPT backend.""" + """OpenAI ChatGPT backend for vision-language anomaly detection. + + This class implements a backend for using OpenAI's ChatGPT models with vision + capabilities (e.g. GPT-4V) for anomaly detection. It handles: + + - Authentication with OpenAI API + - Image encoding and formatting + - Few-shot learning with reference images + - Model prompting and response processing + + Args: + model_name (str): Name of the ChatGPT model to use (e.g. + ``"gpt-4-vision-preview"``) + api_key (str | None, optional): OpenAI API key. If not provided, will + attempt to load from environment. Defaults to ``None``. + + Example: + >>> from anomalib.models.image.vlm_ad.backends import ChatGPT + >>> backend = ChatGPT(model_name="gpt-4-vision-preview") # doctest: +SKIP + >>> backend.add_reference_images("normal_image.jpg") # doctest: +SKIP + >>> response = backend.predict("test.jpg", prompt) # doctest: +SKIP + + Raises: + ImportError: If OpenAI package is not installed + ValueError: If no API key is provided or found in environment + + See Also: + - :class:`Backend`: Base class for VLM backends + - :class:`Huggingface`: Alternative backend using Hugging Face models + - :class:`Ollama`: Alternative backend using Ollama models + """ def __init__(self, model_name: str, api_key: str | None = None) -> None: """Initialize the ChatGPT backend.""" @@ -39,7 +94,14 @@ def __init__(self, model_name: str, api_key: str | None = None) -> None: @property def client(self) -> OpenAI: - """Get the OpenAI client.""" + """Get the OpenAI client. + + Returns: + OpenAI: Initialized OpenAI client instance + + Raises: + ImportError: If OpenAI package is not installed + """ if OpenAI is None: msg = "OpenAI is not installed. Please install it to use ChatGPT backend." raise ImportError(msg) @@ -48,16 +110,33 @@ def client(self) -> OpenAI: return self._client def add_reference_images(self, image: str | Path) -> None: - """Add reference images for k-shot.""" + """Add reference images for few-shot learning. + + Args: + image (str | Path): Path to the reference image file + """ self._ref_images_encoded.append(self._encode_image_to_url(image)) @property def num_reference_images(self) -> int: - """Get the number of reference images.""" + """Get the number of reference images. + + Returns: + int: Number of reference images added for few-shot learning + """ return len(self._ref_images_encoded) def predict(self, image: str | Path, prompt: Prompt) -> str: - """Predict the anomaly label.""" + """Predict whether an image contains anomalies. + + Args: + image (str | Path): Path to the image file to analyze + prompt (Prompt): Prompt object containing few-shot and prediction + prompts + + Returns: + str: Model's response indicating if anomalies were detected + """ image_encoded = self._encode_image_to_url(image) messages = [] @@ -72,7 +151,15 @@ def predict(self, image: str | Path, prompt: Prompt) -> str: @staticmethod def _generate_message(content: str, images: list[str] | None) -> dict: - """Generate a message.""" + """Generate a message for the ChatGPT API. + + Args: + content (str): Text content of the message + images (list[str] | None): List of base64-encoded image URLs + + Returns: + dict: Formatted message dictionary for the API + """ message: dict[str, list[dict] | str] = {"role": "user"} if images is not None: _content: list[dict[str, str | dict]] = [{"type": "text", "text": content}] @@ -83,7 +170,14 @@ def _generate_message(content: str, images: list[str] | None) -> dict: return message def _encode_image_to_url(self, image: str | Path) -> str: - """Encode the image to base64 and embed in url string.""" + """Encode an image file to a base64 URL string. + + Args: + image (str | Path): Path to the image file + + Returns: + str: Base64-encoded image URL string + """ image_path = Path(image) extension = image_path.suffix base64_encoded = self._encode_image_to_base_64(image_path) @@ -91,11 +185,35 @@ def _encode_image_to_url(self, image: str | Path) -> str: @staticmethod def _encode_image_to_base_64(image: str | Path) -> str: - """Encode the image to base64.""" + """Encode an image file to base64. + + Args: + image (str | Path): Path to the image file + + Returns: + str: Base64-encoded image string + """ image = Path(image) return base64.b64encode(image.read_bytes()).decode("utf-8") def _get_api_key(self, api_key: str | None = None) -> str: + """Get the OpenAI API key. + + Attempts to get the API key in the following order: + 1. From the provided argument + 2. From environment variable ``OPENAI_API_KEY`` + 3. From ``.env`` file + + Args: + api_key (str | None, optional): API key provided directly. Defaults to + ``None``. + + Returns: + str: Valid OpenAI API key + + Raises: + ValueError: If no API key is found + """ if api_key is None: load_dotenv() api_key = os.getenv("OPENAI_API_KEY") diff --git a/src/anomalib/models/image/vlm_ad/backends/huggingface.py b/src/anomalib/models/image/vlm_ad/backends/huggingface.py index e8d3c1e84b..9e427b6965 100644 --- a/src/anomalib/models/image/vlm_ad/backends/huggingface.py +++ b/src/anomalib/models/image/vlm_ad/backends/huggingface.py @@ -1,4 +1,28 @@ -"""Huggingface backend.""" +"""Hugging Face backend for Vision Language Models (VLMs). + +This module implements a backend for using Hugging Face models for vision-language +tasks in anomaly detection. The backend handles: + +- Loading models and processors from Hugging Face Hub +- Processing images into model inputs +- Few-shot learning with reference images +- Model inference and response processing + +Example: + >>> from anomalib.models.image.vlm_ad.backends import Huggingface + >>> backend = Huggingface(model_name="llava-hf/llava-1.5-7b-hf") # doctest: +SKIP + >>> backend.add_reference_images("normal_image.jpg") # doctest: +SKIP + >>> response = backend.predict("test.jpg", prompt) # doctest: +SKIP + +Args: + model_name (str): Name of the Hugging Face model to use (e.g. + ``"llava-hf/llava-1.5-7b-hf"``) + +See Also: + - :class:`Backend`: Base class for VLM backends + - :class:`ChatGPT`: Alternative backend using OpenAI models + - :class:`Ollama`: Alternative backend using Ollama models +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -28,13 +52,46 @@ class Huggingface(Backend): - """Huggingface backend.""" + """Hugging Face backend for vision-language anomaly detection. + + This class implements a backend for using Hugging Face vision-language models for + anomaly detection. It handles: + + - Loading models and processors from Hugging Face Hub + - Processing images into model inputs + - Few-shot learning with reference images + - Model inference and response processing + + Args: + model_name (str): Name of the Hugging Face model to use (e.g. + ``"llava-hf/llava-1.5-7b-hf"``) + + Example: + >>> from anomalib.models.image.vlm_ad.backends import Huggingface + >>> backend = Huggingface( # doctest: +SKIP + ... model_name="llava-hf/llava-1.5-7b-hf" + ... ) + >>> backend.add_reference_images("normal_image.jpg") # doctest: +SKIP + >>> response = backend.predict("test.jpg", prompt) # doctest: +SKIP + + Raises: + ValueError: If transformers package is not installed + + See Also: + - :class:`Backend`: Base class for VLM backends + - :class:`ChatGPT`: Alternative backend using OpenAI models + - :class:`Ollama`: Alternative backend using Ollama models + """ def __init__( self, model_name: str, ) -> None: - """Initialize the Huggingface backend.""" + """Initialize the Huggingface backend. + + Args: + model_name (str): Name of the Hugging Face model to use + """ self.model_name: str = model_name self._ref_images: list[str] = [] self._processor: ProcessorMixin | None = None @@ -42,7 +99,14 @@ def __init__( @property def processor(self) -> "ProcessorMixin": - """Get the Huggingface processor.""" + """Get the Hugging Face processor. + + Returns: + ProcessorMixin: Initialized processor for the model + + Raises: + ValueError: If transformers package is not installed + """ if self._processor is None: if transformers is None: msg = "transformers is not installed." @@ -52,7 +116,14 @@ def processor(self) -> "ProcessorMixin": @property def model(self) -> "PreTrainedModel": - """Get the Huggingface model.""" + """Get the Hugging Face model. + + Returns: + PreTrainedModel: Initialized model instance + + Raises: + ValueError: If transformers package is not installed + """ if self._model is None: if transformers is None: msg = "transformers is not installed." @@ -62,7 +133,15 @@ def model(self) -> "PreTrainedModel": @staticmethod def _generate_message(content: str, images: list[str] | None) -> dict: - """Generate a message.""" + """Generate a message for the model. + + Args: + content (str): Text content of the message + images (list[str] | None): List of image paths to include in message + + Returns: + dict: Formatted message dictionary with role and content + """ message: dict[str, str | list[dict]] = {"role": "user"} _content: list[dict[str, str]] = [{"type": "text", "text": content}] if images is not None: @@ -71,16 +150,32 @@ def _generate_message(content: str, images: list[str] | None) -> dict: return message def add_reference_images(self, image: str | Path) -> None: - """Add reference images for k-shot.""" + """Add reference images for few-shot learning. + + Args: + image (str | Path): Path to the reference image file + """ self._ref_images.append(Image.open(image)) @property def num_reference_images(self) -> int: - """Get the number of reference images.""" + """Get the number of reference images. + + Returns: + int: Number of reference images added + """ return len(self._ref_images) def predict(self, image_path: str | Path, prompt: Prompt) -> str: - """Predict the anomaly label.""" + """Predict whether an image contains anomalies. + + Args: + image_path (str | Path): Path to the image to analyze + prompt (Prompt): Prompt object containing few-shot and prediction prompts + + Returns: + str: Model's prediction response + """ image = Image.open(image_path) messages: list[dict] = [] @@ -93,6 +188,4 @@ def predict(self, image_path: str | Path, prompt: Prompt) -> str: images = [*self._ref_images, image] inputs = self.processor(images, processed_prompt, return_tensors="pt", padding=True).to(self.model.device) outputs = self.model.generate(**inputs, max_new_tokens=100) - result = self.processor.decode(outputs[0], skip_special_tokens=True) - print(result) - return result + return self.processor.decode(outputs[0], skip_special_tokens=True) diff --git a/src/anomalib/models/image/vlm_ad/backends/ollama.py b/src/anomalib/models/image/vlm_ad/backends/ollama.py index ff680bee3b..4c712cdba8 100644 --- a/src/anomalib/models/image/vlm_ad/backends/ollama.py +++ b/src/anomalib/models/image/vlm_ad/backends/ollama.py @@ -1,9 +1,34 @@ -"""Ollama backend. +"""Ollama backend for Vision Language Models (VLMs). -Assumes that the Ollama service is running in the background. -See: https://github.com/ollama/ollama -Ensure that ollama is running. On linux: `ollama serve` -On Mac and Windows ensure that the ollama service is running by launching from the application list. +This module implements a backend for using Ollama models for vision-language tasks in +anomaly detection. The backend handles: + +- Communication with local Ollama service +- Image encoding and formatting +- Few-shot learning with reference images +- Model inference and response processing + +Example: + >>> from anomalib.models.image.vlm_ad.backends import Ollama + >>> backend = Ollama(model_name="llava") # doctest: +SKIP + >>> backend.add_reference_images("normal_image.jpg") # doctest: +SKIP + >>> response = backend.predict("test.jpg", prompt) # doctest: +SKIP + +Note: + Requires Ollama service to be running in the background: + + - Linux: Run ``ollama serve`` + - Mac/Windows: Launch Ollama application from applications list + + See `Ollama documentation `_ for setup details. + +Args: + model_name (str): Name of the Ollama model to use (e.g. ``"llava"``) + +See Also: + - :class:`Backend`: Base class for VLM backends + - :class:`ChatGPT`: Alternative backend using OpenAI models + - :class:`Huggingface`: Alternative backend using Hugging Face models """ # Copyright (C) 2024 Intel Corporation @@ -28,32 +53,94 @@ class Ollama(Backend): - """Ollama backend.""" + """Ollama backend for vision-language anomaly detection. + + This class implements a backend for using Ollama models with vision capabilities + for anomaly detection. It handles: + + - Communication with local Ollama service + - Image encoding and formatting + - Few-shot learning with reference images + - Model inference and response processing + + Args: + model_name (str): Name of the Ollama model to use (e.g. ``"llava"``) + + Example: + >>> from anomalib.models.image.vlm_ad.backends import Ollama + >>> backend = Ollama(model_name="llava") # doctest: +SKIP + >>> backend.add_reference_images("normal_image.jpg") # doctest: +SKIP + >>> response = backend.predict("test.jpg", prompt) # doctest: +SKIP + + Note: + Requires Ollama service to be running in the background: + + - Linux: Run ``ollama serve`` + - Mac/Windows: Launch Ollama application from applications list + + See Also: + - :class:`Backend`: Base class for VLM backends + - :class:`ChatGPT`: Alternative backend using OpenAI models + - :class:`Huggingface`: Alternative backend using Hugging Face models + """ def __init__(self, model_name: str) -> None: - """Initialize the Ollama backend.""" + """Initialize the Ollama backend. + + Args: + model_name (str): Name of the Ollama model to use + """ self.model_name: str = model_name self._ref_images_encoded: list[str] = [] def add_reference_images(self, image: str | Path) -> None: - """Encode the image to base64.""" + """Add and encode reference images for few-shot learning. + + The images are encoded to base64 format for sending to the Ollama service. + + Args: + image (str | Path): Path to the reference image file + """ self._ref_images_encoded.append(_encode_image(image)) @property def num_reference_images(self) -> int: - """Get the number of reference images.""" + """Get the number of reference images. + + Returns: + int: Number of reference images added + """ return len(self._ref_images_encoded) @staticmethod def _generate_message(content: str, images: list[str] | None) -> dict: - """Generate a message.""" + """Generate a message for the Ollama chat API. + + Args: + content (str): Text content of the message + images (list[str] | None): List of base64 encoded images to include + + Returns: + dict: Formatted message dictionary with role, content and optional images + """ message: dict[str, str | list[str]] = {"role": "user", "content": content} if images: message["images"] = images return message def predict(self, image: str | Path, prompt: Prompt) -> str: - """Predict the anomaly label.""" + """Predict whether an image contains anomalies. + + Args: + image (str | Path): Path to the image to analyze + prompt (Prompt): Prompt object containing few-shot and prediction prompts + + Returns: + str: Model's prediction response + + Raises: + ImportError: If Ollama package is not installed + """ if not chat: msg = "Ollama is not installed. Please install it using `pip install ollama`." raise ImportError(msg) diff --git a/src/anomalib/models/image/vlm_ad/lightning_model.py b/src/anomalib/models/image/vlm_ad/lightning_model.py index 7340474f29..92a52a7c75 100644 --- a/src/anomalib/models/image/vlm_ad/lightning_model.py +++ b/src/anomalib/models/image/vlm_ad/lightning_model.py @@ -1,4 +1,29 @@ -"""Visual Anomaly Model for Zero/Few-Shot Anomaly Classification.""" +"""Vision Language Model (VLM) based Anomaly Detection. + +This module implements anomaly detection using Vision Language Models (VLMs) like +GPT-4V, LLaVA, etc. The models use natural language prompting to detect anomalies +in images by comparing them with reference normal images. + +The module supports both zero-shot and few-shot learning approaches: + +- Zero-shot: No reference images needed +- Few-shot: Uses ``k`` reference normal images for better context + +Example: + >>> from anomalib.models.image import VlmAd + >>> model = VlmAd( # doctest: +SKIP + ... model="gpt-4-vision-preview", + ... api_key="YOUR_API_KEY", + ... k_shot=3 + ... ) + >>> model.fit(["normal1.jpg", "normal2.jpg"]) # doctest: +SKIP + >>> prediction = model.predict("test.jpg") # doctest: +SKIP + +See Also: + - :class:`VlmAd`: Main model class for VLM-based anomaly detection + - :mod:`.backends`: Different VLM backend implementations + - :mod:`.utils`: Utility functions for prompting and responses +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -21,7 +46,41 @@ class VlmAd(AnomalibModule): - """Visual anomaly model.""" + """Vision Language Model (VLM) based anomaly detection model. + + This model uses VLMs like GPT-4V, LLaVA, etc. to detect anomalies in images by + comparing them with reference normal images through natural language prompting. + + Args: + model (ModelName | str): Name of the VLM model to use. Can be one of: + - ``ModelName.LLAMA_OLLAMA`` + - ``ModelName.GPT_4O_MINI`` + - ``ModelName.VICUNA_7B_HF`` + - ``ModelName.VICUNA_13B_HF`` + - ``ModelName.MISTRAL_7B_HF`` + Defaults to ``ModelName.LLAMA_OLLAMA``. + api_key (str | None, optional): API key for models that require + authentication. Defaults to None. + k_shot (int, optional): Number of reference normal images to use for + few-shot learning. If 0, uses zero-shot approach. Defaults to 0. + + Example: + >>> from anomalib.models.image import VlmAd + >>> # Zero-shot approach + >>> model = VlmAd( # doctest: +SKIP + ... model="gpt-4-vision-preview", + ... api_key="YOUR_API_KEY" + ... ) + >>> # Few-shot approach with 3 reference images + >>> model = VlmAd( # doctest: +SKIP + ... model="gpt-4-vision-preview", + ... api_key="YOUR_API_KEY", + ... k_shot=3 + ... ) + + Raises: + ValueError: If an unsupported VLM model is specified. + """ def __init__( self, @@ -53,7 +112,12 @@ def _setup(self) -> None: self.collect_reference_images(dataloader) def collect_reference_images(self, dataloader: DataLoader) -> None: - """Collect reference images for few-shot inference.""" + """Collect reference images for few-shot inference. + + Args: + dataloader (DataLoader): DataLoader containing normal images for + reference. + """ for batch in dataloader: for img_path in batch.image_path: self.vlm_backend.add_reference_images(img_path) @@ -62,7 +126,11 @@ def collect_reference_images(self, dataloader: DataLoader) -> None: @property def prompt(self) -> Prompt: - """Get the prompt.""" + """Get the prompt for VLM interaction. + + Returns: + Prompt: Object containing prompts for prediction and few-shot learning. + """ return Prompt( predict=( "You are given an image. It is either normal or anomalous." @@ -78,7 +146,16 @@ def prompt(self) -> Prompt: ) def validation_step(self, batch: ImageBatch, *args, **kwargs) -> ImageBatch: - """Validation step.""" + """Perform validation step. + + Args: + batch (ImageBatch): Batch of images to validate. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + ImageBatch: Batch with predictions and explanations added. + """ del args, kwargs # These variables are not used. assert batch.image_path is not None responses = [(self.vlm_backend.predict(img_path, self.prompt)) for img_path in batch.image_path] @@ -88,30 +165,48 @@ def validation_step(self, batch: ImageBatch, *args, **kwargs) -> ImageBatch: @property def learning_type(self) -> LearningType: - """The learning type of the model.""" + """Get the learning type of the model. + + Returns: + LearningType: ZERO_SHOT if k_shot=0, else FEW_SHOT. + """ return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT @property def trainer_arguments(self) -> dict[str, int | float]: - """Doesn't need training.""" + """Get trainer arguments. + + Returns: + dict[str, int | float]: Empty dict as no training is needed. + """ return {} @staticmethod def configure_transforms(image_size: tuple[int, int] | None = None) -> None: - """This modes does not require any transforms.""" + """Configure image transforms. + + Args: + image_size (tuple[int, int] | None, optional): Ignored as each backend + has its own transforms. Defaults to None. + """ if image_size is not None: logger.warning("Ignoring image_size argument as each backend has its own transforms.") @classmethod def configure_post_processor(cls) -> PostProcessor | None: - """Post processing is not required for this model.""" + """Configure post processor. + + Returns: + PostProcessor | None: None as post processing is not required. + """ return None @staticmethod def configure_evaluator() -> Evaluator: - """Default evaluator. + """Configure default evaluator. - Override in subclass for model-specific evaluator behaviour. + Returns: + Evaluator: Evaluator configured with F1Score metric. """ image_f1score = F1Score(fields=["pred_label", "gt_label"], prefix="image_") return Evaluator(test_metrics=image_f1score) diff --git a/src/anomalib/models/image/vlm_ad/utils.py b/src/anomalib/models/image/vlm_ad/utils.py index ce9b9067ac..dec6f05327 100644 --- a/src/anomalib/models/image/vlm_ad/utils.py +++ b/src/anomalib/models/image/vlm_ad/utils.py @@ -1,4 +1,22 @@ -"""Dataclasses.""" +"""Utility classes and functions for Vision Language Model (VLM) based anomaly detection. + +This module provides utility classes for VLM-based anomaly detection: + +- :class:`Prompt`: Dataclass for storing few-shot and prediction prompts +- :class:`ModelName`: Enum of supported VLM models + +Example: + >>> from anomalib.models.image.vlm_ad.utils import Prompt, ModelName + >>> prompt = Prompt( # doctest: +SKIP + ... few_shot="These are normal examples...", + ... predict="Is this image normal or anomalous?" + ... ) + >>> model_name = ModelName.LLAMA_OLLAMA # doctest: +SKIP + +See Also: + - :class:`VlmAd`: Main model class using these utilities + - :mod:`.backends`: VLM backend implementations using these utilities +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,14 +27,56 @@ @dataclass class Prompt: - """Prompt.""" + """Dataclass for storing prompts used in VLM-based anomaly detection. + + This class stores two types of prompts used when querying vision language models: + + - Few-shot prompt: Used to provide context about normal examples + - Prediction prompt: Used to query about a specific test image + + Args: + few_shot (str): Prompt template for few-shot learning with reference normal + images. Used to establish context about what constitutes normal. + predict (str): Prompt template for querying about test images. Used to ask + the model whether a given image contains anomalies. + + Example: + >>> from anomalib.models.image.vlm_ad.utils import Prompt + >>> prompt = Prompt( # doctest: +SKIP + ... few_shot="Here are some examples of normal items...", + ... predict="Is this image normal or does it contain defects?" + ... ) + + See Also: + - :class:`VlmAd`: Main model class using these prompts + - :mod:`.backends`: VLM backend implementations using these prompts + """ few_shot: str predict: str class ModelName(Enum): - """List of supported models.""" + """Enumeration of supported Vision Language Models (VLMs). + + This enum defines the available VLM models that can be used for anomaly detection: + + - ``LLAMA_OLLAMA``: LLaVA model running via Ollama + - ``GPT_4O_MINI``: GPT-4O Mini model + - ``VICUNA_7B_HF``: LLaVA v1.6 with Vicuna 7B base from Hugging Face + - ``VICUNA_13B_HF``: LLaVA v1.6 with Vicuna 13B base from Hugging Face + - ``MISTRAL_7B_HF``: LLaVA v1.6 with Mistral 7B base from Hugging Face + + Example: + >>> from anomalib.models.image.vlm_ad.utils import ModelName + >>> model_name = ModelName.LLAMA_OLLAMA # doctest: +SKIP + >>> model_name.value + 'llava' + + See Also: + - :class:`VlmAd`: Main model class using these model options + - :mod:`.backends`: Backend implementations for different models + """ LLAMA_OLLAMA = "llava" GPT_4O_MINI = "gpt-4o-mini" diff --git a/src/anomalib/models/image/winclip/__init__.py b/src/anomalib/models/image/winclip/__init__.py index 8435a3c1aa..86f2b72691 100644 --- a/src/anomalib/models/image/winclip/__init__.py +++ b/src/anomalib/models/image/winclip/__init__.py @@ -1,4 +1,18 @@ -"""WinCLIP Model.""" +"""WinCLIP Model for anomaly detection. + +This module implements anomaly detection using the WinCLIP model, which leverages +CLIP embeddings and a sliding window approach to detect anomalies in images. + +Example: + >>> from anomalib.models.image import WinClip + >>> model = WinClip() # doctest: +SKIP + >>> model.fit(["normal1.jpg", "normal2.jpg"]) # doctest: +SKIP + >>> prediction = model.predict("test.jpg") # doctest: +SKIP + +See Also: + - :class:`WinClip`: Main model class for WinCLIP-based anomaly detection + - :class:`WinClipModel`: PyTorch implementation of the WinCLIP model +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/image/winclip/lightning_model.py b/src/anomalib/models/image/winclip/lightning_model.py index 23a7cf23a1..e078f60e50 100644 --- a/src/anomalib/models/image/winclip/lightning_model.py +++ b/src/anomalib/models/image/winclip/lightning_model.py @@ -1,6 +1,28 @@ """WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation. -Paper https://arxiv.org/abs/2303.14814 +This module implements the WinCLIP model for zero-shot and few-shot anomaly +detection using CLIP embeddings and a sliding window approach. + +The model can perform both anomaly classification and segmentation tasks by +comparing image regions with normal reference examples through CLIP embeddings. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.engine import Engine + >>> from anomalib.models.image import WinClip + + >>> datamodule = MVTec(root="./datasets/MVTec") # doctest: +SKIP + >>> model = WinClip() # doctest: +SKIP + + >>> Engine.test(model=model, datamodule=datamodule) # doctest: +SKIP + +Paper: + WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation + https://arxiv.org/abs/2303.14814 + +See Also: + - :class:`WinClip`: Main model class for WinCLIP-based anomaly detection + - :class:`WinClipModel`: PyTorch implementation of the WinCLIP model """ # Copyright (C) 2024 Intel Corporation @@ -34,18 +56,49 @@ class WinClip(AnomalibModule): """WinCLIP Lightning model. + This model implements the WinCLIP algorithm for zero-/few-shot anomaly detection using CLIP + embeddings and a sliding window approach. The model can perform both anomaly classification + and segmentation by comparing image regions with normal reference examples. + Args: - class_name (str, optional): The name of the object class used in the prompt ensemble. - Defaults to ``None``. - k_shot (int): The number of reference images for few-shot inference. - Defaults to ``0``. - scales (tuple[int], optional): The scales of the sliding windows used for multiscale anomaly detection. - Defaults to ``(2, 3)``. - few_shot_source (str | Path, optional): Path to a folder of reference images used for few-shot inference. - Defaults to ``None``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + class_name (str | None, optional): Name of the object class used in the prompt + ensemble. If not provided, will try to infer from the datamodule or use "object" + as default. Defaults to ``None``. + k_shot (int, optional): Number of reference images to use for few-shot inference. + If 0, uses zero-shot approach. Defaults to ``0``. + scales (tuple[int], optional): Scales of sliding windows used for multiscale anomaly + detection. Defaults to ``(2, 3)``. + few_shot_source (str | Path | None, optional): Path to folder containing reference + images for few-shot inference. If not provided, reference images are sampled from + training data. Defaults to ``None``. + pre_processor (PreProcessor | bool, optional): Pre-processor instance or flag to use + default. Used to pre-process input data before model inference. Defaults to + ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance or flag to + use default. Used to post-process model predictions. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or flag to use default. + Used to compute metrics. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or flag to use default. + Used to create visualizations. Defaults to ``True``. + + Example: + >>> from anomalib.models.image import WinClip + >>> # Zero-shot approach + >>> model = WinClip() # doctest: +SKIP + >>> # Few-shot with 5 reference images + >>> model = WinClip(k_shot=5) # doctest: +SKIP + >>> # Custom class name + >>> model = WinClip(class_name="transistor") # doctest: +SKIP + + Notes: + - The model automatically excludes CLIP backbone parameters from checkpoints to + reduce size + - Input image size is fixed at 240x240 and cannot be modified + - Uses a custom normalization transform specific to CLIP + + See Also: + - :class:`WinClipModel`: PyTorch implementation of the core model + - :class:`OneClassPostProcessor`: Default post-processor used by WinCLIP """ EXCLUDE_FROM_STATE_DICT = frozenset({"model.clip"}) @@ -74,13 +127,15 @@ def __init__( self.few_shot_source = Path(few_shot_source) if few_shot_source else None def _setup(self) -> None: - """Setup WinCLIP. + """Setup WinCLIP model. - - Set the class name used in the prompt ensemble. - - Collect text embeddings for zero-shot inference. - - Collect reference images for few-shot inference. + This method: + - Sets the class name used in the prompt ensemble + - Collects text embeddings for zero-shot inference + - Collects reference images for few-shot inference if ``k_shot > 0`` - We need to pass the device because this hook is called before the model is moved to the device. + Note: + This hook is called before the model is moved to the target device. """ # get class name self.class_name = self._get_class_name() @@ -105,12 +160,15 @@ def _setup(self) -> None: self.model.setup(self.class_name, ref_images) def _get_class_name(self) -> str: - """Set the class name used in the prompt ensemble. + """Get the class name used in the prompt ensemble. - - When a class name is provided by the user, it is used. - - When the user did not provide a class name, the category name from the datamodule is used, if available. - - When the user did not provide a class name and the datamodule does not have a category name, the default - class name "object" is used. + The class name is determined in the following order: + 1. Use class name provided in initialization + 2. Use category name from datamodule if available + 3. Use default value "object" + + Returns: + str: Class name to use in prompts """ if self.class_name is not None: logger.info("Using class name from init args: %s", self.class_name) @@ -124,11 +182,14 @@ class name "object" is used. def collect_reference_images(self, dataloader: DataLoader) -> torch.Tensor: """Collect reference images for few-shot inference. - The reference images are collected by iterating the training dataset until the required number of images are - collected. + Iterates through the training dataset until the required number of reference images + (specified by ``k_shot``) are collected. + + Args: + dataloader (DataLoader): DataLoader to collect reference images from Returns: - ref_images (Tensor): A tensor containing the reference images. + torch.Tensor: Tensor containing the collected reference images """ ref_images = torch.Tensor() for batch in dataloader: @@ -140,34 +201,56 @@ def collect_reference_images(self, dataloader: DataLoader) -> torch.Tensor: @staticmethod def configure_optimizers() -> None: - """WinCLIP doesn't require optimization, therefore returns no optimizers.""" + """Configure optimizers. + + WinCLIP doesn't require optimization, therefore returns no optimizers. + """ return def validation_step(self, batch: Batch, *args, **kwargs) -> dict: - """Validation Step of WinCLIP.""" + """Validation Step of WinCLIP. + + Args: + batch (Batch): Input batch + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments + + Returns: + dict: Dictionary containing the batch updated with predictions + """ del args, kwargs # These variables are not used. predictions = self.model(batch.image) return batch.update(**predictions._asdict()) @property def trainer_arguments(self) -> dict[str, int | float]: - """Set model-specific trainer arguments.""" + """Get model-specific trainer arguments. + + Returns: + dict[str, int | float]: Empty dictionary as WinCLIP needs no special arguments + """ return {} @property def learning_type(self) -> LearningType: - """The learning type of the model. + """Get the learning type of the model. - WinCLIP is a zero-/few-shot model, depending on the user configuration. Therefore, the learning type is - set to ``LearningType.FEW_SHOT`` when ``k_shot`` is greater than zero and ``LearningType.ZERO_SHOT`` otherwise. + Returns: + LearningType: ``LearningType.FEW_SHOT`` if ``k_shot > 0``, else + ``LearningType.ZERO_SHOT`` """ return LearningType.FEW_SHOT if self.k_shot else LearningType.ZERO_SHOT def state_dict(self, **kwargs) -> OrderedDict[str, Any]: - """Return the state dict of the model. + """Get the state dict of the model. + + Removes parameters of the frozen backbone to reduce checkpoint size. - Before returning the state dict, we remove the parameters of the frozen backbone to reduce the size of the - checkpoint. + Args: + **kwargs: Additional arguments to pass to parent's state_dict + + Returns: + OrderedDict[str, Any]: State dict with backbone parameters removed """ state_dict = super().state_dict(**kwargs) for pattern in self.EXCLUDE_FROM_STATE_DICT: @@ -179,8 +262,16 @@ def state_dict(self, **kwargs) -> OrderedDict[str, Any]: def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True) -> Any: # noqa: ANN401 """Load the state dict of the model. - Before loading the state dict, we restore the parameters of the frozen backbone to ensure that the model - is loaded correctly. We also restore the auxiliary objects like threshold classes and normalization metrics. + Restores backbone parameters before loading to ensure correct model initialization. + + Args: + state_dict (OrderedDict[str, Any]): State dict to load + strict (bool, optional): Whether to strictly enforce that the keys in + ``state_dict`` match the keys returned by this module's + ``state_dict()`` function. Defaults to ``True``. + + Returns: + Any: Return value from parent's load_state_dict """ # restore the parameters of the excluded modules, if any full_dict = super().state_dict() @@ -191,7 +282,15 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True @classmethod def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor: - """Configure the default pre-processor used by the model.""" + """Configure the default pre-processor used by the model. + + Args: + image_size (tuple[int, int] | None, optional): Not used as WinCLIP has fixed + input size. Defaults to ``None``. + + Returns: + PreProcessor: Configured pre-processor with CLIP-specific transforms + """ if image_size is not None: logger.warning("Image size is not used in WinCLIP. The input image size is determined by the model.") @@ -203,5 +302,9 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P @staticmethod def configure_post_processor() -> OneClassPostProcessor: - """Return the default post-processor for WinCLIP.""" + """Configure the default post-processor for WinCLIP. + + Returns: + OneClassPostProcessor: Default post-processor instance + """ return OneClassPostProcessor() diff --git a/src/anomalib/models/image/winclip/prompting.py b/src/anomalib/models/image/winclip/prompting.py index f33a63d1f4..2c3d661b28 100644 --- a/src/anomalib/models/image/winclip/prompting.py +++ b/src/anomalib/models/image/winclip/prompting.py @@ -1,4 +1,24 @@ -"""Compositional prompt ensemble for WinCLIP.""" +"""Compositional prompt ensemble for WinCLIP. + +This module provides prompt templates and utilities for generating prompt ensembles +used by the WinCLIP model. The prompts are used to query CLIP about normal and +anomalous states of objects. + +The module contains: + - Lists of normal and anomalous state descriptors + - Templates for constructing image description prompts + - Functions to generate prompt ensembles by combining states and templates + +Example: + >>> from anomalib.models.image.winclip.prompting import create_prompt_ensemble + >>> normal, anomalous = create_prompt_ensemble("transistor") # doctest: +SKIP + >>> print(normal[0]) # doctest: +SKIP + 'a photo of a transistor.' + +See Also: + - :class:`WinClip`: Main model class using these prompts + - :class:`WinClipModel`: PyTorch model implementation +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -46,22 +66,39 @@ def create_prompt_ensemble(class_name: str = "object") -> tuple[list[str], list[str]]: - """Create prompt ensemble for WinCLIP. + """Create prompt ensemble for WinCLIP model. - All combinations of states and templates are generated for both normal and anomalous prompts. + This function generates a comprehensive set of text prompts used by the WinCLIP model for + zero-shot anomaly detection. It creates two sets of prompts: + + 1. Normal prompts describing non-anomalous objects + 2. Anomalous prompts describing objects with defects + + The prompts are generated by combining predefined states (e.g., "flawless", "damaged") + with templates (e.g., "a photo of a {}") for the given object class. Args: - class_name (str): Name of the object. + class_name (str, optional): Name of the object class to use in the prompts. + Defaults to ``"object"``. Returns: - tuple[list[str], list[str]]: Tuple containing the normal and anomalous prompts. + tuple[list[str], list[str]]: A tuple containing: + - List of normal prompts describing non-anomalous objects + - List of anomalous prompts describing defective objects + + Example: + Generate prompts for the "bottle" class: - Examples: >>> normal_prompts, anomalous_prompts = create_prompt_ensemble("bottle") - >>> normal_prompts[:2] - ['a cropped photo of the bottle.', 'a close-up photo of a bottle.'] - >>> anomalous_prompts[:2] - ['a cropped photo of the damaged bottle.', 'a close-up photo of a damaged bottle.'] + >>> print(normal_prompts[0]) + 'a cropped photo of the bottle.' + >>> print(anomalous_prompts[0]) + 'a cropped photo of the damaged bottle.' + + See Also: + - :data:`NORMAL_STATES`: Predefined states for normal objects + - :data:`ANOMALOUS_STATES`: Predefined states for anomalous objects + - :data:`TEMPLATES`: Predefined templates for prompt generation """ normal_states = [state.format(class_name) for state in NORMAL_STATES] normal_ensemble = [template.format(state) for state in normal_states for template in TEMPLATES] diff --git a/src/anomalib/models/image/winclip/torch_model.py b/src/anomalib/models/image/winclip/torch_model.py index 8d2bfc69f9..9847030074 100644 --- a/src/anomalib/models/image/winclip/torch_model.py +++ b/src/anomalib/models/image/winclip/torch_model.py @@ -1,4 +1,29 @@ -"""PyTorch model for the WinCLIP implementation.""" +"""PyTorch model implementation of WinCLIP for zero-/few-shot anomaly detection. + +This module provides the core PyTorch model implementation of WinCLIP, which uses +CLIP embeddings and a sliding window approach to detect anomalies in images. + +The model can operate in both zero-shot and few-shot modes: +- Zero-shot: No reference images needed +- Few-shot: Uses ``k`` reference normal images for better context + +Example: + >>> from anomalib.models.image.winclip.torch_model import WinClipModel + >>> model = WinClipModel() # doctest: +SKIP + >>> # Zero-shot inference + >>> prediction = model(image) # doctest: +SKIP + >>> # Few-shot with reference images + >>> model = WinClipModel(reference_images=normal_images) # doctest: +SKIP + +Paper: + WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation + https://arxiv.org/abs/2303.14814 + +See Also: + - :class:`WinClip`: Lightning model wrapper + - :mod:`.prompting`: Prompt ensemble generation + - :mod:`.utils`: Utility functions for scoring and aggregation +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -27,25 +52,42 @@ class WinClipModel(DynamicBufferMixin, BufferListMixin, nn.Module): """PyTorch module that implements the WinClip model for image anomaly detection. + The model uses CLIP embeddings and a sliding window approach to detect anomalies in + images. It can operate in both zero-shot and few-shot modes. + Args: - class_name (str, optional): The name of the object class used in the prompt ensemble. - Defaults to ``None``. - reference_images (torch.Tensor, optional): Tensor of shape ``(K, C, H, W)`` containing the reference images. - Defaults to ``None``. - scales (tuple[int], optional): The scales of the sliding windows used for multi-scale anomaly detection. - Defaults to ``(2, 3)``. - apply_transform (bool, optional): Whether to apply the default CLIP transform to the input images. - Defaults to ``False``. + class_name (str | None, optional): Name of the object class used in prompt + ensemble. Defaults to ``None``. + reference_images (torch.Tensor | None, optional): Reference images of shape + ``(K, C, H, W)``. Defaults to ``None``. + scales (tuple[int], optional): Scales of sliding windows for multi-scale + detection. Defaults to ``(2, 3)``. + apply_transform (bool, optional): Whether to apply default CLIP transform to + inputs. Defaults to ``False``. Attributes: - clip (CLIP): The CLIP model used for image and text encoding. - grid_size (tuple[int]): The size of the feature map grid. - k_shot (int): The number of reference images used for few-shot anomaly detection. - scales (tuple[int]): The scales of the sliding windows used for multi-scale anomaly detection. - masks (list[torch.Tensor] | None): The masks representing the sliding window locations. - _text_embeddings (torch.Tensor | None): The text embeddings for the compositional prompt ensemble. - _visual_embeddings (list[torch.Tensor] | None): The multi-scale embeddings for the reference images. - _patch_embeddings (torch.Tensor | None): The patch embeddings for the reference images. + clip (CLIP): CLIP model for image and text encoding. + grid_size (tuple[int]): Size of feature map grid. + k_shot (int): Number of reference images for few-shot detection. + scales (tuple[int]): Scales of sliding windows. + masks (list[torch.Tensor] | None): Masks for sliding window locations. + _text_embeddings (torch.Tensor | None): Text embeddings for prompt ensemble. + _visual_embeddings (list[torch.Tensor] | None): Multi-scale reference embeddings. + _patch_embeddings (torch.Tensor | None): Patch embeddings for reference images. + + Example: + >>> from anomalib.models.image.winclip.torch_model import WinClipModel + >>> # Zero-shot mode + >>> model = WinClipModel(class_name="transistor") # doctest: +SKIP + >>> image = torch.rand(1, 3, 224, 224) # doctest: +SKIP + >>> prediction = model(image) # doctest: +SKIP + >>> + >>> # Few-shot mode with reference images + >>> ref_images = torch.rand(3, 3, 224, 224) # doctest: +SKIP + >>> model = WinClipModel( # doctest: +SKIP + ... class_name="transistor", + ... reference_images=ref_images + ... ) """ def __init__( @@ -80,46 +122,50 @@ def __init__( self.setup(class_name, reference_images) def setup(self, class_name: str | None = None, reference_images: torch.Tensor | None = None) -> None: - """Setup WinCLIP. + """Setup WinCLIP model with class name and/or reference images. - WinCLIP's setup stage consists of collecting the text and visual embeddings used during inference. The - following steps are performed, depending on the arguments passed to the model: - - Collect text embeddings for zero-shot inference. - - Collect reference images for few-shot inference. - The k_shot attribute is updated based on the number of reference images. + The setup stage collects text and visual embeddings used during inference: + - Text embeddings for zero-shot inference if ``class_name`` provided + - Visual embeddings for few-shot inference if ``reference_images`` provided + The ``k_shot`` attribute is updated based on number of reference images. - The setup method is called internally by the constructor. However, it can also be called manually to update the - text and visual embeddings after the model has been initialized. + This method is called by constructor but can also be called manually to update + embeddings after initialization. Args: - class_name (str): The name of the object class used in the prompt ensemble. - reference_images (torch.Tensor): Tensor of shape ``(batch_size, C, H, W)`` containing the reference images. + class_name (str | None, optional): Name of object class for prompt ensemble. + Defaults to ``None``. + reference_images (torch.Tensor | None, optional): Reference images of shape + ``(batch_size, C, H, W)``. Defaults to ``None``. Examples: - >>> model = WinClipModel() - >>> model.setup("transistor") - >>> model.text_embeddings.shape + >>> model = WinClipModel() # doctest: +SKIP + >>> model.setup("transistor") # doctest: +SKIP + >>> model.text_embeddings.shape # doctest: +SKIP torch.Size([2, 640]) - >>> ref_images = torch.rand(2, 3, 240, 240) - >>> model = WinClipModel() - >>> model.setup("transistor", ref_images) - >>> model.k_shot + >>> ref_images = torch.rand(2, 3, 240, 240) # doctest: +SKIP + >>> model = WinClipModel() # doctest: +SKIP + >>> model.setup("transistor", ref_images) # doctest: +SKIP + >>> model.k_shot # doctest: +SKIP 2 - >>> model.visual_embeddings[0].shape + >>> model.visual_embeddings[0].shape # doctest: +SKIP torch.Size([2, 196, 640]) - >>> model = WinClipModel("transistor") - >>> model.k_shot + >>> model = WinClipModel("transistor") # doctest: +SKIP + >>> model.k_shot # doctest: +SKIP 0 - >>> model.setup(reference_images=ref_images) - >>> model.k_shot + >>> model.setup(reference_images=ref_images) # doctest: +SKIP + >>> model.k_shot # doctest: +SKIP 2 - >>> model = WinClipModel(class_name="transistor", reference_images=ref_images) - >>> model.text_embeddings.shape + >>> model = WinClipModel( # doctest: +SKIP + ... class_name="transistor", + ... reference_images=ref_images + ... ) + >>> model.text_embeddings.shape # doctest: +SKIP torch.Size([2, 640]) - >>> model.visual_embeddings[0].shape + >>> model.visual_embeddings[0].shape # doctest: +SKIP torch.Size([2, 196, 640]) """ # update class name and text embeddings @@ -133,29 +179,35 @@ def setup(self, class_name: str | None = None, reference_images: torch.Tensor | self._collect_visual_embeddings(self.reference_images) def encode_image(self, batch: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: - """Encode the batch of images to obtain image embeddings, window embeddings, and patch embeddings. + """Encode batch of images to get image, window and patch embeddings. - The image embeddings and patch embeddings are obtained by passing the batch of images through the model. The - window embeddings are obtained by masking the feature map and passing it through the transformer. A forward hook - is used to retrieve the intermediate feature map and share computation between the image and window embeddings. + The image and patch embeddings are obtained by passing images through the model. + Window embeddings are obtained by masking feature map and passing through + transformer. A forward hook retrieves intermediate feature map to share + computation. Args: - batch (torch.Tensor): Batch of input images of shape ``(N, C, H, W)``. + batch (torch.Tensor): Input images of shape ``(N, C, H, W)``. Returns: - Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: A tuple containing the image embeddings, - window embeddings, and patch embeddings respectively. + tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: Tuple containing: + - Image embeddings of shape ``(N, D)`` + - Window embeddings list, each of shape ``(N, W, D)`` + - Patch embeddings of shape ``(N, P, D)`` + where ``D`` is embedding dimension, ``W`` is number of windows, + and ``P`` is number of patches. Examples: - >>> model = WinClipModel() - >>> model.prepare_masks() - >>> batch = torch.rand((1, 3, 240, 240)) - >>> image_embeddings, window_embeddings, patch_embeddings = model.encode_image(batch) - >>> image_embeddings.shape + >>> model = WinClipModel() # doctest: +SKIP + >>> model.prepare_masks() # doctest: +SKIP + >>> batch = torch.rand((1, 3, 240, 240)) # doctest: +SKIP + >>> outputs = model.encode_image(batch) # doctest: +SKIP + >>> image_embeddings, window_embeddings, patch_embeddings = outputs + >>> image_embeddings.shape # doctest: +SKIP torch.Size([1, 640]) - >>> [embedding.shape for embedding in window_embeddings] + >>> [emb.shape for emb in window_embeddings] # doctest: +SKIP [torch.Size([1, 196, 640]), torch.Size([1, 169, 640])] - >>> patch_embeddings.shape + >>> patch_embeddings.shape # doctest: +SKIP torch.Size([1, 225, 896]) """ # apply transform if needed @@ -189,14 +241,16 @@ def hook(_model: Identity, inputs: tuple[torch.Tensor,], _outputs: torch.Tensor) ) def _get_window_embeddings(self, feature_map: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: - """Computes the embeddings for each window in the feature map using the given masks. + """Compute embeddings for each window in feature map using given masks. Args: - feature_map (torch.Tensor): The input feature map of shape ``(n_batches, n_patches, dimensionality)``. - masks (torch.Tensor): Masks of shape ``(kernel_size, n_masks)`` representing the sliding window locations. + feature_map (torch.Tensor): Input features of shape + ``(n_batches, n_patches, dimensionality)``. + masks (torch.Tensor): Window location masks of shape + ``(kernel_size, n_masks)``. Returns: - torch.Tensor: The embeddings for each sliding window location. + torch.Tensor: Embeddings for each sliding window location. """ batch_size = feature_map.shape[0] n_masks = masks.shape[1] @@ -225,13 +279,16 @@ def _get_window_embeddings(self, feature_map: torch.Tensor, masks: torch.Tensor) @torch.no_grad def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | InferenceBatch: - """Forward-pass through the model to obtain image and pixel scores. + """Forward pass to get image and pixel anomaly scores. Args: - batch (torch.Tensor): Batch of input images of shape ``(batch_size, C, H, W)``. + batch (torch.Tensor): Input images of shape ``(batch_size, C, H, W)``. Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple containing the image scores and pixel scores. + tuple[torch.Tensor, torch.Tensor] | InferenceBatch: Either tuple containing: + - Image scores of shape ``(batch_size,)`` + - Pixel scores of shape ``(batch_size, H, W)`` + Or ``InferenceBatch`` with same information. """ image_embeddings, window_embeddings, patch_embeddings = self.encode_image(batch) @@ -258,19 +315,20 @@ def _compute_zero_shot_scores( image_scores: torch.Tensor, window_embeddings: list[torch.Tensor], ) -> torch.Tensor: - """Compute the multi-scale anomaly score maps based on the text embeddings. + """Compute multi-scale anomaly score maps using text embeddings. - Each window embedding is compared to the text embeddings to obtain a similarity score for each window. Harmonic - averaging is then used to aggregate the scores for each window into a single score map for each scale. Finally, - the score maps are combined into a single multi-scale score map by aggregating across scales. + Each window embedding is compared to text embeddings for similarity scores. + Harmonic averaging aggregates window scores into score maps per scale. + Score maps are combined into single multi-scale map by cross-scale + aggregation. Args: - image_scores (torch.Tensor): Tensor of shape ``(batch_size)`` representing the full image scores. - window_embeddings (list[torch.Tensor]): List of tensors of shape ``(batch_size, n_windows, n_features)`` - representing the embeddings for each sliding window location. + image_scores (torch.Tensor): Full image scores of shape ``(batch_size)``. + window_embeddings (list[torch.Tensor]): Window embeddings list, each of + shape ``(batch_size, n_windows, n_features)``. Returns: - torch.Tensor: Tensor of shape ``(batch_size, H, W)`` representing the 0-shot scores for each patch location. + torch.Tensor: Zero-shot scores of shape ``(batch_size, H, W)``. """ # image scores are added to represent the full image scale multi_scale_scores = [image_scores.view(-1, 1, 1).repeat(1, self.grid_size[0], self.grid_size[1])] @@ -286,21 +344,21 @@ def _compute_few_shot_scores( patch_embeddings: torch.Tensor, window_embeddings: list[torch.Tensor], ) -> torch.Tensor: - """Compute the multi-scale anomaly score maps based on the reference image embeddings. + """Compute multi-scale anomaly score maps using reference embeddings. - Visual association scores are computed between the extracted embeddings and the reference image embeddings for - each scale. The window-level scores are additionally aggregated into a single score map for each scale using - harmonic averaging. The final score maps are obtained by averaging across scales. + Visual association scores are computed between extracted embeddings and + reference embeddings at each scale. Window scores are aggregated into score + maps per scale using harmonic averaging. Final maps obtained by averaging + across scales. Args: patch_embeddings (torch.Tensor): Full-scale patch embeddings of shape ``(batch_size, n_patches, n_features)``. - window_embeddings (list[torch.Tensor]): List of tensors of shape ``(batch_size, n_windows, n_features)`` - representing the embeddings for each sliding window location. + window_embeddings (list[torch.Tensor]): Window embeddings list, each of + shape ``(batch_size, n_windows, n_features)``. Returns: - torch.Tensor: Tensor of shape ``(batch_size, H, W)`` representing the few-shot scores for each patch - location. + torch.Tensor: Few-shot scores of shape ``(batch_size, H, W)``. """ multi_scale_scores = [ visual_association_score(patch_embeddings, self.patch_embeddings).reshape((-1, *self.grid_size)), @@ -318,15 +376,14 @@ def _compute_few_shot_scores( @torch.no_grad def _collect_text_embeddings(self, class_name: str) -> None: - """Collect text embeddings for the object class using a compositional prompt ensemble. + """Collect text embeddings using compositional prompt ensemble. - First, an ensemble of normal and anomalous prompts is created based on the name of the object class. The - prompt ensembles are then tokenized and encoded to obtain prompt embeddings. The prompt embeddings are - averaged to obtain a single text embedding for the object class. These final text embeddings are stored in - the model to be used during inference. + Creates ensemble of normal and anomalous prompts based on class name. + Prompts are tokenized and encoded to get embeddings. Embeddings are averaged + per class and stored for inference. Args: - class_name (str): The name of the object class used in the prompt ensemble. + class_name (str): Object class name for prompt ensemble. """ # get the device, this is to ensure that we move the text embeddings to the same device as the model device = next(self.parameters()).device @@ -347,31 +404,34 @@ def _collect_text_embeddings(self, class_name: str) -> None: @torch.no_grad def _collect_visual_embeddings(self, images: torch.Tensor) -> None: - """Collect visual embeddings based on a set of normal reference images. + """Collect visual embeddings from normal reference images. Args: - images (torch.Tensor): Tensor of shape ``(K, C, H, W)`` containing the reference images. + images (torch.Tensor): Reference images of shape ``(K, C, H, W)``. """ _, self._visual_embeddings, self._patch_embeddings = self.encode_image(images) def _generate_masks(self) -> list[torch.Tensor]: - """Prepare a set of masks that operate as multi-scale sliding windows. + """Prepare multi-scale sliding window masks. - For each of the scales, a set of masks is created that select patches from the feature map. Each mask represents - a sliding window location in the pixel domain. The masks are stored in the model to be used during inference. + Creates masks for each scale that select patches from feature map. Each mask + represents a sliding window location. Masks are stored for inference. Returns: - list[torch.Tensor]: A list of tensors of shape ``(n_patches_per_mask, n_masks)`` representing the sliding - window locations for each scale. + list[torch.Tensor]: List of masks, each of shape + ``(n_patches_per_mask, n_masks)``. """ return [make_masks(self.grid_size, scale, 1) for scale in self.scales] @property def transform(self) -> Compose: - """The transform used by the model. + """Get model's transform pipeline. + + Retrieves transforms from CLIP backbone and prepends ``ToPILImage`` transform + since original transforms expect PIL images. - To obtain the transforms, we retrieve the transforms from the clip backbone. Since the original transforms are - intended for PIL images, we prepend a ToPILImage transform to the list of transforms. + Returns: + Compose: Transform pipeline for preprocessing images. """ transforms = copy(self._transform.transforms) transforms.insert(0, ToPILImage()) @@ -379,7 +439,14 @@ def transform(self) -> Compose: @property def text_embeddings(self) -> torch.Tensor: - """The text embeddings used by the model.""" + """Get model's text embeddings. + + Returns: + torch.Tensor: Text embeddings used for zero-shot inference. + + Raises: + RuntimeError: If text embeddings not collected via ``setup``. + """ if self._text_embeddings.numel() == 0: msg = "Text embeddings have not been collected. Pass a class name to the model using ``setup``." raise RuntimeError(msg) @@ -387,7 +454,14 @@ def text_embeddings(self) -> torch.Tensor: @property def visual_embeddings(self) -> list[torch.Tensor]: - """The visual embeddings used by the model.""" + """Get model's visual embeddings. + + Returns: + list[torch.Tensor]: Visual embeddings used for few-shot inference. + + Raises: + RuntimeError: If visual embeddings not collected via ``setup``. + """ if self._visual_embeddings[0].numel() == 0: msg = "Visual embeddings have not been collected. Pass some reference images to the model using ``setup``." raise RuntimeError(msg) @@ -395,7 +469,14 @@ def visual_embeddings(self) -> list[torch.Tensor]: @property def patch_embeddings(self) -> torch.Tensor: - """The patch embeddings used by the model.""" + """Get model's patch embeddings. + + Returns: + torch.Tensor: Patch embeddings used for few-shot inference. + + Raises: + RuntimeError: If patch embeddings not collected via ``setup``. + """ if self._patch_embeddings.numel() == 0: msg = "Patch embeddings have not been collected. Pass some reference images to the model using ``setup``." raise RuntimeError(msg) diff --git a/src/anomalib/models/image/winclip/utils.py b/src/anomalib/models/image/winclip/utils.py index 620d04d867..b48928c170 100644 --- a/src/anomalib/models/image/winclip/utils.py +++ b/src/anomalib/models/image/winclip/utils.py @@ -1,4 +1,24 @@ -"""WinCLIP utils.""" +"""Utility functions for WinCLIP model. + +This module provides utility functions used by the WinCLIP model for anomaly detection: + +- :func:`cosine_similarity`: Compute pairwise cosine similarity between tensors +- :func:`class_scores`: Calculate anomaly scores from CLIP embeddings +- :func:`harmonic_aggregation`: Aggregate scores using harmonic mean +- :func:`make_masks`: Generate sliding window masks +- :func:`visual_association_score`: Compute visual association scores + +Example: + >>> import torch + >>> from anomalib.models.image.winclip.utils import cosine_similarity + >>> input1 = torch.randn(100, 128) # doctest: +SKIP + >>> input2 = torch.randn(200, 128) # doctest: +SKIP + >>> similarity = cosine_similarity(input1, input2) # doctest: +SKIP + +See Also: + - :class:`WinClip`: Main model class using these utilities + - :class:`WinClipModel`: PyTorch model implementation +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,31 +30,56 @@ def cosine_similarity(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: """Compute pairwise cosine similarity matrix between two tensors. - Computes the cosine similarity between all pairs of vectors in the two tensors. + Computes the cosine similarity between all pairs of vectors in the two input tensors. + The inputs can be either 2D or 3D tensors. For 2D inputs, an implicit batch + dimension of 1 is added. Args: - input1 (torch.Tensor): Input tensor of shape ``(N, D)`` or ``(B, N, D)``. - input2 (torch.Tensor): Input tensor of shape ``(M, D)`` or ``(B, M, D)``. + input1 (torch.Tensor): First input tensor of shape ``(N, D)`` or ``(B, N, D)``, + where: + - ``B`` is the optional batch dimension + - ``N`` is the number of vectors in first input + - ``D`` is the dimension of each vector + input2 (torch.Tensor): Second input tensor of shape ``(M, D)`` or ``(B, M, D)``, + where: + - ``B`` is the optional batch dimension + - ``M`` is the number of vectors in second input + - ``D`` is the dimension of each vector (must match input1) Returns: - torch.Tensor: Cosine similarity matrix of shape ``(N, M)`` or ``(B, N, M)``. + torch.Tensor: Cosine similarity matrix of shape ``(N, M)`` for 2D inputs or + ``(B, N, M)`` for 3D inputs, where each element ``[i,j]`` is the cosine + similarity between vector ``i`` from ``input1`` and vector ``j`` from + ``input2``. Examples: + 2D inputs (single batch): + >>> input1 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) >>> input2 = torch.tensor([[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) >>> cosine_similarity(input1, input2) tensor([[[0.0000, 0.7071], [1.0000, 0.7071]]]) - >>> input1 = torch.randn(100, 128) - >>> input2 = torch.randn(200, 128) - >>> cosine_similarity(input1, input2).shape + Different sized inputs: + + >>> input1 = torch.randn(100, 128) # 100 vectors of dimension 128 + >>> input2 = torch.randn(200, 128) # 200 vectors of dimension 128 + >>> similarity = cosine_similarity(input1, input2) + >>> similarity.shape torch.Size([100, 200]) - >>> input1 = torch.randn(10, 100, 128) - >>> input2 = torch.randn(10, 200, 128) - >>> cosine_similarity(input1, input2).shape + 3D inputs (batched): + + >>> input1 = torch.randn(10, 100, 128) # 10 batches of 100 vectors + >>> input2 = torch.randn(10, 200, 128) # 10 batches of 200 vectors + >>> similarity = cosine_similarity(input1, input2) + >>> similarity.shape torch.Size([10, 100, 200]) + + Note: + The function automatically handles both 2D and 3D inputs by adding a batch + dimension to 2D inputs. The vector dimension ``D`` must match between inputs. """ ndim = input1.ndim input1 = input1.unsqueeze(0) if input1.ndim == 2 else input1 @@ -54,42 +99,65 @@ def class_scores( temperature: float = 1.0, target_class: int | None = None, ) -> torch.Tensor: - """Compute class scores between a set of N image embeddings and a set of M text embeddings. + """Compute class scores between image embeddings and text embeddings. + + Computes similarity scores between image and text embeddings by first calculating + cosine similarity and then applying temperature scaling and softmax. This follows + Equation (1) in the WinCLIP paper. - Each text embedding represents the embedding of a prompt for a specific class. By computing the cosine similarity - between each image embedding and each text embedding, we obtain a similarity matrix of shape (N, M). This matrix is - then used to compute the confidence scores for each class by scaling by a temperature parameter and applying the - softmax function (Equation (1) in the WinCLIP paper). + Each text embedding represents a prompt for a specific class. The similarity matrix + is used to compute confidence scores for each class. Args: - image_embeddings (torch.Tensor): Image embedding matrix of shape ``(N, D)`` or ``(B, N, D)``. - text_embeddings (torch.Tensor): Text embedding matrix of shape ``(M, D)`` or ``(B, M, D)``. - temperature (float): Temperature hyperparameter. - target_class (int): Index of the target class. If None, the scores for all classes are returned. + image_embeddings (torch.Tensor): Image embeddings with shape ``(N, D)`` or + ``(B, N, D)``, where: + - ``B`` is optional batch dimension + - ``N`` is number of image embeddings + - ``D`` is embedding dimension + text_embeddings (torch.Tensor): Text embeddings with shape ``(M, D)`` or + ``(B, M, D)``, where: + - ``B`` is optional batch dimension + - ``M`` is number of text embeddings + - ``D`` is embedding dimension (must match image embeddings) + temperature (float, optional): Temperature scaling parameter. Higher values + make distribution more uniform, lower values make it more peaked. + Defaults to ``1.0``. + target_class (int | None, optional): Index of target class. If provided, + returns scores only for that class. Defaults to ``None``. Returns: - torch.Tensor: Similarity score of shape ``(N, M)`` or ``(B, N, M)``. + torch.Tensor: Class similarity scores. Shape depends on inputs and + ``target_class``: + - If no target class: ``(N, M)`` or ``(B, N, M)`` + - If target class specified: ``(N,)`` or ``(B, N)`` Examples: + Basic usage with 2D inputs: + >>> image_embeddings = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) >>> text_embeddings = torch.tensor([[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) >>> class_scores(image_embeddings, text_embeddings) tensor([[0.3302, 0.6698], [0.5727, 0.4273]]) - >>> image_embeddings = torch.randn(100, 128) - >>> text_embeddings = torch.randn(200, 128) + With different sized inputs: + + >>> image_embeddings = torch.randn(100, 128) # 100 vectors + >>> text_embeddings = torch.randn(200, 128) # 200 class prompts >>> class_scores(image_embeddings, text_embeddings).shape torch.Size([100, 200]) - >>> image_embeddings = torch.randn(10, 100, 128) - >>> text_embeddings = torch.randn(10, 200, 128) + With batched 3D inputs: + + >>> image_embeddings = torch.randn(10, 100, 128) # 10 batches + >>> text_embeddings = torch.randn(10, 200, 128) # 10 batches >>> class_scores(image_embeddings, text_embeddings).shape torch.Size([10, 100, 200]) - >>> image_embeddings = torch.randn(10, 100, 128) - >>> text_embeddings = torch.randn(10, 200, 128) - >>> class_scores(image_embeddings, text_embeddings, target_class=0).shape + With target class specified: + + >>> scores = class_scores(image_embeddings, text_embeddings, target_class=0) + >>> scores.shape torch.Size([10, 100]) """ scores = (cosine_similarity(image_embeddings, text_embeddings) / temperature).softmax(dim=-1) @@ -101,31 +169,37 @@ def class_scores( def harmonic_aggregation(window_scores: torch.Tensor, output_size: tuple, masks: torch.Tensor) -> torch.Tensor: """Perform harmonic aggregation on window scores. - Computes a single score for each patch location by aggregating the scores of all windows that cover the patch. - Scores are aggregated using the harmonic mean. + Computes a single score for each patch location by aggregating the scores of all + windows that cover the patch. Scores are aggregated using the harmonic mean. Args: - window_scores (torch.Tensor): Tensor of shape ``(batch_size, n_masks)`` representing the scores for each sliding - window location. - output_size (tuple): Tuple of integers representing the output size ``(H, W)``. - masks (torch.Tensor): Tensor of shape ``(n_patches_per_mask, n_masks)`` representing the masks. Each mask is - set of indices indicating which patches are covered by the mask. + window_scores (torch.Tensor): Scores for each sliding window location. + Shape: ``(batch_size, n_masks)``. + output_size (tuple): Output dimensions ``(H, W)``. + masks (torch.Tensor): Binary masks indicating which patches are covered by each + window. Shape: ``(n_patches_per_mask, n_masks)``. Returns: - torch.Tensor: Tensor of shape ``(batch_size, H, W)```` representing the aggregated scores. + torch.Tensor: Aggregated scores. Shape: ``(batch_size, H, W)``. + + Example: + Example for a 3x3 patch grid with 4 sliding windows of size 2x2: - Examples: - >>> # example for a 3x3 patch grid with 4 sliding windows of size 2x2 >>> window_scores = torch.tensor([[1.0, 0.75, 0.5, 0.25]]) >>> output_size = (3, 3) >>> masks = torch.Tensor([[0, 1, 3, 4], - [1, 2, 4, 5], - [3, 4, 6, 7], - [4, 5, 7, 8]]) + ... [1, 2, 4, 5], + ... [3, 4, 6, 7], + ... [4, 5, 7, 8]]) >>> harmonic_aggregation(window_scores, output_size, masks) tensor([[[1.0000, 0.8571, 0.7500], [0.6667, 0.4800, 0.3750], [0.5000, 0.3333, 0.2500]]]) + + Note: + The harmonic mean is used instead of arithmetic mean as it is more sensitive to + low scores, making it better suited for anomaly detection where we want to + emphasize potential defects. """ batch_size = window_scores.shape[0] height, width = output_size @@ -170,37 +244,57 @@ def visual_association_score(embeddings: torch.Tensor, reference_embeddings: tor def make_masks(grid_size: tuple[int, int], kernel_size: int, stride: int = 1) -> torch.Tensor: - """Make a set of masks to select patches from a feature map in a sliding window fashion. + """Make masks to select patches from a feature map using sliding windows. + + Creates a set of masks for selecting patches from a feature map in a sliding window + fashion. Each column in the returned tensor represents one mask. A mask consists of + indices indicating which patches are covered by that sliding window position. - Each column in the returned tensor represents a mask. Each mask is a set of indices indicating which patches are - covered by the mask. The number of masks is equal to the number of sliding windows that fit in the feature map. + The number of masks equals the number of possible sliding window positions that fit + in the feature map given the kernel size and stride. Args: - grid_size (tuple[int, int]): The shape of the feature map. - kernel_size (int): The size of the kernel in number of patches. - stride (int): The size of the stride in number of patches. + grid_size (tuple[int, int]): Height and width of the feature map grid as + ``(H, W)``. + kernel_size (int): Size of the sliding window kernel in number of patches. + stride (int, optional): Stride of the sliding window in number of patches. + Defaults to ``1``. Returns: - torch.Tensor: Set of masks of shape ``(n_patches_per_mask, n_masks)``. + torch.Tensor: Set of masks with shape ``(n_patches_per_mask, n_masks)``. Each + column represents indices of patches covered by one sliding window position. + + Raises: + ValueError: If any dimension of ``grid_size`` is smaller than ``kernel_size``. Examples: + Create masks for a 3x3 grid with kernel size 2 and stride 1: + >>> make_masks((3, 3), 2) tensor([[0, 1, 3, 4], [1, 2, 4, 5], [3, 4, 6, 7], [4, 5, 7, 8]], dtype=torch.int32) + Create masks for a 4x4 grid with kernel size 2 and stride 1: + >>> make_masks((4, 4), 2) tensor([[ 0, 1, 2, 4, 5, 6, 8, 9, 10], [ 1, 2, 3, 5, 6, 7, 9, 10, 11], [ 4, 5, 6, 8, 9, 10, 12, 13, 14], [ 5, 6, 7, 9, 10, 11, 13, 14, 15]], dtype=torch.int32) + Create masks for a 4x4 grid with kernel size 2 and stride 2: + >>> make_masks((4, 4), 2, stride=2) tensor([[ 0, 2, 8, 10], [ 1, 3, 9, 11], [ 4, 6, 12, 14], [ 5, 7, 13, 15]], dtype=torch.int32) + + Note: + The returned masks can be used with :func:`visual_association_score` to compute + scores for sliding window positions. """ if any(dim < kernel_size for dim in grid_size): msg = ( diff --git a/src/anomalib/models/video/__init__.py b/src/anomalib/models/video/__init__.py index ae952f0e30..0de4f1328d 100644 --- a/src/anomalib/models/video/__init__.py +++ b/src/anomalib/models/video/__init__.py @@ -1,4 +1,31 @@ -"""Anomalib Video Models.""" +"""Anomalib Video Models. + +This module contains implementations of various deep learning models for video-based +anomaly detection. + +Example: + >>> from anomalib.models.video import AiVad + >>> from anomalib.data import Avenue + >>> from anomalib.engine import Engine + + >>> # Initialize a model and datamodule + >>> datamodule = Avenue( + ... clip_length_in_frames=2, + ... frames_between_clips=1, + ... target_frame=VideoTargetFrame.LAST + ... ) + >>> model = AiVad() + + >>> # Train using the engine + >>> engine = Engine() # doctest: +SKIP + >>> engine.fit(model=model, datamodule=datamodule) # doctest: +SKIP + + >>> # Get predictions + >>> predictions = engine.predict(model=model, datamodule=datamodule) # doctest: +SKIP + +Available Models: + - :class:`AiVad`: AI-based Video Anomaly Detection +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/video/ai_vad/__init__.py b/src/anomalib/models/video/ai_vad/__init__.py index 740636009b..4652c299e7 100644 --- a/src/anomalib/models/video/ai_vad/__init__.py +++ b/src/anomalib/models/video/ai_vad/__init__.py @@ -1,8 +1,31 @@ -"""Implementatation of the AI-VAD Model. +"""Implementation of the AI-VAD model. -AI-VAD: Accurate and Interpretable Video Anomaly Detection +This module provides the implementation of the AI-VAD +Attribute-based Representations for Accurate and Interpretable Video Anomaly +Detection. -Paper https://arxiv.org/pdf/2212.00789.pdf +The model extracts three types of features from video regions: + - Velocity features: Histogram of optical flow magnitudes + - Pose features: Human keypoint detections using KeypointRCNN + - Deep features: CLIP embeddings of region crops + +These features are used to model normal behavior patterns and detect anomalies as +deviations from the learned distributions. + +Example: + >>> from anomalib.models.video.ai_vad import AiVad + >>> # Initialize the model + >>> model = AiVad( + ... input_size=(256, 256), + ... use_pose_features=True, + ... use_deep_features=True, + ... use_velocity_features=True + ... ) + +Reference: + Tal Reiss, Yedid Hoshen, "AI-VAD: Attribute-based Representations for + Accurate and Interpretable Video Anomaly Detection", arXiv:2212.00789, 2022 + https://arxiv.org/pdf/2212.00789.pdf """ # Copyright (C) 2023-2024 Intel Corporation diff --git a/src/anomalib/models/video/ai_vad/density.py b/src/anomalib/models/video/ai_vad/density.py index 778e945769..65cef958f9 100644 --- a/src/anomalib/models/video/ai_vad/density.py +++ b/src/anomalib/models/video/ai_vad/density.py @@ -1,4 +1,29 @@ -"""Density estimation module for AI-VAD model implementation.""" +"""Density estimation module for AI-VAD model implementation. + +This module implements the density estimation stage of the AI-VAD model. It provides +density estimators for modeling the distribution of extracted features from normal +video samples. + +The module provides the following components: + - :class:`BaseDensityEstimator`: Abstract base class for density estimators + - :class:`CombinedDensityEstimator`: Main density estimator that combines + multiple feature-specific estimators + +Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.density import CombinedDensityEstimator + >>> from anomalib.models.video.ai_vad.features import FeatureType + >>> estimator = CombinedDensityEstimator() + >>> features = { + ... FeatureType.VELOCITY: torch.randn(32, 8), + ... FeatureType.POSE: torch.randn(32, 34), + ... FeatureType.DEEP: torch.randn(32, 512) + ... } + >>> scores = estimator(features) # Returns anomaly scores during inference + +The density estimators are used to model the distribution of normal behavior and +detect anomalies as samples with low likelihood under the learned distributions. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,11 +41,38 @@ class BaseDensityEstimator(nn.Module, ABC): - """Base density estimator.""" + """Abstract base class for density estimators. + + This class defines the interface for density estimators used in the AI-VAD model. + Subclasses must implement methods for updating the density model with new features, + predicting densities for test samples, and fitting the model. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.density import BaseDensityEstimator + >>> class MyEstimator(BaseDensityEstimator): + ... def update(self, features, group=None): + ... pass + ... def predict(self, features): + ... return torch.rand(features.shape[0]) + ... def fit(self): + ... pass + >>> estimator = MyEstimator() + >>> features = torch.randn(32, 8) + >>> scores = estimator(features) # Forward pass returns predictions + """ @abstractmethod def update(self, features: dict[FeatureType, torch.Tensor] | torch.Tensor, group: str | None = None) -> None: - """Update the density model with a new set of features.""" + """Update the density model with a new set of features. + + Args: + features (dict[FeatureType, torch.Tensor] | torch.Tensor): Input features + to update the model. Can be either a dictionary mapping feature types + to tensors, or a single tensor. + group (str | None, optional): Optional group identifier for grouped + density estimation. Defaults to ``None``. + """ raise NotImplementedError @abstractmethod @@ -28,19 +80,45 @@ def predict( self, features: dict[FeatureType, torch.Tensor] | torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """Predict the density of a set of features.""" + """Predict the density of a set of features. + + Args: + features (dict[FeatureType, torch.Tensor] | torch.Tensor): Input features + to compute density for. Can be either a dictionary mapping feature + types to tensors, or a single tensor. + + Returns: + torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Predicted density + scores. May return either a single tensor of scores or a tuple of + tensors for more complex estimators. + """ raise NotImplementedError @abstractmethod def fit(self) -> None: - """Compose model using collected features.""" + """Compose model using collected features. + + This method should be called after updating the model with features to fit + the density estimator to the collected data. + """ raise NotImplementedError def forward( self, features: dict[FeatureType, torch.Tensor] | torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None: - """Update or predict depending on training status.""" + """Forward pass that either updates or predicts based on training status. + + Args: + features (dict[FeatureType, torch.Tensor] | torch.Tensor): Input + features. Can be either a dictionary mapping feature types to + tensors, or a single tensor. + + Returns: + torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None: During + training, returns ``None`` after updating. During inference, + returns density predictions. + """ if self.training: self.update(features) return None @@ -53,18 +131,38 @@ class CombinedDensityEstimator(BaseDensityEstimator): Combines density estimators for the different feature types included in the model. Args: - use_pose_features (bool): Flag indicating if pose features should be used. - Defaults to ``True``. - use_deep_features (bool): Flag indicating if deep features should be used. - Defaults to ``True``. - use_velocity_features (bool): Flag indicating if velocity features should be used. - Defaults to ``False``. - n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. - Defaults to ``1``. - n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. - Defaults to ``1``. - n_components_velocity (int): Number of components used by GMM density estimation for velocity features. - Defaults to ``5``. + use_pose_features (bool, optional): Flag indicating if pose features should be + used. Defaults to ``True``. + use_deep_features (bool, optional): Flag indicating if deep features should be + used. Defaults to ``True``. + use_velocity_features (bool, optional): Flag indicating if velocity features + should be used. Defaults to ``False``. + n_neighbors_pose (int, optional): Number of neighbors used in KNN density + estimation for pose features. Defaults to ``1``. + n_neighbors_deep (int, optional): Number of neighbors used in KNN density + estimation for deep features. Defaults to ``1``. + n_components_velocity (int, optional): Number of components used by GMM density + estimation for velocity features. Defaults to ``5``. + + Raises: + ValueError: If none of the feature types (velocity, pose, deep) are enabled. + + Example: + >>> from anomalib.models.video.ai_vad.density import CombinedDensityEstimator + >>> estimator = CombinedDensityEstimator( + ... use_pose_features=True, + ... use_deep_features=True, + ... use_velocity_features=True, + ... n_neighbors_pose=1, + ... n_neighbors_deep=1, + ... n_components_velocity=5 + ... ) + >>> # Update with features from training data + >>> estimator.update(features, group="video_001") + >>> # Fit the density estimators + >>> estimator.fit() + >>> # Get predictions for test data + >>> region_scores, image_score = estimator.predict(features) """ def __init__( @@ -96,8 +194,12 @@ def update(self, features: dict[FeatureType, torch.Tensor], group: str | None = """Update the density estimators for the different feature types. Args: - features (dict[FeatureType, torch.Tensor]): Dictionary containing extracted features for a single frame. - group (str): Identifier of the video from which the frame was sampled. Used for grouped density estimation. + features (dict[FeatureType, torch.Tensor]): Dictionary containing + extracted features for a single frame. Keys are feature types and + values are the corresponding feature tensors. + group (str | None, optional): Identifier of the video from which the + frame was sampled. Used for grouped density estimation. Defaults to + ``None``. """ if self.use_velocity_features: self.velocity_estimator.update(features[FeatureType.VELOCITY]) @@ -107,7 +209,11 @@ def update(self, features: dict[FeatureType, torch.Tensor], group: str | None = self.pose_estimator.update(features[FeatureType.POSE], group=group) def fit(self) -> None: - """Fit the density estimation models on the collected features.""" + """Fit the density estimation models on the collected features. + + This method should be called after updating with all training features to + fit the density estimators to the collected data. + """ if self.use_velocity_features: self.velocity_estimator.fit() if self.use_deep_features: @@ -116,14 +222,28 @@ def fit(self) -> None: self.pose_estimator.fit() def predict(self, features: dict[FeatureType, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - """Predict the region- and image-level anomaly scores for an image based on a set of features. + """Predict region and image-level anomaly scores. + + Computes anomaly scores for each region in the frame and an overall frame + score based on the maximum region score. Args: - features (dict[Tensor]): Dictionary containing extracted features for a single frame. + features (dict[FeatureType, torch.Tensor]): Dictionary containing + extracted features for a single frame. Keys are feature types and + values are the corresponding feature tensors. Returns: - Tensor: Region-level anomaly scores for all regions withing the frame. - Tensor: Frame-level anomaly score for the frame. + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Region-level anomaly scores for all regions within the frame + - Frame-level anomaly score for the frame + + Example: + >>> features = { + ... FeatureType.VELOCITY: velocity_features, + ... FeatureType.DEEP: deep_features, + ... FeatureType.POSE: pose_features + ... } + >>> region_scores, image_score = estimator.predict(features) """ n_regions = next(iter(features.values())).shape[0] device = next(iter(features.values())).device @@ -147,13 +267,30 @@ def predict(self, features: dict[FeatureType, torch.Tensor]) -> tuple[torch.Tens class GroupedKNNEstimator(DynamicBufferMixin, BaseDensityEstimator): """Grouped KNN density estimator. - Keeps track of the group (e.g. video id) from which the features were sampled for normalization purposes. + Keeps track of the group (e.g. video id) from which the features were sampled for + normalization purposes. Args: n_neighbors (int): Number of neighbors used in KNN search. + + Example: + >>> from anomalib.models.video.ai_vad.density import GroupedKNNEstimator + >>> import torch + >>> estimator = GroupedKNNEstimator(n_neighbors=5) + >>> features = torch.randn(32, 512) # (N, D) + >>> estimator.update(features, group="video_1") + >>> estimator.fit() + >>> scores = estimator.predict(features) + >>> scores.shape + torch.Size([32]) """ def __init__(self, n_neighbors: int) -> None: + """Initialize the grouped KNN density estimator. + + Args: + n_neighbors (int): Number of neighbors used in KNN search. + """ super().__init__() self.n_neighbors = n_neighbors @@ -168,8 +305,15 @@ def update(self, features: torch.Tensor, group: str | None = None) -> None: """Update the internal feature bank while keeping track of the group. Args: - features (torch.Tensor): Feature vectors extracted from a video frame. - group (str): Identifier of the group (video) from which the frame was sampled. + features (torch.Tensor): Feature vectors extracted from a video frame of + shape ``(N, D)``. + group (str | None, optional): Identifier of the group (video) from which + the frame was sampled. Defaults to ``None``. + + Example: + >>> estimator = GroupedKNNEstimator(n_neighbors=5) + >>> features = torch.randn(32, 512) # (N, D) + >>> estimator.update(features, group="video_1") """ group = group or "default" @@ -179,7 +323,17 @@ def update(self, features: torch.Tensor, group: str | None = None) -> None: self.feature_collection[group] = [features] def fit(self) -> None: - """Fit the KNN model by stacking the feature vectors and computing the normalization statistics.""" + """Fit the KNN model by stacking features and computing normalization stats. + + Stacks the collected feature vectors group-wise and computes the normalization + statistics. After fitting, the feature collection is deleted to free up memory. + + Example: + >>> estimator = GroupedKNNEstimator(n_neighbors=5) + >>> features = torch.randn(32, 512) # (N, D) + >>> estimator.update(features, group="video_1") + >>> estimator.fit() + """ # stack the collected features group-wise feature_collection = {key: torch.vstack(value) for key, value in self.feature_collection.items()} # assign memory bank, group index and group names @@ -202,17 +356,30 @@ def predict( """Predict the (normalized) density for a set of features. Args: - features (torch.Tensor): Input features that will be compared to the density model. - group (str, optional): Group (video id) from which the features originate. If passed, all features of the - same group in the memory bank will be excluded from the density estimation. + features (torch.Tensor): Input features of shape ``(N, D)`` that will be + compared to the density model. + group (str | None, optional): Group (video id) from which the features + originate. If passed, all features of the same group in the memory + bank will be excluded from the density estimation. Defaults to ``None``. - n_neighbors (int): Number of neighbors used in the KNN search. + n_neighbors (int, optional): Number of neighbors used in the KNN search. Defaults to ``1``. - normalize (bool): Flag indicating if the density should be normalized to min-max stats of the feature bank. - Defatuls to ``True``. + normalize (bool, optional): Flag indicating if the density should be + normalized to min-max stats of the feature bank. + Defaults to ``True``. Returns: - Tensor: Mean (normalized) distances of input feature vectors to k nearest neighbors in feature bank. + torch.Tensor: Mean (normalized) distances of input feature vectors to k + nearest neighbors in feature bank. + + Example: + >>> estimator = GroupedKNNEstimator(n_neighbors=5) + >>> features = torch.randn(32, 512) # (N, D) + >>> estimator.update(features, group="video_1") + >>> estimator.fit() + >>> scores = estimator.predict(features, group="video_1") + >>> scores.shape + torch.Size([32]) """ n_neighbors = n_neighbors or self.n_neighbors @@ -234,12 +401,15 @@ def _nearest_neighbors(feature_bank: torch.Tensor, features: torch.Tensor, n_nei """Perform the KNN search. Args: - feature_bank (torch.Tensor): Feature bank used for KNN search. - features (Ternsor): Input features. - n_neighbors (int): Number of neighbors used in KNN search. + feature_bank (torch.Tensor): Feature bank of shape ``(M, D)`` used for + KNN search. + features (torch.Tensor): Input features of shape ``(N, D)``. + n_neighbors (int, optional): Number of neighbors used in KNN search. + Defaults to ``1``. Returns: - Tensor: Distances between the input features and their K nearest neighbors in the feature bank. + torch.Tensor: Distances between the input features and their K nearest + neighbors in the feature bank. """ distances = torch.cdist(features, feature_bank, p=2.0) # euclidean norm if n_neighbors == 1: @@ -250,7 +420,12 @@ def _nearest_neighbors(feature_bank: torch.Tensor, features: torch.Tensor, n_nei return distances def _compute_normalization_statistics(self, grouped_features: dict[str, Tensor]) -> None: - """Compute min-max normalization statistics while taking the group into account.""" + """Compute min-max normalization statistics while taking the group into account. + + Args: + grouped_features (dict[str, Tensor]): Dictionary mapping group names to + feature tensors. + """ for group, features in grouped_features.items(): distances = self.predict(features, group, normalize=False) self.normalization_statistics.update(distances) @@ -264,7 +439,7 @@ def _normalize(self, distances: torch.Tensor) -> torch.Tensor: distances (torch.Tensor): Distance tensor produced by KNN search. Returns: - Tensor: Normalized distances. + torch.Tensor: Normalized distances. """ return (distances - self.normalization_statistics.min) / ( self.normalization_statistics.max - self.normalization_statistics.min @@ -274,9 +449,23 @@ def _normalize(self, distances: torch.Tensor) -> torch.Tensor: class GMMEstimator(BaseDensityEstimator): """Density estimation based on Gaussian Mixture Model. + Fits a GMM to the training features and uses the negative log-likelihood as an + anomaly score during inference. + Args: - n_components (int): Number of components used in the GMM. + n_components (int, optional): Number of Gaussian components used in the GMM. Defaults to ``2``. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.density import GMMEstimator + >>> estimator = GMMEstimator(n_components=2) + >>> features = torch.randn(32, 8) # (N, D) + >>> estimator.update(features) + >>> estimator.fit() + >>> scores = estimator.predict(features) + >>> scores.shape + torch.Size([32]) """ def __init__(self, n_components: int = 2) -> None: @@ -288,27 +477,44 @@ def __init__(self, n_components: int = 2) -> None: self.normalization_statistics = MinMax() def update(self, features: torch.Tensor, group: str | None = None) -> None: - """Update the feature bank.""" + """Update the feature bank with new features. + + Args: + features (torch.Tensor): Feature vectors of shape ``(N, D)`` to add to + the memory bank. + group (str | None, optional): Unused group parameter included for + interface compatibility. Defaults to ``None``. + """ del group if isinstance(self.memory_bank, list): self.memory_bank.append(features) def fit(self) -> None: - """Fit the GMM and compute normalization statistics.""" + """Fit the GMM and compute normalization statistics. + + Concatenates all features in the memory bank, fits the GMM to the combined + features, and computes min-max normalization statistics over the training + scores. + """ self.memory_bank = torch.vstack(self.memory_bank) self.gmm.fit(self.memory_bank) self._compute_normalization_statistics() def predict(self, features: torch.Tensor, normalize: bool = True) -> torch.Tensor: - """Predict the density of a set of feature vectors. + """Predict anomaly scores for input features. + + Computes the negative log-likelihood of each feature vector under the + fitted GMM. Lower likelihood (higher score) indicates more anomalous + samples. Args: - features (torch.Tensor): Input feature vectors. - normalize (bool): Flag indicating if the density should be normalized to min-max stats of the feature bank. - Defaults to ``True``. + features (torch.Tensor): Input feature vectors of shape ``(N, D)``. + normalize (bool, optional): Whether to normalize scores using min-max + statistics from training. Defaults to ``True``. Returns: - Tensor: Density scores of the input feature vectors. + torch.Tensor: Anomaly scores of shape ``(N,)``. Higher values indicate + more anomalous samples. """ density = -self.gmm.score_samples(features) if normalize: @@ -316,19 +522,23 @@ def predict(self, features: torch.Tensor, normalize: bool = True) -> torch.Tenso return density def _compute_normalization_statistics(self) -> None: - """Compute min-max normalization statistics over the feature bank.""" + """Compute min-max normalization statistics over the feature bank. + + Computes anomaly scores for all training features and updates the min-max + statistics used for score normalization during inference. + """ training_scores = self.predict(self.memory_bank, normalize=False) self.normalization_statistics.update(training_scores) self.normalization_statistics.compute() def _normalize(self, density: torch.Tensor) -> torch.Tensor: - """Normalize distance predictions. + """Normalize anomaly scores using min-max statistics. Args: - density (torch.Tensor): Distance tensor produced by KNN search. + density (torch.Tensor): Raw anomaly scores of shape ``(N,)``. Returns: - Tensor: Normalized distances. + torch.Tensor: Normalized anomaly scores of shape ``(N,)``. """ return (density - self.normalization_statistics.min) / ( self.normalization_statistics.max - self.normalization_statistics.min diff --git a/src/anomalib/models/video/ai_vad/features.py b/src/anomalib/models/video/ai_vad/features.py index f2107f217c..312946b4ea 100644 --- a/src/anomalib/models/video/ai_vad/features.py +++ b/src/anomalib/models/video/ai_vad/features.py @@ -1,4 +1,25 @@ -"""Feature extraction module for AI-VAD model implementation.""" +"""Feature extraction module for AI-VAD model implementation. + +This module implements the feature extraction stage of the AI-VAD model. It extracts +three types of features from video regions: + +- Velocity features: Histogram of optical flow magnitudes +- Pose features: Human keypoint detections using KeypointRCNN +- Deep features: CLIP embeddings of region crops + +Example: + >>> from anomalib.models.video.ai_vad.features import FeatureExtractor + >>> import torch + >>> extractor = FeatureExtractor() + >>> frames = torch.randn(32, 2, 3, 256, 256) # (N, L, C, H, W) + >>> flow = torch.randn(32, 2, 256, 256) # (N, 2, H, W) + >>> regions = [{"boxes": torch.randn(5, 4)}] * 32 # List of region dicts + >>> features = extractor(frames, flow, regions) + +The module provides the following components: + - :class:`FeatureType`: Enum of available feature types + - :class:`FeatureExtractor`: Main class that handles feature extraction +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,7 +37,26 @@ class FeatureType(str, Enum): - """Names of the different feature streams used in AI-VAD.""" + """Names of the different feature streams used in AI-VAD. + + This enum defines the available feature types that can be extracted from video + regions in the AI-VAD model. + + Attributes: + POSE: Keypoint features extracted using KeypointRCNN model + VELOCITY: Histogram features computed from optical flow magnitudes + DEEP: Visual embedding features extracted using CLIP model + + Example: + >>> from anomalib.models.video.ai_vad.features import FeatureType + >>> feature_type = FeatureType.POSE + >>> feature_type + + >>> feature_type == "pose" + True + >>> feature_type in [FeatureType.POSE, FeatureType.VELOCITY] + True + """ POSE = "pose" VELOCITY = "velocity" @@ -26,15 +66,31 @@ class FeatureType(str, Enum): class FeatureExtractor(nn.Module): """Feature extractor for AI-VAD. + Extracts velocity, pose and deep features from video regions based on the enabled + feature types. + Args: - n_velocity_bins (int): Number of discrete bins used for velocity histogram features. - Defaults to ``8``. - use_velocity_features (bool): Flag indicating if velocity features should be used. - Defaults to ``True``. - use_pose_features (bool): Flag indicating if pose features should be used. - Defaults to ``True``. - use_deep_features (bool): Flag indicating if deep features should be used. - Defaults to ``True``. + n_velocity_bins (int, optional): Number of discrete bins used for velocity + histogram features. Defaults to ``8``. + use_velocity_features (bool, optional): Flag indicating if velocity features + should be used. Defaults to ``True``. + use_pose_features (bool, optional): Flag indicating if pose features should be + used. Defaults to ``True``. + use_deep_features (bool, optional): Flag indicating if deep features should be + used. Defaults to ``True``. + + Raises: + ValueError: If none of the feature types (velocity, pose, deep) are enabled. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.features import FeatureExtractor + >>> extractor = FeatureExtractor() + >>> rgb_batch = torch.randn(32, 3, 256, 256) # (N, C, H, W) + >>> flow_batch = torch.randn(32, 2, 256, 256) # (N, 2, H, W) + >>> regions = [{"boxes": torch.randn(5, 4)}] * 32 # List of region dicts + >>> features = extractor(rgb_batch, flow_batch, regions) + >>> # Returns list of dicts with keys: velocity, pose, deep """ def __init__( @@ -65,15 +121,31 @@ def forward( ) -> list[dict]: """Forward pass through the feature extractor. - Extract any combination of velocity, pose and deep features depending on configuration. + Extract any combination of velocity, pose and deep features depending on + configuration. Args: - rgb_batch (torch.Tensor): Batch of RGB images of shape (N, 3, H, W) - flow_batch (torch.Tensor): Batch of optical flow images of shape (N, 2, H, W) - regions (list[dict]): Region information per image in batch. + rgb_batch (torch.Tensor): Batch of RGB images of shape ``(N, 3, H, W)``. + flow_batch (torch.Tensor): Batch of optical flow images of shape + ``(N, 2, H, W)``. + regions (list[dict]): Region information per image in batch. Each dict + contains bounding boxes of shape ``(M, 4)``. Returns: - list[dict]: Feature dictionary per image in batch. + list[dict]: Feature dictionary per image in batch. Each dict contains + the enabled feature types as keys with corresponding feature tensors + as values. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.features import FeatureExtractor + >>> extractor = FeatureExtractor() + >>> rgb_batch = torch.randn(32, 3, 256, 256) # (N, C, H, W) + >>> flow_batch = torch.randn(32, 2, 256, 256) # (N, 2, H, W) + >>> regions = [{"boxes": torch.randn(5, 4)}] * 32 # List of region dicts + >>> features = extractor(rgb_batch, flow_batch, regions) + >>> features[0].keys() # Features for first image + dict_keys(['velocity', 'pose', 'deep']) """ batch_size = rgb_batch.shape[0] @@ -104,7 +176,21 @@ def forward( class DeepExtractor(nn.Module): """Deep feature extractor. - Extracts the deep (appearance) features from the input regions. + Extracts deep (appearance) features from input regions using a CLIP vision encoder. + + The extractor uses a pre-trained ViT-B/16 CLIP model to encode image regions into + a 512-dimensional feature space. Input regions are resized to 224x224 and + normalized using CLIP's default preprocessing. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.features import DeepExtractor + >>> extractor = DeepExtractor() + >>> batch = torch.randn(32, 3, 256, 256) # (N, C, H, W) + >>> boxes = torch.tensor([[0, 10, 20, 50, 60]]) # (M, 5) with batch indices + >>> features = extractor(batch, boxes, batch_size=32) + >>> features.shape + torch.Size([1, 512]) """ def __init__(self) -> None: @@ -118,13 +204,16 @@ def forward(self, batch: torch.Tensor, boxes: torch.Tensor, batch_size: int) -> """Extract deep features using CLIP encoder. Args: - batch (torch.Tensor): Batch of RGB input images of shape (N, 3, H, W) - boxes (torch.Tensor): Bounding box coordinates of shaspe (M, 5). - First column indicates batch index of the bbox. + batch (torch.Tensor): Batch of RGB input images of shape ``(N, 3, H, W)`` + boxes (torch.Tensor): Bounding box coordinates of shape ``(M, 5)``. First + column indicates batch index of the bbox, remaining columns are + coordinates ``[x1, y1, x2, y2]``. batch_size (int): Number of images in the batch. Returns: - Tensor: Deep feature tensor of shape (M, 512) + torch.Tensor: Deep feature tensor of shape ``(M, 512)``, where ``M`` is + the number of input regions and 512 is the CLIP feature dimension. + Returns empty tensor if no valid regions. """ rgb_regions = roi_align(batch, boxes, output_size=[224, 224]) @@ -138,10 +227,23 @@ def forward(self, batch: torch.Tensor, boxes: torch.Tensor, batch_size: int) -> class VelocityExtractor(nn.Module): """Velocity feature extractor. - Extracts histograms of optical flow magnitude and direction. + Extracts histograms of optical flow magnitude and direction from video regions. + The histograms capture motion patterns by binning flow vectors based on their + direction and weighting by magnitude. Args: - n_bins (int): Number of direction bins used for the feature histograms. + n_bins (int, optional): Number of direction bins used for the feature + histograms. Defaults to ``8``. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.features import VelocityExtractor + >>> extractor = VelocityExtractor(n_bins=8) + >>> flows = torch.randn(32, 2, 256, 256) # (N, 2, H, W) + >>> boxes = torch.tensor([[0, 10, 20, 50, 60]]) # (M, 5) with batch indices + >>> features = extractor(flows, boxes) + >>> features.shape + torch.Size([1, 8]) """ def __init__(self, n_bins: int = 8) -> None: @@ -150,15 +252,25 @@ def __init__(self, n_bins: int = 8) -> None: self.n_bins = n_bins def forward(self, flows: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor: - """Extract velocioty features by filling a histogram. + """Extract velocity features by computing flow direction histograms. + + For each region, computes a histogram of optical flow directions weighted by + flow magnitudes. The flow vectors are converted from cartesian to polar + coordinates, with directions binned into ``n_bins`` equal intervals between + ``-π`` and ``π``. The histogram values are normalized by the bin counts. Args: - flows (torch.Tensor): Batch of optical flow images of shape (N, 2, H, W) - boxes (torch.Tensor): Bounding box coordinates of shaspe (M, 5). - First column indicates batch index of the bbox. + flows (torch.Tensor): Batch of optical flow images of shape + ``(N, 2, H, W)``, where the second dimension contains x and y flow + components. + boxes (torch.Tensor): Bounding box coordinates of shape ``(M, 5)``. First + column indicates batch index of the bbox, remaining columns are + coordinates ``[x1, y1, x2, y2]``. Returns: - Tensor: Velocity feature tensor of shape (M, n_bins) + torch.Tensor: Velocity feature tensor of shape ``(M, n_bins)``, where + ``M`` is the number of input regions. Returns empty tensor if no + valid regions. """ flow_regions = roi_align(flows, boxes, output_size=[224, 224]) @@ -189,10 +301,25 @@ def forward(self, flows: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor: class PoseExtractor(nn.Module): """Pose feature extractor. - Extracts pose features based on estimated body landmark keypoints. + Extracts pose features based on estimated body landmark keypoints using a + KeypointRCNN model. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.features import PoseExtractor + >>> extractor = PoseExtractor() + >>> batch = torch.randn(2, 3, 256, 256) # (N, C, H, W) + >>> boxes = torch.tensor([[0, 10, 10, 50, 50], [1, 20, 20, 60, 60]]) + >>> features = extractor(batch, boxes) + >>> # Returns list of pose feature tensors for each image """ def __init__(self, *args, **kwargs) -> None: + """Initialize the pose feature extractor. + + Loads a pre-trained KeypointRCNN model and extracts its components for + feature extraction. + """ super().__init__(*args, **kwargs) weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT @@ -206,13 +333,17 @@ def __init__(self, *args, **kwargs) -> None: def _post_process(keypoint_detections: list[dict]) -> list[torch.Tensor]: """Convert keypoint predictions to 1D feature vectors. - Post-processing consists of flattening and normalizing to bbox coordinates. + Post-processing consists of flattening the keypoint coordinates and + normalizing them relative to the bounding box coordinates. Args: keypoint_detections (list[dict]): Outputs of the keypoint extractor + containing detected keypoints and bounding boxes. Returns: - list[torch.Tensor]: List of pose feature tensors for each image + list[torch.Tensor]: List of pose feature tensors for each image, where + each tensor has shape ``(N, K*2)`` with ``N`` being the number of + detections and ``K`` the number of keypoints. """ poses = [] for detection in keypoint_detections: @@ -226,13 +357,23 @@ def _post_process(keypoint_detections: list[dict]) -> list[torch.Tensor]: def forward(self, batch: torch.Tensor, boxes: torch.Tensor) -> list[torch.Tensor]: """Extract pose features using a human keypoint estimation model. + The method performs the following steps: + 1. Transform input images + 2. Extract backbone features + 3. Pool ROI features for each box + 4. Predict keypoint locations + 5. Post-process predictions + Args: - batch (torch.Tensor): Batch of RGB input images of shape (N, 3, H, W) - boxes (torch.Tensor): Bounding box coordinates of shaspe (M, 5). - First column indicates batch index of the bbox. + batch (torch.Tensor): Batch of RGB input images of shape + ``(N, 3, H, W)``. + boxes (torch.Tensor): Bounding box coordinates of shape ``(M, 5)``. + First column indicates batch index of the bbox, remaining columns + are coordinates ``[x1, y1, x2, y2]``. Returns: - list[torch.Tensor]: list of pose feature tensors for each image. + list[torch.Tensor]: List of pose feature tensors for each image, where + each tensor contains normalized keypoint coordinates. """ images, _ = self.transform(batch) features = self.backbone(images.tensors) diff --git a/src/anomalib/models/video/ai_vad/flow.py b/src/anomalib/models/video/ai_vad/flow.py index 9728a23290..fc1fb2b68e 100644 --- a/src/anomalib/models/video/ai_vad/flow.py +++ b/src/anomalib/models/video/ai_vad/flow.py @@ -1,4 +1,23 @@ -"""Optical Flow extraction module for AI-VAD implementation.""" +"""Optical Flow extraction module for AI-VAD implementation. + +This module implements the optical flow extraction stage of the AI-VAD model. It uses +RAFT (Recurrent All-Pairs Field Transforms) to compute dense optical flow between +consecutive video frames. + +Example: + >>> from anomalib.models.video.ai_vad.flow import FlowExtractor + >>> import torch + >>> extractor = FlowExtractor() + >>> first_frame = torch.randn(32, 3, 256, 256) # (N, C, H, W) + >>> last_frame = torch.randn(32, 3, 256, 256) # (N, C, H, W) + >>> flow = extractor(first_frame, last_frame) + >>> flow.shape + torch.Size([32, 2, 256, 256]) + +The module provides the following components: + - :class:`FlowExtractor`: Main class that handles optical flow computation using + RAFT model +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/models/video/ai_vad/lightning_model.py b/src/anomalib/models/video/ai_vad/lightning_model.py index 3afd674673..ebca72a289 100644 --- a/src/anomalib/models/video/ai_vad/lightning_model.py +++ b/src/anomalib/models/video/ai_vad/lightning_model.py @@ -1,6 +1,38 @@ -"""Attribute-based Representations for Accurate and Interpretable Video Anomaly Detection. +"""AI-VAD. -Paper https://arxiv.org/pdf/2212.00789.pdf +Attribute-based Representations for Accurate and Interpretable Video Anomaly +Detection. + +This module implements the AI-VAD model as described in the paper "AI-VAD: +Attribute-based Representations for Accurate and Interpretable Video Anomaly +Detection." + +The model extracts regions of interest from video frames using object detection and +foreground detection, then computes attribute-based representations including +velocity, pose and deep features for anomaly detection. + +Example: + >>> from anomalib.models.video import AiVad + >>> from anomalib.data import Avenue + >>> from anomalib.data.utils import VideoTargetFrame + >>> from anomalib.engine import Engine + + >>> # Initialize model and datamodule + >>> datamodule = Avenue( + ... clip_length_in_frames=2, + ... frames_between_clips=1, + ... target_frame=VideoTargetFrame.LAST + ... ) + >>> model = AiVad() + + >>> # Train using the engine + >>> engine = Engine() + >>> engine.fit(model=model, datamodule=datamodule) + +Reference: + Tal Reiss, Yedid Hoshen. "AI-VAD: Attribute-based Representations for Accurate + and Interpretable Video Anomaly Detection." arXiv preprint arXiv:2212.00789 + (2022). https://arxiv.org/pdf/2212.00789.pdf """ # Copyright (C) 2023-2024 Intel Corporation @@ -26,42 +58,70 @@ class AiVad(MemoryBankMixin, AnomalibModule): - """AI-VAD: Attribute-based Representations for Accurate and Interpretable Video Anomaly Detection. + """AI-VAD: Attribute-based Representations for Video Anomaly Detection. + + This model extracts regions of interest from video frames using object detection and + foreground detection, then computes attribute-based representations including + velocity, pose and deep features for anomaly detection. Args: - box_score_thresh (float): Confidence threshold for bounding box predictions. - Defaults to ``0.7``. - persons_only (bool): When enabled, only regions labeled as person are included. - Defaults to ``False``. - min_bbox_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. - Defaults to ``100``. - max_bbox_overlap (float): Maximum allowed overlap between bounding boxes. - Defaults to ``0.65``. - enable_foreground_detections (bool): Add additional foreground detections based on pixel difference between - consecutive frames. + box_score_thresh (float, optional): Confidence threshold for bounding box + predictions. Defaults to ``0.7``. + persons_only (bool, optional): When enabled, only regions labeled as person are + included. Defaults to ``False``. + min_bbox_area (int, optional): Minimum bounding box area. Regions with surface + area lower than this value are excluded. Defaults to ``100``. + max_bbox_overlap (float, optional): Maximum allowed overlap between bounding + boxes. Defaults to ``0.65``. + enable_foreground_detections (bool, optional): Add additional foreground + detections based on pixel difference between consecutive frames. Defaults to ``True``. - foreground_kernel_size (int): Gaussian kernel size used in foreground detection. - Defaults to ``3``. - foreground_binary_threshold (int): Value between 0 and 255 which acts as binary threshold in foreground - detection. - Defaults to ``18``. - n_velocity_bins (int): Number of discrete bins used for velocity histogram features. - Defaults to ``1``. - use_velocity_features (bool): Flag indicating if velocity features should be used. - Defaults to ``True``. - use_pose_features (bool): Flag indicating if pose features should be used. - Defaults to ``True``. - use_deep_features (bool): Flag indicating if deep features should be used. - Defaults to ``True``. - n_components_velocity (int): Number of components used by GMM density estimation for velocity features. - Defaults to ``2``. - n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. - Defaults to ``1``. - n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. - Defaults to ``1``. - pre_processor (PreProcessor, optional): Pre-processor for the model. - This is used to pre-process the input data before it is passed to the model. - Defaults to ``None``. + foreground_kernel_size (int, optional): Gaussian kernel size used in foreground + detection. Defaults to ``3``. + foreground_binary_threshold (int, optional): Value between 0 and 255 which acts + as binary threshold in foreground detection. Defaults to ``18``. + n_velocity_bins (int, optional): Number of discrete bins used for velocity + histogram features. Defaults to ``1``. + use_velocity_features (bool, optional): Flag indicating if velocity features + should be used. Defaults to ``True``. + use_pose_features (bool, optional): Flag indicating if pose features should be + used. Defaults to ``True``. + use_deep_features (bool, optional): Flag indicating if deep features should be + used. Defaults to ``True``. + n_components_velocity (int, optional): Number of components used by GMM density + estimation for velocity features. Defaults to ``2``. + n_neighbors_pose (int, optional): Number of neighbors used in KNN density + estimation for pose features. Defaults to ``1``. + n_neighbors_deep (int, optional): Number of neighbors used in KNN density + estimation for deep features. Defaults to ``1``. + pre_processor (PreProcessor | bool, optional): Pre-processor instance or bool + flag to enable default pre-processor. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance or bool + flag to enable default post-processor. Defaults to ``True``. + **kwargs: Additional keyword arguments passed to the parent class. + + Example: + >>> from anomalib.models.video import AiVad + >>> from anomalib.data import Avenue + >>> from anomalib.data.utils import VideoTargetFrame + >>> from anomalib.engine import Engine + + >>> # Initialize model and datamodule + >>> datamodule = Avenue( + ... clip_length_in_frames=2, + ... frames_between_clips=1, + ... target_frame=VideoTargetFrame.LAST + ... ) + >>> model = AiVad() + + >>> # Train using the engine + >>> engine = Engine() + >>> engine.fit(model=model, datamodule=datamodule) + + Note: + The model follows a one-class learning approach and does not require + optimization during training. Instead, it builds density estimators based on + extracted features from normal samples. """ def __init__( @@ -115,7 +175,7 @@ def training_step(self, batch: VideoBatch) -> None: Extract features from the batch of clips and update the density estimators. Args: - batch (Batch): Batch containing image filename, image, label and mask + batch (VideoBatch): Batch containing video frames and metadata. """ features_per_batch = self.model(batch.image) @@ -128,7 +188,11 @@ def training_step(self, batch: VideoBatch) -> None: return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: - """Fit the density estimators to the extracted features from the training set.""" + """Fit the density estimators to the extracted features from the training set. + + Raises: + ValueError: If no regions were extracted during training. + """ if self.total_detections == 0: msg = "No regions were extracted during training." raise ValueError(msg) @@ -137,15 +201,15 @@ def fit(self) -> None: def validation_step(self, batch: VideoBatch, *args, **kwargs) -> STEP_OUTPUT: """Perform the validation step of AI-VAD. - Extract boxes and box scores.. + Extract boxes and box scores from the input batch. Args: - batch (Batch): Input batch - *args: Arguments. - **kwargs: Keyword arguments. + batch (VideoBatch): Input batch containing video frames and metadata. + *args: Additional arguments (unused). + **kwargs: Additional keyword arguments (unused). Returns: - Batch dictionary with added boxes and box scores. + STEP_OUTPUT: Batch dictionary with added predictions and anomaly maps. """ del args, kwargs # Unused arguments. @@ -154,15 +218,19 @@ def validation_step(self, batch: VideoBatch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, Any]: - """AI-VAD specific trainer arguments.""" + """Get AI-VAD specific trainer arguments. + + Returns: + dict[str, Any]: Dictionary of trainer arguments. + """ return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} @property def learning_type(self) -> LearningType: - """Return the learning type of the model. + """Get the learning type of the model. Returns: - LearningType: Learning type of the model. + LearningType: Learning type of the model (ONE_CLASS). """ return LearningType.ONE_CLASS @@ -172,11 +240,22 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P AI-VAD does not need a pre-processor or transforms, as the region- and feature-extractors apply their own transforms. + + Args: + image_size (tuple[int, int] | None, optional): Image size (unused). + Defaults to ``None``. + + Returns: + PreProcessor: Empty pre-processor instance. """ del image_size return PreProcessor() # A pre-processor with no transforms. @staticmethod def configure_post_processor() -> PostProcessor: - """Return the default post-processor for AI-VAD.""" + """Configure the post-processor for AI-VAD. + + Returns: + PostProcessor: One-class post-processor instance. + """ return OneClassPostProcessor() diff --git a/src/anomalib/models/video/ai_vad/regions.py b/src/anomalib/models/video/ai_vad/regions.py index 441af32493..0ca7a4bed4 100644 --- a/src/anomalib/models/video/ai_vad/regions.py +++ b/src/anomalib/models/video/ai_vad/regions.py @@ -1,4 +1,20 @@ -"""Regions extraction module of AI-VAD model implementation.""" +"""Regions extraction module of AI-VAD model implementation. + +This module implements the region extraction stage of the AI-VAD model. It extracts +regions of interest from video frames using object detection and foreground +detection. + +Example: + >>> from anomalib.models.video.ai_vad.regions import RegionExtractor + >>> import torch + >>> extractor = RegionExtractor() + >>> frames = torch.randn(32, 2, 3, 256, 256) # (N, L, C, H, W) + >>> regions = extractor(frames) + +The module provides the following components: + - :class:`RegionExtractor`: Main class that handles region extraction using + object detection and foreground detection +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -17,23 +33,35 @@ class RegionExtractor(nn.Module): """Region extractor for AI-VAD. + This class extracts regions of interest from video frames using object detection and + foreground detection. It uses a Mask R-CNN model for object detection and can + optionally detect foreground regions based on frame differences. + Args: - box_score_thresh (float): Confidence threshold for bounding box predictions. - Defaults to ``0.8``. - persons_only (bool): When enabled, only regions labeled as person are included. - Defaults to ``False``. - min_bbox_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. - Defaults to ``100``. - max_bbox_overlap (float): Maximum allowed overlap between bounding boxes. - Defaults to ``0.65``. - enable_foreground_detections (bool): Add additional foreground detections based on pixel difference between - consecutive frames. + box_score_thresh (float, optional): Confidence threshold for bounding box + predictions. Defaults to ``0.8``. + persons_only (bool, optional): When enabled, only regions labeled as person are + included. Defaults to ``False``. + min_bbox_area (int, optional): Minimum bounding box area. Regions with a surface + area lower than this value are excluded. Defaults to ``100``. + max_bbox_overlap (float, optional): Maximum allowed overlap between bounding + boxes. Defaults to ``0.65``. + enable_foreground_detections (bool, optional): Add additional foreground + detections based on pixel difference between consecutive frames. Defaults to ``True``. - foreground_kernel_size (int): Gaussian kernel size used in foreground detection. - Defaults to ``3``. - foreground_binary_threshold (int): Value between 0 and 255 which acts as binary threshold in foreground - detection. - Defaults to ``18``. + foreground_kernel_size (int, optional): Gaussian kernel size used in foreground + detection. Defaults to ``3``. + foreground_binary_threshold (int, optional): Value between 0 and 255 which acts + as binary threshold in foreground detection. Defaults to ``18``. + + Example: + >>> import torch + >>> from anomalib.models.video.ai_vad.regions import RegionExtractor + >>> extractor = RegionExtractor() + >>> first_frame = torch.randn(2, 3, 256, 256) # (N, C, H, W) + >>> last_frame = torch.randn(2, 3, 256, 256) # (N, C, H, W) + >>> regions = extractor(first_frame, last_frame) + >>> # Returns list of dicts with keys: boxes, labels, scores, masks """ def __init__( @@ -61,13 +89,24 @@ def __init__( def forward(self, first_frame: torch.Tensor, last_frame: torch.Tensor) -> list[dict]: """Perform forward-pass through region extractor. + The forward pass consists of: + 1. Object detection on the last frame using Mask R-CNN + 2. Optional foreground detection by comparing first and last frames + 3. Post-processing to filter and refine detections + Args: - first_frame (torch.Tensor): Batch of input images of shape (N, C, H, W) + first_frame (torch.Tensor): Batch of input images of shape ``(N, C, H, W)`` forming the first frames in the clip. - last_frame (torch.Tensor): Batch of input images of shape (N, C, H, W) forming the last frame in the clip. + last_frame (torch.Tensor): Batch of input images of shape ``(N, C, H, W)`` + forming the last frame in the clip. Returns: - list[dict]: List of Mask RCNN predictions for each image in the batch. + list[dict]: List of Mask R-CNN predictions for each image in the batch. Each + dict contains: + - boxes (torch.Tensor): Detected bounding boxes + - labels (torch.Tensor): Class labels for each detection + - scores (torch.Tensor): Confidence scores for each detection + - masks (torch.Tensor): Instance segmentation masks """ with torch.no_grad(): regions = self.backbone(last_frame) @@ -93,21 +132,30 @@ def _add_foreground_boxes( ) -> list[dict[str, torch.Tensor]]: """Add any foreground regions that were not detected by the region extractor. - This method adds regions that likely belong to the foreground of the video scene, but were not detected by the - region extractor module. The foreground pixels are determined by taking the pixel difference between two - consecutive video frames and applying a binary threshold. The final detections consist of all connected - components in the foreground that do not fall in one of the bounding boxes predicted by the region extractor. + This method adds regions that likely belong to the foreground of the video + scene, but were not detected by the region extractor module. The foreground + pixels are determined by taking the pixel difference between two consecutive + video frames and applying a binary threshold. The final detections consist of + all connected components in the foreground that do not fall in one of the + bounding boxes predicted by the region extractor. Args: - regions (list[dict[str, torch.Tensor]]): Region detections for a batch of images, generated by the region - extraction module. - first_frame (torch.Tensor): video frame at time t-1 - last_frame (torch.Tensor): Video frame time t - kernel_size (int): Kernel size for Gaussian smoothing applied to input frames - binary_threshold (int): Binary threshold used in foreground detection, should be in range [0, 255] + regions (list[dict[str, torch.Tensor]]): Region detections for a batch of + images, generated by the region extraction module. + first_frame (torch.Tensor): Video frame at time t-1 + last_frame (torch.Tensor): Video frame at time t + kernel_size (int): Kernel size for Gaussian smoothing applied to input + frames + binary_threshold (int): Binary threshold used in foreground detection, + should be in range ``[0, 255]`` Returns: - list[dict[str, torch.Tensor]]: region detections with foreground regions appended + list[dict[str, torch.Tensor]]: Region detections with foreground regions + appended. Each dict contains: + - boxes (torch.Tensor): Updated bounding boxes + - labels (torch.Tensor): Updated class labels + - scores (torch.Tensor): Updated confidence scores + - masks (torch.Tensor): Updated instance masks """ # apply gaussian blur to first and last frame first_frame = gaussian_blur(first_frame, [kernel_size, kernel_size]) @@ -157,14 +205,16 @@ def _add_foreground_boxes( def post_process_bbox_detections(self, regions: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]: """Post-process the region detections. - The region detections are filtered based on class label, bbox area and overlap with other regions. + The region detections are filtered based on class label, bbox area and overlap + with other regions. Args: - regions (list[dict[str, torch.Tensor]]): Region detections for a batch of images, generated by the region - extraction module. + regions (list[dict[str, torch.Tensor]]): Region detections for a batch of + images, generated by the region extraction module. Returns: - list[dict[str, torch.Tensor]]: Filtered regions + list[dict[str, torch.Tensor]]: Filtered regions containing only valid + detections based on the filtering criteria. """ filtered_regions_list = [] for img_regions in regions: @@ -175,13 +225,15 @@ def post_process_bbox_detections(self, regions: list[dict[str, torch.Tensor]]) - return filtered_regions_list def _keep_only_persons(self, regions: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """Remove all region detections that are not labeled as a person by the region extractor. + """Remove all region detections that are not labeled as a person. Args: - regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. + regions (dict[str, torch.Tensor]): Region detections for a single image in + the batch. Returns: - dict[str, torch.Tensor]: Region detections from which non-person objects have been removed. + dict[str, torch.Tensor]: Region detections from which non-person objects + have been removed. """ keep = torch.where(regions["labels"] == PERSON_LABEL) return self.subsample_regions(regions, keep) @@ -190,11 +242,14 @@ def _filter_by_area(self, regions: dict[str, torch.Tensor], min_area: int) -> di """Remove all regions with a surface area smaller than the specified value. Args: - regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. - min_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. + regions (dict[str, torch.Tensor]): Region detections for a single image in + the batch. + min_area (int): Minimum bounding box area. Regions with a surface area + lower than this value are excluded. Returns: - dict[str, torch.Tensor]: Region detections from which small regions have been removed. + dict[str, torch.Tensor]: Region detections from which small regions have + been removed. """ areas = box_area(regions["boxes"]) keep = torch.where(areas > min_area) @@ -203,16 +258,20 @@ def _filter_by_area(self, regions: dict[str, torch.Tensor], min_area: int) -> di def _delete_overlapping_boxes(self, regions: dict[str, torch.Tensor], threshold: float) -> dict[str, torch.Tensor]: """Delete overlapping bounding boxes. - For each bounding box, the overlap with all other bounding boxes relative to their own surface area is computed. - When the relative overlap with any other box is higher than the specified threshold, the box is removed. when - both boxes have a relative overlap higher than the threshold, only the smaller box is removed. + For each bounding box, the overlap with all other bounding boxes relative to + their own surface area is computed. When the relative overlap with any other + box is higher than the specified threshold, the box is removed. When both boxes + have a relative overlap higher than the threshold, only the smaller box is + removed. Args: - regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. + regions (dict[str, torch.Tensor]): Region detections for a single image in + the batch. threshold (float): Maximum allowed overlap between bounding boxes. Returns: - dict[str, torch.Tensor]: Region detections from which overlapping regions have been removed. + dict[str, torch.Tensor]: Region detections from which overlapping regions + have been removed. """ # sort boxes by area areas = box_area(regions["boxes"]) @@ -240,11 +299,13 @@ def subsample_regions(regions: dict[str, torch.Tensor], indices: torch.Tensor) - """Subsample the items in a region dictionary based on a Tensor of indices. Args: - regions (dict[str, torch.Tensor]): Region detections for a single image in the batch. + regions (dict[str, torch.Tensor]): Region detections for a single image in + the batch. indices (torch.Tensor): Indices of region detections that should be kept. Returns: - dict[str, torch.Tensor]: Subsampled region detections. + dict[str, torch.Tensor]: Subsampled region detections containing only the + specified indices. """ new_regions_dict = {} for key, value in regions.items(): diff --git a/src/anomalib/models/video/ai_vad/torch_model.py b/src/anomalib/models/video/ai_vad/torch_model.py index 2679470d01..dfe3e563f6 100644 --- a/src/anomalib/models/video/ai_vad/torch_model.py +++ b/src/anomalib/models/video/ai_vad/torch_model.py @@ -1,6 +1,31 @@ """PyTorch model for AI-VAD model implementation. -Paper https://arxiv.org/pdf/2212.00789.pdf +This module implements the AI-VAD model as described in the paper +"AI-VAD: Attribute-based Representations for Accurate and Interpretable Video +Anomaly Detection." + +Example: + >>> from anomalib.models.video import AiVad + >>> from anomalib.data import Avenue + >>> from anomalib.data.utils import VideoTargetFrame + >>> from anomalib.engine import Engine + + >>> # Initialize model and datamodule + >>> datamodule = Avenue( + ... clip_length_in_frames=2, + ... frames_between_clips=1, + ... target_frame=VideoTargetFrame.LAST + ... ) + >>> model = AiVad() + + >>> # Train using the engine + >>> engine = Engine() + >>> engine.fit(model=model, datamodule=datamodule) + +Reference: + Tal Reiss, Yedid Hoshen. "AI-VAD: Attribute-based Representations for Accurate and + Interpretable Video Anomaly Detection." arXiv preprint arXiv:2212.00789 (2022). + https://arxiv.org/pdf/2212.00789.pdf """ # Copyright (C) 2023-2024 Intel Corporation @@ -20,37 +45,55 @@ class AiVadModel(nn.Module): """AI-VAD model. + The model consists of several stages: + 1. Flow extraction between consecutive frames + 2. Region extraction using object detection and foreground detection + 3. Feature extraction including velocity, pose and deep features + 4. Density estimation for anomaly detection + Args: - box_score_thresh (float): Confidence threshold for region extraction stage. - Defaults to ``0.8``. - persons_only (bool): When enabled, only regions labeled as person are included. - Defaults to ``False``. - min_bbox_area (int): Minimum bounding box area. Regions with a surface area lower than this value are excluded. - Defaults to ``100``. - max_bbox_overlap (float): Maximum allowed overlap between bounding boxes. - Defaults to ``0.65``. - enable_foreground_detections (bool): Add additional foreground detections based on pixel difference between - consecutive frames. - Defaults to ``True``. - foreground_kernel_size (int): Gaussian kernel size used in foreground detection. - Defaults to ``3``. - foreground_binary_threshold (int): Value between 0 and 255 which acts as binary threshold in foreground - detection. - Defaults to ``18``. - n_velocity_bins (int): Number of discrete bins used for velocity histogram features. - Defaults to ``8``. - use_velocity_features (bool): Flag indicating if velocity features should be used. - Defaults to ``True``. - use_pose_features (bool): Flag indicating if pose features should be used. + box_score_thresh (float, optional): Confidence threshold for region extraction + stage. Defaults to ``0.8``. + persons_only (bool, optional): When enabled, only regions labeled as person are + included. Defaults to ``False``. + min_bbox_area (int, optional): Minimum bounding box area. Regions with a surface + area lower than this value are excluded. Defaults to ``100``. + max_bbox_overlap (float, optional): Maximum allowed overlap between bounding + boxes. Defaults to ``0.65``. + enable_foreground_detections (bool, optional): Add additional foreground + detections based on pixel difference between consecutive frames. Defaults to ``True``. - use_deep_features (bool): Flag indicating if deep features should be used. - Defaults to ``True``. - n_components_velocity (int): Number of components used by GMM density estimation for velocity features. - Defaults to ``5``. - n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. - Defaults to ``1``. - n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. - Defaults to ``1``. + foreground_kernel_size (int, optional): Gaussian kernel size used in foreground + detection. Defaults to ``3``. + foreground_binary_threshold (int, optional): Value between 0 and 255 which acts + as binary threshold in foreground detection. Defaults to ``18``. + n_velocity_bins (int, optional): Number of discrete bins used for velocity + histogram features. Defaults to ``8``. + use_velocity_features (bool, optional): Flag indicating if velocity features + should be used. Defaults to ``True``. + use_pose_features (bool, optional): Flag indicating if pose features should be + used. Defaults to ``True``. + use_deep_features (bool, optional): Flag indicating if deep features should be + used. Defaults to ``True``. + n_components_velocity (int, optional): Number of components used by GMM density + estimation for velocity features. Defaults to ``5``. + n_neighbors_pose (int, optional): Number of neighbors used in KNN density + estimation for pose features. Defaults to ``1``. + n_neighbors_deep (int, optional): Number of neighbors used in KNN density + estimation for deep features. Defaults to ``1``. + + Raises: + ValueError: If none of the feature types (velocity, pose, deep) are enabled. + + Example: + >>> from anomalib.models.video.ai_vad.torch_model import AiVadModel + >>> model = AiVadModel() + >>> batch = torch.randn(32, 2, 3, 256, 256) # (N, L, C, H, W) + >>> output = model(batch) + >>> output.pred_score.shape + torch.Size([32]) + >>> output.anomaly_map.shape + torch.Size([32, 256, 256]) """ def __init__( @@ -110,13 +153,31 @@ def __init__( def forward(self, batch: torch.Tensor) -> InferenceBatch: """Forward pass through AI-VAD model. + The forward pass consists of the following steps: + 1. Extract first and last frame from input clip + 2. Extract optical flow between frames and detect regions of interest + 3. Extract features (velocity, pose, deep) for each region + 4. Estimate density and compute anomaly scores + Args: - batch (torch.Tensor): Input image of shape (N, L, C, H, W) + batch (torch.Tensor): Input tensor of shape ``(N, L, C, H, W)`` where: + - ``N``: Batch size + - ``L``: Sequence length + - ``C``: Number of channels + - ``H``: Height + - ``W``: Width Returns: - list[torch.Tensor]: List of bbox locations for each image. - list[torch.Tensor]: List of per-bbox anomaly scores for each image. - list[torch.Tensor]: List of per-image anomaly scores. + InferenceBatch: Batch containing: + - ``pred_score``: Per-image anomaly scores of shape ``(N,)`` + - ``anomaly_map``: Per-pixel anomaly scores of shape ``(N, H, W)`` + + Example: + >>> batch = torch.randn(32, 2, 3, 256, 256) + >>> model = AiVadModel() + >>> output = model(batch) + >>> output.pred_score.shape, output.anomaly_map.shape + (torch.Size([32]), torch.Size([32, 256, 256])) """ self.flow_extractor.eval() self.region_extractor.eval() diff --git a/src/anomalib/pipelines/__init__.py b/src/anomalib/pipelines/__init__.py index 0ca537d4de..3612aed388 100644 --- a/src/anomalib/pipelines/__init__.py +++ b/src/anomalib/pipelines/__init__.py @@ -1,4 +1,27 @@ -"""Pipelines for end-to-end usecases.""" +"""Pipelines for end-to-end anomaly detection use cases. + +This module provides high-level pipeline implementations for common anomaly detection +workflows: + +- :class:`Benchmark`: Pipeline for benchmarking model performance across datasets + +The pipelines handle: + - Configuration and setup + - Data loading and preprocessing + - Model training and evaluation + - Result collection and analysis + - Logging and visualization + +Example: + >>> from anomalib.pipelines import Benchmark + >>> benchmark = Benchmark(config_path="config.yaml") + >>> results = benchmark.run() + +The pipelines leverage components from :mod:`anomalib.pipelines.components` for: + - Job management and execution + - Parameter grid search + - Result gathering +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/benchmark/__init__.py b/src/anomalib/pipelines/benchmark/__init__.py index bfb34aded2..759ba32276 100644 --- a/src/anomalib/pipelines/benchmark/__init__.py +++ b/src/anomalib/pipelines/benchmark/__init__.py @@ -1,4 +1,23 @@ -"""Benchmarking.""" +"""Benchmarking pipeline for anomaly detection models. + +This module provides functionality for benchmarking anomaly detection models in +anomalib. The benchmarking pipeline allows evaluating and comparing multiple models +across different datasets and metrics. + +Example: + >>> from anomalib.pipelines import Benchmark + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim, Patchcore + + >>> # Initialize benchmark with models and datasets + >>> benchmark = Benchmark( + ... models=[Padim(), Patchcore()], + ... datasets=[MVTec(category="bottle"), MVTec(category="cable")] + ... ) + + >>> # Run benchmark + >>> results = benchmark.run() +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/benchmark/generator.py b/src/anomalib/pipelines/benchmark/generator.py index 988e0111b7..2da6f93dfd 100644 --- a/src/anomalib/pipelines/benchmark/generator.py +++ b/src/anomalib/pipelines/benchmark/generator.py @@ -1,4 +1,22 @@ -"""Benchmark job generator.""" +"""Benchmark job generator for running model benchmarking experiments. + +This module provides functionality for generating benchmark jobs that evaluate model +performance. It generates jobs based on provided configurations for models, +datasets and other parameters. + +Example: + >>> from anomalib.pipelines.benchmark.generator import BenchmarkJobGenerator + >>> generator = BenchmarkJobGenerator(accelerator="gpu") + >>> args = { + ... "seed": 42, + ... "model": {"class_path": "Padim"}, + ... "data": {"class_path": "MVTec", "init_args": {"category": "bottle"}} + ... } + >>> jobs = list(generator.generate_jobs(args, None)) + +The generator creates :class:`BenchmarkJob` instances that can be executed to run +benchmarking experiments with specified models and datasets. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -17,10 +35,25 @@ class BenchmarkJobGenerator(JobGenerator): - """Generate BenchmarkJob. + """Generate benchmark jobs for evaluating model performance. + + This class generates benchmark jobs based on provided configurations for models, + datasets and other parameters. Each job evaluates a specific model-dataset + combination. Args: - accelerator (str): The accelerator to use. + accelerator (str): Type of accelerator to use for running the jobs (e.g. + ``"cpu"``, ``"gpu"``). + + Example: + >>> from anomalib.pipelines.benchmark.generator import BenchmarkJobGenerator + >>> generator = BenchmarkJobGenerator(accelerator="gpu") + >>> args = { + ... "seed": 42, + ... "model": {"class_path": "Padim"}, + ... "data": {"class_path": "MVTec", "init_args": {"category": "bottle"}} + ... } + >>> jobs = list(generator.generate_jobs(args, None)) """ def __init__(self, accelerator: str) -> None: @@ -28,7 +61,11 @@ def __init__(self, accelerator: str) -> None: @property def job_class(self) -> type: - """Return the job class.""" + """Get the job class used by this generator. + + Returns: + type: The :class:`BenchmarkJob` class. + """ return BenchmarkJob @hide_output @@ -37,7 +74,27 @@ def generate_jobs( args: dict, previous_stage_result: PREV_STAGE_RESULT, ) -> Generator[BenchmarkJob, None, None]: - """Return iterator based on the arguments.""" + """Generate benchmark jobs from the provided arguments. + + Args: + args (dict): Dictionary containing job configuration including model, + dataset and other parameters. + previous_stage_result (PREV_STAGE_RESULT): Results from previous pipeline + stage (unused). + + Yields: + Generator[BenchmarkJob, None, None]: Generator yielding benchmark job + instances. + + Example: + >>> generator = BenchmarkJobGenerator(accelerator="cpu") + >>> args = { + ... "seed": 42, + ... "model": {"class_path": "Padim"}, + ... "data": {"class_path": "MVTec"} + ... } + >>> jobs = list(generator.generate_jobs(args, None)) + """ del previous_stage_result # Not needed for this job for _container in get_iterator_from_grid_dict(args): # Pass experimental configs as a flatten dictionary to the job runner. diff --git a/src/anomalib/pipelines/benchmark/job.py b/src/anomalib/pipelines/benchmark/job.py index d98b689304..dccacf77e7 100644 --- a/src/anomalib/pipelines/benchmark/job.py +++ b/src/anomalib/pipelines/benchmark/job.py @@ -1,4 +1,32 @@ -"""Benchmarking job.""" +"""Benchmarking job for evaluating model performance. + +This module provides functionality for running individual benchmarking jobs that +evaluate model performance on specific datasets. Each job runs a model on a dataset +and collects performance metrics. + +Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim + >>> from anomalib.pipelines.benchmark.job import BenchmarkJob + + >>> # Initialize model, datamodule and job + >>> model = Padim() + >>> datamodule = MVTec(category="bottle") + >>> job = BenchmarkJob( + ... accelerator="gpu", + ... model=model, + ... datamodule=datamodule, + ... seed=42, + ... flat_cfg={"model.name": "padim"} + ... ) + + >>> # Run the benchmark job + >>> results = job.run() + +The job executes model training and evaluation, collecting metrics like accuracy, +F1-score, and inference time. Results are returned in a standardized format for +comparison across different model-dataset combinations. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -25,14 +53,42 @@ class BenchmarkJob(Job): - """Benchmarking job. + """Benchmarking job for evaluating anomaly detection models. + + This class implements a benchmarking job that evaluates model performance by + training and testing on a given dataset. It collects metrics like accuracy, + F1-score, and timing information. Args: - accelerator (str): The accelerator to use. - model (AnomalibModule): The model to use. - datamodule (AnomalibDataModule): The data module to use. - seed (int): The seed to use. - flat_cfg (dict): The flat dictionary of configs with dotted keys. + accelerator (str): Type of accelerator to use for computation (e.g. + ``"cpu"``, ``"gpu"``). + model (AnomalibModule): Anomaly detection model instance to benchmark. + datamodule (AnomalibDataModule): Data module providing the dataset. + seed (int): Random seed for reproducibility. + flat_cfg (dict): Flattened configuration dictionary with dotted keys. + + Example: + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim + >>> from anomalib.pipelines.benchmark.job import BenchmarkJob + + >>> # Initialize model, datamodule and job + >>> model = Padim() + >>> datamodule = MVTec(category="bottle") + >>> job = BenchmarkJob( + ... accelerator="gpu", + ... model=model, + ... datamodule=datamodule, + ... seed=42, + ... flat_cfg={"model.name": "padim"} + ... ) + + >>> # Run the benchmark job + >>> results = job.run() + + The job executes model training and evaluation, collecting metrics like + accuracy, F1-score, and inference time. Results are returned in a standardized + format for comparison across different model-dataset combinations. """ name = "benchmark" @@ -57,7 +113,23 @@ def run( self, task_id: int | None = None, ) -> dict[str, Any]: - """Run the benchmark.""" + """Run the benchmark job. + + This method executes the full benchmarking pipeline including model + training and testing. It measures execution time for different stages and + collects performance metrics. + + Args: + task_id (int | None, optional): ID of the task when running in + distributed mode. When provided, the job will use the specified + device. Defaults to ``None``. + + Returns: + dict[str, Any]: Dictionary containing benchmark results including: + - Timing information (job, fit and test duration) + - Model configuration + - Performance metrics from testing + """ job_start_time = time.time() devices: str | list[int] = "auto" if task_id is not None: @@ -93,7 +165,16 @@ def run( @staticmethod def collect(results: list[dict[str, Any]]) -> pd.DataFrame: - """Gather the results returned from run.""" + """Collect and aggregate results from multiple benchmark runs. + + Args: + results (list[dict[str, Any]]): List of result dictionaries from + individual benchmark runs. + + Returns: + pd.DataFrame: DataFrame containing aggregated results with each row + representing a benchmark run. + """ output: dict[str, Any] = {} for key in results[0]: output[key] = [] @@ -104,7 +185,14 @@ def collect(results: list[dict[str, Any]]) -> pd.DataFrame: @staticmethod def save(result: pd.DataFrame) -> None: - """Save the result to a csv file.""" + """Save benchmark results to CSV file. + + The results are saved in the ``runs/benchmark/YYYY-MM-DD-HH_MM_SS`` + directory. The method also prints a tabular view of the results. + + Args: + result (pd.DataFrame): DataFrame containing benchmark results to save. + """ BenchmarkJob._print_tabular_results(result) file_path = Path("runs") / BenchmarkJob.name / datetime.now().strftime("%Y-%m-%d-%H_%M_%S") / "results.csv" file_path.parent.mkdir(parents=True, exist_ok=True) @@ -113,7 +201,12 @@ def save(result: pd.DataFrame) -> None: @staticmethod def _print_tabular_results(gathered_result: pd.DataFrame) -> None: - """Print the tabular results.""" + """Print benchmark results in a formatted table. + + Args: + gathered_result (pd.DataFrame): DataFrame containing results to + display. + """ if gathered_result is not None: console = Console() table = Table(title=f"{BenchmarkJob.name} Results", show_header=True, header_style="bold magenta") diff --git a/src/anomalib/pipelines/benchmark/pipeline.py b/src/anomalib/pipelines/benchmark/pipeline.py index 3b27caeec1..9e31c4e043 100644 --- a/src/anomalib/pipelines/benchmark/pipeline.py +++ b/src/anomalib/pipelines/benchmark/pipeline.py @@ -1,4 +1,27 @@ -"""Benchmarking.""" +"""Benchmarking pipeline for evaluating anomaly detection models. + +This module provides functionality for running benchmarking experiments that evaluate +and compare multiple anomaly detection models. The benchmarking pipeline supports +running experiments in parallel across multiple GPUs when available. + +Example: + >>> from anomalib.pipelines import Benchmark + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim, Patchcore + + >>> # Initialize benchmark with models and datasets + >>> benchmark = Benchmark( + ... models=[Padim(), Patchcore()], + ... datasets=[MVTec(category="bottle"), MVTec(category="cable")] + ... ) + + >>> # Run benchmark + >>> results = benchmark.run() + +The pipeline handles setting up appropriate runners based on available hardware, +using parallel execution when multiple GPUs are available and serial execution +otherwise. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,11 +35,51 @@ class Benchmark(Pipeline): - """Benchmarking pipeline.""" + """Benchmarking pipeline for evaluating anomaly detection models. + + This pipeline handles running benchmarking experiments that evaluate and compare + multiple anomaly detection models. It supports both serial and parallel execution + depending on available hardware. + + Example: + >>> from anomalib.pipelines import Benchmark + >>> from anomalib.data import MVTec + >>> from anomalib.models import Padim, Patchcore + + >>> # Initialize benchmark with models and datasets + >>> benchmark = Benchmark( + ... models=[Padim(), Patchcore()], + ... datasets=[MVTec(category="bottle"), MVTec(category="cable")] + ... ) + + >>> # Run benchmark + >>> results = benchmark.run() + """ @staticmethod def _setup_runners(args: dict) -> list[Runner]: - """Setup the runners for the pipeline.""" + """Set up the appropriate runners for benchmark execution. + + This method configures either serial or parallel runners based on the + specified accelerator(s) and available hardware. For CUDA devices, parallel + execution is used when multiple GPUs are available. + + Args: + args (dict): Dictionary containing configuration arguments. Must include + an ``"accelerator"`` key specifying either a single accelerator or + list of accelerators to use. + + Returns: + list[Runner]: List of configured runner instances. + + Raises: + ValueError: If an unsupported accelerator type is specified. Only + ``"cpu"`` and ``"cuda"`` are supported. + + Example: + >>> args = {"accelerator": "cuda"} + >>> runners = Benchmark._setup_runners(args) + """ accelerators = args["accelerator"] if isinstance(args["accelerator"], list) else [args["accelerator"]] runners: list[Runner] = [] for accelerator in accelerators: diff --git a/src/anomalib/pipelines/components/__init__.py b/src/anomalib/pipelines/components/__init__.py index 1350937639..e831487797 100644 --- a/src/anomalib/pipelines/components/__init__.py +++ b/src/anomalib/pipelines/components/__init__.py @@ -1,4 +1,25 @@ -"""Utilities for the pipeline modules.""" +"""Components for building and executing pipelines. + +This module provides core components for constructing and running data processing +pipelines: + +- :class:`Job`: Base class for defining pipeline jobs +- :class:`JobGenerator`: Creates job instances for pipeline stages +- :class:`Pipeline`: Manages execution flow between pipeline stages +- :class:`Runner`: Executes jobs serially or in parallel + +Example: + >>> from anomalib.pipelines.components import Pipeline, JobGenerator + >>> generator = JobGenerator() + >>> pipeline = Pipeline([generator]) + >>> pipeline.run({"param": "value"}) + +The components handle: + - Job creation and configuration + - Pipeline stage organization + - Job execution and result gathering + - Error handling and logging +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/components/base/__init__.py b/src/anomalib/pipelines/components/base/__init__.py index 90682e9cd0..4d1ec79baa 100644 --- a/src/anomalib/pipelines/components/base/__init__.py +++ b/src/anomalib/pipelines/components/base/__init__.py @@ -1,4 +1,28 @@ -"""Base classes for pipelines.""" +"""Base classes for pipeline components in anomalib. + +This module provides the core base classes used to build pipelines in anomalib: + +- :class:`Job`: Base class for individual pipeline jobs +- :class:`JobGenerator`: Base class for generating pipeline jobs +- :class:`Runner`: Base class for executing pipeline jobs +- :class:`Pipeline`: Base class for creating complete pipelines + +Example: + >>> from anomalib.pipelines.components.base import Pipeline + >>> from anomalib.pipelines.components.base import Runner + >>> from anomalib.pipelines.components.base import Job, JobGenerator + + >>> # Create custom pipeline components + >>> class MyJob(Job): + ... pass + >>> class MyRunner(Runner): + ... pass + >>> class MyPipeline(Pipeline): + ... pass + +The base classes provide the foundation for building modular and extensible +pipelines for tasks like training, inference and benchmarking. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/components/base/job.py b/src/anomalib/pipelines/components/base/job.py index f10278d0f1..bdd69521e2 100644 --- a/src/anomalib/pipelines/components/base/job.py +++ b/src/anomalib/pipelines/components/base/job.py @@ -1,4 +1,34 @@ -"""Job from which all the jobs inherit from.""" +"""Base job class that defines the interface for pipeline jobs. + +This module provides the base :class:`Job` class that all pipeline jobs inherit from. Jobs +are atomic units of work that can be executed independently, either serially or in +parallel. + +Example: + >>> from anomalib.pipelines.components.base import Job + >>> class MyJob(Job): + ... name = "my_job" + ... def run(self, task_id=None): + ... # Implement job logic + ... pass + ... @staticmethod + ... def collect(results): + ... # Combine results from multiple runs + ... pass + ... @staticmethod + ... def save(results): + ... # Save final results + ... pass + +The base job interface defines three key methods that subclasses must implement: + +- :meth:`run`: Execute the core job logic +- :meth:`collect`: Gather and combine results from multiple job runs +- :meth:`save`: Save or export the final collected results + +Jobs can be used as building blocks in pipelines for tasks like training, +inference, or benchmarking. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/components/base/pipeline.py b/src/anomalib/pipelines/components/base/pipeline.py index 850c64afcb..8203f79b99 100644 --- a/src/anomalib/pipelines/components/base/pipeline.py +++ b/src/anomalib/pipelines/components/base/pipeline.py @@ -1,4 +1,27 @@ -"""Base class for pipeline.""" +"""Base class for building pipelines in anomalib. + +This module provides the abstract base class for creating pipelines that can execute +jobs in a configurable way. Pipelines handle setting up runners, parsing configs, +and orchestrating job execution. + +Example: + >>> from anomalib.pipelines.components.base import Pipeline + >>> class MyPipeline(Pipeline): + ... def _setup_runners(self, args: dict) -> list[Runner]: + ... # Configure and return list of runners + ... pass + ... def run(self, args: Namespace | None = None): + ... # Execute pipeline logic + ... pass + +The base pipeline interface defines key methods that subclasses must implement: + +- :meth:`_setup_runners`: Configure the runners that will execute pipeline jobs +- :meth:`run`: Execute the core pipeline logic + +Pipelines can be used to implement workflows like training, inference, or +benchmarking by composing jobs and runners in a modular way. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/components/base/runner.py b/src/anomalib/pipelines/components/base/runner.py index cee46dfacb..86aa7a4222 100644 --- a/src/anomalib/pipelines/components/base/runner.py +++ b/src/anomalib/pipelines/components/base/runner.py @@ -1,4 +1,31 @@ -"""Base runner.""" +"""Base runner class for executing pipeline jobs. + +This module provides the abstract base class for runners that execute pipeline jobs. +Runners handle the mechanics of job execution, whether serial or parallel. + +Example: + >>> from anomalib.pipelines.components.base import Runner + >>> from anomalib.pipelines.components.base import JobGenerator + >>> class MyRunner(Runner): + ... def run(self, args: dict, prev_stage_results=None): + ... # Implement runner logic + ... pass + + >>> # Create and use runner + >>> generator = JobGenerator() + >>> runner = MyRunner(generator) + >>> results = runner.run({"param": "value"}) + +The base runner interface defines the core :meth:`run` method that subclasses must +implement to execute jobs. Runners work with job generators to create and execute +pipeline jobs. + +Runners can implement different execution strategies like: + +- Serial execution of jobs one after another +- Parallel execution across multiple processes +- Distributed execution across machines +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/components/runners/__init__.py b/src/anomalib/pipelines/components/runners/__init__.py index 27ef21046f..1527244ac1 100644 --- a/src/anomalib/pipelines/components/runners/__init__.py +++ b/src/anomalib/pipelines/components/runners/__init__.py @@ -1,4 +1,22 @@ -"""Executor for running a single job.""" +"""Runners for executing pipeline jobs. + +This module provides runner implementations for executing pipeline jobs in different +ways: + +- :class:`SerialRunner`: Executes jobs sequentially on a single device +- :class:`ParallelRunner`: Executes jobs in parallel across multiple devices + +Example: + >>> from anomalib.pipelines.components.runners import SerialRunner + >>> from anomalib.pipelines.components.base import JobGenerator + >>> generator = JobGenerator() + >>> runner = SerialRunner(generator) + >>> results = runner.run({"param": "value"}) + +The runners handle the mechanics of job execution while working with job generators +to create and execute pipeline jobs. They implement the :class:`Runner` interface +defined in ``anomalib.pipelines.components.base``. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/components/runners/parallel.py b/src/anomalib/pipelines/components/runners/parallel.py index 148980a6c2..4064edf71c 100644 --- a/src/anomalib/pipelines/components/runners/parallel.py +++ b/src/anomalib/pipelines/components/runners/parallel.py @@ -1,4 +1,26 @@ -"""Process pool executor.""" +"""Parallel execution of pipeline jobs using process pools. + +This module provides the :class:`ParallelRunner` class for executing pipeline jobs in +parallel across multiple processes. It uses Python's :class:`ProcessPoolExecutor` to +manage a pool of worker processes. + +Example: + >>> from anomalib.pipelines.components.runners import ParallelRunner + >>> from anomalib.pipelines.components.base import JobGenerator + >>> generator = JobGenerator() + >>> runner = ParallelRunner(generator, n_jobs=4) + >>> results = runner.run({"param": "value"}) + +The parallel runner handles: + +- Creating and managing a pool of worker processes +- Distributing jobs across available workers +- Collecting and combining results from parallel executions +- Error handling for failed jobs + +The number of parallel jobs can be configured based on available compute resources +like CPU cores or GPUs. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -22,26 +44,42 @@ class ParallelExecutionError(Exception): class ParallelRunner(Runner): - """Run the job in parallel using a process pool. + """Run jobs in parallel using a process pool. - It creates a pool of processes and submits the jobs to the pool. - This is useful when you have fixed resources that you want to re-use. - Once a process is done, it is replaced with a new job. + This runner executes jobs concurrently using a pool of worker processes. It manages + process creation, job distribution, and result collection. Args: - generator (JobGenerator): The generator that generates the jobs. - n_jobs (int): The number of jobs to run in parallel. + generator (JobGenerator): Generator that creates jobs to be executed. + n_jobs (int): Number of parallel processes to use. Example: - Creating a pool with the size of the number of available GPUs and submitting jobs to the pool. - >>> ParallelRunner(generator, n_jobs=torch.cuda.device_count()) - Each time a job is submitted to the pool, an additional parameter `task_id` will be passed to `job.run` method. - The job can then use this `task_id` to assign a particular device to train on. - >>> def run(self, arg1: int, arg2: nn.Module, task_id: int) -> None: - >>> device = torch.device(f"cuda:{task_id}") - >>> model = arg2.to(device) - >>> ... - + Create a pool with size matching available GPUs and submit jobs: + + >>> from anomalib.pipelines.components.runners import ParallelRunner + >>> from anomalib.pipelines.components.base import JobGenerator + >>> import torch + >>> generator = JobGenerator() + >>> runner = ParallelRunner(generator, n_jobs=torch.cuda.device_count()) + >>> results = runner.run({"param": "value"}) + + Notes: + When a job is submitted to the pool, a ``task_id`` parameter is passed to the + job's ``run()`` method. Jobs can use this ID to manage device assignment: + + .. code-block:: python + + def run(self, arg1: int, arg2: nn.Module, task_id: int) -> None: + device = torch.device(f"cuda:{task_id}") + model = arg2.to(device) + # ... rest of job logic + + The runner handles: + - Creating and managing worker processes + - Distributing jobs to available workers + - Collecting and combining results + - Error handling for failed jobs + - Resource cleanup """ def __init__(self, generator: JobGenerator, n_jobs: int) -> None: diff --git a/src/anomalib/pipelines/components/runners/serial.py b/src/anomalib/pipelines/components/runners/serial.py index 86cc3533ea..3caa274660 100644 --- a/src/anomalib/pipelines/components/runners/serial.py +++ b/src/anomalib/pipelines/components/runners/serial.py @@ -1,4 +1,32 @@ -"""Executor for running a job serially.""" +"""Serial execution of pipeline jobs. + +This module provides the :class:`SerialRunner` class for executing pipeline jobs +sequentially on a single device. It processes jobs one at a time in order. + +Example: + >>> from anomalib.pipelines.components.runners import SerialRunner + >>> from anomalib.pipelines.components.base import JobGenerator + >>> generator = JobGenerator() + >>> runner = SerialRunner(generator) + >>> results = runner.run({"param": "value"}) + +The serial runner handles: + +- Sequential execution of jobs in order +- Progress tracking with progress bars +- Result collection and combination +- Error handling for failed jobs + +This is useful when: + +- Resources are limited to a single device +- Jobs need to be executed in a specific order +- Debugging pipeline execution +- Simple workflows that don't require parallelization + +The runner implements the :class:`Runner` interface defined in +``anomalib.pipelines.components.base``. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -18,13 +46,67 @@ class SerialExecutionError(Exception): class SerialRunner(Runner): - """Serial executor for running a single job at a time.""" + """Serial executor for running jobs sequentially. + + This runner executes jobs one at a time in a sequential manner. It provides progress + tracking and error handling while running jobs serially. + + Args: + generator (JobGenerator): Generator that creates jobs to be executed. + + Example: + Create a runner and execute jobs sequentially: + + >>> from anomalib.pipelines.components.runners import SerialRunner + >>> from anomalib.pipelines.components.base import JobGenerator + >>> generator = JobGenerator() + >>> runner = SerialRunner(generator) + >>> results = runner.run({"param": "value"}) + + The runner handles: + - Sequential execution of jobs + - Progress tracking with progress bars + - Result collection and combination + - Error handling for failed jobs + """ def __init__(self, generator: JobGenerator) -> None: super().__init__(generator) def run(self, args: dict, prev_stage_results: PREV_STAGE_RESULT = None) -> GATHERED_RESULTS: - """Run the job.""" + """Execute jobs sequentially and gather results. + + This method runs each job one at a time, collecting results and handling any + failures that occur during execution. + + Args: + args (dict): Arguments specific to the job. For example, if there is a + pipeline defined where one of the job generators is hyperparameter + optimization, then the pipeline configuration file will look something + like: + + .. code-block:: yaml + + arg1: + arg2: + hpo: + param1: + param2: + ... + + In this case, ``args`` will receive a dictionary with all keys under + ``hpo``. + + prev_stage_results (PREV_STAGE_RESULT, optional): Results from the previous + pipeline stage. Used when the current stage depends on previous results. + Defaults to None. + + Returns: + GATHERED_RESULTS: Combined results from all executed jobs. + + Raises: + SerialExecutionError: If any job fails during execution. + """ results = [] failures = False logger.info(f"Running job {self.generator.job_class.name}") diff --git a/src/anomalib/pipelines/components/utils/__init__.py b/src/anomalib/pipelines/components/utils/__init__.py index 230edc6891..85f293bbbe 100644 --- a/src/anomalib/pipelines/components/utils/__init__.py +++ b/src/anomalib/pipelines/components/utils/__init__.py @@ -1,4 +1,22 @@ -"""Utils.""" +"""Utility functions for pipeline components. + +This module provides utility functions used by various pipeline components for tasks +like: + +- Grid search parameter iteration via :func:`get_iterator_from_grid_dict` +- Other utility functions for pipeline execution + +Example: + >>> from anomalib.pipelines.components.utils import get_iterator_from_grid_dict + >>> params = {"lr": [0.1, 0.01], "batch_size": [32, 64]} + >>> iterator = get_iterator_from_grid_dict(params) + >>> for config in iterator: + ... print(config) + {"lr": 0.1, "batch_size": 32} + {"lr": 0.1, "batch_size": 64} + {"lr": 0.01, "batch_size": 32} + {"lr": 0.01, "batch_size": 64} +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pipelines/components/utils/grid_search.py b/src/anomalib/pipelines/components/utils/grid_search.py index 04e481ca6a..240d12829a 100644 --- a/src/anomalib/pipelines/components/utils/grid_search.py +++ b/src/anomalib/pipelines/components/utils/grid_search.py @@ -1,4 +1,30 @@ -"""Utils for benchmarking.""" +"""Utilities for grid search parameter iteration. + +This module provides utilities for iterating over grid search parameter combinations +in a structured way. The main function :func:`get_iterator_from_grid_dict` takes a +dictionary of parameters and yields all possible combinations. + +Example: + >>> from anomalib.pipelines.components.utils import get_iterator_from_grid_dict + >>> params = { + ... "model": { + ... "backbone": {"grid": ["resnet18", "resnet50"]}, + ... "lr": {"grid": [0.001, 0.0001]} + ... } + ... } + >>> for config in get_iterator_from_grid_dict(params): + ... print(config) + {'model': {'backbone': 'resnet18', 'lr': 0.001}} + {'model': {'backbone': 'resnet18', 'lr': 0.0001}} + {'model': {'backbone': 'resnet50', 'lr': 0.001}} + {'model': {'backbone': 'resnet50', 'lr': 0.0001}} + +The module handles: + - Flattening nested parameter dictionaries + - Generating all combinations of grid parameters + - Reconstructing nested dictionary structure + - Preserving non-grid parameters +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -7,11 +33,7 @@ from itertools import product from typing import Any -from anomalib.utils.config import ( - convert_valuesview_to_tuple, - flatten_dict, - to_nested_dict, -) +from anomalib.utils.config import convert_valuesview_to_tuple, flatten_dict, to_nested_dict def get_iterator_from_grid_dict(container: dict) -> Generator[dict, Any, None]: diff --git a/src/anomalib/pipelines/types.py b/src/anomalib/pipelines/types.py index dbb1572122..a4af438d36 100644 --- a/src/anomalib/pipelines/types.py +++ b/src/anomalib/pipelines/types.py @@ -1,4 +1,20 @@ -"""Types.""" +"""Types used in pipeline components. + +This module defines type aliases used throughout the pipeline components for type +hinting and documentation. + +The following types are defined: + - ``RUN_RESULTS``: Return type of individual job runs + - ``GATHERED_RESULTS``: Combined results from multiple job runs + - ``PREV_STAGE_RESULT``: Optional results from previous pipeline stage + +Example: + >>> from anomalib.pipelines.types import RUN_RESULTS, GATHERED_RESULTS + >>> def my_job() -> RUN_RESULTS: + ... return {"metric": 0.95} + >>> def gather_results(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + ... return {"mean_metric": sum(r["metric"] for r in results) / len(results)} +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/post_processing/__init__.py b/src/anomalib/post_processing/__init__.py index 25e3ab2adf..1c68178d8e 100644 --- a/src/anomalib/post_processing/__init__.py +++ b/src/anomalib/post_processing/__init__.py @@ -1,4 +1,21 @@ -"""Anomalib post-processing module.""" +"""Post-processing module for anomaly detection results. + +This module provides post-processing functionality for anomaly detection outputs: + +- Base :class:`PostProcessor` class defining the post-processing interface +- :class:`OneClassPostProcessor` for one-class anomaly detection results + +The post-processors handle: + - Normalizing anomaly scores + - Thresholding and anomaly classification + - Mask generation and refinement + - Result aggregation and formatting + +Example: + >>> from anomalib.post_processing import OneClassPostProcessor + >>> post_processor = OneClassPostProcessor(threshold=0.5) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/post_processing/base.py b/src/anomalib/post_processing/base.py index f5b49bc8b1..2d4d378dc2 100644 --- a/src/anomalib/post_processing/base.py +++ b/src/anomalib/post_processing/base.py @@ -1,4 +1,25 @@ -"""Base class for post-processor.""" +"""Base class for post-processing anomaly detection results. + +This module provides the abstract base class :class:`PostProcessor` that defines +the interface for post-processing anomaly detection outputs. + +The post-processors handle: + - Normalizing anomaly scores + - Thresholding and anomaly classification + - Mask generation and refinement + - Result aggregation and formatting + +Example: + >>> from anomalib.post_processing import PostProcessor + >>> class MyPostProcessor(PostProcessor): + ... def forward(self, batch): + ... # Post-process the batch + ... return batch + +The post-processors are implemented as both :class:`torch.nn.Module` and +:class:`lightning.pytorch.Callback` to support both inference and training +workflows. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,11 +33,37 @@ class PostProcessor(nn.Module, Callback, ABC): - """Base class for post-processor. + """Base class for post-processing anomaly detection results. + + The post-processor is implemented as both a :class:`torch.nn.Module` and + :class:`lightning.pytorch.Callback` to support inference and training workflows. + It handles tasks like score normalization, thresholding, and mask refinement. - The post-processor is a callback that is used to post-process the predictions of the model. + The class must be inherited and the :meth:`forward` method must be implemented + to define the post-processing logic. + + Example: + >>> from anomalib.post_processing import PostProcessor + >>> class MyPostProcessor(PostProcessor): + ... def forward(self, batch): + ... # Normalize scores between 0 and 1 + ... batch.anomaly_scores = normalize(batch.anomaly_scores) + ... return batch """ @abstractmethod def forward(self, batch: InferenceBatch) -> InferenceBatch: - """Functional forward method for post-processing.""" + """Post-process a batch of model predictions. + + Args: + batch (:class:`anomalib.data.InferenceBatch`): Batch containing model + predictions and metadata. + + Returns: + :class:`anomalib.data.InferenceBatch`: Post-processed batch with + normalized scores, thresholded predictions, and/or refined masks. + + Raises: + NotImplementedError: This is an abstract method that must be + implemented by subclasses. + """ diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index c19ef85300..ca89ba4df5 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -1,4 +1,19 @@ -"""Post-processing module for anomaly detection models.""" +"""Post-processing module for one-class anomaly detection results. + +This module provides post-processing functionality for one-class anomaly detection +outputs through the :class:`OneClassPostProcessor` class. + +The post-processor handles: + - Normalizing image and pixel-level anomaly scores + - Computing adaptive thresholds for anomaly classification + - Applying sensitivity adjustments to thresholds + - Formatting results for downstream use + +Example: + >>> from anomalib.post_processing import OneClassPostProcessor + >>> post_processor = OneClassPostProcessor(image_sensitivity=0.5) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,7 +28,28 @@ class OneClassPostProcessor(PostProcessor): - """Default post-processor for one-class anomaly detection.""" + """Post-processor for one-class anomaly detection. + + This class handles post-processing of anomaly detection results by: + - Normalizing image and pixel-level anomaly scores + - Computing adaptive thresholds for anomaly classification + - Applying sensitivity adjustments to thresholds + - Formatting results for downstream use + + Args: + image_sensitivity (float | None, optional): Sensitivity value for image-level + predictions. Higher values make the model more sensitive to anomalies. + Defaults to None. + pixel_sensitivity (float | None, optional): Sensitivity value for pixel-level + predictions. Higher values make the model more sensitive to anomalies. + Defaults to None. + **kwargs: Additional keyword arguments passed to parent class. + + Example: + >>> from anomalib.post_processing import OneClassPostProcessor + >>> post_processor = OneClassPostProcessor(image_sensitivity=0.5) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) + """ def __init__( self, @@ -39,7 +75,15 @@ def on_validation_batch_end( *args, **kwargs, ) -> None: - """Update the normalization and thresholding metrics using the batch output.""" + """Update normalization and thresholding metrics using batch output. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + outputs (Batch): Batch containing model predictions and ground truth. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ del trainer, pl_module, args, kwargs # Unused arguments. if outputs.pred_score is not None: self._image_threshold.update(outputs.pred_score, outputs.gt_label) @@ -51,7 +95,12 @@ def on_validation_batch_end( self._pixel_normalization_stats.update(outputs.anomaly_map) def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Compute the final threshold and normalization values.""" + """Compute final threshold and normalization values. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + """ del trainer, pl_module if self._image_threshold.update_called: self._image_threshold.compute() @@ -70,7 +119,15 @@ def on_test_batch_end( *args, **kwargs, ) -> None: - """Apply the post-processing steps to the current batch of predictions.""" + """Apply post-processing steps to current batch of predictions. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + outputs (Batch): Batch containing model predictions. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ del trainer, pl_module, args, kwargs self.post_process_batch(outputs) @@ -82,12 +139,31 @@ def on_predict_batch_end( *args, **kwargs, ) -> None: - """Normalize the predicted scores and anomaly maps.""" + """Normalize predicted scores and anomaly maps. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + outputs (Batch): Batch containing model predictions. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ del trainer, pl_module, args, kwargs self.post_process_batch(outputs) def forward(self, predictions: InferenceBatch) -> InferenceBatch: - """Funcional forward method for post-processing.""" + """Post-process model predictions. + + Args: + predictions (InferenceBatch): Batch containing model predictions. + + Returns: + InferenceBatch: Post-processed batch with normalized scores and + thresholded predictions. + + Raises: + ValueError: If neither `pred_score` nor `anomaly_map` is provided. + """ if predictions.pred_score is None and predictions.anomaly_map is None: msg = "At least one of pred_score or anomaly_map must be provided." raise ValueError(msg) @@ -104,14 +180,24 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: ) def post_process_batch(self, batch: Batch) -> None: - """Normalize the predicted scores and anomaly maps.""" + """Post-process a batch of predictions. + + Applies normalization and thresholding to the batch predictions. + + Args: + batch (Batch): Batch containing model predictions. + """ # apply normalization self.normalize_batch(batch) # apply threshold self.threshold_batch(batch) def threshold_batch(self, batch: Batch) -> None: - """Apply thresholding to the batch predictions.""" + """Apply thresholding to batch predictions. + + Args: + batch (Batch): Batch containing model predictions. + """ batch.pred_label = ( batch.pred_label if batch.pred_label is not None @@ -124,7 +210,11 @@ def threshold_batch(self, batch: Batch) -> None: ) def normalize_batch(self, batch: Batch) -> None: - """Normalize the predicted scores and anomaly maps.""" + """Normalize predicted scores and anomaly maps. + + Args: + batch (Batch): Batch containing model predictions. + """ # normalize pixel-level predictions batch.anomaly_map = self._normalize(batch.anomaly_map, self.pixel_min, self.pixel_max, self.raw_pixel_threshold) # normalize image-level predictions @@ -132,7 +222,15 @@ def normalize_batch(self, batch: Batch) -> None: @staticmethod def _threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | None: - """Apply thresholding to a single tensor.""" + """Apply thresholding to a single tensor. + + Args: + preds (torch.Tensor | None): Predictions to threshold. + threshold (float): Threshold value. + + Returns: + torch.Tensor | None: Thresholded predictions or None if input is None. + """ if preds is None: return None return preds > threshold @@ -144,7 +242,17 @@ def _normalize( norm_max: float, threshold: float, ) -> torch.Tensor | None: - """Normalize a tensor using the min, max, and threshold values.""" + """Normalize a tensor using min, max, and threshold values. + + Args: + preds (torch.Tensor | None): Predictions to normalize. + norm_min (float): Minimum value for normalization. + norm_max (float): Maximum value for normalization. + threshold (float): Threshold value. + + Returns: + torch.Tensor | None: Normalized predictions or None if input is None. + """ if preds is None: return None preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5 @@ -153,44 +261,76 @@ def _normalize( @property def raw_image_threshold(self) -> float: - """Get the image-level threshold.""" + """Get the raw image-level threshold. + + Returns: + float: Raw image-level threshold value. + """ return self._image_threshold.value @property def raw_pixel_threshold(self) -> float: - """Get the pixel-level threshold.""" + """Get the raw pixel-level threshold. + + Returns: + float: Raw pixel-level threshold value. + """ return self._pixel_threshold.value @property def normalized_image_threshold(self) -> float: - """Get the image-level threshold.""" + """Get the normalized image-level threshold. + + Returns: + float: Normalized image-level threshold value, adjusted by sensitivity. + """ if self.image_sensitivity is not None: return 1 - self.image_sensitivity return 0.5 @property def normalized_pixel_threshold(self) -> float: - """Get the pixel-level threshold.""" + """Get the normalized pixel-level threshold. + + Returns: + float: Normalized pixel-level threshold value, adjusted by sensitivity. + """ if self.pixel_sensitivity is not None: return 1 - self.pixel_sensitivity return 0.5 @property def image_min(self) -> float: - """Get the minimum value for normalization.""" + """Get the minimum value for image-level normalization. + + Returns: + float: Minimum image-level value. + """ return self._image_normalization_stats.min @property def image_max(self) -> float: - """Get the maximum value for normalization.""" + """Get the maximum value for image-level normalization. + + Returns: + float: Maximum image-level value. + """ return self._image_normalization_stats.max @property def pixel_min(self) -> float: - """Get the minimum value for normalization.""" + """Get the minimum value for pixel-level normalization. + + Returns: + float: Minimum pixel-level value. + """ return self._pixel_normalization_stats.min @property def pixel_max(self) -> float: - """Get the maximum value for normalization.""" + """Get the maximum value for pixel-level normalization. + + Returns: + float: Maximum pixel-level value. + """ return self._pixel_normalization_stats.max diff --git a/src/anomalib/pre_processing/__init__.py b/src/anomalib/pre_processing/__init__.py index d70565f882..db63480c35 100644 --- a/src/anomalib/pre_processing/__init__.py +++ b/src/anomalib/pre_processing/__init__.py @@ -1,4 +1,23 @@ -"""Anomalib pre-processing module.""" +"""Pre-processing module for anomaly detection pipelines. + +This module provides functionality for pre-processing data before model training +and inference through the :class:`PreProcessor` class. + +The pre-processor handles: + - Applying transforms to data during different pipeline stages + - Managing stage-specific transforms (train/val/test) + - Integrating with both PyTorch and Lightning workflows + +Example: + >>> from anomalib.pre_processing import PreProcessor + >>> from torchvision.transforms.v2 import Resize + >>> pre_processor = PreProcessor(transform=Resize(size=(256, 256))) + >>> transformed_batch = pre_processor(batch) + +The pre-processor is implemented as both a :class:`torch.nn.Module` and +:class:`lightning.pytorch.Callback` to support both inference and training +workflows. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pre_processing/pre_processing.py b/src/anomalib/pre_processing/pre_processing.py index 27cffc7605..95a1a7b880 100644 --- a/src/anomalib/pre_processing/pre_processing.py +++ b/src/anomalib/pre_processing/pre_processing.py @@ -1,4 +1,23 @@ -"""Anomalib pre-processing module.""" +"""Pre-processing module for anomaly detection pipelines. + +This module provides functionality for pre-processing data before model training +and inference through the :class:`PreProcessor` class. + +The pre-processor handles: + - Applying transforms to data during different pipeline stages + - Managing stage-specific transforms (train/val/test) + - Integrating with both PyTorch and Lightning workflows + +Example: + >>> from anomalib.pre_processing import PreProcessor + >>> from torchvision.transforms.v2 import Resize + >>> pre_processor = PreProcessor(transform=Resize(size=(256, 256))) + >>> transformed_batch = pre_processor(batch) + +The pre-processor is implemented as both a :class:`torch.nn.Module` and +:class:`lightning.pytorch.Callback` to support both inference and training +workflows. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -33,44 +52,56 @@ class PreProcessor(nn.Module, Callback): training, validation, testing, and prediction. Args: - train_transform (Transform | None): Transform to apply during training. - val_transform (Transform | None): Transform to apply during validation. - test_transform (Transform | None): Transform to apply during testing. - transform (Transform | None): General transform to apply if stage-specific - transforms are not provided. + train_transform (Transform | None, optional): Transform to apply during + training. Defaults to None. + val_transform (Transform | None, optional): Transform to apply during + validation. Defaults to None. + test_transform (Transform | None, optional): Transform to apply during + testing. Defaults to None. + transform (Transform | None, optional): General transform to apply if + stage-specific transforms are not provided. Defaults to None. Raises: - ValueError: If both `transform` and any of the stage-specific transforms + ValueError: If both ``transform`` and any of the stage-specific transforms are provided simultaneously. Notes: - If only `transform` is provided, it will be used for all stages (train, val, test). + If only ``transform`` is provided, it will be used for all stages (train, + val, test). Priority of transforms: - 1. Explicitly set PreProcessor transforms (highest priority) - 2. Datamodule transforms (if PreProcessor has no transforms) - 3. Dataloader transforms (if neither PreProcessor nor datamodule have transforms) - 4. Default transforms (lowest priority) + 1. Explicitly set ``PreProcessor`` transforms (highest priority) + 2. Datamodule transforms (if ``PreProcessor`` has no transforms) + 3. Dataloader transforms (if neither ``PreProcessor`` nor datamodule + have transforms) + 4. Default transforms (lowest priority) - Examples: + Example: >>> from torchvision.transforms.v2 import Compose, Resize, ToTensor >>> from anomalib.pre_processing import PreProcessor - >>> # Define transforms - >>> train_transform = Compose([Resize((224, 224)), ToTensor()]) - >>> val_transform = Compose([Resize((256, 256)), CenterCrop((224, 224)), ToTensor()]) - + >>> train_transform = Compose([ + ... Resize((224, 224)), + ... ToTensor() + ... ]) + >>> val_transform = Compose([ + ... Resize((256, 256)), + ... CenterCrop((224, 224)), + ... ToTensor() + ... ]) >>> # Create PreProcessor with stage-specific transforms >>> pre_processor = PreProcessor( ... train_transform=train_transform, ... val_transform=val_transform ... ) - >>> # Create PreProcessor with a single transform for all stages - >>> common_transform = Compose([Resize((224, 224)), ToTensor()]) + >>> common_transform = Compose([ + ... Resize((224, 224)), + ... ToTensor() + ... ]) >>> pre_processor_common = PreProcessor(transform=common_transform) - >>> # Use in a Lightning module + Integration with Lightning: >>> class MyModel(LightningModule): ... def __init__(self): ... super().__init__() @@ -80,7 +111,7 @@ class PreProcessor(nn.Module, Callback): ... return [self.pre_processor] ... ... def training_step(self, batch, batch_idx): - ... # The pre_processor will automatically apply the correct transform + ... # Pre-processor automatically applies correct transform ... processed_batch = self.pre_processor(batch) ... # Rest of the training step """ @@ -110,7 +141,12 @@ def __init__( self.export_transform = get_exportable_transform(self.test_transform) def setup_datamodule_transforms(self, datamodule: "AnomalibDataModule") -> None: - """Set up datamodule transforms.""" + """Set up datamodule transforms. + + Args: + datamodule (AnomalibDataModule): The datamodule to configure + transforms for. + """ # If PreProcessor has transforms, propagate them to datamodule if any([self.train_transform, self.val_transform, self.test_transform]): transforms = { @@ -125,7 +161,12 @@ def setup_datamodule_transforms(self, datamodule: "AnomalibDataModule") -> None: set_datamodule_stage_transform(datamodule, transform, stage) def setup_dataloader_transforms(self, dataloaders: "EVAL_DATALOADERS | TRAIN_DATALOADERS") -> None: - """Set up dataloader transforms.""" + """Set up dataloader transforms. + + Args: + dataloaders (EVAL_DATALOADERS | TRAIN_DATALOADERS): The dataloaders + to configure transforms for. + """ if isinstance(dataloaders, DataLoader): dataloaders = [dataloaders] @@ -153,9 +194,9 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non """Configure transforms at the start of each stage. Args: - trainer: The Lightning trainer. - pl_module: The Lightning module. - stage: The stage (e.g., 'fit', 'validate', 'test', 'predict'). + trainer (Trainer): The Lightning trainer. + pl_module (LightningModule): The Lightning module. + stage (str): The stage (e.g., 'fit', 'validate', 'test', 'predict'). """ stage = TrainerFn(stage).value # Ensure stage is str @@ -171,7 +212,13 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: """Apply transforms to the batch of tensors for inference. This forward-pass is only used after the model is exported. - Within the Lightning training/validation/testing loops, the transforms are applied - in the `on_*_batch_start` methods. + Within the Lightning training/validation/testing loops, the transforms are + applied in the ``on_*_batch_start`` methods. + + Args: + batch (torch.Tensor): Input batch to transform. + + Returns: + torch.Tensor: Transformed batch. """ return self.export_transform(batch) if self.export_transform else batch diff --git a/src/anomalib/pre_processing/utils/__init__.py b/src/anomalib/pre_processing/utils/__init__.py index 8361223189..f4ff633692 100644 --- a/src/anomalib/pre_processing/utils/__init__.py +++ b/src/anomalib/pre_processing/utils/__init__.py @@ -1,4 +1,17 @@ -"""Utility functions for pre-processing.""" +"""Utility functions for pre-processing. + +This module provides utility functions used by the pre-processing module for +handling transforms and data processing tasks. + +The utilities include: + - Transform management for different pipeline stages + - Conversion between transform types + - Helper functions for dataloader/datamodule transform handling + +Example: + >>> from anomalib.pre_processing.utils import get_exportable_transform + >>> transform = get_exportable_transform(train_transform) +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/pre_processing/utils/transform.py b/src/anomalib/pre_processing/utils/transform.py index 37eb1e9dd1..e2032e6284 100644 --- a/src/anomalib/pre_processing/utils/transform.py +++ b/src/anomalib/pre_processing/utils/transform.py @@ -1,4 +1,23 @@ -"""Utility functions for transforms.""" +"""Utility functions for transforms. + +This module provides utility functions for managing transforms in the pre-processing +pipeline. The utilities handle: + - Getting and setting transforms for different pipeline stages + - Converting between transform types + - Managing transforms across dataloaders and datamodules + +Example: + >>> from anomalib.pre_processing.utils.transform import get_dataloaders_transforms + >>> transforms = get_dataloaders_transforms(dataloaders) + >>> print(transforms["train"]) # Get training stage transform + Compose( + Resize(size=(256, 256), ...), + ToTensor() + ) + +The module ensures consistent transform handling across the training, validation, +and testing stages of the anomaly detection pipeline. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,13 +32,40 @@ def get_dataloaders_transforms(dataloaders: Sequence[DataLoader]) -> dict[str, Transform]: - """Get transforms from dataloaders. + """Extract transforms from a sequence of dataloaders. + + This function retrieves the transforms associated with different stages (train, + validation, test) from a sequence of dataloaders. It maps Lightning stage names + to their corresponding transform stages. + + The stage mapping is: + - ``fit`` -> ``train`` + - ``validate`` -> ``val`` + - ``test`` -> ``test`` + - ``predict`` -> ``test`` Args: - dataloaders: The dataloaders to get transforms from. + dataloaders: A sequence of PyTorch :class:`DataLoader` objects to extract + transforms from. Each dataloader should have a ``dataset`` attribute + with a ``transform`` property. Returns: - Dictionary mapping stages to their transforms. + A dictionary mapping stage names (``train``, ``val``, ``test``) to their + corresponding :class:`torchvision.transforms.v2.Transform` objects. + + Example: + >>> from torch.utils.data import DataLoader + >>> from torchvision.transforms.v2 import Resize, ToTensor + >>> # Create dataloaders with transforms + >>> train_loader = DataLoader(dataset_with_transform) + >>> val_loader = DataLoader(dataset_with_transform) + >>> # Get transforms + >>> transforms = get_dataloaders_transforms([train_loader, val_loader]) + >>> print(transforms["train"]) # Access training transform + Compose( + Resize(size=(256, 256)), + ToTensor() + ) """ transforms: dict[str, Transform] = {} stage_lookup = { @@ -41,11 +87,36 @@ def get_dataloaders_transforms(dataloaders: Sequence[DataLoader]) -> dict[str, T def set_dataloaders_transforms(dataloaders: Sequence[DataLoader], transforms: dict[str, Transform | None]) -> None: - """Set transforms to dataloaders. + """Set transforms to dataloaders based on their stage. + + This function propagates transforms to dataloaders based on their stage mapping. + The stage mapping follows the convention: + + - ``fit`` -> ``train`` + - ``validate`` -> ``val`` + - ``test`` -> ``test`` + - ``predict`` -> ``test`` Args: - dataloaders: The dataloaders to propagate transforms to. - transforms: Dictionary mapping stages to their transforms. + dataloaders: A sequence of PyTorch :class:`DataLoader` objects to set + transforms for. Each dataloader should have a ``dataset`` attribute. + transforms: Dictionary mapping stage names (``train``, ``val``, ``test``) + to their corresponding :class:`torchvision.transforms.v2.Transform` + objects. The transforms can be ``None``. + + Example: + >>> from torch.utils.data import DataLoader + >>> from torchvision.transforms.v2 import Resize, ToTensor + >>> # Create transforms + >>> transforms = { + ... "train": Compose([Resize((256, 256)), ToTensor()]), + ... "val": Compose([Resize((256, 256)), ToTensor()]) + ... } + >>> # Create dataloaders + >>> train_loader = DataLoader(dataset_with_transform) + >>> val_loader = DataLoader(dataset_with_transform) + >>> # Set transforms + >>> set_dataloaders_transforms([train_loader, val_loader], transforms) """ stage_mapping = { "fit": "train", @@ -66,11 +137,31 @@ def set_dataloaders_transforms(dataloaders: Sequence[DataLoader], transforms: di def set_dataloader_transform(dataloader: DataLoader | Sequence[DataLoader], transform: Transform) -> None: - """Set a transform for a dataloader or list of dataloaders. + """Set a transform for a dataloader or sequence of dataloaders. + + This function sets the transform for either a single dataloader or multiple dataloaders. + The transform is set on the dataset object of each dataloader if it has a ``transform`` + attribute. Args: - dataloader: The dataloader(s) to set the transform for. - transform: The transform to set. + dataloader: A single :class:`torch.utils.data.DataLoader` or a sequence of + dataloaders to set the transform for. Each dataloader should have a + ``dataset`` attribute with a ``transform`` attribute. + transform: The :class:`torchvision.transforms.v2.Transform` object to set as + the transform. + + Raises: + TypeError: If ``dataloader`` is neither a :class:`torch.utils.data.DataLoader` + nor a sequence of dataloaders. + + Example: + >>> from torch.utils.data import DataLoader + >>> from torchvision.transforms.v2 import Resize + >>> # Create transform and dataloader + >>> transform = Resize(size=(256, 256)) + >>> dataloader = DataLoader(dataset_with_transform) + >>> # Set transform + >>> set_dataloader_transform(dataloader, transform) """ if isinstance(dataloader, DataLoader): if hasattr(dataloader.dataset, "transform"): @@ -84,19 +175,36 @@ def set_dataloader_transform(dataloader: DataLoader | Sequence[DataLoader], tran def set_datamodule_stage_transform(datamodule: AnomalibDataModule, transform: Transform, stage: str) -> None: - """Set a transform for a specific stage in a AnomalibDataModule. + """Set a transform for a specific stage in a :class:`AnomalibDataModule`. + + This function sets the transform for a specific stage (train/val/test/predict) in an + AnomalibDataModule by mapping the stage name to the corresponding dataset attribute + and setting its transform. Args: - datamodule: The AnomalibDataModule to set the transform for. - transform: The transform to set. - stage: The stage to set the transform for. + datamodule: The :class:`AnomalibDataModule` instance to set the transform for. + Must have dataset attributes corresponding to different stages. + transform: The :class:`torchvision.transforms.v2.Transform` object to set as + the transform for the specified stage. + stage: The pipeline stage to set the transform for. Must be one of: + ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``. Note: - The stage parameter maps to dataset attributes as follows: - - 'fit' -> 'train_data' - - 'validate' -> 'val_data' - - 'test' -> 'test_data' - - 'predict' -> 'test_data' + The ``stage`` parameter maps to dataset attributes as follows: + + - ``'fit'`` -> ``'train_data'`` + - ``'validate'`` -> ``'val_data'`` + - ``'test'`` -> ``'test_data'`` + - ``'predict'`` -> ``'test_data'`` + + Example: + >>> from torchvision.transforms.v2 import Resize + >>> from anomalib.data import MVTec + >>> # Create transform and datamodule + >>> transform = Resize(size=(256, 256)) + >>> datamodule = MVTec() + >>> # Set transform for training stage + >>> set_datamodule_stage_transform(datamodule, transform, "fit") """ stage_datasets = { "fit": "train_data", @@ -113,9 +221,35 @@ def set_datamodule_stage_transform(datamodule: AnomalibDataModule, transform: Tr def get_exportable_transform(transform: Transform | None) -> Transform | None: - """Get exportable transform. + """Get an exportable version of a transform. + + This function converts a torchvision transform into a format that is compatible with + ONNX and OpenVINO export. It handles two main compatibility issues: - Some transforms are not supported by ONNX/OpenVINO, so we need to replace them with exportable versions. + 1. Disables antialiasing in ``Resize`` transforms + 2. Converts ``CenterCrop`` to ``ExportableCenterCrop`` + + Args: + transform (Transform | None): The transform to convert. If ``None``, returns + ``None``. + + Returns: + Transform | None: The converted transform that is compatible with ONNX/OpenVINO + export. Returns ``None`` if input transform is ``None``. + + Example: + >>> from torchvision.transforms.v2 import Compose, Resize, CenterCrop + >>> transform = Compose([ + ... Resize((224, 224), antialias=True), + ... CenterCrop(200) + ... ]) + >>> exportable = get_exportable_transform(transform) + >>> # Now transform is compatible with ONNX/OpenVINO export + + Note: + Some torchvision transforms are not directly supported by ONNX/OpenVINO. This + function handles the most common cases, but additional transforms may need + special handling. """ if transform is None: return None @@ -126,7 +260,30 @@ def get_exportable_transform(transform: Transform | None) -> Transform | None: def disable_antialiasing(transform: Transform) -> Transform: """Disable antialiasing in Resize transforms. - Resizing with antialiasing is not supported by ONNX, so we need to disable it. + This function recursively disables antialiasing in any ``Resize`` transforms found + within the provided transform or transform composition. This is necessary because + antialiasing is not supported during ONNX export. + + Args: + transform (Transform): Transform or composition of transforms to process. + + Returns: + Transform: The processed transform with antialiasing disabled in any + ``Resize`` transforms. + + Example: + >>> from torchvision.transforms.v2 import Compose, Resize + >>> transform = Compose([ + ... Resize((224, 224), antialias=True), + ... Resize((256, 256), antialias=True) + ... ]) + >>> transform = disable_antialiasing(transform) + >>> # Now all Resize transforms have antialias=False + + Note: + This function modifies the transforms in-place by setting their + ``antialias`` attribute to ``False``. The original transform object is + returned. """ if isinstance(transform, Resize): transform.antialias = False @@ -137,9 +294,33 @@ def disable_antialiasing(transform: Transform) -> Transform: def convert_center_crop_transform(transform: Transform) -> Transform: - """Convert CenterCrop to ExportableCenterCrop. + """Convert torchvision's CenterCrop to ExportableCenterCrop. - Torchvision's CenterCrop is not supported by ONNX, so we need to replace it with our own ExportableCenterCrop. + This function recursively converts any ``CenterCrop`` transforms found within the + provided transform or transform composition to ``ExportableCenterCrop``. This is + necessary because torchvision's ``CenterCrop`` is not supported during ONNX + export. + + Args: + transform (Transform): Transform or composition of transforms to process. + + Returns: + Transform: The processed transform with all ``CenterCrop`` transforms + converted to ``ExportableCenterCrop``. + + Example: + >>> from torchvision.transforms.v2 import Compose, CenterCrop + >>> transform = Compose([ + ... CenterCrop(224), + ... CenterCrop((256, 256)) + ... ]) + >>> transform = convert_center_crop_transform(transform) + >>> # Now all CenterCrop transforms are converted to ExportableCenterCrop + + Note: + This function creates new ``ExportableCenterCrop`` instances to replace the + original ``CenterCrop`` transforms. The original transform object is + returned with the replacements applied. """ if isinstance(transform, CenterCrop): transform = ExportableCenterCrop(size=transform.size) diff --git a/src/anomalib/utils/__init__.py b/src/anomalib/utils/__init__.py index 8ffe7654fe..2787c5772d 100644 --- a/src/anomalib/utils/__init__.py +++ b/src/anomalib/utils/__init__.py @@ -1,4 +1,27 @@ -"""Helpers for downloading files, calculating metrics, computing anomaly maps, and visualization.""" +"""Utility functions and helpers for anomaly detection. -# Copyright (C) 2022 Intel Corporation +This module provides various utility functions and helpers for: + - File downloading and management + - Metric calculation and evaluation + - Anomaly map computation and processing + - Result visualization and plotting + +The utilities ensure consistent behavior across the library and provide common +functionality used by multiple components. + +Example: + >>> from anomalib.utils.visualization import ImageVisualizer + >>> # Create visualizer + >>> visualizer = ImageVisualizer() + >>> # Generate visualization + >>> vis_result = visualizer.visualize(image=img, pred_mask=mask) + +The module is organized into submodules for different types of utilities: + - ``download``: Functions for downloading datasets and models + - ``metrics``: Implementations of evaluation metrics + - ``map``: Tools for generating and processing anomaly maps + - ``visualization``: Classes for visualizing detection results +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/utils/config.py b/src/anomalib/utils/config.py index aadaa6a42b..a7611b24ee 100644 --- a/src/anomalib/utils/config.py +++ b/src/anomalib/utils/config.py @@ -1,4 +1,11 @@ -"""Get configurable parameters.""" +"""Configuration utilities. + +This module contains utility functions for handling configuration objects, including: +- Converting between different configuration formats (dict, Namespace, DictConfig) +- Flattening and nesting dictionaries +- Converting paths and values +- Updating configurations +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -16,7 +23,36 @@ def _convert_nested_path_to_str(config: Any) -> Any: # noqa: ANN401 - """Goes over the dictionary and converts all path values to str.""" + """Convert all path values to strings recursively in a configuration object. + + This function traverses a configuration object and converts any ``Path`` or + ``JSONArgparsePath`` objects to string representations. It handles nested + dictionaries and lists recursively. + + Args: + config: Configuration object that may contain path values. Can be a + dictionary, list, Path object, or other types. + + Returns: + Any: Configuration with all path values converted to strings. The returned + object maintains the same structure as the input, with only path + values converted to strings. + + Examples: + >>> from pathlib import Path + >>> config = { + ... "model_path": Path("/path/to/model"), + ... "data": { + ... "train_path": Path("/data/train"), + ... "val_path": Path("/data/val") + ... } + ... } + >>> converted = _convert_nested_path_to_str(config) + >>> print(converted["model_path"]) + /path/to/model + >>> print(converted["data"]["train_path"]) + /data/train + """ if isinstance(config, dict): for key, value in config.items(): config[key] = _convert_nested_path_to_str(value) @@ -29,22 +65,40 @@ def _convert_nested_path_to_str(config: Any) -> Any: # noqa: ANN401 def to_nested_dict(config: dict) -> dict: - """Convert the flattened dictionary to nested dictionary. + """Convert a flattened dictionary to a nested dictionary. + + This function takes a dictionary with dot-separated keys and converts it into a nested + dictionary structure. Keys containing dots (`.`) are split and used to create nested + dictionaries. + + Args: + config: Flattened dictionary where keys can contain dots to indicate nesting + levels. For example, ``"dataset.category"`` will become + ``{"dataset": {"category": ...}}``. + + Returns: + dict: A nested dictionary where dot-separated keys in the input are converted + to nested dictionary structures. Keys without dots remain at the top + level. Examples: >>> config = { - "dataset.category": "bottle", - "dataset.image_size": 224, - "model_name": "padim", - } - >>> to_nested_dict(config) - { - "dataset": { - "category": "bottle", - "image_size": 224, - }, - "model_name": "padim", - } + ... "dataset.category": "bottle", + ... "dataset.image_size": 224, + ... "model_name": "padim" + ... } + >>> result = to_nested_dict(config) + >>> print(result["dataset"]["category"]) + bottle + >>> print(result["dataset"]["image_size"]) + 224 + >>> print(result["model_name"]) + padim + + Note: + - The function preserves the original values while only restructuring the keys + - Non-dot keys are kept as-is at the root level + - Empty key segments (e.g. ``"dataset..name"``) are handled as literal keys """ out: dict[str, Any] = {} for key, value in config.items(): @@ -57,13 +111,34 @@ def to_nested_dict(config: dict) -> dict: def to_yaml(config: Namespace | ListConfig | DictConfig) -> str: - """Convert the config to a yaml string. + """Convert configuration object to YAML string. + + This function takes a configuration object and converts it to a YAML formatted string. + It handles different configuration object types including ``Namespace``, + ``ListConfig``, and ``DictConfig``. Args: - config (Namespace | ListConfig | DictConfig): Config + config: Configuration object to convert. Can be one of: + - ``Namespace``: A namespace object from OmegaConf + - ``ListConfig``: A list configuration from OmegaConf + - ``DictConfig``: A dictionary configuration from OmegaConf Returns: - str: YAML string + str: Configuration as YAML formatted string + + Examples: + >>> from omegaconf import DictConfig + >>> config = DictConfig({"model": "padim", "dataset": {"name": "mvtec"}}) + >>> yaml_str = to_yaml(config) + >>> print(yaml_str) + model: padim + dataset: + name: mvtec + + Note: + - For ``Namespace`` objects, the function first converts to dictionary format + - Nested paths in the configuration are converted to strings + - The original configuration object is not modified """ _config = config.clone() if isinstance(config, Namespace) else config.copy() if isinstance(_config, Namespace): @@ -73,22 +148,40 @@ def to_yaml(config: Namespace | ListConfig | DictConfig) -> str: def to_tuple(input_size: int | ListConfig) -> tuple[int, int]: - """Convert int or list to a tuple. + """Convert input size to a tuple of (height, width). + + This function takes either a single integer or a sequence of two integers and + converts it to a tuple representing image dimensions (height, width). If a single + integer is provided, it is used for both dimensions. Args: - input_size (int | ListConfig): input_size + input_size: Input size specification. Can be either: + - A single ``int`` that will be used for both height and width + - A ``ListConfig`` or sequence containing exactly 2 integers for height + and width + + Returns: + tuple[int, int]: A tuple of ``(height, width)`` dimensions + + Examples: + Create a square tuple from single integer: - Example: >>> to_tuple(256) (256, 256) + + Create a tuple from list of dimensions: + >>> to_tuple([256, 256]) (256, 256) Raises: - ValueError: Unsupported value type. + ValueError: If ``input_size`` is a sequence without exactly 2 elements + TypeError: If ``input_size`` is neither an integer nor a sequence of + integers - Returns: - tuple[int, int]: Tuple of input_size + Note: + When using a sequence input, the first value is interpreted as height and + the second as width. """ ret_val: tuple[int, int] if isinstance(input_size, int): @@ -106,30 +199,46 @@ def to_tuple(input_size: int | ListConfig) -> tuple[int, int]: def convert_valuesview_to_tuple(values: ValuesView) -> list[tuple]: - """Convert a ValuesView object to a list of tuples. + """Convert ``ValuesView`` to list of tuples for parameter combinations. - This is useful to get list of possible values for each parameter in the config and a tuple for values that are - are to be patched. Ideally this is useful when used with product. + This function takes a ``ValuesView`` object and converts it to a list of tuples + that can be used for creating parameter combinations. It is particularly useful + when working with ``itertools.product`` to generate all possible parameter + combinations. - Example: - >>> params = DictConfig({ - "dataset.category": [ - "bottle", - "cable", - ], - "dataset.image_size": 224, - "model_name": ["padim"], - }) - >>> convert_to_tuple(params.values()) - [('bottle', 'cable'), (224,), ('padim',)] - >>> list(itertools.product(*convert_to_tuple(params.values()))) - [('bottle', 224, 'padim'), ('cable', 224, 'padim')] + The function handles both iterable and non-iterable values: + - Iterable values (except strings) are converted to tuples + - Non-iterable values and strings are wrapped in single-element tuples Args: - values: ValuesView: ValuesView object to be converted to a list of tuples. + values: A ``ValuesView`` object containing parameter values to convert Returns: - list[Tuple]: List of tuples. + list[tuple]: A list of tuples where each tuple contains parameter values. + Single values are wrapped in 1-element tuples. + + Examples: + Create parameter combinations from a config: + + >>> params = DictConfig({ + ... "dataset.category": [ + ... "bottle", + ... "cable", + ... ], + ... "dataset.image_size": 224, + ... "model_name": ["padim"], + ... }) + >>> convert_valuesview_to_tuple(params.values()) + [('bottle', 'cable'), (224,), ('padim',)] + + Use with ``itertools.product`` to get all combinations: + + >>> list(itertools.product(*convert_valuesview_to_tuple(params.values()))) + [('bottle', 224, 'padim'), ('cable', 224, 'padim')] + + Note: + Strings are treated as non-iterable values even though they are technically + iterable in Python. This prevents unwanted character-by-character splitting. """ return_list = [] for value in values: @@ -141,21 +250,48 @@ def convert_valuesview_to_tuple(values: ValuesView) -> list[tuple]: def flatten_dict(config: dict, prefix: str = "") -> dict: - """Flatten the dictionary. + """Flatten a nested dictionary using dot notation. + + Takes a nested dictionary and flattens it into a single-level dictionary where + nested keys are joined using dot notation. This is useful for converting + hierarchical configurations into a flat format. + + Args: + config: Nested dictionary to flatten. Can contain arbitrary levels of + nesting. + prefix: Optional string prefix to prepend to all flattened keys. Defaults + to empty string. + + Returns: + dict: Flattened dictionary where nested keys are joined with dots. + For example, ``{"a": {"b": 1}}`` becomes ``{"a.b": 1}``. Examples: + Basic nested dictionary flattening: + >>> config = { - "dataset": { - "category": "bottle", - "image_size": 224, - }, - "model_name": "padim", - } - >>> flatten_dict(config) + ... "dataset": { + ... "category": "bottle", + ... "image_size": 224 + ... }, + ... "model_name": "padim" + ... } + >>> flattened = flatten_dict(config) + >>> print(flattened) # doctest: +SKIP { - "dataset.category": "bottle", - "dataset.image_size": 224, - "model_name": "padim", + 'dataset.category': 'bottle', + 'dataset.image_size': 224, + 'model_name': 'padim' + } + + With custom prefix: + + >>> flattened = flatten_dict(config, prefix="config.") + >>> print(flattened) # doctest: +SKIP + { + 'config.dataset.category': 'bottle', + 'config.dataset.image_size': 224, + 'config.model_name': 'padim' } """ out = {} @@ -168,18 +304,44 @@ def flatten_dict(config: dict, prefix: str = "") -> dict: def namespace_from_dict(container: dict) -> Namespace: - """Convert dictionary to Namespace recursively. + """Convert a dictionary to a Namespace object recursively. + + This function takes a dictionary and recursively converts it and all nested + dictionaries into ``Namespace`` objects. This is useful for accessing dictionary + keys as attributes. + + Args: + container: Dictionary to convert into a ``Namespace`` object. Can contain + arbitrary levels of nesting. + + Returns: + ``Namespace`` object with equivalent structure to input dictionary. Nested + dictionaries are converted to nested ``Namespace`` objects. Examples: + Basic dictionary conversion: + >>> container = { - "dataset": { - "category": "bottle", - "image_size": 224, - }, - "model_name": "padim", - } - >>> namespace_from_dict(container) - Namespace(dataset=Namespace(category='bottle', image_size=224), model_name='padim') + ... "dataset": { + ... "category": "bottle", + ... "image_size": 224, + ... }, + ... "model_name": "padim", + ... } + >>> namespace = namespace_from_dict(container) + >>> namespace.dataset.category + 'bottle' + >>> namespace.model_name + 'padim' + + The returned object allows attribute-style access: + + >>> namespace.dataset.image_size + 224 + + Note: + All dictionary keys must be valid Python identifiers to be accessed as + attributes in the resulting ``Namespace`` object. """ output = Namespace() for k, v in container.items(): @@ -191,9 +353,23 @@ def namespace_from_dict(container: dict) -> Namespace: def dict_from_namespace(container: Namespace) -> dict: - """Convert Namespace to dictionary recursively. + """Convert a Namespace object to a dictionary recursively. + + This function takes a ``Namespace`` object and recursively converts it and all nested + ``Namespace`` objects into dictionaries. This is useful for serializing ``Namespace`` + objects or converting them to a format that can be easily saved or transmitted. + + Args: + container: ``Namespace`` object to convert into a dictionary. Can contain + arbitrary levels of nesting. + + Returns: + Dictionary with equivalent structure to input ``Namespace``. Nested + ``Namespace`` objects are converted to nested dictionaries. Examples: + Basic namespace conversion: + >>> from jsonargparse import Namespace >>> ns = Namespace() >>> ns.a = 1 @@ -201,6 +377,20 @@ def dict_from_namespace(container: Namespace) -> dict: >>> ns.b.c = 2 >>> dict_from_namespace(ns) {'a': 1, 'b': {'c': 2}} + + The function handles arbitrary nesting: + + >>> ns = Namespace() + >>> ns.x = Namespace() + >>> ns.x.y = Namespace() + >>> ns.x.y.z = 3 + >>> dict_from_namespace(ns) + {'x': {'y': {'z': 3}}} + + Note: + This function is the inverse of :func:`namespace_from_dict`. Together they + provide bidirectional conversion between dictionaries and ``Namespace`` + objects. """ output = {} for k, v in container.__dict__.items(): @@ -212,13 +402,34 @@ def dict_from_namespace(container: Namespace) -> dict: def update_config(config: DictConfig | ListConfig | Namespace) -> DictConfig | ListConfig | Namespace: - """Update config. + """Update configuration with warnings and NNCF settings. + + This function processes the provided configuration by: + - Showing relevant configuration-specific warnings via ``_show_warnings`` + - Updating NNCF (Neural Network Compression Framework) settings via + ``_update_nncf_config`` Args: - config: Configurable parameters. + config: Configuration object to update. Can be either a ``DictConfig``, + ``ListConfig``, or ``Namespace`` instance containing model and training + parameters. Returns: - DictConfig | ListConfig | Namespace: Updated config. + Updated configuration with any NNCF-specific modifications applied. Returns + the same type as the input configuration. + + Examples: + >>> from omegaconf import DictConfig + >>> config = DictConfig({"optimization": {"nncf": {"apply": True}}}) + >>> updated = update_config(config) + + >>> from jsonargparse import Namespace + >>> config = Namespace(data={"clip_length_in_frames": 1}) + >>> updated = update_config(config) + + Note: + This function is typically called after loading the initial configuration + but before using it for model training or inference. """ _show_warnings(config) @@ -226,13 +437,40 @@ def update_config(config: DictConfig | ListConfig | Namespace) -> DictConfig | L def _update_nncf_config(config: DictConfig | ListConfig) -> DictConfig | ListConfig: - """Set the NNCF input size based on the value of the crop_size parameter in the configurable parameters object. + """Update NNCF configuration with input size settings. + + This function updates the Neural Network Compression Framework (NNCF) + configuration by setting default input size parameters if they are not already + specified. It also handles merging any NNCF-specific configuration updates. + + The function checks if NNCF optimization settings exist in the config and adds + default input shape information of ``[1, 3, 10, 10]`` if not present. If NNCF + is enabled and contains update configuration, it merges those updates. Args: - config (DictConfig | ListConfig): Configurable parameters of the current run. + config: Configuration object containing NNCF settings. Must be either a + ``DictConfig`` or ``ListConfig`` instance. Returns: - DictConfig | ListConfig: Updated configurable parameters in DictConfig object. + ``DictConfig`` or ``ListConfig`` with updated NNCF configuration settings. + + Example: + >>> from omegaconf import DictConfig + >>> config = DictConfig({ + ... "optimization": { + ... "nncf": { + ... "apply": True, + ... "input_info": {"sample_size": [1, 3, 224, 224]} + ... } + ... } + ... }) + >>> updated = _update_nncf_config(config) + + Note: + The default input size of ``[1, 3, 10, 10]`` represents: + - Batch size of 1 + - 3 input channels (RGB) + - Height and width of 10 pixels """ if "optimization" in config and "nncf" in config.optimization: if "input_info" not in config.optimization.nncf: @@ -244,10 +482,32 @@ def _update_nncf_config(config: DictConfig | ListConfig) -> DictConfig | ListCon def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None: - """Show warnings if any based on the configuration settings. + """Show configuration-specific warnings. + + This function checks the provided configuration for conditions that may cause + issues and displays appropriate warning messages. Currently checks for: + + - Video clip length compatibility issues with models and visualizers Args: - config (DictConfig | ListConfig | Namespace): Configurable parameters for the current run. + config: Configuration object to check for warning conditions. Can be one of: + - ``DictConfig`` + - ``ListConfig`` + - ``Namespace`` + + Example: + >>> from omegaconf import DictConfig + >>> config = DictConfig({ + ... "data": { + ... "init_args": {"clip_length_in_frames": 2} + ... } + ... }) + >>> _show_warnings(config) # Will show video clip length warning + + Note: + The function currently focuses on video-related configuration warnings, + specifically checking the ``clip_length_in_frames`` parameter in the data + configuration section. """ if "clip_length_in_frames" in config.data and config.data.init_args.clip_length_in_frames > 1: logger.warning( diff --git a/src/anomalib/utils/cv/__init__.py b/src/anomalib/utils/cv/__init__.py index 72435b61dc..537d57173f 100644 --- a/src/anomalib/utils/cv/__init__.py +++ b/src/anomalib/utils/cv/__init__.py @@ -1,6 +1,22 @@ -"""Anomalib computer vision utilities.""" +"""Computer vision utilities for anomaly detection. -# Copyright (C) 2022 Intel Corporation +This module provides computer vision utilities used by the anomalib library for +processing and analyzing images during anomaly detection. + +The utilities include: + - Connected components analysis for both CPU and GPU + - Image processing operations + - Computer vision helper functions + +Example: + >>> from anomalib.utils.cv import connected_components_cpu + >>> # Process image to get binary mask + >>> mask = get_binary_mask(image) + >>> # Find connected components + >>> labels = connected_components_cpu(mask) +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .connected_components import connected_components_cpu, connected_components_gpu diff --git a/src/anomalib/utils/cv/connected_components.py b/src/anomalib/utils/cv/connected_components.py index e2fc1000df..fe2c7fa1c2 100644 --- a/src/anomalib/utils/cv/connected_components.py +++ b/src/anomalib/utils/cv/connected_components.py @@ -1,6 +1,22 @@ -"""Connected component labeling.""" +"""Connected component labeling for anomaly detection. -# Copyright (C) 2022 Intel Corporation +This module provides functions for performing connected component labeling on both +GPU and CPU. Connected components are used to identify and label contiguous +regions in binary images, which is useful for post-processing anomaly detection +results. + +Example: + >>> import torch + >>> from anomalib.utils.cv import connected_components_gpu + >>> # Create a binary mask tensor (1 for anomaly, 0 for normal) + >>> mask = torch.zeros(1, 1, 4, 4) + >>> mask[0, 0, 1:3, 1:3] = 1 # Create a 2x2 square anomaly + >>> # Get labeled components + >>> labels = connected_components_gpu(mask) + >>> print(labels.unique()) # Should show [0, 1] for background and one component +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import cv2 @@ -10,14 +26,39 @@ def connected_components_gpu(image: torch.Tensor, num_iterations: int = 1000) -> torch.Tensor: - """Perform connected component labeling on GPU and remap the labels from 0 to N. + """Perform connected component labeling on GPU. + + Labels connected regions in a binary image and remaps the labels sequentially + from 0 to N, where N is the number of unique components. Uses the GPU for + faster processing of large images. Args: - image (torch.Tensor): Binary input image from which we want to extract connected components (Bx1xHxW) - num_iterations (int): Number of iterations used in the connected component computation. + image (torch.Tensor): Binary input image tensor of shape ``(B, 1, H, W)`` + where ``B`` is batch size, ``H`` is height and ``W`` is width. + Values should be binary (0 or 1). + num_iterations (int, optional): Number of iterations for the connected + components algorithm. Higher values may be needed for complex regions. + Defaults to 1000. Returns: - Tensor: Components labeled from 0 to N. + torch.Tensor: Integer tensor of same shape as input, containing labeled + components from 0 to N. Background (zero) pixels in the input remain + ``0``, while connected regions are labeled with integers from ``1`` + to ``N``. + + Example: + >>> import torch + >>> from anomalib.utils.cv import connected_components_gpu + >>> # Create a binary mask with a 2x2 square anomaly + >>> mask = torch.zeros(1, 1, 4, 4) + >>> mask[0, 0, 1:3, 1:3] = 1 + >>> labels = connected_components_gpu(mask) + >>> print(labels.unique()) # Should show tensor([0, 1]) + >>> print(labels[0, 0]) # Show the labeled components + tensor([[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]]) """ components = connected_components(image, num_iterations=num_iterations) @@ -32,11 +73,39 @@ def connected_components_gpu(image: torch.Tensor, num_iterations: int = 1000) -> def connected_components_cpu(image: torch.Tensor) -> torch.Tensor: """Perform connected component labeling on CPU. + Labels connected regions in a binary image using OpenCV's implementation. + Ensures unique labeling across batched inputs by remapping component labels + sequentially. + Args: - image (torch.Tensor): Binary input data from which we want to extract connected components (Bx1xHxW) + image (torch.Tensor): Binary input tensor of shape ``(B, 1, H, W)`` where + ``B`` is batch size, ``H`` is height and ``W`` is width. Values should + be binary (``0`` or ``1``). Returns: - Tensor: Components labeled from 0 to N. + torch.Tensor: Integer tensor of same shape as input, containing labeled + components from ``0`` to ``N``. Background (zero) pixels in the input + remain ``0``, while connected regions are labeled with integers from + ``1`` to ``N``, ensuring unique labels across the batch. + + Example: + >>> import torch + >>> from anomalib.utils.cv import connected_components_cpu + >>> # Create a binary mask with a 2x2 square anomaly + >>> mask = torch.zeros(1, 1, 4, 4) + >>> mask[0, 0, 1:3, 1:3] = 1 + >>> labels = connected_components_cpu(mask) + >>> print(labels.unique()) # Should show tensor([0, 1]) + >>> print(labels[0, 0]) # Show the labeled components + tensor([[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]]) + + Note: + This function uses OpenCV's ``connectedComponents`` implementation which + runs on CPU. For GPU acceleration, use :func:`connected_components_gpu` + instead. """ components = torch.zeros_like(image) label_idx = 1 diff --git a/src/anomalib/utils/exceptions/__init__.py b/src/anomalib/utils/exceptions/__init__.py index 52d64883d1..fc9c2a55c2 100644 --- a/src/anomalib/utils/exceptions/__init__.py +++ b/src/anomalib/utils/exceptions/__init__.py @@ -1,4 +1,21 @@ -"""Utilities related to exception and error handling.""" +"""Exception and error handling utilities for anomaly detection. + +This module provides utilities for handling exceptions and errors in the anomalib +library. The utilities include: + - Dynamic import handling with graceful fallbacks + - Custom exception types for anomaly detection + - Error handling helpers and decorators + +Example: + >>> from anomalib.utils.exceptions import try_import + >>> # Try importing an optional dependency + >>> torch_fidelity = try_import("torch_fidelity") + >>> if torch_fidelity is None: + ... print("torch-fidelity not installed") + +The module ensures consistent and informative error handling across the anomalib +codebase. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/utils/exceptions/imports.py b/src/anomalib/utils/exceptions/imports.py index 6ef8dbd89d..36609c9652 100644 --- a/src/anomalib/utils/exceptions/imports.py +++ b/src/anomalib/utils/exceptions/imports.py @@ -1,4 +1,21 @@ -"""Import handling utilities.""" +"""Import handling utilities for anomaly detection. + +This module provides utilities for handling dynamic imports and import-related +exceptions in the anomalib library. The utilities include: + - Dynamic module import with graceful error handling + - Import availability checking + - Deprecation warnings for legacy import functions + +Example: + >>> from anomalib.utils.exceptions import try_import + >>> # Try importing an optional dependency + >>> torch_fidelity = try_import("torch_fidelity") + >>> if torch_fidelity is None: + ... print("torch-fidelity not installed") + +The module ensures consistent handling of optional dependencies and provides +helpful error messages when imports fail. +""" # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,13 +27,33 @@ def try_import(import_path: str) -> bool: - """Try to import a module. + """Try to import a module and return whether the import succeeded. + + This function attempts to dynamically import a Python module and handles any + import errors gracefully. It is deprecated and will be removed in v2.0.0. + Users should migrate to ``module_available`` from lightning-utilities instead. Args: - import_path (str): The import path of the module. + import_path (str): The import path of the module to try importing. This can + be a top-level package name (e.g. ``"torch"``) or a submodule path + (e.g. ``"torch.nn"``). Returns: - bool: True if import succeeds, False otherwise. + bool: ``True`` if the import succeeds, ``False`` if an ``ImportError`` + occurs. + + Warns: + DeprecationWarning: This function is deprecated and will be removed in + v2.0.0. Use ``module_available`` from lightning-utilities instead. + + Example: + >>> from anomalib.utils.exceptions import try_import + >>> # Try importing an optional dependency + >>> has_torch = try_import("torch") + >>> if not has_torch: + ... print("PyTorch is not installed") + >>> # Try importing a submodule + >>> has_torchvision = try_import("torchvision.transforms") """ import warnings diff --git a/src/anomalib/utils/logging.py b/src/anomalib/utils/logging.py index d73ef440c4..a92149c53e 100644 --- a/src/anomalib/utils/logging.py +++ b/src/anomalib/utils/logging.py @@ -1,4 +1,28 @@ -"""Logging Utility functions.""" +"""Logging utility functions for anomaly detection. + +This module provides utilities for logging and output management. The key components include: + + - ``LoggerRedirectError``: Custom exception for logging redirection failures + - ``hide_output``: Decorator to suppress function output streams + - Helper functions for redirecting output to loggers + +Example: + >>> from anomalib.utils.logging import hide_output + >>> @hide_output + >>> def my_function(): + ... print("This output will be hidden") + >>> my_function() + +The module ensures consistent logging behavior by: + - Providing decorators for output control + - Handling both stdout and stderr redirection + - Supporting exception propagation + - Offering flexible output management + +Note: + The logging utilities are designed to work with both standard Python logging + and custom logging implementations. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,35 +37,67 @@ class LoggerRedirectError(Exception): - """Exception occurred when executing function with outputs redirected to logger.""" + """Exception raised when redirecting function output to logger fails. + + This exception is raised when there is an error while redirecting the output + streams (stdout/stderr) of a function to a logger. It typically occurs in + functions decorated with ``@hide_output``. + + Example: + >>> @hide_output + >>> def problematic_function(): + ... raise ValueError("Something went wrong") + >>> problematic_function() + Traceback (most recent call last): + ... + LoggerRedirectError: Error occurred while executing problematic_function + + Note: + This exception wraps the original exception that caused the redirection + failure, which can be accessed through the ``__cause__`` attribute. + """ def hide_output(func: Callable[..., Any]) -> Callable[..., Any]: - """Hide output of the function. + """Hide output of a function by redirecting stdout and stderr. + + This decorator captures and discards any output that would normally be printed + to stdout or stderr when the decorated function executes. The function's + return value is preserved. Args: - func (function): Hides output from all streams of this function. + func (Callable[..., Any]): Function whose output should be hidden. + All output streams from this function will be captured. + + Returns: + Callable[..., Any]: Wrapped function that executes silently. + + Raises: + LoggerRedirectError: If an error occurs during function execution. The + original exception can be accessed via ``__cause__``. Example: + Basic usage to hide print statements: + >>> @hide_output - >>> def my_function(): - >>> print("This will not be printed") - >>> my_function() + ... def my_function(): + ... print("This will not be printed") + >>> my_function() # No output will appear + + Exceptions are still propagated: >>> @hide_output - >>> def my_function(): - >>> 1/0 + ... def my_function(): + ... 1/0 # doctest: +IGNORE_EXCEPTION_DETAIL >>> my_function() Traceback (most recent call last): - File "", line 1, in - File "", line 2, in my_fun - ZeroDivisionError: division by zero + ... + LoggerRedirectError: Error occurred while executing my_function - Raises: - Exception: In case the execution of function fails, it raises an exception. - - Returns: - object of the called function + Note: + - The decorator preserves the function's metadata using ``functools.wraps`` + - Both ``stdout`` and ``stderr`` streams are captured + - Original streams are always restored, even if an exception occurs """ @functools.wraps(func) @@ -66,11 +122,33 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 def redirect_logs(log_file: str) -> None: - """Add file handler to logger. + """Add file handler to logger and remove other handlers. - It also removes all other handlers from the loggers. + This function sets up file-based logging by: + - Creating a file handler for the specified log file + - Setting a standard format for log messages + - Removing all other handlers from existing loggers + - Configuring warning capture - Note: This feature does not work well with multiprocessing and won't redirect logs from child processes. + Args: + log_file: Path to the log file where messages will be written. + Parent directories will be created if they don't exist. + + Example: + >>> from pathlib import Path + >>> log_path = Path("logs/app.log") + >>> redirect_logs(str(log_path)) # doctest: +SKIP + >>> import logging + >>> logger = logging.getLogger(__name__) + >>> logger.info("Test message") # Message written to logs/app.log + + Note: + - The log format includes timestamp, logger name, level and message + - All existing handlers are removed from loggers to ensure logs only go + to file + - This function does not work well with multiprocessing - logs from + child processes will not be redirected + - The function captures Python warnings in addition to regular logs """ Path(log_file).parent.mkdir(exist_ok=True, parents=True) logger_file_handler = logging.FileHandler(log_file) diff --git a/src/anomalib/utils/normalization/__init__.py b/src/anomalib/utils/normalization/__init__.py index ebf4493204..6434926074 100644 --- a/src/anomalib/utils/normalization/__init__.py +++ b/src/anomalib/utils/normalization/__init__.py @@ -1,13 +1,51 @@ -"""Tools for anomaly score normalization.""" +"""Tools for anomaly score normalization. -# Copyright (C) 2022 Intel Corporation +This module provides utilities for normalizing anomaly scores in anomaly detection +tasks. The utilities include: + - Min-max normalization to scale scores to [0,1] range + - Enum class to specify normalization methods + +Example: + >>> from anomalib.utils.normalization import NormalizationMethod + >>> # Use min-max normalization + >>> method = NormalizationMethod.MIN_MAX + >>> print(method) + min_max + >>> # Use no normalization + >>> method = NormalizationMethod.NONE + >>> print(method) + none + +The module ensures consistent normalization of anomaly scores across different +detection algorithms. +""" + +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from enum import Enum class NormalizationMethod(str, Enum): - """Normalization method for normalization.""" + """Enumeration of supported normalization methods for anomaly scores. + + This enum class defines the available methods for normalizing anomaly scores: + - ``MIN_MAX``: Scales scores to [0,1] range using min-max normalization + - ``NONE``: No normalization is applied, raw scores are used + + Example: + >>> from anomalib.utils.normalization import NormalizationMethod + >>> # Use min-max normalization + >>> method = NormalizationMethod.MIN_MAX + >>> print(method) + min_max + >>> # Use no normalization + >>> method = NormalizationMethod.NONE + >>> print(method) + none + + The enum inherits from ``str`` to enable string comparison and serialization. + """ MIN_MAX = "min_max" NONE = "none" diff --git a/src/anomalib/utils/normalization/min_max.py b/src/anomalib/utils/normalization/min_max.py index 9df69c8d06..bf1982be36 100644 --- a/src/anomalib/utils/normalization/min_max.py +++ b/src/anomalib/utils/normalization/min_max.py @@ -1,4 +1,25 @@ -"""Tools for min-max normalization.""" +"""Tools for min-max normalization. + +This module provides utilities for min-max normalization of anomaly scores. The +main function :func:`normalize` scales values to [0,1] range and centers them +around a threshold. + +Example: + >>> import numpy as np + >>> from anomalib.utils.normalization.min_max import normalize + >>> # Create sample anomaly scores + >>> scores = np.array([0.1, 0.5, 0.8]) + >>> threshold = 0.5 + >>> min_val = 0.0 + >>> max_val = 1.0 + >>> # Normalize scores + >>> normalized = normalize(scores, threshold, min_val, max_val) + >>> print(normalized) # Values centered around 0.5 + [0.1 0.5 0.8] + +The module supports both NumPy arrays and PyTorch tensors as inputs, with +appropriate handling for each type. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,7 +34,39 @@ def normalize( min_val: float | np.ndarray | torch.Tensor, max_val: float | np.ndarray | torch.Tensor, ) -> np.ndarray | torch.Tensor: - """Apply min-max normalization and shift the values such that the threshold value is centered at 0.5.""" + """Apply min-max normalization and center values around a threshold. + + This function performs min-max normalization on the input values and shifts them + such that the threshold value is centered at 0.5. The output is clipped to the + range [0,1]. + + Args: + targets (numpy.ndarray | numpy.float32 | torch.Tensor): Input values to + normalize. Can be either a NumPy array or PyTorch tensor. + threshold (float | numpy.ndarray | torch.Tensor): Threshold value that will + be centered at 0.5 after normalization. + min_val (float | numpy.ndarray | torch.Tensor): Minimum value used for + normalization scaling. + max_val (float | numpy.ndarray | torch.Tensor): Maximum value used for + normalization scaling. + + Returns: + numpy.ndarray | torch.Tensor: Normalized values in range [0,1] with + threshold centered at 0.5. Output type matches input type. + + Raises: + TypeError: If ``targets`` is neither a NumPy array nor PyTorch tensor. + + Example: + >>> import torch + >>> scores = torch.tensor([0.1, 0.5, 0.8]) + >>> threshold = 0.5 + >>> min_val = 0.0 + >>> max_val = 1.0 + >>> normalized = normalize(scores, threshold, min_val, max_val) + >>> print(normalized) + tensor([0.1000, 0.5000, 0.8000]) + """ normalized = ((targets - threshold) / (max_val - min_val)) + 0.5 if isinstance(targets, np.ndarray | np.float32 | np.float64): normalized = np.minimum(normalized, 1) diff --git a/src/anomalib/utils/path.py b/src/anomalib/utils/path.py index c9f92937d2..aea614eb54 100644 --- a/src/anomalib/utils/path.py +++ b/src/anomalib/utils/path.py @@ -1,4 +1,30 @@ -"""Anomalib Path Utils.""" +"""Path utilities for anomaly detection. + +This module provides utilities for managing paths and directories in anomaly +detection projects. The key components include: + + - Version directory creation and management + - Symbolic link handling + - Path resolution and validation + +Example: + >>> from anomalib.utils.path import create_versioned_dir + >>> from pathlib import Path + >>> # Create versioned directory + >>> version_dir = create_versioned_dir(Path("experiments")) + >>> version_dir.name + 'v1' + +The module ensures consistent path handling by: + - Creating incrementing version directories (v1, v2, etc.) + - Maintaining a ``latest`` symbolic link + - Handling both string and ``Path`` inputs + - Providing cross-platform compatibility + +Note: + All paths are resolved to absolute paths to ensure consistent behavior + across different working directories. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -10,28 +36,41 @@ def create_versioned_dir(root_dir: str | Path) -> Path: """Create a new version directory and update the ``latest`` symbolic link. + This function creates a new versioned directory (e.g. ``v1``, ``v2``, etc.) inside the + specified root directory and updates a ``latest`` symbolic link to point to it. + The version numbers increment automatically based on existing directories. + Args: - root_dir (Path): The root directory where the version directories are stored. + root_dir (Union[str, Path]): Root directory path where version directories will be + created. Can be provided as a string or ``Path`` object. Directory will be + created if it doesn't exist. Returns: - latest_link_path (Path): The path to the ``latest`` symbolic link. + Path: Path to the ``latest`` symbolic link that points to the newly created + version directory. Examples: - >>> version_dir = create_version_dir(Path('path/to/experiments/')) - PosixPath('/path/to/experiments/latest') - - >>> version_dir.resolve().name - v1 - - Calling the function again will create a new version directory and - update the ``latest`` symbolic link: - - >>> version_dir = create_version_dir('path/to/experiments/') - PosixPath('/path/to/experiments/latest') - - >>> version_dir.resolve().name - v2 - + Create first version directory: + + >>> from pathlib import Path + >>> version_dir = create_versioned_dir(Path("experiments")) + >>> version_dir + PosixPath('experiments/latest') + >>> version_dir.resolve().name # Points to v1 + 'v1' + + Create second version directory: + + >>> version_dir = create_versioned_dir("experiments") + >>> version_dir.resolve().name # Now points to v2 + 'v2' + + Note: + - The function resolves all paths to absolute paths + - Creates parent directories if they don't exist + - Handles existing symbolic links by removing and recreating them + - Version directories follow the pattern ``v1``, ``v2``, etc. + - The ``latest`` link always points to the most recently created version """ # Compile a regular expression to match version directories version_pattern = re.compile(r"^v(\d+)$") @@ -66,23 +105,55 @@ def create_versioned_dir(root_dir: str | Path) -> Path: def convert_to_snake_case(s: str) -> str: - """Converts a string to snake case. + """Convert a string to snake case format. + + This function converts various string formats (space-separated, camelCase, + PascalCase, etc.) to snake_case by: + + - Converting spaces and punctuation to underscores + - Inserting underscores before capital letters + - Converting to lowercase + - Removing redundant underscores Args: - s (str): The input string to be converted. + s (str): Input string to convert to snake case. Returns: - str: The converted string in snake case. + str: The input string converted to snake case format. Examples: + Convert space-separated string: + >>> convert_to_snake_case("Snake Case") 'snake_case' + Convert camelCase: + >>> convert_to_snake_case("snakeCase") 'snake_case' + Convert PascalCase: + + >>> convert_to_snake_case("SnakeCase") + 'snake_case' + + Handle existing snake_case: + >>> convert_to_snake_case("snake_case") 'snake_case' + + Handle punctuation: + + >>> convert_to_snake_case("snake.case") + 'snake_case' + + >>> convert_to_snake_case("snake-case") + 'snake_case' + + Note: + - Leading/trailing underscores are removed + - Multiple consecutive underscores are collapsed to a single underscore + - Punctuation marks (``.``, ``-``, ``'``) are converted to underscores """ # Replace whitespace, hyphens, periods, and apostrophes with underscores s = re.sub(r"\s+|[-.\']", "_", s) @@ -98,45 +169,63 @@ def convert_to_snake_case(s: str) -> str: def convert_to_title_case(text: str) -> str: - """Converts a given text to title case, handling regular text, snake_case, and camelCase. + """Convert text to title case, handling various text formats. + + This function converts text from various formats (regular text, snake_case, camelCase, + PascalCase) to title case format. It preserves punctuation and handles contractions + appropriately. Args: - text (str): The input text to be converted to title case. + text (str): Input text to convert to title case. Can be in any text format like + snake_case, camelCase, PascalCase or regular text. Returns: - str: The input text converted to title case. + str: The input text converted to title case format. Raises: - TypeError: If the input is not a string. + TypeError: If the input ``text`` is not a string. Examples: Regular text: + >>> convert_to_title_case("the quick brown fox") 'The Quick Brown Fox' Snake case: + >>> convert_to_title_case("convert_snake_case_to_title_case") 'Convert Snake Case To Title Case' Camel case: + >>> convert_to_title_case("convertCamelCaseToTitleCase") 'Convert Camel Case To Title Case' Pascal case: + >>> convert_to_title_case("ConvertPascalCaseToTitleCase") 'Convert Pascal Case To Title Case' Mixed cases: + >>> convert_to_title_case("mixed_snake_camelCase and PascalCase") 'Mixed Snake Camel Case And Pascal Case' Handling punctuation and contractions: + >>> convert_to_title_case("what's the_weather_like? it'sSunnyToday.") "What's The Weather Like? It's Sunny Today." With numbers and special characters: + >>> convert_to_title_case("python3.9_features and camelCaseNames") 'Python 3.9 Features And Camel Case Names' + + Note: + - Preserves contractions (e.g., "what's" -> "What's") + - Handles mixed case formats in the same string + - Maintains punctuation and spacing + - Properly capitalizes words after numbers and special characters """ if not isinstance(text, str): msg = "Input must be a string" @@ -166,48 +255,62 @@ def generate_output_filename( category: str | None = None, mkdir: bool = True, ) -> Path: - """Generate an output filename based on the input path, preserving the directory structure. + """Generate an output filename based on the input path. - This function takes an input path, an output base directory, a dataset name, and an optional - category. It generates an output path that preserves the directory structure after the dataset - name (and category, if provided) while placing the file in the specified output directory. + This function generates an output path that preserves the directory structure after the + dataset name (and category if provided) while placing the file in the specified output + directory. Args: - input_path (str | Path): The input file path. - output_path (str | Path): The base output directory. - dataset_name (str): The name of the dataset in the input path. - category (str | None, optional): The category name in the input path. Defaults to None. - mkdir (bool, optional): Whether to create the output directory. Defaults to True. + input_path (str | Path): Path to the input file. + output_path (str | Path): Base output directory path. + dataset_name (str): Name of the dataset to find in the input path. + category (str | None, optional): Category name to find in the input path after + dataset name. Defaults to ``None``. + mkdir (bool, optional): Whether to create the output directory structure. + Defaults to ``True``. Returns: - Path: The generated output file path. + Path: Generated output file path preserving relevant directory structure. Raises: - ValueError: If the dataset name or category (if provided) is not found in the input path. + ValueError: If ``dataset_name`` is not found in ``input_path``. + ValueError: If ``category`` is provided but not found in ``input_path`` after + ``dataset_name``. Examples: + Basic usage with category: + >>> input_path = "/data/MVTec/bottle/test/broken_large/000.png" >>> output_base = "/results" >>> dataset = "MVTec" - - # With category >>> generate_output_filename(input_path, output_base, dataset, "bottle") PosixPath('/results/test/broken_large/000.png') - # Without category + Without category preserves more structure: + >>> generate_output_filename(input_path, output_base, dataset) PosixPath('/results/bottle/test/broken_large/000.png') - # Different dataset structure - >>> input_path = "/datasets/MyDataset/train/class_A/image_001.jpg" - >>> generate_output_filename(input_path, "/output", "MyDataset", "class_A") + Different dataset structure: + + >>> path = "/datasets/MyDataset/train/class_A/image_001.jpg" + >>> generate_output_filename(path, "/output", "MyDataset", "class_A") PosixPath('/output/image_001.jpg') - # Error case: Dataset not in path - >>> generate_output_filename("/wrong/path/image.png", "/out", "NonexistentDataset") + Dataset not found raises error: + + >>> generate_output_filename("/wrong/path/image.png", "/out", "Missing") Traceback (most recent call last): ... - ValueError: Dataset name 'NonexistentDataset' not found in the input path. + ValueError: Dataset name 'Missing' not found in the input path. + + Note: + - Directory structure after ``dataset_name`` (or ``category`` if provided) is + preserved in output path + - If ``mkdir=True``, creates output directory structure if it doesn't exist + - Dataset and category name matching is case-insensitive + - Original filename is preserved in output path """ input_path = Path(input_path) output_path = Path(output_path) diff --git a/src/anomalib/utils/post_processing.py b/src/anomalib/utils/post_processing.py index ff6a7d33eb..18b2c3b3a5 100644 --- a/src/anomalib/utils/post_processing.py +++ b/src/anomalib/utils/post_processing.py @@ -1,4 +1,34 @@ -"""Post Process This module contains utils function to apply post-processing to the output predictions.""" +"""Post-processing utilities for anomaly detection predictions. + +This module provides utilities for post-processing anomaly detection predictions. +The key components include: + + - Label addition to images with confidence scores + - Morphological operations on prediction masks + - Normalization and thresholding of anomaly maps + +Example: + >>> import numpy as np + >>> from anomalib.utils.post_processing import add_label + >>> # Add label to image + >>> image = np.zeros((100, 100, 3), dtype=np.uint8) + >>> labeled_image = add_label( + ... image=image, + ... label_name="Anomalous", + ... color=(255, 0, 0), + ... confidence=0.95 + ... ) + +The module ensures consistent post-processing by: + - Providing standardized label formatting + - Supporting both classification and segmentation outputs + - Handling proper scaling of visual elements + - Offering configurable processing parameters + +Note: + All functions preserve the input data types and handle proper normalization + of values where needed. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -18,18 +48,54 @@ def add_label( font_scale: float = 5e-3, thickness_scale: float = 1e-3, ) -> np.ndarray: - """Add a label to an image. + """Add a text label with optional confidence score to an image. + + This function adds a text label to the top-left corner of an image. The label has a + colored background patch and can optionally include a confidence percentage. Args: - image (np.ndarray): Input image. - label_name (str): Name of the label that will be displayed on the image. - color (tuple[int, int, int]): RGB values for background color of label. - confidence (float | None): confidence score of the label. - font_scale (float): scale of the font size relative to image size. Increase for bigger font. - thickness_scale (float): scale of the font thickness. Increase for thicker font. + image (np.ndarray): Input image to add the label to. Must be a 3-channel RGB or + BGR image. + label_name (str): Text label to display on the image (e.g. "normal", + "anomalous"). + color (tuple[int, int, int]): RGB color values for the label background as a + tuple of 3 integers in range [0,255]. + confidence (float | None, optional): Confidence score between 0 and 1 to display + as percentage. If ``None``, only the label name is shown. Defaults to + ``None``. + font_scale (float, optional): Scale factor for font size relative to image + dimensions. Larger values produce bigger text. Defaults to ``5e-3``. + thickness_scale (float, optional): Scale factor for font thickness relative to + image dimensions. Larger values produce thicker text. Defaults to ``1e-3``. Returns: - np.ndarray: Image with label. + np.ndarray: Copy of input image with label added to top-left corner. + + Example: + Add a normal label with 95% confidence: + + >>> import numpy as np + >>> image = np.zeros((100, 100, 3), dtype=np.uint8) + >>> labeled_image = add_label( + ... image=image, + ... label_name="normal", + ... color=(0, 255, 0), + ... confidence=0.95 + ... ) + + Add an anomalous label without confidence: + + >>> labeled_image = add_label( + ... image=image, + ... label_name="anomalous", + ... color=(255, 0, 0) + ... ) + + Note: + - The function creates a copy of the input image to avoid modifying it + - Font size and thickness scale automatically with image dimensions + - Label is always placed in the top-left corner + - Uses OpenCV's FONT_HERSHEY_PLAIN font family """ image = image.copy() img_height, img_width, _ = image.shape @@ -62,25 +128,110 @@ def add_label( def add_normal_label(image: np.ndarray, confidence: float | None = None) -> np.ndarray: - """Add the normal label to the image.""" + """Add a 'normal' label to the image. + + This function adds a 'normal' label to the top-left corner of the image using a + light green color. The label can optionally include a confidence score. + + Args: + image (np.ndarray): Input image to add the label to. Should be a 3-channel + RGB or BGR image. + confidence (float | None, optional): Confidence score between 0 and 1 to + display with the label. If ``None``, only the label is shown. + Defaults to ``None``. + + Returns: + np.ndarray: Copy of input image with 'normal' label added. + + Examples: + Add normal label without confidence: + + >>> labeled_image = add_normal_label(image) + + Add normal label with 95% confidence: + + >>> labeled_image = add_normal_label(image, confidence=0.95) + + Note: + - Creates a copy of the input image + - Uses a light green color (RGB: 225, 252, 134) + - Label is placed in top-left corner + - Font size scales with image dimensions + """ return add_label(image, "normal", (225, 252, 134), confidence) def add_anomalous_label(image: np.ndarray, confidence: float | None = None) -> np.ndarray: - """Add the anomalous label to the image.""" + """Add an 'anomalous' label to the image. + + This function adds an 'anomalous' label to the top-left corner of the image using a + light red color. The label can optionally include a confidence score. + + Args: + image (np.ndarray): Input image to add the label to. Should be a 3-channel + RGB or BGR image. + confidence (float | None, optional): Confidence score between 0 and 1 to + display with the label. If ``None``, only the label is shown. + Defaults to ``None``. + + Returns: + np.ndarray: Copy of input image with 'anomalous' label added. + + Examples: + Add anomalous label without confidence: + + >>> labeled_image = add_anomalous_label(image) + + Add anomalous label with 95% confidence: + + >>> labeled_image = add_anomalous_label(image, confidence=0.95) + + Note: + - Creates a copy of the input image + - Uses a light red color (RGB: 255, 100, 100) + - Label is placed in top-left corner + - Font size scales with image dimensions + """ return add_label(image, "anomalous", (255, 100, 100), confidence) def anomaly_map_to_color_map(anomaly_map: np.ndarray, normalize: bool = True) -> np.ndarray: - """Compute anomaly color heatmap. + """Convert an anomaly map to a color heatmap visualization. + + This function converts a grayscale anomaly map into a color heatmap using the JET + colormap. The anomaly map can optionally be normalized before coloring. Args: - anomaly_map (np.ndarray): Final anomaly map computed by the distance metric. - normalize (bool, optional): Bool to normalize the anomaly map prior to applying - the color map. Defaults to True. + anomaly_map (np.ndarray): Grayscale anomaly map computed by the model's + distance metric. Should be a 2D array of float values. + normalize (bool, optional): Whether to normalize the anomaly map to [0,1] range + before applying the colormap. If ``True``, the map is normalized using + min-max scaling. Defaults to ``True``. Returns: - np.ndarray: [description] + np.ndarray: RGB color heatmap visualization of the anomaly map. Values are in + range [0,255] and type uint8. + + Examples: + Convert anomaly map without normalization: + + >>> heatmap = anomaly_map_to_color_map(anomaly_map, normalize=False) + >>> heatmap.shape + (224, 224, 3) + >>> heatmap.dtype + dtype('uint8') + + Convert with normalization (default): + + >>> heatmap = anomaly_map_to_color_map(anomaly_map) + >>> heatmap.min(), heatmap.max() + (0, 255) + + Note: + - Input map is converted to uint8 by scaling to [0,255] range + - Uses OpenCV's JET colormap for visualization + - Output is converted from BGR to RGB color format + - Shape of output matches input with added channel dimension """ if normalize: anomaly_map = (anomaly_map - anomaly_map.min()) / np.ptp(anomaly_map) @@ -98,23 +249,55 @@ def superimpose_anomaly_map( gamma: int = 0, normalize: bool = False, ) -> np.ndarray: - """Superimpose anomaly map on top of in the input image. + """Superimpose an anomaly heatmap on top of an input image. - Args: - anomaly_map (np.ndarray): Anomaly map - image (np.ndarray): Input image - alpha (float, optional): Weight to overlay anomaly map - on the input image. Defaults to 0.4. - gamma (int, optional): Value to add to the blended image - to smooth the processing. Defaults to 0. Overall, - the formula to compute the blended image is - I' = (alpha*I1 + (1-alpha)*I2) + gamma - normalize: whether or not the anomaly maps should - be normalized to image min-max at image level + This function overlays a colored anomaly map visualization on an input image using + alpha blending. The anomaly map can optionally be normalized before blending. + Args: + anomaly_map (np.ndarray): Grayscale anomaly map computed by the model's + distance metric. Should be a 2D array of float values. + image (np.ndarray): Input image to overlay the anomaly map on. Will be + resized to match anomaly map dimensions. + alpha (float, optional): Blending weight for the anomaly map overlay. + Should be in range [0,1] where 0 shows only the input image and 1 + shows only the anomaly map. Defaults to ``0.4``. + gamma (int, optional): Value added to the blended result for smoothing. + The blending formula is: + ``output = (alpha * anomaly_map + (1-alpha) * image) + gamma`` + Defaults to ``0``. + normalize (bool, optional): Whether to normalize the anomaly map to [0,1] + range before coloring. If ``True``, uses min-max scaling at the image + level. Defaults to ``False``. Returns: - np.ndarray: Image with anomaly map superimposed on top of it. + np.ndarray: RGB image with the colored anomaly map overlay. Values are in + range [0,255] and type uint8. + + Examples: + Basic overlay without normalization: + + >>> result = superimpose_anomaly_map(anomaly_map, image, alpha=0.4) + >>> result.shape + (224, 224, 3) + >>> result.dtype + dtype('uint8') + + Overlay with normalization and custom blending: + + >>> result = superimpose_anomaly_map( + ... anomaly_map, + ... image, + ... alpha=0.7, + ... gamma=10, + ... normalize=True + ... ) + + Note: + - Input image is resized to match anomaly map dimensions + - Anomaly map is converted to a color heatmap using JET colormap + - Output maintains RGB color format + - Shape of output matches the anomaly map dimensions """ anomaly_map = anomaly_map_to_color_map(anomaly_map.squeeze(), normalize=normalize) height, width = anomaly_map.shape[:2] @@ -123,15 +306,46 @@ def superimpose_anomaly_map( def compute_mask(anomaly_map: np.ndarray, threshold: float, kernel_size: int = 4) -> np.ndarray: - """Compute anomaly mask via thresholding the predicted anomaly map. + """Compute binary anomaly mask by thresholding and post-processing anomaly map. + + This function converts a continuous-valued anomaly map into a binary mask by: + - Thresholding the anomaly scores + - Applying morphological operations to reduce noise + - Scaling to 8-bit range [0, 255] Args: - anomaly_map (np.ndarray): Anomaly map predicted via the model - threshold (float): Value to threshold anomaly scores into 0-1 range. - kernel_size (int): Value to apply morphological operations to the predicted mask. Defaults to 4. + anomaly_map (np.ndarray): Anomaly map containing predicted anomaly scores. + Should be a 2D array of float values. + threshold (float): Threshold value to binarize anomaly scores. Values above + this threshold are considered anomalous (1) and below as normal (0). + kernel_size (int, optional): Size of the morphological structuring element + used for noise removal. Higher values result in smoother masks. + Defaults to ``4``. Returns: - Predicted anomaly mask + np.ndarray: Binary anomaly mask where anomalous regions are marked with + 255 and normal regions with 0. Output is uint8 type. + + Examples: + Basic thresholding with default kernel size: + + >>> anomaly_scores = np.random.rand(100, 100) + >>> mask = compute_mask(anomaly_scores, threshold=0.5) + >>> mask.shape + (100, 100) + >>> mask.dtype + dtype('uint8') + >>> np.unique(mask) + array([ 0, 255], dtype=uint8) + + Custom kernel size for stronger smoothing: + + >>> mask = compute_mask(anomaly_scores, threshold=0.5, kernel_size=8) + + Note: + - Input anomaly map is squeezed to remove singleton dimensions + - Morphological opening is used to remove small noise artifacts + - Output is scaled to [0, 255] range for visualization """ anomaly_map = anomaly_map.squeeze() mask: np.ndarray = np.zeros_like(anomaly_map).astype(np.uint8) @@ -148,13 +362,45 @@ def compute_mask(anomaly_map: np.ndarray, threshold: float, kernel_size: int = 4 def draw_boxes(image: np.ndarray, boxes: np.ndarray, color: tuple[int, int, int]) -> np.ndarray: """Draw bounding boxes on an image. + This function draws rectangular bounding boxes on an input image using OpenCV. Each box + is drawn with the specified color and a fixed thickness of 2 pixels. + Args: - image (np.ndarray): Source image. - boxes (np.nparray): 2D array of shape (N, 4) where each row contains the xyxy coordinates of a bounding box. - color (tuple[int, int, int]): Color of the drawn boxes in RGB format. + image (np.ndarray): Source image on which to draw the boxes. Should be a valid + OpenCV-compatible image array. + boxes (np.ndarray): 2D array of shape ``(N, 4)`` where each row contains the + ``(x1, y1, x2, y2)`` coordinates of a bounding box in pixel units. The + coordinates specify the top-left and bottom-right corners. + color (tuple[int, int, int]): Color of the drawn boxes in RGB format, specified + as a tuple of 3 integers in the range ``[0, 255]``. Returns: - np.ndarray: Image showing the bounding boxes drawn on top of the source image. + np.ndarray: Modified image with bounding boxes drawn on top. Has the same + dimensions and type as the input image. + + Examples: + Draw a single red box: + + >>> import numpy as np + >>> image = np.zeros((100, 100, 3), dtype=np.uint8) + >>> boxes = np.array([[10, 10, 50, 50]]) # Single box + >>> result = draw_boxes(image, boxes, color=(255, 0, 0)) + >>> result.shape + (100, 100, 3) + + Draw multiple boxes in green: + + >>> boxes = np.array([ + ... [20, 20, 40, 40], + ... [60, 60, 80, 80] + ... ]) # Two boxes + >>> result = draw_boxes(image, boxes, color=(0, 255, 0)) + + Note: + - Input coordinates are converted to integers before drawing + - Boxes are drawn with a fixed thickness of 2 pixels + - The function modifies the input image in-place + - OpenCV uses BGR color format internally but the function expects RGB """ for box in boxes: x_1, y_1, x_2, y_2 = box.astype(int) diff --git a/src/anomalib/utils/types/__init__.py b/src/anomalib/utils/types/__init__.py index a220571bc0..da61db770d 100644 --- a/src/anomalib/utils/types/__init__.py +++ b/src/anomalib/utils/types/__init__.py @@ -1,4 +1,23 @@ -"""Typing aliases for Anomalib.""" +"""Type aliases for anomaly detection. + +This module provides type aliases used throughout the anomalib library. The aliases +include: + - ``NORMALIZATION``: Type for normalization methods and configurations + - ``THRESHOLD``: Type for threshold values and configurations + +Example: + >>> from anomalib.utils.types import NORMALIZATION, THRESHOLD + >>> from anomalib.utils.normalization import NormalizationMethod + >>> # Use min-max normalization + >>> norm: NORMALIZATION = NormalizationMethod.MIN_MAX + >>> print(norm) + min_max + >>> # Use threshold configuration + >>> thresh: THRESHOLD = {"method": "adaptive", "delta": 0.1} + +The module ensures consistent typing across the codebase and provides helpful type +hints for configuration objects. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/utils/visualization/__init__.py b/src/anomalib/utils/visualization/__init__.py index 404036dfad..582707aa13 100644 --- a/src/anomalib/utils/visualization/__init__.py +++ b/src/anomalib/utils/visualization/__init__.py @@ -1,4 +1,26 @@ -"""Visualization utils.""" +"""Tools for visualizing anomaly detection results. + +This module provides utilities for visualizing anomaly detection outputs. The +utilities include: + - Base visualization interface and common functionality + - Image-based visualization for detection results + - Explanation visualization for model interpretability + - Metrics visualization for performance analysis + +Example: + >>> from anomalib.utils.visualization import ImageVisualizer + >>> # Create visualizer for detection results + >>> visualizer = ImageVisualizer() + >>> # Visualize detection on an image + >>> vis_result = visualizer.visualize( + ... image=image, + ... pred_mask=mask, + ... anomaly_map=heatmap + ... ) + +The module ensures consistent and informative visualization across different +detection approaches and result types. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/utils/visualization/base.py b/src/anomalib/utils/visualization/base.py index bafac0c0fa..f2a36430bd 100644 --- a/src/anomalib/utils/visualization/base.py +++ b/src/anomalib/utils/visualization/base.py @@ -1,4 +1,26 @@ -"""Base visualization generator.""" +"""Base visualization generator for anomaly detection. + +This module provides the base visualization interface and common functionality used +across different visualization types. The key components include: + + - ``GeneratorResult``: Dataclass for standardized visualization outputs + - ``VisualizationStep``: Enum for controlling when visualizations are generated + - ``BaseVisualizer``: Abstract base class defining the visualization interface + +Example: + >>> from anomalib.utils.visualization import BaseVisualizer + >>> # Create custom visualizer + >>> class CustomVisualizer(BaseVisualizer): + ... def generate(self, **kwargs): + ... # Generate visualization + ... yield GeneratorResult(image=img) + >>> # Use visualizer + >>> vis = CustomVisualizer(visualize_on="batch") + >>> results = vis.generate(image=input_img) + +The module ensures consistent visualization behavior and output formats across +different visualization implementations. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/utils/visualization/explanation.py b/src/anomalib/utils/visualization/explanation.py index 10904161e3..b840f8bd11 100644 --- a/src/anomalib/utils/visualization/explanation.py +++ b/src/anomalib/utils/visualization/explanation.py @@ -1,6 +1,31 @@ -"""Explanation visualization generator. - -Note: This is a temporary visualizer, and will be replaced with the new visualizer in the future. +"""Explanation visualization generator for model interpretability. + +This module provides utilities for visualizing model explanations and +interpretability results. The key components include: + + - Text-based explanations rendered on images + - Label visualization for model decisions + - Combined visualization with original image and explanation + +Example: + >>> from anomalib.utils.visualization import ExplanationVisualizer + >>> # Create visualizer + >>> visualizer = ExplanationVisualizer() + >>> # Generate visualization + >>> results = visualizer.generate( + ... outputs={ + ... "image": images, + ... "explanation": explanations, + ... "image_path": paths + ... } + ... ) + +Note: + This is a temporary visualizer that will be replaced with an enhanced + version in a future release. + +The module ensures consistent visualization of model explanations across +different interpretability approaches. """ # Copyright (C) 2024 Intel Corporation diff --git a/src/anomalib/utils/visualization/image.py b/src/anomalib/utils/visualization/image.py index 16b852235f..9aa0b68822 100644 --- a/src/anomalib/utils/visualization/image.py +++ b/src/anomalib/utils/visualization/image.py @@ -1,4 +1,46 @@ -"""Image/video generator.""" +"""Image and video visualization generator. + +This module provides utilities for visualizing anomaly detection results on images +and videos. The key components include: + + - ``ImageResult``: Dataclass for storing visualization data + - ``ImageVisualizer``: Main visualization generator class + - ``VisualizationMode``: Enum for controlling visualization style + - ``_ImageGrid``: Helper class for creating image grids + +The module supports both classification and segmentation tasks, with options for: + + - Full visualization showing all available outputs + - Simple visualization showing only key predictions + - Customizable normalization of anomaly maps + - Automatic handling of both image and video inputs + +Example: + >>> from anomalib.utils.visualization import ImageVisualizer + >>> from anomalib.utils.visualization.image import VisualizationMode + >>> # Create visualizer + >>> visualizer = ImageVisualizer( + ... mode=VisualizationMode.FULL, + ... task="segmentation", + ... normalize=True + ... ) + >>> # Generate visualization + >>> results = visualizer.generate( + ... outputs={ + ... "image": images, + ... "pred_mask": masks, + ... "anomaly_map": heatmaps + ... } + ... ) + +The module ensures consistent visualization across different anomaly detection +approaches and result types. It handles proper scaling and formatting of inputs, +and provides a flexible interface for customizing the visualization output. + +Note: + When using video inputs, the visualizer automatically handles frame extraction + and maintains proper frame ordering in the output. +""" # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -17,11 +59,7 @@ from anomalib import TaskType from anomalib.data import ImageItem, NumpyImageItem, VideoItem from anomalib.data.utils import read_image -from anomalib.utils.post_processing import ( - add_anomalous_label, - add_normal_label, - superimpose_anomaly_map, -) +from anomalib.utils.post_processing import add_anomalous_label, add_normal_label, superimpose_anomaly_map from .base import BaseVisualizer, GeneratorResult, VisualizationStep @@ -30,7 +68,13 @@ class VisualizationMode(str, Enum): - """Type of visualization mode.""" + """Visualization mode for controlling output style. + + The mode determines how results are displayed: + + - ``FULL``: Shows all available visualizations in a grid + - ``SIMPLE``: Shows only the key prediction results + """ FULL = "full" SIMPLE = "simple" @@ -38,7 +82,21 @@ class VisualizationMode(str, Enum): @dataclass class ImageResult: - """Collection of data needed to visualize the predictions for an image.""" + """Collection of data needed to visualize predictions for an image. + + Args: + image (np.ndarray): Input image to visualize + pred_score (float): Predicted anomaly score + pred_label (str): Predicted label (e.g. "normal" or "anomalous") + anomaly_map (np.ndarray | None): Anomaly heatmap if available + gt_mask (np.ndarray | None): Ground truth mask if available + pred_mask (np.ndarray | None): Predicted segmentation mask if available + normalize (InitVar[bool]): Whether to normalize anomaly maps to [0,1] + + Note: + The class automatically handles proper scaling and type conversion of + inputs during initialization. + """ image: np.ndarray pred_score: float @@ -90,7 +148,14 @@ def __repr__(self) -> str: def from_dataset_item(cls: type["ImageResult"], item: ImageItem | NumpyImageItem) -> "ImageResult": """Create an ImageResult object from a DatasetItem object. - This is a temporary solution until we refactor the visualizer to take a DatasetItem object directly as input. + This is a temporary solution until we refactor the visualizer to take a + DatasetItem object directly as input. + + Args: + item (ImageItem | NumpyImageItem): Dataset item to convert + + Returns: + ImageResult: New image result object """ if isinstance(item, ImageItem): item = item.to_numpy() @@ -100,14 +165,19 @@ def from_dataset_item(cls: type["ImageResult"], item: ImageItem | NumpyImageItem class ImageVisualizer(BaseVisualizer): - """Image/video generator. + """Image and video visualization generator. Args: - mode (VisualizationMode, optional): Type of visualization mode. Defaults to VisualizationMode.FULL. - task (TaskType, optional): Type of task. Defaults to TaskType.CLASSIFICATION. - normalize (bool, optional): Whether or not the anomaly maps should be normalized to image min-max at image - level. Defaults to False. Note: This is more useful when NormalizationMethod is set to None. Otherwise, - the overlayed anomaly map will contain the raw scores. + mode (VisualizationMode, optional): Visualization mode. Defaults to + ``VisualizationMode.FULL``. + task (TaskType | str, optional): Type of task. Defaults to + ``TaskType.CLASSIFICATION``. + normalize (bool, optional): Whether to normalize anomaly maps to image + min-max. Defaults to ``False``. + + Note: + Normalization is most useful when no other normalization method is used, + as otherwise the overlay will show raw anomaly scores. """ def __init__( @@ -122,7 +192,17 @@ def __init__( self.normalize = normalize def generate(self, **kwargs) -> Iterator[GeneratorResult]: - """Generate images and return them as an iterator.""" + """Generate images and return them as an iterator. + + Args: + **kwargs: Keyword arguments containing model outputs. + + Returns: + Iterator yielding visualization results. + + Raises: + ValueError: If outputs are not provided in kwargs. + """ outputs = kwargs.get("outputs", None) if outputs is None: msg = "Outputs must be provided to generate images." @@ -133,10 +213,14 @@ def _visualize_batch(self, batch: dict) -> Iterator[GeneratorResult]: """Yield a visualization result for each item in the batch. Args: - batch (dict): Dictionary containing the ground truth and predictions of a batch of images. + batch (dict): Dictionary containing the ground truth and predictions + of a batch of images. Returns: Generator that yields a display-ready visualization for each image. + + Raises: + TypeError: If item has neither image path nor video path defined. """ for item in batch: if hasattr(item, "image_path") and item.image_path is not None: @@ -167,7 +251,10 @@ def visualize_image(self, image_result: ImageResult) -> np.ndarray: image_result (ImageResult): GT and Prediction data for a single image. Returns: - The full or simple visualization for the image, depending on the specified mode. + np.ndarray: The full or simple visualization for the image. + + Raises: + ValueError: If visualization mode is unknown. """ if self.mode == VisualizationMode.FULL: return self._visualize_full(image_result) @@ -179,15 +266,21 @@ def visualize_image(self, image_result: ImageResult) -> np.ndarray: def _visualize_full(self, image_result: ImageResult) -> np.ndarray: """Generate the full set of visualization for an image. - The full visualization mode shows a grid with subplots that contain the original image, the GT mask (if - available), the predicted heat map, the predicted segmentation mask (if available), and the predicted - segmentations (if available). + The full visualization mode shows a grid with subplots that contain: + - Original image + - GT mask (if available) + - Predicted heat map + - Predicted segmentation mask (if available) + - Predicted segmentations (if available) Args: image_result (ImageResult): GT and Prediction data for a single image. Returns: - An image showing the full set of visualizations for the input image. + np.ndarray: Image showing the full set of visualizations. + + Raises: + ValueError: If predicted mask is None for segmentation task. """ image_grid = _ImageGrid() if self.task == TaskType.SEGMENTATION: @@ -216,13 +309,17 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray: def _visualize_simple(self, image_result: ImageResult) -> np.ndarray: """Generate a simple visualization for an image. - The simple visualization mode only shows the model's predictions in a single image. + The simple visualization mode only shows the model's predictions in a + single image. Args: image_result (ImageResult): GT and Prediction data for a single image. Returns: - An image showing the simple visualization for the input image. + np.ndarray: Image showing the simple visualization. + + Raises: + ValueError: If task type is unknown. """ if self.task == TaskType.SEGMENTATION: visualization = mark_boundaries( @@ -245,8 +342,9 @@ def _visualize_simple(self, image_result: ImageResult) -> np.ndarray: class _ImageGrid: """Helper class that compiles multiple images into a grid using subplots. - Individual images can be added with the `add_image` method. When all images have been added, the `generate` method - must be called to compile the image grid and obtain the final visualization. + Individual images can be added with the ``add_image`` method. When all images + have been added, the ``generate`` method must be called to compile the image + grid and obtain the final visualization. """ def __init__(self) -> None: @@ -258,18 +356,24 @@ def add_image(self, image: np.ndarray, title: str | None = None, color_map: str """Add an image to the grid. Args: - image (np.ndarray): Image which should be added to the figure. - title (str): Image title shown on the plot. - color_map (str | None): Name of matplotlib color map used to map scalar data to colours. Defaults to None. + image (np.ndarray): Image to add to the figure + title (str | None): Image title shown on the plot + color_map (str | None): Name of matplotlib color map for mapping + scalar data to colours. Defaults to ``None``. """ image_data = {"image": image, "title": title, "color_map": color_map} self.images.append(image_data) def generate(self) -> np.ndarray: - """Generate the image. + """Generate the image grid. Returns: - Image consisting of a grid of added images and their title. + np.ndarray: Image consisting of a grid of added images and their + titles. + + Note: + Uses Agg backend to avoid issues with dimension mismatch when using + backends like MacOSX. """ num_cols = len(self.images) figure_size = (num_cols * 5, 5) @@ -290,7 +394,8 @@ def generate(self) -> np.ndarray: axis.title.set_text(image_dict["title"]) self.figure.canvas.draw() # convert canvas to numpy array to prepare for visualization with opencv - img = np.frombuffer(self.figure.canvas.tostring_rgb(), dtype=np.uint8) - img = img.reshape(self.figure.canvas.get_width_height()[::-1] + (3,)) + img = np.frombuffer(self.figure.canvas.buffer_rgba(), dtype=np.uint8) + img = img.reshape(self.figure.canvas.get_width_height()[::-1] + (4,)) # RGBA has 4 channels + img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) plt.close(self.figure) return img diff --git a/src/anomalib/utils/visualization/metrics.py b/src/anomalib/utils/visualization/metrics.py index 36f4948405..d5320cca4b 100644 --- a/src/anomalib/utils/visualization/metrics.py +++ b/src/anomalib/utils/visualization/metrics.py @@ -1,4 +1,29 @@ -"""Metrics visualization generator.""" +"""Metrics visualization generator for anomaly detection results. + +This module provides utilities for visualizing metric plots from anomaly detection +models. The key components include: + + - Automatic generation of metric plots from model metrics + - Support for both image-level and pixel-level metrics + - Consistent file naming and output format + +Example: + >>> from anomalib.utils.visualization import MetricsVisualizer + >>> # Create metrics visualizer + >>> visualizer = MetricsVisualizer() + >>> # Generate metric plots + >>> results = visualizer.generate(pl_module=model) + +The module ensures proper visualization of model performance metrics by: + - Automatically detecting plottable metrics + - Generating standardized plot formats + - Handling both classification and segmentation metrics + - Providing consistent file naming conventions + +Note: + Metrics must implement a ``generate_figure`` method to be visualized. + The method should return a tuple of (figure, log_name). +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -13,14 +38,36 @@ class MetricsVisualizer(BaseVisualizer): - """Generate metric plots.""" + """Generate metric plots from model metrics. + + This class handles the automatic generation of metric plots from an anomalib + model's metrics. It supports both image-level and pixel-level metrics. + """ def __init__(self) -> None: super().__init__(VisualizationStep.STAGE_END) @staticmethod def generate(**kwargs) -> Iterator[GeneratorResult]: - """Generate metric plots and return them as an iterator.""" + """Generate metric plots and return them as an iterator. + + Args: + **kwargs: Keyword arguments passed to the generator. + Must include ``pl_module`` containing the model metrics. + + Yields: + Iterator[GeneratorResult]: Generator results containing the plot + figures and filenames. + + Raises: + ValueError: If ``pl_module`` is not provided in kwargs. + + Example: + >>> visualizer = MetricsVisualizer() + >>> for result in visualizer.generate(pl_module=model): + ... # Process the visualization result + ... print(result.file_name) + """ pl_module: AnomalibModule = kwargs.get("pl_module", None) if pl_module is None: msg = "`pl_module` must be provided" diff --git a/src/anomalib/visualization/__init__.py b/src/anomalib/visualization/__init__.py index 989f4cc34c..c49e72a6a3 100644 --- a/src/anomalib/visualization/__init__.py +++ b/src/anomalib/visualization/__init__.py @@ -1,4 +1,30 @@ -"""Visualization module.""" +"""Visualization module for anomaly detection. + +This module provides utilities for visualizing anomaly detection results. The key +components include: + + - Base ``Visualizer`` class defining the visualization interface + - ``ImageVisualizer`` class for image-based visualization + - Functions for visualizing anomaly maps and segmentation masks + - Tools for visualizing ``ImageItem`` objects + +Example: + >>> from anomalib.visualization import ImageVisualizer + >>> # Create visualizer + >>> visualizer = ImageVisualizer() + >>> # Generate visualization + >>> vis_result = visualizer.visualize(image=img, pred_mask=mask) + +The module ensures consistent visualization by: + - Providing standardized visualization interfaces + - Supporting both classification and segmentation results + - Handling various input formats + - Maintaining consistent output formats + +Note: + All visualization functions preserve the input format and dimensions unless + explicitly specified otherwise. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/visualization/base.py b/src/anomalib/visualization/base.py index dc49a85401..0229d05a3b 100644 --- a/src/anomalib/visualization/base.py +++ b/src/anomalib/visualization/base.py @@ -1,4 +1,30 @@ -"""Base Visualizer.""" +"""Base visualization module for anomaly detection. + +This module provides the base ``Visualizer`` class that defines the interface for +visualizing anomaly detection results. The key components include: + + - Base ``Visualizer`` class that inherits from PyTorch Lightning's ``Callback`` + - Interface for visualizing model outputs during testing and prediction + - Support for customizable visualization formats and configurations + +Example: + >>> from anomalib.visualization import Visualizer + >>> # Create custom visualizer + >>> class CustomVisualizer(Visualizer): + ... def visualize(self, **kwargs): + ... # Custom visualization logic + ... pass + +The module ensures consistent visualization by: + - Providing a standardized visualization interface + - Supporting both classification and segmentation results + - Enabling customizable visualization formats + - Maintaining consistent output formats + +Note: + All visualizer implementations should inherit from the base ``Visualizer`` + class and implement the required visualization methods. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -9,6 +35,31 @@ class Visualizer(Callback): """Base class for all visualizers. - In Anomalib, the visualizer is used to visualize the results of the model - during the testing and prediction phases. + This class serves as the foundation for implementing visualization functionality in + Anomalib. It inherits from PyTorch Lightning's ``Callback`` class to integrate with + the training workflow. + + The visualizer is responsible for generating visual representations of model outputs + during testing and prediction phases. This includes: + + - Visualizing input images + - Displaying model predictions + - Showing ground truth annotations + - Generating overlays and heatmaps + - Saving visualization results + + Example: + >>> from anomalib.visualization import Visualizer + >>> # Create custom visualizer + >>> class CustomVisualizer(Visualizer): + ... def visualize(self, **kwargs): + ... # Custom visualization logic + ... pass + + Note: + All custom visualizers should: + - Inherit from this base class + - Implement the ``visualize`` method + - Handle relevant visualization configurations + - Maintain consistent output formats """ diff --git a/src/anomalib/visualization/image/__init__.py b/src/anomalib/visualization/image/__init__.py index 9f60f1399a..fb7e299407 100644 --- a/src/anomalib/visualization/image/__init__.py +++ b/src/anomalib/visualization/image/__init__.py @@ -1,4 +1,31 @@ -"""Image visualization module.""" +"""Image visualization module for anomaly detection. + +This module provides utilities for visualizing images and anomaly detection results. +The key components include: + + - Functions for visualizing anomaly maps and segmentation masks + - Tools for overlaying images and adding text annotations + - Colormap application utilities + - Image item visualization + - ``ImageVisualizer`` class for consistent visualization + +Example: + >>> from anomalib.visualization.image import ImageVisualizer + >>> # Create visualizer + >>> visualizer = ImageVisualizer() + >>> # Generate visualization + >>> vis_result = visualizer.visualize(image=img, pred_mask=mask) + +The module ensures consistent visualization by: + - Providing standardized colormaps and overlays + - Supporting both classification and segmentation results + - Handling various input formats + - Maintaining consistent output formats + +Note: + All visualization functions preserve the input image format and dimensions + unless explicitly specified otherwise. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/visualization/image/functional.py b/src/anomalib/visualization/image/functional.py index 940586c2d6..558e55613e 100644 --- a/src/anomalib/visualization/image/functional.py +++ b/src/anomalib/visualization/image/functional.py @@ -1,4 +1,30 @@ -"""Visualizer for ImageItem fields using PIL and torchvision.""" +"""Image visualization functions using PIL and torchvision. + +This module provides functions for visualizing images and anomaly detection results using +PIL and torchvision. The key components include: + + - Functions for adding text overlays to images + - Tools for applying colormaps to anomaly maps + - Image overlay and blending utilities + - Mask and anomaly map visualization + +Example: + >>> from PIL import Image + >>> from anomalib.visualization.image.functional import add_text_to_image + >>> # Create image and add text + >>> image = Image.new('RGB', (100, 100)) + >>> result = add_text_to_image(image, text="Anomaly") + +The module ensures consistent visualization by: + - Providing standardized text rendering + - Supporting various color formats and fonts + - Handling different image formats + - Maintaining aspect ratios + +Note: + All visualization functions preserve the input image format and dimensions + unless explicitly specified otherwise. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -11,7 +37,7 @@ import torch import torch.nn.functional as F # noqa: N812 -from PIL import Image, ImageDraw, ImageEnhance, ImageFilter, ImageFont +from PIL import Image, ImageDraw, ImageEnhance, ImageFont from torchvision.transforms.functional import to_pil_image logger = logging.getLogger(__name__) @@ -20,14 +46,40 @@ def dynamic_font_size(image_size: tuple[int, int], min_size: int = 20, max_size: int = 100, divisor: int = 10) -> int: """Calculate a dynamic font size based on image dimensions. + This function determines an appropriate font size based on the image dimensions while + staying within specified bounds. The font size is calculated by dividing the smaller + image dimension by the divisor. + Args: - image_size: Tuple of image dimensions (width, height). - min_size: Minimum font size (default: 20). - max_size: Maximum font size (default: 100). - divisor: Divisor for calculating font size (default: 10). + image_size (tuple[int, int]): Tuple of image dimensions ``(width, height)``. + min_size (int, optional): Minimum allowed font size. Defaults to ``20``. + max_size (int, optional): Maximum allowed font size. Defaults to ``100``. + divisor (int, optional): Value to divide the minimum image dimension by. + Defaults to ``10``. Returns: - Calculated font size within the specified range. + int: Calculated font size constrained between ``min_size`` and ``max_size``. + + Examples: + Calculate font size for a small image: + + >>> dynamic_font_size((200, 100)) + 20 + + Calculate font size for a large image: + + >>> dynamic_font_size((1000, 800)) + 80 + + Font size is capped at max_size: + + >>> dynamic_font_size((2000, 2000), max_size=50) + 50 + + Note: + - The function uses the smaller dimension to ensure text fits in both directions + - The calculated size is clamped between ``min_size`` and ``max_size`` + - Larger ``divisor`` values result in smaller font sizes """ min_dimension = min(image_size) return max(min_size, min(max_size, min_dimension // divisor)) @@ -43,7 +95,64 @@ def add_text_to_image( position: tuple[int, int] = (10, 10), padding: int = 3, ) -> Image.Image: - """Add text to an image with configurable parameters.""" + """Add text to an image with configurable parameters. + + This function adds text to a PIL Image with customizable font, size, color and + background options. The text can be positioned anywhere on the image and includes + an optional background box. + + Args: + image (Image.Image): The PIL Image to add text to. + text (str): The text string to add to the image. + font (str | None, optional): Path to a font file. If ``None`` or loading fails, + the default system font is used. Defaults to ``None``. + size (int | None, optional): Font size in pixels. If ``None``, size is + calculated dynamically based on image dimensions. Defaults to ``None``. + color (tuple[int, int, int] | str, optional): Text color as RGB tuple or color + name. Defaults to ``"white"``. + background (tuple[int, ...] | str | None, optional): Background color for text + box. Can be RGB/RGBA tuple or color name. If ``None``, no background is + drawn. Defaults to semi-transparent black ``(0, 0, 0, 128)``. + position (tuple[int, int], optional): Top-left position of text as ``(x, y)`` + coordinates. Defaults to ``(10, 10)``. + padding (int, optional): Padding around text in background box in pixels. + Defaults to ``3``. + + Returns: + Image.Image: New PIL Image with text added. + + Examples: + Basic white text: + + >>> from PIL import Image + >>> img = Image.new('RGB', (200, 100)) + >>> result = add_text_to_image(img, "Hello") + + Custom font and color: + + >>> result = add_text_to_image( + ... img, + ... "Hello", + ... font="arial.ttf", + ... color=(255, 0, 0) + ... ) + + Text with custom background: + + >>> result = add_text_to_image( + ... img, + ... "Hello", + ... background=(0, 0, 255, 200), + ... position=(50, 50) + ... ) + + Note: + - The function creates a transparent overlay for the text + - Font size is calculated dynamically if not specified + - Falls back to default system font if custom font fails to load + - Input image is converted to RGBA for compositing + - Output is converted back to RGB + """ # Create a new RGBA image as a transparent overlay overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) @@ -77,24 +186,52 @@ def apply_colormap(image: Image.Image) -> Image.Image: """Apply a colormap to a single-channel PIL Image using torch and PIL. This function converts a grayscale image to a colored image using the 'jet' colormap. + The colormap is created by interpolating between 9 key colors from dark blue to dark + red. Args: - image (Image.Image): A single-channel PIL Image or an object that can be converted to PIL Image. + image (``Image.Image``): A single-channel PIL Image or an object that can be + converted to PIL Image. If not already in 'L' mode (8-bit grayscale), it will + be converted. Returns: - Image.Image: A new PIL Image with the colormap applied. + ``Image.Image``: A new PIL Image in RGB mode with the colormap applied. Raises: TypeError: If the input cannot be converted to a PIL Image. Example: + Create a random grayscale image and apply colormap: + >>> from PIL import Image >>> import numpy as np >>> # Create a sample grayscale image - >>> gray_image = Image.fromarray(np.random.randint(0, 256, (100, 100), dtype=np.uint8), mode='L') + >>> gray = np.random.randint(0, 256, (100, 100), dtype=np.uint8) + >>> gray_image = Image.fromarray(gray, mode='L') >>> # Apply the jet colormap >>> colored_image = apply_colormap(gray_image) - >>> colored_image.show() + >>> colored_image.mode + 'RGB' + + Apply to non-PIL input: + + >>> # NumPy array input is automatically converted + >>> colored_image = apply_colormap(gray) + >>> isinstance(colored_image, Image.Image) + True + + Invalid input raises TypeError: + + >>> apply_colormap("not an image") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError: Input must be a PIL Image object or an object that can be... + + Note: + - Input is automatically converted to grayscale if not already + - Uses a custom 'jet' colormap interpolated between 9 key colors + - Output is always in RGB mode regardless of input mode + - The colormap interpolation uses bilinear mode for smooth transitions """ # Try to convert the input to a PIL Image if it's not already if not isinstance(image, Image.Image): @@ -141,28 +278,47 @@ def apply_colormap(image: Image.Image) -> Image.Image: def overlay_image(base: Image.Image, overlay: Image.Image, alpha: float = 0.5) -> Image.Image: """Overlay an image on top of another image with a specified alpha value. + This function takes a base image and overlays another image on top of it with the + specified transparency level. Both images are converted to RGBA mode to enable alpha + compositing. If the overlay image has a different size than the base image, it will + be resized to match. + Args: - base (Image.Image): The base image. - overlay (Image.Image): The image to overlay. - alpha (float): The alpha value for blending (0.0 to 1.0). Defaults to 0.5. + base (:class:`PIL.Image.Image`): The base image that will serve as the + background. + overlay (:class:`PIL.Image.Image`): The image to overlay on top of the base + image. + alpha (float, optional): The alpha/transparency value for blending, between + 0.0 (fully transparent) and 1.0 (fully opaque). Defaults to ``0.5``. Returns: - Image.Image: The image with the overlay applied. + :class:`PIL.Image.Image`: A new image with the overlay composited on top of + the base image using the specified alpha value. Examples: - # Overlay a random mask on an image - >>> from PIL import Image, ImageDraw + Create a base image with a yellow triangle on a green background: + >>> from PIL import Image, ImageDraw >>> image = Image.new('RGB', (200, 200), color='green') >>> draw = ImageDraw.Draw(image) >>> draw.polygon([(50, 50), (150, 50), (100, 150)], fill='yellow') + Create a mask with a white rectangle on black background: + >>> mask = Image.new('L', (200, 200), color=0) >>> draw = ImageDraw.Draw(mask) >>> draw.rectangle([75, 75, 125, 125], fill=255) + Overlay the mask on the image with 30% opacity: + >>> result = overlay_image(image, mask, alpha=0.3) - >>> result.show() + >>> result.show() # doctest: +SKIP + + Note: + - Both input images are converted to RGBA mode internally + - The overlay is automatically resized to match the base image size + - The function uses PIL's alpha compositing for high-quality blending + - The output image preserves the RGBA mode of the composite result """ base = base.convert("RGBA") overlay = overlay.convert("RGBA") @@ -185,45 +341,53 @@ def overlay_images( overlays: Image.Image | list[Image.Image], alpha: float | list[float] = 0.5, ) -> Image.Image: - """Overlay multiple images on top of a base image with a specified alpha value. + """Overlay multiple images on top of a base image with specified transparency. - If the overlay is a mask (L mode), draw its contours on the image instead. + This function overlays one or more images on top of a base image with specified + alpha/transparency values. If an overlay is a mask (L mode), it will be drawn + as a semi-transparent overlay. Args: - base: The base PIL Image. - overlays: PIL Image or list of PIL Images to overlay on top of the base image. - alpha: The alpha value for blending (0.0 to 1.0). Defaults to 0.5. + base (:class:`PIL.Image.Image`): The base image to overlay on top of. + overlays (:class:`PIL.Image.Image` | list[:class:`PIL.Image.Image`]): One + or more images to overlay on the base image. + alpha (float | list[float], optional): Alpha/transparency value(s) between + 0.0 (fully transparent) and 1.0 (fully opaque). Can be a single float + applied to all overlays, or a list of values. Defaults to ``0.5``. Returns: - A new PIL Image with all overlays applied. + :class:`PIL.Image.Image`: A new image with all overlays composited on top + of the base image using the specified alpha values. Examples: - # Overlay a single image + Overlay a single mask: + >>> from PIL import Image, ImageDraw + >>> # Create base image with yellow triangle on green background >>> image = Image.new('RGB', (200, 200), color='green') >>> draw = ImageDraw.Draw(image) >>> draw.polygon([(50, 50), (150, 50), (100, 150)], fill='yellow') - + >>> # Create mask with white rectangle >>> mask = Image.new('L', (200, 200), color=0) >>> draw = ImageDraw.Draw(mask) >>> draw.rectangle([75, 75, 125, 125], fill=255) - + >>> # Apply overlay >>> result = overlay_images(image, mask) - # Overlay multiple images - >>> image = Image.new('RGB', (200, 200), color='green') - >>> draw = ImageDraw.Draw(image) - >>> draw.polygon([(50, 50), (150, 50), (100, 150)], fill='yellow') - - >>> mask1 = Image.new('L', (200, 200), color=0) - >>> draw = ImageDraw.Draw(mask1) - >>> draw.rectangle([25, 25, 75, 75], fill=255) + Overlay multiple masks with different alphas: + >>> # Create second mask with white ellipse >>> mask2 = Image.new('L', (200, 200), color=0) >>> draw = ImageDraw.Draw(mask2) >>> draw.ellipse([50, 50, 150, 100], fill=255) - - >>> result = overlay_images(image, [mask1, mask2]) + >>> # Apply overlays with different alpha values + >>> result = overlay_images(image, [mask, mask2], alpha=[0.3, 0.7]) + + Note: + - All images are converted to RGBA mode internally + - Overlays are automatically resized to match the base image size + - Uses PIL's alpha compositing for high-quality blending + - The output preserves the RGBA mode of the composite result """ if not isinstance(overlays, list): overlays = [overlays] @@ -240,35 +404,56 @@ def visualize_anomaly_map( colormap: bool = True, normalize: bool = False, ) -> Image.Image: - """Visualize the anomaly map. + """Visualize an anomaly map by applying normalization and/or colormap. - This function takes an anomaly map as input and applies normalization and/or colormap - based on the provided parameters. + This function takes an anomaly map and converts it to a visualization by optionally + normalizing the values and applying a colormap. The input can be either a PIL Image + or PyTorch tensor. Args: - anomaly_map (Image.Image | torch.Tensor): The input anomaly map as a PIL Image or torch Tensor. - colormap (bool, optional): Whether to apply a colormap to the anomaly map. Defaults to True. - normalize (bool, optional): Whether to normalize the anomaly map. Defaults to False. + anomaly_map (:class:`PIL.Image.Image` | :class:`torch.Tensor`): Input anomaly + map to visualize. If a tensor is provided, it will be converted to a PIL + Image. + colormap (bool, optional): Whether to apply a colormap to the anomaly map. + When ``True``, converts the image to a colored heatmap visualization. + When ``False``, converts to RGB grayscale. Defaults to ``True``. + normalize (bool, optional): Whether to normalize the anomaly map values to + [0, 255] range before visualization. When ``True``, linearly scales the + values using min-max normalization. Defaults to ``False``. Returns: - Image.Image: The visualized anomaly map as a PIL Image in RGB mode. + :class:`PIL.Image.Image`: Visualized anomaly map as a PIL Image in RGB mode. + If ``colormap=True``, returns a heatmap visualization. Otherwise returns + a grayscale RGB image. + + Examples: + Visualize a PIL Image anomaly map: - Example: >>> from PIL import Image >>> import numpy as np - >>> import torch + >>> # Create sample anomaly map + >>> data = np.random.rand(100, 100).astype(np.float32) + >>> anomaly_map = Image.fromarray(data, mode='F') + >>> # Visualize with normalization and colormap + >>> vis = visualize_anomaly_map(anomaly_map, normalize=True, colormap=True) + >>> vis.mode + 'RGB' - >>> # Create a sample anomaly map as PIL Image - >>> anomaly_map_pil = Image.fromarray(np.random.rand(100, 100).astype(np.float32), mode='F') + Visualize a PyTorch tensor anomaly map: - >>> # Create a sample anomaly map as torch Tensor - >>> anomaly_map_tensor = torch.rand(100, 100) + >>> import torch + >>> # Create random tensor + >>> tensor_map = torch.rand(100, 100) + >>> # Visualize without normalization + >>> vis = visualize_anomaly_map(tensor_map, normalize=False, colormap=True) + >>> isinstance(vis, Image.Image) + True - >>> # Visualize the anomaly maps - >>> visualized_map_pil = visualize_anomaly_map(anomaly_map_pil, normalize=True, colormap=True) - >>> visualized_map_tensor = visualize_anomaly_map(anomaly_map_tensor, normalize=True, colormap=True) - >>> visualized_map_pil.show() - >>> visualized_map_tensor.show() + Note: + - Input tensors are automatically converted to PIL Images + - The function always returns an RGB mode image + - When ``normalize=True``, uses min-max normalization to [0, 255] range + - The colormap used is the default from :func:`apply_colormap` """ image = to_pil_image(anomaly_map) if isinstance(anomaly_map, torch.Tensor) else anomaly_map.copy() @@ -293,81 +478,82 @@ def visualize_mask( ) -> Image.Image: """Visualize a mask with different modes. + This function takes a binary mask and visualizes it in different styles based on the + specified mode. + Args: - mask (Image.Image | torch.Tensor): The input mask. Can be a PIL Image or a PyTorch tensor. - mode (Literal["contour", "binary", "fill"]): The visualization mode. - - "contour": Draw contours of the mask. - - "fill": Fill the masked area with a color. - - "binary": Return the original binary mask. - - "L": Return the original grayscale mask. - - "1": Return the original binary mask. - alpha (float): The alpha value for blending (0.0 to 1.0). Only used in "fill" mode. - Defaults to 0.5. - color (tuple[int, int, int]): The color to apply to the mask. - Defaults to (255, 0, 0) (red). - background_color (tuple[int, int, int, int]): The background color (RGBA). - Defaults to (0, 0, 0, 0) (transparent). + mask (:class:`PIL.Image.Image` | :class:`torch.Tensor`): Input mask to visualize. + Can be a PIL Image or PyTorch tensor. If tensor, should be 2D with values in + [0, 1] or [0, 255]. + mode (Literal["contour", "fill", "binary", "L", "1"]): Visualization mode: + + - ``"contour"``: Draw contours around masked regions + - ``"fill"``: Fill masked regions with semi-transparent color + - ``"binary"``: Return original binary mask + - ``"L"``: Return original grayscale mask + - ``"1"``: Return original binary mask + + alpha (float, optional): Alpha value for blending in ``"fill"`` mode. + Should be between 0.0 and 1.0. Defaults to ``0.5``. + color (tuple[int, int, int], optional): RGB color to apply to mask. + Each value should be 0-255. Defaults to ``(255, 0, 0)`` (red). + background_color (tuple[int, int, int, int], optional): RGBA background color. + Each value should be 0-255. Defaults to ``(0, 0, 0, 0)`` (transparent). Returns: - Image.Image: The visualized mask as a PIL Image. + :class:`PIL.Image.Image`: Visualized mask as a PIL Image. The output mode + depends on the visualization mode: + + - ``"contour"`` and ``"fill"``: Returns RGBA image + - ``"binary"``, ``"L"``, ``"1"``: Returns grayscale image Raises: - TypeError: If the mask is not a PIL Image or PyTorch tensor. - ValueError: If an invalid mode is provided. + TypeError: If ``mask`` is not a PIL Image or PyTorch tensor. + ValueError: If ``mode`` is not one of the allowed values. Examples: + Create a random binary mask: + + >>> import numpy as np + >>> from PIL import Image >>> mask_array = np.random.randint(0, 2, size=(100, 100), dtype=np.uint8) * 255 >>> mask_image = Image.fromarray(mask_array, mode='L') - >>> contour_mask = visualize_mask(mask_image, mode="contour", color=(255, 0, 0)) - >>> contour_mask.show() - - >>> binary_mask = visualize_mask(mask_image, mode="binary") - >>> binary_mask.show() + Visualize mask contours in red: - >>> fill_mask = visualize_mask(mask_image, mode="fill", color=(0, 255, 0), alpha=0.3) - >>> fill_mask.show() - """ - # Convert torch.Tensor to PIL Image if necessary - if isinstance(mask, torch.Tensor): - if mask.dtype == torch.bool: - mask = mask.to(torch.uint8) * 255 - mask = to_pil_image(mask) - - if not isinstance(mask, Image.Image): - msg = "Mask must be a PIL Image or PyTorch tensor" - raise TypeError(msg) - - # Ensure mask is in binary mode - mask = mask.convert("L") - if mode in {"binary", "L", "1"}: - return mask - - # Create a background image - background = Image.new("RGBA", mask.size, background_color) - - match mode: - case "contour": - # Find edges of the mask - edges = mask.filter(ImageFilter.FIND_EDGES) + >>> contour_vis = visualize_mask( + ... mask_image, + ... mode="contour", + ... color=(255, 0, 0) + ... ) + >>> isinstance(contour_vis, Image.Image) + True - # Create a colored version of the edges - colored_edges = Image.new("RGBA", mask.size, (*color, 255)) - colored_edges.putalpha(edges) + Fill mask regions with semi-transparent green: - # Composite the colored edges onto the background - return Image.alpha_composite(background, colored_edges) + >>> fill_vis = visualize_mask( + ... mask_image, + ... mode="fill", + ... color=(0, 255, 0), + ... alpha=0.3 + ... ) + >>> isinstance(fill_vis, Image.Image) + True - case "fill": - # Create a solid color image for the overlay - overlay = Image.new("RGBA", mask.size, (*color, int(255 * alpha))) + Return original binary mask: - # Use the mask to blend the overlay with the background - return Image.composite(overlay, background, mask) + >>> binary_vis = visualize_mask(mask_image, mode="binary") + >>> binary_vis.mode + 'L' - case _: - msg = f"Invalid mode: {mode}. Allowed modes are 'contour', 'binary', or 'fill'." - raise ValueError(msg) + Note: + - Input tensors are automatically converted to PIL Images + - Binary masks are expected to have values of 0 and 255 (or 0 and 1 for tensors) + - The function preserves the original mask when using ``"binary"``, ``"L"`` or + ``"1"`` modes + - ``"contour"`` mode uses edge detection to find mask boundaries + - ``"fill"`` mode creates a semi-transparent overlay using the specified color + """ def visualize_gt_mask( @@ -378,7 +564,49 @@ def visualize_gt_mask( color: tuple[int, int, int] = (255, 0, 0), background_color: tuple[int, int, int, int] = (0, 0, 0, 0), ) -> Image.Image: - """Visualize a ground truth mask.""" + """Visualize a ground truth mask. + + This is a convenience wrapper around :func:`visualize_mask` specifically for + ground truth masks. It provides the same functionality with default parameters + suitable for ground truth visualization. + + Args: + mask (Image.Image | torch.Tensor): Input mask to visualize. Can be either a + PIL Image or PyTorch tensor. + mode (Literal["contour", "fill", "binary", "L", "1"]): Visualization mode. + Defaults to ``"binary"``. + - ``"contour"``: Draw mask boundaries + - ``"fill"``: Fill mask regions with semi-transparent color + - ``"binary"``, ``"L"``, ``"1"``: Return original binary mask + alpha (float): Opacity for the mask visualization in ``"fill"`` mode. + Range [0, 1]. Defaults to ``0.5``. + color (tuple[int, int, int]): RGB color for visualizing the mask. + Defaults to red ``(255, 0, 0)``. + background_color (tuple[int, int, int, int]): RGBA color for the + background. Defaults to transparent ``(0, 0, 0, 0)``. + + Returns: + Image.Image: Visualized mask as a PIL Image. + + Examples: + >>> import torch + >>> from PIL import Image + >>> # Create a sample binary mask + >>> mask = torch.zeros((100, 100)) + >>> mask[25:75, 25:75] = 1 + >>> # Visualize with default settings (binary mode) + >>> vis = visualize_gt_mask(mask) + >>> isinstance(vis, Image.Image) + True + >>> # Visualize with contours in blue + >>> vis = visualize_gt_mask(mask, mode="contour", color=(0, 0, 255)) + >>> isinstance(vis, Image.Image) + True + + Note: + See :func:`visualize_mask` for more details on the visualization modes and + parameters. + """ return visualize_mask(mask, mode=mode, alpha=alpha, color=color, background_color=background_color) @@ -390,19 +618,94 @@ def visualize_pred_mask( alpha: float = 0.5, background_color: tuple[int, int, int, int] = (0, 0, 0, 0), ) -> Image.Image: - """Visualize a prediction mask.""" + """Visualize a prediction mask. + + This is a convenience wrapper around :func:`visualize_mask` specifically for + prediction masks. It provides the same functionality with default parameters + suitable for prediction visualization. + + Args: + mask (Image.Image | torch.Tensor): Input mask to visualize. Can be either a + PIL Image or PyTorch tensor. + mode (Literal["contour", "fill", "binary", "L", "1"]): Visualization mode. + Defaults to ``"binary"``. + - ``"contour"``: Draw mask boundaries + - ``"fill"``: Fill mask regions with semi-transparent color + - ``"binary"``, ``"L"``, ``"1"``: Return original binary mask + color (tuple[int, int, int]): RGB color for visualizing the mask. + Defaults to red ``(255, 0, 0)``. + alpha (float): Opacity for the mask visualization in ``"fill"`` mode. + Range [0, 1]. Defaults to ``0.5``. + background_color (tuple[int, int, int, int]): RGBA color for the + background. Defaults to transparent ``(0, 0, 0, 0)``. + + Returns: + Image.Image: Visualized mask as a PIL Image. + + Examples: + >>> import torch + >>> from PIL import Image + >>> # Create a sample binary mask + >>> mask = torch.zeros((100, 100)) + >>> mask[25:75, 25:75] = 1 + >>> # Visualize with default settings (binary mode) + >>> vis = visualize_pred_mask(mask) + >>> isinstance(vis, Image.Image) + True + >>> # Visualize with contours in blue + >>> vis = visualize_pred_mask(mask, mode="contour", color=(0, 0, 255)) + >>> isinstance(vis, Image.Image) + True + + Note: + See :func:`visualize_mask` for more details on the visualization modes and + parameters. + """ return visualize_mask(mask, mode=mode, alpha=alpha, color=color, background_color=background_color) def create_image_grid(images: list[Image.Image], nrow: int) -> Image.Image: """Create a grid of images using PIL. + This function arranges a list of PIL images into a grid layout with a specified + number of images per row. All input images must have the same dimensions. + Args: - images: List of PIL Images to arrange in a grid. - nrow: Number of images per row. + images (list[Image.Image]): List of PIL Images to arrange in a grid. All + images must have identical dimensions. + nrow (int): Number of images to display per row in the grid. Returns: - A new PIL Image containing the grid of images. + Image.Image: A new PIL Image containing the arranged grid of input images + with white background. + + Raises: + ValueError: If ``images`` list is empty. + + Examples: + Create a 2x2 grid from 4 images: + + >>> from PIL import Image + >>> import numpy as np + >>> # Create sample images + >>> img1 = Image.fromarray(np.zeros((64, 64, 3), dtype=np.uint8)) + >>> img2 = Image.fromarray(np.ones((64, 64, 3), dtype=np.uint8) * 255) + >>> images = [img1, img2, img1, img2] + >>> # Create grid with 2 images per row + >>> grid = create_image_grid(images, nrow=2) + >>> isinstance(grid, Image.Image) + True + >>> grid.size + (128, 128) + + Note: + - All input images must have identical dimensions + - The grid is filled row by row, left to right, top to bottom + - If the number of images is not divisible by ``nrow``, the last row may + be partially filled + - The output image dimensions will be: + width = ``nrow`` * image_width + height = ceil(len(images)/nrow) * image_height """ if not images: msg = "No images provided to create grid" @@ -431,33 +734,59 @@ def create_image_grid(images: list[Image.Image], nrow: int) -> Image.Image: def get_field_kwargs(field: str) -> dict[str, Any]: """Get the keyword arguments for a visualization function. - This function retrieves the default keyword arguments for a given visualization function. + This function retrieves the default keyword arguments for a given visualization + function by inspecting its signature. Args: - field (str): The name of the visualization field (e.g., 'mask', 'anomaly_map'). + field (str): The name of the visualization field (e.g., ``'mask'``, + ``'anomaly_map'``). Returns: - dict[str, Any]: A dictionary containing the default keyword arguments for the visualization function. + dict[str, Any]: A dictionary containing the default keyword arguments for + the visualization function. Each key is a parameter name and the value + is its default value. Raises: - ValueError: If the specified field does not have a corresponding visualization function. + ValueError: If the specified ``field`` does not have a corresponding + visualization function in the current module. Examples: + Get keyword arguments for visualizing a mask: + >>> # Get keyword arguments for visualizing a mask >>> mask_kwargs = get_field_kwargs('mask') - >>> print(mask_kwargs) - {'mode': 'binary', 'color': (255, 0, 0), 'alpha': 0.5, 'background_color': (0, 0, 0, 0)} + >>> print(mask_kwargs) # doctest: +SKIP + { + 'mode': 'binary', + 'color': (255, 0, 0), + 'alpha': 0.5, + 'background_color': (0, 0, 0, 0) + } + + Get keyword arguments for visualizing an anomaly map: >>> # Get keyword arguments for visualizing an anomaly map >>> anomaly_map_kwargs = get_field_kwargs('anomaly_map') - >>> print(anomaly_map_kwargs) - {'colormap': True, 'normalize': False} + >>> print(anomaly_map_kwargs) # doctest: +SKIP + { + 'colormap': True, + 'normalize': False + } - >>> # Attempt to get keyword arguments for an invalid field - >>> get_field_kwargs('invalid_field') + Attempt to get keyword arguments for an invalid field: + + >>> get_field_kwargs('invalid_field') # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError: 'invalid_field' is not a valid function in the current module. + + Note: + - The function looks for a visualization function named + ``visualize_{field}`` in the current module + - Only parameters with default values are included in the returned dict + - Variable keyword arguments (``**kwargs``) are noted in the dict with a + descriptive string + - Both keyword-only and positional-or-keyword parameters are included """ # Get the current module current_module = sys.modules[__name__] @@ -491,37 +820,61 @@ def get_field_kwargs(field: str) -> dict[str, Any]: def get_visualize_function(field: str) -> Callable: """Get the visualization function for a given field. + This function retrieves the visualization function corresponding to a specified field + from the current module. The function name is constructed by prepending + ``visualize_`` to the field name. + Args: - field (str): The name of the visualization field - (e.g., 'image', 'mask', 'anomaly_map'). + field (str): Name of the visualization field. Common values include: + - ``"image"``: For basic image visualization + - ``"mask"``: For segmentation mask visualization + - ``"anomaly_map"``: For anomaly heatmap visualization Returns: - Callable: The visualization function corresponding to the given field. + Callable: Visualization function corresponding to the given field. + The returned function will accept parameters specific to that + visualization type. Raises: - AttributeError: If the specified field does not have a corresponding - visualization function. + AttributeError: If no visualization function exists for the specified + ``field``. The error message will indicate which function name was + not found. Examples: - >>> from PIL import Image + Get visualization function for an anomaly map: - Get the visualize function for an anomaly map + >>> from PIL import Image >>> visualize_func = get_visualize_function('anomaly_map') >>> anomaly_map = Image.new('F', (256, 256)) - >>> visualized_map = visualize_func(anomaly_map, colormap=True, normalize=True) + >>> visualized_map = visualize_func( + ... anomaly_map, + ... colormap=True, + ... normalize=True + ... ) >>> isinstance(visualized_map, Image.Image) True + Get visualization function for a mask: + >>> visualize_func = get_visualize_function('mask') >>> mask = Image.new('1', (256, 256)) >>> visualized_mask = visualize_func(mask, color=(255, 0, 0)) >>> isinstance(visualized_mask, Image.Image) True - Attempt to get a function for an invalid field - >>> get_visualize_function('invalid_field') - Raises AttributeError: module 'anomalib.visualization.image.functional' - has no attribute 'visualize_invalid_field' + Attempting to get function for invalid field raises error: + + >>> get_visualize_function('invalid_field') # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AttributeError: module 'anomalib.visualization.image.functional' has no + attribute 'visualize_invalid_field' + + Note: + - The function looks for visualization functions in the current module + - Function names must follow the pattern ``visualize_{field}`` + - Each visualization function may have different parameters + - All visualization functions return PIL Image objects """ current_module = sys.modules[__name__] func_name = f"visualize_{field}" diff --git a/src/anomalib/visualization/image/visualizer.py b/src/anomalib/visualization/image/visualizer.py index c230bdde03..906220464c 100644 --- a/src/anomalib/visualization/image/visualizer.py +++ b/src/anomalib/visualization/image/visualizer.py @@ -1,4 +1,30 @@ -"""Image Visualizer.""" +"""Image visualization module for anomaly detection. + +This module provides the ``ImageVisualizer`` class for visualizing images and their +associated anomaly detection results. The key components include: + + - Visualization of individual fields (images, masks, anomaly maps) + - Overlay of multiple fields + - Configurable visualization parameters + - Support for saving visualizations + +Example: + >>> from anomalib.visualization.image import ImageVisualizer + >>> # Create visualizer with default settings + >>> visualizer = ImageVisualizer() + >>> # Generate visualization + >>> vis_result = visualizer.visualize(image=img, pred_mask=mask) + +The module ensures consistent visualization by: + - Providing standardized field configurations + - Supporting flexible overlay options + - Handling text annotations + - Maintaining consistent output formats + +Note: + All visualization functions preserve the input image format and dimensions + unless explicitly specified in the configuration. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -27,70 +53,82 @@ class ImageVisualizer(Visualizer): """Image Visualizer. - This class is responsible for visualizing images and their corresponding anomaly maps - during the testing and prediction phases of an anomaly detection model. + This class visualizes images and their corresponding anomaly maps during testing and + prediction phases of an anomaly detection model. Args: - fields (list[str] | None): List of fields to visualize. - Defaults to ["image", "gt_mask"]. - overlay_fields (list[tuple[str, list[str]]] | None): List of tuples specifying fields to overlay. - Defaults to [("image", ["anomaly_map"]), ("image", ["pred_mask"])]. - field_size (tuple[int, int]): Size of each field in the visualization. - Defaults to (256, 256). - fields_config (dict[str, dict[str, Any]]): Custom configurations for field visualization. - Defaults to DEFAULT_FIELDS_CONFIG. - overlay_fields_config (dict[str, dict[str, Any]]): Custom configurations for field overlays. - Defaults to DEFAULT_OVERLAY_FIELDS_CONFIG. - text_config (dict[str, Any]): Configuration for text overlay. - Defaults to DEFAULT_TEXT_CONFIG. - output_dir (str | Path | None): Directory to save the visualizations. - Defaults to None. + fields (list[str] | None, optional): List of fields to visualize. + Defaults to ``["image", "gt_mask"]``. + overlay_fields (list[tuple[str, list[str]]] | None, optional): List of tuples + specifying fields to overlay. Each tuple contains a base field and list of + fields to overlay on it. + Defaults to ``[("image", ["anomaly_map"]), ("image", ["pred_mask"])]``. + field_size (tuple[int, int], optional): Size of each field in visualization as + ``(width, height)``. Defaults to ``(256, 256)``. + fields_config (dict[str, dict[str, Any]] | None, optional): Custom configurations + for field visualization. Merged with ``DEFAULT_FIELDS_CONFIG``. + Defaults to ``None``. + overlay_fields_config (dict[str, dict[str, Any]] | None, optional): Custom + configurations for field overlays. Merged with + ``DEFAULT_OVERLAY_FIELDS_CONFIG``. Defaults to ``None``. + text_config (dict[str, Any] | None, optional): Configuration for text overlay. + Merged with ``DEFAULT_TEXT_CONFIG``. Defaults to ``None``. + output_dir (str | Path | None, optional): Directory to save visualizations. + Defaults to ``None``. Examples: Basic usage with default settings: + >>> visualizer = ImageVisualizer() - Customizing fields to visualize: + Customize fields to visualize: + >>> visualizer = ImageVisualizer( ... fields=["image", "gt_mask", "anomaly_map"], ... overlay_fields=[("image", ["anomaly_map"])] ... ) - Adjusting field size: + Adjust field size: + >>> visualizer = ImageVisualizer(field_size=(512, 512)) - Customizing anomaly map visualization: - >>> visualizer = ImageVisualizer( - ... fields_config={ - ... "anomaly_map": {"colormap": True, "normalize": True} - ... } - ... ) + Customize anomaly map visualization: - Modifying overlay appearance: - >>> visualizer = ImageVisualizer( - ... overlay_fields_config={ - ... "pred_mask": {"alpha": 0.7, "color": (255, 0, 0), "mode": "fill"}, - ... "anomaly_map": {"alpha": 0.5, "color": (0, 255, 0), "mode": "contour"} - ... } - ... ) + >>> fields_config = { + ... "anomaly_map": {"colormap": True, "normalize": True} + ... } + >>> visualizer = ImageVisualizer(fields_config=fields_config) - Customizing text overlay: - >>> visualizer = ImageVisualizer( - ... text_config={ - ... "font": "arial.ttf", - ... "size": 20, - ... "color": "yellow", - ... "background": (0, 0, 0, 200) - ... } - ... ) + Modify overlay appearance: + + >>> overlay_config = { + ... "pred_mask": {"alpha": 0.7, "color": (255, 0, 0), "mode": "fill"}, + ... "anomaly_map": {"alpha": 0.5, "color": (0, 255, 0), "mode": "contour"} + ... } + >>> visualizer = ImageVisualizer(overlay_fields_config=overlay_config) + + Customize text overlay: + + >>> text_config = { + ... "font": "arial.ttf", + ... "size": 20, + ... "color": "yellow", + ... "background": (0, 0, 0, 200) + ... } + >>> visualizer = ImageVisualizer(text_config=text_config) + + Specify output directory: - Specifying output directory: >>> visualizer = ImageVisualizer(output_dir="./output/visualizations") Advanced configuration combining multiple customizations: + >>> visualizer = ImageVisualizer( ... fields=["image", "gt_mask", "anomaly_map", "pred_mask"], - ... overlay_fields=[("image", ["anomaly_map"]), ("image", ["pred_mask"])], + ... overlay_fields=[ + ... ("image", ["anomaly_map"]), + ... ("image", ["pred_mask"]) + ... ], ... field_size=(384, 384), ... fields_config={ ... "anomaly_map": {"colormap": True, "normalize": True}, @@ -110,15 +148,21 @@ class ImageVisualizer(Visualizer): ... ) Note: - - The 'fields' parameter determines which individual fields are visualized. - - The 'overlay_fields' parameter specifies which fields should be overlaid on others. - - Field configurations in 'fields_config' affect how individual fields are visualized. - - Overlay configurations in 'overlay_fields_config' determine how fields are blended when overlaid. - - Text configurations in 'text_config' control the appearance of text labels on visualizations. - - If 'output_dir' is not specified, visualizations will be saved in a default location. - - For more details on available options for each configuration, refer to the documentation - of the `visualize_image_item`, `visualize_field`, and related functions. + - The ``fields`` parameter determines which individual fields are visualized + - The ``overlay_fields`` parameter specifies which fields should be overlaid + on others + - Field configurations in ``fields_config`` affect how individual fields are + visualized + - Overlay configurations in ``overlay_fields_config`` determine how fields are + blended when overlaid + - Text configurations in ``text_config`` control the appearance of text labels + on visualizations + - If ``output_dir`` is not specified, visualizations will be saved in a + default location + + For more details on available options for each configuration, refer to the + documentation of the :func:`visualize_image_item`, :func:`visualize_field`, and + related functions. """ def __init__( diff --git a/tests/unit/cli/test_installation.py b/tests/unit/cli/test_installation.py index 6a34017639..6fd32c4db2 100644 --- a/tests/unit/cli/test_installation.py +++ b/tests/unit/cli/test_installation.py @@ -1,6 +1,6 @@ """Tests for installation utils.""" -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import os