4
4
5
5
import torch
6
6
from linear_operator import to_linear_operator
7
- from linear_operator .operators import DiagLinearOperator , MatmulLinearOperator , SumLinearOperator
7
+ from linear_operator .operators import DiagLinearOperator , LinearOperator , MatmulLinearOperator , SumLinearOperator
8
8
from linear_operator .utils import linear_cg
9
9
from torch import Tensor
10
+ from torch .autograd .function import FunctionCtx
10
11
11
12
from .. import settings
12
- from ..distributions import Delta , MultivariateNormal
13
+ from ..distributions import Delta , Distribution , MultivariateNormal
13
14
from ..module import Module
14
15
from ..utils .memoize import cached
15
16
from ._variational_strategy import _VariationalStrategy
@@ -35,7 +36,10 @@ class _NgdInterpTerms(torch.autograd.Function):
35
36
36
37
@staticmethod
37
38
def forward (
38
- ctx , interp_term : torch .Tensor , natural_vec : torch .Tensor , natural_mat : torch .Tensor
39
+ ctx : FunctionCtx ,
40
+ interp_term : torch .Tensor ,
41
+ natural_vec : torch .Tensor ,
42
+ natural_mat : torch .Tensor ,
39
43
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
40
44
# Compute precision
41
45
prec = natural_mat .mul (- 2.0 )
@@ -80,8 +84,8 @@ def forward(
80
84
81
85
@staticmethod
82
86
def backward (
83
- ctx , interp_mean_grad : torch .Tensor , interp_var_grad : torch .Tensor , kl_div_grad : torch .Tensor
84
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
87
+ ctx : FunctionCtx , interp_mean_grad : torch .Tensor , interp_var_grad : torch .Tensor , kl_div_grad : torch .Tensor
88
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , None ]:
85
89
# Get the saved terms
86
90
interp_term , s_times_interp_term , interp_mean , natural_vec , expec_vec , prec = ctx .saved_tensors
87
91
@@ -101,12 +105,10 @@ def backward(
101
105
# interp_mean component: K^{-1/2} k
102
106
# interp_var component: (k^T K^{-1/2} m) K^{-1/2} k
103
107
# kl component: S^{-1} m
104
- expec_vec_grad = sum (
105
- [
106
- (interp_var_grad * interp_mean .unsqueeze (- 2 ) * interp_term ).sum (dim = - 1 ).mul (- 2 ),
107
- (interp_mean_grad * interp_term ).sum (dim = - 1 ),
108
- (kl_div_grad .unsqueeze (- 1 ) * natural_vec ),
109
- ]
108
+ expec_vec_grad = (
109
+ (interp_var_grad * interp_mean .unsqueeze (- 2 ) * interp_term ).sum (dim = - 1 ).mul (- 2 )
110
+ + (interp_mean_grad * interp_term ).sum (dim = - 1 )
111
+ + (kl_div_grad .unsqueeze (- 1 ) * natural_vec )
110
112
)
111
113
112
114
# Compute gradient of expected matrix (mm^T + S)
@@ -179,7 +181,7 @@ def prior_distribution(self) -> MultivariateNormal:
179
181
180
182
@property
181
183
@cached (name = "variational_distribution_memo" )
182
- def variational_distribution (self ) -> MultivariateNormal :
184
+ def variational_distribution (self ) -> Distribution :
183
185
if self ._ngd ():
184
186
raise RuntimeError (
185
187
"Variational distribution for NGD-CIQ should be computed during forward calls. "
@@ -192,12 +194,13 @@ def forward(
192
194
x : torch .Tensor ,
193
195
inducing_points : torch .Tensor ,
194
196
inducing_values : torch .Tensor ,
195
- variational_inducing_covar : Optional [MultivariateNormal ] = None ,
197
+ variational_inducing_covar : Optional [LinearOperator ] = None ,
198
+ * params ,
196
199
** kwargs ,
197
200
) -> MultivariateNormal :
198
201
# Compute full prior distribution
199
202
full_inputs = torch .cat ([inducing_points , x ], dim = - 2 )
200
- full_output = self .model .forward (full_inputs )
203
+ full_output = self .model .forward (full_inputs , * params , ** kwargs )
201
204
full_covar = full_output .lazy_covariance_matrix
202
205
203
206
# Covariance terms
@@ -272,7 +275,7 @@ def kl_divergence(self) -> Tensor:
272
275
else :
273
276
return super ().kl_divergence ()
274
277
275
- def __call__ (self , x : torch .Tensor , prior : bool = False , ** kwargs ) -> MultivariateNormal :
278
+ def __call__ (self , x : torch .Tensor , prior : bool = False , * params , * *kwargs ) -> MultivariateNormal :
276
279
# This is mostly the same as _VariationalStrategy.__call__()
277
280
# but with special rules for natural gradient descent (to prevent O(M^3) computation)
278
281
@@ -310,6 +313,7 @@ def __call__(self, x: torch.Tensor, prior: bool = False, **kwargs) -> Multivaria
310
313
inducing_points ,
311
314
inducing_values = None ,
312
315
variational_inducing_covar = None ,
316
+ * params ,
313
317
** kwargs ,
314
318
)
315
319
else :
@@ -332,7 +336,6 @@ def __call__(self, x: torch.Tensor, prior: bool = False, **kwargs) -> Multivaria
332
336
inducing_points ,
333
337
inducing_values = variational_dist_u .mean ,
334
338
variational_inducing_covar = None ,
335
- ngd = False ,
336
339
** kwargs ,
337
340
)
338
341
else :
0 commit comments