Skip to content

Commit 003b37d

Browse files
authored
Add deterministic version of the FBeta class (#167)
1 parent b368220 commit 003b37d

File tree

5 files changed

+228
-109
lines changed

5 files changed

+228
-109
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# v1.x.x
1+
# v1.17
22

3-
-
3+
- [`FBeta`](https://poutyne.org/metrics.html#poutyne.FBeta) is using the non-deterministic torch function [`bincount`](https://pytorch.org/docs/stable/generated/torch.bincount.html). Either by passing the argument `make_deterministic` to the [`FBeta`](https://poutyne.org/metrics.html#poutyne.FBeta) class or by using one of the PyTorch functions `torch.set_deterministic_debug_mode` or `torch.use_deterministic_algorithms`, you can now make this function deterministic. Note that this might make your code slower.
44

55
# v1.16
66

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
The source code of this file was copied from the torchmetrics project, and has been modified. All modifications
3+
made from the original source code are under the LGPLv3 license.
4+
5+
6+
Copyright (c) 2022 Poutyne and all respective contributors.
7+
8+
Each contributor holds copyright over their respective contributions. The project versioning (Git)
9+
records all such contribution source information on the Poutyne and AllenNLP repository.
10+
11+
This file is part of Poutyne.
12+
13+
Poutyne is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
14+
License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
15+
version.
16+
17+
Poutyne is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
18+
of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
19+
20+
You should have received a copy of the GNU Lesser General Public License along with Poutyne. If not, see
21+
<https://www.gnu.org/licenses/>.
22+
23+
24+
Copyright The PyTorch Lightning team.
25+
26+
Licensed under the Apache License, Version 2.0 (the "License");
27+
you may not use this file except in compliance with the License.
28+
You may obtain a copy of the License at
29+
30+
http://www.apache.org/licenses/LICENSE-2.0
31+
32+
Unless required by applicable law or agreed to in writing, software
33+
distributed under the License is distributed on an "AS IS" BASIS,
34+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
35+
See the License for the specific language governing permissions and
36+
limitations under the License.
37+
"""
38+
from typing import Optional
39+
40+
import torch
41+
from torch import Tensor
42+
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE
43+
44+
45+
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
46+
"""PyTorch currently does not support``torch.bincount`` for:
47+
48+
- deterministic mode on GPU.
49+
- MPS devices
50+
51+
This implementation fallback to a for-loop counting occurrences in that case.
52+
53+
Args:
54+
x: tensor to count
55+
minlength: minimum length to count
56+
57+
Returns:
58+
Number of occurrences for each unique element in x
59+
"""
60+
if minlength is None:
61+
minlength = len(torch.unique(x))
62+
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
63+
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
64+
for i in range(minlength):
65+
output[i] = (x == i).sum()
66+
return output
67+
return torch.bincount(x, minlength=minlength)

poutyne/framework/metrics/predefined/fscores.py

Lines changed: 83 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444
from poutyne.framework.metrics.base import Metric
4545
from poutyne.framework.metrics.metrics_registering import register_metric_class
46+
from poutyne.framework.metrics.predefined.bincount import _bincount
47+
from poutyne.utils import set_deterministic_debug_mode
4648

4749

4850
class FBeta(Metric):
@@ -115,6 +117,8 @@ class FBeta(Metric):
115117
names (Optional[Union[str, List[str]]]): The names associated to the metrics. It is a string when
116118
a single metric is requested. It is a list of 3 strings if all metrics are requested.
117119
(Default value = None)
120+
make_deterministic (Optional[bool]): Avoid non-deterministic operations in computations. This might make the
121+
code slower.
118122
"""
119123

120124
def __init__(
@@ -127,6 +131,7 @@ def __init__(
127131
ignore_index: int = -100,
128132
threshold: float = 0.0,
129133
names: Optional[Union[str, List[str]]] = None,
134+
make_deterministic: Optional[bool] = None,
130135
) -> None:
131136
super().__init__()
132137
self.metric_options = ('fscore', 'precision', 'recall')
@@ -154,6 +159,9 @@ def __init__(
154159
self.ignore_index = ignore_index
155160
self.threshold = threshold
156161
self.__name__ = self._get_names(names)
162+
self.deterministic_debug_mode = (
163+
"error" if make_deterministic is True else "default" if make_deterministic is False else None
164+
)
157165

158166
# statistics
159167
# the total number of true positive instances under each class
@@ -235,80 +243,81 @@ def update(self, y_pred: torch.Tensor, y_true: Union[torch.Tensor, Tuple[torch.T
235243

236244
def _update(self, y_pred: torch.Tensor, y_true: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> None:
237245
# pylint: disable=too-many-branches
238-
if isinstance(y_true, tuple):
239-
y_true, mask = y_true
240-
mask = mask.bool()
241-
else:
242-
mask = torch.ones_like(y_true).bool()
243-
244-
if self.ignore_index is not None:
245-
mask *= y_true != self.ignore_index
246-
247-
if y_pred.shape[0] == 1:
248-
y_pred, y_true, mask = (
249-
y_pred.squeeze().unsqueeze(0),
250-
y_true.squeeze().unsqueeze(0),
251-
mask.squeeze().unsqueeze(0),
252-
)
253-
else:
254-
y_pred, y_true, mask = y_pred.squeeze(), y_true.squeeze(), mask.squeeze()
255-
256-
num_classes = 2
257-
if y_pred.shape != y_true.shape:
258-
num_classes = y_pred.size(1)
259-
260-
if (y_true >= num_classes).any():
261-
raise ValueError(
262-
f"A gold label passed to FBetaMeasure contains an id >= {num_classes}, the number of classes."
263-
)
264-
265-
if self._average == 'binary' and num_classes > 2:
266-
raise ValueError("When `average` is binary, the number of prediction scores must be 2.")
267-
268-
# It means we call this metric at the first time
269-
# when `self._true_positive_sum` is None.
270-
if self._true_positive_sum is None:
271-
self._true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
272-
self._true_sum = torch.zeros(num_classes, device=y_pred.device)
273-
self._pred_sum = torch.zeros(num_classes, device=y_pred.device)
274-
self._total_sum = torch.zeros(num_classes, device=y_pred.device)
275-
276-
y_true = y_true.float()
277-
278-
if y_pred.shape != y_true.shape:
279-
argmax_y_pred = y_pred.argmax(1).float()
280-
else:
281-
argmax_y_pred = (y_pred > self.threshold).float()
282-
true_positives = (y_true == argmax_y_pred) * mask
283-
true_positives_bins = y_true[true_positives]
284-
285-
# Watch it:
286-
# The total numbers of true positives under all _predicted_ classes are zeros.
287-
if true_positives_bins.shape[0] == 0:
288-
true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
289-
else:
290-
true_positive_sum = torch.bincount(true_positives_bins.long(), minlength=num_classes).float()
291-
292-
pred_bins = argmax_y_pred[mask].long()
293-
# Watch it:
294-
# When the `mask` is all 0, we will get an _empty_ tensor.
295-
if pred_bins.shape[0] != 0:
296-
pred_sum = torch.bincount(pred_bins, minlength=num_classes).float()
297-
else:
298-
pred_sum = torch.zeros(num_classes, device=y_pred.device)
299-
300-
y_true_bins = y_true[mask].long()
301-
if y_true.shape[0] != 0:
302-
true_sum = torch.bincount(y_true_bins, minlength=num_classes).float()
303-
else:
304-
true_sum = torch.zeros(num_classes, device=y_pred.device)
305-
306-
self._true_positive_sum += true_positive_sum
307-
self._pred_sum += pred_sum
308-
self._true_sum += true_sum
309-
self._total_sum += mask.sum().to(torch.float)
310-
311-
return true_positive_sum, pred_sum, true_sum
246+
with set_deterministic_debug_mode(self.deterministic_debug_mode):
247+
if isinstance(y_true, tuple):
248+
y_true, mask = y_true
249+
mask = mask.bool()
250+
else:
251+
mask = torch.ones_like(y_true).bool()
252+
253+
if self.ignore_index is not None:
254+
mask *= y_true != self.ignore_index
255+
256+
if y_pred.shape[0] == 1:
257+
y_pred, y_true, mask = (
258+
y_pred.squeeze().unsqueeze(0),
259+
y_true.squeeze().unsqueeze(0),
260+
mask.squeeze().unsqueeze(0),
261+
)
262+
else:
263+
y_pred, y_true, mask = y_pred.squeeze(), y_true.squeeze(), mask.squeeze()
264+
265+
num_classes = 2
266+
if y_pred.shape != y_true.shape:
267+
num_classes = y_pred.size(1)
268+
269+
if (y_true >= num_classes).any():
270+
raise ValueError(
271+
f"A gold label passed to FBetaMeasure contains an id >= {num_classes}, the number of classes."
272+
)
273+
274+
if self._average == 'binary' and num_classes > 2:
275+
raise ValueError("When `average` is binary, the number of prediction scores must be 2.")
276+
277+
# It means we call this metric at the first time
278+
# when `self._true_positive_sum` is None.
279+
if self._true_positive_sum is None:
280+
self._true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
281+
self._true_sum = torch.zeros(num_classes, device=y_pred.device)
282+
self._pred_sum = torch.zeros(num_classes, device=y_pred.device)
283+
self._total_sum = torch.zeros(num_classes, device=y_pred.device)
284+
285+
y_true = y_true.float()
286+
287+
if y_pred.shape != y_true.shape:
288+
argmax_y_pred = y_pred.argmax(1).float()
289+
else:
290+
argmax_y_pred = (y_pred > self.threshold).float()
291+
true_positives = (y_true == argmax_y_pred) * mask
292+
true_positives_bins = y_true[true_positives]
293+
294+
# Watch it:
295+
# The total numbers of true positives under all _predicted_ classes are zeros.
296+
if true_positives_bins.shape[0] == 0:
297+
true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
298+
else:
299+
true_positive_sum = _bincount(true_positives_bins.long(), minlength=num_classes).float()
300+
301+
pred_bins = argmax_y_pred[mask].long()
302+
# Watch it:
303+
# When the `mask` is all 0, we will get an _empty_ tensor.
304+
if pred_bins.shape[0] != 0:
305+
pred_sum = _bincount(pred_bins, minlength=num_classes).float()
306+
else:
307+
pred_sum = torch.zeros(num_classes, device=y_pred.device)
308+
309+
y_true_bins = y_true[mask].long()
310+
if y_true.shape[0] != 0:
311+
true_sum = _bincount(y_true_bins, minlength=num_classes).float()
312+
else:
313+
true_sum = torch.zeros(num_classes, device=y_pred.device)
314+
315+
self._true_positive_sum += true_positive_sum
316+
self._pred_sum += pred_sum
317+
self._true_sum += true_sum
318+
self._total_sum += mask.sum().to(torch.float)
319+
320+
return true_positive_sum, pred_sum, true_sum
312321

313322
def compute(self) -> Union[float, Tuple[float]]:
314323
"""

poutyne/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
<https://www.gnu.org/licenses/>.
1818
"""
1919

20+
import contextlib
2021
import numbers
2122

2223
# -*- coding: utf-8 -*-
2324
import os
2425
import random
2526
import warnings
26-
from typing import IO, Any, BinaryIO, Union
27+
from typing import IO, Any, BinaryIO, Optional, Union
2728

2829
import numpy as np
2930
import torch
@@ -332,3 +333,15 @@ def is_torch_or_numpy(v):
332333
"tensor or a Numpy array.\n"
333334
)
334335
return 1
336+
337+
338+
@contextlib.contextmanager
339+
def set_deterministic_debug_mode(mode: Optional[Union[str, int]]):
340+
if mode is None:
341+
yield
342+
return
343+
344+
old_mode = torch.get_deterministic_debug_mode()
345+
torch.set_deterministic_debug_mode(mode)
346+
yield
347+
torch.set_deterministic_debug_mode(old_mode)

0 commit comments

Comments
 (0)