You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I'm using a sparse multitask GP to learn a dynamical model in a reinforcement learning problem. I'm then using the model to compute Moment Matching predictions at uncertain inputs.
It works well up to a certain amount of points.
Traceback (most recent call last):
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/utils/cholesky.py", line 27, in psd_safe_cholesky
L = torch.cholesky(A, upper=upper, out=out)
RuntimeError: cholesky_cpu: For batch 0: U(13,13) is zero, singular U.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "gpytorch_issue.py", line 95, in <module>
left_tensor=B.inv_matmul(Kmn.evaluate()))
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/lazy/lazy_tensor.py", line 963, in inv_matmul
return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/functions/_inv_matmul.py", line 51, in forward
solves = _solve(lazy_tsr, right_tensor)
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/functions/_inv_matmul.py", line 15, in _solve
return lazy_tsr.cholesky()._cholesky_solve(rhs)
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/lazy/lazy_tensor.py", line 750, in cholesky
chol = self._cholesky(upper=False)
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/utils/memoize.py", line 59, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/lazy/lazy_tensor.py", line 419, in _cholesky
cholesky = psd_safe_cholesky(evaluated_mat, jitter=settings.cholesky_jitter.value(), upper=upper).contiguous()
File "/home/valou/workspace/pythonProjects/.venvs-python37/venv-gym-fb/lib/python3.7/site-packages/gpytorch/utils/cholesky.py", line 51, in psd_safe_cholesky
f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}. "
gpytorch.utils.errors.NotPSDError: Matrix not positive definite after repeatedly adding jitter up to 1.0e-06. Original error on first attempt: cholesky_cpu: For batch 0: U(13,13) is zero, singular U.
Expected Behavior
For low values of n, but if n is too high, matrix B becomes singular
System information
Please complete the following information:
GPytorch version: 1.3.1
Pytorch version: 1.7.0
This can be resolved by using cholesky_jitter in the last line. This line works for me:
with gpytorch.settings.cholesky_jitter(1e-1):
beta = G.inv_matmul(Y.T[:, :, None],
left_tensor=B.inv_matmul(Kmn.evaluate()))
Numerically, what seems to be happening is that the matrix B is becoming increasingly ill-conditioned: on your dataset, the smallest eigenvalues from torch.symeig(B.evaluate()) are something like -0.3.
A more numerically stable implementation exploits G being a diagonal matrix (the sum of two diagonals):
G = DiagLazyTensor((Knn - Q).diag() + noise.unsqueeze(-1))
B = Kmm + G.inv_matmul(Kmn.transpose(-1, -2).evaluate(),
left_tensor=Kmn.evaluate())
beta = G.inv_matmul(Y.T[:, :, None],
left_tensor=B.inv_matmul(Kmn.evaluate()))
The second fix is probably the one you should use in this setting.
Now, B surprisingly has only positive eigenvalues. What's happening is that gpytorch doesn't pick up that G is diagonal (how could it given the sum) and then is running CG after n > 800.
🐛 Bug
Hello,
I'm using a sparse multitask GP to learn a dynamical model in a reinforcement learning problem. I'm then using the model to compute Moment Matching predictions at uncertain inputs.
It works well up to a certain amount of points.
To reproduce
** Code snippet to reproduce **
It works for
n=800
but not forn=900
** Stack trace/error message **
Expected Behavior
For low values of n, but if n is too high, matrix
B
becomes singularSystem information
Please complete the following information:
GPytorch version: 1.3.1
Pytorch version: 1.7.0
OS:
$lsb_release - a
Distributor ID: Debian
Description: Debian GNU/Linux 9.13 (stretch)
Release: 9.13
Codename: stretch
Additional context
In the RL context, we should be able to compute the predictions as$n \rightarrow \infty$
Reference for MM prediction: Peter Deisenroth, M. (2010). Efficient Reinforcement Learning using Gaussian Processes, chapter 2.4
The text was updated successfully, but these errors were encountered: