Skip to content

Mean and kernel functions for first and second derivatives #2235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
99697e1
Added an RBF kernel with second (non-mixed) derivatives. Need to be t…
ankushaggarwal Jun 11, 2021
e3e97cb
Added constant mean with second derivative and linear means with firs…
ankushaggarwal Jun 12, 2021
cbf0bf0
Merge branch 'master' into second-derivs
ankushaggarwal Jan 3, 2023
c18e6d8
Hook fixes using pre-commit
ankushaggarwal Jan 22, 2023
6fa7f0d
Merge branch 'master' into second-derivs
ankushaggarwal Jan 22, 2023
c43c8c8
Added new kernel and means to the documentation
ankushaggarwal Jan 22, 2023
451dfd1
Hook fix
ankushaggarwal Jan 22, 2023
9bf9c3b
Changed from lazy tensor to linear operator as per the newer gpytorch…
ankushaggarwal Jan 22, 2023
0bdfc3e
Changed from utils.broadcasting._mul_broadcast_shape to torch.broadca…
ankushaggarwal Jan 22, 2023
f6a3ce2
Revert "Hook fix"
ankushaggarwal Jan 22, 2023
6e91183
Fixed a minor error (as per the new version of gpytorch)
ankushaggarwal Jan 22, 2023
d473faf
Added the diag=True version of rbf gradgrad kernel
ankushaggarwal Jan 22, 2023
1f7ab49
Fix formatting with pre-commit
ankushaggarwal Jan 23, 2023
b18ecc6
Increased the underline length to pass doc test
ankushaggarwal Jan 23, 2023
e246d2f
Getting rid of the changes to pre-commit-hooks
ankushaggarwal Apr 18, 2023
daeedab
Merge remote-tracking branch 'upstream/master' into second-derivs
ankushaggarwal Apr 18, 2023
b717bfc
Added docstrings and type hints
ankushaggarwal Apr 18, 2023
1623823
Fixed the wrong references in doc of ConstantMeanGradGrad
ankushaggarwal Apr 18, 2023
bd8a9cc
Added unit tests
ankushaggarwal Apr 21, 2023
94819cf
Merge branch 'master' into second-derivs
gpleiss May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .product_structure_kernel import ProductStructureKernel
from .rbf_kernel import RBFKernel
from .rbf_kernel_grad import RBFKernelGrad
from .rbf_kernel_gradgrad import RBFKernelGradGrad
from .rff_kernel import RFFKernel
from .rq_kernel import RQKernel
from .scale_kernel import ScaleKernel
Expand Down Expand Up @@ -59,6 +60,7 @@
"RBFKernel",
"RFFKernel",
"RBFKernelGrad",
"RBFKernelGradGrad",
"RQKernel",
"ScaleKernel",
"SpectralDeltaKernel",
Expand Down
157 changes: 157 additions & 0 deletions gpytorch/kernels/rbf_kernel_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3
import torch

from ..lazy.kronecker_product_lazy_tensor import KroneckerProductLazyTensor
from .rbf_kernel import RBFKernel, postprocess_rbf


class RBFKernelGradGrad(RBFKernel):
r"""
Computes a covariance matrix of the RBF kernel that models the covariance
between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}`
and :math:`\mathbf{x_2}`.

See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options.

.. note::

This kernel does not have an `outputscale` parameter. To add a scaling parameter,
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.

Args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the doc string to be in the standard sphinx format? (See the Kernel base class for an example)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

:attr:`batch_shape` (torch.Size, optional):
Set this if you want a separate lengthscale for each
batch of input data. It should be `b` if :attr:`x1` is a `b x n x d` tensor. Default: `torch.Size([])`.
:attr:`active_dims` (tuple of ints, optional):
Set this if you want to compute the covariance of only a few input dimensions. The ints
corresponds to the indices of the dimensions. Default: `None`.
:attr:`lengthscale_prior` (Prior, optional):
Set this if you want to apply a prior to the lengthscale parameter. Default: `None`.
:attr:`lengthscale_constraint` (Constraint, optional):
Set this if you want to apply a constraint to the lengthscale parameter. Default: `Positive`.
:attr:`eps` (float):
The minimum value that the lengthscale can take (prevents divide by zero errors). Default: `1e-6`.

Attributes:
:attr:`lengthscale` (Tensor):
The lengthscale parameter. Size/shape of parameter depends on the
:attr:`ard_num_dims` and :attr:`batch_shape` arguments.

Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad())
>>> covar = covar_module(x) # Output: LazyTensor of size (110 x 110), where 60 = n * (2*d + 1)
>>>
>>> batch_x = torch.randn(2, 10, 5)
>>> # Batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad())
>>> # Batch: different lengthscale for each batch
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad(batch_shape=torch.Size([2])))
>>> covar = covar_module(x) # Output: LazyTensor of size (2 x 110 x 110)
"""

def forward(self, x1, x2, diag=False, **params):
batch_shape = x1.shape[:-2]
n_batch_dims = len(batch_shape)
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

K = torch.zeros(*batch_shape, n1 * (2*d + 1), n2 * (2*d + 1), device=x1.device, dtype=x1.dtype)

if not diag:
# Scale the inputs by the lengthscale (for stability)
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)

# Form all possible rank-1 products for the gradient and Hessian blocks
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d)
outer = outer / self.lengthscale.unsqueeze(-2)
outer = torch.transpose(outer, -1, -2).contiguous()

# 1) Kernel block
K_11 = self.covar_dist(x1_, x2_, square_dist=True, dist_postprocess_func=postprocess_rbf, **params)
K[..., :n1, :n2] = K_11

# 2) First gradient block
outer1 = outer.view(*batch_shape, n1, n2 * d)
K[..., :n1, n2:(n2*(d+1))] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d])

