Skip to content

Metrics fixes and cleanup #2325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gpytorch/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
mean_standardized_log_loss,
negative_log_predictive_density,
quantile_coverage_error,
standardized_mean_squared_error,
)

__all__ = [
"mean_absolute_error",
"mean_squared_error",
"standardized_mean_squared_error",
"mean_standardized_log_loss",
"negative_log_predictive_density",
"quantile_coverage_error",
Expand Down
47 changes: 39 additions & 8 deletions gpytorch/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import pi
from typing import Optional

import torch

Expand All @@ -12,7 +13,7 @@ def mean_absolute_error(
test_y: torch.Tensor,
):
"""
Mean Absolute Error.
Mean absolute error.
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
return torch.abs(pred_dist.mean - test_y).mean(dim=combine_dim)
Expand All @@ -24,7 +25,7 @@ def mean_squared_error(
squared: bool = True,
):
"""
Mean Squared Error.
Mean squared error.
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
res = torch.square(pred_dist.mean - test_y).mean(dim=combine_dim)
Expand All @@ -33,29 +34,59 @@ def mean_squared_error(
return res


def standardized_mean_squared_error(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
"""Standardized mean squared error.

Standardizes the mean squared error by the variance of the test data.
"""
return mean_squared_error(pred_dist, test_y, squared=True) / test_y.var()


def negative_log_predictive_density(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
"""Negative log predictive density.

Computes the negative predictive log density normalized by the size of the test data.
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
return -pred_dist.log_prob(test_y) / test_y.shape[combine_dim]


def mean_standardized_log_loss(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
train_y: Optional[torch.Tensor] = None,
):
"""
Mean Standardized Log Loss.
Reference: Page No. 23,
Gaussian Processes for Machine Learning,
Carl Edward Rasmussen and Christopher K. I. Williams,
The MIT Press, 2006. ISBN 0-262-18253-X
Mean standardized log loss.

Computes the average *standardized* log loss, which subtracts the loss obtained
under the trivial model which predicts with the mean and variance of the training
data from the mean log loss. See p.23 of Rasmussen and Williams (2006).

If no training data is supplied, the mean log loss is computed.
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1

f_mean = pred_dist.mean
f_var = pred_dist.variance
return (0.5 * torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim)
loss_model = (0.5 * torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim)
res = loss_model

if train_y is not None:
data_mean = train_y.mean(dim=combine_dim)
data_var = train_y.var()
loss_trivial_model = (
0.5 * torch.log(2 * pi * data_var) + torch.square(test_y - data_mean) / (2 * data_var)
).mean(dim=combine_dim)
res -= loss_trivial_model

return res


def quantile_coverage_error(
Expand Down
4 changes: 4 additions & 0 deletions test/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
mean_standardized_log_loss,
negative_log_predictive_density,
quantile_coverage_error,
standardized_mean_squared_error,
)
from gpytorch.models import ExactGP

Expand Down Expand Up @@ -126,6 +127,9 @@ def test_negative_log_predictive_density(self):
def test_mean_standardized_log_loss(self):
self._test_metric(mean_standardized_log_loss)

def test_standardized_mean_squared_error(self):
self._test_metric(standardized_mean_squared_error)
Comment on lines +130 to +131
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is a kind of interesting test in that it only executes the metric code, without really checking what the computation does. Not really in scope for this PR but this should probably do something more reasonable...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Treating this as out of scope for this PR for now, but created an issue for it: #2326 .


def test_quantile_coverage_error(self):
self._test_metric(
quantile_coverage_error,
Expand Down