-
Notifications
You must be signed in to change notification settings - Fork 563
Mean and kernel functions for first and second derivatives #2235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
99697e1
Added an RBF kernel with second (non-mixed) derivatives. Need to be t…
ankushaggarwal e3e97cb
Added constant mean with second derivative and linear means with firs…
ankushaggarwal cbf0bf0
Merge branch 'master' into second-derivs
ankushaggarwal c18e6d8
Hook fixes using pre-commit
ankushaggarwal 6fa7f0d
Merge branch 'master' into second-derivs
ankushaggarwal c43c8c8
Added new kernel and means to the documentation
ankushaggarwal 451dfd1
Hook fix
ankushaggarwal 9bf9c3b
Changed from lazy tensor to linear operator as per the newer gpytorch…
ankushaggarwal 0bdfc3e
Changed from utils.broadcasting._mul_broadcast_shape to torch.broadca…
ankushaggarwal f6a3ce2
Revert "Hook fix"
ankushaggarwal 6e91183
Fixed a minor error (as per the new version of gpytorch)
ankushaggarwal d473faf
Added the diag=True version of rbf gradgrad kernel
ankushaggarwal 1f7ab49
Fix formatting with pre-commit
ankushaggarwal b18ecc6
Increased the underline length to pass doc test
ankushaggarwal e246d2f
Getting rid of the changes to pre-commit-hooks
ankushaggarwal daeedab
Merge remote-tracking branch 'upstream/master' into second-derivs
ankushaggarwal b717bfc
Added docstrings and type hints
ankushaggarwal 1623823
Fixed the wrong references in doc of ConstantMeanGradGrad
ankushaggarwal bd8a9cc
Added unit tests
ankushaggarwal 94819cf
Merge branch 'master' into second-derivs
gpleiss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
from linear_operator.operators import KroneckerProductLinearOperator | ||
|
||
from .rbf_kernel import postprocess_rbf, RBFKernel | ||
|
||
|
||
class RBFKernelGradGrad(RBFKernel): | ||
r""" | ||
Computes a covariance matrix of the RBF kernel that models the covariance | ||
between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}` | ||
and :math:`\mathbf{x_2}`. | ||
|
||
See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options. | ||
|
||
.. note:: | ||
|
||
This kernel does not have an `outputscale` parameter. To add a scaling parameter, | ||
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`. | ||
|
||
:param ard_num_dims: Set this if you want a separate lengthscale for each input | ||
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) | ||
:param batch_shape: Set this if you want a separate lengthscale for each batch of input | ||
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is | ||
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. | ||
:param active_dims: Set this if you want to compute the covariance of only | ||
a few input dimensions. The ints corresponds to the indices of the | ||
dimensions. (Default: `None`.) | ||
:param lengthscale_prior: Set this if you want to apply a prior to the | ||
lengthscale parameter. (Default: `None`) | ||
:param lengthscale_constraint: Set this if you want to apply a constraint | ||
to the lengthscale parameter. (Default: `Positive`.) | ||
:param eps: The minimum value that the lengthscale can take (prevents | ||
divide by zero errors). (Default: `1e-6`.) | ||
|
||
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the | ||
ard_num_dims and batch_shape arguments. | ||
|
||
Example: | ||
>>> x = torch.randn(10, 5) | ||
>>> # Non-batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad()) | ||
>>> covar = covar_module(x) # Output: LinearOperator of size (110 x 110), where 110 = n * (2*d + 1) | ||
>>> | ||
>>> batch_x = torch.randn(2, 10, 5) | ||
>>> # Batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad()) | ||
>>> # Batch: different lengthscale for each batch | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad(batch_shape=torch.Size([2]))) | ||
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 110 x 110) | ||
""" | ||
|
||
def forward(self, x1, x2, diag=False, **params): | ||
batch_shape = x1.shape[:-2] | ||
n_batch_dims = len(batch_shape) | ||
n1, d = x1.shape[-2:] | ||
n2 = x2.shape[-2] | ||
|
||
K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype) | ||
|
||
if not diag: | ||
# Scale the inputs by the lengthscale (for stability) | ||
x1_ = x1.div(self.lengthscale) | ||
x2_ = x2.div(self.lengthscale) | ||
|
||
# Form all possible rank-1 products for the gradient and Hessian blocks | ||
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d) | ||
outer = outer / self.lengthscale.unsqueeze(-2) | ||
outer = torch.transpose(outer, -1, -2).contiguous() | ||
|
||
# 1) Kernel block | ||
diff = self.covar_dist(x1_, x2_, square_dist=True, **params) | ||
K_11 = postprocess_rbf(diff) | ||
K[..., :n1, :n2] = K_11 | ||
|
||
# 2) First gradient block | ||
outer1 = outer.view(*batch_shape, n1, n2 * d) | ||
K[..., :n1, n2 : (n2 * (d + 1))] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d]) | ||
|
||
# 3) Second gradient block | ||
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d) | ||
outer2 = outer2.transpose(-1, -2) | ||
K[..., n1 : (n1 * (d + 1)), :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1]) | ||
|
||
# 4) Hessian block | ||
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d]) | ||
kp = KroneckerProductLinearOperator( | ||
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), | ||
) | ||
chain_rule = kp.to_dense() - outer3 | ||
K[..., n1 : (n1 * (d + 1)), n2 : (n2 * (d + 1))] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d]) | ||
|
||
# 5) 1-3 block | ||
douter1dx2 = KroneckerProductLinearOperator( | ||
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), | ||
).to_dense() | ||
|
||
K_13 = (-douter1dx2 + outer1 * outer1) * K_11.repeat( | ||
[*([1] * (n_batch_dims + 1)), d] | ||
) # verified for n1=n2=1 case | ||
K[..., :n1, (n2 * (d + 1)) :] = K_13 | ||
|
||
K_31 = (-douter1dx2.transpose(-1, -2) + outer2 * outer2) * K_11.repeat( | ||
[*([1] * n_batch_dims), d, 1] | ||
) # verified for n1=n2=1 case | ||
K[..., (n1 * (d + 1)) :, :n2] = K_31 | ||
|
||
# rest of the blocks are all of size (n1*d,n2*d) | ||
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1]) | ||
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d]) | ||
# II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2) | ||
kp2 = KroneckerProductLinearOperator( | ||
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), | ||
).to_dense() | ||
|
||
# II may not be the correct thing to use. It might be more appropriate to use kp instead?? | ||
II = kp.to_dense() | ||
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d]) | ||
|
||
K_23 = ((-kp2 + outer1 * outer1) * (-outer2) + 2.0 * II * outer1) * K_11dd # verified for n1=n2=1 case | ||
|
||
K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23 | ||
|
||
K_32 = ( | ||
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2 | ||
) * K_11dd # verified for n1=n2=1 case | ||
|
||
K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32 | ||
|
||
K_33 = ( | ||
(-kp2.transpose(-1, -2) + outer2 * outer2) * (-kp2) - 2.0 * II * outer2 * outer1 + 2.0 * (II) ** 2 | ||
) * K_11dd + ( | ||
(-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2 | ||
) * outer1 * K_11dd # verified for n1=n2=1 case | ||
|
||
K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33 | ||
|
||
# Symmetrize for stability | ||
if n1 == n2 and torch.eq(x1, x2).all(): | ||
K = 0.5 * (K.transpose(-1, -2) + K) | ||
|
||
# Apply a perfect shuffle permutation to match the MutiTask ordering | ||
pi1 = torch.arange(n1 * (2 * d + 1)).view(2 * d + 1, n1).t().reshape((n1 * (2 * d + 1))) | ||
pi2 = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1))) | ||
K = K[..., pi1, :][..., :, pi2] | ||
|
||
return K | ||
|
||
else: | ||
if not (n1 == n2 and torch.eq(x1, x2).all()): | ||
raise RuntimeError("diag=True only works when x1 == x2") | ||
|
||
kernel_diag = super(RBFKernelGradGrad, self).forward(x1, x2, diag=True) | ||
grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2) | ||
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) | ||
gradgrad_diag = ( | ||
3 * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(4) | ||
) | ||
gradgrad_diag = gradgrad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) | ||
k_diag = torch.cat((kernel_diag, grad_diag, gradgrad_diag), dim=-1) | ||
pi = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1))) | ||
return k_diag[..., pi] | ||
|
||
def num_outputs_per_input(self, x1, x2): | ||
return x1.size(-1) * 2 + 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from typing import Any, Optional | ||
|
||
import torch | ||
|
||
from ..priors import Prior | ||
from .mean import Mean | ||
|
||
|
||
class ConstantMeanGradGrad(Mean): | ||
r""" | ||
A (non-zero) constant prior mean function and its first and second derivatives, i.e.: | ||
|
||
.. math:: | ||
|
||
\mu(\mathbf x) &= C \\ | ||
\Grad \mu(\mathbf x) &= \mathbf 0 \\ | ||
\Grad^2 \mu(\mathbf x) &= \mathbf 0 | ||
|
||
where :math:`C` is a learned constant. | ||
|
||
:param prior: Prior for constant parameter :math:`C`. | ||
:type prior: ~gpytorch.priors.Prior, optional | ||
:param batch_shape: The batch shape of the learned constant(s) (default: []). | ||
:type batch_shape: torch.Size, optional | ||
|
||
:var torch.Tensor constant: :math:`C` parameter | ||
""" | ||
|
||
def __init__( | ||
self, | ||
prior: Optional[Prior] = None, | ||
batch_shape: torch.Size = torch.Size(), | ||
**kwargs: Any, | ||
): | ||
super(ConstantMeanGradGrad, self).__init__() | ||
self.batch_shape = batch_shape | ||
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1))) | ||
if prior is not None: | ||
self.register_prior("mean_prior", prior, "constant") | ||
|
||
def forward(self, input): | ||
batch_shape = torch.broadcast_shapes(self.batch_shape, input.shape[:-2]) | ||
mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), 2 * input.size(-1) + 1).contiguous() | ||
mean[..., 1:] = 0 | ||
return mean |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from .mean import Mean | ||
|
||
|
||
class LinearMeanGrad(Mean): | ||
r""" | ||
A linear prior mean function and its first derivative, i.e.: | ||
|
||
.. math:: | ||
|
||
\mu(\mathbf x) &= \mathbf W \cdot \mathbf x + B \\ | ||
\Grad \mu(\mathbf x) &= \mathbf W | ||
|
||
where :math:`\mathbf W` and :math:`B` are learned constants. | ||
|
||
:param input_size: dimension of input :math:`\mathbf x`. | ||
:type input_size: int | ||
:param batch_shape: The batch shape of the learned constant(s) (default: []). | ||
:type batch_shape: torch.Size, optional | ||
:param bias: True/False flag for whether the bias: :math:`B` should be used in the mean (default: True). | ||
:type bias: bool, optional | ||
|
||
:var torch.Tensor weights: :math:`\mathbf W` parameter | ||
:var torch.Tensor bias: :math:`B` parameter | ||
""" | ||
|
||
def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size(), bias: bool = True): | ||
super().__init__() | ||
self.dim = input_size | ||
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1))) | ||
if bias: | ||
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1))) | ||
else: | ||
self.bias = None | ||
|
||
def forward(self, x): | ||
res = x.matmul(self.weights) | ||
if self.bias is not None: | ||
res = res + self.bias.unsqueeze(-1) | ||
dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2) | ||
return torch.cat((res, dres), -1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from .mean import Mean | ||
|
||
|
||
class LinearMeanGradGrad(Mean): | ||
r""" | ||
A linear prior mean function and its first and second derivatives, i.e.: | ||
|
||
.. math:: | ||
|
||
\mu(\mathbf x) &= \mathbf W \cdot \mathbf x + B \\ | ||
\Grad \mu(\mathbf x) &= \mathbf W \\ | ||
\Grad^2 \mu(\mathbf x) &= \mathbf 0 \\ | ||
|
||
where :math:`\mathbf W` and :math:`B` are learned constants. | ||
|
||
:param input_size: dimension of input :math:`\mathbf x`. | ||
:type input_size: int | ||
:param batch_shape: The batch shape of the learned constant(s) (default: []). | ||
:type batch_shape: torch.Size, optional | ||
:param bias: True/False flag for whether the bias: :math:`B` should be used in the mean (default: True). | ||
:type bias: bool, optional | ||
|
||
:var torch.Tensor weights: :math:`\mathbf W` parameter | ||
:var torch.Tensor bias: :math:`B` parameter | ||
""" | ||
|
||
def __init__(self, input_size, batch_shape=torch.Size(), bias=True): | ||
super().__init__() | ||
self.dim = input_size | ||
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1))) | ||
if bias: | ||
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1))) | ||
else: | ||
self.bias = None | ||
|
||
def forward(self, x): | ||
res = x.matmul(self.weights) | ||
if self.bias is not None: | ||
res = res + self.bias.unsqueeze(-1) | ||
dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2) | ||
ddres = torch.zeros_like(dres) | ||
return torch.cat((res, dres, ddres), -1) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc string and type hints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added