diff --git a/src/anomalib/metrics/per_image/__init__.py b/src/anomalib/metrics/per_image/__init__.py index b98ea9fae6..f853f83ce1 100644 --- a/src/anomalib/metrics/per_image/__init__.py +++ b/src/anomalib/metrics/per_image/__init__.py @@ -7,7 +7,8 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .binclf_curve_numpy import BinclfThreshsChoice +from .binclf_curve import BinclfThreshsChoice +from .enums import StatsOutliersPolicy, StatsRepeatedPolicy from .pimo import AUPIMO, PIMO, AUPIMOResult, PIMOResult, aupimo_scores, pimo_curves from .utils import ( compare_models_pairwise_ttest_rel, @@ -15,7 +16,6 @@ format_pairwise_tests_results, per_image_scores_stats, ) -from .utils_numpy import StatsOutliersPolicy, StatsRepeatedPolicy __all__ = [ # constants diff --git a/src/anomalib/metrics/per_image/_validate.py b/src/anomalib/metrics/per_image/_validate.py index 0ebc6916e0..7852681bef 100644 --- a/src/anomalib/metrics/per_image/_validate.py +++ b/src/anomalib/metrics/per_image/_validate.py @@ -1,7 +1,5 @@ """Utils for validating arguments and results. -`torch` is imported in the functions that use it, so this module can be used in numpy-standalone mode. - TODO(jpcbertoldo): Move validations to a common place and reuse them across the codebase. https://github.com/openvinotoolkit/anomalib/issues/2093 """ @@ -13,21 +11,8 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import numpy as np -from numpy import ndarray - - -def is_tensor(tensor: Any, argname: str | None = None) -> None: # noqa: ANN401 - """Validate that `tensor` is a `torch.Tensor`.""" - from torch import Tensor - - argname = f"'{argname}'" if argname is not None else "argument" - - if not isinstance(tensor, Tensor): - msg = f"Expected {argname} to be a tensor, but got {type(tensor)}" - raise TypeError(msg) +import torch +from torch import Tensor def is_num_threshs_gte2(num_threshs: int) -> None: @@ -98,22 +83,22 @@ def is_rate_range(bounds: tuple[float, float]) -> None: raise ValueError(msg) -def is_threshs(threshs: ndarray) -> None: +def is_threshs(threshs: Tensor) -> None: """Validate that the thresholds are valid and monotonically increasing.""" - if not isinstance(threshs, ndarray): - msg = f"Expected thresholds to be an ndarray, but got {type(threshs)}" + if not isinstance(threshs, Tensor): + msg = f"Expected thresholds to be an Tensor, but got {type(threshs)}" raise TypeError(msg) if threshs.ndim != 1: msg = f"Expected thresholds to be 1D, but got {threshs.ndim}" raise ValueError(msg) - if threshs.dtype.kind != "f": - msg = f"Expected thresholds to be of float type, but got ndarray with dtype {threshs.dtype}" + if not threshs.dtype.is_floating_point: + msg = f"Expected thresholds to be of float type, but got Tensor with dtype {threshs.dtype}" raise TypeError(msg) # make sure they are strictly increasing - if not np.all(np.diff(threshs) > 0): + if not torch.all(torch.diff(threshs) > 0): msg = "Expected thresholds to be strictly increasing, but it is not." raise ValueError(msg) @@ -142,55 +127,55 @@ def is_thresh_bounds(thresh_bounds: tuple[float, float]) -> None: raise ValueError(msg) -def is_anomaly_maps(anomaly_maps: ndarray) -> None: - if not isinstance(anomaly_maps, ndarray): - msg = f"Expected anomaly maps to be an ndarray, but got {type(anomaly_maps)}" +def is_anomaly_maps(anomaly_maps: Tensor) -> None: + if not isinstance(anomaly_maps, Tensor): + msg = f"Expected anomaly maps to be an Tensor, but got {type(anomaly_maps)}" raise TypeError(msg) if anomaly_maps.ndim != 3: msg = f"Expected anomaly maps have 3 dimensions (N, H, W), but got {anomaly_maps.ndim} dimensions" raise ValueError(msg) - if anomaly_maps.dtype.kind != "f": + if not anomaly_maps.dtype.is_floating_point: msg = ( - "Expected anomaly maps to be an floating ndarray with anomaly scores," - f" but got ndarray with dtype {anomaly_maps.dtype}" + "Expected anomaly maps to be an floating Tensor with anomaly scores," + f" but got Tensor with dtype {anomaly_maps.dtype}" ) raise TypeError(msg) -def is_masks(masks: ndarray) -> None: - if not isinstance(masks, ndarray): - msg = f"Expected masks to be an ndarray, but got {type(masks)}" +def is_masks(masks: Tensor) -> None: + if not isinstance(masks, Tensor): + msg = f"Expected masks to be an Tensor, but got {type(masks)}" raise TypeError(msg) if masks.ndim != 3: msg = f"Expected masks have 3 dimensions (N, H, W), but got {masks.ndim} dimensions" raise ValueError(msg) - if masks.dtype.kind == "b": + if masks.dtype == torch.bool: pass - - elif masks.dtype.kind in {"i", "u"}: - masks_unique_vals = np.unique(masks) - if np.any((masks_unique_vals != 0) & (masks_unique_vals != 1)): - msg = ( - "Expected masks to be a *binary* ndarray with ground truth labels, " - f"but got ndarray with unique values {sorted(masks_unique_vals)}" - ) - raise ValueError(msg) - - else: + elif masks.dtype.is_floating_point: msg = ( - "Expected masks to be an integer or boolean ndarray with ground truth labels, " - f"but got ndarray with dtype {masks.dtype}" + "Expected masks to be an integer or boolean Tensor with ground truth labels, " + f"but got Tensor with dtype {masks.dtype}" ) raise TypeError(msg) + else: + # assumes the type to be (signed or unsigned) integer + # this will change with the dataclass refactor + masks_unique_vals = torch.unique(masks) + if torch.any((masks_unique_vals != 0) & (masks_unique_vals != 1)): + msg = ( + "Expected masks to be a *binary* Tensor with ground truth labels, " + f"but got Tensor with unique values {sorted(masks_unique_vals)}" + ) + raise ValueError(msg) -def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> None: - if not isinstance(binclf_curves, ndarray): - msg = f"Expected binclf curves to be an ndarray, but got {type(binclf_curves)}" +def is_binclf_curves(binclf_curves: Tensor, valid_threshs: Tensor | None) -> None: + if not isinstance(binclf_curves, Tensor): + msg = f"Expected binclf curves to be an Tensor, but got {type(binclf_curves)}" raise TypeError(msg) if binclf_curves.ndim != 4: @@ -201,7 +186,7 @@ def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> N msg = f"Expected binclf curves to have shape (..., 2, 2), but got {binclf_curves.shape}" raise ValueError(msg) - if binclf_curves.dtype != np.int64: + if binclf_curves.dtype != torch.int64: msg = f"Expected binclf curves to have dtype int64, but got {binclf_curves.dtype}." raise TypeError(msg) @@ -232,47 +217,49 @@ def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> N raise RuntimeError(msg) -def is_images_classes(images_classes: ndarray) -> None: - if not isinstance(images_classes, ndarray): - msg = f"Expected image classes to be an ndarray, but got {type(images_classes)}." +def is_images_classes(images_classes: Tensor) -> None: + if not isinstance(images_classes, Tensor): + msg = f"Expected image classes to be an Tensor, but got {type(images_classes)}." raise TypeError(msg) if images_classes.ndim != 1: msg = f"Expected image classes to be 1D, but got {images_classes.ndim}D." raise ValueError(msg) - if images_classes.dtype.kind == "b": + if images_classes.dtype == torch.bool: pass - elif images_classes.dtype.kind in {"i", "u"}: - unique_vals = np.unique(images_classes) - if np.any((unique_vals != 0) & (unique_vals != 1)): - msg = ( - "Expected image classes to be a *binary* ndarray with ground truth labels, " - f"but got ndarray with unique values {sorted(unique_vals)}" - ) - raise ValueError(msg) - else: + elif images_classes.dtype.is_floating_point: msg = ( - "Expected image classes to be an integer or boolean ndarray with ground truth labels, " - f"but got ndarray with dtype {images_classes.dtype}" + "Expected image classes to be an integer or boolean Tensor with ground truth labels, " + f"but got Tensor with dtype {images_classes.dtype}" ) raise TypeError(msg) + else: + # assumes the type to be (signed or unsigned) integer + # this will change with the dataclass refactor + unique_vals = torch.unique(images_classes) + if torch.any((unique_vals != 0) & (unique_vals != 1)): + msg = ( + "Expected image classes to be a *binary* Tensor with ground truth labels, " + f"but got Tensor with unique values {sorted(unique_vals)}" + ) + raise ValueError(msg) -def is_rates(rates: ndarray, nan_allowed: bool) -> None: - if not isinstance(rates, ndarray): - msg = f"Expected rates to be an ndarray, but got {type(rates)}." +def is_rates(rates: Tensor, nan_allowed: bool) -> None: + if not isinstance(rates, Tensor): + msg = f"Expected rates to be an Tensor, but got {type(rates)}." raise TypeError(msg) if rates.ndim != 1: msg = f"Expected rates to be 1D, but got {rates.ndim}D." raise ValueError(msg) - if rates.dtype.kind != "f": + if not rates.dtype.is_floating_point: msg = f"Expected rates to have dtype of float type, but got {rates.dtype}." raise ValueError(msg) - isnan_mask = np.isnan(rates) + isnan_mask = torch.isnan(rates) if nan_allowed: # if they are all nan, then there is nothing to validate if isnan_mask.all(): @@ -293,11 +280,11 @@ def is_rates(rates: ndarray, nan_allowed: bool) -> None: raise ValueError(msg) -def is_rate_curve(rate_curve: ndarray, nan_allowed: bool, decreasing: bool) -> None: +def is_rate_curve(rate_curve: Tensor, nan_allowed: bool, decreasing: bool) -> None: is_rates(rate_curve, nan_allowed=nan_allowed) - diffs = np.diff(rate_curve) - diffs_valid = diffs[~np.isnan(diffs)] if nan_allowed else diffs + diffs = torch.diff(rate_curve) + diffs_valid = diffs[~torch.isnan(diffs)] if nan_allowed else diffs if decreasing and (diffs_valid > 0).any(): msg = "Expected rate curve to be monotonically decreasing, but got non-monotonically decreasing values." @@ -308,20 +295,20 @@ def is_rate_curve(rate_curve: ndarray, nan_allowed: bool, decreasing: bool) -> N raise ValueError(msg) -def is_per_image_rate_curves(rate_curves: ndarray, nan_allowed: bool, decreasing: bool | None) -> None: - if not isinstance(rate_curves, ndarray): - msg = f"Expected per-image rate curves to be an ndarray, but got {type(rate_curves)}." +def is_per_image_rate_curves(rate_curves: Tensor, nan_allowed: bool, decreasing: bool | None) -> None: + if not isinstance(rate_curves, Tensor): + msg = f"Expected per-image rate curves to be an Tensor, but got {type(rate_curves)}." raise TypeError(msg) if rate_curves.ndim != 2: msg = f"Expected per-image rate curves to be 2D, but got {rate_curves.ndim}D." raise ValueError(msg) - if rate_curves.dtype.kind != "f": + if not rate_curves.dtype.is_floating_point: msg = f"Expected per-image rate curves to have dtype of float type, but got {rate_curves.dtype}." raise ValueError(msg) - isnan_mask = np.isnan(rate_curves) + isnan_mask = torch.isnan(rate_curves) if nan_allowed: # if they are all nan, then there is nothing to validate if isnan_mask.all(): @@ -344,8 +331,8 @@ def is_per_image_rate_curves(rate_curves: ndarray, nan_allowed: bool, decreasing if decreasing is None: return - diffs = np.diff(rate_curves, axis=1) - diffs_valid = diffs[~np.isnan(diffs)] if nan_allowed else diffs + diffs = torch.diff(rate_curves, axis=1) + diffs_valid = diffs[~torch.isnan(diffs)] if nan_allowed else diffs if decreasing and (diffs_valid > 0).any(): msg = ( diff --git a/src/anomalib/metrics/per_image/binclf_curve_numpy.py b/src/anomalib/metrics/per_image/binclf_curve.py similarity index 79% rename from src/anomalib/metrics/per_image/binclf_curve_numpy.py rename to src/anomalib/metrics/per_image/binclf_curve.py index ae444735b6..e958263664 100644 --- a/src/anomalib/metrics/per_image/binclf_curve_numpy.py +++ b/src/anomalib/metrics/per_image/binclf_curve.py @@ -18,6 +18,7 @@ from functools import partial import numpy as np +import torch from numpy import ndarray from . import _validate @@ -38,16 +39,16 @@ class BinclfThreshsChoice(Enum): # =========================================== ARGS VALIDATION =========================================== -def _validate_is_scores_batch(scores_batch: ndarray) -> None: - """scores_batch (ndarray): floating (N, D).""" - if not isinstance(scores_batch, ndarray): - msg = f"Expected `scores_batch` to be an ndarray, but got {type(scores_batch)}" +def _validate_is_scores_batch(scores_batch: torch.Tensor) -> None: + """scores_batch (torch.Tensor): floating (N, D).""" + if not isinstance(scores_batch, torch.Tensor): + msg = f"Expected `scores_batch` to be an torch.Tensor, but got {type(scores_batch)}" raise TypeError(msg) - if scores_batch.dtype.kind != "f": + if not scores_batch.dtype.is_floating_point: msg = ( - "Expected `scores_batch` to be an floating ndarray with anomaly scores_batch," - f" but got ndarray with dtype {scores_batch.dtype}" + "Expected `scores_batch` to be an floating torch.Tensor with anomaly scores_batch," + f" but got torch.Tensor with dtype {scores_batch.dtype}" ) raise TypeError(msg) @@ -56,16 +57,16 @@ def _validate_is_scores_batch(scores_batch: ndarray) -> None: raise ValueError(msg) -def _validate_is_gts_batch(gts_batch: ndarray) -> None: - """gts_batch (ndarray): boolean (N, D).""" - if not isinstance(gts_batch, ndarray): - msg = f"Expected `gts_batch` to be an ndarray, but got {type(gts_batch)}" +def _validate_is_gts_batch(gts_batch: torch.Tensor) -> None: + """gts_batch (torch.Tensor): boolean (N, D).""" + if not isinstance(gts_batch, torch.Tensor): + msg = f"Expected `gts_batch` to be an torch.Tensor, but got {type(gts_batch)}" raise TypeError(msg) - if gts_batch.dtype.kind != "b": + if gts_batch.dtype != torch.bool: msg = ( - "Expected `gts_batch` to be an boolean ndarray with anomaly scores_batch," - f" but got ndarray with dtype {gts_batch.dtype}" + "Expected `gts_batch` to be an boolean torch.Tensor with anomaly scores_batch," + f" but got torch.Tensor with dtype {gts_batch.dtype}" ) raise TypeError(msg) @@ -74,19 +75,14 @@ def _validate_is_gts_batch(gts_batch: ndarray) -> None: raise ValueError(msg) -# =========================================== PYTHON VERSION =========================================== - - def _binclf_one_curve(scores: ndarray, gts: ndarray, threshs: ndarray) -> ndarray: - """One binary classification matrix at each threshold. + """One binary classification matrix at each threshold (PYTHON implementation). In the case where the thresholds are given (i.e. not considering all possible thresholds based on the scores), this weird-looking function is faster than the two options in `torchmetrics` on the CPU: - `_binary_precision_recall_curve_update_vectorized` - `_binary_precision_recall_curve_update_loop` - (both in module `torchmetrics.functional.classification.precision_recall_curve` in `torchmetrics==1.1.0`). - Note: VALIDATION IS NOT DONE HERE. Make sure to validate the arguments before calling this function. Args: @@ -96,7 +92,6 @@ def _binclf_one_curve(scores: ndarray, gts: ndarray, threshs: ndarray) -> ndarra Returns: ndarray: Binary classification matrix curve (K, 2, 2) - Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. """ num_th = len(threshs) @@ -149,20 +144,11 @@ def score_less_than_thresh(score: float, thresh: float) -> bool: ).transpose(0, 2, 1) -_binclf_multiple_curves = np.vectorize(_binclf_one_curve, signature="(n),(n),(k)->(k,2,2)") -_binclf_multiple_curves.__doc__ = """ -Multiple binary classification matrix at each threshold. -vectorized version of `_binclf_one_curve` (see above) -""" - -# =========================================== INTERFACE =========================================== - - def binclf_multiple_curves( - scores_batch: ndarray, - gts_batch: ndarray, - threshs: ndarray, -) -> ndarray: + scores_batch: torch.Tensor, + gts_batch: torch.Tensor, + threshs: torch.Tensor, +) -> torch.Tensor: """Multiple binary classification matrix (per-instance scope) at each threshold (shared). This is a wrapper around `_binclf_multiple_curves_python` and `_binclf_multiple_curves_numba`. @@ -171,13 +157,12 @@ def binclf_multiple_curves( Note: predicted as positive condition is `score >= thresh`. Args: - scores_batch (ndarray): Anomaly scores (N, D,). - gts_batch (ndarray): Binary (bool) ground truth of shape (N, D,). - threshs (ndarray): Sequence of thresholds in ascending order (K,). - algorithm (str, optional): Algorithm to use. Defaults to ALGORITHM_NUMBA. + scores_batch (torch.Tensor): Anomaly scores (N, D,). + gts_batch (torch.Tensor): Binary (bool) ground truth of shape (N, D,). + threshs (torch.Tensor): Sequence of thresholds in ascending order (K,). Returns: - ndarray: Binary classification matrix curves (N, K, 2, 2) + torch.Tensor: Binary classification matrix curves (N, K, 2, 2) The last two dimensions are the confusion matrix (ground truth, predictions) So for each thresh it gives: @@ -205,14 +190,20 @@ def binclf_multiple_curves( _validate_is_gts_batch(gts_batch) _validate.is_same_shape(scores_batch, gts_batch) _validate.is_threshs(threshs) - - return _binclf_multiple_curves(scores_batch, gts_batch, threshs) + # TODO(ashwinvaidya17): this is kept as numpy for now because it is much faster. + # TEMP-0 + result = np.vectorize(_binclf_one_curve, signature="(n),(n),(k)->(k,2,2)")( + scores_batch.detach().cpu().numpy(), + gts_batch.detach().cpu().numpy(), + threshs.detach().cpu().numpy(), + ) + return torch.from_numpy(result).to(scores_batch.device) # ========================================= PER-IMAGE BINCLF CURVE ========================================= -def _get_threshs_minmax_linspace(anomaly_maps: ndarray, num_threshs: int) -> ndarray: +def _get_threshs_minmax_linspace(anomaly_maps: torch.Tensor, num_threshs: int) -> torch.Tensor: """Get thresholds linearly spaced between the min and max of the anomaly maps.""" _validate.is_num_threshs_gte2(num_threshs) # this operation can be a bit expensive @@ -222,33 +213,33 @@ def _get_threshs_minmax_linspace(anomaly_maps: ndarray, num_threshs: int) -> nda except ValueError as ex: msg = f"Invalid threshold bounds computed from the given anomaly maps. Cause: {ex}" raise ValueError(msg) from ex - return np.linspace(thresh_low, thresh_high, num_threshs, dtype=anomaly_maps.dtype) + return torch.linspace(thresh_low, thresh_high, num_threshs, dtype=anomaly_maps.dtype) def per_image_binclf_curve( - anomaly_maps: ndarray, - masks: ndarray, + anomaly_maps: torch.Tensor, + masks: torch.Tensor, threshs_choice: BinclfThreshsChoice | str = BinclfThreshsChoice.MINMAX_LINSPACE.value, - threshs_given: ndarray | None = None, + threshs_given: torch.Tensor | None = None, num_threshs: int | None = None, -) -> tuple[ndarray, ndarray]: +) -> tuple[torch.Tensor, torch.Tensor]: """Compute the binary classification matrix of each image in the batch for multiple thresholds (shared). Args: - anomaly_maps (ndarray): Anomaly score maps of shape (N, H, W) - masks (ndarray): Binary ground truth masks of shape (N, H, W) + anomaly_maps (torch.Tensor): Anomaly score maps of shape (N, H, W) + masks (torch.Tensor): Binary ground truth masks of shape (N, H, W) threshs_choice (str, optional): Sequence of thresholds to use. Defaults to THRESH_SEQUENCE_MINMAX_LINSPACE. # # `threshs_choice`-dependent arguments # # THRESH_SEQUENCE_GIVEN - threshs_given (ndarray, optional): Sequence of thresholds to use. + threshs_given (torch.Tensor, optional): Sequence of thresholds to use. # # THRESH_SEQUENCE_MINMAX_LINSPACE num_threshs (int, optional): Number of thresholds between the min and max of the anomaly maps. Returns: - tuple[ndarray, ndarray]: + tuple[torch.Tensor, torch.Tensor]: [0] Thresholds of shape (K,) and dtype is the same as `anomaly_maps.dtype`. [1] Binary classification matrices of shape (N, K, 2, 2) @@ -281,7 +272,7 @@ def per_image_binclf_curve( _validate.is_masks(masks) _validate.is_same_shape(anomaly_maps, masks) - threshs: ndarray + threshs: torch.Tensor if threshs_choice == BinclfThreshsChoice.GIVEN: assert threshs_given is not None @@ -291,7 +282,7 @@ def per_image_binclf_curve( "Argument `num_threshs` was given, " f"but it is ignored because `threshs_choice` is '{threshs_choice.value}'.", ) - threshs = threshs_given.astype(anomaly_maps.dtype) + threshs = threshs_given.to(anomaly_maps.dtype) elif threshs_choice == BinclfThreshsChoice.MINMAX_LINSPACE: assert num_threshs is not None @@ -315,7 +306,7 @@ def per_image_binclf_curve( # keep the batch dimension and flatten the rest scores_batch = anomaly_maps.reshape(anomaly_maps.shape[0], -1) - gts_batch = masks.reshape(masks.shape[0], -1).astype(bool) # make sure it is boolean + gts_batch = masks.reshape(masks.shape[0], -1).to(bool) # make sure it is boolean binclf_curves = binclf_multiple_curves(scores_batch, gts_batch, threshs) @@ -343,7 +334,7 @@ def per_image_binclf_curve( # =========================================== RATE METRICS =========================================== -def per_image_tpr(binclf_curves: ndarray) -> ndarray: +def per_image_tpr(binclf_curves: torch.Tensor) -> torch.Tensor: """True positive rates (TPR) for image for each thresh. TPR = TP / P = TP / (TP + FN) @@ -353,10 +344,10 @@ def per_image_tpr(binclf_curves: ndarray) -> ndarray: P: positives (TP + FN) Args: - binclf_curves (ndarray): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + binclf_curves (torch.Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. Returns: - ndarray: shape (N, K), dtype float64 + torch.Tensor: shape (N, K), dtype float64 N: number of images K: number of thresholds @@ -367,10 +358,10 @@ def per_image_tpr(binclf_curves: ndarray) -> ndarray: pos = binclf_curves[..., 1, :].sum(axis=2) # 2 was the 3 originally # tprs will be nan if pos == 0 (normal image), which is expected - return tps.astype(np.float64) / pos.astype(np.float64) + return tps.to(torch.float64) / pos.to(torch.float64) -def per_image_fpr(binclf_curves: ndarray) -> ndarray: +def per_image_fpr(binclf_curves: torch.Tensor) -> torch.Tensor: """False positive rates (TPR) for image for each thresh. FPR = FP / N = FP / (FP + TN) @@ -380,10 +371,10 @@ def per_image_fpr(binclf_curves: ndarray) -> ndarray: N: negatives (FP + TN) Args: - binclf_curves (ndarray): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + binclf_curves (torch.Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. Returns: - ndarray: shape (N, K), dtype float64 + torch.Tensor: shape (N, K), dtype float64 N: number of images K: number of thresholds @@ -394,4 +385,4 @@ def per_image_fpr(binclf_curves: ndarray) -> ndarray: neg = binclf_curves[..., 0, :].sum(axis=2) # 2 was the 3 originally # it can be `nan` if an anomalous image is fully covered by the mask - return fps.astype(np.float64) / neg.astype(np.float64) + return fps.to(torch.float64) / neg.to(torch.float64) diff --git a/src/anomalib/metrics/per_image/enums.py b/src/anomalib/metrics/per_image/enums.py new file mode 100644 index 0000000000..f17d7692b0 --- /dev/null +++ b/src/anomalib/metrics/per_image/enums.py @@ -0,0 +1,46 @@ +"""Enumerations for per-image metrics.""" + +# Based on the code: https://github.com/jpcbertoldo/aupimo + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum + + +class StatsOutliersPolicy(Enum): + """How to handle outliers in per-image metrics boxplots. Use them? Only high? Only low? Both? + + Outliers are defined as in a boxplot, i.e. values that are more than 1.5 times the interquartile range (IQR) away + from the Q1 and Q3 quartiles (respectively low and high outliers). The IQR is the difference between Q3 and Q1. + + None | "none": do not include outliers. + "high": only include high outliers. + "low": only include low outliers. + "both": include both high and low outliers. + """ + + NONE: str = "none" + HIGH: str = "high" + LOW: str = "low" + BOTH: str = "both" + + +class StatsRepeatedPolicy(Enum): + """How to handle repeated values in per-image metrics boxplots (two stats with same value). Avoid them? + + None | "none": do not avoid repeated values, so several stats can have the same value and image index. + "avoid": if a stat has the same value as another stat, the one with the closest then another image, + with the nearest score, is selected. + """ + + NONE: str = "none" + AVOID: str = "avoid" + + +class StatsAlternativeHypothesis(Enum): + """Alternative hypothesis for the statistical tests used to compare per-image metrics.""" + + TWO_SIDED: str = "two-sided" + LESS: str = "less" + GREATER: str = "greater" diff --git a/src/anomalib/metrics/per_image/pimo_numpy.py b/src/anomalib/metrics/per_image/functional.py similarity index 82% rename from src/anomalib/metrics/per_image/pimo_numpy.py rename to src/anomalib/metrics/per_image/functional.py index 2bd5c0cb89..8a841fd514 100644 --- a/src/anomalib/metrics/per_image/pimo_numpy.py +++ b/src/anomalib/metrics/per_image/functional.py @@ -13,10 +13,10 @@ import logging import numpy as np -from numpy import ndarray +import torch -from . import _validate, binclf_curve_numpy -from .binclf_curve_numpy import BinclfThreshsChoice +from . import _validate, binclf_curve +from .binclf_curve import BinclfThreshsChoice logger = logging.getLogger(__name__) @@ -24,30 +24,30 @@ # =========================================== AUX =========================================== -def _images_classes_from_masks(masks: ndarray) -> ndarray: +def _images_classes_from_masks(masks: torch.Tensor) -> torch.Tensor: """Deduce the image classes from the masks.""" _validate.is_masks(masks) - return (masks == 1).any(axis=(1, 2)).astype(np.int32) + return (masks == 1).any(axis=(1, 2)).to(torch.int32) # =========================================== ARGS VALIDATION =========================================== -def _validate_has_at_least_one_anomalous_image(masks: ndarray) -> None: +def _validate_has_at_least_one_anomalous_image(masks: torch.Tensor) -> None: image_classes = _images_classes_from_masks(masks) if (image_classes == 1).sum() == 0: msg = "Expected at least one ANOMALOUS image, but found none." raise ValueError(msg) -def _validate_has_at_least_one_normal_image(masks: ndarray) -> None: +def _validate_has_at_least_one_normal_image(masks: torch.Tensor) -> None: image_classes = _images_classes_from_masks(masks) if (image_classes == 0).sum() == 0: msg = "Expected at least one NORMAL image, but found none." raise ValueError(msg) -def _joint_validate_threshs_shared_fpr(threshs: ndarray, shared_fpr: ndarray) -> None: +def _joint_validate_threshs_shared_fpr(threshs: torch.Tensor, shared_fpr: torch.Tensor) -> None: if threshs.shape[0] != shared_fpr.shape[0]: msg = ( "Expected `threshs` and `shared_fpr` to have the same number of elements, " @@ -60,10 +60,10 @@ def _joint_validate_threshs_shared_fpr(threshs: ndarray, shared_fpr: ndarray) -> def pimo_curves( - anomaly_maps: ndarray, - masks: ndarray, + anomaly_maps: torch.Tensor, + masks: torch.Tensor, num_threshs: int, -) -> tuple[ndarray, ndarray, ndarray, ndarray]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the Per-IMage Overlap (PIMO, pronounced pee-mo) curves. PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. @@ -84,7 +84,7 @@ def pimo_curves( num_threshs: number of thresholds to compute (K) Returns: - tuple[ndarray, ndarray, ndarray, ndarray]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: [0] thresholds of shape (K,) in ascending order [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) @@ -92,9 +92,9 @@ def pimo_curves( """ # validate the strings are valid _validate.is_num_threshs_gte2(num_threshs) - _validate.is_anomaly_maps(anomaly_maps) # redundant - _validate.is_masks(masks) # redundant - _validate.is_same_shape(anomaly_maps, masks) # redundant + _validate.is_anomaly_maps(anomaly_maps) + _validate.is_masks(masks) + _validate.is_same_shape(anomaly_maps, masks) _validate_has_at_least_one_anomalous_image(masks) _validate_has_at_least_one_normal_image(masks) @@ -104,14 +104,14 @@ def pimo_curves( # therefore getting a better resolution in terms of FPR quantization # otherwise the function `binclf_curve_numpy.per_image_binclf_curve` would have the range of thresholds # computed from all the images (normal + anomalous) - threshs = binclf_curve_numpy._get_threshs_minmax_linspace( # noqa: SLF001 + threshs = binclf_curve._get_threshs_minmax_linspace( # noqa: SLF001 anomaly_maps[image_classes == 0], num_threshs, ) # N: number of images, K: number of thresholds # shapes are (K,) and (N, K, 2, 2) - threshs, binclf_curves = binclf_curve_numpy.per_image_binclf_curve( + threshs, binclf_curves = binclf_curve.per_image_binclf_curve( anomaly_maps=anomaly_maps, masks=masks, threshs_choice=BinclfThreshsChoice.GIVEN.value, @@ -119,10 +119,10 @@ def pimo_curves( num_threshs=None, ) - shared_fpr: ndarray + shared_fpr: torch.Tensor # mean-per-image-fpr on normal images # shape -> (N, K) - per_image_fprs_normals = binclf_curve_numpy.per_image_fpr(binclf_curves[image_classes == 0]) + per_image_fprs_normals = binclf_curve.per_image_fpr(binclf_curves[image_classes == 0]) try: _validate.is_per_image_rate_curves(per_image_fprs_normals, nan_allowed=False, decreasing=True) except ValueError as ex: @@ -135,7 +135,7 @@ def pimo_curves( shared_fpr = per_image_fprs_normals.mean(axis=0) # shape -> (N, K) - per_image_tprs = binclf_curve_numpy.per_image_tpr(binclf_curves) + per_image_tprs = binclf_curve.per_image_tpr(binclf_curves) return threshs, shared_fpr, per_image_tprs, image_classes @@ -144,12 +144,12 @@ def pimo_curves( def aupimo_scores( - anomaly_maps: ndarray, - masks: ndarray, + anomaly_maps: torch.Tensor, + masks: torch.Tensor, num_threshs: int = 300_000, fpr_bounds: tuple[float, float] = (1e-5, 1e-4), force: bool = False, -) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: """Compute the PIMO curves and their Area Under the Curve (i.e. AUPIMO) scores. Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. @@ -171,7 +171,7 @@ def aupimo_scores( force: whether to force the computation despite bad conditions Returns: - tuple[ndarray, ndarray, ndarray, ndarray, ndarray]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: [0] thresholds of shape (K,) in ascending order [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) @@ -211,13 +211,21 @@ def aupimo_scores( fpr_upper_bound, ) - if not np.isclose(fpr_lower_bound_defacto, fpr_lower_bound, rtol=(rtol := 1e-2)): + if not torch.isclose( + fpr_lower_bound_defacto, + torch.tensor(fpr_lower_bound, dtype=fpr_lower_bound_defacto.dtype, device=fpr_lower_bound_defacto.device), + rtol=(rtol := 1e-2), + ): logger.warning( "The lower bound of the shared FPR integration range is not exactly achieved. " f"Expected {fpr_lower_bound} but got {fpr_lower_bound_defacto}, which is not within {rtol=}.", ) - if not np.isclose(fpr_upper_bound_defacto, fpr_upper_bound, rtol=rtol): + if not torch.isclose( + fpr_upper_bound_defacto, + torch.tensor(fpr_upper_bound, dtype=fpr_upper_bound_defacto.dtype, device=fpr_upper_bound_defacto.device), + rtol=rtol, + ): logger.warning( "The upper bound of the shared FPR integration range is not exactly achieved. " f"Expected {fpr_upper_bound} but got {fpr_upper_bound_defacto}, which is not within {rtol=}.", @@ -237,18 +245,18 @@ def aupimo_scores( raise RuntimeError(msg) # limit the curves to the integration range [lbound, ubound] - shared_fpr_bounded: ndarray = shared_fpr[thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] - per_image_tprs_bounded: ndarray = per_image_tprs[:, thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] + shared_fpr_bounded: torch.Tensor = shared_fpr[thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] + per_image_tprs_bounded: torch.Tensor = per_image_tprs[:, thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] # `shared_fpr` and `tprs` are in descending order; `flip()` reverts to ascending order - shared_fpr_bounded = np.flip(shared_fpr_bounded) - per_image_tprs_bounded = np.flip(per_image_tprs_bounded, axis=1) + shared_fpr_bounded = torch.flip(shared_fpr_bounded, dims=[0]) + per_image_tprs_bounded = torch.flip(per_image_tprs_bounded, dims=[1]) # the log's base does not matter because it's a constant factor canceled by normalization factor - shared_fpr_bounded_log = np.log(shared_fpr_bounded) + shared_fpr_bounded_log = torch.log(shared_fpr_bounded) # deal with edge cases - invalid_shared_fpr = ~np.isfinite(shared_fpr_bounded_log) + invalid_shared_fpr = ~torch.isfinite(shared_fpr_bounded_log) if invalid_shared_fpr.all(): msg = ( @@ -287,7 +295,7 @@ def aupimo_scores( "Try increasing `num_threshs`.", ) - aucs: ndarray = np.trapz(per_image_tprs_bounded, x=shared_fpr_bounded_log, axis=1) # noqa: NPY201 deprecated in Numpy 2.0 + aucs: torch.Tensor = torch.trapezoid(per_image_tprs_bounded, x=shared_fpr_bounded_log, axis=1) # normalize, then clip(0, 1) makes sure that the values are in [0, 1] in case of numerical errors normalization_factor = aupimo_normalizing_factor(fpr_bounds) @@ -299,7 +307,11 @@ def aupimo_scores( # =========================================== AUX =========================================== -def thresh_at_shared_fpr_level(threshs: ndarray, shared_fpr: ndarray, fpr_level: float) -> tuple[int, float, float]: +def thresh_at_shared_fpr_level( + threshs: torch.Tensor, + shared_fpr: torch.Tensor, + fpr_level: float, +) -> tuple[int, float, torch.Tensor]: """Return the threshold and its index at the given shared FPR level. Three cases are possible: @@ -342,13 +354,13 @@ def thresh_at_shared_fpr_level(threshs: ndarray, shared_fpr: ndarray, fpr_level: # fpr_level == 0 or 1 are special case # because there may be multiple solutions, and the chosen should their MINIMUM/MAXIMUM respectively if fpr_level == 0.0: - index = np.min(np.where(shared_fpr == fpr_level)) + index = torch.min(torch.where(shared_fpr == fpr_level)[0]) elif fpr_level == 1.0: - index = np.max(np.where(shared_fpr == fpr_level)) + index = torch.max(torch.where(shared_fpr == fpr_level)[0]) else: - index = np.argmin(np.abs(shared_fpr - fpr_level)) + index = torch.argmin(torch.abs(shared_fpr - fpr_level)) index = int(index) fpr_level_defacto = shared_fpr[index] diff --git a/src/anomalib/metrics/per_image/pimo.py b/src/anomalib/metrics/per_image/pimo.py index bda8d800b6..16dc22ed41 100644 --- a/src/anomalib/metrics/per_image/pimo.py +++ b/src/anomalib/metrics/per_image/pimo.py @@ -44,65 +44,45 @@ from dataclasses import dataclass, field import torch -from torch import Tensor from torchmetrics import Metric from anomalib.data.utils.path import validate_path -from . import _validate, pimo_numpy +from . import _validate, functional logger = logging.getLogger(__name__) -def _images_classes_from_masks(masks: Tensor) -> Tensor: +def _images_classes_from_masks(masks: torch.Tensor) -> torch.Tensor: masks = torch.concat(masks, dim=0) device = masks.device - image_classes = pimo_numpy._images_classes_from_masks(masks.numpy()) # noqa: SLF001 + image_classes = functional._images_classes_from_masks(masks) # noqa: SLF001 return torch.from_numpy(image_classes, device=device) # =========================================== ARGS VALIDATION =========================================== -def _validate_is_anomaly_maps(anomaly_maps: Tensor) -> None: - _validate.is_tensor(anomaly_maps, argname="anomaly_maps") - _validate.is_anomaly_maps(anomaly_maps.numpy()) +def _validate_is_shared_fpr(shared_fpr: torch.Tensor, nan_allowed: bool = False, decreasing: bool = True) -> None: + _validate.is_rate_curve(shared_fpr, nan_allowed=nan_allowed, decreasing=decreasing) -def _validate_is_masks(masks: Tensor) -> None: - _validate.is_tensor(masks, argname="masks") - _validate.is_masks(masks.numpy()) +def _validate_is_image_classes(image_classes: torch.Tensor) -> None: + _validate.is_images_classes(image_classes) -def _validate_is_threshs(threshs: Tensor) -> None: - _validate.is_tensor(threshs, argname="threshs") - _validate.is_threshs(threshs.numpy()) - - -def _validate_is_shared_fpr(shared_fpr: Tensor, nan_allowed: bool = False, decreasing: bool = True) -> None: - _validate.is_tensor(shared_fpr, argname="shared_fpr") - _validate.is_rate_curve(shared_fpr.numpy(), nan_allowed=nan_allowed, decreasing=decreasing) - - -def _validate_is_image_classes(image_classes: Tensor) -> None: - _validate.is_tensor(image_classes, argname="image_classes") - _validate.is_images_classes(image_classes.numpy()) - - -def _validate_is_per_image_tprs(per_image_tprs: Tensor, image_classes: Tensor) -> None: +def _validate_is_per_image_tprs(per_image_tprs: torch.Tensor, image_classes: torch.Tensor) -> None: _validate_is_image_classes(image_classes) - _validate.is_tensor(per_image_tprs, argname="per_image_tprs") - # general validations _validate.is_per_image_rate_curves( - per_image_tprs.numpy(), + per_image_tprs, nan_allowed=True, # normal images have NaN TPRs decreasing=None, # not checked here ) # specific to anomalous images _validate.is_per_image_rate_curves( - per_image_tprs[image_classes == 1].numpy(), + per_image_tprs[image_classes == 1], nan_allowed=False, decreasing=True, ) @@ -114,9 +94,8 @@ def _validate_is_per_image_tprs(per_image_tprs: Tensor, image_classes: Tensor) - raise ValueError(msg) -def _validate_is_aupimos(aupimos: Tensor) -> None: - _validate.is_tensor(aupimos, argname="aupimos") - _validate.is_rates(aupimos.numpy(), nan_allowed=True) +def _validate_is_aupimos(aupimos: torch.Tensor) -> None: + _validate.is_rates(aupimos, nan_allowed=True) def _validate_is_source_images_paths(paths: Sequence[str], expected_num_paths: int | None) -> None: @@ -169,16 +148,17 @@ class PIMOResult: - TPR: True Positive Rate Attributes: - threshs (Tensor): sequence of K (monotonically increasing) thresholds used to compute the PIMO curve - shared_fpr (Tensor): K values of the shared FPR metric at the corresponding thresholds - per_image_tprs (Tensor): for each of the N images, the K values of in-image TPR at the corresponding thresholds + threshs (torch.Tensor): sequence of K (monotonically increasing) thresholds used to compute the PIMO curve + shared_fpr (torch.Tensor): K values of the shared FPR metric at the corresponding thresholds + per_image_tprs (torch.Tensor): for each of the N images, the K values of in-image TPR at the corresponding + thresholds paths (list[str]) (optional): [metadata] paths to the source images to which the PIMO curves correspond """ # data - threshs: Tensor = field(repr=False) # shape => (K,) - shared_fpr: Tensor = field(repr=False) # shape => (K,) - per_image_tprs: Tensor = field(repr=False) # shape => (N, K) + threshs: torch.Tensor = field(repr=False) # shape => (K,) + shared_fpr: torch.Tensor = field(repr=False) # shape => (K,) + per_image_tprs: torch.Tensor = field(repr=False) # shape => (N, K) # optional metadata paths: list[str] | None = field(repr=False, default=None) @@ -194,7 +174,7 @@ def num_images(self) -> int: return self.per_image_tprs.shape[0] @property - def image_classes(self) -> Tensor: + def image_classes(self) -> torch.Tensor: """Image classes (0: normal, 1: anomalous). Deduced from the per-image TPRs. @@ -205,7 +185,7 @@ def image_classes(self) -> Tensor: def __post_init__(self) -> None: """Validate the inputs for the result object are consistent.""" try: - _validate_is_threshs(self.threshs) + _validate.is_threshs(self.threshs) _validate_is_shared_fpr(self.shared_fpr, nan_allowed=False) _validate_is_per_image_tprs(self.per_image_tprs, self.image_classes) @@ -244,9 +224,9 @@ def thresh_at(self, fpr_level: float) -> tuple[int, float, float]: [1] threshold [2] the actual shared FPR value at the returned threshold """ - return pimo_numpy.thresh_at_shared_fpr_level( - self.threshs.numpy(), - self.shared_fpr.numpy(), + return functional.thresh_at_shared_fpr_level( + self.threshs, + self.shared_fpr, fpr_level, ) @@ -264,7 +244,7 @@ class AUPIMOResult: should not be confused with the number of thresholds used to compute the PIMO curve thresh_lower_bound (float): LOWER threshold bound --> corresponds to the UPPER FPR bound thresh_upper_bound (float): UPPER threshold bound --> corresponds to the LOWER FPR bound - aupimos (Tensor): values of AUPIMO scores (1 per image) + aupimos (torch.Tensor): values of AUPIMO scores (1 per image) """ # metadata @@ -275,7 +255,7 @@ class AUPIMOResult: # data thresh_lower_bound: float = field(repr=False) thresh_upper_bound: float = field(repr=False) - aupimos: Tensor = field(repr=False) # shape => (N,) + aupimos: torch.Tensor = field(repr=False) # shape => (N,) # optional metadata paths: list[str] | None = field(repr=False, default=None) @@ -296,7 +276,7 @@ def num_anomalous_images(self) -> int: return int((self.image_classes == 1).sum()) @property - def image_classes(self) -> Tensor: + def image_classes(self) -> torch.Tensor: """Image classes (0: normal, 1: anomalous).""" # if an instance has `nan` aupimo it's because it's a normal image return self.aupimos.isnan().to(torch.int32) @@ -339,7 +319,7 @@ def from_pimoresult( pimoresult: PIMOResult, fpr_bounds: tuple[float, float], num_threshs_auc: int, - aupimos: Tensor, + aupimos: torch.Tensor, paths: list[str] | None = None, ) -> "AUPIMOResult": """Return an AUPIMO result object from a PIMO result object. @@ -391,8 +371,8 @@ def from_pimoresult( # =========================================== FUNCTIONAL =========================================== def pimo_curves( - anomaly_maps: Tensor, - masks: Tensor, + anomaly_maps: torch.Tensor, + masks: torch.Tensor, num_threshs: int, paths: list[str] | None = None, ) -> PIMOResult: @@ -423,34 +403,17 @@ def pimo_curves( Returns: PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. """ - _validate_is_anomaly_maps(anomaly_maps) - anomaly_maps_array = anomaly_maps.detach().cpu().numpy() - - _validate_is_masks(masks) - masks_array = masks.detach().cpu().numpy() - if paths is not None: _validate_is_source_images_paths(paths, expected_num_paths=anomaly_maps.shape[0]) # other validations are done in the numpy code - threshs_array, shared_fpr_array, per_image_tprs_array, _ = pimo_numpy.pimo_curves( - anomaly_maps_array, - masks_array, + threshs, shared_fpr, per_image_tprs, _ = functional.pimo_curves( + anomaly_maps, + masks, num_threshs, ) # _ is `image_classes` -- not needed here because it's a property in the result object - # tensors are build with `torch.from_numpy` and so the returned tensors - # will share the same memory as the numpy arrays - device = anomaly_maps.device - # N: number of images, K: number of thresholds - # shape => (K,) - threshs = torch.from_numpy(threshs_array).to(device) - # shape => (K,) - shared_fpr = torch.from_numpy(shared_fpr_array).to(device) - # shape => (N, K) - per_image_tprs = torch.from_numpy(per_image_tprs_array).to(device) - return PIMOResult( threshs=threshs, shared_fpr=shared_fpr, @@ -460,8 +423,8 @@ def pimo_curves( def aupimo_scores( - anomaly_maps: Tensor, - masks: Tensor, + anomaly_maps: torch.Tensor, + masks: torch.Tensor, num_threshs: int = 300_000, fpr_bounds: tuple[float, float] = (1e-5, 1e-4), force: bool = False, @@ -495,33 +458,18 @@ def aupimo_scores( Returns: tuple[PIMOResult, AUPIMOResult]: PIMO and AUPIMO results dataclass objects. See `PIMOResult` and `AUPIMOResult`. """ - anomaly_maps_array = anomaly_maps.detach().cpu().numpy() - masks_array = masks.detach().cpu().numpy() - if paths is not None: _validate_is_source_images_paths(paths, expected_num_paths=anomaly_maps.shape[0]) # other validations are done in the numpy code - threshs_array, shared_fpr_array, per_image_tprs_array, _, aupimos_array, num_threshs_auc = pimo_numpy.aupimo_scores( - anomaly_maps_array, - masks_array, + threshs, shared_fpr, per_image_tprs, _, aupimos, num_threshs_auc = functional.aupimo_scores( + anomaly_maps, + masks, num_threshs, fpr_bounds=fpr_bounds, force=force, ) - # tensors are build with `torch.from_numpy` and so the returned tensors - # will share the same memory as the numpy arrays - device = anomaly_maps.device - # N: number of images, K: number of thresholds - # shape => (K,) - threshs = torch.from_numpy(threshs_array).to(device) - # shape => (K,) - shared_fpr = torch.from_numpy(shared_fpr_array).to(device) - # shape => (N, K) - per_image_tprs = torch.from_numpy(per_image_tprs_array).to(device) - # shape => (N,) - aupimos = torch.from_numpy(aupimos_array).to(device) pimoresult = PIMOResult( threshs=threshs, @@ -582,8 +530,8 @@ class PIMO(Metric): num_threshs: int binclf_algorithm: str - anomaly_maps: list[Tensor] - masks: list[Tensor] + anomaly_maps: list[torch.Tensor] + masks: list[torch.Tensor] @property def _is_empty(self) -> bool: @@ -596,7 +544,7 @@ def num_images(self) -> int: return sum(am.shape[0] for am in self.anomaly_maps) @property - def image_classes(self) -> Tensor: + def image_classes(self) -> torch.Tensor: """Image classes (0: normal, 1: anomalous).""" return _images_classes_from_masks(self.masks) @@ -625,15 +573,15 @@ def __init__( self.add_state("anomaly_maps", default=[], dist_reduce_fx="cat") self.add_state("masks", default=[], dist_reduce_fx="cat") - def update(self, anomaly_maps: Tensor, masks: Tensor) -> None: + def update(self, anomaly_maps: torch.Tensor, masks: torch.Tensor) -> None: """Update lists of anomaly maps and masks. Args: - anomaly_maps (Tensor): predictions of the model (ndim == 2, float) - masks (Tensor): ground truth masks (ndim == 2, binary) + anomaly_maps (torch.Tensor): predictions of the model (ndim == 2, float) + masks (torch.Tensor): ground truth masks (ndim == 2, binary) """ - _validate_is_anomaly_maps(anomaly_maps) - _validate_is_masks(masks) + _validate.is_anomaly_maps(anomaly_maps) + _validate.is_masks(masks) _validate.is_same_shape(anomaly_maps, masks) self.anomaly_maps.append(anomaly_maps) self.masks.append(masks) @@ -706,7 +654,7 @@ def normalizing_factor(fpr_bounds: tuple[float, float]) -> float: Returns: float: the normalization factor (>0). """ - return pimo_numpy.aupimo_normalizing_factor(fpr_bounds) + return functional.aupimo_normalizing_factor(fpr_bounds) def __repr__(self) -> str: """Show the metric name and its integration bounds.""" diff --git a/src/anomalib/metrics/per_image/utils.py b/src/anomalib/metrics/per_image/utils.py index 927ae4989f..d54aa218d8 100644 --- a/src/anomalib/metrics/per_image/utils.py +++ b/src/anomalib/metrics/per_image/utils.py @@ -7,18 +7,22 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import itertools import logging from collections import OrderedDict from copy import deepcopy from typing import TYPE_CHECKING +import matplotlib as mpl import pandas as pd +import scipy +import scipy.stats import torch from pandas import DataFrame from torch import Tensor -from . import _validate, utils_numpy -from .utils_numpy import StatsOutliersPolicy, StatsRepeatedPolicy +from . import _validate +from .enums import StatsAlternativeHypothesis, StatsOutliersPolicy, StatsRepeatedPolicy if TYPE_CHECKING: from .pimo import AUPIMOResult @@ -26,7 +30,18 @@ logger = logging.getLogger(__name__) + # =========================================== ARGS VALIDATION =========================================== +def _validate_is_per_image_scores(per_image_scores: torch.Tensor) -> None: + if per_image_scores.ndim != 1: + msg = f"Expected per-image scores to be 1D, but got {per_image_scores.ndim}D." + raise ValueError(msg) + + +def _validate_is_image_class(image_class: int) -> None: + if image_class not in {0, 1}: + msg = f"Expected image class to be either 0 for 'normal' or 1 for 'anomalous', but got {image_class}." + raise ValueError(msg) def _validate_is_models_ordered(models_ordered: tuple[str, ...]) -> None: @@ -294,26 +309,100 @@ def per_image_scores_stats( The list is sorted by increasing `stat_value`. """ - _validate.is_tensor(per_image_scores, "per_image_scores") - per_image_scores_array = per_image_scores.detach().cpu().numpy() + # other validations happen inside `utils_numpy.per_image_scores_stats` - if images_classes is not None: - _validate.is_tensor(images_classes, "images_classes") - images_classes_array = images_classes.detach().cpu().numpy() + outliers_policy = StatsOutliersPolicy(outliers_policy) + repeated_policy = StatsRepeatedPolicy(repeated_policy) + _validate_is_per_image_scores(per_image_scores) - else: - images_classes_array = None + # restrain the images to the class `only_class` if given, else use all images + if images_classes is None: + images_selection_mask = torch.ones_like(per_image_scores, dtype=bool) - # other validations happen inside `utils_numpy.per_image_scores_stats` + elif only_class is not None: + _validate.is_images_classes(images_classes) + _validate.is_same_shape(per_image_scores, images_classes) + _validate_is_image_class(only_class) + images_selection_mask = images_classes == only_class - return utils_numpy.per_image_scores_stats( - per_image_scores_array, - images_classes_array, - only_class=only_class, - outliers_policy=outliers_policy, - repeated_policy=repeated_policy, - repeated_replacement_atol=repeated_replacement_atol, - ) + else: + images_selection_mask = torch.ones_like(per_image_scores, dtype=bool) + + # indexes in `per_image_scores` are referred to as `candidate_idx` + # while the indexes in the original array are referred to as `image_idx` + # - `candidate_idx` works for `per_image_scores` and `candidate2image_idx` (see below) + # - `image_idx` works for `images_classes` and `images_idxs_selected` + per_image_scores = per_image_scores[images_selection_mask] + # converts `candidate_idx` to `image_idx` + candidate2image_idx = torch.nonzero(images_selection_mask, as_tuple=True)[0] + + # function used in `matplotlib.boxplot` + boxplot_stats = mpl.cbook.boxplot_stats(per_image_scores)[0] # [0] is for the only boxplot + + # remove unnecessary keys + boxplot_stats = {name: value for name, value in boxplot_stats.items() if name not in {"iqr", "cilo", "cihi"}} + + # unroll `fliers` (outliers), remove unnecessary ones according to `outliers_policy`, + # then add them to `boxplot_stats` with unique keys + outliers = boxplot_stats.pop("fliers") + outliers_lo = outliers[outliers < boxplot_stats["med"]] + outliers_hi = outliers[outliers > boxplot_stats["med"]] + + if outliers_policy in {StatsOutliersPolicy.HIGH, StatsOutliersPolicy.BOTH}: + boxplot_stats = { + **boxplot_stats, + **{f"outhi_{idx:06}": value for idx, value in enumerate(outliers_hi)}, + } + + if outliers_policy in {StatsOutliersPolicy.LOW, StatsOutliersPolicy.BOTH}: + boxplot_stats = { + **boxplot_stats, + **{f"outlo_{idx:06}": value for idx, value in enumerate(outliers_lo)}, + } + + # state variables for the stateful function `append_record` below + images_idxs_selected: set[int] = set() + records: list[dict[str, str | int | float]] = [] + + def append_record(stat_name: str, stat_value: float) -> None: + candidates_sorted = torch.abs(per_image_scores - stat_value).argsort() + candidate_idx = candidates_sorted[0] + image_idx = candidate2image_idx[candidate_idx] + + # handle repeated values + if image_idx not in images_idxs_selected or repeated_policy == StatsRepeatedPolicy.NONE: + pass + + elif repeated_policy == StatsRepeatedPolicy.AVOID: + for other_candidate_idx in candidates_sorted: + other_candidate_image_idx = candidate2image_idx[other_candidate_idx] + if other_candidate_image_idx in images_idxs_selected: + continue + # if the code reaches here, it means that `other_candidate_image_idx` is not in `images_idxs_selected` + # i.e. this image has not been selected yet, so it can be used + other_candidate_score = per_image_scores[other_candidate_idx] + # if the other candidate is not too far from the value, use it + # note that the first choice has not changed, so if no other is selected in the loop + # it will be the first choice + if torch.isclose(other_candidate_score, stat_value, atol=repeated_replacement_atol): + candidate_idx = other_candidate_idx + image_idx = other_candidate_image_idx + break + + images_idxs_selected.add(image_idx) + records.append( + { + "stat_name": stat_name, + "stat_value": float(stat_value), + "image_idx": int(image_idx), + "score": float(per_image_scores[candidate_idx]), + }, + ) + + # loop over the stats from the lowest to the highest value + for stat, val in sorted(boxplot_stats.items(), key=lambda x: x[1]): + append_record(stat, val) + return sorted(records, key=lambda r: r["score"]) def compare_models_pairwise_ttest_rel( @@ -374,14 +463,44 @@ def compare_models_pairwise_ttest_rel( scores_per_model_items = [ ( model_name, - (scores if isinstance(scores, Tensor) else scores.aupimos).detach().cpu().numpy(), + (scores if isinstance(scores, Tensor) else scores.aupimos), ) for model_name, scores in scores_per_model.items() ] cls = OrderedDict if isinstance(scores_per_model, OrderedDict) else dict scores_per_model_with_arrays = cls(scores_per_model_items) - return utils_numpy.compare_models_pairwise_ttest_rel(scores_per_model_with_arrays, alternative, higher_is_better) + _validate_is_scores_per_model(scores_per_model_with_arrays) + StatsAlternativeHypothesis(alternative) + + # remove nan values; list of items keeps the order of the OrderedDict + scores_per_model_nonan_items = [ + (model_name, scores[~torch.isnan(scores)]) for model_name, scores in scores_per_model_with_arrays.items() + ] + + # sort models by average value if not an ordered dictionary + # position 0 is assumed the best model + if isinstance(scores_per_model_with_arrays, OrderedDict): + scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items) + else: + scores_per_model_nonan = OrderedDict( + sorted(scores_per_model_nonan_items, key=lambda kv: kv[1].mean(), reverse=higher_is_better), + ) + + models_ordered = tuple(scores_per_model_nonan.keys()) + models_pairs = list(itertools.permutations(models_ordered, 2)) + confidences: dict[tuple[str, str], float] = {} + for model_i, model_j in models_pairs: + values_i = scores_per_model_nonan[model_i] + values_j = scores_per_model_nonan[model_j] + pvalue = scipy.stats.ttest_rel( + values_i, + values_j, + alternative=alternative, + ).pvalue + confidences[model_i, model_j] = 1.0 - float(pvalue) + + return models_ordered, confidences def compare_models_pairwise_wilcoxon( @@ -391,6 +510,7 @@ def compare_models_pairwise_wilcoxon( | OrderedDict[str, "AUPIMOResult"], alternative: str, higher_is_better: bool, + atol: float | None = 1e-3, ) -> tuple[tuple[str, ...], dict[tuple[str, str], float]]: """Compare all pairs of models using the Wilcoxon signed-rank test (non-parametric). @@ -444,14 +564,57 @@ def compare_models_pairwise_wilcoxon( scores_per_model_items = [ ( model_name, - (scores if isinstance(scores, Tensor) else scores.aupimos).detach().cpu().numpy(), + (scores if isinstance(scores, Tensor) else scores.aupimos), ) for model_name, scores in scores_per_model.items() ] cls = OrderedDict if isinstance(scores_per_model, OrderedDict) else dict scores_per_model_with_arrays = cls(scores_per_model_items) - return utils_numpy.compare_models_pairwise_wilcoxon(scores_per_model_with_arrays, alternative, higher_is_better) + _validate_is_scores_per_model(scores_per_model_with_arrays) + StatsAlternativeHypothesis(alternative) + + # remove nan values; list of items keeps the order of the OrderedDict + scores_per_model_nonan_items = [ + (model_name, scores[~torch.isnan(scores)]) for model_name, scores in scores_per_model_with_arrays.items() + ] + + # sort models by average value if not an ordered dictionary + # position 0 is assumed the best model + if isinstance(scores_per_model_with_arrays, OrderedDict): + scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items) + else: + # these average ranks will NOT consider `atol` because we want to rank the models anyway + scores_nonan = torch.stack([v for _, v in scores_per_model_nonan_items], axis=0) + avg_ranks = scipy.stats.rankdata( + -scores_nonan if higher_is_better else scores_nonan, + method="average", + axis=0, + ).mean(axis=1) + # ascending order, lower score is better --> best to worst model + argsort_avg_ranks = avg_ranks.argsort() + scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items[idx] for idx in argsort_avg_ranks) + + models_ordered = tuple(scores_per_model_nonan.keys()) + models_pairs = list(itertools.permutations(models_ordered, 2)) + confidences: dict[tuple[str, str], float] = {} + for model_i, model_j in models_pairs: + values_i = scores_per_model_nonan[model_i] + values_j = scores_per_model_nonan[model_j] + diff = values_i - values_j + + if atol is not None: + # make the difference null if below the tolerance + diff[torch.abs(diff) <= atol] = 0.0 + + # extreme case + if (diff == 0).all(): # noqa: SIM108 + pvalue = 1.0 + else: + pvalue = scipy.stats.wilcoxon(diff, alternative=alternative).pvalue + confidences[model_i, model_j] = 1.0 - float(pvalue) + + return models_ordered, confidences def format_pairwise_tests_results( diff --git a/src/anomalib/metrics/per_image/utils_numpy.py b/src/anomalib/metrics/per_image/utils_numpy.py deleted file mode 100644 index 619e7c1677..0000000000 --- a/src/anomalib/metrics/per_image/utils_numpy.py +++ /dev/null @@ -1,481 +0,0 @@ -"""Utility functions for per-image metrics.""" - -# Original Code -# https://github.com/jpcbertoldo/aupimo -# -# Modified -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import itertools -from collections import OrderedDict -from enum import Enum - -import matplotlib as mpl -import numpy as np -import scipy -import scipy.stats -from numpy import ndarray - -from . import _validate - -# =========================================== CONSTANTS =========================================== - - -class StatsOutliersPolicy(Enum): - """How to handle outliers in per-image metrics boxplots. Use them? Only high? Only low? Both? - - Outliers are defined as in a boxplot, i.e. values that are more than 1.5 times the interquartile range (IQR) away - from the Q1 and Q3 quartiles (respectively low and high outliers). The IQR is the difference between Q3 and Q1. - - None | "none": do not include outliers. - "high": only include high outliers. - "low": only include low outliers. - "both": include both high and low outliers. - """ - - NONE: str = "none" - HIGH: str = "high" - LOW: str = "low" - BOTH: str = "both" - - -class StatsRepeatedPolicy(Enum): - """How to handle repeated values in per-image metrics boxplots (two stats with same value). Avoid them? - - None | "none": do not avoid repeated values, so several stats can have the same value and image index. - "avoid": if a stat has the same value as another stat, the one with the closest then another image, - with the nearest score, is selected. - """ - - NONE: str = "none" - AVOID: str = "avoid" - - -class StatsAlternativeHypothesis(Enum): - """Alternative hypothesis for the statistical tests used to compare per-image metrics.""" - - TWO_SIDED: str = "two-sided" - LESS: str = "less" - GREATER: str = "greater" - - -# =========================================== ARGS VALIDATION =========================================== -def _validate_is_image_class(image_class: int) -> None: - if not isinstance(image_class, int): - msg = f"Expected image class to be an int (0 for 'normal', 1 for 'anomalous'), but got {type(image_class)}." - raise TypeError(msg) - - if image_class not in {0, 1}: - msg = f"Expected image class to be either 0 for 'normal' or 1 for 'anomalous', but got {image_class}." - raise ValueError(msg) - - -def _validate_is_per_image_scores(per_image_scores: ndarray) -> None: - if not isinstance(per_image_scores, ndarray): - msg = f"Expected per-image scores to be a numpy array, but got {type(per_image_scores)}." - raise TypeError(msg) - - if per_image_scores.ndim != 1: - msg = f"Expected per-image scores to be 1D, but got {per_image_scores.ndim}D." - raise ValueError(msg) - - -def _validate_is_scores_per_model(scores_per_model: dict[str, ndarray] | OrderedDict[str, ndarray]) -> None: - if not isinstance(scores_per_model, dict | OrderedDict): - msg = f"Expected scores per model to be a dictionary or ordered dictionary, but got {type(scores_per_model)}." - raise TypeError(msg) - - if len(scores_per_model) < 2: - msg = f"Expected scores per model to have at least 2 models, but got {len(scores_per_model)}." - raise ValueError(msg) - - first_key_value = None - - for model_name, scores in scores_per_model.items(): - if not isinstance(model_name, str): - msg = f"Expected model name to be a string, but got {type(model_name)} for model {model_name}." - raise TypeError(msg) - - if not isinstance(scores, ndarray): - msg = f"Expected scores to be a numpy array, but got {type(scores)} for model {model_name}." - raise TypeError(msg) - - if scores.ndim != 1: - msg = f"Expected scores to be 1D, but got {scores.ndim}D for model {model_name}." - raise ValueError(msg) - - num_valid_scores = scores[~np.isnan(scores)].shape[0] - - if num_valid_scores < 2: - msg = f"Expected at least 2 scores, but got {num_valid_scores} for model {model_name}." - raise ValueError(msg) - - if first_key_value is None: - first_key_value = (model_name, scores) - continue - - first_model_name, first_scores = first_key_value - - # same shape - if scores.shape != first_scores.shape: - msg = ( - "Expected scores to have the same shape, " - f"but got ({model_name}) {scores.shape} != {first_scores.shape} ({first_model_name})." - ) - raise ValueError(msg) - - # `nan` at the same indices - if (np.isnan(scores) != np.isnan(first_scores)).any(): - msg = ( - "Expected `nan` values, if any, to be at the same indices, " - f"but there are differences between models {model_name} and {first_model_name}." - ) - raise ValueError(msg) - - -# =========================================== FUNCTIONS =========================================== - - -def per_image_scores_stats( - per_image_scores: ndarray, - images_classes: ndarray | None = None, - only_class: int | None = None, - outliers_policy: StatsOutliersPolicy | str = StatsOutliersPolicy.NONE.value, - repeated_policy: StatsRepeatedPolicy | str = StatsRepeatedPolicy.AVOID.value, - repeated_replacement_atol: float = 1e-2, -) -> list[dict[str, str | int | float]]: - """Compute statistics of per-image scores (based on a boxplot's statistics). - - For a single per-image metric collection (1 model, 1 dataset), compute statistics (based on a boxplot) - and find the closest image to each statistic. - - This function uses `matplotlib.cbook.boxplot_stats`, which is the same function used by `matplotlib.pyplot.boxplot`. - - ** OUTLIERS ** - Outliers are defined as in a boxplot, i.e. values that are more than 1.5 times the interquartile range (IQR) away - from the Q1 and Q3 quartiles (respectively low and high outliers). The IQR is the difference between Q3 and Q1. - - Outliers are handled according to `outliers_policy`: - - None | "none": do not include outliers. - - "high": only include high outliers. - - "low": only include low outliers. - - "both": include both high and low outliers. - - ** IMAGE INDEX ** - Each statistic is associated with the image whose score is the closest to the statistic's value. - - ** REPEATED VALUES ** - It is possible that two stats have the same value (e.g. the median and the 25th percentile can be the same). - Such cases are handled according to `repeated_policy`: - - None | "none": do not address the issue, so several stats can have the same value and image index. - - "avoid": avoid repeated values by iterativealy looking for other images with similar score, whose score - must be within `repeated_replacement_atol` (absolute tolerance) of the repeated value. - - Args: - per_image_scores (ndarray): 1D ndarray of per-image scores. - images_classes (ndarray | None): - Used to filter statistics to only one class. If None, all images are considered. - If given, 1D ndarray of binary image classes (0 for 'normal', 1 for 'anomalous'). Defaults to None. - only_class (int | None): - Only used if `images_classes` is not None. - If not None, only compute statistics for images of the given class. - `None` means both image classes are used. - Defaults to None. - outliers_policy (str | None): How to handle outliers stats (use them?). See `OutliersPolicy`. Defaults to None. - repeated_policy (str | None): How to handle repeated values in boxplot stats (two stats with same value). - See `RepeatedPolicy`. Defaults to None. - repeated_replacement_atol (float): Absolute tolerance used to replace repeated values. Only used if - `repeated_policy` is not None (or 'none'). Defaults to 1e-2 (1%). - - Returns: - list[dict[str, str | int | float]]: List of boxplot statistics. - - Each dictionary has the following keys: - - 'stat_name': Name of the statistic. Possible values: - - 'mean': Mean of the scores. - - 'med': Median of the scores. - - 'q1': 25th percentile of the scores. - - 'q3': 75th percentile of the scores. - - 'whishi': Upper whisker value. - - 'whislo': Lower whisker value. - - 'outlo_i': low outlier value; `i` is a unique index for each low outlier. - - 'outhi_j': high outlier value; `j` is a unique index for each high outlier. - - 'stat_value': Value of the statistic (same units as `values`). - - 'image_idx': Index of the image in `per_image_scores` whose score is the closest to the statistic's value. - - 'score': The score of the image at index `image_idx` (not necessarily the same as `stat_value`). - - The list is sorted by increasing `stat_value`. - """ - outliers_policy = StatsOutliersPolicy(outliers_policy) - repeated_policy = StatsRepeatedPolicy(repeated_policy) - _validate_is_per_image_scores(per_image_scores) - - # restrain the images to the class `only_class` if given, else use all images - if images_classes is None: - images_selection_mask = np.ones_like(per_image_scores, dtype=bool) - - elif only_class is not None: - _validate.is_images_classes(images_classes) - _validate.is_same_shape(per_image_scores, images_classes) - _validate_is_image_class(only_class) - images_selection_mask = images_classes == only_class - - else: - images_selection_mask = np.ones_like(per_image_scores, dtype=bool) - - # indexes in `per_image_scores` are referred to as `candidate_idx` - # while the indexes in the original array are referred to as `image_idx` - # - `candidate_idx` works for `per_image_scores` and `candidate2image_idx` (see below) - # - `image_idx` works for `images_classes` and `images_idxs_selected` - per_image_scores = per_image_scores[images_selection_mask] - # converts `candidate_idx` to `image_idx` - candidate2image_idx = np.nonzero(images_selection_mask)[0] - - # function used in `matplotlib.boxplot` - boxplot_stats = mpl.cbook.boxplot_stats(per_image_scores)[0] # [0] is for the only boxplot - - # remove unnecessary keys - boxplot_stats = {name: value for name, value in boxplot_stats.items() if name not in {"iqr", "cilo", "cihi"}} - - # unroll `fliers` (outliers), remove unnecessary ones according to `outliers_policy`, - # then add them to `boxplot_stats` with unique keys - outliers = boxplot_stats.pop("fliers") - outliers_lo = outliers[outliers < boxplot_stats["med"]] - outliers_hi = outliers[outliers > boxplot_stats["med"]] - - if outliers_policy in {StatsOutliersPolicy.HIGH, StatsOutliersPolicy.BOTH}: - boxplot_stats = { - **boxplot_stats, - **{f"outhi_{idx:06}": value for idx, value in enumerate(outliers_hi)}, - } - - if outliers_policy in {StatsOutliersPolicy.LOW, StatsOutliersPolicy.BOTH}: - boxplot_stats = { - **boxplot_stats, - **{f"outlo_{idx:06}": value for idx, value in enumerate(outliers_lo)}, - } - - # state variables for the stateful function `append_record` below - images_idxs_selected: set[int] = set() - records: list[dict[str, str | int | float]] = [] - - def append_record(stat_name: str, stat_value: float) -> None: - candidates_sorted = np.abs(per_image_scores - stat_value).argsort() - candidate_idx = candidates_sorted[0] - image_idx = candidate2image_idx[candidate_idx] - - # handle repeated values - if image_idx not in images_idxs_selected or repeated_policy == StatsRepeatedPolicy.NONE: - pass - - elif repeated_policy == StatsRepeatedPolicy.AVOID: - for other_candidate_idx in candidates_sorted: - other_candidate_image_idx = candidate2image_idx[other_candidate_idx] - if other_candidate_image_idx in images_idxs_selected: - continue - # if the code reaches here, it means that `other_candidate_image_idx` is not in `images_idxs_selected` - # i.e. this image has not been selected yet, so it can be used - other_candidate_score = per_image_scores[other_candidate_idx] - # if the other candidate is not too far from the value, use it - # note that the first choice has not changed, so if no other is selected in the loop - # it will be the first choice - if np.isclose(other_candidate_score, stat_value, atol=repeated_replacement_atol): - candidate_idx = other_candidate_idx - image_idx = other_candidate_image_idx - break - - images_idxs_selected.add(image_idx) - records.append( - { - "stat_name": stat_name, - "stat_value": float(stat_value), - "image_idx": int(image_idx), - "score": float(per_image_scores[candidate_idx]), - }, - ) - - # loop over the stats from the lowest to the highest value - for stat, val in sorted(boxplot_stats.items(), key=lambda x: x[1]): - append_record(stat, val) - return sorted(records, key=lambda r: r["score"]) - - -def compare_models_pairwise_ttest_rel( - scores_per_model: dict[str, ndarray] | OrderedDict[str, ndarray], - alternative: str, - higher_is_better: bool, -) -> tuple[tuple[str, ...], dict[tuple[str, str], float]]: - """Compare all pairs of models using the paired t-test on two related samples (parametric). - - This is a test for the null hypothesis that two repeated samples have identical average (expected) values. - In fact, it tests whether the average of the differences between the two samples is significantly different from 0. - - Refs: - - `scipy.stats.ttest_rel`: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_rel.html - - Wikipedia page: https://en.wikipedia.org/wiki/Student's_t-test#Dependent_t-test_for_paired_samples - - === - - If an ordered dictionary is given, the models are sorted by the order of the dictionary. - Otherwise, the models are sorted by average SCORE. - - Args: - scores_per_model: Dictionary of `n` models and their per-image scores. - key: model name - value: tensor of shape (num_images,). All `nan` values must be at the same positions. - higher_is_better: Whether higher values of score are better or worse. Defaults to True. - alternative: Alternative hypothesis for the statistical tests. See `confidences` in "Returns" section. - Valid values are `StatsAlternativeHypothesis.ALTERNATIVES`. - - Returns: - (models_ordered, test_results): - - models_ordered: Models sorted by the user (`OrderedDict` input) or automatically (`dict` input). - - Automatic sorting is by average score from best to worst model. - Depending on `higher_is_better`, this corresponds to: - - `higher_is_better=True` ==> descending score order - - `higher_is_better=False` ==> ascending score order - along the indices from 0 to `n-1`. - - - confidences: Dictionary of confidence values for each pair of models. - - For all pairs of indices i and j from 0 to `n-1` such that i != j: - - key: (models_ordered[i], models_ordered[j]) - - value: confidence on the alternative hypothesis. - - For models `models_ordered[i]` and `models_ordered[j]`, the alternative hypothesis is: - - if `less`: model[i] < model[j] - - if `greater`: model[i] > model[j] - - if `two-sided`: model[i] != model[j] - in termos of average score. - """ - _validate_is_scores_per_model(scores_per_model) - StatsAlternativeHypothesis(alternative) - - # remove nan values; list of items keeps the order of the OrderedDict - scores_per_model_nonan_items = [ - (model_name, scores[~np.isnan(scores)]) for model_name, scores in scores_per_model.items() - ] - - # sort models by average value if not an ordered dictionary - # position 0 is assumed the best model - if isinstance(scores_per_model, OrderedDict): - scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items) - else: - scores_per_model_nonan = OrderedDict( - sorted(scores_per_model_nonan_items, key=lambda kv: kv[1].mean(), reverse=higher_is_better), - ) - - models_ordered = tuple(scores_per_model_nonan.keys()) - models_pairs = list(itertools.permutations(models_ordered, 2)) - confidences: dict[tuple[str, str], float] = {} - for model_i, model_j in models_pairs: - values_i = scores_per_model_nonan[model_i] - values_j = scores_per_model_nonan[model_j] - pvalue = scipy.stats.ttest_rel( - values_i, - values_j, - alternative=alternative, - ).pvalue - confidences[model_i, model_j] = 1.0 - float(pvalue) - - return models_ordered, confidences - - -def compare_models_pairwise_wilcoxon( - scores_per_model: dict[str, ndarray] | OrderedDict[str, ndarray], - alternative: str, - higher_is_better: bool, - atol: float | None = 1e-3, -) -> tuple[tuple[str, ...], dict[tuple[str, str], float]]: - """Compare all pairs of models using the Wilcoxon signed-rank test (non-parametric). - - Each comparison of two models is a Wilcoxon signed-rank test (null hypothesis is that they are equal). - - It tests whether the distribution of the differences of scores is symmetric about zero in a non-parametric way. - This is like the non-parametric version of the paired t-test. - - Refs: - - `scipy.stats.wilcoxon`: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wilcoxon.html#scipy.stats.wilcoxon - - Wikipedia page: https://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test - - === - - If an ordered dictionary is given, the models are sorted by the order of the dictionary. - Otherwise, the models are sorted by average RANK. - - Args: - scores_per_model: Dictionary of `n` models and their per-image scores. - key: model name - value: tensor of shape (num_images,). All `nan` values must be at the same positions. - higher_is_better: Whether higher values of score are better or worse. Defaults to True. - alternative: Alternative hypothesis for the statistical tests. See `confidences` in "Returns" section. - Valid values are `StatsAlternativeHypothesis.ALTERNATIVES`. - atol: Absolute tolerance used to consider two scores as equal. Defaults to 1e-3 (0.1%). - When doing a paired test, if the difference between two scores is below `atol`, the difference is - truncated to 0. If `atol` is None, no truncation is done. - - Returns: - (models_ordered, test_results): - - models_ordered: Models sorted by the user (`OrderedDict` input) or automatically (`dict` input). - - Automatic sorting is from "best to worst" model, which corresponds to ascending average rank - along the indices from 0 to `n-1`. - - - confidences: Dictionary of confidence values for each pair of models. - - For all pairs of indices i and j from 0 to `n-1` such that i != j: - - key: (models_ordered[i], models_ordered[j]) - - value: confidence on the alternative hypothesis. - - For models `models_ordered[i]` and `models_ordered[j]`, the alternative hypothesis is: - - if `less`: model[i] < model[j] - - if `greater`: model[i] > model[j] - - if `two-sided`: model[i] != model[j] - in terms of average ranks (not scores!). - """ - _validate_is_scores_per_model(scores_per_model) - StatsAlternativeHypothesis(alternative) - - # remove nan values; list of items keeps the order of the OrderedDict - scores_per_model_nonan_items = [ - (model_name, scores[~np.isnan(scores)]) for model_name, scores in scores_per_model.items() - ] - - # sort models by average value if not an ordered dictionary - # position 0 is assumed the best model - if isinstance(scores_per_model, OrderedDict): - scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items) - else: - # these average ranks will NOT consider `atol` because we want to rank the models anyway - scores_nonan = np.stack([v for _, v in scores_per_model_nonan_items], axis=0) - avg_ranks = scipy.stats.rankdata( - -scores_nonan if higher_is_better else scores_nonan, - method="average", - axis=0, - ).mean(axis=1) - # ascending order, lower score is better --> best to worst model - argsort_avg_ranks = avg_ranks.argsort() - scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items[idx] for idx in argsort_avg_ranks) - - models_ordered = tuple(scores_per_model_nonan.keys()) - models_pairs = list(itertools.permutations(models_ordered, 2)) - confidences: dict[tuple[str, str], float] = {} - for model_i, model_j in models_pairs: - values_i = scores_per_model_nonan[model_i] - values_j = scores_per_model_nonan[model_j] - diff = values_i - values_j - - if atol is not None: - # make the difference null if below the tolerance - diff[np.abs(diff) <= atol] = 0.0 - - # extreme case - if (diff == 0).all(): # noqa: SIM108 - pvalue = 1.0 - else: - pvalue = scipy.stats.wilcoxon(diff, alternative=alternative).pvalue - confidences[model_i, model_j] = 1.0 - float(pvalue) - - return models_ordered, confidences diff --git a/tests/unit/metrics/per_image/test_binclf_curve.py b/tests/unit/metrics/per_image/test_binclf_curve.py index cd7c0cdd98..eed53f3248 100644 --- a/tests/unit/metrics/per_image/test_binclf_curve.py +++ b/tests/unit/metrics/per_image/test_binclf_curve.py @@ -9,88 +9,87 @@ # ruff: noqa: SLF001, PT011 -import numpy as np import pytest -from numpy import ndarray +import torch -from anomalib.metrics.per_image import binclf_curve_numpy +from anomalib.metrics.per_image import binclf_curve def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: """Generate test cases.""" - pred = np.arange(1, 5, dtype=np.float32) - threshs = np.arange(1, 5, dtype=np.float32) + pred = torch.arange(1, 5, dtype=torch.float32) + threshs = torch.arange(1, 5, dtype=torch.float32) - gt_norm = np.zeros(4).astype(bool) - gt_anom = np.concatenate([np.zeros(2), np.ones(2)]).astype(bool) + gt_norm = torch.zeros(4).to(bool) + gt_anom = torch.concatenate([torch.zeros(2), torch.ones(2)]).to(bool) # in the case where thresholds are all unique values in the predictions - expected_norm = np.stack( + expected_norm = torch.stack( [ - np.array([[0, 4], [0, 0]]), - np.array([[1, 3], [0, 0]]), - np.array([[2, 2], [0, 0]]), - np.array([[3, 1], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[1, 3], [0, 0]]), + torch.tensor([[2, 2], [0, 0]]), + torch.tensor([[3, 1], [0, 0]]), ], axis=0, - ).astype(int) - expected_anom = np.stack( + ).to(int) + expected_anom = torch.stack( [ - np.array([[0, 2], [0, 2]]), - np.array([[1, 1], [0, 2]]), - np.array([[2, 0], [0, 2]]), - np.array([[2, 0], [1, 1]]), + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[1, 1], [0, 2]]), + torch.tensor([[2, 0], [0, 2]]), + torch.tensor([[2, 0], [1, 1]]), ], axis=0, - ).astype(int) + ).to(int) - expected_tprs_norm = np.array([np.nan, np.nan, np.nan, np.nan]) - expected_tprs_anom = np.array([1.0, 1.0, 1.0, 0.5]) - expected_tprs = np.stack([expected_tprs_anom, expected_tprs_norm], axis=0).astype(np.float64) + expected_tprs_norm = torch.tensor([torch.nan, torch.nan, torch.nan, torch.nan]) + expected_tprs_anom = torch.tensor([1.0, 1.0, 1.0, 0.5]) + expected_tprs = torch.stack([expected_tprs_anom, expected_tprs_norm], axis=0).to(torch.float64) - expected_fprs_norm = np.array([1.0, 0.75, 0.5, 0.25]) - expected_fprs_anom = np.array([1.0, 0.5, 0.0, 0.0]) - expected_fprs = np.stack([expected_fprs_anom, expected_fprs_norm], axis=0).astype(np.float64) + expected_fprs_norm = torch.tensor([1.0, 0.75, 0.5, 0.25]) + expected_fprs_anom = torch.tensor([1.0, 0.5, 0.0, 0.0]) + expected_fprs = torch.stack([expected_fprs_anom, expected_fprs_norm], axis=0).to(torch.float64) # in the case where all thresholds are higher than the highest prediction - expected_norm_threshs_too_high = np.stack( + expected_norm_threshs_too_high = torch.stack( [ - np.array([[4, 0], [0, 0]]), - np.array([[4, 0], [0, 0]]), - np.array([[4, 0], [0, 0]]), - np.array([[4, 0], [0, 0]]), + torch.tensor([[4, 0], [0, 0]]), + torch.tensor([[4, 0], [0, 0]]), + torch.tensor([[4, 0], [0, 0]]), + torch.tensor([[4, 0], [0, 0]]), ], axis=0, - ).astype(int) - expected_anom_threshs_too_high = np.stack( + ).to(int) + expected_anom_threshs_too_high = torch.stack( [ - np.array([[2, 0], [2, 0]]), - np.array([[2, 0], [2, 0]]), - np.array([[2, 0], [2, 0]]), - np.array([[2, 0], [2, 0]]), + torch.tensor([[2, 0], [2, 0]]), + torch.tensor([[2, 0], [2, 0]]), + torch.tensor([[2, 0], [2, 0]]), + torch.tensor([[2, 0], [2, 0]]), ], axis=0, - ).astype(int) + ).to(int) # in the case where all thresholds are lower than the lowest prediction - expected_norm_threshs_too_low = np.stack( + expected_norm_threshs_too_low = torch.stack( [ - np.array([[0, 4], [0, 0]]), - np.array([[0, 4], [0, 0]]), - np.array([[0, 4], [0, 0]]), - np.array([[0, 4], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), ], axis=0, - ).astype(int) - expected_anom_threshs_too_low = np.stack( + ).to(int) + expected_anom_threshs_too_low = torch.stack( [ - np.array([[0, 2], [0, 2]]), - np.array([[0, 2], [0, 2]]), - np.array([[0, 2], [0, 2]]), - np.array([[0, 2], [0, 2]]), + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[0, 2], [0, 2]]), ], axis=0, - ).astype(int) + ).to(int) if metafunc.function is test__binclf_one_curve: metafunc.parametrize( @@ -106,11 +105,14 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ], ) - preds = np.stack([pred, pred], axis=0) - gts = np.stack([gt_anom, gt_norm], axis=0) - binclf_curves = np.stack([expected_anom, expected_norm], axis=0) - binclf_curves_threshs_too_high = np.stack([expected_anom_threshs_too_high, expected_norm_threshs_too_high], axis=0) - binclf_curves_threshs_too_low = np.stack([expected_anom_threshs_too_low, expected_norm_threshs_too_low], axis=0) + preds = torch.stack([pred, pred], axis=0) + gts = torch.stack([gt_anom, gt_norm], axis=0) + binclf_curves = torch.stack([expected_anom, expected_norm], axis=0) + binclf_curves_threshs_too_high = torch.stack( + [expected_anom_threshs_too_high, expected_norm_threshs_too_high], + axis=0, + ) + binclf_curves_threshs_too_low = torch.stack([expected_anom_threshs_too_low, expected_norm_threshs_too_low], axis=0) if metafunc.function is test__binclf_multiple_curves: metafunc.parametrize( @@ -149,16 +151,16 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ([preds, gts[:1], threshs], {}, ValueError), ([preds[:, :2], gts, threshs], {}, ValueError), # `scores` be of type float - ([preds.astype(int), gts, threshs], {}, TypeError), + ([preds.to(int), gts, threshs], {}, TypeError), # `gts` be of type bool - ([preds, gts.astype(int), threshs], {}, TypeError), + ([preds, gts.to(int), threshs], {}, TypeError), # `threshs` be of type float - ([preds, gts, threshs.astype(int)], {}, TypeError), + ([preds, gts, threshs.to(int)], {}, TypeError), # `threshs` must be sorted in ascending order - ([preds, gts, np.flip(threshs)], {}, ValueError), - ([preds, gts, np.concatenate([threshs[-2:], threshs[:2]])], {}, ValueError), + ([preds, gts, torch.flip(threshs, dims=[0])], {}, ValueError), + ([preds, gts, torch.concatenate([threshs[-2:], threshs[:2]])], {}, ValueError), # `threshs` must be unique - ([preds, gts, np.sort(np.concatenate([threshs, threshs]))], {}, ValueError), + ([preds, gts, torch.sort(torch.concatenate([threshs, threshs]))[0]], {}, ValueError), ], ) @@ -167,7 +169,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: preds = preds.reshape(2, 2, 2) gts = gts.reshape(2, 2, 2) - per_image_binclf_curves_numpy_argvalues = [ + per_image_binclf_curves_argvalues = [ # `threshs_choice` = "given" ( preds, @@ -208,7 +210,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ), ( 2 * preds, - gts.astype(int), # this is ok + gts.to(int), # this is ok "minmax-linspace", None, len(threshs), @@ -217,7 +219,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ), ] - if metafunc.function is test_per_image_binclf_curve_numpy: + if metafunc.function is test_per_image_binclf_curve: metafunc.parametrize( argnames=( "anomaly_maps", @@ -228,10 +230,10 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: "expected_threshs", "expected_binclf_curves", ), - argvalues=per_image_binclf_curves_numpy_argvalues, + argvalues=per_image_binclf_curves_argvalues, ) - if metafunc.function is test_per_image_binclf_curve_numpy_validations: + if metafunc.function is test_per_image_binclf_curve_validations: metafunc.parametrize( argnames=("args", "exception"), argvalues=[ @@ -242,11 +244,11 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ([preds, gts[:1]], ValueError), ([preds[:, :1], gts], ValueError), # `scores` be of type float - ([preds.astype(int), gts], TypeError), + ([preds.to(int), gts], TypeError), # `gts` be of type bool or int - ([preds, gts.astype(float)], TypeError), + ([preds, gts.to(float)], TypeError), # `threshs` be of type float - ([preds, gts, threshs.astype(int)], TypeError), + ([preds, gts, threshs.to(int)], TypeError), ], ) metafunc.parametrize( @@ -263,7 +265,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ) # same as above but testing other validations - if metafunc.function is test_per_image_binclf_curve_numpy_validations_alt: + if metafunc.function is test_per_image_binclf_curve_validations_alt: metafunc.parametrize( argnames=("args", "kwargs", "exception"), argvalues=[ @@ -276,7 +278,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ], ) - if metafunc.function is test_rate_metrics_numpy: + if metafunc.function is test_rate_metrics: metafunc.parametrize( argnames=("binclf_curves", "expected_fprs", "expected_tprs"), argvalues=[ @@ -290,22 +292,22 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: # LOW-LEVEL FUNCTIONS (PYTHON) -def test__binclf_one_curve(pred: ndarray, gt: ndarray, threshs: ndarray, expected: ndarray) -> None: +def test__binclf_one_curve(pred: torch.Tensor, gt: torch.Tensor, threshs: torch.Tensor, expected: torch.Tensor) -> None: """Test if `_binclf_one_curve()` returns the expected values.""" - computed = binclf_curve_numpy._binclf_one_curve(pred, gt, threshs) - assert computed.shape == (threshs.size, 2, 2) - assert (computed == expected).all() + computed = binclf_curve._binclf_one_curve(pred, gt, threshs) + assert computed.shape == (threshs.numel(), 2, 2) + assert (computed == expected.numpy()).all() def test__binclf_multiple_curves( - preds: ndarray, - gts: ndarray, - threshs: ndarray, - expecteds: ndarray, + preds: torch.Tensor, + gts: torch.Tensor, + threshs: torch.Tensor, + expecteds: torch.Tensor, ) -> None: """Test if `_binclf_multiple_curves()` returns the expected values.""" - computed = binclf_curve_numpy.binclf_multiple_curves(preds, gts, threshs) - assert computed.shape == (preds.shape[0], threshs.size, 2, 2) + computed = binclf_curve.binclf_multiple_curves(preds, gts, threshs) + assert computed.shape == (preds.shape[0], threshs.numel(), 2, 2) assert (computed == expecteds).all() @@ -314,13 +316,13 @@ def test__binclf_multiple_curves( def test_binclf_multiple_curves( - preds: ndarray, - gts: ndarray, - threshs: ndarray, - expected_binclf_curves: ndarray, + preds: torch.Tensor, + gts: torch.Tensor, + threshs: torch.Tensor, + expected_binclf_curves: torch.Tensor, ) -> None: """Test if `binclf_multiple_curves()` returns the expected values.""" - computed = binclf_curve_numpy.binclf_multiple_curves( + computed = binclf_curve.binclf_multiple_curves( preds, gts, threshs, @@ -329,39 +331,39 @@ def test_binclf_multiple_curves( assert (computed == expected_binclf_curves).all() # it's ok to have the threhsholds beyond the range of the preds - binclf_curve_numpy.binclf_multiple_curves(preds, gts, 2 * threshs) + binclf_curve.binclf_multiple_curves(preds, gts, 2 * threshs) # or inside the bounds without reaching them - binclf_curve_numpy.binclf_multiple_curves(preds, gts, 0.5 * threshs) + binclf_curve.binclf_multiple_curves(preds, gts, 0.5 * threshs) # it's also ok to have more threshs than unique values in the preds # add the values in between the threshs threshs_unncessary = 0.5 * (threshs[:-1] + threshs[1:]) - threshs_unncessary = np.concatenate([threshs_unncessary, threshs]) - threshs_unncessary = np.sort(threshs_unncessary) - binclf_curve_numpy.binclf_multiple_curves(preds, gts, threshs_unncessary) + threshs_unncessary = torch.concatenate([threshs_unncessary, threshs]) + threshs_unncessary = torch.sort(threshs_unncessary)[0] + binclf_curve.binclf_multiple_curves(preds, gts, threshs_unncessary) # or less - binclf_curve_numpy.binclf_multiple_curves(preds, gts, threshs[1:3]) + binclf_curve.binclf_multiple_curves(preds, gts, threshs[1:3]) def test_binclf_multiple_curves_validations(args: list, kwargs: dict, exception: Exception) -> None: """Test if `_binclf_multiple_curves_python()` raises the expected errors.""" with pytest.raises(exception): - binclf_curve_numpy.binclf_multiple_curves(*args, **kwargs) + binclf_curve.binclf_multiple_curves(*args, **kwargs) -def test_per_image_binclf_curve_numpy( - anomaly_maps: ndarray, - masks: ndarray, +def test_per_image_binclf_curve( + anomaly_maps: torch.Tensor, + masks: torch.Tensor, threshs_choice: str, - threshs_given: ndarray | None, + threshs_given: torch.Tensor | None, num_threshs: int | None, - expected_threshs: ndarray, - expected_binclf_curves: ndarray, + expected_threshs: torch.Tensor, + expected_binclf_curves: torch.Tensor, ) -> None: """Test if `per_image_binclf_curve()` returns the expected values.""" - computed_threshs, computed_binclf_curves = binclf_curve_numpy.per_image_binclf_curve( + computed_threshs, computed_binclf_curves = binclf_curve.per_image_binclf_curve( anomaly_maps, masks, threshs_choice=threshs_choice, @@ -380,24 +382,28 @@ def test_per_image_binclf_curve_numpy( assert (computed_binclf_curves == expected_binclf_curves).all() -def test_per_image_binclf_curve_numpy_validations(args: list, kwargs: dict, exception: Exception) -> None: +def test_per_image_binclf_curve_validations(args: list, kwargs: dict, exception: Exception) -> None: """Test if `per_image_binclf_curve()` raises the expected errors.""" with pytest.raises(exception): - binclf_curve_numpy.per_image_binclf_curve(*args, **kwargs) + binclf_curve.per_image_binclf_curve(*args, **kwargs) -def test_per_image_binclf_curve_numpy_validations_alt(args: list, kwargs: dict, exception: Exception) -> None: +def test_per_image_binclf_curve_validations_alt(args: list, kwargs: dict, exception: Exception) -> None: """Test if `per_image_binclf_curve()` raises the expected errors.""" - test_per_image_binclf_curve_numpy_validations(args, kwargs, exception) + test_per_image_binclf_curve_validations(args, kwargs, exception) -def test_rate_metrics_numpy(binclf_curves: ndarray, expected_fprs: ndarray, expected_tprs: ndarray) -> None: +def test_rate_metrics( + binclf_curves: torch.Tensor, + expected_fprs: torch.Tensor, + expected_tprs: torch.Tensor, +) -> None: """Test if rate metrics are computed correctly.""" - tprs = binclf_curve_numpy.per_image_tpr(binclf_curves) - fprs = binclf_curve_numpy.per_image_fpr(binclf_curves) + tprs = binclf_curve.per_image_tpr(binclf_curves) + fprs = binclf_curve.per_image_fpr(binclf_curves) assert tprs.shape == expected_tprs.shape assert fprs.shape == expected_fprs.shape - assert np.allclose(tprs, expected_tprs, equal_nan=True) - assert np.allclose(fprs, expected_fprs, equal_nan=True) + assert torch.allclose(tprs, expected_tprs, equal_nan=True) + assert torch.allclose(fprs, expected_fprs, equal_nan=True) diff --git a/tests/unit/metrics/per_image/test_pimo.py b/tests/unit/metrics/per_image/test_pimo.py index d0092fb616..061f0f8042 100644 --- a/tests/unit/metrics/per_image/test_pimo.py +++ b/tests/unit/metrics/per_image/test_pimo.py @@ -1,4 +1,4 @@ -"""Test `anomalib.metrics.per_image.pimo_numpy`.""" +"""Test `anomalib.metrics.per_image.functional`.""" # Original Code # https://github.com/jpcbertoldo/aupimo @@ -7,13 +7,13 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import numpy as np +import logging + import pytest import torch -from numpy import ndarray from torch import Tensor -from anomalib.metrics.per_image import pimo, pimo_numpy +from anomalib.metrics.per_image import functional, pimo from anomalib.metrics.per_image.pimo import AUPIMOResult, PIMOResult @@ -23,7 +23,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: All functions are parametrized with the same setting: 1 normal and 2 anomalous images. The anomaly maps are the same for all functions, but the masks are different. """ - expected_threshs = np.arange(1, 7 + 1, dtype=np.float32) + expected_threshs = torch.arange(1, 7 + 1, dtype=torch.float32) shape = (1000, 1000) # (H, W), 1 million pixels # --- normal --- @@ -31,7 +31,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: # value: 7 6 5 4 3 2 1 # count: 1 9 90 900 9k 90k 900k # cumsum: 1 10 100 1k 10k 100k 1M - pred_norm = np.ones(1_000_000, dtype=np.float32) + pred_norm = torch.ones(1_000_000, dtype=torch.float32) pred_norm[:100_000] += 1 pred_norm[:10_000] += 1 pred_norm[:1_000] += 1 @@ -39,36 +39,31 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: pred_norm[:10] += 1 pred_norm[:1] += 1 pred_norm = pred_norm.reshape(shape) - mask_norm = np.zeros_like(pred_norm, dtype=np.int32) + mask_norm = torch.zeros_like(pred_norm, dtype=torch.int32) - expected_fpr_norm = np.array([1.0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], dtype=np.float64) - expected_tpr_norm = np.full((7,), np.nan, dtype=np.float64) + expected_fpr_norm = torch.tensor([1.0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], dtype=torch.float64) + expected_tpr_norm = torch.full((7,), torch.nan, dtype=torch.float64) # --- anomalous --- - pred_anom1 = pred_norm.copy() - mask_anom1 = np.ones_like(pred_anom1, dtype=np.int32) - expected_tpr_anom1 = expected_fpr_norm.copy() + pred_anom1 = pred_norm.clone() + mask_anom1 = torch.ones_like(pred_anom1, dtype=torch.int32) + expected_tpr_anom1 = expected_fpr_norm.clone() # only the first 100_000 pixels are anomalous # which corresponds to the first 100_000 highest scores (2 to 7) - pred_anom2 = pred_norm.copy() - mask_anom2 = np.concatenate([np.ones(100_000), np.zeros(900_000)]).reshape(shape).astype(np.int32) + pred_anom2 = pred_norm.clone() + mask_anom2 = torch.concatenate([torch.ones(100_000), torch.zeros(900_000)]).reshape(shape).to(torch.int32) expected_tpr_anom2 = (10 * expected_fpr_norm).clip(0, 1) - anomaly_maps = np.stack([pred_norm, pred_anom1, pred_anom2], axis=0) - masks = np.stack([mask_norm, mask_anom1, mask_anom2], axis=0) + anomaly_maps = torch.stack([pred_norm, pred_anom1, pred_anom2], axis=0) + masks = torch.stack([mask_norm, mask_anom1, mask_anom2], axis=0) expected_shared_fpr = expected_fpr_norm - expected_per_image_tprs = np.stack([expected_tpr_norm, expected_tpr_anom1, expected_tpr_anom2], axis=0) - expected_image_classes = np.array([0, 1, 1], dtype=np.int32) - - if ( - metafunc.function is test_pimo_numpy - or metafunc.function is test_pimo - or metafunc.function is test_aupimo_values_numpy - or metafunc.function is test_aupimo_values - ): - argvalues_arrays = [ + expected_per_image_tprs = torch.stack([expected_tpr_norm, expected_tpr_anom1, expected_tpr_anom2], axis=0) + expected_image_classes = torch.tensor([0, 1, 1], dtype=torch.int32) + + if metafunc.function is test_pimo or metafunc.function is test_aupimo_values: + argvalues_tensors = [ ( anomaly_maps, masks, @@ -86,11 +81,6 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: expected_image_classes, ), ] - argvalues_tensors = [ - tuple(torch.from_numpy(arg) if isinstance(arg, ndarray) else arg for arg in arvals) - for arvals in argvalues_arrays - ] - argvalues = argvalues_arrays if "numpy" in metafunc.function.__name__ else argvalues_tensors metafunc.parametrize( argnames=( "anomaly_maps", @@ -100,58 +90,53 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: "expected_per_image_tprs", "expected_image_classes", ), - argvalues=argvalues, + argvalues=argvalues_tensors, ) - if metafunc.function is test_aupimo_values_numpy or metafunc.function is test_aupimo_values: - argvalues_arrays = [ + if metafunc.function is test_aupimo_values: + argvalues_tensors = [ ( (1e-1, 1.0), - np.array( + torch.tensor( [ - np.nan, + torch.nan, # recall: trapezium area = (a + b) * h / 2 (0.10 + 1.0) * 1 / 2, (1.0 + 1.0) * 1 / 2, ], - dtype=np.float64, + dtype=torch.float64, ), ), ( (1e-3, 1e-1), - np.array( + torch.tensor( [ - np.nan, + torch.nan, # average of two trapezium areas / 2 (normalizing factor) (((1e-3 + 1e-2) * 1 / 2) + ((1e-2 + 1e-1) * 1 / 2)) / 2, (((1e-2 + 1e-1) * 1 / 2) + ((1e-1 + 1.0) * 1 / 2)) / 2, ], - dtype=np.float64, + dtype=torch.float64, ), ), ( (1e-5, 1e-4), - np.array( + torch.tensor( [ - np.nan, + torch.nan, (1e-5 + 1e-4) * 1 / 2, (1e-4 + 1e-3) * 1 / 2, ], - dtype=np.float64, + dtype=torch.float64, ), ), ] - argvalues_tensors = [ - tuple(torch.from_numpy(arg) if isinstance(arg, ndarray) else arg for arg in arvals) - for arvals in argvalues_arrays - ] - argvalues = argvalues_arrays if "numpy" in metafunc.function.__name__ else argvalues_tensors metafunc.parametrize( argnames=( "fpr_bounds", "expected_aupimos", # trapezoid surfaces ), - argvalues=argvalues, + argvalues=argvalues_tensors, ) if metafunc.function is test_aupimo_edge: @@ -183,39 +168,24 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: def _do_test_pimo_outputs( - threshs: ndarray | Tensor, - shared_fpr: ndarray | Tensor, - per_image_tprs: ndarray | Tensor, - image_classes: ndarray | Tensor, - expected_threshs: ndarray | Tensor, - expected_shared_fpr: ndarray | Tensor, - expected_per_image_tprs: ndarray | Tensor, - expected_image_classes: ndarray | Tensor, + threshs: Tensor, + shared_fpr: Tensor, + per_image_tprs: Tensor, + image_classes: Tensor, + expected_threshs: Tensor, + expected_shared_fpr: Tensor, + expected_per_image_tprs: Tensor, + expected_image_classes: Tensor, ) -> None: """Test if the outputs of any of the PIMO interfaces are correct.""" - if isinstance(threshs, Tensor): - assert isinstance(shared_fpr, Tensor) - assert isinstance(per_image_tprs, Tensor) - assert isinstance(image_classes, Tensor) - assert isinstance(expected_threshs, Tensor) - assert isinstance(expected_shared_fpr, Tensor) - assert isinstance(expected_per_image_tprs, Tensor) - assert isinstance(expected_image_classes, Tensor) - allclose = torch.allclose - - elif isinstance(threshs, ndarray): - assert isinstance(shared_fpr, ndarray) - assert isinstance(per_image_tprs, ndarray) - assert isinstance(image_classes, ndarray) - assert isinstance(expected_threshs, ndarray) - assert isinstance(expected_shared_fpr, ndarray) - assert isinstance(expected_per_image_tprs, ndarray) - assert isinstance(expected_image_classes, ndarray) - allclose = np.allclose - - else: - msg = "Expected `threshs` to be a Tensor or ndarray." - raise TypeError(msg) + assert isinstance(shared_fpr, Tensor) + assert isinstance(per_image_tprs, Tensor) + assert isinstance(image_classes, Tensor) + assert isinstance(expected_threshs, Tensor) + assert isinstance(expected_shared_fpr, Tensor) + assert isinstance(expected_per_image_tprs, Tensor) + assert isinstance(expected_image_classes, Tensor) + allclose = torch.allclose assert threshs.ndim == 1 assert shared_fpr.ndim == 1 @@ -228,32 +198,6 @@ def _do_test_pimo_outputs( assert (image_classes == expected_image_classes).all() -def test_pimo_numpy( - anomaly_maps: ndarray, - masks: ndarray, - expected_threshs: ndarray, - expected_shared_fpr: ndarray, - expected_per_image_tprs: ndarray, - expected_image_classes: ndarray, -) -> None: - """Test if `pimo()` returns the expected values.""" - threshs, shared_fpr, per_image_tprs, image_classes = pimo_numpy.pimo_curves( - anomaly_maps, - masks, - num_threshs=7, - ) - _do_test_pimo_outputs( - threshs, - shared_fpr, - per_image_tprs, - image_classes, - expected_threshs, - expected_shared_fpr, - expected_per_image_tprs, - expected_image_classes, - ) - - def test_pimo( anomaly_maps: Tensor, masks: Tensor, @@ -298,16 +242,16 @@ def do_assertions(pimoresult: PIMOResult) -> None: def _do_test_aupimo_outputs( - threshs: ndarray | Tensor, - shared_fpr: ndarray | Tensor, - per_image_tprs: ndarray | Tensor, - image_classes: ndarray | Tensor, - aupimos: ndarray | Tensor, - expected_threshs: ndarray | Tensor, - expected_shared_fpr: ndarray | Tensor, - expected_per_image_tprs: ndarray | Tensor, - expected_image_classes: ndarray | Tensor, - expected_aupimos: ndarray | Tensor, + threshs: Tensor, + shared_fpr: Tensor, + per_image_tprs: Tensor, + image_classes: Tensor, + aupimos: Tensor, + expected_threshs: Tensor, + expected_shared_fpr: Tensor, + expected_per_image_tprs: Tensor, + expected_image_classes: Tensor, + expected_aupimos: Tensor, ) -> None: _do_test_pimo_outputs( threshs, @@ -319,60 +263,22 @@ def _do_test_aupimo_outputs( expected_per_image_tprs, expected_image_classes, ) - if isinstance(threshs, Tensor): - assert isinstance(aupimos, Tensor) - assert isinstance(expected_aupimos, Tensor) - allclose = torch.allclose - - elif isinstance(threshs, ndarray): - assert isinstance(aupimos, ndarray) - assert isinstance(expected_aupimos, ndarray) - allclose = np.allclose + assert isinstance(aupimos, Tensor) + assert isinstance(expected_aupimos, Tensor) + allclose = torch.allclose assert tuple(aupimos.shape) == (3,) assert allclose(aupimos, expected_aupimos, equal_nan=True) -def test_aupimo_values_numpy( - anomaly_maps: ndarray, - masks: ndarray, - fpr_bounds: tuple[float, float], - expected_threshs: ndarray, - expected_shared_fpr: ndarray, - expected_per_image_tprs: ndarray, - expected_image_classes: ndarray, - expected_aupimos: ndarray, -) -> None: - """Test if `aupimo()` returns the expected values.""" - threshs, shared_fpr, per_image_tprs, image_classes, aupimos, _ = pimo_numpy.aupimo_scores( - anomaly_maps, - masks, - num_threshs=7, - fpr_bounds=fpr_bounds, - force=True, - ) - _do_test_aupimo_outputs( - threshs, - shared_fpr, - per_image_tprs, - image_classes, - aupimos, - expected_threshs, - expected_shared_fpr, - expected_per_image_tprs, - expected_image_classes, - expected_aupimos, - ) - - def test_aupimo_values( - anomaly_maps: ndarray, - masks: ndarray, + anomaly_maps: torch.Tensor, + masks: torch.Tensor, fpr_bounds: tuple[float, float], - expected_threshs: ndarray, - expected_shared_fpr: ndarray, - expected_per_image_tprs: ndarray, - expected_image_classes: ndarray, - expected_aupimos: ndarray, + expected_threshs: torch.Tensor, + expected_shared_fpr: torch.Tensor, + expected_per_image_tprs: torch.Tensor, + expected_image_classes: torch.Tensor, + expected_aupimos: torch.Tensor, ) -> None: """Test if `aupimo()` returns the expected values.""" @@ -441,9 +347,10 @@ def do_assertions(pimoresult: PIMOResult, aupimoresult: AUPIMOResult) -> None: def test_aupimo_edge( - anomaly_maps: ndarray, - masks: ndarray, + anomaly_maps: torch.Tensor, + masks: torch.Tensor, fpr_bounds: tuple[float, float], + caplog: pytest.LogCaptureFixture, ) -> None: """Test some edge cases.""" # None is the case of testing the default bounds @@ -452,7 +359,7 @@ def test_aupimo_edge( # not enough points on the curve # 10 threshs / 6 decades = 1.6 threshs per decade < 3 with pytest.raises(RuntimeError): # force=False --> raise error - pimo_numpy.aupimo_scores( + functional.aupimo_scores( anomaly_maps, masks, num_threshs=10, @@ -460,19 +367,20 @@ def test_aupimo_edge( **fpr_bounds, ) - with pytest.warns(RuntimeWarning): # force=True --> warn - pimo_numpy.aupimo_scores( + with caplog.at_level(logging.WARNING): # force=True --> warn + functional.aupimo_scores( anomaly_maps, masks, num_threshs=10, force=True, **fpr_bounds, ) + assert "Computation was forced!" in caplog.text # default number of points on the curve (300k threshs) should be enough - rng = np.random.default_rng(42) - pimo_numpy.aupimo_scores( - anomaly_maps * rng.uniform(1.0, 1.1, size=anomaly_maps.shape), + torch.manual_seed(42) + functional.aupimo_scores( + anomaly_maps * torch.FloatTensor(anomaly_maps.shape).uniform_(1.0, 1.1), masks, force=False, **fpr_bounds,