1
1
"""Utils for validating arguments and results.
2
2
3
- `torch` is imported in the functions that use it, so this module can be used in numpy-standalone mode.
4
-
5
3
TODO(jpcbertoldo): Move validations to a common place and reuse them across the codebase.
6
4
https://github.com/openvinotoolkit/anomalib/issues/2093
7
5
"""
13
11
# Copyright (C) 2024 Intel Corporation
14
12
# SPDX-License-Identifier: Apache-2.0
15
13
16
- from typing import Any
17
-
18
- import numpy as np
19
- from numpy import ndarray
20
-
21
-
22
- def is_tensor (tensor : Any , argname : str | None = None ) -> None : # noqa: ANN401
23
- """Validate that `tensor` is a `torch.Tensor`."""
24
- from torch import Tensor
25
-
26
- argname = f"'{ argname } '" if argname is not None else "argument"
27
-
28
- if not isinstance (tensor , Tensor ):
29
- msg = f"Expected { argname } to be a tensor, but got { type (tensor )} "
30
- raise TypeError (msg )
14
+ import torch
15
+ from torch import Tensor
31
16
32
17
33
18
def is_num_threshs_gte2 (num_threshs : int ) -> None :
@@ -98,22 +83,22 @@ def is_rate_range(bounds: tuple[float, float]) -> None:
98
83
raise ValueError (msg )
99
84
100
85
101
- def is_threshs (threshs : ndarray ) -> None :
86
+ def is_threshs (threshs : Tensor ) -> None :
102
87
"""Validate that the thresholds are valid and monotonically increasing."""
103
- if not isinstance (threshs , ndarray ):
104
- msg = f"Expected thresholds to be an ndarray , but got { type (threshs )} "
88
+ if not isinstance (threshs , Tensor ):
89
+ msg = f"Expected thresholds to be an Tensor , but got { type (threshs )} "
105
90
raise TypeError (msg )
106
91
107
92
if threshs .ndim != 1 :
108
93
msg = f"Expected thresholds to be 1D, but got { threshs .ndim } "
109
94
raise ValueError (msg )
110
95
111
- if threshs .dtype .kind != "f" :
112
- msg = f"Expected thresholds to be of float type, but got ndarray with dtype { threshs .dtype } "
96
+ if not threshs .dtype .is_floating_point :
97
+ msg = f"Expected thresholds to be of float type, but got Tensor with dtype { threshs .dtype } "
113
98
raise TypeError (msg )
114
99
115
100
# make sure they are strictly increasing
116
- if not np .all (np .diff (threshs ) > 0 ):
101
+ if not torch .all (torch .diff (threshs ) > 0 ):
117
102
msg = "Expected thresholds to be strictly increasing, but it is not."
118
103
raise ValueError (msg )
119
104
@@ -142,55 +127,55 @@ def is_thresh_bounds(thresh_bounds: tuple[float, float]) -> None:
142
127
raise ValueError (msg )
143
128
144
129
145
- def is_anomaly_maps (anomaly_maps : ndarray ) -> None :
146
- if not isinstance (anomaly_maps , ndarray ):
147
- msg = f"Expected anomaly maps to be an ndarray , but got { type (anomaly_maps )} "
130
+ def is_anomaly_maps (anomaly_maps : Tensor ) -> None :
131
+ if not isinstance (anomaly_maps , Tensor ):
132
+ msg = f"Expected anomaly maps to be an Tensor , but got { type (anomaly_maps )} "
148
133
raise TypeError (msg )
149
134
150
135
if anomaly_maps .ndim != 3 :
151
136
msg = f"Expected anomaly maps have 3 dimensions (N, H, W), but got { anomaly_maps .ndim } dimensions"
152
137
raise ValueError (msg )
153
138
154
- if anomaly_maps .dtype .kind != "f" :
139
+ if not anomaly_maps .dtype .is_floating_point :
155
140
msg = (
156
- "Expected anomaly maps to be an floating ndarray with anomaly scores,"
157
- f" but got ndarray with dtype { anomaly_maps .dtype } "
141
+ "Expected anomaly maps to be an floating Tensor with anomaly scores,"
142
+ f" but got Tensor with dtype { anomaly_maps .dtype } "
158
143
)
159
144
raise TypeError (msg )
160
145
161
146
162
- def is_masks (masks : ndarray ) -> None :
163
- if not isinstance (masks , ndarray ):
164
- msg = f"Expected masks to be an ndarray , but got { type (masks )} "
147
+ def is_masks (masks : Tensor ) -> None :
148
+ if not isinstance (masks , Tensor ):
149
+ msg = f"Expected masks to be an Tensor , but got { type (masks )} "
165
150
raise TypeError (msg )
166
151
167
152
if masks .ndim != 3 :
168
153
msg = f"Expected masks have 3 dimensions (N, H, W), but got { masks .ndim } dimensions"
169
154
raise ValueError (msg )
170
155
171
- if masks .dtype . kind == "b" :
156
+ if masks .dtype == torch . bool :
172
157
pass
173
-
174
- elif masks .dtype .kind in {"i" , "u" }:
175
- masks_unique_vals = np .unique (masks )
176
- if np .any ((masks_unique_vals != 0 ) & (masks_unique_vals != 1 )):
177
- msg = (
178
- "Expected masks to be a *binary* ndarray with ground truth labels, "
179
- f"but got ndarray with unique values { sorted (masks_unique_vals )} "
180
- )
181
- raise ValueError (msg )
182
-
183
- else :
158
+ elif masks .dtype .is_floating_point :
184
159
msg = (
185
- "Expected masks to be an integer or boolean ndarray with ground truth labels, "
186
- f"but got ndarray with dtype { masks .dtype } "
160
+ "Expected masks to be an integer or boolean Tensor with ground truth labels, "
161
+ f"but got Tensor with dtype { masks .dtype } "
187
162
)
188
163
raise TypeError (msg )
164
+ else :
165
+ # assumes the type to be (signed or unsigned) integer
166
+ # this will change with the dataclass refactor
167
+ masks_unique_vals = torch .unique (masks )
168
+ if torch .any ((masks_unique_vals != 0 ) & (masks_unique_vals != 1 )):
169
+ msg = (
170
+ "Expected masks to be a *binary* Tensor with ground truth labels, "
171
+ f"but got Tensor with unique values { sorted (masks_unique_vals )} "
172
+ )
173
+ raise ValueError (msg )
189
174
190
175
191
- def is_binclf_curves (binclf_curves : ndarray , valid_threshs : ndarray | None ) -> None :
192
- if not isinstance (binclf_curves , ndarray ):
193
- msg = f"Expected binclf curves to be an ndarray , but got { type (binclf_curves )} "
176
+ def is_binclf_curves (binclf_curves : Tensor , valid_threshs : Tensor | None ) -> None :
177
+ if not isinstance (binclf_curves , Tensor ):
178
+ msg = f"Expected binclf curves to be an Tensor , but got { type (binclf_curves )} "
194
179
raise TypeError (msg )
195
180
196
181
if binclf_curves .ndim != 4 :
@@ -201,7 +186,7 @@ def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> N
201
186
msg = f"Expected binclf curves to have shape (..., 2, 2), but got { binclf_curves .shape } "
202
187
raise ValueError (msg )
203
188
204
- if binclf_curves .dtype != np .int64 :
189
+ if binclf_curves .dtype != torch .int64 :
205
190
msg = f"Expected binclf curves to have dtype int64, but got { binclf_curves .dtype } ."
206
191
raise TypeError (msg )
207
192
@@ -232,47 +217,49 @@ def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> N
232
217
raise RuntimeError (msg )
233
218
234
219
235
- def is_images_classes (images_classes : ndarray ) -> None :
236
- if not isinstance (images_classes , ndarray ):
237
- msg = f"Expected image classes to be an ndarray , but got { type (images_classes )} ."
220
+ def is_images_classes (images_classes : Tensor ) -> None :
221
+ if not isinstance (images_classes , Tensor ):
222
+ msg = f"Expected image classes to be an Tensor , but got { type (images_classes )} ."
238
223
raise TypeError (msg )
239
224
240
225
if images_classes .ndim != 1 :
241
226
msg = f"Expected image classes to be 1D, but got { images_classes .ndim } D."
242
227
raise ValueError (msg )
243
228
244
- if images_classes .dtype . kind == "b" :
229
+ if images_classes .dtype == torch . bool :
245
230
pass
246
- elif images_classes .dtype .kind in {"i" , "u" }:
247
- unique_vals = np .unique (images_classes )
248
- if np .any ((unique_vals != 0 ) & (unique_vals != 1 )):
249
- msg = (
250
- "Expected image classes to be a *binary* ndarray with ground truth labels, "
251
- f"but got ndarray with unique values { sorted (unique_vals )} "
252
- )
253
- raise ValueError (msg )
254
- else :
231
+ elif images_classes .dtype .is_floating_point :
255
232
msg = (
256
- "Expected image classes to be an integer or boolean ndarray with ground truth labels, "
257
- f"but got ndarray with dtype { images_classes .dtype } "
233
+ "Expected image classes to be an integer or boolean Tensor with ground truth labels, "
234
+ f"but got Tensor with dtype { images_classes .dtype } "
258
235
)
259
236
raise TypeError (msg )
237
+ else :
238
+ # assumes the type to be (signed or unsigned) integer
239
+ # this will change with the dataclass refactor
240
+ unique_vals = torch .unique (images_classes )
241
+ if torch .any ((unique_vals != 0 ) & (unique_vals != 1 )):
242
+ msg = (
243
+ "Expected image classes to be a *binary* Tensor with ground truth labels, "
244
+ f"but got Tensor with unique values { sorted (unique_vals )} "
245
+ )
246
+ raise ValueError (msg )
260
247
261
248
262
- def is_rates (rates : ndarray , nan_allowed : bool ) -> None :
263
- if not isinstance (rates , ndarray ):
264
- msg = f"Expected rates to be an ndarray , but got { type (rates )} ."
249
+ def is_rates (rates : Tensor , nan_allowed : bool ) -> None :
250
+ if not isinstance (rates , Tensor ):
251
+ msg = f"Expected rates to be an Tensor , but got { type (rates )} ."
265
252
raise TypeError (msg )
266
253
267
254
if rates .ndim != 1 :
268
255
msg = f"Expected rates to be 1D, but got { rates .ndim } D."
269
256
raise ValueError (msg )
270
257
271
- if rates .dtype .kind != "f" :
258
+ if not rates .dtype .is_floating_point :
272
259
msg = f"Expected rates to have dtype of float type, but got { rates .dtype } ."
273
260
raise ValueError (msg )
274
261
275
- isnan_mask = np .isnan (rates )
262
+ isnan_mask = torch .isnan (rates )
276
263
if nan_allowed :
277
264
# if they are all nan, then there is nothing to validate
278
265
if isnan_mask .all ():
@@ -293,11 +280,11 @@ def is_rates(rates: ndarray, nan_allowed: bool) -> None:
293
280
raise ValueError (msg )
294
281
295
282
296
- def is_rate_curve (rate_curve : ndarray , nan_allowed : bool , decreasing : bool ) -> None :
283
+ def is_rate_curve (rate_curve : Tensor , nan_allowed : bool , decreasing : bool ) -> None :
297
284
is_rates (rate_curve , nan_allowed = nan_allowed )
298
285
299
- diffs = np .diff (rate_curve )
300
- diffs_valid = diffs [~ np .isnan (diffs )] if nan_allowed else diffs
286
+ diffs = torch .diff (rate_curve )
287
+ diffs_valid = diffs [~ torch .isnan (diffs )] if nan_allowed else diffs
301
288
302
289
if decreasing and (diffs_valid > 0 ).any ():
303
290
msg = "Expected rate curve to be monotonically decreasing, but got non-monotonically decreasing values."
@@ -308,20 +295,20 @@ def is_rate_curve(rate_curve: ndarray, nan_allowed: bool, decreasing: bool) -> N
308
295
raise ValueError (msg )
309
296
310
297
311
- def is_per_image_rate_curves (rate_curves : ndarray , nan_allowed : bool , decreasing : bool | None ) -> None :
312
- if not isinstance (rate_curves , ndarray ):
313
- msg = f"Expected per-image rate curves to be an ndarray , but got { type (rate_curves )} ."
298
+ def is_per_image_rate_curves (rate_curves : Tensor , nan_allowed : bool , decreasing : bool | None ) -> None :
299
+ if not isinstance (rate_curves , Tensor ):
300
+ msg = f"Expected per-image rate curves to be an Tensor , but got { type (rate_curves )} ."
314
301
raise TypeError (msg )
315
302
316
303
if rate_curves .ndim != 2 :
317
304
msg = f"Expected per-image rate curves to be 2D, but got { rate_curves .ndim } D."
318
305
raise ValueError (msg )
319
306
320
- if rate_curves .dtype .kind != "f" :
307
+ if not rate_curves .dtype .is_floating_point :
321
308
msg = f"Expected per-image rate curves to have dtype of float type, but got { rate_curves .dtype } ."
322
309
raise ValueError (msg )
323
310
324
- isnan_mask = np .isnan (rate_curves )
311
+ isnan_mask = torch .isnan (rate_curves )
325
312
if nan_allowed :
326
313
# if they are all nan, then there is nothing to validate
327
314
if isnan_mask .all ():
@@ -344,8 +331,8 @@ def is_per_image_rate_curves(rate_curves: ndarray, nan_allowed: bool, decreasing
344
331
if decreasing is None :
345
332
return
346
333
347
- diffs = np .diff (rate_curves , axis = 1 )
348
- diffs_valid = diffs [~ np .isnan (diffs )] if nan_allowed else diffs
334
+ diffs = torch .diff (rate_curves , axis = 1 )
335
+ diffs_valid = diffs [~ torch .isnan (diffs )] if nan_allowed else diffs
349
336
350
337
if decreasing and (diffs_valid > 0 ).any ():
351
338
msg = (
0 commit comments