diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index 1d87e764b..55119b784 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -15,6 +15,7 @@ from .kernel import AdditiveKernel, Kernel, ProductKernel from .lcm_kernel import LCMKernel from .linear_kernel import LinearKernel +from .matern52_kernel_grad import Matern52KernelGrad from .matern_kernel import MaternKernel from .multi_device_kernel import MultiDeviceKernel from .multitask_kernel import MultitaskKernel @@ -69,4 +70,5 @@ "ScaleKernel", "SpectralDeltaKernel", "SpectralMixtureKernel", + "Matern52KernelGrad", ] diff --git a/gpytorch/kernels/matern52_kernel_grad.py b/gpytorch/kernels/matern52_kernel_grad.py new file mode 100644 index 000000000..04aa95c2f --- /dev/null +++ b/gpytorch/kernels/matern52_kernel_grad.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 + +import math + +import torch +from linear_operator.operators import KroneckerProductLinearOperator + +from gpytorch.kernels.matern_kernel import MaternKernel + +sqrt5 = math.sqrt(5) +five_thirds = 5.0 / 3.0 + + +class Matern52KernelGrad(MaternKernel): + r""" + Computes a covariance matrix of the Matern52 kernel that models the covariance + between the values and partial derivatives for inputs :math:`\mathbf{x_1}` + and :math:`\mathbf{x_2}`. + + See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options. + + .. note:: + + This kernel does not have an `outputscale` parameter. To add a scaling parameter, + decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`. + + :param ard_num_dims: Set this if you want a separate lengthscale for each input + dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) + :param batch_shape: Set this if you want a separate lengthscale for each batch of input + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is + a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. + :param active_dims: Set this if you want to compute the covariance of only + a few input dimensions. The ints corresponds to the indices of the + dimensions. (Default: `None`.) + :param lengthscale_prior: Set this if you want to apply a prior to the + lengthscale parameter. (Default: `None`) + :param lengthscale_constraint: Set this if you want to apply a constraint + to the lengthscale parameter. (Default: `Positive`.) + :param eps: The minimum value that the lengthscale can take (prevents + divide by zero errors). (Default: `1e-6`.) + + :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the + ard_num_dims and batch_shape arguments. + + Example: + >>> x = torch.randn(10, 5) + >>> # Non-batch: Simple option + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad()) + >>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1) + >>> + >>> batch_x = torch.randn(2, 10, 5) + >>> # Batch: Simple option + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad()) + >>> # Batch: different lengthscale for each batch + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501 + >>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60) + """ + + def __init__(self, **kwargs): + + # remove nu in case it was set + kwargs.pop("nu", None) + super(Matern52KernelGrad, self).__init__(nu=2.5, **kwargs) + + def forward(self, x1, x2, diag=False, **params): + + lengthscale = self.lengthscale + + batch_shape = x1.shape[:-2] + n_batch_dims = len(batch_shape) + n1, d = x1.shape[-2:] + n2 = x2.shape[-2] + + if not diag: + + K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype) + + distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params) + exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix) + + # differences matrix in each dimension to be used for derivatives + # shape of n1 x n2 x d + outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d) + outer = outer / lengthscale.unsqueeze(-2) ** 2 + # shape of n1 x d x n2 + outer = torch.transpose(outer, -1, -2).contiguous() + + # 1) Kernel block, cov(f^m, f^n) + # shape is n1 x n2 + exp_component = torch.exp(-sqrt5 * distance_matrix) + constant_component = (sqrt5 * distance_matrix).add(1).add(five_thirds * distance_matrix**2) + + K[..., :n1, :n2] = constant_component * exp_component + + # 2) First gradient block, cov(f^m, omega^n_d) + outer1 = outer.view(*batch_shape, n1, n2 * d) + K[..., :n1, n2:] = outer1 * (-five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat( + [*([1] * (n_batch_dims + 1)), d] + ) + + # 3) Second gradient block, cov(omega^m_d, f^n) + outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d) + outer2 = outer2.transpose(-1, -2) + # the - signs on -outer2 and -five_thirds cancel out + K[..., n1:, :n2] = outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat( + [*([1] * n_batch_dims), d, 1] + ) + + # 4) Hessian block, cov(omega^m_d, omega^n_d) + outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d]) + kp = KroneckerProductLinearOperator( + torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2, + torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), + ) + + part1 = -five_thirds * exp_neg_sqrt5r + part2 = 5 * outer3 + part3 = 1 + sqrt5 * distance_matrix + + K[..., n1:, n2:] = part1.repeat([*([1] * n_batch_dims), d, d]).mul_( + # need to use kp.to_dense().mul instead of kp.to_dense().mul_ + # because otherwise a RuntimeError is raised due to how autograd works with + # view + inplace operations in the case of 1-dimensional input + part2.sub_(kp.to_dense().mul(part3.repeat([*([1] * n_batch_dims), d, d]))) + ) + + # Symmetrize for stability + if n1 == n2 and torch.eq(x1, x2).all(): + K = 0.5 * (K.transpose(-1, -2) + K) + + # Apply a perfect shuffle permutation to match the MutiTask ordering + pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1))) + pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) + K = K[..., pi1, :][..., :, pi2] + + return K + else: + if not (n1 == n2 and torch.eq(x1, x2).all()): + raise RuntimeError("diag=True only works when x1 == x2") + + # nu is set to 2.5 + kernel_diag = super(Matern52KernelGrad, self).forward(x1, x2, diag=True) + grad_diag = ( + five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) + ) / lengthscale**2 + grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) + k_diag = torch.cat((kernel_diag, grad_diag), dim=-1) + pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) + return k_diag[..., pi] + + def num_outputs_per_input(self, x1, x2): + return x1.size(-1) + 1 diff --git a/test/kernels/test_matern52_kernel_grad.py b/test/kernels/test_matern52_kernel_grad.py new file mode 100644 index 000000000..5a76a2a33 --- /dev/null +++ b/test/kernels/test_matern52_kernel_grad.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.kernels import Matern52KernelGrad +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class TestMatern52KernelGrad(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return Matern52KernelGrad(**kwargs) + + def create_kernel_ard(self, num_dims, **kwargs): + return Matern52KernelGrad(ard_num_dims=num_dims, **kwargs) + + def test_kernel(self, cuda=False): + a = torch.tensor([[[1, 2], [2, 4]]], dtype=torch.float) + b = torch.tensor([[[1, 3], [0, 4]]], dtype=torch.float) + + actual = torch.tensor( + [ + [0.3056225, -0.0000000, 0.5822443, 0.0188260, -0.0209871, 0.0419742], + [0.0000000, 0.5822443, 0.0000000, 0.0209871, -0.0056045, 0.0531832], + [-0.5822443, 0.0000000, -0.8515886, -0.0419742, 0.0531832, -0.0853792], + [0.1304891, -0.2014212, -0.2014212, 0.0336440, -0.0815567, -0.0000000], + [0.2014212, -0.1754366, -0.3768578, 0.0815567, -0.1870145, -0.0000000], + [0.2014212, -0.3768578, -0.1754366, 0.0000000, -0.0000000, 0.0407784], + ] + ) + + kernel = Matern52KernelGrad() + + if cuda: + a = a.cuda() + b = b.cuda() + actual = actual.cuda() + kernel = kernel.cuda() + + res = kernel(a, b).to_dense() + + self.assertLess(torch.norm(res - actual), 1e-5) + + def test_kernel_cuda(self): + if torch.cuda.is_available(): + self.test_kernel(cuda=True) + + def test_kernel_batch(self): + a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) + b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) + + kernel = Matern52KernelGrad() + res = kernel(a, b).to_dense() + + # Compute each batch separately + actual = torch.zeros(2, 8, 8) + actual[0, :, :] = kernel(a[0, :, :].squeeze(), b[0, :, :].squeeze()).to_dense() + actual[1, :, :] = kernel(a[1, :, :].squeeze(), b[1, :, :].squeeze()).to_dense() + + self.assertLess(torch.norm(res - actual), 1e-5) + + def test_initialize_lengthscale(self): + kernel = Matern52KernelGrad() + kernel.initialize(lengthscale=3.14) + actual_value = torch.tensor(3.14).view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + def test_initialize_lengthscale_batch(self): + kernel = Matern52KernelGrad(batch_shape=torch.Size([2])) + ls_init = torch.tensor([3.14, 4.13]) + kernel.initialize(lengthscale=ls_init) + actual_value = ls_init.view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + +if __name__ == "__main__": + unittest.main()