Skip to content

keops periodic and keops kernels unit tests #2296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
da5c83d
working version
m-julian Mar 9, 2023
59c4f7f
change name of import
m-julian Mar 9, 2023
49743cc
add periodic kernel
m-julian Mar 9, 2023
e3b0737
copy original periodic kernel code
m-julian Mar 9, 2023
4abaa4d
tests for keops periodic kernel
m-julian Mar 9, 2023
74fd11d
addtests
m-julian Mar 9, 2023
d367880
remove copied file
m-julian Mar 9, 2023
f8371bc
fixed NaN
m-julian Mar 9, 2023
a92b596
add more tests
m-julian Mar 9, 2023
2b412a8
update docstring
m-julian Mar 9, 2023
dc256f8
add nonkeops tests
m-julian Mar 9, 2023
1f19026
add cuda check
m-julian Mar 9, 2023
41c4785
formatting
m-julian Mar 9, 2023
93dce5b
use arithmetic operators
m-julian Mar 9, 2023
a5587ba
add tests
m-julian Mar 9, 2023
96e1f85
use cuda tensors
m-julian Mar 10, 2023
ad0efbc
subcclass from periodic kernel
m-julian Mar 10, 2023
abe64a6
docstring update
m-julian Mar 10, 2023
1e9a5c2
base keops class for tests
m-julian Mar 24, 2023
8fddebd
run keops tests on cpu
m-julian Mar 24, 2023
aca88ad
formatting
m-julian Mar 24, 2023
f589d99
use KernelLinearOperator
m-julian May 11, 2023
d6a3e3e
add comment
m-julian May 13, 2023
e34791b
gradient and ard tests
m-julian May 13, 2023
c0f3171
another gradient test
m-julian May 13, 2023
1f4d08c
diag and refactor
m-julian May 13, 2023
1940cfb
Update test cases, adapt to new KernelLinearOperator style
gpleiss Jun 2, 2023
fed7222
Update to latest version of LinearOperator, add keops tests to CI
gpleiss Jun 2, 2023
e9bbac8
Add behavioral test for KeOps regression
gpleiss Jun 2, 2023
a3b6040
Include KeOps kernels in the docs
gpleiss Jun 2, 2023
9e7f066
Refactor keops implementation, add more testing
gpleiss Jun 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ requirements:
run:
- pytorch>=1.11
- scikit-learn
- linear_operator>=0.4.0
- linear_operator>=0.5.0

test:
imports:
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/run_test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
pip install -e .
if [[ ${{ matrix.extras }} == "with-extras" ]]; then
pip install "pyro-ppl>=1.8";
pip install pykeops;
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
fi
- name: Run unit tests
Expand All @@ -75,7 +76,8 @@ jobs:
pip install pytest nbval jupyter tqdm matplotlib torchvision scipy
pip install -e .
pip install "pyro-ppl>=1.8";
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
pip install pykeops;
pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history
- name: Run example notebooks
run: |
grep -l smoke_test examples/**/*.ipynb | xargs grep -L 'smoke_test = False' | CI=true xargs pytest --nbval-lax --current-env
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
setuptools_scm<=7.1.0
ipython<=8.6.0
ipykernel<=6.17.1
linear_operator>=0.4.0
linear_operator>=0.5.0
m2r2<=0.3.3.post2
nbclient<=0.7.3
nbformat<=5.8.0
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ GPyTorch's documentation
models
likelihoods
kernels
keops_kernels
means
marginal_log_likelihoods
metrics
Expand Down
41 changes: 41 additions & 0 deletions docs/source/keops_kernels.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
.. role:: hidden
:class: hidden-section

gpytorch.kernels.keops
===================================

.. automodule:: gpytorch.kernels.keops
.. currentmodule:: gpytorch.kernels.keops


These kernels are compatible with the GPyTorch KeOps integration.
For more information, see the `KeOps tutorial`_.

