diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index 7c8637979..99d92a5f0 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -208,7 +208,7 @@ def log_prob(self, value): # flip shape of last two dimensions new_shape = value.shape[:-2] + value.shape[:-3:-1] value = value.view(new_shape).transpose(-1, -2).contiguous() - return super().log_prob(value.view(*value.shape[:-2], -1)) + return super().log_prob(value.reshape(*value.shape[:-2], -1)) @property def mean(self):