Skip to content

Commit a10155c

Browse files
committed
Update typehints for Variational classes
1 parent cea9e6d commit a10155c

13 files changed

+167
-86
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ __pycache__/
2020
# C extensions
2121
*.so
2222

23+
# Type checking
24+
.pyre/
25+
2326
# Distribution / packaging
2427
.Python
2528
env/

.pyre_configuration

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"site_package_search_strategy": "pep561",
3+
"source_directories": [
4+
{"import_root": ".", "source": "gpytorch/"}
5+
],
6+
"search_path": [
7+
".",
8+
"../linear_operator",
9+
{"site-package": "faiss"},
10+
{"site-package": "linear_operator"},
11+
{"site-package": "pykeops"},
12+
{"site-package": "pyro"},
13+
{"site-package": "scipy"},
14+
{"site-package": "sklearn"}
15+
],
16+
"strict": true
17+
}

gpytorch/module.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import itertools
66
import operator
77
from collections import OrderedDict
8+
from typing import Union
89

910
import torch
1011
from linear_operator.operators import LinearOperator
11-
from torch import nn
12+
from torch import nn, Tensor
1213
from torch.distributions import Distribution
1314

1415
from .constraints import Interval
@@ -56,7 +57,7 @@ def added_loss_terms(self):
5657
for _, strategy in self.named_added_loss_terms():
5758
yield strategy
5859

