Skip to content

Commit ba6d8be

Browse files
committed
Add constant_constraint to ConstantMean
[Fixes #2074]
1 parent 2622873 commit ba6d8be

File tree

2 files changed

+163
-16
lines changed

2 files changed

+163
-16
lines changed

gpytorch/means/constant_mean.py

+98-8
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,119 @@
11
#!/usr/bin/env python3
22

3+
import warnings
4+
from typing import Any, Optional
5+
36
import torch
47

8+
from ..constraints import Interval
9+
from ..priors import Prior
510
from ..utils.broadcasting import _mul_broadcast_shape
11+
from ..utils.warnings import OldVersionWarning
612
from .mean import Mean
713

814

15+
def _ensure_updated_strategy_flag_set(
16+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
17+
):
18+
if prefix + "constant" in state_dict:
19+
constant = state_dict.pop(prefix + "constant")
20+
state_dict[prefix + "raw_constant"] = constant
21+
warnings.warn(
22+
"You have loaded a GP model with a ConstantMean from a previous version of "
23+
"GPyTorch. The mean module parameter `constant` has been renamed to `raw_constant`. "
24+
"We have updated the name of the parameter in your state dict, but we recommend that you "
25+
"re-save your model.",
26+
OldVersionWarning,
27+
)
28+
29+
930
class ConstantMean(Mean):
10-
def __init__(self, prior=None, batch_shape=torch.Size(), **kwargs):
31+
r"""
32+
A (non-zero) constant prior mean function, i.e.:
33+
34+
.. math::
35+
\mu(\mathbf x) = C
36+
37+
where :math:`C` is a learned constant.
38+
39+
:param constant_prior: Prior for constant parameter :math:`C`.
40+
:type constant_prior: ~gpytorch.priors.Prior, optional
41+
:param constant_constraint: Constraint for constant parameter :math:`C`.
42+
:type constant_constraint: ~gpytorch.priors.Interval, optional
43+
:param batch_shape: The batch shape of the learned constant(s) (default: []).
44+
:type batch_shape: torch.Size, optional
45+
46+
:var torch.Tensor constant: :math:`C` parameter
47+
"""
48+
49+
def __init__(
50+
self,
51+
constant_prior: Optional[Prior] = None,
52+
constant_constraint: Optional[Interval] = None,
53+
batch_shape: torch.Size = torch.Size(),
54+
**kwargs: Any,
55+
):
1156
super(ConstantMean, self).__init__()
57+
58+
# Deprecated kwarg
59+
constant_prior_deprecated = kwargs.get("prior")
60+
if constant_prior_deprecated is not None:
61+
if constant_prior is None: # Using the old kwarg for the constant_prior
62+
warnings.warn(
63+
"The kwarg `prior` for ConstantMean has been renamed to `constant_prior`, and will be deprecated.",
64+
DeprecationWarning,
65+
)
66+
constant_prior = constant_prior_deprecated
67+
else: # Weird edge case where someone set both `prior` and `constant_prior`
68+
warnings.warn(
69+
"You have set both the `constant_prior` and the deprecated `prior` arguments for ConstantMean. "
70+
"`prior` is deprecated, and will be ignored.",
71+
DeprecationWarning,
72+
)
73+
74+
# Ensure that old versions of the model still load
75+
self._register_load_state_dict_pre_hook(_ensure_updated_strategy_flag_set)
76+
1277
self.batch_shape = batch_shape
13-
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
14-
if prior is not None:
15-
self.register_prior("mean_prior", prior, self._constant_param, self._constant_closure)
78+
self.register_parameter(name="raw_constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
79+
if constant_prior is not None:
80+
self.register_prior("mean_prior", constant_prior, self._constant_param, self._constant_closure)
81+
if constant_constraint is not None:
82+
self.register_constraint("raw_constant", constant_constraint)
83+
84+
@property
85+
def constant(self):
86+
return self._constant_param(self)
1687

88+
@constant.setter
89+
def constant(self, value):
90+
self._constant_closure(self, value)
91+
92+
# We need a getter of this form so that we can pickle ConstantMean modules with a mean prior, see PR #1992
1793
def _constant_param(self, m):
18-
return m.constant
94+
if hasattr(m, "raw_constant_constraint"):
95+
return m.raw_constant_constraint.transform(m.raw_constant)
96+
return m.raw_constant
1997

98+
# We need a setter of this form so that we can pickle ConstantMean modules with a mean prior, see PR #1992
2099
def _constant_closure(self, m, value):
21100
if not torch.is_tensor(value):
22-
value = torch.as_tensor(value).to(self.constant)
23-
m.initialize(constant=value.reshape(self.constant.shape))
101+
value = torch.as_tensor(value).to(m.raw_constant)
102+
103+
# Reshape the value so that it has a singleton dimension on the end
104+
if value.numel() != self.raw_constant.numel():
105+
raise RuntimeError(
106+
f"Value of shape {value.shape} is incompatibile with ConstantMean of batch shape {m.batch_shape}"
107+
)
108+
value = value.reshape(self.constant.shape)
109+
110+
if hasattr(m, "raw_constant_constraint"):
111+
m.initialize(raw_constant=m.raw_constant_constraint.inverse_transform(value))
112+
else:
113+
m.initialize(raw_constant=value)
24114

25115
def forward(self, input):
26116
if input.shape[:-2] == self.batch_shape:
27117
return self.constant.expand(input.shape[:-1])
28118
else:
29-
return self.constant.expand(_mul_broadcast_shape(input.shape[:-1], self.constant.shape))
119+
return self.constant.expand(_mul_broadcast_shape(input.shape[:-1], self.raw_constant.shape))

test/means/test_constant_mean.py

+65-8
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,40 @@
11
#!/usr/bin/env python3
22

3+
import math
34
import pickle
45
import unittest
6+
import warnings
7+
from collections import OrderedDict
58

69
import torch
710

11+
import gpytorch
12+
from gpytorch.constraints import GreaterThan
813
from gpytorch.means import ConstantMean
914
from gpytorch.priors import NormalPrior
1015
from gpytorch.test.base_mean_test_case import BaseMeanTestCase
16+
from gpytorch.utils.warnings import OldVersionWarning
17+
18+
19+
# Test class for loading models that have state dicts with the old ConstantMean parameter names
20+
class _GPModel(gpytorch.models.ExactGP):
21+
def __init__(self, mean_module):
22+
train_x = torch.randn(10, 3)
23+
train_y = torch.randn(10)
24+
likelihood = gpytorch.likelihoods.GaussianLikelihood()
25+
super().__init__(train_x, train_y, likelihood)
26+
self.mean_module = mean_module
1127

1228

1329
class TestConstantMean(BaseMeanTestCase, unittest.TestCase):
1430
batch_shape = None
1531

16-
def create_mean(self, prior=None):
17-
return ConstantMean(prior=prior, batch_shape=torch.Size([]))
32+
def create_mean(self, prior=None, constraint=None):
33+
return ConstantMean(
34+
constant_prior=prior,
35+
constant_constraint=constraint,
36+
batch_shape=(self.__class__.batch_shape or torch.Size([])),
37+
)
1838

1939
def test_prior(self):
2040
if self.batch_shape is None:
@@ -28,16 +48,53 @@ def test_prior(self):
2848
mean._constant_closure(mean, value)
2949
self.assertTrue(torch.equal(mean.constant.data, value.reshape(mean.constant.data.shape)))
3050

51+
def test_constraint(self):
52+
mean = self.create_mean()
53+
self.assertAllClose(mean.constant, torch.zeros(mean.constant.shape))
54+
55+
constraint = GreaterThan(1.5)
56+
mean = self.create_mean(constraint=constraint)
57+
self.assertTrue(torch.all(mean.constant >= 1.5))
58+
mean.constant = torch.full(self.__class__.batch_shape or torch.Size([]), fill_value=1.65)
59+
self.assertAllClose(mean.constant, torch.tensor([1.65]).expand(mean.constant.shape))
60+
61+
def test_loading_old_module(self):
62+
batch_shape = self.__class__.batch_shape or torch.Size([])
63+
constant = torch.randn(*batch_shape, 1)
64+
mean = self.create_mean()
65+
model = _GPModel(mean)
66+
67+
old_state_dict = OrderedDict(
68+
[
69+
("likelihood.noise_covar.raw_noise", torch.tensor([0.0])),
70+
("likelihood.noise_covar.raw_noise_constraint.lower_bound", torch.tensor(1.0000e-04)),
71+
("likelihood.noise_covar.raw_noise_constraint.upper_bound", torch.tensor(math.inf)),
72+
("mean_module.constant", constant),
73+
]
74+
)
75+
with warnings.catch_warnings(record=True) as ws:
76+
warnings.simplefilter("always", OldVersionWarning)
77+
model.load_state_dict(old_state_dict)
78+
self.assertTrue(any(issubclass(w.category, OldVersionWarning) for w in ws))
79+
self.assertEqual(model.mean_module.constant.data, constant)
80+
81+
new_state_dict = OrderedDict(
82+
[
83+
("likelihood.noise_covar.raw_noise", torch.tensor([0.0])),
84+
("likelihood.noise_covar.raw_noise_constraint.lower_bound", torch.tensor(1.0000e-04)),
85+
("likelihood.noise_covar.raw_noise_constraint.upper_bound", torch.tensor(math.inf)),
86+
("mean_module.raw_constant", constant),
87+
]
88+
)
89+
with warnings.catch_warnings(record=True) as ws:
90+
warnings.simplefilter("always", OldVersionWarning)
91+
model.load_state_dict(new_state_dict)
92+
self.assertFalse(any(issubclass(w.category, OldVersionWarning) for w in ws))
93+
3194

3295
class TestConstantMeanBatch(TestConstantMean, unittest.TestCase):
3396
batch_shape = torch.Size([3])
3497

35-
def create_mean(self, prior=None):
36-
return ConstantMean(prior=prior, batch_shape=self.__class__.batch_shape)
37-
3898

3999
class TestConstantMeanMultiBatch(TestConstantMean, unittest.TestCase):
40100
batch_shape = torch.Size([2, 3])
41-
42-
def create_mean(self, prior=None):
43-
return ConstantMean(prior=prior, batch_shape=self.__class__.batch_shape)

0 commit comments

Comments
 (0)