diff --git a/gpytorch/priors/prior.py b/gpytorch/priors/prior.py index 1a6e6e1f7..2c5468bf1 100644 --- a/gpytorch/priors/prior.py +++ b/gpytorch/priors/prior.py @@ -1,10 +1,18 @@ #!/usr/bin/env python3 from abc import ABC +from typing import Any, Mapping +from torch.distributions import TransformedDistribution from torch.nn import Module from ..distributions import Distribution +from .utils import _load_transformed_to_base_dist + + +TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \ +'_transformed' attributes modified, these are just copies of the base attribute. \ +Please modify the base attribute (e.g. {}) instead.""" class Prior(Distribution, Module, ABC): @@ -25,3 +33,20 @@ def log_prob(self, x): :rtype: torch.Tensor """ return super(Prior, self).log_prob(self.transform(x)) + + def load_state_dict(self, state_dict: Mapping[str, Any], *args, **kwargs): + Module.load_state_dict(self, state_dict, *args, **kwargs) + if isinstance(self, TransformedDistribution): + _load_transformed_to_base_dist(self) + + def __setattr__(self, name: str, value: Any) -> None: + if hasattr(self, name) and "_transformed_" in name: + base_attr_name = name.replace("_transformed_", "") + raise AttributeError(TRANSFORMED_ERROR_MSG.format(base_attr_name)) + + elif hasattr(self, f"_transformed_{name}"): + self.base_dist.__setattr__(name, value) + super().__setattr__(f"_transformed_{name}", value) + + else: + return super().__setattr__(name, value) diff --git a/gpytorch/priors/torch_priors.py b/gpytorch/priors/torch_priors.py index 5e5dd2669..a3e243384 100644 --- a/gpytorch/priors/torch_priors.py +++ b/gpytorch/priors/torch_priors.py @@ -40,6 +40,7 @@ class HalfNormalPrior(Prior, HalfNormal): def __init__(self, scale, validate_args=None, transform=None): TModule.__init__(self) HalfNormal.__init__(self, scale=scale, validate_args=validate_args) + _bufferize_attributes(self, ("scale",)) self._transform = transform def expand(self, batch_shape): @@ -54,6 +55,7 @@ class LogNormalPrior(Prior, LogNormal): def __init__(self, loc, scale, validate_args=None, transform=None): TModule.__init__(self) LogNormal.__init__(self, loc=loc, scale=scale, validate_args=validate_args) + _bufferize_attributes(self, ("loc", "scale")) self._transform = transform def expand(self, batch_shape): @@ -84,6 +86,7 @@ class HalfCauchyPrior(Prior, HalfCauchy): def __init__(self, scale, validate_args=None, transform=None): TModule.__init__(self) HalfCauchy.__init__(self, scale=scale, validate_args=validate_args) + _bufferize_attributes(self, ("scale",)) self._transform = transform def expand(self, batch_shape): diff --git a/gpytorch/priors/utils.py b/gpytorch/priors/utils.py index 3cfce190e..e4468ab78 100644 --- a/gpytorch/priors/utils.py +++ b/gpytorch/priors/utils.py @@ -1,11 +1,36 @@ #!/usr/bin/env python3 +from torch.distributions import TransformedDistribution + def _bufferize_attributes(module, attributes): - attr_clones = {attr: getattr(module, attr).clone() for attr in attributes} - for attr, value in attr_clones.items(): - delattr(module, attr) - module.register_buffer(attr, value) + r""" + Adds the parameters of the prior as a torch buffer to enable saving/ + loading to/from state_dicts. + For TransformedDistributions Adds a _transformed_ attribute to the + parameters. This enables its parameters to be saved and + loaded to/from state_dicts, as the original parameters cannot be. + """ + if isinstance(module, TransformedDistribution): + for attr in attributes: + module.register_buffer(f"_transformed_{attr}", getattr(module, attr)) + else: + attr_clones = {attr: getattr(module, attr).clone() for attr in attributes} + for attr, value in attr_clones.items(): + delattr(module, attr) + module.register_buffer(attr, value) + + +def _load_transformed_to_base_dist(module): + r"""loads the _transformed_ attributes to the parameters of a torch + TransformedDistribution. This enables its parameters to be saved and + loaded to/from state_dicts, as the original parameters cannot be. + """ + transf_str = "_transformed_" + transformed_attrs = [attr for attr in dir(module) if transf_str in attr] + for transf_attr in transformed_attrs: + base_attr_name = transf_attr.replace(transf_str, "") + setattr(module.base_dist, base_attr_name, getattr(module, transf_attr)) def _del_attributes(module, attributes, raise_on_error=False): diff --git a/test/priors/test_prior.py b/test/priors/test_prior.py new file mode 100644 index 000000000..53fef5976 --- /dev/null +++ b/test/priors/test_prior.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import unittest + +from torch import Tensor + +from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior + + +TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \ +'_transformed' attributes modified, these are just copies of the base attribute. \ +Please modify the base attribute (e.g. {}) instead.""" + + +class TestPrior(unittest.TestCase): + def test_state_dict(self): + normal = NormalPrior(0.1, 1).state_dict() + self.assertTrue("loc" in normal) + self.assertTrue("scale" in normal) + self.assertEqual(normal["loc"], 0.1) + + gamma = GammaPrior(1.1, 2).state_dict() + self.assertTrue("concentration" in gamma) + self.assertTrue("rate" in gamma) + self.assertEqual(gamma["concentration"], 1.1) + + ln = LogNormalPrior(2.1, 1.2).state_dict() + self.assertTrue("_transformed_loc" in ln) + self.assertTrue("_transformed_scale" in ln) + self.assertEqual(ln["_transformed_loc"], 2.1) + + hc = HalfCauchyPrior(1.3).state_dict() + self.assertTrue("_transformed_scale" in hc) + + def test_load_state_dict(self): + ln1 = LogNormalPrior(loc=0.5, scale=0.1) + ln2 = LogNormalPrior(loc=2.5, scale=2.1) + gm1 = GammaPrior(concentration=0.5, rate=0.1) + gm2 = GammaPrior(concentration=2.5, rate=2.1) + hc1 = HalfCauchyPrior(scale=1.1) + hc2 = HalfCauchyPrior(scale=101.1) + + ln2.load_state_dict(ln1.state_dict()) + self.assertEqual(ln2.loc, ln1.loc) + self.assertEqual(ln2.scale, ln1.scale) + + gm2.load_state_dict(gm1.state_dict()) + self.assertEqual(gm2.concentration, gm1.concentration) + self.assertEqual(gm2.rate, gm1.rate) + + hc2.load_state_dict(hc1.state_dict()) + self.assertEqual(hc2.scale, hc1.scale) + + def test_transformed_attributes(self): + norm = NormalPrior(loc=2.5, scale=2.1) + ln = LogNormalPrior(loc=2.5, scale=2.1) + hc = HalfCauchyPrior(scale=2.2) + + with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"): + getattr(norm, "_transformed_loc") + + self.assertTrue(getattr(ln, "_transformed_loc"), 2.5) + norm.loc = Tensor([1.01]) + ln.loc = Tensor([1.01]) + self.assertEqual(ln._transformed_loc, 1.01) + with self.assertRaises(AttributeError): + ln._transformed_loc = 1.1 + + with self.assertRaises(AttributeError): + hc._transformed_scale = 1.01 diff --git a/test/priors/test_utils.py b/test/priors/test_utils.py new file mode 100644 index 000000000..c62bbaefb --- /dev/null +++ b/test/priors/test_utils.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +import unittest + +from torch import Tensor + +from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior + + +class TestPrior(unittest.TestCase): + def test_state_dict(self): + normal = NormalPrior(0.1, 1).state_dict() + self.assertTrue("loc" in normal) + self.assertTrue("scale" in normal) + self.assertEqual(normal["loc"], 0.1) + + gamma = GammaPrior(1.1, 2).state_dict() + self.assertTrue("concentration" in gamma) + self.assertTrue("rate" in gamma) + self.assertEqual(gamma["concentration"], 1.1) + + ln = LogNormalPrior(2.1, 1.2).state_dict() + self.assertTrue("_transformed_loc" in ln) + self.assertTrue("_transformed_scale" in ln) + self.assertEqual(ln["_transformed_loc"], 2.1) + + hc = HalfCauchyPrior(1.3).state_dict() + self.assertTrue("_transformed_scale" in hc) + + def test_load_state_dict(self): + ln1 = LogNormalPrior(loc=0.5, scale=0.1) + ln2 = LogNormalPrior(loc=2.5, scale=2.1) + gm1 = GammaPrior(concentration=0.5, rate=0.1) + gm2 = GammaPrior(concentration=2.5, rate=2.1) + hc1 = HalfCauchyPrior(scale=1.1) + hc2 = HalfCauchyPrior(scale=101.1) + + ln2.load_state_dict(ln1.state_dict()) + self.assertEqual(ln2.loc, ln1.loc) + self.assertEqual(ln2.scale, ln1.scale) + + gm2.load_state_dict(gm1.state_dict()) + self.assertEqual(gm2.concentration, gm1.concentration) + self.assertEqual(gm2.rate, gm1.rate) + + hc2.load_state_dict(hc1.state_dict()) + self.assertEqual(hc2.scale, hc1.scale) + + def test_transformed_attributes(self): + norm = NormalPrior(loc=2.5, scale=2.1) + ln = LogNormalPrior(loc=2.5, scale=2.1) + hc = HalfCauchyPrior(scale=2.2) + + with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"): + getattr(norm, "_transformed_loc") + + self.assertTrue(getattr(ln, "_transformed_loc"), 2.5) + norm.loc = Tensor([1.01]) + ln.loc = Tensor([1.01]) + self.assertEqual(ln._transformed_loc, 1.01) + self.assertEqual(hc._transformed_scale, 2.2)