59-
def forward(self, *inputs, **kwargs):
60+
def forward(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
6061
raise NotImplementedError
6162

6263
def constraints(self):

gpytorch/variational/_variational_distribution.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from ..distributions import Distribution, MultivariateNormal
78
from ..module import Module
89

910

@@ -15,21 +16,21 @@ class _VariationalDistribution(Module, ABC):
1516
:ivar torch.dtype device: The device of the VariationalDistribution parameters
1617
"""
1718

18-
def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3):
19+
def __init__(self, num_inducing_points: int, batch_shape: torch.Size = torch.Size([]), mean_init_std: float = 1e-3):
1920
super().__init__()
2021
self.num_inducing_points = num_inducing_points
2122
self.batch_shape = batch_shape
2223
self.mean_init_std = mean_init_std
2324

2425
@property
25-
def device(self):
26+
def device(self) -> torch.device:
2627
return next(self.parameters()).device
2728

2829
@property
29-
def dtype(self):
30+
def dtype(self) -> torch.dtype:
3031
return next(self.parameters()).dtype
3132

32-
def forward(self):
33+
def forward(self) -> Distribution:
3334
r"""
3435
Constructs and returns the variational distribution
3536
@@ -46,13 +47,13 @@ def shape(self) -> torch.Size:
4647
return torch.Size([*self.batch_shape, self.num_inducing_points])
4748

4849
@abstractmethod
49-
def initialize_variational_distribution(self, prior_dist):
50+
def initialize_variational_distribution(self, prior_dist: MultivariateNormal) -> None:
5051
r"""
5152
Method for initializing the variational distribution, based on the prior distribution.
5253
5354
:param ~gpytorch.distributions.Distribution prior_dist: The prior distribution :math:`p(\mathbf u)`.
5455
"""
5556
raise NotImplementedError
5657

57-
def __call__(self):
58+
def __call__(self) -> Distribution:
5859
return self.forward()

gpytorch/variational/_variational_strategy.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,44 @@
33
import functools
44
from abc import ABC, abstractproperty
55
from copy import deepcopy
6-
from typing import Optional, Tuple
6+
from typing import Optional, Tuple, Union
77

88
import torch
99
from linear_operator.operators import LinearOperator
1010
from torch import Tensor
1111

1212
from .. import settings
13-
from ..distributions import Delta, MultivariateNormal
13+
from ..distributions import Delta, Distribution, MultivariateNormal
14+
from ..kernels import Kernel
1415
from ..likelihoods import GaussianLikelihood
16+
from ..means import Mean
1517
from ..models import ApproximateGP, ExactGP
18+
from ..models.exact_prediction_strategies import DefaultPredictionStrategy
1619
from ..module import Module
1720
from ..utils.memoize import add_to_cache, cached, clear_cache_hook
1821
from . import _VariationalDistribution
1922

2023

2124
class _BaseExactGP(ExactGP):
22-
def __init__(self, train_inputs, train_targets, likelihood, mean_module, covar_module):
25+
def __init__(
26+
self,
27+
train_inputs: Optional[Union[Tensor, Tuple[Tensor, ...]]],
28+
train_targets: Optional[Tensor],
29+
likelihood: GaussianLikelihood,
30+
mean_module: Mean,
31+
covar_module: Kernel,
32+
):
2333
super().__init__(train_inputs, train_targets, likelihood)
2434
self.mean_module = mean_module
2535
self.covar_module = covar_module
2636

27-
def forward(self, x):
37+
def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
2838
mean = self.mean_module(x)
2939
covar = self.covar_module(x)
3040
return MultivariateNormal(mean, covar)
3141

3242

33-
def _add_cache_hook(tsr, pred_strat):
43+
def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor:
3444
if tsr.grad_fn is not None:
3545
wrapper = functools.partial(clear_cache_hook, pred_strat)
3646
functools.update_wrapper(wrapper, clear_cache_hook)
@@ -47,7 +57,7 @@ class _VariationalStrategy(Module, ABC):
4757

4858
def __init__(
4959
self,
50-
model: ApproximateGP,
60+
model: Union[ApproximateGP, "_VariationalStrategy"],
5161
inducing_points: Tensor,
5262
variational_distribution: _VariationalDistribution,
5363
learn_inducing_locations: bool = True,
@@ -73,7 +83,7 @@ def __init__(
7383
self._variational_distribution = variational_distribution
7484
self.register_buffer("variational_params_initialized", torch.tensor(0))
7585

76-
def _clear_cache(self):
86+
def _clear_cache(self) -> None:
7787
clear_cache_hook(self)
7888

7989
def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
@@ -110,7 +120,7 @@ def prior_distribution(self) -> MultivariateNormal:
110120

111121
@property
112122
@cached(name="variational_distribution_memo")
113-
def variational_distribution(self) -> MultivariateNormal:
123+
def variational_distribution(self) -> Distribution:
114124
return self._variational_distribution()
115125

116126
def forward(

gpytorch/variational/additive_grid_interpolation_variational_strategy.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
#!/usr/bin/env python3
22

3+
from typing import Iterable, Optional, Tuple
34

45
import torch
6+
from linear_operator.operators import LinearOperator
7+
from torch import LongTensor, Tensor
58

69
from ..distributions import Delta, MultivariateNormal
10+
from ..models import ApproximateGP
11+
from ..variational._variational_distribution import _VariationalDistribution
712
from ..variational.grid_interpolation_variational_strategy import GridInterpolationVariationalStrategy
813

914

1015
class AdditiveGridInterpolationVariationalStrategy(GridInterpolationVariationalStrategy):
1116
def __init__(
12-
self, model, grid_size, grid_bounds, num_dim, variational_distribution, mixing_params=False, sum_output=True
17+
self,
18+
model: ApproximateGP,
19+
grid_size: int,
20+
grid_bounds: Iterable[Tuple[float, float]],
21+
num_dim: int,
22+
variational_distribution: _VariationalDistribution,
23+
mixing_params: bool = False,
24+
sum_output: bool = True,
1325
):
1426
super(AdditiveGridInterpolationVariationalStrategy, self).__init__(
1527
model, grid_size, grid_bounds, variational_distribution
@@ -21,20 +33,17 @@ def __init__(
2133
self.register_parameter(name="mixing_params", parameter=torch.nn.Parameter(torch.ones(num_dim) / num_dim))
2234

2335
@property
24-
def prior_distribution(self):
25-
"""
26-
If desired, models can compare the input to forward to inducing_points and use a GridKernel for space
27-
efficiency.
28-
29-
However, when using a default VariationalDistribution which has an O(m^2) space complexity anyways, we find that
30-
GridKernel is typically not worth it due to the moderate slow down of using FFTs.
31-
"""
36+
def prior_distribution(self) -> MultivariateNormal:
37+
# If desired, models can compare the input to forward to inducing_points and use a GridKernel for space
38+
# efficiency.
39+
# However, when using a default VariationalDistribution which has an O(m^2) space complexity anyways,
40+
# we find that GridKernel is typically not worth it due to the moderate slow down of using FFTs.
3241
out = super(AdditiveGridInterpolationVariationalStrategy, self).prior_distribution
3342
mean = out.mean.repeat(self.num_dim, 1)
3443
covar = out.lazy_covariance_matrix.repeat(self.num_dim, 1, 1)
3544
return MultivariateNormal(mean, covar)
3645

37-
def _compute_grid(self, inputs):
46+
def _compute_grid(self, inputs: Tensor) -> Tuple[LongTensor, Tensor]:
3847
num_data, num_dim = inputs.size()
3948
inputs = inputs.transpose(0, 1).reshape(-1, 1)
4049
interp_indices, interp_values = super(AdditiveGridInterpolationVariationalStrategy, self)._compute_grid(inputs)
@@ -45,7 +54,15 @@ def _compute_grid(self, inputs):
4554
interp_values = interp_values.mul(self.mixing_params.unsqueeze(1).unsqueeze(2))
4655
return interp_indices, interp_values
4756

48-
def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None):
57+
def forward(
58+
self,
59+
x: Tensor,
60+
inducing_points: Tensor,
61+
inducing_values: Tensor,
62+
variational_inducing_covar: Optional[LinearOperator] = None,
63+
*params,
64+
**kwargs,
65+
) -> MultivariateNormal:
4966
if x.ndimension() == 1:
5067
x = x.unsqueeze(-1)
5168
elif x.ndimension() != 2:

gpytorch/variational/cholesky_variational_distribution.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,20 @@ class CholeskyVariationalDistribution(_VariationalDistribution):
1616
matrix. In order to ensure that the covariance matrix remains positive definite, we only consider the lower
1717
triangle.
1818
19-
:param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
19+
:param num_inducing_points: Size of the variational distribution. This implies that the variational mean
2020
should be this size, and the variational covariance matrix should have this many rows and columns.
2121
:param batch_shape: Specifies an optional batch size
2222
for the variational parameters. This is useful for example when doing additive variational inference.
23-
:type batch_shape: :obj:`torch.Size`, optional
24-
:param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization.
23+
:param mean_init_std: (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization.
2524
"""
2625

27-
def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3, **kwargs):
26+
def __init__(
27+
self,
28+
num_inducing_points: int,
29+
batch_shape: torch.Size = torch.Size([]),
30+
mean_init_std: float = 1e-3,
31+
**kwargs,
32+
):
2833
super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
2934
mean_init = torch.zeros(num_inducing_points)
3035
covar_init = torch.eye(num_inducing_points, num_inducing_points)
@@ -34,7 +39,7 @@ def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_st
3439
self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init))
3540
self.register_parameter(name="chol_variational_covar", parameter=torch.nn.Parameter(covar_init))
3641

37-
def forward(self):
42+
def forward(self) -> MultivariateNormal:
3843
chol_variational_covar = self.chol_variational_covar
3944
dtype = chol_variational_covar.dtype
4045
device = chol_variational_covar.device
@@ -47,7 +52,7 @@ def forward(self):
4752
variational_covar = CholLinearOperator(chol_variational_covar)
4853
return MultivariateNormal(self.variational_mean, variational_covar)
4954

50-
def initialize_variational_distribution(self, prior_dist):
55+
def initialize_variational_distribution(self, prior_dist: MultivariateNormal) -> None:
5156
self.variational_mean.data.copy_(prior_dist.mean)
5257
self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std)
5358
self.chol_variational_covar.data.copy_(prior_dist.lazy_covariance_matrix.cholesky().to_dense())

gpytorch/variational/ciq_variational_strategy.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
import torch
66
from linear_operator import to_linear_operator
7-
from linear_operator.operators import DiagLinearOperator, MatmulLinearOperator, SumLinearOperator
7+
from linear_operator.operators import DiagLinearOperator, LinearOperator, MatmulLinearOperator, SumLinearOperator
88
from linear_operator.utils import linear_cg
99
from torch import Tensor
10+
from torch.autograd.function import FunctionCtx
1011

1112
from .. import settings
12-
from ..distributions import Delta, MultivariateNormal
13+
from ..distributions import Delta, Distribution, MultivariateNormal
1314
from ..module import Module
1415
from ..utils.memoize import cached
1516
from ._variational_strategy import _VariationalStrategy
@@ -35,7 +36,10 @@ class _NgdInterpTerms(torch.autograd.Function):
3536

3637
@staticmethod
3738
def forward(
38-
ctx, interp_term: torch.Tensor, natural_vec: torch.Tensor, natural_mat: torch.Tensor
39+
ctx: FunctionCtx,
40+
interp_term: torch.Tensor,
41+
natural_vec: torch.Tensor,
42+
natural_mat: torch.Tensor,
3943
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
4044
# Compute precision
4145
prec = natural_mat.mul(-2.0)
@@ -80,8 +84,8 @@ def forward(
8084

8185
@staticmethod
8286
def backward(
83-
ctx, interp_mean_grad: torch.Tensor, interp_var_grad: torch.Tensor, kl_div_grad: torch.Tensor
84-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
87+
ctx: FunctionCtx, interp_mean_grad: torch.Tensor, interp_var_grad: torch.Tensor, kl_div_grad: torch.Tensor
88+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]:
8589
# Get the saved terms
8690
interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec = ctx.saved_tensors
8791

@@ -101,12 +105,10 @@ def backward(
101105
# interp_mean component: K^{-1/2} k
102106
# interp_var component: (k^T K^{-1/2} m) K^{-1/2} k
103107
# kl component: S^{-1} m
104-
expec_vec_grad = sum(
105-
[
106-
(interp_var_grad * interp_mean.unsqueeze(-2) * interp_term).sum(dim=-1).mul(-2),
107-
(interp_mean_grad * interp_term).sum(dim=-1),
108-
(kl_div_grad.unsqueeze(-1) * natural_vec),
109-
]
108+
expec_vec_grad = (
109+
(interp_var_grad * interp_mean.unsqueeze(-2) * interp_term).sum(dim=-1).mul(-2)
110+
+ (interp_mean_grad * interp_term).sum(dim=-1)
111+
+ (kl_div_grad.unsqueeze(-1) * natural_vec)
110112
)
111113

112114
# Compute gradient of expected matrix (mm^T + S)
@@ -179,7 +181,7 @@ def prior_distribution(self) -> MultivariateNormal:
179181

180182
@property
181183
@cached(name="variational_distribution_memo")
182-
def variational_distribution(self) -> MultivariateNormal:
184+
def variational_distribution(self) -> Distribution:
183185
if self._ngd():
184186
raise RuntimeError(
185187
"Variational distribution for NGD-CIQ should be computed during forward calls. "
@@ -192,12 +194,13 @@ def forward(
192194
x: torch.Tensor,
193195
inducing_points: torch.Tensor,
194196
inducing_values: torch.Tensor,
195-
variational_inducing_covar: Optional[MultivariateNormal] = None,
197+
variational_inducing_covar: Optional[LinearOperator] = None,
198+
*params,
196199
**kwargs,
197200
) -> MultivariateNormal:
198201
# Compute full prior distribution
199202
full_inputs = torch.cat([inducing_points, x], dim=-2)
200-
full_output = self.model.forward(full_inputs)
203+
full_output = self.model.forward(full_inputs, *params, **kwargs)
201204
full_covar = full_output.lazy_covariance_matrix
202205

203206
# Covariance terms
@@ -272,7 +275,7 @@ def kl_divergence(self) -> Tensor:
272275
else:
273276
return super().kl_divergence()
274277

275-
def __call__(self, x: torch.Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
278+
def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> MultivariateNormal:
276279
# This is mostly the same as _VariationalStrategy.__call__()
277280
# but with special rules for natural gradient descent (to prevent O(M^3) computation)
278281

@@ -310,6 +313,7 @@ def __call__(self, x: torch.Tensor, prior: bool = False, **kwargs) -> Multivaria
310313
inducing_points,
311314
inducing_values=None,
312315
variational_inducing_covar=None,
316+
*params,
313317
**kwargs,
314318
)
315319
else:
@@ -332,7 +336,6 @@ def __call__(self, x: torch.Tensor, prior: bool = False, **kwargs) -> Multivaria
332336
inducing_points,
333337
inducing_values=variational_dist_u.mean,
334338
variational_inducing_covar=None,
335-
ngd=False,
336339
**kwargs,
337340
)
338341
else:

0 commit comments

Comments
 (0)