Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 18daa29

Browse files
JackKuo666matt-gardner
authored andcommitted
Fix pearson correlation.py (#3101)
* fix bug: ZeroDivisionError: float division by zero Since the input tensor may be, for example, a tensor ([[0.,0.,0.,0.], [0.,0.,0.,0. ]), there will be a case where (math.sqrt(predictions_variance) or math.sqrt(labels_variance)) is zero, so a judgment is added here to prevent the denominator from being zero. If it is zero, the denominator is assigned a value of 1. * fix bug: ZeroDivisionError: float division by zero Since the input tensor may be, for example, a tensor ([[0.,0.,0.,0.], [0.,0.,0.,0. ]), there will be a case where (math.sqrt(predictions_variance) or math.sqrt(labels_variance)) is zero, so a judgment is added here to prevent the denominator from being zero. If it is zero, the denominator is assigned a value of 1. * fix bug: ZeroDivisionError: float division by zero Since the input tensor may be, for example, a tensor ([[0.,0.,0.,0.], [0.,0.,0.,0. ]), there will be a case where (math.sqrt(predictions_variance) or math.sqrt(labels_variance)) is zero, so a judgment is added here to prevent the denominator from being zero. If it is zero, the pearson_r is assigned a value of 0. * fix bug: ZeroDivisionError: float division by zero Since the input tensor may be, for example, a tensor ([[0.,0.,0.,0.], [0.,0.,0.,0. ]), there will be a case where (math.sqrt(predictions_variance) or math.sqrt(labels_variance)) is zero, so a judgment is added here to prevent the denominator from being zero. If it is zero, the pearson_r is assigned a value of 0. * fix some pylint things fix some pylint things * Update pearson_correlation.py * Update pearson_correlation_test.py * Update pearson_correlation_test.py
1 parent e641543 commit 18daa29

File tree

2 files changed

+72
-34
lines changed

2 files changed

+72
-34
lines changed

allennlp/tests/training/metrics/pearson_correlation_test.py

+64-33
Original file line numberDiff line numberDiff line change
@@ -7,54 +7,85 @@
77
from allennlp.training.metrics import PearsonCorrelation
88

99

10+
def pearson_corrcoef(predictions, labels, fweights=None):
11+
covariance_matrices = np.cov(predictions, labels, fweights=fweights)
12+
denominator = np.sqrt(covariance_matrices[0, 0] * covariance_matrices[1, 1])
13+
if np.around(denominator, decimals=5) == 0:
14+
expected_pearson_correlation = 0
15+
else:
16+
expected_pearson_correlation = covariance_matrices[0, 1] / denominator
17+
return expected_pearson_correlation
18+
19+
1020
class PearsonCorrelationTest(AllenNlpTestCase):
1121
def test_pearson_correlation_unmasked_computation(self):
1222
pearson_correlation = PearsonCorrelation()
1323
batch_size = 100
1424
num_labels = 10
15-
predictions = np.random.randn(batch_size, num_labels).astype("float32")
16-
labels = 0.5 * predictions + np.random.randn(batch_size, num_labels).astype("float32")
25+
predictions_1 = np.random.randn(batch_size, num_labels).astype("float32")
26+
labels_1 = 0.5 * predictions_1 + np.random.randn(batch_size, num_labels).astype("float32")
27+
28+
predictions_2 = np.random.randn(1).repeat(num_labels).astype("float32")
29+
predictions_2 = predictions_2[np.newaxis, :].repeat(batch_size, axis=0)
30+
labels_2 = np.random.randn(1).repeat(num_labels).astype("float32")
31+
labels_2 = 0.5 * predictions_2 + labels_2[np.newaxis, :].repeat(batch_size, axis=0)
32+
33+
# in most cases, the data is constructed like predictions_1, the data of such a batch different.
34+
# but in a few cases, for example, predictions_2, the data of such a batch is exactly the same.
35+
predictions_labels = [(predictions_1, labels_1), (predictions_2, labels_2)]
1736

1837
stride = 10
1938

20-
for i in range(batch_size // stride):
21-
timestep_predictions = torch.FloatTensor(predictions[stride * i:stride * (i+1), :])
22-
timestep_labels = torch.FloatTensor(labels[stride * i:stride * (i+1), :])
23-
expected_pearson_correlation = np.corrcoef(predictions[:stride * (i + 1), :].reshape(-1),
24-
labels[:stride * (i + 1), :].reshape(-1))[0, 1]
25-
pearson_correlation(timestep_predictions, timestep_labels)
26-
assert_allclose(expected_pearson_correlation, pearson_correlation.get_metric(), rtol=1e-5)
27-
# Test reset
28-
pearson_correlation.reset()
29-
pearson_correlation(torch.FloatTensor(predictions), torch.FloatTensor(labels))
30-
assert_allclose(np.corrcoef(predictions.reshape(-1), labels.reshape(-1))[0, 1],
31-
pearson_correlation.get_metric(), rtol=1e-5)
39+
for predictions, labels in predictions_labels:
40+
pearson_correlation.reset()
41+
for i in range(batch_size // stride):
42+
timestep_predictions = torch.FloatTensor(predictions[stride * i:stride * (i + 1), :])
43+
timestep_labels = torch.FloatTensor(labels[stride * i:stride * (i + 1), :])
44+
expected_pearson_correlation = pearson_corrcoef(predictions[:stride * (i + 1), :].reshape(-1),
45+
labels[:stride * (i + 1), :].reshape(-1))
46+
pearson_correlation(timestep_predictions, timestep_labels)
47+
assert_allclose(expected_pearson_correlation, pearson_correlation.get_metric(), rtol=1e-5)
48+
# Test reset
49+
pearson_correlation.reset()
50+
pearson_correlation(torch.FloatTensor(predictions), torch.FloatTensor(labels))
51+
assert_allclose(pearson_corrcoef(predictions.reshape(-1), labels.reshape(-1)),
52+
pearson_correlation.get_metric(), rtol=1e-5)
3253

3354
def test_pearson_correlation_masked_computation(self):
3455
pearson_correlation = PearsonCorrelation()
3556
batch_size = 100
3657
num_labels = 10
37-
predictions = np.random.randn(batch_size, num_labels).astype("float32")
38-
labels = 0.5 * predictions + np.random.randn(batch_size, num_labels).astype("float32")
58+
predictions_1 = np.random.randn(batch_size, num_labels).astype("float32")
59+
labels_1 = 0.5 * predictions_1 + np.random.randn(batch_size, num_labels).astype("float32")
60+
61+
predictions_2 = np.random.randn(1).repeat(num_labels).astype("float32")
62+
predictions_2 = predictions_2[np.newaxis, :].repeat(batch_size, axis=0)
63+
labels_2 = np.random.randn(1).repeat(num_labels).astype("float32")
64+
labels_2 = 0.5 * predictions_2 + labels_2[np.newaxis, :].repeat(batch_size, axis=0)
65+
66+
predictions_labels = [(predictions_1, labels_1), (predictions_2, labels_2)]
67+
3968
# Random binary mask
4069
mask = np.random.randint(0, 2, size=(batch_size, num_labels)).astype("float32")
4170
stride = 10
4271

43-
for i in range(batch_size // stride):
44-
timestep_predictions = torch.FloatTensor(predictions[stride * i:stride * (i+1), :])
45-
timestep_labels = torch.FloatTensor(labels[stride * i:stride * (i+1), :])
46-
timestep_mask = torch.FloatTensor(mask[stride * i:stride * (i+1), :])
47-
covariance_matrices = np.cov(predictions[:stride * (i + 1), :].reshape(-1),
48-
labels[:stride * (i + 1), :].reshape(-1),
49-
fweights=mask[:stride * (i + 1), :].reshape(-1))
50-
expected_pearson_correlation = covariance_matrices[0, 1] / np.sqrt(covariance_matrices[0, 0] *
51-
covariance_matrices[1, 1])
52-
pearson_correlation(timestep_predictions, timestep_labels, timestep_mask)
72+
for predictions, labels in predictions_labels:
73+
pearson_correlation.reset()
74+
for i in range(batch_size // stride):
75+
timestep_predictions = torch.FloatTensor(predictions[stride * i:stride * (i + 1), :])
76+
timestep_labels = torch.FloatTensor(labels[stride * i:stride * (i + 1), :])
77+
timestep_mask = torch.FloatTensor(mask[stride * i:stride * (i + 1), :])
78+
expected_pearson_correlation = pearson_corrcoef(predictions[:stride * (i + 1), :].reshape(-1),
79+
labels[:stride * (i + 1), :].reshape(-1),
80+
fweights=mask[:stride * (i + 1), :].reshape(-1))
81+
82+
pearson_correlation(timestep_predictions, timestep_labels, timestep_mask)
83+
assert_allclose(expected_pearson_correlation, pearson_correlation.get_metric(), rtol=1e-5)
84+
# Test reset
85+
pearson_correlation.reset()
86+
pearson_correlation(torch.FloatTensor(predictions),
87+
torch.FloatTensor(labels), torch.FloatTensor(mask))
88+
expected_pearson_correlation = pearson_corrcoef(predictions.reshape(-1), labels.reshape(-1),
89+
fweights=mask.reshape(-1))
90+
5391
assert_allclose(expected_pearson_correlation, pearson_correlation.get_metric(), rtol=1e-5)
54-
# Test reset
55-
pearson_correlation.reset()
56-
pearson_correlation(torch.FloatTensor(predictions), torch.FloatTensor(labels), torch.FloatTensor(mask))
57-
covariance_matrices = np.cov(predictions.reshape(-1), labels.reshape(-1), fweights=mask.reshape(-1))
58-
expected_pearson_correlation = covariance_matrices[0, 1] / np.sqrt(covariance_matrices[0, 0] *
59-
covariance_matrices[1, 1])
60-
assert_allclose(expected_pearson_correlation, pearson_correlation.get_metric(), rtol=1e-5)

allennlp/training/metrics/pearson_correlation.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22
import math
3+
import numpy as np
34

45
from overrides import overrides
56
import torch
@@ -29,6 +30,8 @@ class PearsonCorrelation(Metric):
2930
If we have these values, the sample Pearson correlation coefficient is simply:
3031
3132
r = covariance / (sqrt(predictions_variance) * sqrt(labels_variance))
33+
34+
if predictions_variance or labels_variance is 0, r is 0
3235
"""
3336
def __init__(self) -> None:
3437
self._predictions_labels_covariance = Covariance()
@@ -65,7 +68,11 @@ def get_metric(self, reset: bool = False):
6568
labels_variance = self._labels_variance.get_metric(reset=reset)
6669
if reset:
6770
self.reset()
68-
pearson_r = covariance / (math.sqrt(predictions_variance) * math.sqrt(labels_variance))
71+
denominator = (math.sqrt(predictions_variance) * math.sqrt(labels_variance))
72+
if np.around(denominator, decimals=5) == 0:
73+
pearson_r = 0
74+
else:
75+
pearson_r = covariance / denominator
6976
return pearson_r
7077

7178
@overrides

0 commit comments

Comments
 (0)