diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index 342815902..b692217f4 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -244,7 +244,7 @@ def rsample(self, sample_shape=torch.Size(), base_samples=None): return samples.view(new_shape).transpose(-1, -2).contiguous() return samples.view(sample_shape + self._output_shape) - def to_data_independent_dist(self): + def to_data_independent_dist(self, jitter_val=1e-4): """ Convert a multitask MVN into a batched (non-multitask) MVNs The result retains the intertask covariances, but gets rid of the inter-data covariances. @@ -256,12 +256,16 @@ def to_data_independent_dist(self): # Create batch distribution where all data are independent, but the tasks are dependent full_covar = self.lazy_covariance_matrix num_data, num_tasks = self.mean.shape[-2:] - data_indices = torch.arange(0, num_data * num_tasks, num_tasks, device=full_covar.device).view(-1, 1, 1) - task_indices = torch.arange(num_tasks, device=full_covar.device) + if self._interleaved: + data_indices = torch.arange(0, num_data * num_tasks, num_tasks, device=full_covar.device).view(-1, 1, 1) + task_indices = torch.arange(num_tasks, device=full_covar.device) + else: + data_indices = torch.arange(num_data, device=full_covar.device).view(-1, 1, 1) + task_indices = torch.arange(0, num_data * num_tasks, num_data, device=full_covar.device) task_covars = full_covar[ ..., data_indices + task_indices.unsqueeze(-2), data_indices + task_indices.unsqueeze(-1) ] - return MultivariateNormal(self.mean, to_linear_operator(task_covars).add_jitter()) + return MultivariateNormal(self.mean, to_linear_operator(task_covars).add_jitter(jitter_val=jitter_val)) @property def variance(self): diff --git a/test/distributions/test_multitask_multivariate_normal.py b/test/distributions/test_multitask_multivariate_normal.py index bea6dee8c..773a1b004 100644 --- a/test/distributions/test_multitask_multivariate_normal.py +++ b/test/distributions/test_multitask_multivariate_normal.py @@ -6,7 +6,7 @@ import unittest import torch -from linear_operator.operators import DiagLinearOperator +from linear_operator.operators import DiagLinearOperator, KroneckerProductLinearOperator from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.test.base_test_case import BaseTestCase @@ -201,6 +201,28 @@ def test_log_prob_cuda(self): with least_used_cuda_device(): self.test_log_prob(cuda=True) + def test_to_data_independent_dist(self, dtype=torch.float, device="cpu", interleaved=True): + # Create a fake covariance + factor = torch.randn(4, 4, device=device, dtype=dtype) + data_covar = factor.mT @ factor + task_covar = torch.tensor([[1.0, 0.3, 0.1], [0.3, 1.0, 0.3], [0.1, 0.3, 1.0]], device=device, dtype=dtype) + if interleaved: + covar = KroneckerProductLinearOperator(data_covar, task_covar) + else: + covar = KroneckerProductLinearOperator(task_covar, data_covar) + + mean = torch.randn(4, 3, device=device, dtype=dtype) + dist = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved) + + res = dist.to_data_independent_dist(jitter_val=1e-4) + self.assertEqual(res.mean, mean) + data_var = data_covar.diagonal(dim1=-1, dim2=-2) + jitter = torch.eye(3, dtype=dtype, device=device) * 1e-4 + self.assertAllClose(res.covariance_matrix, data_var.view(-1, 1, 1) * task_covar + jitter) + + def test_to_data_independent_dist_no_interleave(self, dtype=torch.float, device="cpu"): + return self.test_to_data_independent_dist(dtype=dtype, device=device, interleaved=False) + def test_multitask_from_batch(self): mean = torch.randn(2, 3) variance = torch.randn(2, 3).clamp_min(1e-6)