Skip to content

VNNGP with Batches #2344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
gpleiss opened this issue May 15, 2023 Discussed in #2300 · 1 comment
Open

VNNGP with Batches #2344

gpleiss opened this issue May 15, 2023 Discussed in #2300 · 1 comment

Comments

@gpleiss
Copy link
Member

gpleiss commented May 15, 2023

Discussed in #2300

Originally posted by Turakar March 12, 2023
I am trying to get a VNNGP to work in a batched setting. To this end, I tried the following code, which is based on the tutorial.

import math

import matplotlib.pyplot as plt
import torch
from torch import Tensor
from tqdm.auto import tqdm

from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls import PredictiveLogLikelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import MeanFieldVariationalDistribution, NNVariationalStrategy


class BatchGPModel(ApproximateGP):
    def __init__(self, train_x: Tensor):
        batch_shape = train_x.shape[:-2]

        inducing_points = torch.clone(train_x)
        variational_distribution = MeanFieldVariationalDistribution(inducing_points.size(-2), batch_shape=batch_shape)
        variational_strategy = NNVariationalStrategy(self, inducing_points, variational_distribution, 25, 25)
        super().__init__(variational_strategy)

        self.mean_module = ConstantMean(batch_shape=batch_shape)
        self.covar_module = ScaleKernel(RBFKernel(batch_shape=batch_shape), batch_shape=batch_shape)
        self.likelihood = GaussianLikelihood(batch_shape=batch_shape)

    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

    def __call__(self, x, prior=False, **kwargs):
        if x is not None:
            if x.dim() == 1:
                x = x.unsqueeze(-1)
        return self.variational_strategy(x=x, prior=False, **kwargs)


def main():
    x = torch.linspace(0, 1, 100)
    train_y = torch.stack(
        [
            torch.sin(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
            torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
            torch.sin(x * (2 * math.pi)) + 2 * torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
            -torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
        ],
        0,
    )
    train_x = torch.stack([x] * 4).unsqueeze(-1)
    num_tasks = 4

    # initialize model
    model = BatchGPModel(train_x)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = PredictiveLogLikelihood(model.likelihood, model, num_data=x.size(0))

    num_batches = model.variational_strategy._total_training_batches
    epochs_iter = tqdm(range(50), desc="Epoch")
    for _ in epochs_iter:
        minibatch_iter = tqdm(range(num_batches), desc="Minibatch", leave=False)

        for _ in minibatch_iter:
            optimizer.zero_grad()
            output = model(x=None)
            current_training_indices = model.variational_strategy.current_training_indices
            y_batch = train_y[:, current_training_indices]
            loss = -mll(output, y_batch).sum()
            minibatch_iter.set_postfix(loss=loss.item())
            loss.backward()
            optimizer.step()

    # Get into evaluation (predictive posterior) mode
    model.eval()

    # Initialize plots
    fig, axs = plt.subplots(1, num_tasks, figsize=(4 * num_tasks, 3))

    # Make predictions
    with torch.no_grad(), settings.fast_pred_var():
        test_x = torch.stack([torch.linspace(0, 1, 51)] * num_tasks).unsqueeze(-1)
        predictions = model.likelihood(model(test_x))
        mean = predictions.mean
        lower, upper = predictions.confidence_region()

    for task, ax in enumerate(axs):
        # Plot training data as black stars
        ax.plot(x.detach().numpy(), train_y[task].detach().numpy(), "k*")
        # Predictive mean as blue line
        ax.plot(test_x[0, :, 0].numpy(), mean[task].numpy(), "b")
        # Shade in confidence
        ax.fill_between(test_x[0, :, 0].numpy(), lower[task].numpy(), upper[task].numpy().T, alpha=0.5)
        ax.set_ylim([-3, 3])
        ax.legend(["Observed Data", "Mean", "Confidence"])
        ax.set_title(f"Task {task + 1}")

    fig.tight_layout()

    plt.show()


if __name__ == "__main__":
    main()

However, this does not work, as the _stochastic_kl_helper() in NNVariationalStrategy calls forward() of the covar_module with an input of shape (4, 25, 25, 1), while a shape of (4, n, 1) is expected. Did I get something wrong about the usage of NNVariationalStrategy or is there something wrong in GPyTorch? This is the exact stacktrace:

Traceback (most recent call last):
  File "/path/to/gpytorch/snippet_variational_batch_sogp.py", line 109, in <module>
    main()
  File "/path/to/gpytorch/snippet_variational_batch_sogp.py", line 71, in main
    output = model(x=None)
  File "/path/to/gpytorch/snippet_variational_batch_sogp.py", line 40, in __call__
    return self.variational_strategy(x=x, prior=False, **kwargs)
  File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 131, in __call__
    return self.forward(x, self.inducing_points, None, None)
  File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 168, in forward
    kl = self._kl_divergence(kl_indices)
  File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 325, in _kl_divergence
    kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
  File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 273, in _stochastic_kl_helper
    cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors)
  File "/path/to/gpytorch/gpytorch/kernels/scale_kernel.py", line 109, in forward
    orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
  File "/path/to/gpytorch/gpytorch/kernels/rbf_kernel.py", line 80, in forward
    return RBFCovariance.apply(
  File "/path/to/gpytorch/gpytorch/functions/rbf_covariance.py", line 12, in forward
    x1_ = x1.div(lengthscale)
RuntimeError: The size of tensor a (25) must match the size of tensor b (4) at non-singleton dimension 1

Does somebody have an idea what's going on here? Maybe @LuhuanWu or @gpleiss ?

@LuhuanWu
Copy link
Contributor

LuhuanWu commented Jul 9, 2023

Submitted PR #2375 , and replied in #2300.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants