Skip to content

Commit 25da2cc

Browse files
Merge pull request #2512 from m-julian/matern52_grad
Matern52 grad
2 parents 2e7959d + f8b9a5e commit 25da2cc

File tree

3 files changed

+232
-0
lines changed

3 files changed

+232
-0
lines changed

gpytorch/kernels/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .kernel import AdditiveKernel, Kernel, ProductKernel
1616
from .lcm_kernel import LCMKernel
1717
from .linear_kernel import LinearKernel
18+
from .matern52_kernel_grad import Matern52KernelGrad
1819
from .matern_kernel import MaternKernel
1920
from .multi_device_kernel import MultiDeviceKernel
2021
from .multitask_kernel import MultitaskKernel
@@ -69,4 +70,5 @@
6970
"ScaleKernel",
7071
"SpectralDeltaKernel",
7172
"SpectralMixtureKernel",
73+
"Matern52KernelGrad",
7274
]
+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#!/usr/bin/env python3
2+
3+
import math
4+
5+
import torch
6+
from linear_operator.operators import KroneckerProductLinearOperator
7+
8+
from gpytorch.kernels.matern_kernel import MaternKernel
9+
10+
sqrt5 = math.sqrt(5)
11+
five_thirds = 5.0 / 3.0
12+
13+
14+
class Matern52KernelGrad(MaternKernel):
15+
r"""
16+
Computes a covariance matrix of the Matern52 kernel that models the covariance
17+
between the values and partial derivatives for inputs :math:`\mathbf{x_1}`
18+
and :math:`\mathbf{x_2}`.
19+
20+
See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options.
21+
22+
.. note::
23+
24+
This kernel does not have an `outputscale` parameter. To add a scaling parameter,
25+
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
26+
27+
:param ard_num_dims: Set this if you want a separate lengthscale for each input
28+
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
29+
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
30+
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
31+
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
32+
:param active_dims: Set this if you want to compute the covariance of only
33+
a few input dimensions. The ints corresponds to the indices of the
34+
dimensions. (Default: `None`.)
35+
:param lengthscale_prior: Set this if you want to apply a prior to the
36+
lengthscale parameter. (Default: `None`)
37+
:param lengthscale_constraint: Set this if you want to apply a constraint
38+
to the lengthscale parameter. (Default: `Positive`.)
39+
:param eps: The minimum value that the lengthscale can take (prevents
40+
divide by zero errors). (Default: `1e-6`.)
41+
42+
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
43+
ard_num_dims and batch_shape arguments.
44+
45+
Example:
46+
>>> x = torch.randn(10, 5)
47+
>>> # Non-batch: Simple option
48+
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad())
49+
>>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1)
50+
>>>
51+
>>> batch_x = torch.randn(2, 10, 5)
52+
>>> # Batch: Simple option
53+
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad())
54+
>>> # Batch: different lengthscale for each batch
55+
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501
56+
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60)
57+
"""
58+
59+
def __init__(self, **kwargs):
60+
61+
# remove nu in case it was set
62+
kwargs.pop("nu", None)
63+
super(Matern52KernelGrad, self).__init__(nu=2.5, **kwargs)
64+
65+
def forward(self, x1, x2, diag=False, **params):
66+
67+
lengthscale = self.lengthscale
68+
69+
batch_shape = x1.shape[:-2]
70+
n_batch_dims = len(batch_shape)
71+
n1, d = x1.shape[-2:]
72+
n2 = x2.shape[-2]
73+
74+
if not diag:
75+
76+
K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)
77+
78+
distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params)
79+
exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix)
80+
81+
# differences matrix in each dimension to be used for derivatives
82+
# shape of n1 x n2 x d
83+
outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d)
84+
outer = outer / lengthscale.unsqueeze(-2) ** 2
85+
# shape of n1 x d x n2
86+
outer = torch.transpose(outer, -1, -2).contiguous()
87+
88+
# 1) Kernel block, cov(f^m, f^n)
89+
# shape is n1 x n2
90+
exp_component = torch.exp(-sqrt5 * distance_matrix)
91+
constant_component = (sqrt5 * distance_matrix).add(1).add(five_thirds * distance_matrix**2)
92+
93+
K[..., :n1, :n2] = constant_component * exp_component
94+
95+
# 2) First gradient block, cov(f^m, omega^n_d)
96+
outer1 = outer.view(*batch_shape, n1, n2 * d)
97+
K[..., :n1, n2:] = outer1 * (-five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
98+
[*([1] * (n_batch_dims + 1)), d]
99+
)
100+
101+
# 3) Second gradient block, cov(omega^m_d, f^n)
102+
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
103+
outer2 = outer2.transpose(-1, -2)
104+
# the - signs on -outer2 and -five_thirds cancel out
105+
K[..., n1:, :n2] = outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
106+
[*([1] * n_batch_dims), d, 1]
107+
)
108+
109+
# 4) Hessian block, cov(omega^m_d, omega^n_d)
110+
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
111+
kp = KroneckerProductLinearOperator(
112+
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2,
113+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
114+
)
115+
116+
part1 = -five_thirds * exp_neg_sqrt5r
117+
part2 = 5 * outer3
118+
part3 = 1 + sqrt5 * distance_matrix
119+
120+
K[..., n1:, n2:] = part1.repeat([*([1] * n_batch_dims), d, d]).mul_(
121+
# need to use kp.to_dense().mul instead of kp.to_dense().mul_
122+
# because otherwise a RuntimeError is raised due to how autograd works with
123+
# view + inplace operations in the case of 1-dimensional input
124+
part2.sub_(kp.to_dense().mul(part3.repeat([*([1] * n_batch_dims), d, d])))
125+
)
126+
127+
# Symmetrize for stability
128+
if n1 == n2 and torch.eq(x1, x2).all():
129+
K = 0.5 * (K.transpose(-1, -2) + K)
130+
131+
# Apply a perfect shuffle permutation to match the MutiTask ordering
132+
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
133+
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
134+
K = K[..., pi1, :][..., :, pi2]
135+
136+
return K
137+
else:
138+
if not (n1 == n2 and torch.eq(x1, x2).all()):
139+
raise RuntimeError("diag=True only works when x1 == x2")
140+
141+
# nu is set to 2.5
142+
kernel_diag = super(Matern52KernelGrad, self).forward(x1, x2, diag=True)
143+
grad_diag = (
144+
five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype)
145+
) / lengthscale**2
146+
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
147+
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1)
148+
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
149+
return k_diag[..., pi]
150+
151+
def num_outputs_per_input(self, x1, x2):
152+
return x1.size(-1) + 1
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
import torch
6+
7+
from gpytorch.kernels import Matern52KernelGrad
8+
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase
9+
10+
11+
class TestMatern52KernelGrad(unittest.TestCase, BaseKernelTestCase):
12+
def create_kernel_no_ard(self, **kwargs):
13+
return Matern52KernelGrad(**kwargs)
14+
15+
def create_kernel_ard(self, num_dims, **kwargs):
16+
return Matern52KernelGrad(ard_num_dims=num_dims, **kwargs)
17+
18+
def test_kernel(self, cuda=False):
19+
a = torch.tensor([[[1, 2], [2, 4]]], dtype=torch.float)
20+
b = torch.tensor([[[1, 3], [0, 4]]], dtype=torch.float)
21+
22+
actual = torch.tensor(
23+
[
24+
[0.3056225, -0.0000000, 0.5822443, 0.0188260, -0.0209871, 0.0419742],
25+
[0.0000000, 0.5822443, 0.0000000, 0.0209871, -0.0056045, 0.0531832],
26+
[-0.5822443, 0.0000000, -0.8515886, -0.0419742, 0.0531832, -0.0853792],
27+
[0.1304891, -0.2014212, -0.2014212, 0.0336440, -0.0815567, -0.0000000],
28+
[0.2014212, -0.1754366, -0.3768578, 0.0815567, -0.1870145, -0.0000000],
29+
[0.2014212, -0.3768578, -0.1754366, 0.0000000, -0.0000000, 0.0407784],
30+
]
31+
)
32+
33+
kernel = Matern52KernelGrad()
34+
35+
if cuda:
36+
a = a.cuda()
37+
b = b.cuda()
38+
actual = actual.cuda()
39+
kernel = kernel.cuda()
40+
41+
res = kernel(a, b).to_dense()
42+
43+
self.assertLess(torch.norm(res - actual), 1e-5)
44+
45+
def test_kernel_cuda(self):
46+
if torch.cuda.is_available():
47+
self.test_kernel(cuda=True)
48+
49+
def test_kernel_batch(self):
50+
a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float)
51+
b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1)
52+
53+
kernel = Matern52KernelGrad()
54+
res = kernel(a, b).to_dense()
55+
56+
# Compute each batch separately
57+
actual = torch.zeros(2, 8, 8)
58+
actual[0, :, :] = kernel(a[0, :, :].squeeze(), b[0, :, :].squeeze()).to_dense()
59+
actual[1, :, :] = kernel(a[1, :, :].squeeze(), b[1, :, :].squeeze()).to_dense()
60+
61+
self.assertLess(torch.norm(res - actual), 1e-5)
62+
63+
def test_initialize_lengthscale(self):
64+
kernel = Matern52KernelGrad()
65+
kernel.initialize(lengthscale=3.14)
66+
actual_value = torch.tensor(3.14).view_as(kernel.lengthscale)
67+
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)
68+
69+
def test_initialize_lengthscale_batch(self):
70+
kernel = Matern52KernelGrad(batch_shape=torch.Size([2]))
71+
ls_init = torch.tensor([3.14, 4.13])
72+
kernel.initialize(lengthscale=ls_init)
73+
actual_value = ls_init.view_as(kernel.lengthscale)
74+
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

0 commit comments

Comments
 (0)