From 87c98a36672c034276001d6aceba8a9d8c823ae8 Mon Sep 17 00:00:00 2001 From: Luhuan Wu Date: Sun, 25 Jun 2023 23:56:53 -0700 Subject: [PATCH 01/11] fix vnngp batch mode --- .../nearest_neighbor_variational_strategy.py | 157 ++++++++++++------ 1 file changed, 109 insertions(+), 48 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 6f9b429b4..e6c571c4a 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -8,6 +8,7 @@ from linear_operator.utils.cholesky import psd_safe_cholesky from torch import LongTensor, Tensor +from .. import settings from ..distributions import MultivariateNormal from ..models import ApproximateGP, ExactGP from ..module import Module @@ -62,7 +63,8 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` :param k: Number of nearest neighbors. :param training_batch_size: The number of data points that will be in the training batch size. - :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability + :param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability. + :param compute_full_kl: Whether to compute full kl divergence or stochastic estimate. .. _Wu et al (2022): https://arxiv.org/pdf/2202.01694.pdf @@ -79,7 +81,8 @@ def __init__( variational_distribution: _VariationalDistribution, k: int, training_batch_size: int, - jitter_val: Optional[float] = None, + jitter_val: Optional[float] = 1e-3, + compute_full_kl: Optional[bool] = False ): assert isinstance( variational_distribution, MeanFieldVariationalDistribution @@ -88,8 +91,7 @@ def __init__( super().__init__( model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val ) - # Make sure we don't try to initialize variational parameters - because of minibatching - self.variational_params_initialized.fill_(1) + #self.variational_params_initialized.fill_(1) # Model object.__setattr__(self, "model", model) @@ -111,10 +113,13 @@ def __init__( k, dim=self.D, batch_shape=self._inducing_batch_shape, device=inducing_points.device ) self._compute_nn() + # otherwise, no nearest neighbor approximation is used self.training_batch_size = training_batch_size self._set_training_iterator() + self.compute_full_kl = compute_full_kl + @property @cached(name="prior_distribution_memo") def prior_distribution(self) -> MultivariateNormal: @@ -142,6 +147,17 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat # Delete previously cached items from the training distribution if self.training: self._clear_cache() + + # (Maybe) initialize variational distribution + if not self.variational_params_initialized.item(): + prior_dist = self.prior_distribution + self._variational_distribution.variational_mean.data.copy_(prior_dist.mean) + self._variational_distribution.variational_mean.data.add_( + torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std) + # initialize with a small variational stddev for quicker conv. of kl divergence + self._variational_distribution._variational_stddev.data.copy_(1e-2) + self.variational_params_initialized.fill_(1) + return self.forward( x, self.inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs ) @@ -200,6 +216,7 @@ def forward( x_bsz = x.shape[-2] assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape + # select K nearest neighbors from inducing points for test point x expanded_nn_indices = nn_indices.unsqueeze(-1).expand(*x_batch_shape, x_bsz, self.k, self.D) expanded_inducing_points = inducing_points.unsqueeze(-2).expand(*x_batch_shape, self.M, self.k, self.D) inducing_points = expanded_inducing_points.gather(-3, expanded_nn_indices) @@ -225,10 +242,19 @@ def forward( x = x.unsqueeze(-2) assert x.shape == (*x_batch_shape, x_bsz, 1, self.D) - # Compute forward mode in the standard way - dist = super().forward(x, inducing_points, inducing_values, variational_inducing_covar, **kwargs) - predictive_mean = dist.mean # (*batch_shape, x_bsz, 1) - predictive_covar = dist.covariance_matrix # (*batch_shape, x_bsz, 1, 1) + # Compute forward mode in the standard way + _x_batch_dims = tuple(range(len(x_batch_shape))) + _x = x.permute((-3,) + _x_batch_dims + (-2,-1)) + _inducing_points = inducing_points.permute((-3,) + _x_batch_dims + (-2,-1)) + _inducing_values = inducing_values.permute((-2,) + _x_batch_dims + (-1,)) + _variational_inducing_covar = variational_inducing_covar.permute((-3,)+ _x_batch_dims + (-2,-1)) + dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs) + + _x_batch_dims = tuple(range(1, 1 + len(x_batch_shape))) + predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1) + predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1) + predictive_mean = predictive_mean.permute(_x_batch_dims + (0,-1)) + predictive_covar = predictive_covar.permute(_x_batch_dims + (0,-2,-1)) # Undo batch mode predictive_mean = predictive_mean.squeeze(-1) @@ -254,8 +280,11 @@ def get_fantasy_model( def _set_training_iterator(self) -> None: self._training_indices_iter = 0 - training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k - self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) + if self. k < self.M: + training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k + self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) + else: + self._training_indices_iterator = (torch.arange(self.k),) self._total_training_batches = len(self._training_indices_iterator) def _get_training_indices(self) -> LongTensor: @@ -280,78 +309,108 @@ def _firstk_kl_helper(self) -> Tensor: variational_inducing_covar = DiagLinearOperator(variational_covar_fisrtk) variational_distribution = MultivariateNormal(inducing_values, variational_inducing_covar) - kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape + with settings.max_preconditioner_size(0): + kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape return kl def _stochastic_kl_helper(self, kl_indices: Tensor) -> Tensor: - # Compute the KL divergence for a mini batch of the rest M-1 inducing points + # Compute the KL divergence for a mini batch of the rest M-k inducing points # See paper appendix for kl breakdown - kl_bs = len(kl_indices) - variational_mean = self._variational_distribution.variational_mean + kl_bs = len(kl_indices) # training_batch_size + variational_mean = self._variational_distribution.variational_mean # (*model_bs, M) variational_stddev = self._variational_distribution._variational_stddev - # compute logdet_q + ### (1) compute logdet_q inducing_point_log_variational_covar = (variational_stddev[..., kl_indices] ** 2).log() - logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) - - # Select a mini-batch of inducing points according to kl_indices, and their k-nearest neighbors - inducing_points = self.inducing_points[..., kl_indices, :] - nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device) + logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs + + ### (2) compute lodet_p + # Select a mini-batch of inducing points according to kl_indices + inducing_points = self.inducing_points[..., kl_indices, :] # (*inducing_bs, kl_bs, D) + # Select their K nearest neighbors + nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device) + # (*inducing_bs, kl_bs, K) expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand( *self._inducing_batch_shape, self.M, self.k, self.D - ) + ) expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand( *self._inducing_batch_shape, kl_bs, self.k, self.D ) nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices) - - # compute interp_term - cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors) - cross_cov = to_dense(self.model.covar_module.forward(nearest_neighbors, inducing_points.unsqueeze(-2))) + # (*inducing_bs, kl_bs, K, D) + + # Compute prior distribution + # Move the kl_bs dimension to the first dimension to enable batch covar_module computation + nearest_neighbors_ = nearest_neighbors.permute((-3,)+tuple(range(len(self._inducing_batch_shape)))+(-2,-1)) + # (kl_bs, *inducing_bs, K, D) + inducing_points_ = inducing_points.permute((-2,)+tuple(range(len(self._inducing_batch_shape)))+(-1,)) + # (kl_bs, *inducing_bs, D) + full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2)) + full_mean, full_covar = full_output.mean, full_output.covariance_matrix + + # Mean terms + _undo_permute_dims = tuple(range(1,1+len(self._inducing_batch_shape)))+(0,-1) + nearest_neighbors_prior_mean = full_mean[..., :self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + inducing_prior_mean = full_mean[..., self.k:].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs) + # Covar terms + nearest_neighbors_prior_cov = full_covar[..., :self.k, :self.k] + nearest_neighbors_inducing_prior_cross_cov = full_covar[..., :self.k, self.k:] + inducing_prior_cov = full_covar[..., self.k:, self.k:] + inducing_prior_cov = inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,)+tuple(range(len(self._inducing_batch_shape)))) + + # Interpolation term K_nn^{-1} k_{nu} interp_term = torch.linalg.solve( - cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), cross_cov - ).squeeze(-1) - - # compte logdet_p - invquad_term_for_F = torch.sum(interp_term * cross_cov.squeeze(-1), dim=-1) - cov_inducing_points = self.model.covar_module.forward(inducing_points, inducing_points, diag=True) - F = cov_inducing_points - invquad_term_for_F + nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), \ + nearest_neighbors_inducing_prior_cross_cov + ).squeeze(-1) # (kl_bs, *inducing_bs, K) + interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + nearest_neighbors_inducing_prior_cross_cov \ + = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute(_undo_permute_dims) # k_{n(j),j}, (*inducing_bs, kl_bs, K) + + invquad_term_for_F = torch.sum(interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1) # (*inducing_bs, kl_bs) + + inducing_prior_cov = self.model.covar_module.forward( + inducing_points, inducing_points, diag=True + ) # (*inducing_bs, kl_bs) + + F = inducing_prior_cov - invquad_term_for_F F = F + self.jitter_val - logdet_p = F.log().sum(dim=-1) + # K_uu - k_un K_nn^{-1} k_nu + logdet_p = F.log().sum(dim=-1)# shape: inducing_bs - # compute trace_term + ### (3) compute trace_term expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_variational_mean = variational_mean.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_nearest_neighbor_indices = nearest_neighbor_indices.expand(*self._batch_shape, kl_bs, self.k) nearest_neighbor_variational_covar = ( expanded_variational_stddev.gather(-2, expanded_nearest_neighbor_indices) ** 2 - ) - bjsquared_s = torch.sum(interp_term**2 * nearest_neighbor_variational_covar, dim=-1) - inducing_point_covar = variational_stddev[..., kl_indices] ** 2 - trace_term = (1.0 / F * (bjsquared_s + inducing_point_covar)).sum(dim=-1) + ) # (*batch_shape, kl_bs, k) + bjsquared_s_nearest_neighbors = torch.sum(interp_term**2 * nearest_neighbor_variational_covar, dim=-1) # (*batch_shape, kl_bs) + inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs) + trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum(dim=-1) # batch_shape - # compute invquad_term - nearest_neighbor_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices) - Bj_m = torch.sum(interp_term * nearest_neighbor_variational_mean, dim=-1) - inducing_point_variational_mean = variational_mean[..., kl_indices] - invquad_term = torch.sum((inducing_point_variational_mean - Bj_m) ** 2 / F, dim=-1) + ### (4) compute invquad_term + nearest_neighbors_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices) + Bj_m_nearest_neighbors = torch.sum(interp_term * (nearest_neighbors_variational_mean-nearest_neighbors_prior_mean), dim=-1) + inducing_variational_mean = variational_mean[..., kl_indices] + invquad_term = torch.sum((inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1) kl = (logdet_p - logdet_q - kl_bs + trace_term + invquad_term) * (1.0 / 2) assert kl.shape == self._batch_shape, kl.shape - kl = kl.mean() return kl def _kl_divergence( - self, kl_indices: Optional[LongTensor] = None, compute_full: bool = False, batch_size: Optional[int] = None + self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None ) -> Tensor: - if compute_full: + if self.compute_full_kl: if batch_size is None: batch_size = self.training_batch_size kl = self._firstk_kl_helper() for kl_indices in torch.split(torch.arange(self.k, self.M), batch_size): kl += self._stochastic_kl_helper(kl_indices) else: + # compute a stochastic estimate assert kl_indices is not None if (self._training_indices_iter == 1) or (self.M == self.k): assert len(kl_indices) == self.k, ( @@ -373,5 +432,7 @@ def _compute_nn(self) -> "NNVariationalStrategy": with torch.no_grad(): inducing_points_fl = self.inducing_points.data.float() self.nn_util.set_nn_idx(inducing_points_fl) - self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) - return self + if self.k < self.M: + self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) + # shape (*_inducing_batch_shape, M-k, k) + return self \ No newline at end of file From 1e0c505a8e18e69d6ab0235798141b9afeea4e81 Mon Sep 17 00:00:00 2001 From: Luhuan Wu Date: Mon, 26 Jun 2023 00:11:48 -0700 Subject: [PATCH 02/11] fix format issue --- .../nearest_neighbor_variational_strategy.py | 140 ++++++++++-------- 1 file changed, 76 insertions(+), 64 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index e6c571c4a..5be995cee 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -8,7 +8,7 @@ from linear_operator.utils.cholesky import psd_safe_cholesky from torch import LongTensor, Tensor -from .. import settings +from .. import settings from ..distributions import MultivariateNormal from ..models import ApproximateGP, ExactGP from ..module import Module @@ -63,7 +63,7 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` :param k: Number of nearest neighbors. :param training_batch_size: The number of data points that will be in the training batch size. - :param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability. + :param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability. :param compute_full_kl: Whether to compute full kl divergence or stochastic estimate. .. _Wu et al (2022): @@ -82,7 +82,7 @@ def __init__( k: int, training_batch_size: int, jitter_val: Optional[float] = 1e-3, - compute_full_kl: Optional[bool] = False + compute_full_kl: Optional[bool] = False, ): assert isinstance( variational_distribution, MeanFieldVariationalDistribution @@ -91,7 +91,6 @@ def __init__( super().__init__( model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val ) - #self.variational_params_initialized.fill_(1) # Model object.__setattr__(self, "model", model) @@ -113,12 +112,12 @@ def __init__( k, dim=self.D, batch_shape=self._inducing_batch_shape, device=inducing_points.device ) self._compute_nn() - # otherwise, no nearest neighbor approximation is used + # otherwise, no nearest neighbor approximation is used self.training_batch_size = training_batch_size self._set_training_iterator() - self.compute_full_kl = compute_full_kl + self.compute_full_kl = compute_full_kl @property @cached(name="prior_distribution_memo") @@ -147,13 +146,14 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat # Delete previously cached items from the training distribution if self.training: self._clear_cache() - + # (Maybe) initialize variational distribution if not self.variational_params_initialized.item(): prior_dist = self.prior_distribution self._variational_distribution.variational_mean.data.copy_(prior_dist.mean) self._variational_distribution.variational_mean.data.add_( - torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std) + torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std + ) # initialize with a small variational stddev for quicker conv. of kl divergence self._variational_distribution._variational_stddev.data.copy_(1e-2) self.variational_params_initialized.fill_(1) @@ -209,7 +209,6 @@ def forward( return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var)) else: - nn_indices = self.nn_util.find_nn_idx(x.float()) x_batch_shape = x.shape[:-2] @@ -242,19 +241,19 @@ def forward( x = x.unsqueeze(-2) assert x.shape == (*x_batch_shape, x_bsz, 1, self.D) - # Compute forward mode in the standard way + # Compute forward mode in the standard way _x_batch_dims = tuple(range(len(x_batch_shape))) - _x = x.permute((-3,) + _x_batch_dims + (-2,-1)) - _inducing_points = inducing_points.permute((-3,) + _x_batch_dims + (-2,-1)) + _x = x.permute((-3,) + _x_batch_dims + (-2, -1)) + _inducing_points = inducing_points.permute((-3,) + _x_batch_dims + (-2, -1)) _inducing_values = inducing_values.permute((-2,) + _x_batch_dims + (-1,)) - _variational_inducing_covar = variational_inducing_covar.permute((-3,)+ _x_batch_dims + (-2,-1)) + _variational_inducing_covar = variational_inducing_covar.permute((-3,) + _x_batch_dims + (-2, -1)) dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs) _x_batch_dims = tuple(range(1, 1 + len(x_batch_shape))) predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1) predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1) - predictive_mean = predictive_mean.permute(_x_batch_dims + (0,-1)) - predictive_covar = predictive_covar.permute(_x_batch_dims + (0,-2,-1)) + predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1)) + predictive_covar = predictive_covar.permute(_x_batch_dims + (0, -2, -1)) # Undo batch mode predictive_mean = predictive_mean.squeeze(-1) @@ -280,11 +279,11 @@ def get_fantasy_model( def _set_training_iterator(self) -> None: self._training_indices_iter = 0 - if self. k < self.M: + if self.k < self.M: training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) else: - self._training_indices_iterator = (torch.arange(self.k),) + self._training_indices_iterator = (torch.arange(self.k),) self._total_training_batches = len(self._training_indices_iterator) def _get_training_indices(self) -> LongTensor: @@ -317,92 +316,105 @@ def _stochastic_kl_helper(self, kl_indices: Tensor) -> Tensor: # Compute the KL divergence for a mini batch of the rest M-k inducing points # See paper appendix for kl breakdown kl_bs = len(kl_indices) # training_batch_size - variational_mean = self._variational_distribution.variational_mean # (*model_bs, M) + variational_mean = self._variational_distribution.variational_mean # (*model_bs, M) variational_stddev = self._variational_distribution._variational_stddev - ### (1) compute logdet_q + # (1) compute logdet_q inducing_point_log_variational_covar = (variational_stddev[..., kl_indices] ** 2).log() - logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs + logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs - ### (2) compute lodet_p + # (2) compute lodet_p # Select a mini-batch of inducing points according to kl_indices - inducing_points = self.inducing_points[..., kl_indices, :] # (*inducing_bs, kl_bs, D) + inducing_points = self.inducing_points[..., kl_indices, :] # (*inducing_bs, kl_bs, D) # Select their K nearest neighbors - nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device) + nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device) # (*inducing_bs, kl_bs, K) expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand( *self._inducing_batch_shape, self.M, self.k, self.D - ) + ) expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand( *self._inducing_batch_shape, kl_bs, self.k, self.D ) nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices) # (*inducing_bs, kl_bs, K, D) - - # Compute prior distribution - # Move the kl_bs dimension to the first dimension to enable batch covar_module computation - nearest_neighbors_ = nearest_neighbors.permute((-3,)+tuple(range(len(self._inducing_batch_shape)))+(-2,-1)) + + # Compute prior distribution + # Move the kl_bs dimension to the first dimension to enable batch covar_module computation + nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._inducing_batch_shape))) + (-2, -1)) # (kl_bs, *inducing_bs, K, D) - inducing_points_ = inducing_points.permute((-2,)+tuple(range(len(self._inducing_batch_shape)))+(-1,)) + inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._inducing_batch_shape))) + (-1,)) # (kl_bs, *inducing_bs, D) full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2)) - full_mean, full_covar = full_output.mean, full_output.covariance_matrix - + full_mean, full_covar = full_output.mean, full_output.covariance_matrix + # Mean terms - _undo_permute_dims = tuple(range(1,1+len(self._inducing_batch_shape)))+(0,-1) - nearest_neighbors_prior_mean = full_mean[..., :self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) - inducing_prior_mean = full_mean[..., self.k:].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs) - # Covar terms - nearest_neighbors_prior_cov = full_covar[..., :self.k, :self.k] - nearest_neighbors_inducing_prior_cross_cov = full_covar[..., :self.k, self.k:] - inducing_prior_cov = full_covar[..., self.k:, self.k:] - inducing_prior_cov = inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,)+tuple(range(len(self._inducing_batch_shape)))) + _undo_permute_dims = tuple(range(1, 1 + len(self._inducing_batch_shape))) + (0, -1) + nearest_neighbors_prior_mean = full_mean[..., : self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + inducing_prior_mean = full_mean[..., self.k :].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs) + # Covar terms + nearest_neighbors_prior_cov = full_covar[..., : self.k, : self.k] + nearest_neighbors_inducing_prior_cross_cov = full_covar[..., : self.k, self.k :] + inducing_prior_cov = full_covar[..., self.k :, self.k :] + inducing_prior_cov = ( + inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._inducing_batch_shape)))) + ) # Interpolation term K_nn^{-1} k_{nu} interp_term = torch.linalg.solve( - nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), \ - nearest_neighbors_inducing_prior_cross_cov - ).squeeze(-1) # (kl_bs, *inducing_bs, K) - interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) - nearest_neighbors_inducing_prior_cross_cov \ - = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute(_undo_permute_dims) # k_{n(j),j}, (*inducing_bs, kl_bs, K) - - invquad_term_for_F = torch.sum(interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1) # (*inducing_bs, kl_bs) - + nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), + nearest_neighbors_inducing_prior_cross_cov, + ).squeeze( + -1 + ) # (kl_bs, *inducing_bs, K) + interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + nearest_neighbors_inducing_prior_cross_cov = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute( + _undo_permute_dims + ) # k_{n(j),j}, (*inducing_bs, kl_bs, K) + + invquad_term_for_F = torch.sum( + interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1 + ) # (*inducing_bs, kl_bs) + inducing_prior_cov = self.model.covar_module.forward( inducing_points, inducing_points, diag=True - ) # (*inducing_bs, kl_bs) - + ) # (*inducing_bs, kl_bs) + F = inducing_prior_cov - invquad_term_for_F F = F + self.jitter_val # K_uu - k_un K_nn^{-1} k_nu - logdet_p = F.log().sum(dim=-1)# shape: inducing_bs + logdet_p = F.log().sum(dim=-1) # shape: inducing_bs - ### (3) compute trace_term + # (3) compute trace_term expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_variational_mean = variational_mean.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_nearest_neighbor_indices = nearest_neighbor_indices.expand(*self._batch_shape, kl_bs, self.k) nearest_neighbor_variational_covar = ( expanded_variational_stddev.gather(-2, expanded_nearest_neighbor_indices) ** 2 - ) # (*batch_shape, kl_bs, k) - bjsquared_s_nearest_neighbors = torch.sum(interp_term**2 * nearest_neighbor_variational_covar, dim=-1) # (*batch_shape, kl_bs) - inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs) - trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum(dim=-1) # batch_shape - - ### (4) compute invquad_term + ) # (*batch_shape, kl_bs, k) + bjsquared_s_nearest_neighbors = torch.sum( + interp_term**2 * nearest_neighbor_variational_covar, dim=-1 + ) # (*batch_shape, kl_bs) + inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs) + trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum( + dim=-1 + ) # batch_shape + + # (4) compute invquad_term nearest_neighbors_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices) - Bj_m_nearest_neighbors = torch.sum(interp_term * (nearest_neighbors_variational_mean-nearest_neighbors_prior_mean), dim=-1) + Bj_m_nearest_neighbors = torch.sum( + interp_term * (nearest_neighbors_variational_mean - nearest_neighbors_prior_mean), dim=-1 + ) inducing_variational_mean = variational_mean[..., kl_indices] - invquad_term = torch.sum((inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1) + invquad_term = torch.sum( + (inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1 + ) kl = (logdet_p - logdet_q - kl_bs + trace_term + invquad_term) * (1.0 / 2) assert kl.shape == self._batch_shape, kl.shape return kl - def _kl_divergence( - self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None - ) -> Tensor: + def _kl_divergence(self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None) -> Tensor: if self.compute_full_kl: if batch_size is None: batch_size = self.training_batch_size @@ -435,4 +447,4 @@ def _compute_nn(self) -> "NNVariationalStrategy": if self.k < self.M: self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) # shape (*_inducing_batch_shape, M-k, k) - return self \ No newline at end of file + return self From 0dd0908793c38e7edf80bf4ad63700148a2480b2 Mon Sep 17 00:00:00 2001 From: Luhuan Wu Date: Sat, 8 Jul 2023 23:27:49 -0700 Subject: [PATCH 03/11] fix nearest_neighbor_variational_strategy batch compatibility --- .../nearest_neighbor_variational_strategy.py | 20 +++++++++++-------- ...t_nearest_neighbor_variational_strategy.py | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 5be995cee..299bcc53b 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -139,7 +139,7 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat if x is not None: assert self.inducing_points.shape[:-2] == x.shape[:-2], ( f"x batch shape must matches inducing points batch shape, " - f"but got train data batch shape = {x.shape[:-2]}, " + f"but got x batch shape = {x.shape[:-2]}, " f"inducing points batch shape = {self.inducing_points.shape[:-2]}." ) @@ -212,6 +212,7 @@ def forward( nn_indices = self.nn_util.find_nn_idx(x.float()) x_batch_shape = x.shape[:-2] + batch_shape = torch.broadcast_shapes(self._model_batch_shape, x_batch_shape) x_bsz = x.shape[-2] assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape @@ -222,7 +223,6 @@ def forward( assert inducing_points.shape == (*x_batch_shape, x_bsz, self.k, self.D) # get variational mean and covar for nearest neighbors - batch_shape = torch.broadcast_shapes(self._model_batch_shape, x_batch_shape) inducing_values = self._variational_distribution.variational_mean expanded_inducing_values = inducing_values.unsqueeze(-1).expand(*batch_shape, self.M, self.k) expanded_nn_indices = nn_indices.expand(*batch_shape, x_bsz, self.k) @@ -240,16 +240,20 @@ def forward( # Make everything batch mode x = x.unsqueeze(-2) assert x.shape == (*x_batch_shape, x_bsz, 1, self.D) + x = x.expand(*batch_shape, x_bsz, 1, self.D) # Compute forward mode in the standard way - _x_batch_dims = tuple(range(len(x_batch_shape))) - _x = x.permute((-3,) + _x_batch_dims + (-2, -1)) - _inducing_points = inducing_points.permute((-3,) + _x_batch_dims + (-2, -1)) - _inducing_values = inducing_values.permute((-2,) + _x_batch_dims + (-1,)) - _variational_inducing_covar = variational_inducing_covar.permute((-3,) + _x_batch_dims + (-2, -1)) + _batch_dims = tuple(range(len(batch_shape))) + _x = x.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, 1, D) + + # inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D) + inducing_points = inducing_points.expand(*batch_shape, x_bsz, self.k, self.D) + _inducing_points = inducing_points.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, k, D) + _inducing_values = inducing_values.permute((-2,) + _batch_dims + (-1,)) + _variational_inducing_covar = variational_inducing_covar.permute((-3,) + _batch_dims + (-2, -1)) dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs) - _x_batch_dims = tuple(range(1, 1 + len(x_batch_shape))) + _x_batch_dims = tuple(range(1, 1 + len(batch_shape))) predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1) predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1) predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1)) diff --git a/test/variational/test_nearest_neighbor_variational_strategy.py b/test/variational/test_nearest_neighbor_variational_strategy.py index 91a7594f7..5576989dc 100644 --- a/test/variational/test_nearest_neighbor_variational_strategy.py +++ b/test/variational/test_nearest_neighbor_variational_strategy.py @@ -113,7 +113,7 @@ def _training_iter( return output, loss def _eval_iter(self, model, cuda=False): - inducing_batch_shape = model.variational_strategy.inducing_points.shape[:-2] + inducing_batch_shape = model.variational_strategy._inducing_batch_shape test_x = torch.randn(*inducing_batch_shape, 32, 2).clamp(-2.5, 2.5) if cuda: test_x = test_x.cuda() From dac209998e4eece8feb18d5bba298900f353f9d1 Mon Sep 17 00:00:00 2001 From: Luhuan Wu Date: Sat, 8 Jul 2023 23:34:29 -0700 Subject: [PATCH 04/11] fix nearest_neighbor_variational_strategy type error --- gpytorch/variational/nearest_neighbor_variational_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 299bcc53b..061764db0 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -155,7 +155,7 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std ) # initialize with a small variational stddev for quicker conv. of kl divergence - self._variational_distribution._variational_stddev.data.copy_(1e-2) + self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2)) self.variational_params_initialized.fill_(1) return self.forward( From 6e304a1980453d23a93b7a66ce984902399ef2ec Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:12:07 -0700 Subject: [PATCH 05/11] Improve type hints --- .../nearest_neighbor_variational_strategy.py | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 061764db0..d158bc28a 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -3,6 +3,7 @@ from typing import Any, Optional import torch +from jaxtyping import Float from linear_operator import to_dense from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator from linear_operator.utils.cholesky import psd_safe_cholesky @@ -77,10 +78,10 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): def __init__( self, model: ApproximateGP, - inducing_points: Tensor, - variational_distribution: _VariationalDistribution, + inducing_points: Float[Tensor, "... M D"], + variational_distribution: Float[_VariationalDistribution, "... M"], k: int, - training_batch_size: int, + training_batch_size: Optional[int] = None, jitter_val: Optional[float] = 1e-3, compute_full_kl: Optional[bool] = False, ): @@ -96,8 +97,7 @@ def __init__( object.__setattr__(self, "model", model) self.inducing_points = inducing_points - self.M: int = inducing_points.shape[-2] - self.D: int = inducing_points.shape[-1] + self.M, self.D = inducing_points.shape[-2:] self.k = k assert self.k <= self.M, ( f"Number of nearest neighbors k must be smaller than or equal to number of inducing points, " @@ -114,24 +114,28 @@ def __init__( self._compute_nn() # otherwise, no nearest neighbor approximation is used - self.training_batch_size = training_batch_size + self.training_batch_size = training_batch_size if training_batch_size is not None else self.M self._set_training_iterator() self.compute_full_kl = compute_full_kl @property @cached(name="prior_distribution_memo") - def prior_distribution(self) -> MultivariateNormal: + def prior_distribution(self) -> Float[MultivariateNormal, "... M"]: out = self.model.forward(self.inducing_points) res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val)) return res - def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator: + def _cholesky_factor( + self, induc_induc_covar: Float[LinearOperator, "... M M"] + ) -> Float[TriangularLinearOperator, "... M M"]: # Uncached version L = psd_safe_cholesky(to_dense(induc_induc_covar)) return TriangularLinearOperator(L) - def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> MultivariateNormal: + def __call__( + self, x: Float[Tensor, "... N D"], prior: bool = False, **kwargs: Any + ) -> Float[MultivariateNormal, "... N"]: # If we're in prior mode, then we're done! if prior: return self.model.forward(x, **kwargs) @@ -168,12 +172,12 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat def forward( self, - x: Tensor, - inducing_points: Tensor, - inducing_values: Optional[Tensor] = None, - variational_inducing_covar: Optional[LinearOperator] = None, + x: Float[Tensor, "... N D"], + inducing_points: Float[Tensor, "... M D"], + inducing_values: Float[Tensor, "... M"], + variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None, **kwargs: Any, - ) -> MultivariateNormal: + ) -> Float[MultivariateNormal, "... N"]: if self.training: # In training mode, note that the full inducing points set = full training dataset # Users have the option to choose input None or a tensor of training data for x @@ -270,8 +274,8 @@ def forward( def get_fantasy_model( self, - inputs: Tensor, - targets: Tensor, + inputs: Float[Tensor, "... N D"], + targets: Float[Tensor, "... N"], mean_module: Optional[Module] = None, covar_module: Optional[Module] = None, **kwargs, @@ -297,7 +301,7 @@ def _get_training_indices(self) -> LongTensor: self._set_training_iterator() return self.current_training_indices - def _firstk_kl_helper(self) -> Tensor: + def _firstk_kl_helper(self) -> Float[Tensor, "..."]: # Compute the KL divergence for first k inducing points train_x_firstk = self.inducing_points[..., : self.k, :] full_output = self.model.forward(train_x_firstk) @@ -316,7 +320,7 @@ def _firstk_kl_helper(self) -> Tensor: kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape return kl - def _stochastic_kl_helper(self, kl_indices: Tensor) -> Tensor: + def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821 # Compute the KL divergence for a mini batch of the rest M-k inducing points # See paper appendix for kl breakdown kl_bs = len(kl_indices) # training_batch_size @@ -418,7 +422,9 @@ def _stochastic_kl_helper(self, kl_indices: Tensor) -> Tensor: return kl - def _kl_divergence(self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None) -> Tensor: + def _kl_divergence( + self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None + ) -> Float[Tensor, "..."]: if self.compute_full_kl: if batch_size is None: batch_size = self.training_batch_size @@ -438,7 +444,7 @@ def _kl_divergence(self, kl_indices: Optional[LongTensor] = None, batch_size: Op kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices) return kl - def kl_divergence(self) -> Tensor: + def kl_divergence(self) -> Float[Tensor, "..."]: try: return pop_from_cache(self, "kl_divergence_memo") except CachingError: From f58f56edc91ee19b31b8aec851c60e959c9b0d59 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:12:31 -0700 Subject: [PATCH 06/11] Allow x to be broadcasted to inducing points batch shape --- .../nearest_neighbor_variational_strategy.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index d158bc28a..d4701230a 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -141,11 +141,16 @@ def __call__( return self.model.forward(x, **kwargs) if x is not None: - assert self.inducing_points.shape[:-2] == x.shape[:-2], ( - f"x batch shape must matches inducing points batch shape, " - f"but got x batch shape = {x.shape[:-2]}, " - f"inducing points batch shape = {self.inducing_points.shape[:-2]}." - ) + # Make sure x and inducing points have the same batch shape + if not (self.inducing_points.shape[:-2] == x.shape[:-2]): + try: + x = x.expand(*self.inducing_points.shape[:-2], *x.shape[-2:]) + except RuntimeError: + raise RuntimeError( + f"x batch shape must match or broadcast with the inducing points' batch shape, " + f"but got x batch shape = {x.shape[:-2]}, " + f"inducing points batch shape = {self.inducing_points.shape[:-2]}." + ) # Delete previously cached items from the training distribution if self.training: From 45ed860259467fcab4b39f07cf00cc5518aa6b30 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:13:37 -0700 Subject: [PATCH 07/11] Fix behaviour for training_batch_size == M Create only a single taining minibatch with all the data, and use the appropriate kl_helper --- .../nearest_neighbor_variational_strategy.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index d4701230a..c506ab993 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -292,11 +292,15 @@ def get_fantasy_model( def _set_training_iterator(self) -> None: self._training_indices_iter = 0 - if self.k < self.M: + if self.training_batch_size == self.M: + self._training_indices_iterator = (torch.arange(self.M, device=self.inducing_points.device),) + else: + # The first training batch always contains the first k inducing points + # This is because computing the KL divergence for the first k inducing points is special-cased + # (since the first k inducing points have < k neighbors) + # Note that there is a special function _firstk_kl_helper for this training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) - else: - self._training_indices_iterator = (torch.arange(self.k),) self._total_training_batches = len(self._training_indices_iterator) def _get_training_indices(self) -> LongTensor: @@ -430,7 +434,7 @@ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[T def _kl_divergence( self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None ) -> Float[Tensor, "..."]: - if self.compute_full_kl: + if self.compute_full_kl or (self._total_training_batches == 1): if batch_size is None: batch_size = self.training_batch_size kl = self._firstk_kl_helper() From d28f5997270ef9d22d1a22d779e7f8d267d56dcc Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:14:03 -0700 Subject: [PATCH 08/11] Remove unnecessary settings --- gpytorch/variational/nearest_neighbor_variational_strategy.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index c506ab993..0b7a95bf5 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -9,7 +9,6 @@ from linear_operator.utils.cholesky import psd_safe_cholesky from torch import LongTensor, Tensor -from .. import settings from ..distributions import MultivariateNormal from ..models import ApproximateGP, ExactGP from ..module import Module @@ -325,8 +324,7 @@ def _firstk_kl_helper(self) -> Float[Tensor, "..."]: variational_inducing_covar = DiagLinearOperator(variational_covar_fisrtk) variational_distribution = MultivariateNormal(inducing_values, variational_inducing_covar) - with settings.max_preconditioner_size(0): - kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape + kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape return kl def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821 From dd65761b3538ab1c3ee3f4c456eba7485049d1f8 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:14:45 -0700 Subject: [PATCH 09/11] Remove M=k case --- .../nearest_neighbor_variational_strategy.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 0b7a95bf5..e9d93583a 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -98,8 +98,8 @@ def __init__( self.inducing_points = inducing_points self.M, self.D = inducing_points.shape[-2:] self.k = k - assert self.k <= self.M, ( - f"Number of nearest neighbors k must be smaller than or equal to number of inducing points, " + assert self.k < self.M, ( + f"Number of nearest neighbors k must be smaller than the number of inducing points, " f"but got k = {k}, M = {self.M}." ) @@ -441,7 +441,7 @@ def _kl_divergence( else: # compute a stochastic estimate assert kl_indices is not None - if (self._training_indices_iter == 1) or (self.M == self.k): + if self._training_indices_iter == 1: assert len(kl_indices) == self.k, ( f"kl_indices sould be the first batch data of length k, " f"but got len(kl_indices) = {len(kl_indices)} and k = {self.k}." @@ -461,7 +461,6 @@ def _compute_nn(self) -> "NNVariationalStrategy": with torch.no_grad(): inducing_points_fl = self.inducing_points.data.float() self.nn_util.set_nn_idx(inducing_points_fl) - if self.k < self.M: - self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) - # shape (*_inducing_batch_shape, M-k, k) + self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) + # shape (*_inducing_batch_shape, M-k, k) return self From 960296989e3fa1bb6a904b36c2489d4b1cefb66e Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 15 Aug 2024 13:47:46 -0700 Subject: [PATCH 10/11] Handle mixed model/inducing point batch sizes with KL helper --- .../nearest_neighbor_variational_strategy.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index e9d93583a..97a1bee63 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -220,7 +220,7 @@ def forward( nn_indices = self.nn_util.find_nn_idx(x.float()) x_batch_shape = x.shape[:-2] - batch_shape = torch.broadcast_shapes(self._model_batch_shape, x_batch_shape) + batch_shape = torch.broadcast_shapes(self._batch_shape, x_batch_shape) x_bsz = x.shape[-2] assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape @@ -340,30 +340,31 @@ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[T # (2) compute lodet_p # Select a mini-batch of inducing points according to kl_indices - inducing_points = self.inducing_points[..., kl_indices, :] # (*inducing_bs, kl_bs, D) + inducing_points = self.inducing_points[..., kl_indices, :].expand(*self._batch_shape, kl_bs, self.D) + # (*bs, kl_bs, D) # Select their K nearest neighbors nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device) - # (*inducing_bs, kl_bs, K) + # (*bs, kl_bs, K) expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand( - *self._inducing_batch_shape, self.M, self.k, self.D + *self._batch_shape, self.M, self.k, self.D ) expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand( - *self._inducing_batch_shape, kl_bs, self.k, self.D + *self._batch_shape, kl_bs, self.k, self.D ) nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices) - # (*inducing_bs, kl_bs, K, D) + # (*bs, kl_bs, K, D) # Compute prior distribution # Move the kl_bs dimension to the first dimension to enable batch covar_module computation - nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._inducing_batch_shape))) + (-2, -1)) - # (kl_bs, *inducing_bs, K, D) - inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._inducing_batch_shape))) + (-1,)) - # (kl_bs, *inducing_bs, D) + nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._batch_shape))) + (-2, -1)) + # (kl_bs, *bs, K, D) + inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._batch_shape))) + (-1,)) + # (kl_bs, *bs, D) full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2)) full_mean, full_covar = full_output.mean, full_output.covariance_matrix # Mean terms - _undo_permute_dims = tuple(range(1, 1 + len(self._inducing_batch_shape))) + (0, -1) + _undo_permute_dims = tuple(range(1, 1 + len(self._batch_shape))) + (0, -1) nearest_neighbors_prior_mean = full_mean[..., : self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) inducing_prior_mean = full_mean[..., self.k :].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs) # Covar terms @@ -371,7 +372,7 @@ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[T nearest_neighbors_inducing_prior_cross_cov = full_covar[..., : self.k, self.k :] inducing_prior_cov = full_covar[..., self.k :, self.k :] inducing_prior_cov = ( - inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._inducing_batch_shape)))) + inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._batch_shape)))) ) # Interpolation term K_nn^{-1} k_{nu} From 4d4740ba7961af278e5699539a80a20693330761 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 15 Aug 2024 13:48:13 -0700 Subject: [PATCH 11/11] Test cases catch batch VNNGP errors --- ...test_nearest_neighbor_variational_strategy.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/variational/test_nearest_neighbor_variational_strategy.py b/test/variational/test_nearest_neighbor_variational_strategy.py index 5576989dc..6e86a04af 100644 --- a/test/variational/test_nearest_neighbor_variational_strategy.py +++ b/test/variational/test_nearest_neighbor_variational_strategy.py @@ -56,7 +56,10 @@ def __init__(self, inducing_points, k, training_batch_size): else: self.mean_module = gpytorch.means.ZeroMean() - self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + self.covar_module = gpytorch.kernels.ScaleKernel( + gpytorch.kernels.RBFKernel(batch_shape=batch_shape, ard_num_dims=2), + batch_shape=batch_shape, + ) def forward(self, x): mean_x = self.mean_module(x) @@ -85,7 +88,6 @@ def _training_iter( ): # We cannot inheret the superclass method # Because it sets the training data to be the inducing points - train_x = model.variational_strategy.inducing_points train_y = torch.randn(train_x.shape[:-1]) mll = mll_cls(likelihood, model, num_data=train_x.size(-2)) @@ -98,8 +100,10 @@ def _training_iter( # Single optimization iteration model.train() likelihood.train() - output = model(train_x) - loss = -mll(output, train_y) + output = model(x=None) + current_training_indices = model.variational_strategy.current_training_indices + y_batch = train_y[..., current_training_indices] + loss = -mll(output, y_batch) loss.sum().backward() # Make sure we have gradients for all parameters @@ -136,6 +140,8 @@ def test_training_iteration( ): # We cannot inheret the superclass method # Because it expects `variational_params_intialized` to be set to 0 + # Also, the expected output.event_shape should be the training_batch_size + # not self.event_shape (which is reserved for test_eval_iteration) # Batch shapes model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape @@ -170,7 +176,7 @@ def test_training_iteration( cuda=self.cuda, ) self.assertEqual(output.batch_shape, expected_batch_shape) - self.assertEqual(output.event_shape, self.event_shape) + self.assertEqual(output.event_shape, torch.Size([model.variational_strategy.training_batch_size])) self.assertEqual(loss.shape, expected_batch_shape) def test_training_iteration_batch_inducing(self):