Skip to content

Commit 938d4f9

Browse files
authored
Merge pull request #2376 from douglas-boubert/patch-1
Fix lazy kernel slicing when there are multiple outputs
2 parents aeb5e23 + 60d2698 commit 938d4f9

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

gpytorch/lazy/lazy_evaluated_kernel_tensor.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -161,21 +161,21 @@ def _getitem(self, row_index, col_index, *batch_indices):
161161
# However - if we have multiple outputs per input, then the indices won't directly
162162
# correspond to the entries of row/col. We'll have to do a little pre-processing
163163
if num_outs_per_in_rows != 1 or num_outs_per_in_cols != 1:
164-
if not isinstance(x1, slice) or not isinstance(x2, slice):
164+
if not isinstance(row_index, slice) or not isinstance(col_index, slice):
165165
# It's too complicated to deal with tensor indices in this case - we'll use the super method
166166
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)
167167

168168
# Now we know that x1 and x2 are slices
169169
# Let's make sure that the slice dimensions perfectly correspond with the number of
170170
# outputs per input that we have
171171
row_start, row_end, row_step = (
172-
row_index.start,
173-
row_index.stop,
172+
row_index.start or 0,
173+
row_index.stop or self.shape[-2],
174174
row_index.step,
175175
)
176176
col_start, col_end, col_step = (
177-
col_index.start,
178-
col_index.stop,
177+
col_index.start or 0,
178+
col_index.stop or self.shape[-1],
179179
col_index.step,
180180
)
181181
if row_step is not None or col_step is not None:

test/lazy/test_lazy_evaluated_kernel_tensor.py

+12
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,18 @@ def test_batch_getitem(self):
122122
self.assertEqual(k.size(), torch.Size([2, 5, 5]))
123123
self.assertEqual(k[..., :4, :3].size(), torch.Size([2, 4, 3]))
124124

125+
def test_batch_getitem_multioutput(self):
126+
"""Ensure slicing is efficient when using a multioutput kernel"""
127+
x1 = torch.randn(5, 6)
128+
x2 = torch.randn(5, 6)
129+
kern = gpytorch.kernels.RBFKernelGrad(batch_shape=torch.Size([2]))
130+
k = kern(x1, x2)
131+
k.evaluate_kernel = MagicMock(name="evaluate_kernel")
132+
k_sliced = k[..., :7, :14]
133+
self.assertFalse(k.evaluate_kernel.called)
134+
self.assertEqual(k.size(), torch.Size([2, 35, 35]))
135+
self.assertEqual(k_sliced.size(), torch.Size([2, 7, 14]))
136+
125137
def test_getitem_tensor_index(self):
126138
# Not supported a.t.m. with LazyEvaluatedKernelTensors
127139
pass

0 commit comments

Comments
 (0)