diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 6f9b429b4..97a1bee63 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -3,6 +3,7 @@ from typing import Any, Optional import torch +from jaxtyping import Float from linear_operator import to_dense from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator from linear_operator.utils.cholesky import psd_safe_cholesky @@ -62,7 +63,8 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` :param k: Number of nearest neighbors. :param training_batch_size: The number of data points that will be in the training batch size. - :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability + :param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability. + :param compute_full_kl: Whether to compute full kl divergence or stochastic estimate. .. _Wu et al (2022): https://arxiv.org/pdf/2202.01694.pdf @@ -75,11 +77,12 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): def __init__( self, model: ApproximateGP, - inducing_points: Tensor, - variational_distribution: _VariationalDistribution, + inducing_points: Float[Tensor, "... M D"], + variational_distribution: Float[_VariationalDistribution, "... M"], k: int, - training_batch_size: int, - jitter_val: Optional[float] = None, + training_batch_size: Optional[int] = None, + jitter_val: Optional[float] = 1e-3, + compute_full_kl: Optional[bool] = False, ): assert isinstance( variational_distribution, MeanFieldVariationalDistribution @@ -88,18 +91,15 @@ def __init__( super().__init__( model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val ) - # Make sure we don't try to initialize variational parameters - because of minibatching - self.variational_params_initialized.fill_(1) # Model object.__setattr__(self, "model", model) self.inducing_points = inducing_points - self.M: int = inducing_points.shape[-2] - self.D: int = inducing_points.shape[-1] + self.M, self.D = inducing_points.shape[-2:] self.k = k - assert self.k <= self.M, ( - f"Number of nearest neighbors k must be smaller than or equal to number of inducing points, " + assert self.k < self.M, ( + f"Number of nearest neighbors k must be smaller than the number of inducing points, " f"but got k = {k}, M = {self.M}." ) @@ -111,37 +111,61 @@ def __init__( k, dim=self.D, batch_shape=self._inducing_batch_shape, device=inducing_points.device ) self._compute_nn() + # otherwise, no nearest neighbor approximation is used - self.training_batch_size = training_batch_size + self.training_batch_size = training_batch_size if training_batch_size is not None else self.M self._set_training_iterator() + self.compute_full_kl = compute_full_kl + @property @cached(name="prior_distribution_memo") - def prior_distribution(self) -> MultivariateNormal: + def prior_distribution(self) -> Float[MultivariateNormal, "... M"]: out = self.model.forward(self.inducing_points) res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val)) return res - def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator: + def _cholesky_factor( + self, induc_induc_covar: Float[LinearOperator, "... M M"] + ) -> Float[TriangularLinearOperator, "... M M"]: # Uncached version L = psd_safe_cholesky(to_dense(induc_induc_covar)) return TriangularLinearOperator(L) - def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> MultivariateNormal: + def __call__( + self, x: Float[Tensor, "... N D"], prior: bool = False, **kwargs: Any + ) -> Float[MultivariateNormal, "... N"]: # If we're in prior mode, then we're done! if prior: return self.model.forward(x, **kwargs) if x is not None: - assert self.inducing_points.shape[:-2] == x.shape[:-2], ( - f"x batch shape must matches inducing points batch shape, " - f"but got train data batch shape = {x.shape[:-2]}, " - f"inducing points batch shape = {self.inducing_points.shape[:-2]}." - ) + # Make sure x and inducing points have the same batch shape + if not (self.inducing_points.shape[:-2] == x.shape[:-2]): + try: + x = x.expand(*self.inducing_points.shape[:-2], *x.shape[-2:]) + except RuntimeError: + raise RuntimeError( + f"x batch shape must match or broadcast with the inducing points' batch shape, " + f"but got x batch shape = {x.shape[:-2]}, " + f"inducing points batch shape = {self.inducing_points.shape[:-2]}." + ) # Delete previously cached items from the training distribution if self.training: self._clear_cache() + + # (Maybe) initialize variational distribution + if not self.variational_params_initialized.item(): + prior_dist = self.prior_distribution + self._variational_distribution.variational_mean.data.copy_(prior_dist.mean) + self._variational_distribution.variational_mean.data.add_( + torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std + ) + # initialize with a small variational stddev for quicker conv. of kl divergence + self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2)) + self.variational_params_initialized.fill_(1) + return self.forward( x, self.inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs ) @@ -152,12 +176,12 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat def forward( self, - x: Tensor, - inducing_points: Tensor, - inducing_values: Optional[Tensor] = None, - variational_inducing_covar: Optional[LinearOperator] = None, + x: Float[Tensor, "... N D"], + inducing_points: Float[Tensor, "... M D"], + inducing_values: Float[Tensor, "... M"], + variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None, **kwargs: Any, - ) -> MultivariateNormal: + ) -> Float[MultivariateNormal, "... N"]: if self.training: # In training mode, note that the full inducing points set = full training dataset # Users have the option to choose input None or a tensor of training data for x @@ -193,20 +217,20 @@ def forward( return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var)) else: - nn_indices = self.nn_util.find_nn_idx(x.float()) x_batch_shape = x.shape[:-2] + batch_shape = torch.broadcast_shapes(self._batch_shape, x_batch_shape) x_bsz = x.shape[-2] assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape + # select K nearest neighbors from inducing points for test point x expanded_nn_indices = nn_indices.unsqueeze(-1).expand(*x_batch_shape, x_bsz, self.k, self.D) expanded_inducing_points = inducing_points.unsqueeze(-2).expand(*x_batch_shape, self.M, self.k, self.D) inducing_points = expanded_inducing_points.gather(-3, expanded_nn_indices) assert inducing_points.shape == (*x_batch_shape, x_bsz, self.k, self.D) # get variational mean and covar for nearest neighbors - batch_shape = torch.broadcast_shapes(self._model_batch_shape, x_batch_shape) inducing_values = self._variational_distribution.variational_mean expanded_inducing_values = inducing_values.unsqueeze(-1).expand(*batch_shape, self.M, self.k) expanded_nn_indices = nn_indices.expand(*batch_shape, x_bsz, self.k) @@ -224,11 +248,24 @@ def forward( # Make everything batch mode x = x.unsqueeze(-2) assert x.shape == (*x_batch_shape, x_bsz, 1, self.D) + x = x.expand(*batch_shape, x_bsz, 1, self.D) # Compute forward mode in the standard way - dist = super().forward(x, inducing_points, inducing_values, variational_inducing_covar, **kwargs) - predictive_mean = dist.mean # (*batch_shape, x_bsz, 1) - predictive_covar = dist.covariance_matrix # (*batch_shape, x_bsz, 1, 1) + _batch_dims = tuple(range(len(batch_shape))) + _x = x.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, 1, D) + + # inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D) + inducing_points = inducing_points.expand(*batch_shape, x_bsz, self.k, self.D) + _inducing_points = inducing_points.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, k, D) + _inducing_values = inducing_values.permute((-2,) + _batch_dims + (-1,)) + _variational_inducing_covar = variational_inducing_covar.permute((-3,) + _batch_dims + (-2, -1)) + dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs) + + _x_batch_dims = tuple(range(1, 1 + len(batch_shape))) + predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1) + predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1) + predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1)) + predictive_covar = predictive_covar.permute(_x_batch_dims + (0, -2, -1)) # Undo batch mode predictive_mean = predictive_mean.squeeze(-1) @@ -241,8 +278,8 @@ def forward( def get_fantasy_model( self, - inputs: Tensor, - targets: Tensor, + inputs: Float[Tensor, "... N D"], + targets: Float[Tensor, "... N"], mean_module: Optional[Module] = None, covar_module: Optional[Module] = None, **kwargs, @@ -254,8 +291,15 @@ def get_fantasy_model( def _set_training_iterator(self) -> None: self._training_indices_iter = 0 - training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k - self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) + if self.training_batch_size == self.M: + self._training_indices_iterator = (torch.arange(self.M, device=self.inducing_points.device),) + else: + # The first training batch always contains the first k inducing points + # This is because computing the KL divergence for the first k inducing points is special-cased + # (since the first k inducing points have < k neighbors) + # Note that there is a special function _firstk_kl_helper for this + training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k + self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) self._total_training_batches = len(self._training_indices_iterator) def _get_training_indices(self) -> LongTensor: @@ -265,7 +309,7 @@ def _get_training_indices(self) -> LongTensor: self._set_training_iterator() return self.current_training_indices - def _firstk_kl_helper(self) -> Tensor: + def _firstk_kl_helper(self) -> Float[Tensor, "..."]: # Compute the KL divergence for first k inducing points train_x_firstk = self.inducing_points[..., : self.k, :] full_output = self.model.forward(train_x_firstk) @@ -283,77 +327,122 @@ def _firstk_kl_helper(self) -> Tensor: kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape return kl - def _stochastic_kl_helper(self, kl_indices: Tensor) -> Tensor: - # Compute the KL divergence for a mini batch of the rest M-1 inducing points + def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821 + # Compute the KL divergence for a mini batch of the rest M-k inducing points # See paper appendix for kl breakdown - kl_bs = len(kl_indices) - variational_mean = self._variational_distribution.variational_mean + kl_bs = len(kl_indices) # training_batch_size + variational_mean = self._variational_distribution.variational_mean # (*model_bs, M) variational_stddev = self._variational_distribution._variational_stddev - # compute logdet_q + # (1) compute logdet_q inducing_point_log_variational_covar = (variational_stddev[..., kl_indices] ** 2).log() - logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) + logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs - # Select a mini-batch of inducing points according to kl_indices, and their k-nearest neighbors - inducing_points = self.inducing_points[..., kl_indices, :] + # (2) compute lodet_p + # Select a mini-batch of inducing points according to kl_indices + inducing_points = self.inducing_points[..., kl_indices, :].expand(*self._batch_shape, kl_bs, self.D) + # (*bs, kl_bs, D) + # Select their K nearest neighbors nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device) + # (*bs, kl_bs, K) expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand( - *self._inducing_batch_shape, self.M, self.k, self.D + *self._batch_shape, self.M, self.k, self.D ) expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand( - *self._inducing_batch_shape, kl_bs, self.k, self.D + *self._batch_shape, kl_bs, self.k, self.D ) nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices) + # (*bs, kl_bs, K, D) + + # Compute prior distribution + # Move the kl_bs dimension to the first dimension to enable batch covar_module computation + nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._batch_shape))) + (-2, -1)) + # (kl_bs, *bs, K, D) + inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._batch_shape))) + (-1,)) + # (kl_bs, *bs, D) + full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2)) + full_mean, full_covar = full_output.mean, full_output.covariance_matrix + + # Mean terms + _undo_permute_dims = tuple(range(1, 1 + len(self._batch_shape))) + (0, -1) + nearest_neighbors_prior_mean = full_mean[..., : self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + inducing_prior_mean = full_mean[..., self.k :].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs) + # Covar terms + nearest_neighbors_prior_cov = full_covar[..., : self.k, : self.k] + nearest_neighbors_inducing_prior_cross_cov = full_covar[..., : self.k, self.k :] + inducing_prior_cov = full_covar[..., self.k :, self.k :] + inducing_prior_cov = ( + inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._batch_shape)))) + ) - # compute interp_term - cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors) - cross_cov = to_dense(self.model.covar_module.forward(nearest_neighbors, inducing_points.unsqueeze(-2))) + # Interpolation term K_nn^{-1} k_{nu} interp_term = torch.linalg.solve( - cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), cross_cov - ).squeeze(-1) - - # compte logdet_p - invquad_term_for_F = torch.sum(interp_term * cross_cov.squeeze(-1), dim=-1) - cov_inducing_points = self.model.covar_module.forward(inducing_points, inducing_points, diag=True) - F = cov_inducing_points - invquad_term_for_F + nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), + nearest_neighbors_inducing_prior_cross_cov, + ).squeeze( + -1 + ) # (kl_bs, *inducing_bs, K) + interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + nearest_neighbors_inducing_prior_cross_cov = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute( + _undo_permute_dims + ) # k_{n(j),j}, (*inducing_bs, kl_bs, K) + + invquad_term_for_F = torch.sum( + interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1 + ) # (*inducing_bs, kl_bs) + + inducing_prior_cov = self.model.covar_module.forward( + inducing_points, inducing_points, diag=True + ) # (*inducing_bs, kl_bs) + + F = inducing_prior_cov - invquad_term_for_F F = F + self.jitter_val - logdet_p = F.log().sum(dim=-1) + # K_uu - k_un K_nn^{-1} k_nu + logdet_p = F.log().sum(dim=-1) # shape: inducing_bs - # compute trace_term + # (3) compute trace_term expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_variational_mean = variational_mean.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_nearest_neighbor_indices = nearest_neighbor_indices.expand(*self._batch_shape, kl_bs, self.k) nearest_neighbor_variational_covar = ( expanded_variational_stddev.gather(-2, expanded_nearest_neighbor_indices) ** 2 + ) # (*batch_shape, kl_bs, k) + bjsquared_s_nearest_neighbors = torch.sum( + interp_term**2 * nearest_neighbor_variational_covar, dim=-1 + ) # (*batch_shape, kl_bs) + inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs) + trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum( + dim=-1 + ) # batch_shape + + # (4) compute invquad_term + nearest_neighbors_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices) + Bj_m_nearest_neighbors = torch.sum( + interp_term * (nearest_neighbors_variational_mean - nearest_neighbors_prior_mean), dim=-1 + ) + inducing_variational_mean = variational_mean[..., kl_indices] + invquad_term = torch.sum( + (inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1 ) - bjsquared_s = torch.sum(interp_term**2 * nearest_neighbor_variational_covar, dim=-1) - inducing_point_covar = variational_stddev[..., kl_indices] ** 2 - trace_term = (1.0 / F * (bjsquared_s + inducing_point_covar)).sum(dim=-1) - - # compute invquad_term - nearest_neighbor_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices) - Bj_m = torch.sum(interp_term * nearest_neighbor_variational_mean, dim=-1) - inducing_point_variational_mean = variational_mean[..., kl_indices] - invquad_term = torch.sum((inducing_point_variational_mean - Bj_m) ** 2 / F, dim=-1) kl = (logdet_p - logdet_q - kl_bs + trace_term + invquad_term) * (1.0 / 2) assert kl.shape == self._batch_shape, kl.shape - kl = kl.mean() return kl def _kl_divergence( - self, kl_indices: Optional[LongTensor] = None, compute_full: bool = False, batch_size: Optional[int] = None - ) -> Tensor: - if compute_full: + self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None + ) -> Float[Tensor, "..."]: + if self.compute_full_kl or (self._total_training_batches == 1): if batch_size is None: batch_size = self.training_batch_size kl = self._firstk_kl_helper() for kl_indices in torch.split(torch.arange(self.k, self.M), batch_size): kl += self._stochastic_kl_helper(kl_indices) else: + # compute a stochastic estimate assert kl_indices is not None - if (self._training_indices_iter == 1) or (self.M == self.k): + if self._training_indices_iter == 1: assert len(kl_indices) == self.k, ( f"kl_indices sould be the first batch data of length k, " f"but got len(kl_indices) = {len(kl_indices)} and k = {self.k}." @@ -363,7 +452,7 @@ def _kl_divergence( kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices) return kl - def kl_divergence(self) -> Tensor: + def kl_divergence(self) -> Float[Tensor, "..."]: try: return pop_from_cache(self, "kl_divergence_memo") except CachingError: @@ -374,4 +463,5 @@ def _compute_nn(self) -> "NNVariationalStrategy": inducing_points_fl = self.inducing_points.data.float() self.nn_util.set_nn_idx(inducing_points_fl) self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) + # shape (*_inducing_batch_shape, M-k, k) return self diff --git a/test/variational/test_nearest_neighbor_variational_strategy.py b/test/variational/test_nearest_neighbor_variational_strategy.py index 91a7594f7..6e86a04af 100644 --- a/test/variational/test_nearest_neighbor_variational_strategy.py +++ b/test/variational/test_nearest_neighbor_variational_strategy.py @@ -56,7 +56,10 @@ def __init__(self, inducing_points, k, training_batch_size): else: self.mean_module = gpytorch.means.ZeroMean() - self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + self.covar_module = gpytorch.kernels.ScaleKernel( + gpytorch.kernels.RBFKernel(batch_shape=batch_shape, ard_num_dims=2), + batch_shape=batch_shape, + ) def forward(self, x): mean_x = self.mean_module(x) @@ -85,7 +88,6 @@ def _training_iter( ): # We cannot inheret the superclass method # Because it sets the training data to be the inducing points - train_x = model.variational_strategy.inducing_points train_y = torch.randn(train_x.shape[:-1]) mll = mll_cls(likelihood, model, num_data=train_x.size(-2)) @@ -98,8 +100,10 @@ def _training_iter( # Single optimization iteration model.train() likelihood.train() - output = model(train_x) - loss = -mll(output, train_y) + output = model(x=None) + current_training_indices = model.variational_strategy.current_training_indices + y_batch = train_y[..., current_training_indices] + loss = -mll(output, y_batch) loss.sum().backward() # Make sure we have gradients for all parameters @@ -113,7 +117,7 @@ def _training_iter( return output, loss def _eval_iter(self, model, cuda=False): - inducing_batch_shape = model.variational_strategy.inducing_points.shape[:-2] + inducing_batch_shape = model.variational_strategy._inducing_batch_shape test_x = torch.randn(*inducing_batch_shape, 32, 2).clamp(-2.5, 2.5) if cuda: test_x = test_x.cuda() @@ -136,6 +140,8 @@ def test_training_iteration( ): # We cannot inheret the superclass method # Because it expects `variational_params_intialized` to be set to 0 + # Also, the expected output.event_shape should be the training_batch_size + # not self.event_shape (which is reserved for test_eval_iteration) # Batch shapes model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape @@ -170,7 +176,7 @@ def test_training_iteration( cuda=self.cuda, ) self.assertEqual(output.batch_shape, expected_batch_shape) - self.assertEqual(output.event_shape, self.event_shape) + self.assertEqual(output.event_shape, torch.Size([model.variational_strategy.training_batch_size])) self.assertEqual(loss.shape, expected_batch_shape) def test_training_iteration_batch_inducing(self):