Skip to content

Commit 0c359c5

Browse files
authored
MultitaskMultivariateNormal: fix tensor reshape issue (#2081)
1 parent 2622873 commit 0c359c5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

gpytorch/distributions/multitask_multivariate_normal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def log_prob(self, value):
208208
# flip shape of last two dimensions
209209
new_shape = value.shape[:-2] + value.shape[:-3:-1]
210210
value = value.view(new_shape).transpose(-1, -2).contiguous()
211-
return super().log_prob(value.view(*value.shape[:-2], -1))
211+
return super().log_prob(value.reshape(*value.shape[:-2], -1))
212212

213213
@property
214214
def mean(self):

0 commit comments

Comments
 (0)