.. note::
Only some standard kernels have KeOps impementations.
If there is a kernel you want that's missing, consider submitting a pull request!


.. _KeOps Tutorial:
examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.html


:hidden:`RBFKernel`
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RBFKernel
:members:


:hidden:`MaternKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaternKernel
:members:


:hidden:`PeriodicKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: PeriodicKernel
:members:
3 changes: 2 additions & 1 deletion gpytorch/kernels/keops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .matern_kernel import MaternKernel
from .periodic_kernel import PeriodicKernel
from .rbf_kernel import RBFKernel

__all__ = ["MaternKernel", "RBFKernel"]
__all__ = ["MaternKernel", "RBFKernel", "PeriodicKernel"]
33 changes: 28 additions & 5 deletions gpytorch/kernels/keops/keops_kernel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,41 @@
from abc import abstractmethod
from typing import Any

import torch
from torch import Tensor

from ... import settings
from ..kernel import Kernel

try:
from pykeops.torch import LazyTensor as KEOLazyTensor
import pykeops # noqa F401

class KeOpsKernel(Kernel):
@abstractmethod
def covar_func(self, x1: torch.Tensor, x2: torch.Tensor) -> KEOLazyTensor:
raise NotImplementedError("KeOpsKernels must define a covar_func method")
def _nonkeops_forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
r"""
Computes the covariance matrix (or diagonal) without using KeOps.
This function must implement both the diag=True and diag=False options.
"""
raise NotImplementedError

def __call__(self, *args, **kwargs):
@abstractmethod
def _keops_forward(self, x1: Tensor, x2: Tensor, **kwargs: Any):
r"""
Computes the covariance matrix with KeOps.
This function only implements the diag=False option, and no diag keyword should be passed in.
"""
raise NotImplementedError

def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any):
if diag:
return self._nonkeops_forward(x1, x2, diag=True, **kwargs)
elif x1.size(-2) < settings.max_cholesky_size.value() or x2.size(-2) < settings.max_cholesky_size.value():
return self._nonkeops_forward(x1, x2, diag=False, **kwargs)
else:
return self._keops_forward(x1, x2, **kwargs)

def __call__(self, *args: Any, **kwargs: Any):
# Hotfix for zero gradients. See https://github.com/cornellius-gp/gpytorch/issues/1543
args = [arg.contiguous() if torch.is_tensor(arg) else arg for arg in args]
kwargs = {k: v.contiguous() if torch.is_tensor(v) else v for k, v in kwargs.items()}
Expand All @@ -21,5 +44,5 @@ def __call__(self, *args, **kwargs):
except ImportError:

class KeOpsKernel(Kernel):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
raise RuntimeError("You must have KeOps installed to use a KeOpsKernel")
117 changes: 54 additions & 63 deletions gpytorch/kernels/keops/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,57 @@
import math

import torch
from linear_operator.operators import KeOpsLinearOperator
from linear_operator.operators import KernelLinearOperator

from ... import settings
from ..matern_kernel import MaternKernel as GMaternKernel
from .keops_kernel import KeOpsKernel

try:
from pykeops.torch import LazyTensor as KEOLazyTensor

def _covar_func(x1, x2, nu=2.5, **params):
x1_ = KEOLazyTensor(x1[..., :, None, :])
x2_ = KEOLazyTensor(x2[..., None, :, :])

distance = ((x1_ - x2_) ** 2).sum(-1).sqrt()
exp_component = (-math.sqrt(nu * 2) * distance).exp()

if nu == 0.5:
constant_component = 1
elif nu == 1.5:
constant_component = (math.sqrt(3) * distance) + 1
elif nu == 2.5:
constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * (distance**2))

return constant_component * exp_component

