diff --git a/docs/source/priors.rst b/docs/source/priors.rst index 1dcd55374..ce5647c46 100644 --- a/docs/source/priors.rst +++ b/docs/source/priors.rst @@ -24,6 +24,12 @@ Standard Priors .. autoclass:: GammaPrior :members: +:hidden:`HalfCauchyPrior` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: HalfCauchyPrior + :members: + :hidden:`LKJCovariancePrior` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/priors/__init__.py b/gpytorch/priors/__init__.py index 24c259a83..89e68acbf 100644 --- a/gpytorch/priors/__init__.py +++ b/gpytorch/priors/__init__.py @@ -4,7 +4,14 @@ from .lkj_prior import LKJCholeskyFactorPrior, LKJCovariancePrior, LKJPrior from .prior import Prior from .smoothed_box_prior import SmoothedBoxPrior -from .torch_priors import GammaPrior, LogNormalPrior, MultivariateNormalPrior, NormalPrior, UniformPrior +from .torch_priors import ( + GammaPrior, + HalfCauchyPrior, + LogNormalPrior, + MultivariateNormalPrior, + NormalPrior, + UniformPrior, +) # from .wishart_prior import InverseWishartPrior, WishartPrior @@ -12,6 +19,7 @@ __all__ = [ "Prior", "GammaPrior", + "HalfCauchyPrior", "HorseshoePrior", "LKJPrior", "LKJCholeskyFactorPrior", diff --git a/gpytorch/priors/torch_priors.py b/gpytorch/priors/torch_priors.py index 087f981dc..9df86d8d0 100644 --- a/gpytorch/priors/torch_priors.py +++ b/gpytorch/priors/torch_priors.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import torch -from torch.distributions import Gamma, LogNormal, MultivariateNormal, Normal, Uniform +from torch.distributions import Gamma, HalfCauchy, LogNormal, MultivariateNormal, Normal, Uniform from torch.nn import Module as TModule from .prior import Prior @@ -60,6 +60,20 @@ def expand(self, batch_shape): return UniformPrior(self.low.expand(batch_shape), self.high.expand(batch_shape)) +class HalfCauchyPrior(Prior, HalfCauchy): + """ + Half-Cauchy prior. + """ + + def __init__(self, scale, validate_args=None, transform=None): + TModule.__init__(self) + HalfCauchy.__init__(self, scale=scale, validate_args=validate_args) + self._transform = transform + + def expand(self, batch_shape): + return HalfCauchy(self.loc.expand(batch_shape), self.scale.expand(batch_shape)) + + class GammaPrior(Prior, Gamma): """Gamma Prior parameterized by concentration and rate diff --git a/test/priors/test_half_cauchy_prior.py b/test/priors/test_half_cauchy_prior.py new file mode 100644 index 000000000..f6d7d9fa5 --- /dev/null +++ b/test/priors/test_half_cauchy_prior.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torch.distributions import HalfCauchy + +from gpytorch.priors import HalfCauchyPrior +from gpytorch.test.utils import least_used_cuda_device + + +class TestHalfCauchyPrior(unittest.TestCase): + def test_half_cauchy_prior_to_gpu(self): + if torch.cuda.is_available(): + prior = HalfCauchy(1.0).cuda() + self.assertEqual(prior.concentration.device.type, "cuda") + self.assertEqual(prior.rate.device.type, "cuda") + + def test_half_cauchy_prior_validate_args(self): + with self.assertRaises(ValueError): + HalfCauchyPrior(-1, validate_args=True) + with self.assertRaises(ValueError): + HalfCauchyPrior(-1, validate_args=True) + + def test_half_cauchy_prior_log_prob(self, cuda=False): + device = torch.device("cuda") if cuda else torch.device("cpu") + prior = HalfCauchyPrior(0.1) + dist = HalfCauchy(0.1) + + t = torch.tensor(1.0, device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) + t = torch.tensor([1.5, 0.5], device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) + t = torch.tensor([[1.0, 0.5], [3.0, 0.25]], device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) + + def test_half_cauchy_prior_log_prob_cuda(self): + if torch.cuda.is_available(): + with least_used_cuda_device(): + return self.test_gamma_prior_log_prob(cuda=True) + + def test_half_cauchy_prior_log_prob_log_transform(self, cuda=False): + device = torch.device("cuda") if cuda else torch.device("cpu") + prior = HalfCauchyPrior(0.1, transform=torch.exp) + dist = HalfCauchy(0.1) + + t = torch.tensor(0.0, device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp()))) + t = torch.tensor([-1, 0.5], device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp()))) + t = torch.tensor([[-1, 0.5], [0.1, -2.0]], device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp()))) + + def test_half_cauchy_prior_log_prob_log_transform_cuda(self): + if torch.cuda.is_available(): + with least_used_cuda_device(): + return self.test_half_cauchy_prior_log_prob_log_transform(cuda=True) + + def test_half_cauchy_prior_batch_log_prob(self, cuda=False): + device = torch.device("cuda") if cuda else torch.device("cpu") + prior = HalfCauchyPrior(0.1) + dist = HalfCauchy(0.1) + t = torch.ones(2, device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) + t = torch.ones(2, 2, device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) + + scale = torch.tensor([0.1, 1.0], device=device) + prior = HalfCauchyPrior(scale) + dist = HalfCauchy(scale) + t = torch.ones(2, device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) + t = torch.ones(2, 2, device=device) + self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t))) + with self.assertRaises(ValueError): + prior.log_prob(torch.ones(3, device=device)) + with self.assertRaises(ValueError): + prior.log_prob(torch.ones(2, 3, device=device)) + + def test_half_cauchy_prior_batch_log_prob_cuda(self): + if torch.cuda.is_available(): + with least_used_cuda_device(): + return self.test_half_cauchy_prior_batch_log_prob(cuda=True) + + +if __name__ == "__main__": + unittest.main()