-
Notifications
You must be signed in to change notification settings - Fork 563
Mean and kernel functions for first and second derivatives #2235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
99697e1
e3e97cb
cbf0bf0
c18e6d8
6fa7f0d
c43c8c8
451dfd1
9bf9c3b
0bdfc3e
f6a3ce2
6e91183
d473faf
1f7ab49
b18ecc6
e246d2f
daeedab
b717bfc
1623823
bd8a9cc
94819cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
#!/usr/bin/env python3 | ||
import torch | ||
|
||
from ..lazy.kronecker_product_lazy_tensor import KroneckerProductLazyTensor | ||
from .rbf_kernel import RBFKernel, postprocess_rbf | ||
|
||
|
||
class RBFKernelGradGrad(RBFKernel): | ||
r""" | ||
Computes a covariance matrix of the RBF kernel that models the covariance | ||
between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}` | ||
and :math:`\mathbf{x_2}`. | ||
|
||
See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options. | ||
|
||
.. note:: | ||
|
||
This kernel does not have an `outputscale` parameter. To add a scaling parameter, | ||
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`. | ||
|
||
Args: | ||
:attr:`batch_shape` (torch.Size, optional): | ||
Set this if you want a separate lengthscale for each | ||
batch of input data. It should be `b` if :attr:`x1` is a `b x n x d` tensor. Default: `torch.Size([])`. | ||
:attr:`active_dims` (tuple of ints, optional): | ||
Set this if you want to compute the covariance of only a few input dimensions. The ints | ||
corresponds to the indices of the dimensions. Default: `None`. | ||
:attr:`lengthscale_prior` (Prior, optional): | ||
Set this if you want to apply a prior to the lengthscale parameter. Default: `None`. | ||
:attr:`lengthscale_constraint` (Constraint, optional): | ||
Set this if you want to apply a constraint to the lengthscale parameter. Default: `Positive`. | ||
:attr:`eps` (float): | ||
The minimum value that the lengthscale can take (prevents divide by zero errors). Default: `1e-6`. | ||
|
||
Attributes: | ||
:attr:`lengthscale` (Tensor): | ||
The lengthscale parameter. Size/shape of parameter depends on the | ||
:attr:`ard_num_dims` and :attr:`batch_shape` arguments. | ||
|
||
Example: | ||
>>> x = torch.randn(10, 5) | ||
>>> # Non-batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad()) | ||
>>> covar = covar_module(x) # Output: LazyTensor of size (110 x 110), where 60 = n * (2*d + 1) | ||
>>> | ||
>>> batch_x = torch.randn(2, 10, 5) | ||
>>> # Batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad()) | ||
>>> # Batch: different lengthscale for each batch | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad(batch_shape=torch.Size([2]))) | ||
>>> covar = covar_module(x) # Output: LazyTensor of size (2 x 110 x 110) | ||
""" | ||
|
||
def forward(self, x1, x2, diag=False, **params): | ||
batch_shape = x1.shape[:-2] | ||
n_batch_dims = len(batch_shape) | ||
n1, d = x1.shape[-2:] | ||
n2 = x2.shape[-2] | ||
|
||
K = torch.zeros(*batch_shape, n1 * (2*d + 1), n2 * (2*d + 1), device=x1.device, dtype=x1.dtype) | ||
|
||
if not diag: | ||
# Scale the inputs by the lengthscale (for stability) | ||
x1_ = x1.div(self.lengthscale) | ||
x2_ = x2.div(self.lengthscale) | ||
|
||
# Form all possible rank-1 products for the gradient and Hessian blocks | ||
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d) | ||
outer = outer / self.lengthscale.unsqueeze(-2) | ||
outer = torch.transpose(outer, -1, -2).contiguous() | ||
|
||
# 1) Kernel block | ||
K_11 = self.covar_dist(x1_, x2_, square_dist=True, dist_postprocess_func=postprocess_rbf, **params) | ||
K[..., :n1, :n2] = K_11 | ||
|
||
# 2) First gradient block | ||
outer1 = outer.view(*batch_shape, n1, n2 * d) | ||
K[..., :n1, n2:(n2*(d+1))] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d]) | ||
|
||
# 3) Second gradient block | ||
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d) | ||
outer2 = outer2.transpose(-1, -2) | ||
K[..., n1:(n1*(d+1)), :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1]) | ||
|
||
# 4) Hessian block | ||
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d]) | ||
kp = KroneckerProductLazyTensor( | ||
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), | ||
) | ||
chain_rule = kp.evaluate() - outer3 | ||
K[..., n1:(n1*(d+1)), n2:(n2*(d+1))] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d]) | ||
|
||
# 5) 1-3 block | ||
douter1dx2 = KroneckerProductLazyTensor( | ||
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), | ||
).evaluate() | ||
|
||
K_13 = (-douter1dx2 + outer1*outer1)*K_11.repeat([*([1] * (n_batch_dims + 1)), d]) #verified for n1=n2=1 case | ||
K[..., :n1, (n2*(d+1)):] = K_13 | ||
|
||
K_31 = ( -douter1dx2.transpose(-1,-2) + outer2*outer2)*K_11.repeat([*([1] * n_batch_dims), d, 1]) #verified for n1=n2=1 case | ||
K[..., (n1*(d+1)):, :n2] = K_31 | ||
|
||
# rest of the blocks are all of size (n1*d,n2*d) | ||
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1]) | ||
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d]) | ||
#II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2) | ||
kp2 = KroneckerProductLazyTensor( | ||
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), | ||
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), | ||
).evaluate() | ||
|
||
#II may not be the correct thing to use. It might be more appropriate to use kp instead?? | ||
II = kp.evaluate() | ||
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d]) | ||
|
||
K_23 = ( (-kp2 + outer1*outer1)*(-outer2) + 2.*II*outer1)*K_11dd #verified for n1=n2=1 case | ||
|
||
K[...,n1:(n1*(d+1)),(n2*(d+1)):] = K_23 | ||
|
||
K_32 = ( (-kp2.transpose(-1,-2) + outer2*outer2)*outer1 - 2.*II*outer2)*K_11dd #verified for n1=n2=1 case | ||
|
||
K[...,(n1*(d+1)):,n2:(n2*(d+1))] = K_32 | ||
|
||
K_33 = ( (-kp2.transpose(-1,-2)+outer2*outer2)*(-kp2) - 2.*II*outer2*outer1 + 2.*(II)**2)*K_11dd + ( (-kp2.transpose(-1,-2) + outer2*outer2)*outer1 - 2.*II*outer2)*outer1* K_11dd #verified for n1=n2=1 case | ||
|
||
K[...,(n1*(d+1)):,(n2*(d+1)):] = K_33 | ||
|
||
# Symmetrize for stability | ||
if n1 == n2 and torch.eq(x1, x2).all(): | ||
K = 0.5 * (K.transpose(-1, -2) + K) | ||
|
||
# Apply a perfect shuffle permutation to match the MutiTask ordering | ||
pi1 = torch.arange(n1 * (2*d + 1)).view(2*d + 1, n1).t().reshape((n1 * (2*d + 1))) | ||
pi2 = torch.arange(n2 * (2*d + 1)).view(2*d + 1, n2).t().reshape((n2 * (2*d + 1))) | ||
K = K[..., pi1, :][..., :, pi2] | ||
|
||
return K | ||
|
||
else: | ||
raise RuntimeError("diag=True not implemented yet") | ||
ankushaggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'''This has not been updated from RBFKernelGrad yet | ||
if not (n1 == n2 and torch.eq(x1, x2).all()): | ||
raise RuntimeError("diag=True only works when x1 == x2") | ||
|
||
kernel_diag = super(RBFKernelGrad, self).forward(x1, x2, diag=True) | ||
grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2) | ||
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) | ||
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1) | ||
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) | ||
return k_diag[..., pi] | ||
''' | ||
|
||
def num_outputs_per_input(self, x1, x2): | ||
return x1.size(-1)*2 + 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from ..utils.broadcasting import _mul_broadcast_shape | ||
from .mean import Mean | ||
|
||
|
||
class ConstantMeanGradGrad(Mean): | ||
def __init__(self, prior=None, batch_shape=torch.Size(), **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a doc string and type hints. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
super(ConstantMeanGradGrad, self).__init__() | ||
self.batch_shape = batch_shape | ||
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1))) | ||
if prior is not None: | ||
self.register_prior("mean_prior", prior, "constant") | ||
|
||
def forward(self, input): | ||
batch_shape = _mul_broadcast_shape(self.batch_shape, input.shape[:-2]) | ||
mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), 2*input.size(-1) + 1).contiguous() | ||
mean[..., 1:] = 0 | ||
return mean |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from .mean import Mean | ||
|
||
|
||
class LinearMeanGrad(Mean): | ||
def __init__(self, input_size, batch_shape=torch.Size(), bias=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doc string and type hints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
super().__init__() | ||
self.dim = input_size | ||
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1))) | ||
if bias: | ||
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1))) | ||
else: | ||
self.bias = None | ||
|
||
def forward(self, x): | ||
res = x.matmul(self.weights) | ||
if self.bias is not None: | ||
res = res + self.bias.unsqueeze(-1) | ||
dres = self.weights.expand(x.transpose(-1,-2).shape).transpose(-1,-2) | ||
return torch.cat((res,dres),-1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from .mean import Mean | ||
|
||
|
||
class LinearMeanGradGrad(Mean): | ||
def __init__(self, input_size, batch_shape=torch.Size(), bias=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doc string and type hints. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
super().__init__() | ||
self.dim = input_size | ||
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1))) | ||
if bias: | ||
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1))) | ||
else: | ||
self.bias = None | ||
|
||
def forward(self, x): | ||
res = x.matmul(self.weights) | ||
if self.bias is not None: | ||
res = res + self.bias.unsqueeze(-1) | ||
dres = self.weights.expand(x.transpose(-1,-2).shape).transpose(-1,-2) | ||
ddres = torch.zeros_like(dres) | ||
return torch.cat((res,dres,ddres),-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the doc string to be in the standard sphinx format? (See the Kernel base class for an example)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated