You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Variational GPs with UnwhitenedVariationalStrategy have a shape error during batch prediction.
To reproduce
** Code snippet to reproduce **
importtorchimportgpytorchfromgpytorch.modelsimportApproximateGPfromgpytorch.variationalimportCholeskyVariationalDistributionfromgpytorch.variationalimportUnwhitenedVariationalStrategyfromgpytorch.variationalimportVariationalStrategyimporttorch.optimasoptimclassGPModel(ApproximateGP):
def__init__(self, inducing_points):
variational_distribution=CholeskyVariationalDistribution(inducing_points.size(0))
variational_strategy=UnwhitenedVariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True
)
super().__init__(variational_strategy)
self.mean_module=gpytorch.means.ConstantMean()
self.covar_module=gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
defforward(self, x):
mean_x=self.mean_module(x)
covar_x=self.covar_module(x)
returngpytorch.distributions.MultivariateNormal(mean_x, covar_x)
inducing_points=torch.randn(20, 2)
model=GPModel(inducing_points)
train_x=torch.randn(5, 2)
train_y=torch.randn(5)
likelihood=gpytorch.likelihoods.GaussianLikelihood()
mll=gpytorch.mlls.VariationalELBO(likelihood, model, train_y.numel())
likelihood.train()
model.train()
optimizer=optim.Adam(model.parameters(), lr=0.1)
foriinrange(10):
optimizer.zero_grad()
output=model(train_x)
loss=-mll(output, train_y)
loss.backward()
optimizer.step()
model.eval()
likelihood.eval()
# batch predictiontest_x=torch.randn(10, 5, 2)
model(test_x)
# now use a different batch size, which results in an errortest_x=torch.randn(2, 5, 2)
model(test_x)
** Stack trace/error message **
Traceback (most recent call last):
File "/home/kaiwen/Desktop/preference-bo/debug.py", line 64, in <module>
model(test_x)
File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/models/approximate_gp.py", line 114, in __call__
return self.variational_strategy(inputs, prior=prior, **kwargs)
File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/variational/_variational_strategy.py", line 347, in __call__
return super().__call__(
File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/module.py", line 82, in __call__
outputs = self.forward(*inputs, **kwargs)
File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/variational/unwhitened_variational_strategy.py", line 171, in forward
shape = torch.broadcast_shapes(*shapes)
File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/torch/functional.py", line 136, in broadcast_shapes
raise RuntimeError(
RuntimeError: Shape mismatch: objects cannot be broadcast to a single shape
Expected Behavior
The prediction in the second pass should work fine without the shape error.
System information
Please complete the following information:
GPyTorch 1.14
PyTorch 2.6.0+cu124
Ubuntu 20.04.6 LTS
Additional context
I stepped into the error stack. The above error is because self._cholesky_factor caches a Cholesky factor of an incorrect shape.
However, WhitenedVariationalStrategy does not have this shape error and it's not clear why. I'll spend some more time investigating this issue.
The text was updated successfully, but these errors were encountered:
🐛 Bug
Variational GPs with
UnwhitenedVariationalStrategy
have a shape error during batch prediction.To reproduce
** Code snippet to reproduce **
** Stack trace/error message **
Expected Behavior
The prediction in the second pass should work fine without the shape error.
System information
Please complete the following information:
Additional context
I stepped into the error stack. The above error is because self._cholesky_factor caches a Cholesky factor of an incorrect shape.
However,
WhitenedVariationalStrategy
does not have this shape error and it's not clear why. I'll spend some more time investigating this issue.The text was updated successfully, but these errors were encountered: