Skip to content

[Bug] Shape Error in UnwhitenedVariationalStrategy During Batch Prediction #2642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kayween opened this issue Mar 10, 2025 · 0 comments
Open
Labels

Comments

@kayween
Copy link
Collaborator

kayween commented Mar 10, 2025

🐛 Bug

Variational GPs with UnwhitenedVariationalStrategy have a shape error during batch prediction.

To reproduce

** Code snippet to reproduce **

import torch

import gpytorch

from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import UnwhitenedVariationalStrategy
from gpytorch.variational import VariationalStrategy

import torch.optim as optim


class GPModel(ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = UnwhitenedVariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )

        super().__init__(variational_strategy)

        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)

        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


inducing_points = torch.randn(20, 2)
model = GPModel(inducing_points)

train_x = torch.randn(5, 2)
train_y = torch.randn(5)

likelihood = gpytorch.likelihoods.GaussianLikelihood()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, train_y.numel())

likelihood.train()
model.train()

optimizer = optim.Adam(model.parameters(), lr=0.1)

for i in range(10):
    optimizer.zero_grad()

    output = model(train_x)
    loss = -mll(output, train_y)

    loss.backward()
    optimizer.step()

model.eval()
likelihood.eval()

# batch prediction
test_x = torch.randn(10, 5, 2)
model(test_x)

# now use a different batch size, which results in an error
test_x = torch.randn(2, 5, 2)
model(test_x)

** Stack trace/error message **

Traceback (most recent call last):
  File "/home/kaiwen/Desktop/preference-bo/debug.py", line 64, in <module>
    model(test_x)
  File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/models/approximate_gp.py", line 114, in __call__
    return self.variational_strategy(inputs, prior=prior, **kwargs)
  File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/variational/_variational_strategy.py", line 347, in __call__
    return super().__call__(
  File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/module.py", line 82, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/gpytorch/variational/unwhitened_variational_strategy.py", line 171, in forward
    shape = torch.broadcast_shapes(*shapes)
  File "/home/kaiwen/anaconda3/envs/gpflow/lib/python3.10/site-packages/torch/functional.py", line 136, in broadcast_shapes
    raise RuntimeError(
RuntimeError: Shape mismatch: objects cannot be broadcast to a single shape

Expected Behavior

The prediction in the second pass should work fine without the shape error.

System information

Please complete the following information:

  • GPyTorch 1.14
  • PyTorch 2.6.0+cu124
  • Ubuntu 20.04.6 LTS

Additional context

I stepped into the error stack. The above error is because self._cholesky_factor caches a Cholesky factor of an incorrect shape.

However, WhitenedVariationalStrategy does not have this shape error and it's not clear why. I'll spend some more time investigating this issue.

@kayween kayween added the bug label Mar 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant