Skip to content

Commit 6b69b5d

Browse files
committed
gpytorch.settings.variational_cholesky_jitter can be set dynamically.
Previously, this context manager was only used when VariationalStrategy modules were initialized. With this change, gpytorch.settings.variational_cholesky_jitter will dynamically change the jitter value (for variational models already in use), unless the user specifies a `jitter_val` in the VariationalStrategy constructor. In addition, this PR adds type hintsd to a majority of the VariationalStrategy modules. [Fixes #2244]
1 parent 53c2c62 commit 6b69b5d

8 files changed

+221
-114
lines changed

gpytorch/variational/_variational_strategy.py

+52-30
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
import functools
44
from abc import ABC, abstractproperty
55
from copy import deepcopy
6+
from typing import Optional, Tuple
67

78
import torch
9+
from linear_operator.operators import LinearOperator
10+
from torch import Tensor
811

912
from .. import settings
1013
from ..distributions import Delta, MultivariateNormal
1114
from ..likelihoods import GaussianLikelihood
12-
from ..models import ExactGP
15+
from ..models import ApproximateGP, ExactGP
1316
from ..module import Module
1417
from ..utils.memoize import add_to_cache, cached, clear_cache_hook
18+
from . import _VariationalDistribution
1519

1620

1721
class _BaseExactGP(ExactGP):
@@ -42,14 +46,16 @@ class _VariationalStrategy(Module, ABC):
4246
has_fantasy_strategy = False
4347

4448
def __init__(
45-
self, model, inducing_points, variational_distribution, learn_inducing_locations=True, jitter_val=None
49+
self,
50+
model: ApproximateGP,
51+
inducing_points: Tensor,
52+
variational_distribution: _VariationalDistribution,
53+
learn_inducing_locations: bool = True,
54+
jitter_val: Optional[float] = None,
4655
):
4756
super().__init__()
4857

49-
if jitter_val is None:
50-
self.jitter_val = settings.variational_cholesky_jitter.value(inducing_points.dtype)
51-
else:
52-
self.jitter_val = jitter_val
58+
self._jitter_val = jitter_val
5359

5460
# Model
5561
object.__setattr__(self, "model", model)
@@ -70,7 +76,7 @@ def __init__(
7076
def _clear_cache(self):
7177
clear_cache_hook(self)
7278

73-
def _expand_inputs(self, x, inducing_points):
79+
def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
7480
"""
7581
Pre-processing step in __call__ to make x the same batch_shape as the inducing points
7682
"""
@@ -79,9 +85,19 @@ def _expand_inputs(self, x, inducing_points):
7985
x = x.expand(*batch_shape, *x.shape[-2:])
8086
return x, inducing_points
8187

88+
@property
89+
def jitter_val(self) -> float:
90+
if self._jitter_val is None:
91+
return settings.variational_cholesky_jitter.value(dtype=self.inducing_points.dtype)
92+
return self._jitter_val
93+
94+
@jitter_val.setter
95+
def jitter_val(self, jitter_val: float):
96+
self._jitter_val = jitter_val
97+
8298
@abstractproperty
8399
@cached(name="prior_distribution_memo")
84-
def prior_distribution(self):
100+
def prior_distribution(self) -> MultivariateNormal:
85101
r"""
86102
The :func:`~gpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the
87103
GP prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))`. Most commonly,
@@ -94,22 +110,29 @@ def prior_distribution(self):
94110

95111
@property
96112
@cached(name="variational_distribution_memo")
97-
def variational_distribution(self):
113+
def variational_distribution(self) -> MultivariateNormal:
98114
return self._variational_distribution()
99115

100-
def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None, **kwargs):
116+
def forward(
117+
self,
118+
x: Tensor,
119+
inducing_points: Tensor,
120+
inducing_values: Tensor,
121+
variational_inducing_covar: Optional[LinearOperator] = None,
122+
**kwargs,
123+
) -> MultivariateNormal:
101124
r"""
102125
The :func:`~gpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
103126
inducing point function values. Specifically, forward defines how to transform a variational distribution
104127
over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
105128
specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`
106129
107-
:param torch.Tensor x: Locations :math:`\mathbf X` to get the
130+
:param x: Locations :math:`\mathbf X` to get the
108131
variational posterior of the function values at.
109-
:param torch.Tensor inducing_points: Locations :math:`\mathbf Z` of the inducing points
110-
:param torch.Tensor inducing_values: Samples of the inducing function values :math:`\mathbf u`
132+
:param inducing_points: Locations :math:`\mathbf Z` of the inducing points
133+
:param inducing_values: Samples of the inducing function values :math:`\mathbf u`
111134
(or the mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian.
112-
:param ~linear_operator.operators.LinearOperator variational_inducing_covar: If
135+
:param variational_inducing_covar: If
113136
the distribuiton :math:`q(\mathbf u)` is
114137
Gaussian, then this variable is the covariance matrix of that Gaussian.
115138
Otherwise, it will be None.
@@ -119,19 +142,19 @@ def forward(self, x, inducing_points, inducing_values, variational_inducing_cova
119142
"""
120143
raise NotImplementedError
121144

122-
def kl_divergence(self):
145+
def kl_divergence(self) -> Tensor:
123146
r"""
124147
Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
125148
and the prior inducing distribution :math:`p(\mathbf u)`.
126-
127-
:rtype: torch.Tensor
128149
"""
129150
with settings.max_preconditioner_size(0):
130151
kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
131152
return kl_divergence
132153

133154
@cached(name="amortized_exact_gp")
134-
def amortized_exact_gp(self, mean_module=None, covar_module=None):
155+
def amortized_exact_gp(
156+
self, mean_module: Optional[Module] = None, covar_module: Optional[Module] = None
157+
) -> ExactGP:
135158
mean_module = self.model.mean_module if mean_module is None else mean_module
136159
covar_module = self.model.covar_module if covar_module is None else covar_module
137160

@@ -186,17 +209,17 @@ def amortized_exact_gp(self, mean_module=None, covar_module=None):
186209
inducing_exact_model.prediction_strategy = pred_strat
187210
return inducing_exact_model
188211

189-
def pseudo_points(self):
212+
def pseudo_points(self) -> Tuple[Tensor, Tensor]:
190213
raise NotImplementedError("Each variational strategy must implement its own pseudo points method")
191214

192215
def get_fantasy_model(
193216
self,
194-
inputs,
195-
targets,
196-
mean_module=None,
197-
covar_module=None,
217+
inputs: Tensor,
218+
targets: Tensor,
219+
mean_module: Optional[Module] = None,
220+
covar_module: Optional[Module] = None,
198221
**kwargs,
199-
):
222+
) -> ExactGP:
200223
r"""
201224
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
202225
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing
@@ -211,17 +234,16 @@ def get_fantasy_model(
211234
modules are attributes of the model itself called mean_module and covar_module respectively OR that you
212235
pass them into this method explicitly.
213236
214-
:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
237+
:param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
215238
observations.
216-
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
217-
:param torch.nn.Module mean_module: torch module describing the mean function of the GP model. Optional if
239+
:param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
240+
:param mean_module: torch module describing the mean function of the GP model. Optional if
218241
`mean_module` is already an attribute of the variational GP.
219-
:param torch.nn.Module covar_module: torch module describing the covariance function of the GP model. Optional
242+
:param covar_module: torch module describing the covariance function of the GP model. Optional
220243
if `covar_module` is already an attribute of the variational GP.
221244
:return: An `ExactGP` model with `k + m` training examples, where the `m` fantasy examples have been added
222245
and all test-time caches have been updated. We assume that there are `k` inducing points in this variational
223246
GP. Note that we return an `ExactGP` rather than a variational GP.
224-
:rtype: ~gpytorch.models.ExactGP
225247
226248
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
227249
Maddox, Stanton, Wilson, NeurIPS, '21
@@ -282,7 +304,7 @@ def get_fantasy_model(
282304
fantasy_model.prediction_strategy = fant_pred_strat
283305
return fantasy_model
284306

285-
def __call__(self, x, prior=False, **kwargs):
307+
def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
286308
# If we're in prior mode, then we're done!
287309
if prior:
288310
return self.model.forward(x, **kwargs)

gpytorch/variational/batch_decoupled_variational_strategy.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
#!/usr/bin/env python3
22

3+
from typing import Optional, Tuple
4+
35
import torch
4-
from linear_operator.operators import MatmulLinearOperator, SumLinearOperator
6+
from linear_operator.operators import LinearOperator, MatmulLinearOperator, SumLinearOperator
7+
from torch import Tensor
58
from torch.distributions.kl import kl_divergence
69

710
from ..distributions import Delta, MultivariateNormal
11+
from ..models import ApproximateGP
812
from ..utils.errors import CachingError
913
from ..utils.memoize import pop_from_cache_ignore_args
14+
from ._variational_distribution import _VariationalDistribution
1015
from .delta_variational_distribution import DeltaVariationalDistribution
1116
from .variational_strategy import VariationalStrategy
1217

@@ -58,21 +63,20 @@ class BatchDecoupledVariationalStrategy(VariationalStrategy):
5863
:obj:`~gpytorch.variational.OrthogonallyDecoupledVariationalStrategy` (a variant proposed by
5964
`Salimbeni et al. (2018)`_ that uses orthogonal projections.)
6065
61-
:param ~gpytorch.models.ApproximateGP model: Model this strategy is applied to.
66+
:param model: Model this strategy is applied to.
6267
Typically passed in when the VariationalStrategy is created in the
6368
__init__ method of the user defined model.
64-
:param torch.Tensor inducing_points: Tensor containing a set of inducing
69+
:param inducing_points: Tensor containing a set of inducing
6570
points to use for variational inference.
66-
:param ~gpytorch.variational.VariationalDistribution variational_distribution: A
71+
:param variational_distribution: A
6772
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
6873
:param learn_inducing_locations: (Default True): Whether or not
6974
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
7075
parameters of the model).
71-
:type learn_inducing_locations: `bool`, optional
72-
:type mean_var_batch_dim: `int`, optional
7376
:param mean_var_batch_dim: (Default `None`):
7477
Set this parameter (ideally to `-1`) to indicate which dimension corresponds to different
7578
kernel hyperparameters for the mean/variance functions.
79+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
7680
7781
.. _Cheng et al. (2017):
7882
https://arxiv.org/abs/1711.10127
@@ -133,12 +137,12 @@ class BatchDecoupledVariationalStrategy(VariationalStrategy):
133137

134138
def __init__(
135139
self,
136-
model,
137-
inducing_points,
138-
variational_distribution,
139-
learn_inducing_locations=True,
140-
mean_var_batch_dim=None,
141-
jitter_val=None,
140+
model: ApproximateGP,
141+
inducing_points: Tensor,
142+
variational_distribution: _VariationalDistribution,
143+
learn_inducing_locations: bool = True,
144+
mean_var_batch_dim: Optional[int] = None,
145+
jitter_val: Optional[float] = None,
142146
):
143147
if isinstance(variational_distribution, DeltaVariationalDistribution):
144148
raise NotImplementedError(
@@ -163,15 +167,22 @@ def __init__(
163167
model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val
164168
)
165169

166-
def _expand_inputs(self, x, inducing_points):
170+
def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
167171
# If we haven't explicitly marked a dimension as batch, add the corresponding batch dimension to the input
168172
if self.mean_var_batch_dim is None:
169173
x = x.unsqueeze(-3)
170174
else:
171175
x = x.unsqueeze(self.mean_var_batch_dim - 2)
172176
return super()._expand_inputs(x, inducing_points)
173177

174-
def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None, **kwargs):
178+
def forward(
179+
self,
180+
x: Tensor,
181+
inducing_points: Tensor,
182+
inducing_values: Tensor,
183+
variational_inducing_covar: Optional[LinearOperator] = None,
184+
**kwargs,
185+
) -> MultivariateNormal:
175186
# We'll compute the covariance, and cross-covariance terms for both the
176187
# pred-mean and pred-covar, using their different inducing points (and maybe kernel hypers)
177188

@@ -225,7 +236,7 @@ def forward(self, x, inducing_points, inducing_values, variational_inducing_cova
225236

226237
return MultivariateNormal(predictive_mean, predictive_covar)
227238

228-
def kl_divergence(self):
239+
def kl_divergence(self) -> Tensor:
229240
variational_dist = self.variational_distribution
230241
prior_dist = self.prior_distribution
231242

gpytorch/variational/ciq_variational_strategy.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from linear_operator import to_linear_operator
77
from linear_operator.operators import DiagLinearOperator, MatmulLinearOperator, SumLinearOperator
88
from linear_operator.utils import linear_cg
9+
from torch import Tensor
910

1011
from .. import settings
1112
from ..distributions import Delta, MultivariateNormal
@@ -141,17 +142,17 @@ class CiqVariationalStrategy(_VariationalStrategy):
141142
:obj:`~gpytorch.variational.NaturalVariationalDistribution` and
142143
`natural gradient descent`_.
143144
144-
:param ~gpytorch.models.ApproximateGP model: Model this strategy is applied to.
145+
:param model: Model this strategy is applied to.
145146
Typically passed in when the VariationalStrategy is created in the
146147
__init__ method of the user defined model.
147-
:param torch.Tensor inducing_points: Tensor containing a set of inducing
148+
:param inducing_points: Tensor containing a set of inducing
148149
points to use for variational inference.
149-
:param ~gpytorch.variational.VariationalDistribution variational_distribution: A
150+
:param variational_distribution: A
150151
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
151152
:param learn_inducing_locations: (Default True): Whether or not
152153
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
153154
parameters of the model).
154-
:type learn_inducing_locations: `bool`, optional
155+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
155156
156157
.. _Pleiss et al. (2020):
157158
https://arxiv.org/pdf/2006.11267.pdf
@@ -161,12 +162,12 @@ class CiqVariationalStrategy(_VariationalStrategy):
161162
examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.html
162163
"""
163164

164-
def _ngd(self):
165+
def _ngd(self) -> bool:
165166
return isinstance(self._variational_distribution, NaturalVariationalDistribution)
166167

167168
@property
168169
@cached(name="prior_distribution_memo")
169-
def prior_distribution(self):
170+
def prior_distribution(self) -> MultivariateNormal:
170171
zeros = torch.zeros(
171172
self._variational_distribution.shape(),
172173
dtype=self._variational_distribution.dtype,
@@ -178,7 +179,7 @@ def prior_distribution(self):
178179

179180
@property
180181
@cached(name="variational_distribution_memo")
181-
def variational_distribution(self):
182+
def variational_distribution(self) -> MultivariateNormal:
182183
if self._ngd():
183184
raise RuntimeError(
184185
"Variational distribution for NGD-CIQ should be computed during forward calls. "
@@ -253,8 +254,8 @@ def forward(
253254
# Return the distribution
254255
return MultivariateNormal(predictive_mean, predictive_covar)
255256

256-
def kl_divergence(self):
257-
r"""
257+
def kl_divergence(self) -> Tensor:
258+
"""
258259
Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
259260
and the prior inducing distribution :math:`p(\mathbf u)`.
260261

0 commit comments

Comments
 (0)