Skip to content

Commit 0ae019f

Browse files
committed
fix nearest_neighbor_variational_strategy batch compatibility
1 parent 84a9d77 commit 0ae019f

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

gpytorch/variational/nearest_neighbor_variational_strategy.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat
138138
if x is not None:
139139
assert self.inducing_points.shape[:-2] == x.shape[:-2], (
140140
f"x batch shape must matches inducing points batch shape, "
141-
f"but got train data batch shape = {x.shape[:-2]}, "
141+
f"but got x batch shape = {x.shape[:-2]}, "
142142
f"inducing points batch shape = {self.inducing_points.shape[:-2]}."
143143
)
144144

@@ -211,6 +211,7 @@ def forward(
211211
nn_indices = self.nn_util.find_nn_idx(x.float())
212212

213213
x_batch_shape = x.shape[:-2]
214+
batch_shape = torch.broadcast_shapes(self._model_batch_shape, x_batch_shape)
214215
x_bsz = x.shape[-2]
215216
assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape
216217

@@ -221,7 +222,6 @@ def forward(
221222
assert inducing_points.shape == (*x_batch_shape, x_bsz, self.k, self.D)
222223

223224
# get variational mean and covar for nearest neighbors
224-
batch_shape = torch.broadcast_shapes(self._model_batch_shape, x_batch_shape)
225225
inducing_values = self._variational_distribution.variational_mean
226226
expanded_inducing_values = inducing_values.unsqueeze(-1).expand(*batch_shape, self.M, self.k)
227227
expanded_nn_indices = nn_indices.expand(*batch_shape, x_bsz, self.k)
@@ -239,16 +239,20 @@ def forward(
239239
# Make everything batch mode
240240
x = x.unsqueeze(-2)
241241
assert x.shape == (*x_batch_shape, x_bsz, 1, self.D)
242+
x = x.expand(*batch_shape, x_bsz, 1, self.D)
242243

243244
# Compute forward mode in the standard way
244-
_x_batch_dims = tuple(range(len(x_batch_shape)))
245-
_x = x.permute((-3,) + _x_batch_dims + (-2, -1))
246-
_inducing_points = inducing_points.permute((-3,) + _x_batch_dims + (-2, -1))
247-
_inducing_values = inducing_values.permute((-2,) + _x_batch_dims + (-1,))
248-
_variational_inducing_covar = variational_inducing_covar.permute((-3,) + _x_batch_dims + (-2, -1))
245+
_batch_dims = tuple(range(len(batch_shape)))
246+
_x = x.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, 1, D)
247+
248+
# inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D)
249+
inducing_points = inducing_points.expand(*batch_shape, x_bsz, self.k, self.D)
250+
_inducing_points = inducing_points.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, k, D)
251+
_inducing_values = inducing_values.permute((-2,) + _batch_dims + (-1,))
252+
_variational_inducing_covar = variational_inducing_covar.permute((-3,) + _batch_dims + (-2, -1))
249253
dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs)
250254

251-
_x_batch_dims = tuple(range(1, 1 + len(x_batch_shape)))
255+
_x_batch_dims = tuple(range(1, 1 + len(batch_shape)))
252256
predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1)
253257
predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1)
254258
predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1))

test/variational/test_nearest_neighbor_variational_strategy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _training_iter(
115115
return output, loss
116116

117117
def _eval_iter(self, model, cuda=False):
118-
inducing_batch_shape = model.variational_strategy.inducing_points.shape[:-2]
118+
inducing_batch_shape = model.variational_strategy._inducing_batch_shape
119119
test_x = torch.randn(*inducing_batch_shape, 32, 2).clamp(-2.5, 2.5)
120120
if cuda:
121121
test_x = test_x.cuda()

0 commit comments

Comments
 (0)