Skip to content

Commit d8b6328

Browse files
committed
Fix KeOps regressions from #2296.
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel. (This computation happens during preconditioning, which requires the diagonal of the already-formed kernel LinearOperator object.) This error was because KeopsLinearOperator.diagonal calls to_dense on the output of a batch kernel operation. However, to_dense is not defined for KeOps LazyTensors. This PR is in some sense a hack fix to this bug (a less hack fix will require changes to KernelLinearOperator), but it is also a generally nice and helpful refactor that will improve KeOps kernels in general. The fixes: - KeOpsKernels now only define a forward function, that will be used both when we want to use KeOps and when we want to bypass it. - KeOpsKernels now use a `_lazify_inputs` helper method, which (potentially) wraps the inputs as KeOpsLazyTensors, or potentially leaves the inputs as torch Tensors. - The KeOps wrapping happens unless we want to bypass KeOps, which occurs when either (1) the matrix is small (below Cholesky size) or (2) when the use has turned off the `gpytorch.settings.use_keops` option (*NEW IN THIS PR*). Why this is beneficial: - KeOps kernels now follow the same API as non-KeOps kernels (define a forward method) - The user now only has to define one forward method, that works in both the keops and non-keops cases - The `diagonal` call in KeopsLinearOperator constructs a batch 1x1 matrix, which is small enough to bypass keops and thus avoid the current bug. (Hence why this solution is currently a hack, but could become less hacky with a small modification to KernelLinearOperator and/or the to_dense method in LinearOperator). Other changes: - Fix stability issues with the keops MaternKernel. (There were some NaN issues) - Introduce a `gpytorch.settings.use_keops` feature flag. - Clean up KeOPs notebook [Fixes #2363]
1 parent 43383c2 commit d8b6328

File tree

8 files changed

+317
-349
lines changed

8 files changed

+317
-349
lines changed

examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.ipynb

+77-78
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,14 @@
1717
},
1818
{
1919
"cell_type": "code",
20-
"execution_count": 6,
20+
"execution_count": 1,
2121
"metadata": {},
22-
"outputs": [
23-
{
24-
"name": "stdout",
25-
"output_type": "stream",
26-
"text": [
27-
"The autoreload extension is already loaded. To reload it, use:\n",
28-
" %reload_ext autoreload\n"
29-
]
30-
}
31-
],
22+
"outputs": [],
3223
"source": [
3324
"import math\n",
3425
"import torch\n",
3526
"import gpytorch\n",
27+
"import tqdm.notebook as tqdm\n",
3628
"from matplotlib import pyplot as plt\n",
3729
"\n",
3830
"%matplotlib inline\n",
@@ -45,22 +37,16 @@
4537
"metadata": {},
4638
"source": [
4739
"### Downloading Data\n",
48-
"We will be using the 3droad UCI dataset which contains a total of 278,319 data points. The next cell will download this dataset from a Google drive and load it."
40+
"We will be using the 3droad UCI dataset which contains a total of 434,874 data points. We will split the data in half for training and half for testing.\n",
41+
"\n",
42+
"The next cell will download this dataset from a Google drive and load it."
4943
]
5044
},
5145
{
5246
"cell_type": "code",
53-
"execution_count": 3,
47+
"execution_count": 2,
5448
"metadata": {},
55-
"outputs": [
56-
{
57-
"name": "stdout",
58-
"output_type": "stream",
59-
"text": [
60-
"Downloading '3droad' UCI dataset...\n"
61-
]
62-
}
63-
],
49+
"outputs": [],
6450
"source": [
6551
"import urllib.request\n",
6652
"import os.path\n",
@@ -76,15 +62,25 @@
7662
},
7763
{
7864
"cell_type": "code",
79-
"execution_count": 5,
65+
"execution_count": 3,
8066
"metadata": {},
81-
"outputs": [],
67+
"outputs": [
68+
{
69+
"name": "stdout",
70+
"output_type": "stream",
71+
"text": [
72+
"Num train: 217437\n",
73+
"Num test: 217437\n"
74+
]
75+
}
76+
],
8277
"source": [
8378
"import numpy as np\n",
8479
"\n",
8580
"N = data.shape[0]\n",
8681
"# make train/val/test\n",
87-
"n_train = int(0.8 * N)\n",
82+
"n_train = int(0.5 * N)\n",
83+
"\n",
8884
"train_x, train_y = data[:n_train, :-1], data[:n_train, -1]\n",
8985
"test_x, test_y = data[n_train:, :-1], data[n_train:, -1]\n",
9086
"\n",
@@ -106,7 +102,12 @@
106102
"output_device = torch.device('cuda:0')\n",
107103
"\n",
108104
"train_x, train_y = train_x.to(output_device), train_y.to(output_device)\n",
109-
"test_x, test_y = test_x.to(output_device), test_y.to(output_device)"
105+
"test_x, test_y = test_x.to(output_device), test_y.to(output_device)\n",
106+
"\n",
107+
"print(\n",
108+
" f\"Num train: {train_y.size(-1)}\\n\"\n",
109+
" f\"Num test: {test_y.size(-1)}\"\n",
110+
")"
110111
]
111112
},
112113
{
@@ -120,7 +121,7 @@
120121
},
121122
{
122123
"cell_type": "code",
123-
"execution_count": 7,
124+
"execution_count": 4,
124125
"metadata": {},
125126
"outputs": [],
126127
"source": [
@@ -139,16 +140,36 @@
139140
"\n",
140141
"# initialize likelihood and model\n",
141142
"likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()\n",
142-
"model = ExactGPModel(train_x, train_y, likelihood).cuda()"
143+
"model = ExactGPModel(train_x, train_y, likelihood).cuda()\n",
144+
"\n",
145+
"# Because we know some properties about this dataset,\n",
146+
"# we will initialize the lengthscale to be somewhat small\n",
147+
"# This step isn't necessary, but it will help the model converge faster.\n",
148+
"model.covar_module.base_kernel.lengthscale = 0.05"
143149
]
144150
},
145151
{
146152
"cell_type": "code",
147-
"execution_count": null,
153+
"execution_count": 5,
148154
"metadata": {
149155
"scrolled": false
150156
},
151-
"outputs": [],
157+
"outputs": [
158+
{
159+
"data": {
160+
"application/vnd.jupyter.widget-view+json": {
161+
"model_id": "691194d2d51e4d389fef9f0f7cb34f6b",
162+
"version_major": 2,
163+
"version_minor": 0
164+
},
165+
"text/plain": [
166+
"Training: 0%| | 0/25 [00:00<?, ?it/s]"
167+
]
168+
},
169+
"metadata": {},
170+
"output_type": "display_data"
171+
}
172+
],
152173
"source": [
153174
"# Find optimal model hyperparameters\n",
154175
"model.train()\n",
@@ -158,64 +179,44 @@
158179
"optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters\n",
159180
"\n",
160181
"# \"Loss\" for GPs - the marginal log likelihood\n",
161-
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
182+
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)\n",
162183
"\n",
163184
"import time\n",
164-
"training_iter = 50\n",
165-
"for i in range(training_iter):\n",
185+
"training_iter = 25\n",
186+
"iterator = tqdm.tqdm(range(training_iter), desc=\"Training\")\n",
187+
"for i in iterator:\n",
166188
" start_time = time.time()\n",
167189
" # Zero gradients from previous iteration\n",
168190
" optimizer.zero_grad()\n",
169191
" # Output from model\n",
170192
" output = model(train_x)\n",
171193
" # Calc loss and backprop gradients\n",
172194
" loss = -mll(output, train_y)\n",
195+
" print_values = dict(\n",
196+
" loss=loss.item(),\n",
197+
" ls=model.covar_module.base_kernel.lengthscale.norm().item(),\n",
198+
" os=model.covar_module.outputscale.item(),\n",
199+
" noise=model.likelihood.noise.item(),\n",
200+
" mu=model.mean_module.constant.item(),\n",
201+
" )\n",
202+
" iterator.set_postfix(**print_values)\n",
173203
" loss.backward()\n",
174-
" print('Iter %d/%d - Loss: %.3f lengthscale: %.3f noise: %.3f' % (\n",
175-
" i + 1, training_iter, loss.item(),\n",
176-
" model.covar_module.base_kernel.lengthscale.item(),\n",
177-
" model.likelihood.noise.item()\n",
178-
" ))\n",
179-
" optimizer.step()\n",
180-
" print(time.time() - start_time)"
204+
" optimizer.step()"
181205
]
182206
},
183207
{
184208
"cell_type": "code",
185-
"execution_count": 12,
209+
"execution_count": 6,
186210
"metadata": {},
187-
"outputs": [
188-
{
189-
"name": "stdout",
190-
"output_type": "stream",
191-
"text": [
192-
"Compiling libKeOpstorchd7ba409487 in /home/jake.gardner/.cache/pykeops-1.1.1//build-libKeOpstorchd7ba409487:\n",
193-
" formula: Sum_Reduction(((((Var(0,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))) + (IntCst(1) + (Var(3,1,2) * Square(Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1))))))))) * Exp((Var(4,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))))) * Var(5,3320,1)),0)\n",
194-
" aliases: Var(0,1,2); Var(1,18,0); Var(2,18,1); Var(3,1,2); Var(4,1,2); Var(5,3320,1); \n",
195-
" dtype : float32\n",
196-
"... Done.\n",
197-
"Compiling libKeOpstorch7385e76d34 in /home/jake.gardner/.cache/pykeops-1.1.1//build-libKeOpstorch7385e76d34:\n",
198-
" formula: Sum_Reduction(((((Var(0,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))) + (IntCst(1) + (Var(3,1,2) * Square(Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1))))))))) * Exp((Var(4,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))))) * Var(5,1,1)),0)\n",
199-
" aliases: Var(0,1,2); Var(1,18,0); Var(2,18,1); Var(3,1,2); Var(4,1,2); Var(5,1,1); \n",
200-
" dtype : float32\n",
201-
"... Done.\n",
202-
"Compiling libKeOpstorch97105370ea in /home/jake.gardner/.cache/pykeops-1.1.1//build-libKeOpstorch97105370ea:\n",
203-
" formula: Sum_Reduction(((((Var(0,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))) + (IntCst(1) + (Var(3,1,2) * Square(Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1))))))))) * Exp((Var(4,1,2) * Sqrt(Sum(Square((Var(1,18,0) - Var(2,18,1)))))))) * Var(5,100,1)),0)\n",
204-
" aliases: Var(0,1,2); Var(1,18,0); Var(2,18,1); Var(3,1,2); Var(4,1,2); Var(5,100,1); \n",
205-
" dtype : float32\n",
206-
"... Done.\n"
207-
]
208-
}
209-
],
211+
"outputs": [],
210212
"source": [
211213
"# Get into evaluation (predictive posterior) mode\n",
212214
"model.eval()\n",
213-
"likelihood.eval()\n",
214215
"\n",
215216
"# Test points are regularly spaced along [0,1]\n",
216217
"# Make predictions by feeding model through likelihood\n",
217218
"with torch.no_grad(), gpytorch.settings.fast_pred_var():\n",
218-
" observed_pred = likelihood(model(test_x))"
219+
" observed_pred = model.likelihood(model(test_x))"
219220
]
220221
},
221222
{
@@ -227,29 +228,27 @@
227228
},
228229
{
229230
"cell_type": "code",
230-
"execution_count": 15,
231+
"execution_count": 7,
231232
"metadata": {},
232233
"outputs": [
233234
{
234-
"data": {
235-
"text/plain": [
236-
"tensor(0.1068, device='cuda:0')"
237-
]
238-
},
239-
"execution_count": 15,
240-
"metadata": {},
241-
"output_type": "execute_result"
235+
"name": "stdout",
236+
"output_type": "stream",
237+
"text": [
238+
"RMSE: 0.138\n"
239+
]
242240
}
243241
],
244242
"source": [
245-
"torch.sqrt(torch.mean(torch.pow(observed_pred.mean - test_y, 2)))"
243+
"rmse = (observed_pred.mean - test_y).square().mean().sqrt().item()\n",
244+
"print(f\"RMSE: {rmse:.3f}\")"
246245
]
247246
}
248247
],
249248
"metadata": {
250249
"anaconda-cloud": {},
251250
"kernelspec": {
252-
"display_name": "Python 3",
251+
"display_name": "Python 3 (ipykernel)",
253252
"language": "python",
254253
"name": "python3"
255254
},
@@ -263,7 +262,7 @@
263262
"name": "python",
264263
"nbconvert_exporter": "python",
265264
"pygments_lexer": "ipython3",
266-
"version": "3.7.1"
265+
"version": "3.8.0"
267266
}
268267
},
269268
"nbformat": 4,

gpytorch/kernels/keops/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from .keops_kernel import KeOpsKernel
12
from .matern_kernel import MaternKernel
23
from .periodic_kernel import PeriodicKernel
34
from .rbf_kernel import RBFKernel
45

5-
__all__ = ["MaternKernel", "RBFKernel", "PeriodicKernel"]
6+
__all__ = ["KeOpsKernel", "MaternKernel", "PeriodicKernel", "RBFKernel"]
+45-29
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,64 @@
1-
from abc import abstractmethod
2-
from typing import Any
1+
import warnings
2+
from typing import Any, Tuple, Union
33

44
import torch
5+
from linear_operator import LinearOperator
56
from torch import Tensor
67

78
from ... import settings
89
from ..kernel import Kernel
910

1011
try:
1112
import pykeops # noqa F401
13+
from pykeops.torch import LazyTensor
14+
15+
def _lazify_and_expand_inputs(
16+
x1: Tensor, x2: Tensor
17+
) -> Tuple[Union[Tensor, LazyTensor], Union[Tensor, LazyTensor]]:
18+
r"""
19+
Potentially wrap inputs x1 and x2 as KeOps LazyTensors,
20+
depending on whether or not we want to use KeOps under the hood or not.
21+
"""
22+
x1_ = x1[..., :, None, :]
23+
x2_ = x2[..., None, :, :]
24+
if _use_keops(x1, x2):
25+
res = LazyTensor(x1_), LazyTensor(x2_)
26+
return res
27+
return x1_, x2_
28+
29+
def _use_keops(x1: Tensor, x2: Tensor) -> bool:
30+
r"""
31+
Determine whether or not to use KeOps under the hood
32+
This largely depends on the size of the kernel matrix
33+
34+
There are situations where we do not want the KeOps linear operator to use KeOps under the hood.
35+
See https://github.com/cornellius-gp/gpytorch/pull/1319
36+
"""
37+
return (
38+
settings.use_keops.on()
39+
and x1.size(-2) >= settings.max_cholesky_size.value()
40+
and x2.size(-2) >= settings.max_cholesky_size.value()
41+
)
1242

