Skip to content

Commit 917603c

Browse files
authored
Added ability for priors of transformed distributions to have their p… (#2551)
1 parent c118306 commit 917603c

File tree

5 files changed

+188
-4
lines changed

5 files changed

+188
-4
lines changed

gpytorch/priors/prior.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
#!/usr/bin/env python3
22

33
from abc import ABC
4+
from typing import Any, Mapping
45

6+
from torch.distributions import TransformedDistribution
57
from torch.nn import Module
68

79
from ..distributions import Distribution
10+
from .utils import _load_transformed_to_base_dist
11+
12+
13+
TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \
14+
'_transformed' attributes modified, these are just copies of the base attribute. \
15+
Please modify the base attribute (e.g. {}) instead."""
816

917

1018
class Prior(Distribution, Module, ABC):
@@ -25,3 +33,20 @@ def log_prob(self, x):
2533
:rtype: torch.Tensor
2634
"""
2735
return super(Prior, self).log_prob(self.transform(x))
36+
37+
def load_state_dict(self, state_dict: Mapping[str, Any], *args, **kwargs):
38+
Module.load_state_dict(self, state_dict, *args, **kwargs)
39+
if isinstance(self, TransformedDistribution):
40+
_load_transformed_to_base_dist(self)
41+
42+
def __setattr__(self, name: str, value: Any) -> None:
43+
if hasattr(self, name) and "_transformed_" in name:
44+
base_attr_name = name.replace("_transformed_", "")
45+
raise AttributeError(TRANSFORMED_ERROR_MSG.format(base_attr_name))
46+
47+
elif hasattr(self, f"_transformed_{name}"):
48+
self.base_dist.__setattr__(name, value)
49+
super().__setattr__(f"_transformed_{name}", value)
50+
51+
else:
52+
return super().__setattr__(name, value)

gpytorch/priors/torch_priors.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class HalfNormalPrior(Prior, HalfNormal):
4040
def __init__(self, scale, validate_args=None, transform=None):
4141
TModule.__init__(self)
4242
HalfNormal.__init__(self, scale=scale, validate_args=validate_args)
43+
_bufferize_attributes(self, ("scale",))
4344
self._transform = transform
4445

4546
def expand(self, batch_shape):
@@ -54,6 +55,7 @@ class LogNormalPrior(Prior, LogNormal):
5455
def __init__(self, loc, scale, validate_args=None, transform=None):
5556
TModule.__init__(self)
5657
LogNormal.__init__(self, loc=loc, scale=scale, validate_args=validate_args)
58+
_bufferize_attributes(self, ("loc", "scale"))
5759
self._transform = transform
5860

5961
def expand(self, batch_shape):
@@ -84,6 +86,7 @@ class HalfCauchyPrior(Prior, HalfCauchy):
8486
def __init__(self, scale, validate_args=None, transform=None):
8587
TModule.__init__(self)
8688
HalfCauchy.__init__(self, scale=scale, validate_args=validate_args)
89+
_bufferize_attributes(self, ("scale",))
8790
self._transform = transform
8891

8992
def expand(self, batch_shape):

gpytorch/priors/utils.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,36 @@
11
#!/usr/bin/env python3
22

3+
from torch.distributions import TransformedDistribution
4+
35

46
def _bufferize_attributes(module, attributes):
5-
attr_clones = {attr: getattr(module, attr).clone() for attr in attributes}
6-
for attr, value in attr_clones.items():
7-
delattr(module, attr)
8-
module.register_buffer(attr, value)
7+
r"""
8+
Adds the parameters of the prior as a torch buffer to enable saving/
9+
loading to/from state_dicts.
10+
For TransformedDistributions Adds a _transformed_ attribute to the
11+
parameters. This enables its parameters to be saved and
12+
loaded to/from state_dicts, as the original parameters cannot be.
13+
"""
14+
if isinstance(module, TransformedDistribution):
15+
for attr in attributes:
16+
module.register_buffer(f"_transformed_{attr}", getattr(module, attr))
17+
else:
18+
attr_clones = {attr: getattr(module, attr).clone() for attr in attributes}
19+
for attr, value in attr_clones.items():
20+
delattr(module, attr)
21+
module.register_buffer(attr, value)
22+
23+
24+
def _load_transformed_to_base_dist(module):
25+
r"""loads the _transformed_ attributes to the parameters of a torch
26+
TransformedDistribution. This enables its parameters to be saved and
27+
loaded to/from state_dicts, as the original parameters cannot be.
28+
"""
29+
transf_str = "_transformed_"
30+
transformed_attrs = [attr for attr in dir(module) if transf_str in attr]
31+
for transf_attr in transformed_attrs:
32+
base_attr_name = transf_attr.replace(transf_str, "")
33+
setattr(module.base_dist, base_attr_name, getattr(module, transf_attr))
934

1035

1136
def _del_attributes(module, attributes, raise_on_error=False):

test/priors/test_prior.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
from torch import Tensor
6+
7+
from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior
8+
9+
10+
TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \
11+
'_transformed' attributes modified, these are just copies of the base attribute. \
12+
Please modify the base attribute (e.g. {}) instead."""
13+
14+
15+
class TestPrior(unittest.TestCase):
16+
def test_state_dict(self):
17+
normal = NormalPrior(0.1, 1).state_dict()
18+
self.assertTrue("loc" in normal)
19+
self.assertTrue("scale" in normal)
20+
self.assertEqual(normal["loc"], 0.1)
21+
22+
gamma = GammaPrior(1.1, 2).state_dict()
23+
self.assertTrue("concentration" in gamma)
24+
self.assertTrue("rate" in gamma)
25+
self.assertEqual(gamma["concentration"], 1.1)
26+
27+
ln = LogNormalPrior(2.1, 1.2).state_dict()
28+
self.assertTrue("_transformed_loc" in ln)
29+
self.assertTrue("_transformed_scale" in ln)
30+
self.assertEqual(ln["_transformed_loc"], 2.1)
31+
32+
hc = HalfCauchyPrior(1.3).state_dict()
33+
self.assertTrue("_transformed_scale" in hc)
34+
35+
def test_load_state_dict(self):
36+
ln1 = LogNormalPrior(loc=0.5, scale=0.1)
37+
ln2 = LogNormalPrior(loc=2.5, scale=2.1)
38+
gm1 = GammaPrior(concentration=0.5, rate=0.1)
39+
gm2 = GammaPrior(concentration=2.5, rate=2.1)
40+
hc1 = HalfCauchyPrior(scale=1.1)
41+
hc2 = HalfCauchyPrior(scale=101.1)
42+
43+
ln2.load_state_dict(ln1.state_dict())
44+
self.assertEqual(ln2.loc, ln1.loc)
45+
self.assertEqual(ln2.scale, ln1.scale)
46+
47+
gm2.load_state_dict(gm1.state_dict())
48+
self.assertEqual(gm2.concentration, gm1.concentration)
49+
self.assertEqual(gm2.rate, gm1.rate)
50+
51+
hc2.load_state_dict(hc1.state_dict())
52+
self.assertEqual(hc2.scale, hc1.scale)
53+
54+
def test_transformed_attributes(self):
55+
norm = NormalPrior(loc=2.5, scale=2.1)
56+
ln = LogNormalPrior(loc=2.5, scale=2.1)
57+
hc = HalfCauchyPrior(scale=2.2)
58+
59+
with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"):
60+
getattr(norm, "_transformed_loc")
61+
62+
self.assertTrue(getattr(ln, "_transformed_loc"), 2.5)
63+
norm.loc = Tensor([1.01])
64+
ln.loc = Tensor([1.01])
65+
self.assertEqual(ln._transformed_loc, 1.01)
66+
with self.assertRaises(AttributeError):
67+
ln._transformed_loc = 1.1
68+
69+
with self.assertRaises(AttributeError):
70+
hc._transformed_scale = 1.01

test/priors/test_utils.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
from torch import Tensor
6+
7+
from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior
8+
9+
10+
class TestPrior(unittest.TestCase):
11+
def test_state_dict(self):
12+
normal = NormalPrior(0.1, 1).state_dict()
13+
self.assertTrue("loc" in normal)
14+
self.assertTrue("scale" in normal)
15+
self.assertEqual(normal["loc"], 0.1)
16+
17+
gamma = GammaPrior(1.1, 2).state_dict()
18+
self.assertTrue("concentration" in gamma)
19+
self.assertTrue("rate" in gamma)
20+
self.assertEqual(gamma["concentration"], 1.1)
21+
22+
ln = LogNormalPrior(2.1, 1.2).state_dict()
23+
self.assertTrue("_transformed_loc" in ln)
24+
self.assertTrue("_transformed_scale" in ln)
25+
self.assertEqual(ln["_transformed_loc"], 2.1)
26+
27+
hc = HalfCauchyPrior(1.3).state_dict()
28+
self.assertTrue("_transformed_scale" in hc)
29+
30+
def test_load_state_dict(self):
31+
ln1 = LogNormalPrior(loc=0.5, scale=0.1)
32+
ln2 = LogNormalPrior(loc=2.5, scale=2.1)
33+
gm1 = GammaPrior(concentration=0.5, rate=0.1)
34+
gm2 = GammaPrior(concentration=2.5, rate=2.1)
35+
hc1 = HalfCauchyPrior(scale=1.1)
36+
hc2 = HalfCauchyPrior(scale=101.1)
37+
38+
ln2.load_state_dict(ln1.state_dict())
39+
self.assertEqual(ln2.loc, ln1.loc)
40+
self.assertEqual(ln2.scale, ln1.scale)
41+
42+
gm2.load_state_dict(gm1.state_dict())
43+
self.assertEqual(gm2.concentration, gm1.concentration)
44+
self.assertEqual(gm2.rate, gm1.rate)
45+
46+
hc2.load_state_dict(hc1.state_dict())
47+
self.assertEqual(hc2.scale, hc1.scale)
48+
49+
def test_transformed_attributes(self):
50+
norm = NormalPrior(loc=2.5, scale=2.1)
51+
ln = LogNormalPrior(loc=2.5, scale=2.1)
52+
hc = HalfCauchyPrior(scale=2.2)
53+
54+
with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"):
55+
getattr(norm, "_transformed_loc")
56+
57+
self.assertTrue(getattr(ln, "_transformed_loc"), 2.5)
58+
norm.loc = Tensor([1.01])
59+
ln.loc = Tensor([1.01])
60+
self.assertEqual(ln._transformed_loc, 1.01)
61+
self.assertEqual(hc._transformed_scale, 2.2)

0 commit comments

Comments
 (0)