class MaternKernel(KeOpsKernel):
"""
Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies.

This class can be used as a drop in replacement for gpytorch.kernels.MaternKernel in most cases, and supports
the same arguments. There are currently a few limitations, for example a lack of batch mode support. However,
most other features like ARD will work.
This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases,
and supports the same arguments.

:param nu: (Default: 2.5) The smoothness parameter.
:type nu: float (0.5, 1.5, or 2.5)
:param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
input dimension. It should be `d` if x1 is a `... x n x d` matrix.
:type ard_num_dims: int, optional
:param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
:type batch_shape: torch.Size, optional
:param active_dims: (Default: `None`) Set this if you want to
compute the covariance of only a few input dimensions. The ints
corresponds to the indices of the dimensions.
:type active_dims: Tuple(int)
:param lengthscale_prior: (Default: `None`)
Set this if you want to apply a prior to the lengthscale parameter.
:type lengthscale_prior: ~gpytorch.priors.Prior, optional
:param lengthscale_constraint: (Default: `Positive`) Set this if you want
to apply a constraint to the lengthscale parameter.
:type lengthscale_constraint: ~gpytorch.constraints.Interval, optional
:param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
:type eps: float, optional
"""

has_lengthscale = True
Expand All @@ -27,8 +63,12 @@ def __init__(self, nu=2.5, **kwargs):
super(MaternKernel, self).__init__(**kwargs)
self.nu = nu

def _nonkeops_covar_func(self, x1, x2, diag=False):
distance = self.covar_dist(x1, x2, diag=diag)
def _nonkeops_forward(self, x1, x2, diag=False, **kwargs):
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
x1_ = (x1 - mean) / self.lengthscale
x2_ = (x2 - mean) / self.lengthscale

distance = self.covar_dist(x1_, x2_, diag=diag, **kwargs)
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance)

if self.nu == 0.5:
Expand All @@ -39,63 +79,14 @@ def _nonkeops_covar_func(self, x1, x2, diag=False):
constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2)
return constant_component * exp_component

def covar_func(self, x1, x2, diag=False):
# We only should use KeOps on big kernel matrices
# If we would otherwise be performing Cholesky inference, (or when just computing a kernel matrix diag)
# then don't apply KeOps
# enable gradients to ensure that test time caches on small predictions are still
# backprop-able
with torch.autograd.enable_grad():
if (
diag
or x1.size(-2) < settings.max_cholesky_size.value()
or x2.size(-2) < settings.max_cholesky_size.value()
):
return self._nonkeops_covar_func(x1, x2, diag=diag)
# TODO: x1 / x2 size checks are a work around for a very minor bug in KeOps.
# This bug is fixed on KeOps master, and we'll remove that part of the check
# when they cut a new release.
elif x1.size(-2) == 1 or x2.size(-2) == 1:
return self._nonkeops_covar_func(x1, x2, diag=diag)
else:
# We only should use KeOps on big kernel matrices
# If we would otherwise be performing Cholesky inference, then don't apply KeOps
if (
x1.size(-2) < settings.max_cholesky_size.value()
or x2.size(-2) < settings.max_cholesky_size.value()
):
x1_ = x1[..., :, None, :]
x2_ = x2[..., None, :, :]
else:
x1_ = KEOLazyTensor(x1[..., :, None, :])
x2_ = KEOLazyTensor(x2[..., None, :, :])

distance = ((x1_ - x2_) ** 2).sum(-1).sqrt()
exp_component = (-math.sqrt(self.nu * 2) * distance).exp()

if self.nu == 0.5:
constant_component = 1
elif self.nu == 1.5:
constant_component = (math.sqrt(3) * distance) + 1
elif self.nu == 2.5:
constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * distance**2)

return constant_component * exp_component

def forward(self, x1, x2, diag=False, **params):
def _keops_forward(self, x1, x2, **kwargs):
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]

x1_ = (x1 - mean).div(self.lengthscale)
x2_ = (x2 - mean).div(self.lengthscale)

if diag:
return self.covar_func(x1_, x2_, diag=True)

covar_func = lambda x1, x2, diag=False: self.covar_func(x1, x2, diag)
return KeOpsLinearOperator(x1_, x2_, covar_func)
x1_ = (x1 - mean) / self.lengthscale
x2_ = (x2 - mean) / self.lengthscale
# return KernelLinearOperator inst only when calculating the whole covariance matrix
return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)

except ImportError:

class MaternKernel(KeOpsKernel):
def __init__(self, *args, **kwargs):
super().__init__()
class MaternKernel(GMaternKernel):
pass
93 changes: 93 additions & 0 deletions gpytorch/kernels/keops/periodic_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python3

import math

from linear_operator.operators import KernelLinearOperator

from ..periodic_kernel import PeriodicKernel as GPeriodicKernel
from .keops_kernel import KeOpsKernel

# from ...kernels import PeriodicKernel gives a cyclic import

try:
from pykeops.torch import LazyTensor as KEOLazyTensor

def _covar_func(x1, x2, lengthscale, **kwargs):
# symbolic array of shape ..., ndatax1_ x 1 x ndim
x1_ = KEOLazyTensor(x1[..., :, None, :])
# symbolic array of shape ..., 1 x ndatax2_ x ndim
x2_ = KEOLazyTensor(x2[..., None, :, :])
lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim
# do not use .power(2.0) as it gives NaN values on cuda
# seems related to https://github.com/getkeops/keops/issues/112
K = ((((x1_ - x2_).abs().sin()) ** 2) * (-2.0 / lengthscale)).sum(-1).exp()
return K

# subclass from original periodic kernel to reduce code duplication
class PeriodicKernel(KeOpsKernel, GPeriodicKernel):
"""
Implements the Periodic Kernel using KeOps as a driver for kernel matrix multiplies.

This class can be used as a drop in replacement for :class:`gpytorch.kernels.PeriodicKernel` in most cases,
and supports the same arguments.

:param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
input dimension. It should be `d` if x1 is a `... x n x d` matrix.
:type ard_num_dims: int, optional
:param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
:type batch_shape: torch.Size, optional
:param active_dims: (Default: `None`) Set this if you want to
compute the covariance of only a few input dimensions. The ints
corresponds to the indices of the dimensions.
:type active_dims: Tuple(int)
:param period_length_prior: (Default: `None`)
Set this if you want to apply a prior to the period length parameter.
:type period_length_prior: ~gpytorch.priors.Prior, optional
:param period_length_constraint: (Default: `Positive`) Set this if you want
to apply a constraint to the period length parameter.
:type period_length_constraint: ~gpytorch.constraints.Interval, optional
:param lengthscale_prior: (Default: `None`)
Set this if you want to apply a prior to the lengthscale parameter.
:type lengthscale_prior: ~gpytorch.priors.Prior, optional
:param lengthscale_constraint: (Default: `Positive`) Set this if you want
to apply a constraint to the lengthscale parameter.
:type lengthscale_constraint: ~gpytorch.constraints.Interval, optional
:param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
:type eps: float, optional

:var torch.Tensor period_length: The period length parameter. Size/shape of parameter depends on the
ard_num_dims and batch_shape arguments.
"""

has_lengthscale = True

# code from the already-implemented Periodic Kernel
def _nonkeops_forward(self, x1, x2, diag=False, **kwargs):
x1_ = x1.div(self.period_length / math.pi)
x2_ = x2.div(self.period_length / math.pi)

# We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions.
diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True)

if diag:
lengthscale = self.lengthscale[..., 0, :, None]
else:
lengthscale = self.lengthscale[..., 0, :, None, None]

exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0)
exp_term = exp_term.sum(dim=(-2 if diag else -3))

return exp_term.exp()

def _keops_forward(self, x1, x2, **kwargs):
x1_ = x1.div(self.period_length / math.pi)
x2_ = x2.div(self.period_length / math.pi)
# return KernelLinearOperator inst only when calculating the whole covariance matrix
# pass any parameters which are used inside _covar_func as *args to get gradients computed for them
return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs)

except ImportError:

class PeriodicKernel(GPeriodicKernel):
pass
Loading