43
43
44
44
from poutyne .framework .metrics .base import Metric
45
45
from poutyne .framework .metrics .metrics_registering import register_metric_class
46
+ from poutyne .framework .metrics .predefined .bincount import _bincount
47
+ from poutyne .utils import set_deterministic_debug_mode
46
48
47
49
48
50
class FBeta (Metric ):
@@ -115,6 +117,8 @@ class FBeta(Metric):
115
117
names (Optional[Union[str, List[str]]]): The names associated to the metrics. It is a string when
116
118
a single metric is requested. It is a list of 3 strings if all metrics are requested.
117
119
(Default value = None)
120
+ make_deterministic (Optional[bool]): Avoid non-deterministic operations in computations. This might make the
121
+ code slower.
118
122
"""
119
123
120
124
def __init__ (
@@ -127,6 +131,7 @@ def __init__(
127
131
ignore_index : int = - 100 ,
128
132
threshold : float = 0.0 ,
129
133
names : Optional [Union [str , List [str ]]] = None ,
134
+ make_deterministic : Optional [bool ] = None ,
130
135
) -> None :
131
136
super ().__init__ ()
132
137
self .metric_options = ('fscore' , 'precision' , 'recall' )
@@ -154,6 +159,9 @@ def __init__(
154
159
self .ignore_index = ignore_index
155
160
self .threshold = threshold
156
161
self .__name__ = self ._get_names (names )
162
+ self .deterministic_debug_mode = (
163
+ "error" if make_deterministic is True else "default" if make_deterministic is False else None
164
+ )
157
165
158
166
# statistics
159
167
# the total number of true positive instances under each class
@@ -235,80 +243,81 @@ def update(self, y_pred: torch.Tensor, y_true: Union[torch.Tensor, Tuple[torch.T
235
243
236
244
def _update (self , y_pred : torch .Tensor , y_true : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]) -> None :
237
245
# pylint: disable=too-many-branches
238
- if isinstance (y_true , tuple ):
239
- y_true , mask = y_true
240
- mask = mask .bool ()
241
- else :
242
- mask = torch .ones_like (y_true ).bool ()
243
-
244
- if self .ignore_index is not None :
245
- mask *= y_true != self .ignore_index
246
-
247
- if y_pred .shape [0 ] == 1 :
248
- y_pred , y_true , mask = (
249
- y_pred .squeeze ().unsqueeze (0 ),
250
- y_true .squeeze ().unsqueeze (0 ),
251
- mask .squeeze ().unsqueeze (0 ),
252
- )
253
- else :
254
- y_pred , y_true , mask = y_pred .squeeze (), y_true .squeeze (), mask .squeeze ()
255
-
256
- num_classes = 2
257
- if y_pred .shape != y_true .shape :
258
- num_classes = y_pred .size (1 )
259
-
260
- if (y_true >= num_classes ).any ():
261
- raise ValueError (
262
- f"A gold label passed to FBetaMeasure contains an id >= { num_classes } , the number of classes."
263
- )
264
-
265
- if self ._average == 'binary' and num_classes > 2 :
266
- raise ValueError ("When `average` is binary, the number of prediction scores must be 2." )
267
-
268
- # It means we call this metric at the first time
269
- # when `self._true_positive_sum` is None.
270
- if self ._true_positive_sum is None :
271
- self ._true_positive_sum = torch .zeros (num_classes , device = y_pred .device )
272
- self ._true_sum = torch .zeros (num_classes , device = y_pred .device )
273
- self ._pred_sum = torch .zeros (num_classes , device = y_pred .device )
274
- self ._total_sum = torch .zeros (num_classes , device = y_pred .device )
275
-
276
- y_true = y_true .float ()
277
-
278
- if y_pred .shape != y_true .shape :
279
- argmax_y_pred = y_pred .argmax (1 ).float ()
280
- else :
281
- argmax_y_pred = (y_pred > self .threshold ).float ()
282
- true_positives = (y_true == argmax_y_pred ) * mask
283
- true_positives_bins = y_true [true_positives ]
284
-
285
- # Watch it:
286
- # The total numbers of true positives under all _predicted_ classes are zeros.
287
- if true_positives_bins .shape [0 ] == 0 :
288
- true_positive_sum = torch .zeros (num_classes , device = y_pred .device )
289
- else :
290
- true_positive_sum = torch .bincount (true_positives_bins .long (), minlength = num_classes ).float ()
291
-
292
- pred_bins = argmax_y_pred [mask ].long ()
293
- # Watch it:
294
- # When the `mask` is all 0, we will get an _empty_ tensor.
295
- if pred_bins .shape [0 ] != 0 :
296
- pred_sum = torch .bincount (pred_bins , minlength = num_classes ).float ()
297
- else :
298
- pred_sum = torch .zeros (num_classes , device = y_pred .device )
299
-
300
- y_true_bins = y_true [mask ].long ()
301
- if y_true .shape [0 ] != 0 :
302
- true_sum = torch .bincount (y_true_bins , minlength = num_classes ).float ()
303
- else :
304
- true_sum = torch .zeros (num_classes , device = y_pred .device )
305
-
306
- self ._true_positive_sum += true_positive_sum
307
- self ._pred_sum += pred_sum
308
- self ._true_sum += true_sum
309
- self ._total_sum += mask .sum ().to (torch .float )
310
-
311
- return true_positive_sum , pred_sum , true_sum
246
+ with set_deterministic_debug_mode (self .deterministic_debug_mode ):
247
+ if isinstance (y_true , tuple ):
248
+ y_true , mask = y_true
249
+ mask = mask .bool ()
250
+ else :
251
+ mask = torch .ones_like (y_true ).bool ()
252
+
253
+ if self .ignore_index is not None :
254
+ mask *= y_true != self .ignore_index
255
+
256
+ if y_pred .shape [0 ] == 1 :
257
+ y_pred , y_true , mask = (
258
+ y_pred .squeeze ().unsqueeze (0 ),
259
+ y_true .squeeze ().unsqueeze (0 ),
260
+ mask .squeeze ().unsqueeze (0 ),
261
+ )
262
+ else :
263
+ y_pred , y_true , mask = y_pred .squeeze (), y_true .squeeze (), mask .squeeze ()
264
+
265
+ num_classes = 2
266
+ if y_pred .shape != y_true .shape :
267
+ num_classes = y_pred .size (1 )
268
+
269
+ if (y_true >= num_classes ).any ():
270
+ raise ValueError (
271
+ f"A gold label passed to FBetaMeasure contains an id >= { num_classes } , the number of classes."
272
+ )
273
+
274
+ if self ._average == 'binary' and num_classes > 2 :
275
+ raise ValueError ("When `average` is binary, the number of prediction scores must be 2." )
276
+
277
+ # It means we call this metric at the first time
278
+ # when `self._true_positive_sum` is None.
279
+ if self ._true_positive_sum is None :
280
+ self ._true_positive_sum = torch .zeros (num_classes , device = y_pred .device )
281
+ self ._true_sum = torch .zeros (num_classes , device = y_pred .device )
282
+ self ._pred_sum = torch .zeros (num_classes , device = y_pred .device )
283
+ self ._total_sum = torch .zeros (num_classes , device = y_pred .device )
284
+
285
+ y_true = y_true .float ()
286
+
287
+ if y_pred .shape != y_true .shape :
288
+ argmax_y_pred = y_pred .argmax (1 ).float ()
289
+ else :
290
+ argmax_y_pred = (y_pred > self .threshold ).float ()
291
+ true_positives = (y_true == argmax_y_pred ) * mask
292
+ true_positives_bins = y_true [true_positives ]
293
+
294
+ # Watch it:
295
+ # The total numbers of true positives under all _predicted_ classes are zeros.
296
+ if true_positives_bins .shape [0 ] == 0 :
297
+ true_positive_sum = torch .zeros (num_classes , device = y_pred .device )
298
+ else :
299
+ true_positive_sum = _bincount (true_positives_bins .long (), minlength = num_classes ).float ()
300
+
301
+ pred_bins = argmax_y_pred [mask ].long ()
302
+ # Watch it:
303
+ # When the `mask` is all 0, we will get an _empty_ tensor.
304
+ if pred_bins .shape [0 ] != 0 :
305
+ pred_sum = _bincount (pred_bins , minlength = num_classes ).float ()
306
+ else :
307
+ pred_sum = torch .zeros (num_classes , device = y_pred .device )
308
+
309
+ y_true_bins = y_true [mask ].long ()
310
+ if y_true .shape [0 ] != 0 :
311
+ true_sum = _bincount (y_true_bins , minlength = num_classes ).float ()
312
+ else :
313
+ true_sum = torch .zeros (num_classes , device = y_pred .device )
314
+
315
+ self ._true_positive_sum += true_positive_sum
316
+ self ._pred_sum += pred_sum
317
+ self ._true_sum += true_sum
318
+ self ._total_sum += mask .sum ().to (torch .float )
319
+
320
+ return true_positive_sum , pred_sum , true_sum
312
321
313
322
def compute (self ) -> Union [float , Tuple [float ]]:
314
323
"""
0 commit comments