Skip to content

[Bug] cholesky decomposition on multitask sparse GP #1479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vcharvet opened this issue Feb 12, 2021 · 2 comments
Closed

[Bug] cholesky decomposition on multitask sparse GP #1479

vcharvet opened this issue Feb 12, 2021 · 2 comments
Labels

Comments

@vcharvet
Copy link

🐛 Bug

Hello,
I'm using a sparse multitask GP to learn a dynamical model in a reinforcement learning problem. I'm then using the model to compute Moment Matching predictions at uncertain inputs.
It works well up to a certain amount of points.

To reproduce

** Code snippet to reproduce **

import torch
import gpytorch
import numpy as np

from gpytorch.likelihoods import GaussianLikelihood, MultitaskGaussianLikelihood
from gpytorch.constraints import GreaterThan, Interval
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy, IndependentMultitaskVariationalStrategy, CholeskyVariationalDistribution
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.means import ConstantMean
from gpytorch.mlls import VariationalELBO
from gpytorch.lazy import DiagLazyTensor

torch.set_default_dtype(torch.float64)


np.random.seed(0)

k = 2
d = 3
n = 850


X0 =  np.random.rand(n, d)
A = np.random.rand(d, k)
Y0 = np.sin(X0).dot(A) + 1e-3*(np.random.rand(n, k) - 0.5)  #  Just something
M = 50

class SparseMultivariate(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        batch_shape = torch.Size([k])
        var_distrib = CholeskyVariationalDistribution(
            inducing_points.size(0), batch_shape)
        var_strategy = IndependentMultitaskVariationalStrategy(
            VariationalStrategy(self,
                                inducing_points,
                                var_distrib,
                                learn_inducing_locations=True),
            num_tasks=k)
        super(SparseMultivariate, self).__init__(var_strategy)
        self.mean_module = ConstantMean(batch_shape=batch_shape)
        self.covar_module = ScaleKernel(
            RBFKernel(ard_num_dim=d,
                      batch_shape=batch_shape),
            batch_shape=batch_shape)

    @property
    def Z(self):
        Z = self.variational_strategy.base_variational_strategy.inducing_points#.T
        return Z.squeeze()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        mvn =  MultivariateNormal(mean_x, covar_x)
        return mvn

inducing_points = torch.rand(M, d, dtype=torch.float64)
model = SparseMultivariate(inducing_points)

likelihood = MultitaskGaussianLikelihood(num_tasks=k)

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()}], lr=0.1)

mll = VariationalELBO(likelihood, model, num_data=n)


X, Y = torch.tensor(X0), torch.tensor(Y0)
for i in range(100):
    optimizer.zero_grad()
    output = model(X)
    loss = -mll(output, Y)
    loss.backward()
    optimizer.step()

# code to compute moments
# eye = torch.eye(M).repeat(d, 1, 1)
eye_full = torch.eye(n).repeat(k, 1, 1)

Kmm = model.covar_module(model.Z)
Kmn = model.covar_module(model.Z, X)
Knn = model.covar_module(X)

noise = likelihood.noise_covar.noise
Q = Kmm.inv_matmul(Kmn.evaluate(), left_tensor=Kmn.transpose(-1, -2).evaluate())
G = DiagLazyTensor((Knn - Q).diag()) + noise[:, None, None] * eye_full
B = Kmm + G.inv_matmul(Kmn.transpose(-1, -2).evaluate(),
                       left_tensor=Kmn.evaluate())

# beta = B.inv_matmul(Kmn.evaluate()) @ G.inv_matmul(y_train.T[:, :, None])
beta = G.inv_matmul(Y.T[:, :, None],
                    left_tensor=B.inv_matmul(Kmn.evaluate()))

It works for n=800 but not for n=900

** Stack trace/error message **

Traceback (most recent call last):
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/utils/cholesky.py", line 27, in psd_safe_cholesky
    L = torch.cholesky(A, upper=upper, out=out)
RuntimeError: cholesky_cpu: For batch 0: U(13,13) is zero, singular U.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gpytorch_issue.py", line 95, in <module>
    left_tensor=B.inv_matmul(Kmn.evaluate()))
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/lazy/lazy_tensor.py", line 963, in inv_matmul
    return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/functions/_inv_matmul.py", line 51, in forward
    solves = _solve(lazy_tsr, right_tensor)
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/functions/_inv_matmul.py", line 15, in _solve
    return lazy_tsr.cholesky()._cholesky_solve(rhs)
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/lazy/lazy_tensor.py", line 750, in cholesky
    chol = self._cholesky(upper=False)
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/lazy/lazy_tensor.py", line 419, in _cholesky
    cholesky = psd_safe_cholesky(evaluated_mat, jitter=settings.cholesky_jitter.value(), upper=upper).contiguous()
  File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/utils/cholesky.py", line 51, in psd_safe_cholesky
    f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}. "
gpytorch.utils.errors.NotPSDError: Matrix not positive definite after repeatedly adding jitter up to 1.0e-06. Original error on first attempt: cholesky_cpu: For batch 0: U(13,13) is zero, singular U.

Expected Behavior

For low values of n, but if n is too high, matrix B becomes singular

System information

Please complete the following information:
GPytorch version: 1.3.1
Pytorch version: 1.7.0

OS: $lsb_release - a
Distributor ID: Debian
Description: Debian GNU/Linux 9.13 (stretch)
Release: 9.13
Codename: stretch

Additional context

In the RL context, we should be able to compute the predictions as $n \rightarrow \infty$

Reference for MM prediction: Peter Deisenroth, M. (2010). Efficient Reinforcement Learning using Gaussian Processes, chapter 2.4

@vcharvet vcharvet added the bug label Feb 12, 2021
@wjmaddox
Copy link
Collaborator

This can be resolved by using cholesky_jitter in the last line. This line works for me:

with gpytorch.settings.cholesky_jitter(1e-1):
    beta = G.inv_matmul(Y.T[:, :, None],
                        left_tensor=B.inv_matmul(Kmn.evaluate()))

Numerically, what seems to be happening is that the matrix B is becoming increasingly ill-conditioned: on your dataset, the smallest eigenvalues from torch.symeig(B.evaluate()) are something like -0.3.

A more numerically stable implementation exploits G being a diagonal matrix (the sum of two diagonals):

G = DiagLazyTensor((Knn - Q).diag() + noise.unsqueeze(-1))
B = Kmm + G.inv_matmul(Kmn.transpose(-1, -2).evaluate(),
                       left_tensor=Kmn.evaluate())

beta = G.inv_matmul(Y.T[:, :, None],
                    left_tensor=B.inv_matmul(Kmn.evaluate()))

The second fix is probably the one you should use in this setting.
Now, B surprisingly has only positive eigenvalues. What's happening is that gpytorch doesn't pick up that G is diagonal (how could it given the sum) and then is running CG after n > 800.

@vcharvet
Copy link
Author

vcharvet commented Feb 13, 2021

Hi, thank you very much for your quick reply.
The second fix works well, on example as well as integrated in the RL setting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants