diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 72486fadd..f15f1bc9f 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -18,7 +18,7 @@ requirements: run: - pytorch>=1.11 - scikit-learn - - linear_operator>=0.4.0 + - linear_operator>=0.5.0 test: imports: diff --git a/.github/workflows/run_test_suite.yml b/.github/workflows/run_test_suite.yml index 4efd70630..5b12018b3 100644 --- a/.github/workflows/run_test_suite.yml +++ b/.github/workflows/run_test_suite.yml @@ -55,6 +55,7 @@ jobs: pip install -e . if [[ ${{ matrix.extras }} == "with-extras" ]]; then pip install "pyro-ppl>=1.8"; + pip install pykeops; pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history fi - name: Run unit tests @@ -75,7 +76,8 @@ jobs: pip install pytest nbval jupyter tqdm matplotlib torchvision scipy pip install -e . pip install "pyro-ppl>=1.8"; - pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history + pip install pykeops; + pip install faiss-cpu; # Unofficial pip release: https://pypi.org/project/faiss-cpu/#history - name: Run example notebooks run: | grep -l smoke_test examples/**/*.ipynb | xargs grep -L 'smoke_test = False' | CI=true xargs pytest --nbval-lax --current-env diff --git a/docs/requirements.txt b/docs/requirements.txt index fd3f4d915..2a300985f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,7 @@ setuptools_scm<=7.1.0 ipython<=8.6.0 ipykernel<=6.17.1 -linear_operator>=0.4.0 +linear_operator>=0.5.0 m2r2<=0.3.3.post2 nbclient<=0.7.3 nbformat<=5.8.0 diff --git a/docs/source/index.rst b/docs/source/index.rst index 93975b2f5..2c9d1b3ec 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,6 +29,7 @@ GPyTorch's documentation models likelihoods kernels + keops_kernels means marginal_log_likelihoods metrics diff --git a/docs/source/keops_kernels.rst b/docs/source/keops_kernels.rst new file mode 100644 index 000000000..01a02bae4 --- /dev/null +++ b/docs/source/keops_kernels.rst @@ -0,0 +1,41 @@ +.. role:: hidden + :class: hidden-section + +gpytorch.kernels.keops +=================================== + +.. automodule:: gpytorch.kernels.keops +.. currentmodule:: gpytorch.kernels.keops + + +These kernels are compatible with the GPyTorch KeOps integration. +For more information, see the `KeOps tutorial`_. + +.. note:: + Only some standard kernels have KeOps impementations. + If there is a kernel you want that's missing, consider submitting a pull request! + + +.. _KeOps Tutorial: + examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.html + + +:hidden:`RBFKernel` +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RBFKernel + :members: + + +:hidden:`MaternKernel` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MaternKernel + :members: + + +:hidden:`PeriodicKernel` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: PeriodicKernel + :members: diff --git a/gpytorch/kernels/keops/__init__.py b/gpytorch/kernels/keops/__init__.py index 87e01a84d..639e2572f 100644 --- a/gpytorch/kernels/keops/__init__.py +++ b/gpytorch/kernels/keops/__init__.py @@ -1,4 +1,5 @@ from .matern_kernel import MaternKernel +from .periodic_kernel import PeriodicKernel from .rbf_kernel import RBFKernel -__all__ = ["MaternKernel", "RBFKernel"] +__all__ = ["MaternKernel", "RBFKernel", "PeriodicKernel"] diff --git a/gpytorch/kernels/keops/keops_kernel.py b/gpytorch/kernels/keops/keops_kernel.py index d264447ff..a584b4b3b 100644 --- a/gpytorch/kernels/keops/keops_kernel.py +++ b/gpytorch/kernels/keops/keops_kernel.py @@ -1,18 +1,41 @@ from abc import abstractmethod +from typing import Any import torch +from torch import Tensor +from ... import settings from ..kernel import Kernel try: - from pykeops.torch import LazyTensor as KEOLazyTensor + import pykeops # noqa F401 class KeOpsKernel(Kernel): @abstractmethod - def covar_func(self, x1: torch.Tensor, x2: torch.Tensor) -> KEOLazyTensor: - raise NotImplementedError("KeOpsKernels must define a covar_func method") + def _nonkeops_forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any): + r""" + Computes the covariance matrix (or diagonal) without using KeOps. + This function must implement both the diag=True and diag=False options. + """ + raise NotImplementedError - def __call__(self, *args, **kwargs): + @abstractmethod + def _keops_forward(self, x1: Tensor, x2: Tensor, **kwargs: Any): + r""" + Computes the covariance matrix with KeOps. + This function only implements the diag=False option, and no diag keyword should be passed in. + """ + raise NotImplementedError + + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs: Any): + if diag: + return self._nonkeops_forward(x1, x2, diag=True, **kwargs) + elif x1.size(-2) < settings.max_cholesky_size.value() or x2.size(-2) < settings.max_cholesky_size.value(): + return self._nonkeops_forward(x1, x2, diag=False, **kwargs) + else: + return self._keops_forward(x1, x2, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any): # Hotfix for zero gradients. See https://github.com/cornellius-gp/gpytorch/issues/1543 args = [arg.contiguous() if torch.is_tensor(arg) else arg for arg in args] kwargs = {k: v.contiguous() if torch.is_tensor(v) else v for k, v in kwargs.items()} @@ -21,5 +44,5 @@ def __call__(self, *args, **kwargs): except ImportError: class KeOpsKernel(Kernel): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): raise RuntimeError("You must have KeOps installed to use a KeOpsKernel") diff --git a/gpytorch/kernels/keops/matern_kernel.py b/gpytorch/kernels/keops/matern_kernel.py index d7d0baeb8..73365c2d8 100644 --- a/gpytorch/kernels/keops/matern_kernel.py +++ b/gpytorch/kernels/keops/matern_kernel.py @@ -2,21 +2,57 @@ import math import torch -from linear_operator.operators import KeOpsLinearOperator +from linear_operator.operators import KernelLinearOperator -from ... import settings +from ..matern_kernel import MaternKernel as GMaternKernel from .keops_kernel import KeOpsKernel try: from pykeops.torch import LazyTensor as KEOLazyTensor + def _covar_func(x1, x2, nu=2.5, **params): + x1_ = KEOLazyTensor(x1[..., :, None, :]) + x2_ = KEOLazyTensor(x2[..., None, :, :]) + + distance = ((x1_ - x2_) ** 2).sum(-1).sqrt() + exp_component = (-math.sqrt(nu * 2) * distance).exp() + + if nu == 0.5: + constant_component = 1 + elif nu == 1.5: + constant_component = (math.sqrt(3) * distance) + 1 + elif nu == 2.5: + constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * (distance**2)) + + return constant_component * exp_component + class MaternKernel(KeOpsKernel): """ Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies. - This class can be used as a drop in replacement for gpytorch.kernels.MaternKernel in most cases, and supports - the same arguments. There are currently a few limitations, for example a lack of batch mode support. However, - most other features like ARD will work. + This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases, + and supports the same arguments. + + :param nu: (Default: 2.5) The smoothness parameter. + :type nu: float (0.5, 1.5, or 2.5) + :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each + input dimension. It should be `d` if x1 is a `... x n x d` matrix. + :type ard_num_dims: int, optional + :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each + batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. + :type batch_shape: torch.Size, optional + :param active_dims: (Default: `None`) Set this if you want to + compute the covariance of only a few input dimensions. The ints + corresponds to the indices of the dimensions. + :type active_dims: Tuple(int) + :param lengthscale_prior: (Default: `None`) + Set this if you want to apply a prior to the lengthscale parameter. + :type lengthscale_prior: ~gpytorch.priors.Prior, optional + :param lengthscale_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the lengthscale parameter. + :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional + :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). + :type eps: float, optional """ has_lengthscale = True @@ -27,8 +63,12 @@ def __init__(self, nu=2.5, **kwargs): super(MaternKernel, self).__init__(**kwargs) self.nu = nu - def _nonkeops_covar_func(self, x1, x2, diag=False): - distance = self.covar_dist(x1, x2, diag=diag) + def _nonkeops_forward(self, x1, x2, diag=False, **kwargs): + mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] + x1_ = (x1 - mean) / self.lengthscale + x2_ = (x2 - mean) / self.lengthscale + + distance = self.covar_dist(x1_, x2_, diag=diag, **kwargs) exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance) if self.nu == 0.5: @@ -39,63 +79,14 @@ def _nonkeops_covar_func(self, x1, x2, diag=False): constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2) return constant_component * exp_component - def covar_func(self, x1, x2, diag=False): - # We only should use KeOps on big kernel matrices - # If we would otherwise be performing Cholesky inference, (or when just computing a kernel matrix diag) - # then don't apply KeOps - # enable gradients to ensure that test time caches on small predictions are still - # backprop-able - with torch.autograd.enable_grad(): - if ( - diag - or x1.size(-2) < settings.max_cholesky_size.value() - or x2.size(-2) < settings.max_cholesky_size.value() - ): - return self._nonkeops_covar_func(x1, x2, diag=diag) - # TODO: x1 / x2 size checks are a work around for a very minor bug in KeOps. - # This bug is fixed on KeOps master, and we'll remove that part of the check - # when they cut a new release. - elif x1.size(-2) == 1 or x2.size(-2) == 1: - return self._nonkeops_covar_func(x1, x2, diag=diag) - else: - # We only should use KeOps on big kernel matrices - # If we would otherwise be performing Cholesky inference, then don't apply KeOps - if ( - x1.size(-2) < settings.max_cholesky_size.value() - or x2.size(-2) < settings.max_cholesky_size.value() - ): - x1_ = x1[..., :, None, :] - x2_ = x2[..., None, :, :] - else: - x1_ = KEOLazyTensor(x1[..., :, None, :]) - x2_ = KEOLazyTensor(x2[..., None, :, :]) - - distance = ((x1_ - x2_) ** 2).sum(-1).sqrt() - exp_component = (-math.sqrt(self.nu * 2) * distance).exp() - - if self.nu == 0.5: - constant_component = 1 - elif self.nu == 1.5: - constant_component = (math.sqrt(3) * distance) + 1 - elif self.nu == 2.5: - constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * distance**2) - - return constant_component * exp_component - - def forward(self, x1, x2, diag=False, **params): + def _keops_forward(self, x1, x2, **kwargs): mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] - - x1_ = (x1 - mean).div(self.lengthscale) - x2_ = (x2 - mean).div(self.lengthscale) - - if diag: - return self.covar_func(x1_, x2_, diag=True) - - covar_func = lambda x1, x2, diag=False: self.covar_func(x1, x2, diag) - return KeOpsLinearOperator(x1_, x2_, covar_func) + x1_ = (x1 - mean) / self.lengthscale + x2_ = (x2 - mean) / self.lengthscale + # return KernelLinearOperator inst only when calculating the whole covariance matrix + return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs) except ImportError: - class MaternKernel(KeOpsKernel): - def __init__(self, *args, **kwargs): - super().__init__() + class MaternKernel(GMaternKernel): + pass diff --git a/gpytorch/kernels/keops/periodic_kernel.py b/gpytorch/kernels/keops/periodic_kernel.py new file mode 100644 index 000000000..b844425fa --- /dev/null +++ b/gpytorch/kernels/keops/periodic_kernel.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +import math + +from linear_operator.operators import KernelLinearOperator + +from ..periodic_kernel import PeriodicKernel as GPeriodicKernel +from .keops_kernel import KeOpsKernel + +# from ...kernels import PeriodicKernel gives a cyclic import + +try: + from pykeops.torch import LazyTensor as KEOLazyTensor + + def _covar_func(x1, x2, lengthscale, **kwargs): + # symbolic array of shape ..., ndatax1_ x 1 x ndim + x1_ = KEOLazyTensor(x1[..., :, None, :]) + # symbolic array of shape ..., 1 x ndatax2_ x ndim + x2_ = KEOLazyTensor(x2[..., None, :, :]) + lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim + # do not use .power(2.0) as it gives NaN values on cuda + # seems related to https://github.com/getkeops/keops/issues/112 + K = ((((x1_ - x2_).abs().sin()) ** 2) * (-2.0 / lengthscale)).sum(-1).exp() + return K + + # subclass from original periodic kernel to reduce code duplication + class PeriodicKernel(KeOpsKernel, GPeriodicKernel): + """ + Implements the Periodic Kernel using KeOps as a driver for kernel matrix multiplies. + + This class can be used as a drop in replacement for :class:`gpytorch.kernels.PeriodicKernel` in most cases, + and supports the same arguments. + + :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each + input dimension. It should be `d` if x1 is a `... x n x d` matrix. + :type ard_num_dims: int, optional + :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each + batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. + :type batch_shape: torch.Size, optional + :param active_dims: (Default: `None`) Set this if you want to + compute the covariance of only a few input dimensions. The ints + corresponds to the indices of the dimensions. + :type active_dims: Tuple(int) + :param period_length_prior: (Default: `None`) + Set this if you want to apply a prior to the period length parameter. + :type period_length_prior: ~gpytorch.priors.Prior, optional + :param period_length_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the period length parameter. + :type period_length_constraint: ~gpytorch.constraints.Interval, optional + :param lengthscale_prior: (Default: `None`) + Set this if you want to apply a prior to the lengthscale parameter. + :type lengthscale_prior: ~gpytorch.priors.Prior, optional + :param lengthscale_constraint: (Default: `Positive`) Set this if you want + to apply a constraint to the lengthscale parameter. + :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional + :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). + :type eps: float, optional + + :var torch.Tensor period_length: The period length parameter. Size/shape of parameter depends on the + ard_num_dims and batch_shape arguments. + """ + + has_lengthscale = True + + # code from the already-implemented Periodic Kernel + def _nonkeops_forward(self, x1, x2, diag=False, **kwargs): + x1_ = x1.div(self.period_length / math.pi) + x2_ = x2.div(self.period_length / math.pi) + + # We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions. + diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True) + + if diag: + lengthscale = self.lengthscale[..., 0, :, None] + else: + lengthscale = self.lengthscale[..., 0, :, None, None] + + exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0) + exp_term = exp_term.sum(dim=(-2 if diag else -3)) + + return exp_term.exp() + + def _keops_forward(self, x1, x2, **kwargs): + x1_ = x1.div(self.period_length / math.pi) + x2_ = x2.div(self.period_length / math.pi) + # return KernelLinearOperator inst only when calculating the whole covariance matrix + # pass any parameters which are used inside _covar_func as *args to get gradients computed for them + return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs) + +except ImportError: + + class PeriodicKernel(GPeriodicKernel): + pass diff --git a/gpytorch/kernels/keops/rbf_kernel.py b/gpytorch/kernels/keops/rbf_kernel.py index bfe0579f4..663d092a8 100644 --- a/gpytorch/kernels/keops/rbf_kernel.py +++ b/gpytorch/kernels/keops/rbf_kernel.py @@ -1,63 +1,60 @@ #!/usr/bin/env python3 -import torch -from linear_operator.operators import KeOpsLinearOperator +# from linear_operator.operators import KeOpsLinearOperator +from linear_operator.operators import KernelLinearOperator -from ... import settings -from ..rbf_kernel import postprocess_rbf +from ..rbf_kernel import postprocess_rbf, RBFKernel as GRBFKernel from .keops_kernel import KeOpsKernel try: from pykeops.torch import LazyTensor as KEOLazyTensor + def _covar_func(x1, x2, **kwargs): + x1_ = KEOLazyTensor(x1[..., :, None, :]) + x2_ = KEOLazyTensor(x2[..., None, :, :]) + K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp() + return K + class RBFKernel(KeOpsKernel): - """ + r""" Implements the RBF kernel using KeOps as a driver for kernel matrix multiplies. - This class can be used as a drop in replacement for gpytorch.kernels.RBFKernel in most cases, and supports - the same arguments. There are currently a few limitations, for example a lack of batch mode support. However, - most other features like ARD will work. + This class can be used as a drop in replacement for :class:`gpytorch.kernels.RBFKernel` in most cases, + and supports the same arguments. + + :param ard_num_dims: Set this if you want a separate lengthscale for each input + dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) + :param batch_shape: Set this if you want a separate lengthscale for each batch of input + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is + a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. + :param active_dims: Set this if you want to compute the covariance of only + a few input dimensions. The ints corresponds to the indices of the + dimensions. (Default: `None`.) + :param lengthscale_prior: Set this if you want to apply a prior to the + lengthscale parameter. (Default: `None`) + :param lengthscale_constraint: Set this if you want to apply a constraint + to the lengthscale parameter. (Default: `Positive`.) + :param eps: The minimum value that the lengthscale can take (prevents + divide by zero errors). (Default: `1e-6`.) + + :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the + ard_num_dims and batch_shape arguments. """ has_lengthscale = True - def _nonkeops_covar_func(self, x1, x2, diag=False): - return postprocess_rbf(self.covar_dist(x1, x2, square_dist=True, diag=diag)) - - def covar_func(self, x1, x2, diag=False): - # We only should use KeOps on big kernel matrices - # If we would otherwise be performing Cholesky inference, (or when just computing a kernel matrix diag) - # then don't apply KeOps - # enable gradients to ensure that test time caches on small predictions are still - # backprop-able - with torch.autograd.enable_grad(): - if ( - diag - or x1.size(-2) < settings.max_cholesky_size.value() - or x2.size(-2) < settings.max_cholesky_size.value() - ): - return self._nonkeops_covar_func(x1, x2, diag=diag) - - x1_ = KEOLazyTensor(x1[..., :, None, :]) - x2_ = KEOLazyTensor(x2[..., None, :, :]) - - K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp() + def _nonkeops_forward(self, x1, x2, diag=False, **kwargs): + x1_ = x1 / self.lengthscale + x2_ = x2 / self.lengthscale + return postprocess_rbf(self.covar_dist(x1_, x2_, square_dist=True, diag=diag, **kwargs)) - return K - - def forward(self, x1, x2, diag=False, **params): - x1_ = x1.div(self.lengthscale) - x2_ = x2.div(self.lengthscale) - - covar_func = lambda x1, x2, diag=diag: self.covar_func(x1, x2, diag) - - if diag: - return covar_func(x1_, x2_, diag=True) - - return KeOpsLinearOperator(x1_, x2_, covar_func) + def _keops_forward(self, x1, x2, **kwargs): + x1_ = x1 / self.lengthscale + x2_ = x2 / self.lengthscale + # return KernelLinearOperator inst only when calculating the whole covariance matrix + return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs) except ImportError: - class RBFKernel(KeOpsKernel): - def __init__(self, *args, **kwargs): - super().__init__() + class RBFKernel(GRBFKernel): + pass diff --git a/gpytorch/test/base_keops_test_case.py b/gpytorch/test/base_keops_test_case.py new file mode 100644 index 000000000..2ce4cd090 --- /dev/null +++ b/gpytorch/test/base_keops_test_case.py @@ -0,0 +1,145 @@ +from abc import abstractmethod +from unittest.mock import patch + +import torch + +import gpytorch +from .base_test_case import BaseTestCase + + +CHOLESKY_SIZE_KEOPS, CHOLESKY_SIZE_NONKEOPS = 2, 800 + + +class BaseKeOpsTestCase(BaseTestCase): + @abstractmethod + def k1(self): + """Returns first kernel class""" + pass + + @abstractmethod + def k2(self): + """Returns second kernel class""" + pass + + # tests the keops implementation + def test_forward_x1_eq_x2(self, ard=False, use_keops=True, **kwargs): + max_cholesky_size = CHOLESKY_SIZE_KEOPS if use_keops else CHOLESKY_SIZE_NONKEOPS + with gpytorch.settings.max_cholesky_size(max_cholesky_size): + ndims = 3 + x1 = torch.randn(100, 3) + + if ard: + kern1 = self.k1(ard_num_dims=ndims, **kwargs) + kern2 = self.k2(ard_num_dims=ndims, **kwargs) + else: + kern1 = self.k1(**kwargs) + kern2 = self.k2(**kwargs) + + with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: + # The patch makes sure that we're actually using KeOps + k1 = kern1(x1, x1).to_dense() + k2 = kern2(x1, x1).to_dense() + self.assertLess(torch.norm(k1 - k2), 1e-4) + + if use_keops: + self.assertTrue(_keops_forward_mock.called) + + def test_forward_x1_eq_x2_ard(self): + return self.test_forward_x1_eq_x2(ard=True) + + def test_forward_x1_neq_x2(self, use_keops=True, ard=False, **kwargs): + max_cholesky_size = CHOLESKY_SIZE_KEOPS if use_keops else CHOLESKY_SIZE_NONKEOPS + with gpytorch.settings.max_cholesky_size(max_cholesky_size): + ndims = 3 + x1 = torch.randn(100, ndims) + x2 = torch.randn(50, ndims) + + if ard: + kern1 = self.k1(ard_num_dims=ndims, **kwargs) + kern2 = self.k2(ard_num_dims=ndims, **kwargs) + else: + kern1 = self.k1(**kwargs) + kern2 = self.k2(**kwargs) + + with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: + # The patch makes sure that we're actually using KeOps + k1 = kern1(x1, x2).to_dense() + k2 = kern2(x1, x2).to_dense() + self.assertLess(torch.norm(k1 - k2), 1e-4) + + if use_keops: + self.assertTrue(_keops_forward_mock.called) + + def test_forward_x1_meq_x2_ard(self): + return self.test_forward_x1_neq_x2(ard=True) + + def test_batch_matmul(self, use_keops=True, **kwargs): + max_cholesky_size = CHOLESKY_SIZE_KEOPS if use_keops else CHOLESKY_SIZE_NONKEOPS + with gpytorch.settings.max_cholesky_size(max_cholesky_size): + x1 = torch.randn(3, 2, 100, 3) + kern1 = self.k1(**kwargs) + kern2 = self.k2(**kwargs) + + rhs = torch.randn(3, 2, 100, 1) + with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: + # The patch makes sure that we're actually using KeOps + res1 = kern1(x1, x1).matmul(rhs) + res2 = kern2(x1, x1).matmul(rhs) + self.assertLess(torch.norm(res1 - res2), 1e-4) + + if use_keops: + self.assertTrue(_keops_forward_mock.called) + + def test_gradient(self, use_keops=True, ard=False, **kwargs): + max_cholesky_size = CHOLESKY_SIZE_KEOPS if use_keops else CHOLESKY_SIZE_NONKEOPS + with gpytorch.settings.max_cholesky_size(max_cholesky_size): + ndims = 3 + + x1 = torch.randn(4, 100, ndims) + + if ard: + kern1 = self.k1(ard_num_dims=ndims, **kwargs) + kern2 = self.k2(ard_num_dims=ndims, **kwargs) + else: + kern1 = self.k1(**kwargs) + kern2 = self.k2(**kwargs) + + with patch.object(self.k1, "_keops_forward", wraps=kern1._keops_forward) as _keops_forward_mock: + # The patch makes sure that we're actually using KeOps + res1 = kern1(x1, x1) + res2 = kern2(x1, x1) + s1 = res1.sum() + s2 = res2.sum() + + # stack all gradients into a tensor + grad_s1 = torch.vstack(torch.autograd.grad(s1, [*kern1.hyperparameters()])) + grad_s2 = torch.vstack(torch.autograd.grad(s2, [*kern2.hyperparameters()])) + self.assertAllClose(grad_s1, grad_s2, rtol=1e-4, atol=1e-5) + + if use_keops: + self.assertTrue(_keops_forward_mock.called) + + def test_gradient_ard(self): + return self.test_gradient(ard=True) + + # tests the nonkeops implementation (_nonkeops_covar_func) + def test_forward_x1_eq_x2_nonkeops(self): + self.test_forward_x1_eq_x2(use_keops=False) + + def test_forward_x1_eq_x2_nonkeops_ard(self): + self.test_forward_x1_eq_x2(use_keops=False, ard=True) + + def test_forward_x1_neq_x2_nonkeops(self): + self.test_forward_x1_neq_x2(use_keops=False) + + def test_forward_x1_neq_x2_nonkeops_ard(self): + self.test_forward_x1_neq_x2(use_keops=False, ard=True) + + def test_batch_matmul_nonkeops(self): + self.test_batch_matmul(use_keops=False) + + def test_gradient_nonkeops(self): + self.test_gradient(use_keops=False) + + def test_gradient_nonkeops_ard(self): + self.test_gradient(use_keops=False, ard=True) diff --git a/setup.py b/setup.py index 36969f6c4..649e2fead 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def find_version(*file_paths): torch_min = "1.11" install_requires = [ "scikit-learn", - "linear_operator>=0.4.0", + "linear_operator>=0.5.0", ] # if recent dev version of PyTorch is installed, no need to install stable try: diff --git a/test/examples/test_keops_gp_regression.py b/test/examples/test_keops_gp_regression.py new file mode 100644 index 000000000..b9d5bdd28 --- /dev/null +++ b/test/examples/test_keops_gp_regression.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +import unittest +from math import pi + +import torch +from torch import optim + +import gpytorch +from gpytorch.distributions import MultivariateNormal +from gpytorch.kernels import ScaleKernel +from gpytorch.kernels.keops import RBFKernel +from gpytorch.likelihoods import GaussianLikelihood +from gpytorch.means import ConstantMean +from gpytorch.test.base_test_case import BaseTestCase + + +# Simple training data: let's try to learn a sine function +train_x = torch.randn(1000, 2) +train_y = torch.sin(train_x[..., 0] * (2 * pi) + train_x[..., 1]) +train_y = train_y + torch.randn_like(train_y).mul(0.001) + +test_x = torch.randn(50, 2) +test_y = torch.sin(test_x[..., 0] * (2 * pi) + test_x[..., 1]) + + +class KeOpsGPModel(gpytorch.models.ExactGP): + def __init__(self, train_x, train_y, likelihood): + super().__init__(train_x, train_y, likelihood) + self.mean_module = ConstantMean() + self.covar_module = ScaleKernel(RBFKernel(ard_num_dims=2)) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return MultivariateNormal(mean_x, covar_x) + + +class TestKeOpsGPRegression(BaseTestCase, unittest.TestCase): + seed = 4 + + def test_keops_gp_mean_abs_error(self): + try: + import pykeops # noqa + except ImportError: + return + + likelihood = GaussianLikelihood() + gp_model = KeOpsGPModel(train_x, train_y, likelihood) + mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp_model) + + # Optimize the model + gp_model.train() + likelihood.train() + optimizer = optim.Adam(list(gp_model.parameters()), lr=0.01) + optimizer.n_iter = 0 + + with gpytorch.settings.max_cholesky_size(0): # Ensure that we're using KeOps + for i in range(300): + optimizer.zero_grad() + output = gp_model(train_x) + loss = -mll(output, train_y) + loss.backward() + optimizer.step() + + if i == 0: + for param in gp_model.parameters(): + self.assertTrue(param.grad is not None) + + # Test the model + with torch.no_grad(): + gp_model.eval() + likelihood.eval() + test_preds = likelihood(gp_model(test_x)).mean + mean_abs_error = torch.mean(torch.abs(test_y - test_preds)) + + self.assertLess(mean_abs_error.squeeze().item(), 0.02) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/kernels/keops/test_matern_kernel.py b/test/kernels/keops/test_matern_kernel.py index 46a922289..e68d1c26b 100644 --- a/test/kernels/keops/test_matern_kernel.py +++ b/test/kernels/keops/test_matern_kernel.py @@ -2,10 +2,9 @@ import unittest -import torch - from gpytorch.kernels import MaternKernel as GMaternKernel from gpytorch.kernels.keops import MaternKernel +from gpytorch.test.base_keops_test_case import BaseKeOpsTestCase from gpytorch.test.base_kernel_test_case import BaseKernelTestCase try: @@ -18,67 +17,32 @@ def create_kernel_no_ard(self, **kwargs): def create_kernel_ard(self, num_dims, **kwargs): return MaternKernel(nu=2.5, ard_num_dims=num_dims, **kwargs) - class TestMaternKeOpsKernel(unittest.TestCase): - def forward_x1_eq_x2(self, nu): - if not torch.cuda.is_available(): - return - - x1 = torch.randn(100, 3).cuda() - - kern1 = MaternKernel(nu=nu).cuda() - kern2 = GMaternKernel(nu=nu).cuda() - - k1 = kern1(x1, x1).to_dense() - k2 = kern2(x1, x1).to_dense() - - self.assertLess(torch.norm(k1 - k2), 1e-4) - - def forward_x1_neq_x2(self, nu): - if not torch.cuda.is_available(): - return - - x1 = torch.randn(100, 3).cuda() - x2 = torch.randn(50, 3).cuda() + class TestMaternKeOpsKernel(unittest.TestCase, BaseKeOpsTestCase): + @property + def k1(self): + return MaternKernel - kern1 = MaternKernel(nu=nu).cuda() - kern2 = GMaternKernel(nu=nu).cuda() - - k1 = kern1(x1, x2).to_dense() - k2 = kern2(x1, x2).to_dense() - - self.assertLess(torch.norm(k1 - k2), 1e-4) + @property + def k2(self): + return GMaternKernel def test_forward_nu25_x1_eq_x2(self): - return self.forward_x1_eq_x2(nu=2.5) + return self.test_forward_x1_eq_x2(nu=2.5) def test_forward_nu25_x1_neq_x2(self): - return self.forward_x1_neq_x2(nu=2.5) + return self.test_forward_x1_neq_x2(nu=2.5) def test_forward_nu15_x1_eq_x2(self): - return self.forward_x1_eq_x2(nu=1.5) + return self.test_forward_x1_eq_x2(nu=1.5) def test_forward_nu15_x1_neq_x2(self): - return self.forward_x1_neq_x2(nu=1.5) + return self.test_forward_x1_neq_x2(nu=1.5) def test_forward_nu05_x1_eq_x2(self): - return self.forward_x1_eq_x2(nu=0.5) + return self.test_forward_x1_eq_x2(nu=0.5) def test_forward_nu05_x1_neq_x2(self): - return self.forward_x1_neq_x2(nu=0.5) - - def test_batch_matmul(self): - if not torch.cuda.is_available(): - return - - x1 = torch.randn(3, 2, 100, 3).cuda() - kern1 = MaternKernel(nu=2.5).cuda() - kern2 = GMaternKernel(nu=2.5).cuda() - - rhs = torch.randn(3, 2, 100, 1).cuda() - res1 = kern1(x1, x1).matmul(rhs) - res2 = kern2(x1, x1).matmul(rhs) - - self.assertLess(torch.norm(res1 - res2), 1e-4) + return self.test_forward_x1_neq_x2(nu=0.5) except ImportError: pass diff --git a/test/kernels/keops/test_periodic_kernel.py b/test/kernels/keops/test_periodic_kernel.py new file mode 100644 index 000000000..60e92673a --- /dev/null +++ b/test/kernels/keops/test_periodic_kernel.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import unittest + +from gpytorch.kernels import PeriodicKernel as GPeriodicKernel +from gpytorch.kernels.keops import PeriodicKernel +from gpytorch.priors import NormalPrior +from gpytorch.test.base_keops_test_case import BaseKeOpsTestCase +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + +try: + import pykeops # noqa + + class TestPeriodicKeOpsBaseKernel(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return PeriodicKernel(**kwargs) + + def create_kernel_ard(self, num_dims, **kwargs): + return PeriodicKernel(ard_num_dims=num_dims, **kwargs) + + class TestPeriodicKeOpsKernel(unittest.TestCase, BaseKeOpsTestCase): + @property + def k1(self): + return PeriodicKernel + + @property + def k2(self): + return GPeriodicKernel + + def create_kernel_with_prior(self, period_length_prior): + return self.k1(period_length_prior=period_length_prior) + + def test_prior_type(self): + """ + Raising TypeError if prior type is other than gpytorch.priors.Prior + """ + kernel_fn = lambda prior: self.create_kernel_with_prior(prior) + kernel_fn(None) + kernel_fn(NormalPrior(0, 1)) + self.assertRaises(TypeError, kernel_fn, 1) + +except ImportError: + pass + +if __name__ == "__main__": + unittest.main() diff --git a/test/kernels/keops/test_rbf_kernel.py b/test/kernels/keops/test_rbf_kernel.py index d28943cc3..7e46e7d11 100644 --- a/test/kernels/keops/test_rbf_kernel.py +++ b/test/kernels/keops/test_rbf_kernel.py @@ -2,10 +2,9 @@ import unittest -import torch - from gpytorch.kernels import RBFKernel as GRBFKernel from gpytorch.kernels.keops import RBFKernel +from gpytorch.test.base_keops_test_case import BaseKeOpsTestCase from gpytorch.test.base_kernel_test_case import BaseKernelTestCase try: @@ -18,49 +17,14 @@ def create_kernel_no_ard(self, **kwargs): def create_kernel_ard(self, num_dims, **kwargs): return RBFKernel(ard_num_dims=num_dims, **kwargs) - class TestRBFKeOpsKernel(unittest.TestCase): - def test_forward_x1_eq_x2(self): - if not torch.cuda.is_available(): - return - - x1 = torch.randn(100, 3).cuda() - - kern1 = RBFKernel().cuda() - kern2 = GRBFKernel().cuda() - - k1 = kern1(x1, x1).to_dense() - k2 = kern2(x1, x1).to_dense() - - self.assertLess(torch.norm(k1 - k2), 1e-4) - - def test_forward_x1_neq_x2(self): - if not torch.cuda.is_available(): - return - - x1 = torch.randn(100, 3).cuda() - x2 = torch.randn(50, 3).cuda() - - kern1 = RBFKernel().cuda() - kern2 = GRBFKernel().cuda() - - k1 = kern1(x1, x2).to_dense() - k2 = kern2(x1, x2).to_dense() - - self.assertLess(torch.norm(k1 - k2), 1e-4) - - def test_batch_matmul(self): - if not torch.cuda.is_available(): - return - - x1 = torch.randn(3, 2, 100, 3).cuda() - kern1 = RBFKernel().cuda() - kern2 = GRBFKernel().cuda() - - rhs = torch.randn(3, 2, 100, 1).cuda() - res1 = kern1(x1, x1).matmul(rhs) - res2 = kern2(x1, x1).matmul(rhs) + class TestRBFKeOpsKernel(unittest.TestCase, BaseKeOpsTestCase): + @property + def k1(self): + return RBFKernel - self.assertLess(torch.norm(res1 - res2), 1e-4) + @property + def k2(self): + return GRBFKernel except ImportError: pass