diff --git a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py index 610a15ad0..3efe398b4 100644 --- a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py +++ b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py @@ -161,7 +161,7 @@ def _getitem(self, row_index, col_index, *batch_indices): # However - if we have multiple outputs per input, then the indices won't directly # correspond to the entries of row/col. We'll have to do a little pre-processing if num_outs_per_in_rows != 1 or num_outs_per_in_cols != 1: - if not isinstance(x1, slice) or not isinstance(x2, slice): + if not isinstance(row_index, slice) or not isinstance(col_index, slice): # It's too complicated to deal with tensor indices in this case - we'll use the super method return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices) @@ -169,13 +169,13 @@ def _getitem(self, row_index, col_index, *batch_indices): # Let's make sure that the slice dimensions perfectly correspond with the number of # outputs per input that we have row_start, row_end, row_step = ( - row_index.start, - row_index.stop, + row_index.start or 0, + row_index.stop or self.shape[-2], row_index.step, ) col_start, col_end, col_step = ( - col_index.start, - col_index.stop, + col_index.start or 0, + col_index.stop or self.shape[-1], col_index.step, ) if row_step is not None or col_step is not None: diff --git a/test/lazy/test_lazy_evaluated_kernel_tensor.py b/test/lazy/test_lazy_evaluated_kernel_tensor.py index 06d937d0c..7bb9e00af 100644 --- a/test/lazy/test_lazy_evaluated_kernel_tensor.py +++ b/test/lazy/test_lazy_evaluated_kernel_tensor.py @@ -122,6 +122,18 @@ def test_batch_getitem(self): self.assertEqual(k.size(), torch.Size([2, 5, 5])) self.assertEqual(k[..., :4, :3].size(), torch.Size([2, 4, 3])) + def test_batch_getitem_multioutput(self): + """Ensure slicing is efficient when using a multioutput kernel""" + x1 = torch.randn(5, 6) + x2 = torch.randn(5, 6) + kern = gpytorch.kernels.RBFKernelGrad(batch_shape=torch.Size([2])) + k = kern(x1, x2) + k.evaluate_kernel = MagicMock(name="evaluate_kernel") + k_sliced = k[..., :7, :14] + self.assertFalse(k.evaluate_kernel.called) + self.assertEqual(k.size(), torch.Size([2, 35, 35])) + self.assertEqual(k_sliced.size(), torch.Size([2, 7, 14])) + def test_getitem_tensor_index(self): # Not supported a.t.m. with LazyEvaluatedKernelTensors pass