# 3) Second gradient block
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
outer2 = outer2.transpose(-1, -2)
K[..., n1:(n1*(d+1)), :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1])

# 4) Hessian block
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
kp = KroneckerProductLazyTensor(
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
)
chain_rule = kp.evaluate() - outer3
K[..., n1:(n1*(d+1)), n2:(n2*(d+1))] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d])

# 5) 1-3 block
douter1dx2 = KroneckerProductLazyTensor(
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
).evaluate()

K_13 = (-douter1dx2 + outer1*outer1)*K_11.repeat([*([1] * (n_batch_dims + 1)), d]) #verified for n1=n2=1 case
K[..., :n1, (n2*(d+1)):] = K_13

K_31 = ( -douter1dx2.transpose(-1,-2) + outer2*outer2)*K_11.repeat([*([1] * n_batch_dims), d, 1]) #verified for n1=n2=1 case
K[..., (n1*(d+1)):, :n2] = K_31

# rest of the blocks are all of size (n1*d,n2*d)
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1])
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d])
#II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2)
kp2 = KroneckerProductLazyTensor(
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
).evaluate()

#II may not be the correct thing to use. It might be more appropriate to use kp instead??
II = kp.evaluate()
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d])

K_23 = ( (-kp2 + outer1*outer1)*(-outer2) + 2.*II*outer1)*K_11dd #verified for n1=n2=1 case

K[...,n1:(n1*(d+1)),(n2*(d+1)):] = K_23

K_32 = ( (-kp2.transpose(-1,-2) + outer2*outer2)*outer1 - 2.*II*outer2)*K_11dd #verified for n1=n2=1 case

K[...,(n1*(d+1)):,n2:(n2*(d+1))] = K_32

K_33 = ( (-kp2.transpose(-1,-2)+outer2*outer2)*(-kp2) - 2.*II*outer2*outer1 + 2.*(II)**2)*K_11dd + ( (-kp2.transpose(-1,-2) + outer2*outer2)*outer1 - 2.*II*outer2)*outer1* K_11dd #verified for n1=n2=1 case

K[...,(n1*(d+1)):,(n2*(d+1)):] = K_33

# Symmetrize for stability
if n1 == n2 and torch.eq(x1, x2).all():
K = 0.5 * (K.transpose(-1, -2) + K)

# Apply a perfect shuffle permutation to match the MutiTask ordering
pi1 = torch.arange(n1 * (2*d + 1)).view(2*d + 1, n1).t().reshape((n1 * (2*d + 1)))
pi2 = torch.arange(n2 * (2*d + 1)).view(2*d + 1, n2).t().reshape((n2 * (2*d + 1)))
K = K[..., pi1, :][..., :, pi2]

return K

else:
raise RuntimeError("diag=True not implemented yet")
'''This has not been updated from RBFKernelGrad yet
if not (n1 == n2 and torch.eq(x1, x2).all()):
raise RuntimeError("diag=True only works when x1 == x2")

kernel_diag = super(RBFKernelGrad, self).forward(x1, x2, diag=True)
grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2)
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1)
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
return k_diag[..., pi]
'''

def num_outputs_per_input(self, x1, x2):
return x1.size(-1)*2 + 1
5 changes: 4 additions & 1 deletion gpytorch/means/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from .constant_mean import ConstantMean
from .constant_mean_grad import ConstantMeanGrad
from .constant_mean_gradgrad import ConstantMeanGradGrad
from .linear_mean import LinearMean
from .linear_mean_grad import LinearMeanGrad
from .linear_mean_gradgrad import LinearMeanGradGrad
from .mean import Mean
from .multitask_mean import MultitaskMean
from .zero_mean import ZeroMean

__all__ = ["Mean", "ConstantMean", "ConstantMeanGrad", "LinearMean", "MultitaskMean", "ZeroMean"]
__all__ = ["Mean", "ConstantMean", "ConstantMeanGrad", "ConstantMeanGradGrad", "LinearMean", "LinearMeanGrad", "LinearMeanGradGrad", "MultitaskMean", "ZeroMean"]
21 changes: 21 additions & 0 deletions gpytorch/means/constant_mean_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3

import torch

from ..utils.broadcasting import _mul_broadcast_shape
from .mean import Mean


class ConstantMeanGradGrad(Mean):
def __init__(self, prior=None, batch_shape=torch.Size(), **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a doc string and type hints.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super(ConstantMeanGradGrad, self).__init__()
self.batch_shape = batch_shape
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
if prior is not None:
self.register_prior("mean_prior", prior, "constant")

def forward(self, input):
batch_shape = _mul_broadcast_shape(self.batch_shape, input.shape[:-2])
mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), 2*input.size(-1) + 1).contiguous()
mean[..., 1:] = 0
return mean
23 changes: 23 additions & 0 deletions gpytorch/means/linear_mean_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class LinearMeanGrad(Mean):
def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string and type hints

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super().__init__()
self.dim = input_size
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
if bias:
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
else:
self.bias = None

def forward(self, x):
res = x.matmul(self.weights)
if self.bias is not None:
res = res + self.bias.unsqueeze(-1)
dres = self.weights.expand(x.transpose(-1,-2).shape).transpose(-1,-2)
return torch.cat((res,dres),-1)
24 changes: 24 additions & 0 deletions gpytorch/means/linear_mean_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class LinearMeanGradGrad(Mean):
def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string and type hints.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super().__init__()
self.dim = input_size
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
if bias:
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
else:
self.bias = None

def forward(self, x):
res = x.matmul(self.weights)
if self.bias is not None:
res = res + self.bias.unsqueeze(-1)
dres = self.weights.expand(x.transpose(-1,-2).shape).transpose(-1,-2)
ddres = torch.zeros_like(dres)
return torch.cat((res,dres,ddres),-1)