Skip to content

Commit 4e117cf

Browse files
PIMO: Port Numpy → Torch (#2316)
* remove numba Signed-off-by: Ashwin Vaidya <[email protected]> * fix pre-commit checks Signed-off-by: Ashwin Vaidya <[email protected]> * remove all unused methods Signed-off-by: Ashwin Vaidya <[email protected]> * replace numpy with torch Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 1e29cf4 commit 4e117cf

File tree

10 files changed

+643
-1063
lines changed

10 files changed

+643
-1063
lines changed

src/anomalib/metrics/per_image/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
# Copyright (C) 2024 Intel Corporation
88
# SPDX-License-Identifier: Apache-2.0
99

10-
from .binclf_curve_numpy import BinclfThreshsChoice
10+
from .binclf_curve import BinclfThreshsChoice
11+
from .enums import StatsOutliersPolicy, StatsRepeatedPolicy
1112
from .pimo import AUPIMO, PIMO, AUPIMOResult, PIMOResult, aupimo_scores, pimo_curves
1213
from .utils import (
1314
compare_models_pairwise_ttest_rel,
1415
compare_models_pairwise_wilcoxon,
1516
format_pairwise_tests_results,
1617
per_image_scores_stats,
1718
)
18-
from .utils_numpy import StatsOutliersPolicy, StatsRepeatedPolicy
1919

2020
__all__ = [
2121
# constants

src/anomalib/metrics/per_image/_validate.py

Lines changed: 67 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Utils for validating arguments and results.
22
3-
`torch` is imported in the functions that use it, so this module can be used in numpy-standalone mode.
4-
53
TODO(jpcbertoldo): Move validations to a common place and reuse them across the codebase.
64
https://github.com/openvinotoolkit/anomalib/issues/2093
75
"""
@@ -13,21 +11,8 @@
1311
# Copyright (C) 2024 Intel Corporation
1412
# SPDX-License-Identifier: Apache-2.0
1513

16-
from typing import Any
17-
18-
import numpy as np
19-
from numpy import ndarray
20-
21-
22-
def is_tensor(tensor: Any, argname: str | None = None) -> None: # noqa: ANN401
23-
"""Validate that `tensor` is a `torch.Tensor`."""
24-
from torch import Tensor
25-
26-
argname = f"'{argname}'" if argname is not None else "argument"
27-
28-
if not isinstance(tensor, Tensor):
29-
msg = f"Expected {argname} to be a tensor, but got {type(tensor)}"
30-
raise TypeError(msg)
14+
import torch
15+
from torch import Tensor
3116

3217

3318
def is_num_threshs_gte2(num_threshs: int) -> None:
@@ -98,22 +83,22 @@ def is_rate_range(bounds: tuple[float, float]) -> None:
9883
raise ValueError(msg)
9984

10085

101-
def is_threshs(threshs: ndarray) -> None:
86+
def is_threshs(threshs: Tensor) -> None:
10287
"""Validate that the thresholds are valid and monotonically increasing."""
103-
if not isinstance(threshs, ndarray):
104-
msg = f"Expected thresholds to be an ndarray, but got {type(threshs)}"
88+
if not isinstance(threshs, Tensor):
89+
msg = f"Expected thresholds to be an Tensor, but got {type(threshs)}"
10590
raise TypeError(msg)
10691

10792
if threshs.ndim != 1:
10893
msg = f"Expected thresholds to be 1D, but got {threshs.ndim}"
10994
raise ValueError(msg)
11095

111-
if threshs.dtype.kind != "f":
112-
msg = f"Expected thresholds to be of float type, but got ndarray with dtype {threshs.dtype}"
96+
if not threshs.dtype.is_floating_point:
97+
msg = f"Expected thresholds to be of float type, but got Tensor with dtype {threshs.dtype}"
11398
raise TypeError(msg)
11499

115100
# make sure they are strictly increasing
116-
if not np.all(np.diff(threshs) > 0):
101+
if not torch.all(torch.diff(threshs) > 0):
117102
msg = "Expected thresholds to be strictly increasing, but it is not."
118103
raise ValueError(msg)
119104

@@ -142,55 +127,55 @@ def is_thresh_bounds(thresh_bounds: tuple[float, float]) -> None:
142127
raise ValueError(msg)
143128

144129

145-
def is_anomaly_maps(anomaly_maps: ndarray) -> None:
146-
if not isinstance(anomaly_maps, ndarray):
147-
msg = f"Expected anomaly maps to be an ndarray, but got {type(anomaly_maps)}"
130+
def is_anomaly_maps(anomaly_maps: Tensor) -> None:
131+
if not isinstance(anomaly_maps, Tensor):
132+
msg = f"Expected anomaly maps to be an Tensor, but got {type(anomaly_maps)}"
148133
raise TypeError(msg)
149134

150135
if anomaly_maps.ndim != 3:
151136
msg = f"Expected anomaly maps have 3 dimensions (N, H, W), but got {anomaly_maps.ndim} dimensions"
152137
raise ValueError(msg)
153138

154-
if anomaly_maps.dtype.kind != "f":
139+
if not anomaly_maps.dtype.is_floating_point:
155140
msg = (
156-
"Expected anomaly maps to be an floating ndarray with anomaly scores,"
157-
f" but got ndarray with dtype {anomaly_maps.dtype}"
141+
"Expected anomaly maps to be an floating Tensor with anomaly scores,"
142+
f" but got Tensor with dtype {anomaly_maps.dtype}"
158143
)
159144
raise TypeError(msg)
160145

161146

162-
def is_masks(masks: ndarray) -> None:
163-
if not isinstance(masks, ndarray):
164-
msg = f"Expected masks to be an ndarray, but got {type(masks)}"
147+
def is_masks(masks: Tensor) -> None:
148+
if not isinstance(masks, Tensor):
149+
msg = f"Expected masks to be an Tensor, but got {type(masks)}"
165150
raise TypeError(msg)
166151

167152
if masks.ndim != 3:
168153
msg = f"Expected masks have 3 dimensions (N, H, W), but got {masks.ndim} dimensions"
169154
raise ValueError(msg)
170155

171-
if masks.dtype.kind == "b":
156+
if masks.dtype == torch.bool:
172157
pass
173-
174-
elif masks.dtype.kind in {"i", "u"}:
175-
masks_unique_vals = np.unique(masks)
176-
if np.any((masks_unique_vals != 0) & (masks_unique_vals != 1)):
177-
msg = (
178-
"Expected masks to be a *binary* ndarray with ground truth labels, "
179-
f"but got ndarray with unique values {sorted(masks_unique_vals)}"
180-
)
181-
raise ValueError(msg)
182-
183-
else:
158+
elif masks.dtype.is_floating_point:
184159
msg = (
185-
"Expected masks to be an integer or boolean ndarray with ground truth labels, "
186-
f"but got ndarray with dtype {masks.dtype}"
160+
"Expected masks to be an integer or boolean Tensor with ground truth labels, "
161+
f"but got Tensor with dtype {masks.dtype}"
187162
)
188163
raise TypeError(msg)
164+
else:
165+
# assumes the type to be (signed or unsigned) integer
166+
# this will change with the dataclass refactor
167+
masks_unique_vals = torch.unique(masks)
168+
if torch.any((masks_unique_vals != 0) & (masks_unique_vals != 1)):
169+
msg = (
170+
"Expected masks to be a *binary* Tensor with ground truth labels, "
171+
f"but got Tensor with unique values {sorted(masks_unique_vals)}"
172+
)
173+
raise ValueError(msg)
189174

190175

191-
def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> None:
192-
if not isinstance(binclf_curves, ndarray):
193-
msg = f"Expected binclf curves to be an ndarray, but got {type(binclf_curves)}"
176+
def is_binclf_curves(binclf_curves: Tensor, valid_threshs: Tensor | None) -> None:
177+
if not isinstance(binclf_curves, Tensor):
178+
msg = f"Expected binclf curves to be an Tensor, but got {type(binclf_curves)}"
194179
raise TypeError(msg)
195180

196181
if binclf_curves.ndim != 4:
@@ -201,7 +186,7 @@ def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> N
201186
msg = f"Expected binclf curves to have shape (..., 2, 2), but got {binclf_curves.shape}"
202187
raise ValueError(msg)
203188

204-
if binclf_curves.dtype != np.int64:
189+
if binclf_curves.dtype != torch.int64:
205190
msg = f"Expected binclf curves to have dtype int64, but got {binclf_curves.dtype}."
206191
raise TypeError(msg)
207192

@@ -232,47 +217,49 @@ def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> N
232217
raise RuntimeError(msg)
233218

234219

235-
def is_images_classes(images_classes: ndarray) -> None:
236-
if not isinstance(images_classes, ndarray):
237-
msg = f"Expected image classes to be an ndarray, but got {type(images_classes)}."
220+
def is_images_classes(images_classes: Tensor) -> None:
221+
if not isinstance(images_classes, Tensor):
222+
msg = f"Expected image classes to be an Tensor, but got {type(images_classes)}."
238223
raise TypeError(msg)
239224

240225
if images_classes.ndim != 1:
241226
msg = f"Expected image classes to be 1D, but got {images_classes.ndim}D."
242227
raise ValueError(msg)
243228

244-
if images_classes.dtype.kind == "b":
229+
if images_classes.dtype == torch.bool:
245230
pass
246-
elif images_classes.dtype.kind in {"i", "u"}:
247-
unique_vals = np.unique(images_classes)
248-
if np.any((unique_vals != 0) & (unique_vals != 1)):
249-
msg = (
250-
"Expected image classes to be a *binary* ndarray with ground truth labels, "
251-
f"but got ndarray with unique values {sorted(unique_vals)}"
252-
)
253-
raise ValueError(msg)
254-
else:
231+
elif images_classes.dtype.is_floating_point:
255232
msg = (
256-
"Expected image classes to be an integer or boolean ndarray with ground truth labels, "
257-
f"but got ndarray with dtype {images_classes.dtype}"
233+
"Expected image classes to be an integer or boolean Tensor with ground truth labels, "
234+
f"but got Tensor with dtype {images_classes.dtype}"
258235
)
259236
raise TypeError(msg)
237+
else:
238+
# assumes the type to be (signed or unsigned) integer
239+
# this will change with the dataclass refactor
240+
unique_vals = torch.unique(images_classes)
241+
if torch.any((unique_vals != 0) & (unique_vals != 1)):
242+
msg = (
243+
"Expected image classes to be a *binary* Tensor with ground truth labels, "
244+
f"but got Tensor with unique values {sorted(unique_vals)}"
245+
)
246+
raise ValueError(msg)
260247

261248

262-
def is_rates(rates: ndarray, nan_allowed: bool) -> None:
263-
if not isinstance(rates, ndarray):
264-
msg = f"Expected rates to be an ndarray, but got {type(rates)}."
249+
def is_rates(rates: Tensor, nan_allowed: bool) -> None:
250+
if not isinstance(rates, Tensor):
251+
msg = f"Expected rates to be an Tensor, but got {type(rates)}."
265252
raise TypeError(msg)
266253

267254
if rates.ndim != 1:
268255
msg = f"Expected rates to be 1D, but got {rates.ndim}D."
269256
raise ValueError(msg)
270257

271-
if rates.dtype.kind != "f":
258+
if not rates.dtype.is_floating_point:
272259
msg = f"Expected rates to have dtype of float type, but got {rates.dtype}."
273260
raise ValueError(msg)
274261

275-
isnan_mask = np.isnan(rates)
262+
isnan_mask = torch.isnan(rates)
276263
if nan_allowed:
277264
# if they are all nan, then there is nothing to validate
278265
if isnan_mask.all():
@@ -293,11 +280,11 @@ def is_rates(rates: ndarray, nan_allowed: bool) -> None:
293280
raise ValueError(msg)
294281

295282

296-
def is_rate_curve(rate_curve: ndarray, nan_allowed: bool, decreasing: bool) -> None:
283+
def is_rate_curve(rate_curve: Tensor, nan_allowed: bool, decreasing: bool) -> None:
297284
is_rates(rate_curve, nan_allowed=nan_allowed)
298285

299-
diffs = np.diff(rate_curve)
300-
diffs_valid = diffs[~np.isnan(diffs)] if nan_allowed else diffs
286+
diffs = torch.diff(rate_curve)
287+
diffs_valid = diffs[~torch.isnan(diffs)] if nan_allowed else diffs
301288

302289
if decreasing and (diffs_valid > 0).any():
303290
msg = "Expected rate curve to be monotonically decreasing, but got non-monotonically decreasing values."
@@ -308,20 +295,20 @@ def is_rate_curve(rate_curve: ndarray, nan_allowed: bool, decreasing: bool) -> N
308295
raise ValueError(msg)
309296

310297

311-
def is_per_image_rate_curves(rate_curves: ndarray, nan_allowed: bool, decreasing: bool | None) -> None:
312-
if not isinstance(rate_curves, ndarray):
313-
msg = f"Expected per-image rate curves to be an ndarray, but got {type(rate_curves)}."
298+
def is_per_image_rate_curves(rate_curves: Tensor, nan_allowed: bool, decreasing: bool | None) -> None:
299+
if not isinstance(rate_curves, Tensor):
300+
msg = f"Expected per-image rate curves to be an Tensor, but got {type(rate_curves)}."
314301
raise TypeError(msg)
315302

316303
if rate_curves.ndim != 2:
317304
msg = f"Expected per-image rate curves to be 2D, but got {rate_curves.ndim}D."
318305
raise ValueError(msg)
319306

320-
if rate_curves.dtype.kind != "f":
307+
if not rate_curves.dtype.is_floating_point:
321308
msg = f"Expected per-image rate curves to have dtype of float type, but got {rate_curves.dtype}."
322309
raise ValueError(msg)
323310

324-
isnan_mask = np.isnan(rate_curves)
311+
isnan_mask = torch.isnan(rate_curves)
325312
if nan_allowed:
326313
# if they are all nan, then there is nothing to validate
327314
if isnan_mask.all():
@@ -344,8 +331,8 @@ def is_per_image_rate_curves(rate_curves: ndarray, nan_allowed: bool, decreasing
344331
if decreasing is None:
345332
return
346333

347-
diffs = np.diff(rate_curves, axis=1)
348-
diffs_valid = diffs[~np.isnan(diffs)] if nan_allowed else diffs
334+
diffs = torch.diff(rate_curves, axis=1)
335+
diffs_valid = diffs[~torch.isnan(diffs)] if nan_allowed else diffs
349336

350337
if decreasing and (diffs_valid > 0).any():
351338
msg = (

0 commit comments

Comments
 (0)