1
1
#!/usr/bin/env python3
2
2
3
+ from typing import Any , Optional
4
+
5
+ from torch import Tensor
6
+
7
+ from ..distributions import MultivariateNormal
8
+ from .exact_gp import ExactGP
9
+
3
10
from .gp import GP
4
11
from .pyro import _PyroMixin # This will only contain functions if Pyro is installed
5
12
@@ -44,38 +51,38 @@ class ApproximateGP(GP, _PyroMixin):
44
51
45
52
def __init__ (self , variational_strategy ):
46
53
super ().__init__ ()
54
+
47
55
self .variational_strategy = variational_strategy
48
56
49
- def forward (self , x ):
57
+ def forward (self , x : Tensor ):
50
58
raise NotImplementedError
51
59
52
- def pyro_guide (self , input , beta = 1.0 , name_prefix = "" ):
60
+ def pyro_guide (self , input : Tensor , beta : float = 1.0 , name_prefix : str = "" ):
53
61
r"""
54
62
(For Pyro integration only). The component of a `pyro.guide` that
55
63
corresponds to drawing samples from the latent GP function.
56
64
57
- :param torch.Tensor input: The inputs :math:`\mathbf X`.
58
- :param float beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
65
+ :param input: The inputs :math:`\mathbf X`.
66
+ :param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
59
67
term by.
60
- :param str name_prefix: (default="") A name prefix to prepend to pyro sample sites.
68
+ :param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
61
69
"""
62
70
return super ().pyro_guide (input , beta = beta , name_prefix = name_prefix )
63
71
64
- def pyro_model (self , input , beta = 1.0 , name_prefix = "" ):
72
+ def pyro_model (self , input : Tensor , beta : float = 1.0 , name_prefix : str = "" ) -> Tensor :
65
73
r"""
66
74
(For Pyro integration only). The component of a `pyro.model` that
67
75
corresponds to drawing samples from the latent GP function.
68
76
69
- :param torch.Tensor input: The inputs :math:`\mathbf X`.
70
- :param float beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
77
+ :param input: The inputs :math:`\mathbf X`.
78
+ :param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
71
79
term by.
72
- :param str name_prefix: (default="") A name prefix to prepend to pyro sample sites.
80
+ :param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
73
81
:return: samples from :math:`q(\mathbf f)`
74
- :rtype: torch.Tensor
75
82
"""
76
83
return super ().pyro_model (input , beta = beta , name_prefix = name_prefix )
77
84
78
- def get_fantasy_model (self , inputs , targets , ** kwargs ) :
85
+ def get_fantasy_model (self , inputs : Tensor , targets : Tensor , ** kwargs : Any ) -> ExactGP :
79
86
r"""
80
87
Returns a new GP model that incorporates the specified inputs and targets as new training data using
81
88
online variational conditioning (OVC).
@@ -88,12 +95,11 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
88
95
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
89
96
are the same for each target batch.
90
97
91
- :param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
98
+ :param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
92
99
observations.
93
- :param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
100
+ :param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
94
101
:return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
95
102
and all test-time caches have been updated.
96
- :rtype: ~gpytorch.models.ExactGP
97
103
98
104
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
99
105
Maddox, Stanton, Wilson, NeurIPS, '21
@@ -102,7 +108,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
102
108
"""
103
109
return self .variational_strategy .get_fantasy_model (inputs = inputs , targets = targets , ** kwargs )
104
110
105
- def __call__ (self , inputs , prior = False , ** kwargs ):
106
- if inputs .dim () == 1 :
111
+ def __call__ (self , inputs : Optional [ Tensor ] , prior : bool = False , ** kwargs ) -> MultivariateNormal :
112
+ if inputs is not None and inputs .dim () == 1 :
107
113
inputs = inputs .unsqueeze (- 1 )
108
114
return self .variational_strategy (inputs , prior = prior , ** kwargs )
0 commit comments