diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index 5e1dce6e7..714e46a6c 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -176,6 +176,12 @@ Specialty Kernels .. autoclass:: RBFKernelGrad :members: +:hidden:`RBFKernelGradGrad` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RBFKernelGradGrad + :members: + Kernels for Scalable GP Regression Methods -------------------------------------------- diff --git a/docs/source/means.rst b/docs/source/means.rst index 69630ae15..f9d828e14 100644 --- a/docs/source/means.rst +++ b/docs/source/means.rst @@ -51,3 +51,21 @@ Specialty Means .. autoclass:: ConstantMeanGrad :members: + +:hidden:`ConstantMeanGradGrad` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ConstantMeanGradGrad + :members: + +:hidden:`LinearMeanGrad` +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: LinearMeanGrad + :members: + +:hidden:`LinearMeanGradGrad` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: LinearMeanGradGrad + :members: diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index 9828cf1ad..cc85fe624 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -25,6 +25,7 @@ from .product_structure_kernel import ProductStructureKernel from .rbf_kernel import RBFKernel from .rbf_kernel_grad import RBFKernelGrad +from .rbf_kernel_gradgrad import RBFKernelGradGrad from .rff_kernel import RFFKernel from .rq_kernel import RQKernel from .scale_kernel import ScaleKernel @@ -61,6 +62,7 @@ "RBFKernel", "RFFKernel", "RBFKernelGrad", + "RBFKernelGradGrad", "RQKernel", "ScaleKernel", "SpectralDeltaKernel", diff --git a/gpytorch/kernels/rbf_kernel_gradgrad.py b/gpytorch/kernels/rbf_kernel_gradgrad.py new file mode 100644 index 000000000..e95422166 --- /dev/null +++ b/gpytorch/kernels/rbf_kernel_gradgrad.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 + +import torch +from linear_operator.operators import KroneckerProductLinearOperator + +from .rbf_kernel import postprocess_rbf, RBFKernel + + +class RBFKernelGradGrad(RBFKernel): + r""" + Computes a covariance matrix of the RBF kernel that models the covariance + between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}` + and :math:`\mathbf{x_2}`. + + See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options. + + .. note:: + + This kernel does not have an `outputscale` parameter. To add a scaling parameter, + decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`. + + :param ard_num_dims: Set this if you want a separate lengthscale for each input + dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) + :param batch_shape: Set this if you want a separate lengthscale for each batch of input + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is + a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. + :param active_dims: Set this if you want to compute the covariance of only + a few input dimensions. The ints corresponds to the indices of the + dimensions. (Default: `None`.) + :param lengthscale_prior: Set this if you want to apply a prior to the + lengthscale parameter. (Default: `None`) + :param lengthscale_constraint: Set this if you want to apply a constraint + to the lengthscale parameter. (Default: `Positive`.) + :param eps: The minimum value that the lengthscale can take (prevents + divide by zero errors). (Default: `1e-6`.) + + :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the + ard_num_dims and batch_shape arguments. + + Example: + >>> x = torch.randn(10, 5) + >>> # Non-batch: Simple option + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad()) + >>> covar = covar_module(x) # Output: LinearOperator of size (110 x 110), where 110 = n * (2*d + 1) + >>> + >>> batch_x = torch.randn(2, 10, 5) + >>> # Batch: Simple option + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad()) + >>> # Batch: different lengthscale for each batch + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernelGradGrad(batch_shape=torch.Size([2]))) + >>> covar = covar_module(x) # Output: LinearOperator of size (2 x 110 x 110) + """ + + def forward(self, x1, x2, diag=False, **params): + batch_shape = x1.shape[:-2] + n_batch_dims = len(batch_shape) + n1, d = x1.shape[-2:] + n2 = x2.shape[-2] + + K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype) + + if not diag: + # Scale the inputs by the lengthscale (for stability) + x1_ = x1.div(self.lengthscale) + x2_ = x2.div(self.lengthscale) + + # Form all possible rank-1 products for the gradient and Hessian blocks + outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d) + outer = outer / self.lengthscale.unsqueeze(-2) + outer = torch.transpose(outer, -1, -2).contiguous() + + # 1) Kernel block + diff = self.covar_dist(x1_, x2_, square_dist=True, **params) + K_11 = postprocess_rbf(diff) + K[..., :n1, :n2] = K_11 + + # 2) First gradient block + outer1 = outer.view(*batch_shape, n1, n2 * d) + K[..., :n1, n2 : (n2 * (d + 1))] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d]) + + # 3) Second gradient block + outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d) + outer2 = outer2.transpose(-1, -2) + K[..., n1 : (n1 * (d + 1)), :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1]) + + # 4) Hessian block + outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d]) + kp = KroneckerProductLinearOperator( + torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), + torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), + ) + chain_rule = kp.to_dense() - outer3 + K[..., n1 : (n1 * (d + 1)), n2 : (n2 * (d + 1))] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d]) + + # 5) 1-3 block + douter1dx2 = KroneckerProductLinearOperator( + torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), + torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), + ).to_dense() + + K_13 = (-douter1dx2 + outer1 * outer1) * K_11.repeat( + [*([1] * (n_batch_dims + 1)), d] + ) # verified for n1=n2=1 case + K[..., :n1, (n2 * (d + 1)) :] = K_13 + + K_31 = (-douter1dx2.transpose(-1, -2) + outer2 * outer2) * K_11.repeat( + [*([1] * n_batch_dims), d, 1] + ) # verified for n1=n2=1 case + K[..., (n1 * (d + 1)) :, :n2] = K_31 + + # rest of the blocks are all of size (n1*d,n2*d) + outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1]) + outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d]) + # II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2) + kp2 = KroneckerProductLinearOperator( + torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2), + torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), + ).to_dense() + + # II may not be the correct thing to use. It might be more appropriate to use kp instead?? + II = kp.to_dense() + K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d]) + + K_23 = ((-kp2 + outer1 * outer1) * (-outer2) + 2.0 * II * outer1) * K_11dd # verified for n1=n2=1 case + + K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23 + + K_32 = ( + (-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2 + ) * K_11dd # verified for n1=n2=1 case + + K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32 + + K_33 = ( + (-kp2.transpose(-1, -2) + outer2 * outer2) * (-kp2) - 2.0 * II * outer2 * outer1 + 2.0 * (II) ** 2 + ) * K_11dd + ( + (-kp2.transpose(-1, -2) + outer2 * outer2) * outer1 - 2.0 * II * outer2 + ) * outer1 * K_11dd # verified for n1=n2=1 case + + K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33 + + # Symmetrize for stability + if n1 == n2 and torch.eq(x1, x2).all(): + K = 0.5 * (K.transpose(-1, -2) + K) + + # Apply a perfect shuffle permutation to match the MutiTask ordering + pi1 = torch.arange(n1 * (2 * d + 1)).view(2 * d + 1, n1).t().reshape((n1 * (2 * d + 1))) + pi2 = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1))) + K = K[..., pi1, :][..., :, pi2] + + return K + + else: + if not (n1 == n2 and torch.eq(x1, x2).all()): + raise RuntimeError("diag=True only works when x1 == x2") + + kernel_diag = super(RBFKernelGradGrad, self).forward(x1, x2, diag=True) + grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2) + grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) + gradgrad_diag = ( + 3 * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(4) + ) + gradgrad_diag = gradgrad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) + k_diag = torch.cat((kernel_diag, grad_diag, gradgrad_diag), dim=-1) + pi = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1))) + return k_diag[..., pi] + + def num_outputs_per_input(self, x1, x2): + return x1.size(-1) * 2 + 1 diff --git a/gpytorch/means/__init__.py b/gpytorch/means/__init__.py index 7317b56e8..29b01017e 100644 --- a/gpytorch/means/__init__.py +++ b/gpytorch/means/__init__.py @@ -2,9 +2,22 @@ from .constant_mean import ConstantMean from .constant_mean_grad import ConstantMeanGrad +from .constant_mean_gradgrad import ConstantMeanGradGrad from .linear_mean import LinearMean +from .linear_mean_grad import LinearMeanGrad +from .linear_mean_gradgrad import LinearMeanGradGrad from .mean import Mean from .multitask_mean import MultitaskMean from .zero_mean import ZeroMean -__all__ = ["Mean", "ConstantMean", "ConstantMeanGrad", "LinearMean", "MultitaskMean", "ZeroMean"] +__all__ = [ + "Mean", + "ConstantMean", + "ConstantMeanGrad", + "ConstantMeanGradGrad", + "LinearMean", + "LinearMeanGrad", + "LinearMeanGradGrad", + "MultitaskMean", + "ZeroMean", +] diff --git a/gpytorch/means/constant_mean_gradgrad.py b/gpytorch/means/constant_mean_gradgrad.py new file mode 100644 index 000000000..b221f26d9 --- /dev/null +++ b/gpytorch/means/constant_mean_gradgrad.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +from typing import Any, Optional + +import torch + +from ..priors import Prior +from .mean import Mean + + +class ConstantMeanGradGrad(Mean): + r""" + A (non-zero) constant prior mean function and its first and second derivatives, i.e.: + + .. math:: + + \mu(\mathbf x) &= C \\ + \Grad \mu(\mathbf x) &= \mathbf 0 \\ + \Grad^2 \mu(\mathbf x) &= \mathbf 0 + + where :math:`C` is a learned constant. + + :param prior: Prior for constant parameter :math:`C`. + :type prior: ~gpytorch.priors.Prior, optional + :param batch_shape: The batch shape of the learned constant(s) (default: []). + :type batch_shape: torch.Size, optional + + :var torch.Tensor constant: :math:`C` parameter + """ + + def __init__( + self, + prior: Optional[Prior] = None, + batch_shape: torch.Size = torch.Size(), + **kwargs: Any, + ): + super(ConstantMeanGradGrad, self).__init__() + self.batch_shape = batch_shape + self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1))) + if prior is not None: + self.register_prior("mean_prior", prior, "constant") + + def forward(self, input): + batch_shape = torch.broadcast_shapes(self.batch_shape, input.shape[:-2]) + mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), 2 * input.size(-1) + 1).contiguous() + mean[..., 1:] = 0 + return mean diff --git a/gpytorch/means/linear_mean_grad.py b/gpytorch/means/linear_mean_grad.py new file mode 100644 index 000000000..df4154513 --- /dev/null +++ b/gpytorch/means/linear_mean_grad.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import torch + +from .mean import Mean + + +class LinearMeanGrad(Mean): + r""" + A linear prior mean function and its first derivative, i.e.: + + .. math:: + + \mu(\mathbf x) &= \mathbf W \cdot \mathbf x + B \\ + \Grad \mu(\mathbf x) &= \mathbf W + + where :math:`\mathbf W` and :math:`B` are learned constants. + + :param input_size: dimension of input :math:`\mathbf x`. + :type input_size: int + :param batch_shape: The batch shape of the learned constant(s) (default: []). + :type batch_shape: torch.Size, optional + :param bias: True/False flag for whether the bias: :math:`B` should be used in the mean (default: True). + :type bias: bool, optional + + :var torch.Tensor weights: :math:`\mathbf W` parameter + :var torch.Tensor bias: :math:`B` parameter + """ + + def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size(), bias: bool = True): + super().__init__() + self.dim = input_size + self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1))) + if bias: + self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1))) + else: + self.bias = None + + def forward(self, x): + res = x.matmul(self.weights) + if self.bias is not None: + res = res + self.bias.unsqueeze(-1) + dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2) + return torch.cat((res, dres), -1) diff --git a/gpytorch/means/linear_mean_gradgrad.py b/gpytorch/means/linear_mean_gradgrad.py new file mode 100644 index 000000000..4e57fc0e0 --- /dev/null +++ b/gpytorch/means/linear_mean_gradgrad.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import torch + +from .mean import Mean + + +class LinearMeanGradGrad(Mean): + r""" + A linear prior mean function and its first and second derivatives, i.e.: + + .. math:: + + \mu(\mathbf x) &= \mathbf W \cdot \mathbf x + B \\ + \Grad \mu(\mathbf x) &= \mathbf W \\ + \Grad^2 \mu(\mathbf x) &= \mathbf 0 \\ + + where :math:`\mathbf W` and :math:`B` are learned constants. + + :param input_size: dimension of input :math:`\mathbf x`. + :type input_size: int + :param batch_shape: The batch shape of the learned constant(s) (default: []). + :type batch_shape: torch.Size, optional + :param bias: True/False flag for whether the bias: :math:`B` should be used in the mean (default: True). + :type bias: bool, optional + + :var torch.Tensor weights: :math:`\mathbf W` parameter + :var torch.Tensor bias: :math:`B` parameter + """ + + def __init__(self, input_size, batch_shape=torch.Size(), bias=True): + super().__init__() + self.dim = input_size + self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1))) + if bias: + self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1))) + else: + self.bias = None + + def forward(self, x): + res = x.matmul(self.weights) + if self.bias is not None: + res = res + self.bias.unsqueeze(-1) + dres = self.weights.expand(x.transpose(-1, -2).shape).transpose(-1, -2) + ddres = torch.zeros_like(dres) + return torch.cat((res, dres, ddres), -1) diff --git a/test/kernels/test_rbf_kernel_gradgrad.py b/test/kernels/test_rbf_kernel_gradgrad.py new file mode 100644 index 000000000..3e01357d5 --- /dev/null +++ b/test/kernels/test_rbf_kernel_gradgrad.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.kernels import RBFKernelGradGrad +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class TestRBFKernelGradGrad(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return RBFKernelGradGrad(**kwargs) + + def create_kernel_ard(self, num_dims, **kwargs): + return RBFKernelGradGrad(ard_num_dims=num_dims, **kwargs) + + def test_kernel(self, cuda=False): + a = torch.tensor([[[1, 2], [2, 4]]], dtype=torch.float) + b = torch.tensor([[[1, 3], [0, 4]]], dtype=torch.float) + + actual = torch.tensor( + [ + [ + [ + 3.5321289e-01, + 0.0000000e00, + -7.3516625e-01, + -7.3516631e-01, + 7.9498571e-01, + 5.4977159e-03, + 1.1442775e-02, + -2.2885550e-02, + 1.2373861e-02, + 8.3823770e-02, + ], + [ + -0.0000000e00, + 7.3516631e-01, + 0.0000000e00, + 0.0000000e00, + -0.0000000e00, + -1.1442775e-02, + -1.2373861e-02, + 4.7633272e-02, + 2.1878703e-02, + -1.7446819e-01, + ], + [ + 7.3516625e-01, + 0.0000000e00, + -7.9498571e-01, + -1.5301522e00, + -1.4056460e00, + 2.2885550e-02, + 4.7633272e-02, + -8.3823770e-02, + 5.1509142e-02, + 2.5366980e-01, + ], + [ + -7.3516631e-01, + -0.0000000e00, + 1.5301522e00, + 4.5904574e00, + -1.6546586e00, + 1.2373861e-02, + -2.1878703e-02, + -5.1509142e-02, + -1.2280136e-01, + 1.8866448e-01, + ], + [ + 7.9498571e-01, + 0.0000000e00, + 1.4056460e00, + -1.6546586e00, + -7.8896437e00, + 8.3823770e-02, + 1.7446819e-01, + -2.5366980e-01, + 1.8866447e-01, + 5.3255635e-01, + ], + [ + 1.2475928e-01, + 2.5967008e-01, + 2.5967011e-01, + 2.8079915e-01, + 2.8079927e-01, + 1.5564885e-02, + 6.4792536e-02, + 0.0000000e00, + 2.3731807e-01, + -3.2396268e-02, + ], + [ + -2.5967008e-01, + -2.8079915e-01, + -5.4046929e-01, + 4.9649185e-01, + -5.8444691e-01, + -6.4792536e-02, + -2.3731807e-01, + 0.0000000e00, + -7.1817851e-01, + 1.3485716e-01, + ], + [ + -2.5967011e-01, + -5.4046929e-01, + -2.8079927e-01, + -5.8444673e-01, + 4.9649167e-01, + -0.0000000e00, + 0.0000000e00, + 3.2396268e-02, + 0.0000000e00, + 0.0000000e00, + ], + [ + 2.8079915e-01, + -4.9649185e-01, + 5.8444673e-01, + -2.7867227e00, + 6.3200271e-01, + 2.3731807e-01, + 7.1817851e-01, + 0.0000000e00, + 1.5077497e00, + -4.9394643e-01, + ], + [ + 2.8079927e-01, + 5.8444691e-01, + -4.9649167e-01, + 6.3200271e-01, + -2.7867231e00, + -3.2396268e-02, + -1.3485716e-01, + -0.0000000e00, + -4.9394643e-01, + 2.0228577e-01, + ], + ] + ] + ) + + kernel = RBFKernelGradGrad() + + if cuda: + a = a.cuda() + b = b.cuda() + actual = actual.cuda() + kernel = kernel.cuda() + + res = kernel(a, b).to_dense() + + self.assertLess(torch.norm(res - actual), 1e-5) + + def test_kernel_cuda(self): + if torch.cuda.is_available(): + self.test_kernel(cuda=True) + + def test_kernel_batch(self): + a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) + b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) + + kernel = RBFKernelGradGrad() + res = kernel(a, b).to_dense() + + # Compute each batch separately + actual = torch.zeros(2, 14, 14) + actual[0, :, :] = kernel(a[0, :, :].squeeze(), b[0, :, :].squeeze()).to_dense() + actual[1, :, :] = kernel(a[1, :, :].squeeze(), b[1, :, :].squeeze()).to_dense() + + self.assertLess(torch.norm(res - actual), 1e-5) + + def test_initialize_lengthscale(self): + kernel = RBFKernelGradGrad() + kernel.initialize(lengthscale=3.14) + actual_value = torch.tensor(3.14).view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + def test_initialize_lengthscale_batch(self): + kernel = RBFKernelGradGrad(batch_shape=torch.Size([2])) + ls_init = torch.tensor([3.14, 4.13]) + kernel.initialize(lengthscale=ls_init) + actual_value = ls_init.view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/means/test_constant_mean_gradgrad.py b/test/means/test_constant_mean_gradgrad.py new file mode 100644 index 000000000..3008850df --- /dev/null +++ b/test/means/test_constant_mean_gradgrad.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.means import ConstantMeanGradGrad +from gpytorch.test.base_mean_test_case import BaseMeanTestCase + + +class TestConstantMeanGradGrad(BaseMeanTestCase, unittest.TestCase): + batch_shape = None + + def create_mean(self): + return ConstantMeanGradGrad(batch_shape=self.__class__.batch_shape or torch.Size()) + + def test_forward_vec(self): + test_x = torch.randn(4) + mean = self.create_mean() + if self.__class__.batch_shape is None: + self.assertEqual(mean(test_x).shape, torch.Size([4, 3])) + else: + self.assertEqual(mean(test_x).shape, torch.Size([*self.__class__.batch_shape, 4, 3])) + self.assertEqual(mean(test_x)[..., 1:].norm().item(), 0) + + def test_forward_mat(self): + test_x = torch.randn(4, 3) + mean = self.create_mean() + if self.__class__.batch_shape is None: + self.assertEqual(mean(test_x).shape, torch.Size([4, 7])) + else: + self.assertEqual(mean(test_x).shape, torch.Size([*self.__class__.batch_shape, 4, 7])) + self.assertEqual(mean(test_x)[..., 1:].norm().item(), 0) + + def test_forward_mat_batch(self): + test_x = torch.randn(3, 4, 3) + mean = self.create_mean() + if self.__class__.batch_shape is None: + self.assertEqual(mean(test_x).shape, torch.Size([3, 4, 7])) + else: + self.assertEqual(mean(test_x).shape, torch.Size([*self.__class__.batch_shape, 4, 7])) + self.assertEqual(mean(test_x)[..., 1:].norm().item(), 0) + + def test_forward_mat_multi_batch(self): + test_x = torch.randn(2, 3, 4, 3) + mean = self.create_mean() + self.assertEqual(mean(test_x).shape, torch.Size([2, 3, 4, 7])) + self.assertEqual(mean(test_x)[..., 1:].norm().item(), 0) + + +class TestConstantMeanGradGradBatch(TestConstantMeanGradGrad): + batch_shape = torch.Size([3]) + + +class TestConstantMeanGradGradMultiBatch(TestConstantMeanGradGrad): + batch_shape = torch.Size([2, 3]) diff --git a/test/means/test_linear_mean_grad.py b/test/means/test_linear_mean_grad.py new file mode 100644 index 000000000..3846b3684 --- /dev/null +++ b/test/means/test_linear_mean_grad.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.means import LinearMeanGrad +from gpytorch.test.base_mean_test_case import BaseMeanTestCase + + +class TestLinearMeanGrad(BaseMeanTestCase, unittest.TestCase): + def create_mean(self, input_size=1, batch_shape=torch.Size(), bias=True, **kwargs): + return LinearMeanGrad(input_size=input_size, batch_shape=batch_shape, bias=bias) + + def test_forward_vec(self): + n = 4 + test_x = torch.randn(n) + mean = self.create_mean(input_size=1) + self.assertEqual(mean(test_x).shape, torch.Size([n, 2])) + + def test_forward_mat(self): + n, d = 4, 5 + test_x = torch.randn(n, d) + mean = self.create_mean(d) + self.assertEqual(mean(test_x).shape, torch.Size([n, d + 1])) + + def test_forward_mat_batch(self): + b, n, d = torch.Size([3]), 4, 5 + test_x = torch.randn(*b, n, d) + mean = self.create_mean(d, b) + self.assertEqual(mean(test_x).shape, torch.Size([*b, n, d + 1])) + + def test_forward_mat_multi_batch(self): + b, n, d = torch.Size([2, 3]), 4, 5 + test_x = torch.randn(*b, n, d) + mean = self.create_mean(d, b) + self.assertEqual(mean(test_x).shape, torch.Size([*b, n, d + 1])) diff --git a/test/means/test_linear_mean_gradgrad.py b/test/means/test_linear_mean_gradgrad.py new file mode 100644 index 000000000..993950770 --- /dev/null +++ b/test/means/test_linear_mean_gradgrad.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.means import LinearMeanGradGrad +from gpytorch.test.base_mean_test_case import BaseMeanTestCase + + +class TestLinearMeanGradGrad(BaseMeanTestCase, unittest.TestCase): + def create_mean(self, input_size=1, batch_shape=torch.Size(), bias=True, **kwargs): + return LinearMeanGradGrad(input_size=input_size, batch_shape=batch_shape, bias=bias) + + def test_forward_vec(self): + n = 4 + test_x = torch.randn(n) + mean = self.create_mean(input_size=1) + self.assertEqual(mean(test_x).shape, torch.Size([n, 3])) + + def test_forward_mat(self): + n, d = 4, 5 + test_x = torch.randn(n, d) + mean = self.create_mean(d) + self.assertEqual(mean(test_x).shape, torch.Size([n, 2 * d + 1])) + + def test_forward_mat_batch(self): + b, n, d = torch.Size([3]), 4, 5 + test_x = torch.randn(*b, n, d) + mean = self.create_mean(d, b) + self.assertEqual(mean(test_x).shape, torch.Size([*b, n, 2 * d + 1])) + + def test_forward_mat_multi_batch(self): + b, n, d = torch.Size([2, 3]), 4, 5 + test_x = torch.randn(*b, n, d) + mean = self.create_mean(d, b) + self.assertEqual(mean(test_x).shape, torch.Size([*b, n, 2 * d + 1]))