Skip to content

Commit 1314acd

Browse files
committed
Add constant_constraint to ConstantMean
[Fixes #2074]
1 parent 0c359c5 commit 1314acd

File tree

2 files changed

+159
-20
lines changed

2 files changed

+159
-20
lines changed

gpytorch/means/constant_mean.py

+94-12
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,111 @@
11
#!/usr/bin/env python3
22

3+
import warnings
4+
from typing import Any, Optional
5+
36
import torch
47

5-
from ..utils.broadcasting import _mul_broadcast_shape
8+
from ..constraints import Interval
9+
from ..priors import Prior
10+
from ..utils.warnings import OldVersionWarning
611
from .mean import Mean
712

813

14+
def _ensure_updated_strategy_flag_set(
15+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
16+
):
17+
if prefix + "constant" in state_dict:
18+
constant = state_dict.pop(prefix + "constant").squeeze(-1) # Remove deprecated singleton dimension
19+
state_dict[prefix + "raw_constant"] = constant
20+
warnings.warn(
21+
"You have loaded a GP model with a ConstantMean from a previous version of "
22+
"GPyTorch. The mean module parameter `constant` has been renamed to `raw_constant`. "
23+
"Additionally, the shape of `raw_constant` is now *batch_shape, whereas the shape of "
24+
"`constant` was *batch_shape x 1. "
25+
"We have updated the name/shape of the parameter in your state dict, but we recommend that you "
26+
"re-save your model.",
27+
OldVersionWarning,
28+
)
29+
30+
931
class ConstantMean(Mean):
10-
def __init__(self, prior=None, batch_shape=torch.Size(), **kwargs):
32+
r"""
33+
A (non-zero) constant prior mean function, i.e.:
34+
35+
.. math::
36+
\mu(\mathbf x) = C
37+
38+
where :math:`C` is a learned constant.
39+
40+
:param constant_prior: Prior for constant parameter :math:`C`.
41+
:type constant_prior: ~gpytorch.priors.Prior, optional
42+
:param constant_constraint: Constraint for constant parameter :math:`C`.
43+
:type constant_constraint: ~gpytorch.priors.Interval, optional
44+
:param batch_shape: The batch shape of the learned constant(s) (default: []).
45+
:type batch_shape: torch.Size, optional
46+
47+
:var torch.Tensor constant: :math:`C` parameter
48+
"""
49+
50+
def __init__(
51+
self,
52+
constant_prior: Optional[Prior] = None,
53+
constant_constraint: Optional[Interval] = None,
54+
batch_shape: torch.Size = torch.Size(),
55+
**kwargs: Any,
56+
):
1157
super(ConstantMean, self).__init__()
58+
59+
# Deprecated kwarg
60+
constant_prior_deprecated = kwargs.get("prior")
61+
if constant_prior_deprecated is not None:
62+
if constant_prior is None: # Using the old kwarg for the constant_prior
63+
warnings.warn(
64+
"The kwarg `prior` for ConstantMean has been renamed to `constant_prior`, and will be deprecated.",
65+
DeprecationWarning,
66+
)
67+
constant_prior = constant_prior_deprecated
68+
else: # Weird edge case where someone set both `prior` and `constant_prior`
69+
warnings.warn(
70+
"You have set both the `constant_prior` and the deprecated `prior` arguments for ConstantMean. "
71+
"`prior` is deprecated, and will be ignored.",
72+
DeprecationWarning,
73+
)
74+
75+
# Ensure that old versions of the model still load
76+
self._register_load_state_dict_pre_hook(_ensure_updated_strategy_flag_set)
77+
1278
self.batch_shape = batch_shape
13-
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
14-
if prior is not None:
15-
self.register_prior("mean_prior", prior, self._constant_param, self._constant_closure)
79+
self.register_parameter(name="raw_constant", parameter=torch.nn.Parameter(torch.zeros(batch_shape)))
80+
if constant_prior is not None:
81+
self.register_prior("mean_prior", constant_prior, self._constant_param, self._constant_closure)
82+
if constant_constraint is not None:
83+
self.register_constraint("raw_constant", constant_constraint)
84+
85+
@property
86+
def constant(self):
87+
return self._constant_param(self)
1688

89+
@constant.setter
90+
def constant(self, value):
91+
self._constant_closure(self, value)
92+
93+
# We need a getter of this form so that we can pickle ConstantMean modules with a mean prior, see PR #1992
1794
def _constant_param(self, m):
18-
return m.constant
95+
if hasattr(m, "raw_constant_constraint"):
96+
return m.raw_constant_constraint.transform(m.raw_constant)
97+
return m.raw_constant
1998

99+
# We need a setter of this form so that we can pickle ConstantMean modules with a mean prior, see PR #1992
20100
def _constant_closure(self, m, value):
21101
if not torch.is_tensor(value):
22-
value = torch.as_tensor(value).to(self.constant)
23-
m.initialize(constant=value.reshape(self.constant.shape))
102+
value = torch.as_tensor(value).to(m.raw_constant)
24103

25-
def forward(self, input):
26-
if input.shape[:-2] == self.batch_shape:
27-
return self.constant.expand(input.shape[:-1])
104+
if hasattr(m, "raw_constant_constraint"):
105+
m.initialize(raw_constant=m.raw_constant_constraint.inverse_transform(value))
28106
else:
29-
return self.constant.expand(_mul_broadcast_shape(input.shape[:-1], self.constant.shape))
107+
m.initialize(raw_constant=value)
108+
109+
def forward(self, input):
110+
constant = self.constant.unsqueeze(-1) # *batch_shape x 1
111+
return constant.expand(torch.broadcast_shapes(constant.shape, input.shape[:-1]))

test/means/test_constant_mean.py

+65-8
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,40 @@
11
#!/usr/bin/env python3
22

3+
import math
34
import pickle
45
import unittest
6+
import warnings
7+
from collections import OrderedDict
58

69
import torch
710

11+
import gpytorch
12+
from gpytorch.constraints import GreaterThan
813
from gpytorch.means import ConstantMean
914
from gpytorch.priors import NormalPrior
1015
from gpytorch.test.base_mean_test_case import BaseMeanTestCase
16+
from gpytorch.utils.warnings import OldVersionWarning
17+
18+
19+
# Test class for loading models that have state dicts with the old ConstantMean parameter names
20+
class _GPModel(gpytorch.models.ExactGP):
21+
def __init__(self, mean_module):
22+
train_x = torch.randn(10, 3)
23+
train_y = torch.randn(10)
24+
likelihood = gpytorch.likelihoods.GaussianLikelihood()
25+
super().__init__(train_x, train_y, likelihood)
26+
self.mean_module = mean_module
1127

1228

1329
class TestConstantMean(BaseMeanTestCase, unittest.TestCase):
1430
batch_shape = None
1531

16-
def create_mean(self, prior=None):
17-
return ConstantMean(prior=prior, batch_shape=torch.Size([]))
32+
def create_mean(self, prior=None, constraint=None):
33+
return ConstantMean(
34+
constant_prior=prior,
35+
constant_constraint=constraint,
36+
batch_shape=(self.__class__.batch_shape or torch.Size([])),
37+
)
1838

1939
def test_prior(self):
2040
if self.batch_shape is None:
@@ -28,16 +48,53 @@ def test_prior(self):
2848
mean._constant_closure(mean, value)
2949
self.assertTrue(torch.equal(mean.constant.data, value.reshape(mean.constant.data.shape)))
3050

51+
def test_constraint(self):
52+
mean = self.create_mean()
53+
self.assertAllClose(mean.constant, torch.zeros(mean.constant.shape))
54+
55+
constraint = GreaterThan(1.5)
56+
mean = self.create_mean(constraint=constraint)
57+
self.assertTrue(torch.all(mean.constant >= 1.5))
58+
mean.constant = torch.full(self.__class__.batch_shape or torch.Size([]), fill_value=1.65)
59+
self.assertAllClose(mean.constant, torch.tensor(1.65).expand(mean.constant.shape))
60+
61+
def test_loading_old_module(self):
62+
batch_shape = self.__class__.batch_shape or torch.Size([])
63+
constant = torch.randn(batch_shape)
64+
mean = self.create_mean()
65+
model = _GPModel(mean)
66+
67+
old_state_dict = OrderedDict(
68+
[
69+
("likelihood.noise_covar.raw_noise", torch.tensor([0.0])),
70+
("likelihood.noise_covar.raw_noise_constraint.lower_bound", torch.tensor(1.0000e-04)),
71+
("likelihood.noise_covar.raw_noise_constraint.upper_bound", torch.tensor(math.inf)),
72+
("mean_module.constant", constant.unsqueeze(-1)),
73+
]
74+
)
75+
with warnings.catch_warnings(record=True) as ws:
76+
warnings.simplefilter("always", OldVersionWarning)
77+
model.load_state_dict(old_state_dict)
78+
self.assertTrue(any(issubclass(w.category, OldVersionWarning) for w in ws))
79+
self.assertEqual(model.mean_module.constant.data, constant)
80+
81+
new_state_dict = OrderedDict(
82+
[
83+
("likelihood.noise_covar.raw_noise", torch.tensor([0.0])),
84+
("likelihood.noise_covar.raw_noise_constraint.lower_bound", torch.tensor(1.0000e-04)),
85+
("likelihood.noise_covar.raw_noise_constraint.upper_bound", torch.tensor(math.inf)),
86+
("mean_module.raw_constant", constant),
87+
]
88+
)
89+
with warnings.catch_warnings(record=True) as ws:
90+
warnings.simplefilter("always", OldVersionWarning)
91+
model.load_state_dict(new_state_dict)
92+
self.assertFalse(any(issubclass(w.category, OldVersionWarning) for w in ws))
93+
3194

3295
class TestConstantMeanBatch(TestConstantMean, unittest.TestCase):
3396
batch_shape = torch.Size([3])
3497

35-
def create_mean(self, prior=None):
36-
return ConstantMean(prior=prior, batch_shape=self.__class__.batch_shape)
37-
3898

3999
class TestConstantMeanMultiBatch(TestConstantMean, unittest.TestCase):
40100
batch_shape = torch.Size([2, 3])
41-
42-
def create_mean(self, prior=None):
43-
return ConstantMean(prior=prior, batch_shape=self.__class__.batch_shape)

0 commit comments

Comments
 (0)