Skip to content

Commit 5a0ff6b

Browse files
Merge pull request #1861 from dme65/master
Make Cholesky max_tries a setting
2 parents 307ef9c + 25cdcff commit 5a0ff6b

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

gpytorch/settings.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class min_variance(_dtype_value_context):
437437

438438
class cholesky_jitter(_dtype_value_context):
439439
"""
440-
The jitter value passed to `psd_safe_cholesky` when using cholesky solves.
440+
The jitter value used by `psd_safe_cholesky` when using cholesky solves.
441441
442442
- Default for `float`: 1e-6
443443
- Default for `double`: 1e-8
@@ -458,6 +458,16 @@ def value(cls, dtype=None):
458458
return super().value(dtype=dtype)
459459

460460

461+
class cholesky_max_tries(_value_context):
462+
"""
463+
The max_tries value used by `psd_safe_cholesky` when using cholesky solves.
464+
465+
(Default: 3)
466+
"""
467+
468+
_global_value = 3
469+
470+
461471
class cg_tolerance(_value_context):
462472
"""
463473
Relative residual tolerance to use for terminating CG.

gpytorch/utils/cholesky.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .warnings import NumericalWarning
1010

1111

12-
def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
12+
def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=None):
1313
# Maybe log
1414
if settings.verbose_linalg.on():
1515
settings.verbose_linalg.logger.debug(f"Running Cholesky on a matrix of size {A.shape}.")
@@ -27,6 +27,8 @@ def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
2727

2828
if jitter is None:
2929
jitter = settings.cholesky_jitter.value(A.dtype)
30+
if max_tries is None:
31+
max_tries = settings.cholesky_max_tries.value()
3032
Aprime = A.clone()
3133
jitter_prev = 0
3234
for i in range(max_tries):
@@ -45,7 +47,7 @@ def _psd_safe_cholesky(A, out=None, jitter=None, max_tries=3):
4547
raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.")
4648

4749

48-
def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=3):
50+
def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=None):
4951
"""Compute the Cholesky decomposition of A. If A is only p.s.d, add a small jitter to the diagonal.
5052
Args:
5153
:attr:`A` (Tensor):

0 commit comments

Comments
 (0)