|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import math |
| 4 | + |
| 5 | +import torch |
| 6 | +from linear_operator.operators import KroneckerProductLinearOperator |
| 7 | + |
| 8 | +from gpytorch.kernels.matern_kernel import MaternKernel |
| 9 | + |
| 10 | +sqrt5 = math.sqrt(5) |
| 11 | +five_thirds = 5.0 / 3.0 |
| 12 | + |
| 13 | + |
| 14 | +class Matern52KernelGrad(MaternKernel): |
| 15 | + r""" |
| 16 | + Computes a covariance matrix of the Matern52 kernel that models the covariance |
| 17 | + between the values and partial derivatives for inputs :math:`\mathbf{x_1}` |
| 18 | + and :math:`\mathbf{x_2}`. |
| 19 | +
|
| 20 | + See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options. |
| 21 | +
|
| 22 | + .. note:: |
| 23 | +
|
| 24 | + This kernel does not have an `outputscale` parameter. To add a scaling parameter, |
| 25 | + decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`. |
| 26 | +
|
| 27 | + :param ard_num_dims: Set this if you want a separate lengthscale for each input |
| 28 | + dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) |
| 29 | + :param batch_shape: Set this if you want a separate lengthscale for each batch of input |
| 30 | + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is |
| 31 | + a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. |
| 32 | + :param active_dims: Set this if you want to compute the covariance of only |
| 33 | + a few input dimensions. The ints corresponds to the indices of the |
| 34 | + dimensions. (Default: `None`.) |
| 35 | + :param lengthscale_prior: Set this if you want to apply a prior to the |
| 36 | + lengthscale parameter. (Default: `None`) |
| 37 | + :param lengthscale_constraint: Set this if you want to apply a constraint |
| 38 | + to the lengthscale parameter. (Default: `Positive`.) |
| 39 | + :param eps: The minimum value that the lengthscale can take (prevents |
| 40 | + divide by zero errors). (Default: `1e-6`.) |
| 41 | +
|
| 42 | + :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the |
| 43 | + ard_num_dims and batch_shape arguments. |
| 44 | +
|
| 45 | + Example: |
| 46 | + >>> x = torch.randn(10, 5) |
| 47 | + >>> # Non-batch: Simple option |
| 48 | + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad()) |
| 49 | + >>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1) |
| 50 | + >>> |
| 51 | + >>> batch_x = torch.randn(2, 10, 5) |
| 52 | + >>> # Batch: Simple option |
| 53 | + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad()) |
| 54 | + >>> # Batch: different lengthscale for each batch |
| 55 | + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501 |
| 56 | + >>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60) |
| 57 | + """ |
| 58 | + |
| 59 | + def __init__(self, **kwargs): |
| 60 | + |
| 61 | + # remove nu in case it was set |
| 62 | + kwargs.pop("nu", None) |
| 63 | + super(Matern52KernelGrad, self).__init__(nu=2.5, **kwargs) |
| 64 | + |
| 65 | + def forward(self, x1, x2, diag=False, **params): |
| 66 | + |
| 67 | + lengthscale = self.lengthscale |
| 68 | + |
| 69 | + batch_shape = x1.shape[:-2] |
| 70 | + n_batch_dims = len(batch_shape) |
| 71 | + n1, d = x1.shape[-2:] |
| 72 | + n2 = x2.shape[-2] |
| 73 | + |
| 74 | + if not diag: |
| 75 | + |
| 76 | + K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype) |
| 77 | + |
| 78 | + distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params) |
| 79 | + exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix) |
| 80 | + |
| 81 | + # differences matrix in each dimension to be used for derivatives |
| 82 | + # shape of n1 x n2 x d |
| 83 | + outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d) |
| 84 | + outer = outer / lengthscale.unsqueeze(-2) ** 2 |
| 85 | + # shape of n1 x d x n2 |
| 86 | + outer = torch.transpose(outer, -1, -2).contiguous() |
| 87 | + |
| 88 | + # 1) Kernel block, cov(f^m, f^n) |
| 89 | + # shape is n1 x n2 |
| 90 | + exp_component = torch.exp(-sqrt5 * distance_matrix) |
| 91 | + constant_component = (sqrt5 * distance_matrix).add(1).add(five_thirds * distance_matrix**2) |
| 92 | + |
| 93 | + K[..., :n1, :n2] = constant_component * exp_component |
| 94 | + |
| 95 | + # 2) First gradient block, cov(f^m, omega^n_d) |
| 96 | + outer1 = outer.view(*batch_shape, n1, n2 * d) |
| 97 | + K[..., :n1, n2:] = outer1 * (-five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat( |
| 98 | + [*([1] * (n_batch_dims + 1)), d] |
| 99 | + ) |
| 100 | + |
| 101 | + # 3) Second gradient block, cov(omega^m_d, f^n) |
| 102 | + outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d) |
| 103 | + outer2 = outer2.transpose(-1, -2) |
| 104 | + # the - signs on -outer2 and -five_thirds cancel out |
| 105 | + K[..., n1:, :n2] = outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat( |
| 106 | + [*([1] * n_batch_dims), d, 1] |
| 107 | + ) |
| 108 | + |
| 109 | + # 4) Hessian block, cov(omega^m_d, omega^n_d) |
| 110 | + outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d]) |
| 111 | + kp = KroneckerProductLinearOperator( |
| 112 | + torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2, |
| 113 | + torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), |
| 114 | + ) |
| 115 | + |
| 116 | + part1 = -five_thirds * exp_neg_sqrt5r |
| 117 | + part2 = 5 * outer3 |
| 118 | + part3 = 1 + sqrt5 * distance_matrix |
| 119 | + |
| 120 | + K[..., n1:, n2:] = part1.repeat([*([1] * n_batch_dims), d, d]).mul_( |
| 121 | + # need to use kp.to_dense().mul instead of kp.to_dense().mul_ |
| 122 | + # because otherwise a RuntimeError is raised due to how autograd works with |
| 123 | + # view + inplace operations in the case of 1-dimensional input |
| 124 | + part2.sub_(kp.to_dense().mul(part3.repeat([*([1] * n_batch_dims), d, d]))) |
| 125 | + ) |
| 126 | + |
| 127 | + # Symmetrize for stability |
| 128 | + if n1 == n2 and torch.eq(x1, x2).all(): |
| 129 | + K = 0.5 * (K.transpose(-1, -2) + K) |
| 130 | + |
| 131 | + # Apply a perfect shuffle permutation to match the MutiTask ordering |
| 132 | + pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1))) |
| 133 | + pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) |
| 134 | + K = K[..., pi1, :][..., :, pi2] |
| 135 | + |
| 136 | + return K |
| 137 | + else: |
| 138 | + if not (n1 == n2 and torch.eq(x1, x2).all()): |
| 139 | + raise RuntimeError("diag=True only works when x1 == x2") |
| 140 | + |
| 141 | + # nu is set to 2.5 |
| 142 | + kernel_diag = super(Matern52KernelGrad, self).forward(x1, x2, diag=True) |
| 143 | + grad_diag = ( |
| 144 | + five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) |
| 145 | + ) / lengthscale**2 |
| 146 | + grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) |
| 147 | + k_diag = torch.cat((kernel_diag, grad_diag), dim=-1) |
| 148 | + pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) |
| 149 | + return k_diag[..., pi] |
| 150 | + |
| 151 | + def num_outputs_per_input(self, x1, x2): |
| 152 | + return x1.size(-1) + 1 |
0 commit comments