Skip to content

Commit 16cff15

Browse files
authored
Merge pull request #2621 from saitcakmak/mtmvn
Support broadcasting batch shapes in MTMVN.from_independent_mvns
2 parents c62c324 + aafb92b commit 16cff15

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

gpytorch/distributions/multitask_multivariate_normal.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def from_independent_mvns(cls, mvns):
153153
if any(isinstance(mvn, MultitaskMultivariateNormal) for mvn in mvns):
154154
raise ValueError("Cannot accept MultitaskMultivariateNormals")
155155
if not all(m.batch_shape == mvns[0].batch_shape for m in mvns[1:]):
156-
raise ValueError("All MultivariateNormals must have the same batch shape")
156+
batch_shape = torch.broadcast_shapes(*(m.batch_shape for m in mvns))
157+
mvns = [mvn.expand(batch_shape) for mvn in mvns]
157158
if not all(m.event_shape == mvns[0].event_shape for m in mvns[1:]):
158159
raise ValueError("All MultivariateNormals must have the same event shape")
159160
mean = torch.stack([mvn.mean for mvn in mvns], -1)

test/distributions/test_multitask_multivariate_normal.py

+8
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,14 @@ def test_from_independent_mvns(self, cuda=False):
278278
self.assertEqual(list(mvn.mean.shape), expected_mean_shape)
279279
self.assertEqual(list(mvn.covariance_matrix.shape), expected_covar_shape)
280280

281+
# Test mixed batch mode mvns
282+
# Second MVN is batched, so the first one will be expanded to match.
283+
mvns[1] = mvns[1].expand(torch.Size([3]))
284+
expected_mvn = mvn.expand(torch.Size([3]))
285+
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
286+
self.assertTrue(torch.equal(mvn.mean, expected_mvn.mean))
287+
self.assertTrue(torch.equal(mvn.covariance_matrix, expected_mvn.covariance_matrix))
288+
281289
# Test batch mode mvns
282290
b = 3
283291
mvns = [

0 commit comments

Comments
 (0)