Skip to content

Commit b7f373d

Browse files
authored
Merge branch 'master' into fix
2 parents e9e9026 + 8cb3136 commit b7f373d

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

gpytorch/lazy/cat_lazy_tensor.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def _get_indices(self, row_index, col_index, *batch_indices):
181181
if len(res_list) == 1:
182182
return res_list[0].view(target_shape).to(self.device)
183183
else:
184-
return torch.cat(res_list).view(target_shape).to(self.device)
184+
# Explicitly move tensors to one device as torch.cat no longer moves tensors:
185+
# https://github.com/pytorch/pytorch/issues/35045
186+
res_list = [lazy_tensor.to(self.device) for lazy_tensor in res_list]
187+
return torch.cat(res_list).view(target_shape)
185188

186189
def _getitem(self, row_index, col_index, *batch_indices):
187190
indices = [*batch_indices, row_index, col_index]

gpytorch/utils/permutation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def apply_permutation(
7676
right_permutation = torch.arange(matrix.size(-1), device=matrix.device)
7777

7878
# Apply permutations
79-
return delazify(matrix.__getitem__((*batch_idx, left_permutation.unsqueeze(-1), right_permutation.unsqueeze(-2))))
79+
res = delazify(matrix.__getitem__((*batch_idx, left_permutation.unsqueeze(-1), right_permutation.unsqueeze(-2))))
80+
return res.to(device=matrix.device)
8081

8182

8283
def inverse_permutation(permutation):

0 commit comments

Comments
 (0)