Skip to content

Update MVN.expand() to support non-lazy MVN & reuse scale_tril where possible #2623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,30 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal:
See :py:meth:`torch.distributions.Distribution.expand
<torch.distributions.distribution.Distribution.expand>`.
"""
new_loc = self.loc.expand(torch.Size(batch_size) + self.loc.shape[-1:])
new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
res = self.__class__(new_loc, new_covar)
return res
# NOTE: Pyro may call this method with list[int] instead of torch.Size.
batch_size = torch.Size(batch_size)
new_loc = self.loc.expand(batch_size + self.loc.shape[-1:])
if self.islazy:
new_covar = self._covar.expand(batch_size + self._covar.shape[-2:])
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems slightly hairy - would an alternative here be to just add support for MultivariateNormal to be initialized directly from the scale_tril?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that'd be an option. I didn't want to touch it since it can potentially affect correctness of MVN. But it doesn't seem like _covar is directly used anywhere, so it may be safer than I initially though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, what would we do if both covar and scale_tril are provided? Do we validate that they are compatible? This would require re-computing one or the other, not the cheapest operation. Seems to conflict with the goal of re-using whenever these are available.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could also just make it such that we error out if both of them are provided.

new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:])
return new

def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
r"""
Expand Down
26 changes: 26 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import unittest
from itertools import product

import torch
from linear_operator import to_linear_operator
Expand Down Expand Up @@ -323,6 +324,31 @@ def test_base_sample_shape(self):
samples = dist.rsample(torch.Size((16,)), base_samples=torch.randn(16, 5))
self.assertEqual(samples.shape, torch.Size((16, 5)))

def test_multivariate_normal_expand(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype, lazy in product((torch.float, torch.double), (True, False)):
mean = torch.tensor([0, 1, 2], device=device, dtype=dtype)
covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype))
if lazy:
mvn = MultivariateNormal(mean=mean, covariance_matrix=DenseLinearOperator(covmat), validate_args=True)
# Initialize scale tril so we can test that it was expanded.
mvn.scale_tril
else:
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
self.assertEqual(mvn.batch_shape, torch.Size([]))
self.assertEqual(mvn.islazy, lazy)
expanded = mvn.expand(torch.Size([2]))
self.assertIsInstance(expanded, MultivariateNormal)
self.assertEqual(expanded.islazy, lazy)
self.assertEqual(expanded.batch_shape, torch.Size([2]))
self.assertEqual(expanded.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(expanded.mean, mean.expand(2, -1)))
self.assertEqual(expanded.mean.shape, torch.Size([2, 3]))
self.assertTrue(torch.allclose(expanded.covariance_matrix, covmat.expand(2, -1, -1)))
self.assertEqual(expanded.covariance_matrix.shape, torch.Size([2, 3, 3]))
self.assertTrue(torch.allclose(expanded.scale_tril, mvn.scale_tril.expand(2, -1, -1)))
self.assertEqual(expanded.scale_tril.shape, torch.Size([2, 3, 3]))


if __name__ == "__main__":
unittest.main()
Loading