Skip to content

Fix bug with Multitask DeepGP predictive variances. #2123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions gpytorch/likelihoods/multitask_gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def marginal(self, function_dist: MultitaskMultivariateNormal, *params, **kwargs

:param function_dist: Random variable whose covariance
matrix is a :obj:`~linear_operator.operators.LinearOperator` we intend to augment.
Returns:
:obj:`gpytorch.distributions.MultitaskMultivariateNormal`:
:rtype: `gpytorch.distributions.MultitaskMultivariateNormal`:
:return: A new random variable whose covariance matrix is a
:obj:`~linear_operator.operators.LinearOperator` with
:math:`\mathbf D_{t} \otimes \mathbf I_{n}` and :math:`\sigma^{2} \mathbf I_{nt}` added.
Expand All @@ -104,13 +103,15 @@ def marginal(self, function_dist: MultitaskMultivariateNormal, *params, **kwargs
if isinstance(covar, LazyEvaluatedKernelTensor):
covar = covar.evaluate_kernel()

covar_kron_lt = self._shaped_noise_covar(mean.shape, add_noise=self.has_global_noise)
covar_kron_lt = self._shaped_noise_covar(
mean.shape, add_noise=self.has_global_noise, interleaved=function_dist._interleaved
)
covar = covar + covar_kron_lt

return function_dist.__class__(mean, covar)
return function_dist.__class__(mean, covar, interleaved=function_dist._interleaved)

def _shaped_noise_covar(
self, shape: torch.Size, add_noise: Optional[bool] = True, *params, **kwargs
self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved: bool = True, *params, **kwargs
) -> LinearOperator:
if not self.has_task_noise:
noise = ConstantDiagLinearOperator(self.noise, diag_shape=shape[-2] * self.num_tasks)
Expand Down Expand Up @@ -140,7 +141,10 @@ def _shaped_noise_covar(
noise = ConstantDiagLinearOperator(self.noise, diag_shape=task_var_lt.shape[-1])
task_var_lt = task_var_lt + noise

covar_kron_lt = ckl_init(eye_lt, task_var_lt)
if interleaved:
covar_kron_lt = ckl_init(eye_lt, task_var_lt)
else:
covar_kron_lt = ckl_init(task_var_lt, eye_lt)

return covar_kron_lt

Expand Down
34 changes: 30 additions & 4 deletions test/likelihoods/test_multitask_gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest

import torch
from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator
from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator, ToeplitzLinearOperator

from gpytorch.distributions import MultitaskMultivariateNormal
from gpytorch.likelihoods import MultitaskGaussianLikelihood
Expand All @@ -17,9 +17,9 @@ def _create_conditional_input(self, batch_shape=torch.Size([])):
return torch.randn(*batch_shape, 5, 4)

def _create_marginal_input(self, batch_shape=torch.Size([])):
mat = torch.randn(*batch_shape, 5, 5)
mat2 = torch.randn(*batch_shape, 4, 4)
covar = KroneckerProductLinearOperator(RootLinearOperator(mat), RootLinearOperator(mat2))
data_mat = ToeplitzLinearOperator(torch.tensor([1, 0.6, 0.4, 0.2, 0.1]))
task_mat = RootLinearOperator(torch.tensor([[1.0], [2.0], [3.0], [4.0]]))
covar = KroneckerProductLinearOperator(data_mat, task_mat)
return MultitaskMultivariateNormal(torch.randn(*batch_shape, 5, 4), covar)

def _create_targets(self, batch_shape=torch.Size([])):
Expand All @@ -28,6 +28,22 @@ def _create_targets(self, batch_shape=torch.Size([])):
def create_likelihood(self):
return MultitaskGaussianLikelihood(num_tasks=4, rank=2)

def test_marginal_variance(self):
likelihood = MultitaskGaussianLikelihood(num_tasks=4, rank=0, has_global_noise=False)
likelihood.task_noises = torch.tensor([[0.1], [0.2], [0.3], [0.4]])

input = self._create_marginal_input()
variance = likelihood(input).variance
self.assertAllClose(variance, torch.tensor([1.1, 4.2, 9.3, 16.4]).repeat(5, 1))

likelihood = MultitaskGaussianLikelihood(num_tasks=4, rank=1, has_global_noise=True)
likelihood.noise = torch.tensor(0.1)
likelihood.task_noise_covar_factor.data = torch.tensor([[1.0], [2.0], [3.0], [4.0]])

input = self._create_marginal_input()
variance = likelihood(input).variance
self.assertAllClose(variance, torch.tensor([2.1, 8.1, 18.1, 32.1]).repeat(5, 1))

def test_setters(self):
likelihood = MultitaskGaussianLikelihood(num_tasks=3, rank=0)

Expand Down Expand Up @@ -59,6 +75,16 @@ def test_setters(self):
self.assertTrue("task noises" in str(context.exception))


class TestMultitaskGaussianLikelihoodNonInterleaved(TestMultitaskGaussianLikelihood, unittest.TestCase):
seed = 2

def _create_marginal_input(self, batch_shape=torch.Size([])):
data_mat = ToeplitzLinearOperator(torch.tensor([1, 0.6, 0.4, 0.2, 0.1]))
task_mat = RootLinearOperator(torch.tensor([[1.0], [2.0], [3.0], [4.0]]))
covar = KroneckerProductLinearOperator(task_mat, data_mat)
return MultitaskMultivariateNormal(torch.randn(*batch_shape, 5, 4), covar, interleaved=False)


class TestMultitaskGaussianLikelihoodBatch(TestMultitaskGaussianLikelihood):
seed = 0

Expand Down