3
3
import functools
4
4
from abc import ABC , abstractproperty
5
5
from copy import deepcopy
6
+ from typing import Optional , Tuple
6
7
7
8
import torch
9
+ from linear_operator .operators import LinearOperator
10
+ from torch import Tensor
8
11
9
12
from .. import settings
10
13
from ..distributions import Delta , MultivariateNormal
11
14
from ..likelihoods import GaussianLikelihood
12
- from ..models import ExactGP
15
+ from ..models import ApproximateGP , ExactGP
13
16
from ..module import Module
14
17
from ..utils .memoize import add_to_cache , cached , clear_cache_hook
18
+ from . import _VariationalDistribution
15
19
16
20
17
21
class _BaseExactGP (ExactGP ):
@@ -42,14 +46,16 @@ class _VariationalStrategy(Module, ABC):
42
46
has_fantasy_strategy = False
43
47
44
48
def __init__ (
45
- self , model , inducing_points , variational_distribution , learn_inducing_locations = True , jitter_val = None
49
+ self ,
50
+ model : ApproximateGP ,
51
+ inducing_points : Tensor ,
52
+ variational_distribution : _VariationalDistribution ,
53
+ learn_inducing_locations : bool = True ,
54
+ jitter_val : Optional [float ] = None ,
46
55
):
47
56
super ().__init__ ()
48
57
49
- if jitter_val is None :
50
- self .jitter_val = settings .variational_cholesky_jitter .value (inducing_points .dtype )
51
- else :
52
- self .jitter_val = jitter_val
58
+ self ._jitter_val = jitter_val
53
59
54
60
# Model
55
61
object .__setattr__ (self , "model" , model )
@@ -70,7 +76,7 @@ def __init__(
70
76
def _clear_cache (self ):
71
77
clear_cache_hook (self )
72
78
73
- def _expand_inputs (self , x , inducing_points ) :
79
+ def _expand_inputs (self , x : Tensor , inducing_points : Tensor ) -> Tuple [ Tensor , Tensor ] :
74
80
"""
75
81
Pre-processing step in __call__ to make x the same batch_shape as the inducing points
76
82
"""
@@ -79,9 +85,19 @@ def _expand_inputs(self, x, inducing_points):
79
85
x = x .expand (* batch_shape , * x .shape [- 2 :])
80
86
return x , inducing_points
81
87
88
+ @property
89
+ def jitter_val (self ) -> float :
90
+ if self ._jitter_val is None :
91
+ return settings .variational_cholesky_jitter .value (dtype = self .inducing_points .dtype )
92
+ return self ._jitter_val
93
+
94
+ @jitter_val .setter
95
+ def jitter_val (self , jitter_val : float ):
96
+ self ._jitter_val = jitter_val
97
+
82
98
@abstractproperty
83
99
@cached (name = "prior_distribution_memo" )
84
- def prior_distribution (self ):
100
+ def prior_distribution (self ) -> MultivariateNormal :
85
101
r"""
86
102
The :func:`~gpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the
87
103
GP prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))`. Most commonly,
@@ -94,22 +110,29 @@ def prior_distribution(self):
94
110
95
111
@property
96
112
@cached (name = "variational_distribution_memo" )
97
- def variational_distribution (self ):
113
+ def variational_distribution (self ) -> MultivariateNormal :
98
114
return self ._variational_distribution ()
99
115
100
- def forward (self , x , inducing_points , inducing_values , variational_inducing_covar = None , ** kwargs ):
116
+ def forward (
117
+ self ,
118
+ x : Tensor ,
119
+ inducing_points : Tensor ,
120
+ inducing_values : Tensor ,
121
+ variational_inducing_covar : Optional [LinearOperator ] = None ,
122
+ ** kwargs ,
123
+ ) -> MultivariateNormal :
101
124
r"""
102
125
The :func:`~gpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
103
126
inducing point function values. Specifically, forward defines how to transform a variational distribution
104
127
over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
105
128
specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`
106
129
107
- :param torch.Tensor x: Locations :math:`\mathbf X` to get the
130
+ :param x: Locations :math:`\mathbf X` to get the
108
131
variational posterior of the function values at.
109
- :param torch.Tensor inducing_points: Locations :math:`\mathbf Z` of the inducing points
110
- :param torch.Tensor inducing_values: Samples of the inducing function values :math:`\mathbf u`
132
+ :param inducing_points: Locations :math:`\mathbf Z` of the inducing points
133
+ :param inducing_values: Samples of the inducing function values :math:`\mathbf u`
111
134
(or the mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian.
112
- :param ~linear_operator.operators.LinearOperator variational_inducing_covar: If
135
+ :param variational_inducing_covar: If
113
136
the distribuiton :math:`q(\mathbf u)` is
114
137
Gaussian, then this variable is the covariance matrix of that Gaussian.
115
138
Otherwise, it will be None.
@@ -119,19 +142,19 @@ def forward(self, x, inducing_points, inducing_values, variational_inducing_cova
119
142
"""
120
143
raise NotImplementedError
121
144
122
- def kl_divergence (self ):
145
+ def kl_divergence (self ) -> Tensor :
123
146
r"""
124
147
Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
125
148
and the prior inducing distribution :math:`p(\mathbf u)`.
126
-
127
- :rtype: torch.Tensor
128
149
"""
129
150
with settings .max_preconditioner_size (0 ):
130
151
kl_divergence = torch .distributions .kl .kl_divergence (self .variational_distribution , self .prior_distribution )
131
152
return kl_divergence
132
153
133
154
@cached (name = "amortized_exact_gp" )
134
- def amortized_exact_gp (self , mean_module = None , covar_module = None ):
155
+ def amortized_exact_gp (
156
+ self , mean_module : Optional [Module ] = None , covar_module : Optional [Module ] = None
157
+ ) -> ExactGP :
135
158
mean_module = self .model .mean_module if mean_module is None else mean_module
136
159
covar_module = self .model .covar_module if covar_module is None else covar_module
137
160
@@ -186,17 +209,17 @@ def amortized_exact_gp(self, mean_module=None, covar_module=None):
186
209
inducing_exact_model .prediction_strategy = pred_strat
187
210
return inducing_exact_model
188
211
189
- def pseudo_points (self ):
212
+ def pseudo_points (self ) -> Tuple [ Tensor , Tensor ] :
190
213
raise NotImplementedError ("Each variational strategy must implement its own pseudo points method" )
191
214
192
215
def get_fantasy_model (
193
216
self ,
194
- inputs ,
195
- targets ,
196
- mean_module = None ,
197
- covar_module = None ,
217
+ inputs : Tensor ,
218
+ targets : Tensor ,
219
+ mean_module : Optional [ Module ] = None ,
220
+ covar_module : Optional [ Module ] = None ,
198
221
** kwargs ,
199
- ):
222
+ ) -> ExactGP :
200
223
r"""
201
224
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
202
225
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing
@@ -211,17 +234,16 @@ def get_fantasy_model(
211
234
modules are attributes of the model itself called mean_module and covar_module respectively OR that you
212
235
pass them into this method explicitly.
213
236
214
- :param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
237
+ :param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
215
238
observations.
216
- :param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
217
- :param torch.nn.Module mean_module: torch module describing the mean function of the GP model. Optional if
239
+ :param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
240
+ :param mean_module: torch module describing the mean function of the GP model. Optional if
218
241
`mean_module` is already an attribute of the variational GP.
219
- :param torch.nn.Module covar_module: torch module describing the covariance function of the GP model. Optional
242
+ :param covar_module: torch module describing the covariance function of the GP model. Optional
220
243
if `covar_module` is already an attribute of the variational GP.
221
244
:return: An `ExactGP` model with `k + m` training examples, where the `m` fantasy examples have been added
222
245
and all test-time caches have been updated. We assume that there are `k` inducing points in this variational
223
246
GP. Note that we return an `ExactGP` rather than a variational GP.
224
- :rtype: ~gpytorch.models.ExactGP
225
247
226
248
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
227
249
Maddox, Stanton, Wilson, NeurIPS, '21
@@ -282,7 +304,7 @@ def get_fantasy_model(
282
304
fantasy_model .prediction_strategy = fant_pred_strat
283
305
return fantasy_model
284
306
285
- def __call__ (self , x , prior = False , ** kwargs ):
307
+ def __call__ (self , x : Tensor , prior : bool = False , ** kwargs ) -> MultivariateNormal :
286
308
# If we're in prior mode, then we're done!
287
309
if prior :
288
310
return self .model .forward (x , ** kwargs )
0 commit comments