diff --git a/gpytorch/metrics/__init__.py b/gpytorch/metrics/__init__.py index 669cb0f1c..3f3336040 100644 --- a/gpytorch/metrics/__init__.py +++ b/gpytorch/metrics/__init__.py @@ -4,11 +4,13 @@ mean_standardized_log_loss, negative_log_predictive_density, quantile_coverage_error, + standardized_mean_squared_error, ) __all__ = [ "mean_absolute_error", "mean_squared_error", + "standardized_mean_squared_error", "mean_standardized_log_loss", "negative_log_predictive_density", "quantile_coverage_error", diff --git a/gpytorch/metrics/metrics.py b/gpytorch/metrics/metrics.py index 1fbe5b97e..13e29448a 100644 --- a/gpytorch/metrics/metrics.py +++ b/gpytorch/metrics/metrics.py @@ -1,4 +1,5 @@ from math import pi +from typing import Optional import torch @@ -12,7 +13,7 @@ def mean_absolute_error( test_y: torch.Tensor, ): """ - Mean Absolute Error. + Mean absolute error. """ combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1 return torch.abs(pred_dist.mean - test_y).mean(dim=combine_dim) @@ -24,7 +25,7 @@ def mean_squared_error( squared: bool = True, ): """ - Mean Squared Error. + Mean squared error. """ combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1 res = torch.square(pred_dist.mean - test_y).mean(dim=combine_dim) @@ -33,10 +34,25 @@ def mean_squared_error( return res +def standardized_mean_squared_error( + pred_dist: MultivariateNormal, + test_y: torch.Tensor, +): + """Standardized mean squared error. + + Standardizes the mean squared error by the variance of the test data. + """ + return mean_squared_error(pred_dist, test_y, squared=True) / test_y.var() + + def negative_log_predictive_density( pred_dist: MultivariateNormal, test_y: torch.Tensor, ): + """Negative log predictive density. + + Computes the negative predictive log density normalized by the size of the test data. + """ combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1 return -pred_dist.log_prob(test_y) / test_y.shape[combine_dim] @@ -44,18 +60,33 @@ def negative_log_predictive_density( def mean_standardized_log_loss( pred_dist: MultivariateNormal, test_y: torch.Tensor, + train_y: Optional[torch.Tensor] = None, ): """ - Mean Standardized Log Loss. - Reference: Page No. 23, - Gaussian Processes for Machine Learning, - Carl Edward Rasmussen and Christopher K. I. Williams, - The MIT Press, 2006. ISBN 0-262-18253-X + Mean standardized log loss. + + Computes the average *standardized* log loss, which subtracts the loss obtained + under the trivial model which predicts with the mean and variance of the training + data from the mean log loss. See p.23 of Rasmussen and Williams (2006). + + If no training data is supplied, the mean log loss is computed. """ combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1 + f_mean = pred_dist.mean f_var = pred_dist.variance - return (0.5 * torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim) + loss_model = (0.5 * torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim) + res = loss_model + + if train_y is not None: + data_mean = train_y.mean(dim=combine_dim) + data_var = train_y.var() + loss_trivial_model = ( + 0.5 * torch.log(2 * pi * data_var) + torch.square(test_y - data_mean) / (2 * data_var) + ).mean(dim=combine_dim) + res = res - loss_trivial_model + + return res def quantile_coverage_error( diff --git a/test/metrics/test_metrics.py b/test/metrics/test_metrics.py index ee46bc138..9800fb1af 100644 --- a/test/metrics/test_metrics.py +++ b/test/metrics/test_metrics.py @@ -14,6 +14,7 @@ mean_standardized_log_loss, negative_log_predictive_density, quantile_coverage_error, + standardized_mean_squared_error, ) from gpytorch.models import ExactGP @@ -126,6 +127,9 @@ def test_negative_log_predictive_density(self): def test_mean_standardized_log_loss(self): self._test_metric(mean_standardized_log_loss) + def test_standardized_mean_squared_error(self): + self._test_metric(standardized_mean_squared_error) + def test_quantile_coverage_error(self): self._test_metric( quantile_coverage_error,