Skip to content

Add a HalfCauchyPrior #1961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/priors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ Standard Priors
.. autoclass:: GammaPrior
:members:

:hidden:`HalfCauchyPrior`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: HalfCauchyPrior
:members:

:hidden:`LKJCovariancePrior`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
10 changes: 9 additions & 1 deletion gpytorch/priors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
from .lkj_prior import LKJCholeskyFactorPrior, LKJCovariancePrior, LKJPrior
from .prior import Prior
from .smoothed_box_prior import SmoothedBoxPrior
from .torch_priors import GammaPrior, LogNormalPrior, MultivariateNormalPrior, NormalPrior, UniformPrior
from .torch_priors import (
GammaPrior,
HalfCauchyPrior,
LogNormalPrior,
MultivariateNormalPrior,
NormalPrior,
UniformPrior,
)

# from .wishart_prior import InverseWishartPrior, WishartPrior


__all__ = [
"Prior",
"GammaPrior",
"HalfCauchyPrior",
"HorseshoePrior",
"LKJPrior",
"LKJCholeskyFactorPrior",
Expand Down
17 changes: 16 additions & 1 deletion gpytorch/priors/torch_priors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

import torch
from torch.distributions import Gamma, LogNormal, MultivariateNormal, Normal, Uniform
from torch.distributions import Gamma, HalfCauchy, LogNormal, MultivariateNormal, Normal, Uniform
from torch.nn import Module as TModule

from .prior import Prior
Expand Down Expand Up @@ -60,6 +60,21 @@ def expand(self, batch_shape):
return UniformPrior(self.low.expand(batch_shape), self.high.expand(batch_shape))


class HalfCauchyPrior(Prior, HalfCauchy):
"""
Half-Cauchy prior.
"""

def __init__(self, scale, validate_args=None, transform=None):
TModule.__init__(self)
HalfCauchy.__init__(self, scale=scale, validate_args=validate_args)
self._transform = transform

def expand(self, batch_shape):
batch_shape = torch.Size(batch_shape)
return HalfCauchy(self.loc.expand(batch_shape), self.scale.expand(batch_shape))


class GammaPrior(Prior, Gamma):
"""Gamma Prior parameterized by concentration and rate

Expand Down
87 changes: 87 additions & 0 deletions test/priors/test_half_cauchy_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3

import unittest

import torch
from torch.distributions import HalfCauchy

from gpytorch.priors import HalfCauchyPrior
from gpytorch.test.utils import least_used_cuda_device


class TestHalfCauchyPrior(unittest.TestCase):
def test_half_cauchy_prior_to_gpu(self):
if torch.cuda.is_available():
prior = HalfCauchy(1.0).cuda()
self.assertEqual(prior.concentration.device.type, "cuda")
self.assertEqual(prior.rate.device.type, "cuda")

def test_half_cauchy_prior_validate_args(self):
with self.assertRaises(ValueError):
HalfCauchyPrior(-1, validate_args=True)
with self.assertRaises(ValueError):
HalfCauchyPrior(-1, validate_args=True)

def test_half_cauchy_prior_log_prob(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
prior = HalfCauchyPrior(0.1)
dist = HalfCauchy(0.1)

t = torch.tensor(1.0, device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
t = torch.tensor([1.5, 0.5], device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
t = torch.tensor([[1.0, 0.5], [3.0, 0.25]], device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))

def test_half_cauchy_prior_log_prob_cuda(self):
if torch.cuda.is_available():
with least_used_cuda_device():
return self.test_gamma_prior_log_prob(cuda=True)

def test_half_cauchy_prior_log_prob_log_transform(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
prior = HalfCauchyPrior(0.1, transform=torch.exp)
dist = HalfCauchy(0.1)

t = torch.tensor(0.0, device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
t = torch.tensor([-1, 0.5], device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
t = torch.tensor([[-1, 0.5], [0.1, -2.0]], device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))

def test_half_cauchy_prior_log_prob_log_transform_cuda(self):
if torch.cuda.is_available():
with least_used_cuda_device():
return self.test_half_cauchy_prior_log_prob_log_transform(cuda=True)

def test_half_cauchy_prior_batch_log_prob(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
prior = HalfCauchyPrior(0.1)
dist = HalfCauchy(0.1)
t = torch.ones(2, device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
t = torch.ones(2, 2, device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))

scale = torch.tensor([0.1, 1.0], device=device)
prior = HalfCauchyPrior(scale)
dist = HalfCauchy(scale)
t = torch.ones(2, device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
t = torch.ones(2, 2, device=device)
self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
with self.assertRaises(ValueError):
prior.log_prob(torch.ones(3, device=device))
with self.assertRaises(ValueError):
prior.log_prob(torch.ones(2, 3, device=device))

def test_half_cauchy_prior_batch_log_prob_cuda(self):
if torch.cuda.is_available():
with least_used_cuda_device():
return self.test_half_cauchy_prior_batch_log_prob(cuda=True)


if __name__ == "__main__":
unittest.main()