Skip to content

Mean and kernel functions for first and second derivatives #2235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
99697e1
Added an RBF kernel with second (non-mixed) derivatives. Need to be t…
ankushaggarwal Jun 11, 2021
e3e97cb
Added constant mean with second derivative and linear means with firs…
ankushaggarwal Jun 12, 2021
cbf0bf0
Merge branch 'master' into second-derivs
ankushaggarwal Jan 3, 2023
c18e6d8
Hook fixes using pre-commit
ankushaggarwal Jan 22, 2023
6fa7f0d
Merge branch 'master' into second-derivs
ankushaggarwal Jan 22, 2023
c43c8c8
Added new kernel and means to the documentation
ankushaggarwal Jan 22, 2023
451dfd1
Hook fix
ankushaggarwal Jan 22, 2023
9bf9c3b
Changed from lazy tensor to linear operator as per the newer gpytorch…
ankushaggarwal Jan 22, 2023
0bdfc3e
Changed from utils.broadcasting._mul_broadcast_shape to torch.broadca…
ankushaggarwal Jan 22, 2023
f6a3ce2
Revert "Hook fix"
ankushaggarwal Jan 22, 2023
6e91183
Fixed a minor error (as per the new version of gpytorch)
ankushaggarwal Jan 22, 2023
d473faf
Added the diag=True version of rbf gradgrad kernel
ankushaggarwal Jan 22, 2023
1f7ab49
Fix formatting with pre-commit
ankushaggarwal Jan 23, 2023
b18ecc6
Increased the underline length to pass doc test
ankushaggarwal Jan 23, 2023
e246d2f
Getting rid of the changes to pre-commit-hooks
ankushaggarwal Apr 18, 2023
daeedab
Merge remote-tracking branch 'upstream/master' into second-derivs
ankushaggarwal Apr 18, 2023
b717bfc
Added docstrings and type hints
ankushaggarwal Apr 18, 2023
1623823
Fixed the wrong references in doc of ConstantMeanGradGrad
ankushaggarwal Apr 18, 2023
bd8a9cc
Added unit tests
ankushaggarwal Apr 21, 2023
94819cf
Merge branch 'master' into second-derivs
gpleiss May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ Specialty Kernels
.. autoclass:: RBFKernelGrad
:members:

:hidden:`RBFKernelGradGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RBFKernelGradGrad
:members:


Kernels for Scalable GP Regression Methods
--------------------------------------------
Expand Down
18 changes: 18 additions & 0 deletions docs/source/means.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,21 @@ Specialty Means

.. autoclass:: ConstantMeanGrad
:members:

:hidden:`ConstantMeanGradGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ConstantMeanGradGrad
:members:

:hidden:`LinearMeanGrad`
~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LinearMeanGrad
:members:

:hidden:`LinearMeanGradGrad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: LinearMeanGradGrad
:members:
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .product_structure_kernel import ProductStructureKernel
from .rbf_kernel import RBFKernel
from .rbf_kernel_grad import RBFKernelGrad
from .rbf_kernel_gradgrad import RBFKernelGradGrad
from .rff_kernel import RFFKernel
from .rq_kernel import RQKernel
from .scale_kernel import ScaleKernel
Expand Down Expand Up @@ -61,6 +62,7 @@
"RBFKernel",
"RFFKernel",
"RBFKernelGrad",
"RBFKernelGradGrad",
"RQKernel",
"ScaleKernel",
"SpectralDeltaKernel",
Expand Down
169 changes: 169 additions & 0 deletions gpytorch/kernels/rbf_kernel_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/usr/bin/env python3

import torch
from linear_operator.operators import KroneckerProductLinearOperator

from .rbf_kernel import postprocess_rbf, RBFKernel


class RBFKernelGradGrad(RBFKernel):
r"""
Computes a covariance matrix of the RBF kernel that models the covariance
between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}`
and :math:`\mathbf{x_2}`.

See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options.

.. note::

This kernel does not have an `outputscale` parameter. To add a scaling parameter,
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.

:param ard_num_dims: Set this if you want a separate lengthscale for each input
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
:param active_dims: Set this if you want to compute the covariance of only
a few input dimensions. The ints corresponds to the indices of the
dimensions. (Default: `None`.)
:param lengthscale_prior: Set this if you want to apply a prior to the
lengthscale parameter. (Default: `None`)
:param lengthscale_constraint: Set this if you want to apply a constraint
to the lengthscale parameter. (Default: `Positive`.)
:param eps: The minimum value that the lengthscale can take (prevents
divide by zero errors). (Default: `1e-6`.)

:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
ard_num_dims and batch_shape arguments.

Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad())
>>> covar = covar_module(x) # Output: LinearOperator of size (110 x 110), where 110 = n * (2*d + 1)
>>>
>>> batch_x = torch.randn(2, 10, 5)
>>> # Batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad())
>>> # Batch: different lengthscale for each batch
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad(batch_shape=torch.Size([2])))
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 110 x 110)
"""

def forward(self, x1, x2, diag=False, **params):
batch_shape = x1.shape[:-2]
n_batch_dims = len(batch_shape)
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)

if not diag:
# Scale the inputs by the lengthscale (for stability)
x1_ = x1.div(self.lengthscale)
x2_ = x2.div(self.lengthscale)

# Form all possible rank-1 products for the gradient and Hessian blocks
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d)
outer = outer / self.lengthscale.unsqueeze(-2)
outer = torch.transpose(outer, -1, -2).contiguous()

# 1) Kernel block
diff = self.covar_dist(x1_, x2_, square_dist=True, **params)
K_11 = postprocess_rbf(diff)
K[..., :n1, :n2] = K_11

# 2) First gradient block
outer1 = outer.view(*batch_shape, n1, n2 * d)
K[..., :n1, n2 : (n2 * (d + 1))] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d])

# 3) Second gradient block
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
outer2 = outer2.transpose(-1, -2)
K[..., n1 : (n1 * (d + 1)), :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1])

# 4) Hessian block
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
kp = KroneckerProductLinearOperator(
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
)
chain_rule = kp.to_dense() - outer3
K[..., n1 : (n1 * (d + 1)), n2 : (n2 * (d + 1))] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d])

# 5) 1-3 block
douter1dx2 = KroneckerProductLinearOperator(
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
).to_dense()

K_13 = (-douter1dx2 + outer1 * outer1) * K_11.repeat(
[*([1] * (n_batch_dims + 1)), d]
) # verified for n1=n2=1 case
K[..., :n1, (n2 * (d + 1)) :] = K_13

K_31 = (-douter1dx2.transpose(-1, -2) + outer2 * outer2) * K_11.repeat(
[*([1] * n_batch_dims), d, 1]
) # verified for n1=n2=1 case
K[..., (n1 * (d + 1)) :, :n2] = K_31

# rest of the blocks are all of size (n1*d,n2*d)
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1])
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d])
# II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2)
kp2 = KroneckerProductLinearOperator(
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
).to_dense()

# II may not be the correct thing to use. It might be more appropriate to use kp instead??
II = kp.to_dense()
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d])

K_23 = ((-kp2 + outer1 * outer1) * (-outer2) + 2.0 * II * outer1) * K_11dd # verified for n1=n2=1 case

K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23

K_32 = (
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
) * K_11dd # verified for n1=n2=1 case

K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32

K_33 = (
(-kp2.transpose(-1, -2) + outer2 * outer2) * (-kp2) - 2.0 * II * outer2 * outer1 + 2.0 * (II) ** 2
) * K_11dd + (
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
) * outer1 * K_11dd # verified for n1=n2=1 case

K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33

# Symmetrize for stability
if n1 == n2 and torch.eq(x1, x2).all():
K = 0.5 * (K.transpose(-1, -2) + K)

# Apply a perfect shuffle permutation to match the MutiTask ordering
pi1 = torch.arange(n1 * (2 * d + 1)).view(2 * d + 1, n1).t().reshape((n1 * (2 * d + 1)))
pi2 = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
K = K[..., pi1, :][..., :, pi2]

return K

else:
if not (n1 == n2 and torch.eq(x1, x2).all()):
raise RuntimeError("diag=True only works when x1 == x2")

kernel_diag = super(RBFKernelGradGrad, self).forward(x1, x2, diag=True)
grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2)
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
gradgrad_diag = (
3 * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(4)
)
gradgrad_diag = gradgrad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
k_diag = torch.cat((kernel_diag, grad_diag, gradgrad_diag), dim=-1)
pi = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
return k_diag[..., pi]

def num_outputs_per_input(self, x1, x2):
return x1.size(-1) * 2 + 1
15 changes: 14 additions & 1 deletion gpytorch/means/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@

from .constant_mean import ConstantMean
from .constant_mean_grad import ConstantMeanGrad
from .constant_mean_gradgrad import ConstantMeanGradGrad
from .linear_mean import LinearMean
from .linear_mean_grad import LinearMeanGrad
from .linear_mean_gradgrad import LinearMeanGradGrad
from .mean import Mean
from .multitask_mean import MultitaskMean
from .zero_mean import ZeroMean

__all__ = ["Mean", "ConstantMean", "ConstantMeanGrad", "LinearMean", "MultitaskMean", "ZeroMean"]
__all__ = [
"Mean",
"ConstantMean",
"ConstantMeanGrad",
"ConstantMeanGradGrad",
"LinearMean",
"LinearMeanGrad",
"LinearMeanGradGrad",
"MultitaskMean",
"ZeroMean",
]
47 changes: 47 additions & 0 deletions gpytorch/means/constant_mean_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3

from typing import Any, Optional

import torch

from ..priors import Prior
from .mean import Mean


class ConstantMeanGradGrad(Mean):
r"""
A (non-zero) constant prior mean function and its first and second derivatives, i.e.:

.. math::

\mu(\mathbf x) &= C \\
\Grad \mu(\mathbf x) &= \mathbf 0 \\
\Grad^2 \mu(\mathbf x) &= \mathbf 0

where :math:`C` is a learned constant.

:param prior: Prior for constant parameter :math:`C`.
:type prior: ~gpytorch.priors.Prior, optional
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor constant: :math:`C` parameter
"""

def __init__(
self,
prior: Optional[Prior] = None,
batch_shape: torch.Size = torch.Size(),
**kwargs: Any,
):
super(ConstantMeanGradGrad, self).__init__()
self.batch_shape = batch_shape
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
if prior is not None:
self.register_prior("mean_prior", prior, "constant")

def forward(self, input):
batch_shape = torch.broadcast_shapes(self.batch_shape, input.shape[:-2])
mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), 2 * input.size(-1) + 1).contiguous()
mean[..., 1:] = 0
return mean
44 changes: 44 additions & 0 deletions gpytorch/means/linear_mean_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class LinearMeanGrad(Mean):
r"""
A linear prior mean function and its first derivative, i.e.:

.. math::

\mu(\mathbf x) &= \mathbf W \cdot \mathbf x + B \\
\Grad \mu(\mathbf x) &= \mathbf W

where :math:`\mathbf W` and :math:`B` are learned constants.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional
:param bias: True/False flag for whether the bias: :math:`B` should be used in the mean (default: True).
:type bias: bool, optional

:var torch.Tensor weights: :math:`\mathbf W` parameter
:var torch.Tensor bias: :math:`B` parameter
"""

def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size(), bias: bool = True):
super().__init__()
self.dim = input_size
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
if bias:
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
else:
self.bias = None

def forward(self, x):
res = x.matmul(self.weights)
if self.bias is not None:
res = res + self.bias.unsqueeze(-1)
dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2)
return torch.cat((res, dres), -1)
46 changes: 46 additions & 0 deletions gpytorch/means/linear_mean_gradgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class LinearMeanGradGrad(Mean):
r"""
A linear prior mean function and its first and second derivatives, i.e.:

.. math::

\mu(\mathbf x) &= \mathbf W \cdot \mathbf x + B \\
\Grad \mu(\mathbf x) &= \mathbf W \\
\Grad^2 \mu(\mathbf x) &= \mathbf 0 \\

where :math:`\mathbf W` and :math:`B` are learned constants.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional
:param bias: True/False flag for whether the bias: :math:`B` should be used in the mean (default: True).
:type bias: bool, optional

:var torch.Tensor weights: :math:`\mathbf W` parameter
:var torch.Tensor bias: :math:`B` parameter
"""

def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string and type hints.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super().__init__()
self.dim = input_size
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
if bias:
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
else:
self.bias = None

def forward(self, x):
res = x.matmul(self.weights)
if self.bias is not None:
res = res + self.bias.unsqueeze(-1)
dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2)
ddres = torch.zeros_like(dres)
return torch.cat((res, dres, ddres), -1)
Loading