Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 0e33b0b

Browse files
authored
Return consistent types from metrics (#4632)
* Return consistent types from metrics * Changelog * Remove unused import
1 parent 2df364f commit 0e33b0b

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
### Fixed
2121

2222
- Ignore *args when constructing classes with `FromParams`.
23+
- Ensured some consistency in the types of the values that metrics return
2324

2425
## [v1.1.0](https://github.com/allenai/allennlp/releases/tag/v1.1.0) - 2020-09-08
2526

allennlp/training/metrics/average.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def get_metric(self, reset: bool = False):
4949
The average of all values that were passed to `__call__`.
5050
"""
5151

52-
average_value = self._total_value / self._count if self._count > 0 else 0
52+
average_value = self._total_value / self._count if self._count > 0 else 0.0
5353
if reset:
5454
self.reset()
55-
return average_value
55+
return float(average_value)
5656

5757
@overrides
5858
def reset(self):

allennlp/training/metrics/perplexity.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from overrides import overrides
2-
import torch
2+
import math
33

44
from allennlp.training.metrics.average import Average
55
from allennlp.training.metrics.metric import Metric
@@ -26,9 +26,7 @@ def get_metric(self, reset: bool = False):
2626
"""
2727
average_loss = super().get_metric(reset)
2828
if average_loss == 0:
29-
perplexity = 0.0
29+
return 0.0
3030

3131
# Exponentiate the loss to compute perplexity
32-
perplexity = float(torch.exp(average_loss))
33-
34-
return perplexity
32+
return math.exp(average_loss)

0 commit comments

Comments
 (0)