-
Notifications
You must be signed in to change notification settings - Fork 563
Add Fantasy Strategy for Variational GPs #1874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
f7a9dda
176ab53
7d018ad
6c5b694
0e4ca31
dd78889
45d97dd
906b40b
65fc084
e26a029
7394641
cb3cce7
1787a73
46ceffb
1a47887
704da18
785688b
9bc2bc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,46 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import functools | ||
from abc import ABC, abstractproperty | ||
from copy import deepcopy | ||
|
||
import torch | ||
|
||
from .. import settings | ||
from ..distributions import Delta, MultivariateNormal | ||
from ..models import ExactGP | ||
from ..module import Module | ||
from ..utils.broadcasting import _mul_broadcast_shape | ||
from ..utils.memoize import cached, clear_cache_hook | ||
from ..utils.memoize import add_to_cache, cached, clear_cache_hook | ||
|
||
|
||
class _BaseExactGP(ExactGP): | ||
def __init__(self, train_inputs, train_targets, likelihood, mean_module, covar_module): | ||
super().__init__(train_inputs, train_targets, likelihood) | ||
self.mean_module = mean_module | ||
self.covar_module = covar_module | ||
|
||
def forward(self, x): | ||
mean = self.mean_module(x) | ||
covar = self.covar_module(x) | ||
return MultivariateNormal(mean, covar) | ||
|
||
|
||
def _add_cache_hook(tsr, pred_strat): | ||
if tsr.grad_fn is not None: | ||
wrapper = functools.partial(clear_cache_hook, pred_strat) | ||
functools.update_wrapper(wrapper, clear_cache_hook) | ||
tsr.grad_fn.register_hook(wrapper) | ||
return tsr | ||
|
||
|
||
class _VariationalStrategy(Module, ABC): | ||
""" | ||
Abstract base class for all Variational Strategies. | ||
""" | ||
|
||
has_fantasy_strategy = False | ||
|
||
def __init__(self, model, inducing_points, variational_distribution, learn_inducing_locations=True): | ||
super().__init__() | ||
|
||
|
@@ -97,6 +122,113 @@ def kl_divergence(self): | |
kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution) | ||
return kl_divergence | ||
|
||
@cached(name="inducing_model") | ||
def inducing_model(self): | ||
with torch.no_grad(): | ||
inducing_noise_covar, inducing_mean = self.pseudo_points | ||
wjmaddox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inducing_points = self.inducing_points.detach() | ||
wjmaddox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if inducing_points.ndim < inducing_mean.ndim: | ||
inducing_points = inducing_points.expand(*inducing_mean.shape[:-2], *inducing_points.shape) | ||
# TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR | ||
new_covar_module = deepcopy(self.model.covar_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels a bit brittle: it's not guaranteed that people use the |
||
|
||
# update inducing mean if necessary | ||
inducing_mean = inducing_mean.squeeze() + self.model.mean_module(inducing_points) | ||
wjmaddox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
inducing_exact_model = _BaseExactGP( | ||
inducing_points, | ||
inducing_mean, | ||
mean_module=deepcopy(self.model.mean_module), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like this is playing the same role as this line https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/models/exact_gp.py#L223 But here we want to copy some of the attributes of one class into a completely different class. Not sure there is a general way to do this without assuming the attribute names. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be good to add an informative error message if the attributes are missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The solution we're going to try here is to require either |
||
covar_module=new_covar_module, | ||
likelihood=deepcopy(self.model.likelihood), | ||
) | ||
|
||
# now fantasize around this model | ||
# as this model is new, we need to compute a posterior to construct the prediction strategy | ||
# which uses the likelihood pseudo caches | ||
faked_points = torch.randn( | ||
*inducing_mean.shape[:-2], | ||
1, | ||
inducing_points.shape[-1], | ||
device=inducing_points.device, | ||
dtype=inducing_points.dtype, | ||
) | ||
inducing_exact_model.eval() | ||
_ = inducing_exact_model(faked_points) | ||
|
||
# then we overwrite the likelihood to take into account the multivariate normal term | ||
pred_strat = inducing_exact_model.prediction_strategy | ||
pred_strat._memoize_cache = {} | ||
with torch.no_grad(): | ||
updated_lik_train_train_covar = ( | ||
pred_strat.train_prior_dist.lazy_covariance_matrix + inducing_noise_covar | ||
) | ||
pred_strat.lik_train_train_covar = updated_lik_train_train_covar | ||
|
||
# do the mean cache because the mean cache doesn't solve against lik_train_train_covar | ||
train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs) | ||
train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1) | ||
mean_cache = updated_lik_train_train_covar.inv_matmul(train_labels_offset).squeeze(-1) | ||
mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy) | ||
add_to_cache(pred_strat, "mean_cache", mean_cache) | ||
# TODO: check to see if we need to do the covar_cache? | ||
|
||
inducing_exact_model.prediction_strategy = pred_strat | ||
return inducing_exact_model | ||
|
||
def pseudo_points(self): | ||
raise NotImplementedError("Each variational strategy must implement its own pseudo points method") | ||
|
||
def get_fantasy_model( | ||
self, | ||
inputs, | ||
targets, | ||
**kwargs, | ||
): | ||
r""" | ||
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return | ||
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing | ||
points and targets. | ||
|
||
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making," | ||
Maddox, Stanton, Wilson, NeurIPS, '21 | ||
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html | ||
""" | ||
|
||
# currently, we only support fantasization for CholeskyVariationalDistribution and | ||
# whitened / unwhitened variational strategies | ||
# from .variational_strategy import VariationalStrategy | ||
# from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy | ||
if not self.has_fantasy_strategy: | ||
raise NotImplementedError( | ||
"No fantasy model support for ", | ||
self.__name__, | ||
". Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.", | ||
) | ||
# first we construct an exact model over the inducing points with the inducing covariance | ||
# matrix | ||
inducing_exact_model = self.inducing_model() | ||
|
||
# then we update this model by adding in the inputs and pseudo targets | ||
# if inputs.shape[-2] == 1 or targets.shape[-1] != 1: | ||
# targets = targets.unsqueeze(-1) | ||
# put on a trailing bdim for bs of 1 | ||
# finally we fantasize wrt targets | ||
fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs) | ||
wjmaddox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fant_pred_strat = fantasy_model.prediction_strategy | ||
|
||
# first we update the lik_train_train_covar | ||
# do the mean cache again because the mean cache resets the likelihood forward | ||
train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs) | ||
train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1) | ||
fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition() | ||
mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1) | ||
mean_cache = _add_cache_hook(mean_cache, fant_pred_strat) | ||
add_to_cache(fant_pred_strat, "mean_cache", mean_cache) | ||
|
||
fantasy_model.prediction_strategy = fant_pred_strat | ||
return fantasy_model | ||
|
||
def __call__(self, x, prior=False, **kwargs): | ||
# If we're in prior mode, then we're done! | ||
if prior: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see why it might be convenient to always have
get_fantasy_model
return anExactGP
, regardless of the original model class, but it might be worth considering naming this something else, reservingget_fantasy_model
for the version of OVC that returns a variational GP (in other words make a package-level decision to require thatget_fantasy_model
always returns an instance of the original class).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a thought I originally had but it requires the unstable direct updates to
m
,S
in order to return its own model class itself rather than the exactGP. Although potentially lower overhead in the future to implement new fantasization strategies.