1343
class KeOpsKernel(Kernel):
14-
@abstractmethod
15-
def _nonkeops_forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
16-
r"""
17-
Computes the covariance matrix (or diagonal) without using KeOps.
18-
This function must implement both the diag=True and diag=False options.
19-
"""
20-
raise NotImplementedError
21-
22-
@abstractmethod
23-
def _keops_forward(self, x1: Tensor, x2: Tensor, **kwargs: Any):
24-
r"""
25-
Computes the covariance matrix with KeOps.
26-
This function only implements the diag=False option, and no diag keyword should be passed in.
27-
"""
28-
raise NotImplementedError
29-
30-
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
31-
if diag:
32-
return self._nonkeops_forward(x1, x2, diag=True, **kwargs)
33-
elif x1.size(-2) < settings.max_cholesky_size.value() or x2.size(-2) < settings.max_cholesky_size.value():
34-
return self._nonkeops_forward(x1, x2, diag=False, **kwargs)
35-
else:
36-
return self._keops_forward(x1, x2, **kwargs)
37-
38-
def __call__(self, *args: Any, **kwargs: Any):
44+
def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor, LazyTensor]:
3945
# Hotfix for zero gradients. See https://github.com/cornellius-gp/gpytorch/issues/1543
4046
args = [arg.contiguous() if torch.is_tensor(arg) else arg for arg in args]
4147
kwargs = {k: v.contiguous() if torch.is_tensor(v) else v for k, v in kwargs.items()}
4248
return super().__call__(*args, **kwargs)
4349

4450
except ImportError:
4551

52+
def _lazify_and_expand_inputs(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]:
53+
return x1, x2
54+
55+
def _use_keops(x1: Tensor, x2: Tensor) -> bool:
56+
return False
57+
4658
class KeOpsKernel(Kernel):
47-
def __init__(self, *args: Any, **kwargs: Any):
48-
raise RuntimeError("You must have KeOps installed to use a KeOpsKernel")
59+
def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor]:
60+
warnings.warn(
61+
"KeOps is not installed. " f"{type(self)} will revert to the the non-keops version of this kernel.",
62+
RuntimeWarning,
63+
)
64+
return super().__call__(*args, **kwargs)

0 commit comments

Comments
 (0)