From 7999029f9f429eb567a8bf76ca5aef567a54f974 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss Date: Tue, 6 Sep 2022 16:29:51 +0000 Subject: [PATCH 1/2] Fix bug with Multitask DeepGP predictive variances. The latent variances of multitask DeepGP models are stored in non-interleaved covariance matrices. Previously, the MultitaskMultivariateNormal.marginal method implicitly assumed that the function covariance matrices were interleaved. In particular, this affects multitask DeepGP models which use a non-interleaved latent covariance matrix. With this PR, MultitaskMultivariateNormal.marginal now checks if the input covariance matrix is interleaved, and makes sure that the returned predictive covariance matrix matches the interleaved/non-interleaved pattern of the latent covariance matrix. [Fixes #2702] --- .../multitask_gaussian_likelihood.py | 16 +++++---- .../test_multitask_gaussian_likelihood.py | 34 ++++++++++++++++--- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/gpytorch/likelihoods/multitask_gaussian_likelihood.py b/gpytorch/likelihoods/multitask_gaussian_likelihood.py index 1f8faa501..bb452d0b3 100644 --- a/gpytorch/likelihoods/multitask_gaussian_likelihood.py +++ b/gpytorch/likelihoods/multitask_gaussian_likelihood.py @@ -92,8 +92,7 @@ def marginal(self, function_dist: MultitaskMultivariateNormal, *params, **kwargs :param function_dist: Random variable whose covariance matrix is a :obj:`~linear_operator.operators.LinearOperator` we intend to augment. - Returns: - :obj:`gpytorch.distributions.MultitaskMultivariateNormal`: + :rtype: `gpytorch.distributions.MultitaskMultivariateNormal`: :return: A new random variable whose covariance matrix is a :obj:`~linear_operator.operators.LinearOperator` with :math:`\mathbf D_{t} \otimes \mathbf I_{n}` and :math:`\sigma^{2} \mathbf I_{nt}` added. @@ -104,13 +103,15 @@ def marginal(self, function_dist: MultitaskMultivariateNormal, *params, **kwargs if isinstance(covar, LazyEvaluatedKernelTensor): covar = covar.evaluate_kernel() - covar_kron_lt = self._shaped_noise_covar(mean.shape, add_noise=self.has_global_noise) + covar_kron_lt = self._shaped_noise_covar( + mean.shape, add_noise=self.has_global_noise, interleaved=function_dist._interleaved + ) covar = covar + covar_kron_lt - return function_dist.__class__(mean, covar) + return function_dist.__class__(mean, covar, interleaved=function_dist._interleaved) def _shaped_noise_covar( - self, shape: torch.Size, add_noise: Optional[bool] = True, *params, **kwargs + self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved=True, *params, **kwargs ) -> LinearOperator: if not self.has_task_noise: noise = ConstantDiagLinearOperator(self.noise, diag_shape=shape[-2] * self.num_tasks) @@ -140,7 +141,10 @@ def _shaped_noise_covar( noise = ConstantDiagLinearOperator(self.noise, diag_shape=task_var_lt.shape[-1]) task_var_lt = task_var_lt + noise - covar_kron_lt = ckl_init(eye_lt, task_var_lt) + if interleaved: + covar_kron_lt = ckl_init(eye_lt, task_var_lt) + else: + covar_kron_lt = ckl_init(task_var_lt, eye_lt) return covar_kron_lt diff --git a/test/likelihoods/test_multitask_gaussian_likelihood.py b/test/likelihoods/test_multitask_gaussian_likelihood.py index 5c15d504b..c2d337495 100644 --- a/test/likelihoods/test_multitask_gaussian_likelihood.py +++ b/test/likelihoods/test_multitask_gaussian_likelihood.py @@ -3,7 +3,7 @@ import unittest import torch -from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator +from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator, ToeplitzLinearOperator from gpytorch.distributions import MultitaskMultivariateNormal from gpytorch.likelihoods import MultitaskGaussianLikelihood @@ -17,9 +17,9 @@ def _create_conditional_input(self, batch_shape=torch.Size([])): return torch.randn(*batch_shape, 5, 4) def _create_marginal_input(self, batch_shape=torch.Size([])): - mat = torch.randn(*batch_shape, 5, 5) - mat2 = torch.randn(*batch_shape, 4, 4) - covar = KroneckerProductLinearOperator(RootLinearOperator(mat), RootLinearOperator(mat2)) + data_mat = ToeplitzLinearOperator(torch.tensor([1, 0.6, 0.4, 0.2, 0.1])) + task_mat = RootLinearOperator(torch.tensor([[1.0], [2.0], [3.0], [4.0]])) + covar = KroneckerProductLinearOperator(data_mat, task_mat) return MultitaskMultivariateNormal(torch.randn(*batch_shape, 5, 4), covar) def _create_targets(self, batch_shape=torch.Size([])): @@ -28,6 +28,22 @@ def _create_targets(self, batch_shape=torch.Size([])): def create_likelihood(self): return MultitaskGaussianLikelihood(num_tasks=4, rank=2) + def test_marginal_variance(self): + likelihood = MultitaskGaussianLikelihood(num_tasks=4, rank=0, has_global_noise=False) + likelihood.task_noises = torch.tensor([[0.1], [0.2], [0.3], [0.4]]) + + input = self._create_marginal_input() + variance = likelihood(input).variance + self.assertAllClose(variance, torch.tensor([1.1, 4.2, 9.3, 16.4]).repeat(5, 1)) + + likelihood = MultitaskGaussianLikelihood(num_tasks=4, rank=1, has_global_noise=True) + likelihood.noise = torch.tensor(0.1) + likelihood.task_noise_covar_factor.data = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + + input = self._create_marginal_input() + variance = likelihood(input).variance + self.assertAllClose(variance, torch.tensor([2.1, 8.1, 18.1, 32.1]).repeat(5, 1)) + def test_setters(self): likelihood = MultitaskGaussianLikelihood(num_tasks=3, rank=0) @@ -59,6 +75,16 @@ def test_setters(self): self.assertTrue("task noises" in str(context.exception)) +class TestMultitaskGaussianLikelihoodNonInterleaved(TestMultitaskGaussianLikelihood, unittest.TestCase): + seed = 2 + + def _create_marginal_input(self, batch_shape=torch.Size([])): + data_mat = ToeplitzLinearOperator(torch.tensor([1, 0.6, 0.4, 0.2, 0.1])) + task_mat = RootLinearOperator(torch.tensor([[1.0], [2.0], [3.0], [4.0]])) + covar = KroneckerProductLinearOperator(task_mat, data_mat) + return MultitaskMultivariateNormal(torch.randn(*batch_shape, 5, 4), covar, interleaved=False) + + class TestMultitaskGaussianLikelihoodBatch(TestMultitaskGaussianLikelihood): seed = 0 From 88ce8aa779c0e5fca4cd5e35ef0863fdd6d3ea93 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss Date: Wed, 7 Sep 2022 10:07:26 -0400 Subject: [PATCH 2/2] Update gpytorch/likelihoods/multitask_gaussian_likelihood.py Co-authored-by: Max Balandat --- gpytorch/likelihoods/multitask_gaussian_likelihood.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/likelihoods/multitask_gaussian_likelihood.py b/gpytorch/likelihoods/multitask_gaussian_likelihood.py index bb452d0b3..af18fc5ce 100644 --- a/gpytorch/likelihoods/multitask_gaussian_likelihood.py +++ b/gpytorch/likelihoods/multitask_gaussian_likelihood.py @@ -111,7 +111,7 @@ def marginal(self, function_dist: MultitaskMultivariateNormal, *params, **kwargs return function_dist.__class__(mean, covar, interleaved=function_dist._interleaved) def _shaped_noise_covar( - self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved=True, *params, **kwargs + self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved: bool = True, *params, **kwargs ) -> LinearOperator: if not self.has_task_noise: noise = ConstantDiagLinearOperator(self.noise, diag_shape=shape[-2] * self.num_tasks)