Skip to content

Commit f863542

Browse files
authored
Move infinite interval bounds check into Interval constructor (#2259)
* Move infinite interval bounds check into Internval constructor Addresses #2258, avoids repeatedly checking the same condition in each forward pass. Avoids forcing GPU synchronization. * Remove second obsolete check in forward, some type annotations.
1 parent 2e1ccec commit f863542

File tree

2 files changed

+43
-45
lines changed

2 files changed

+43
-45
lines changed

gpytorch/constraints/constraints.py

+35-45
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
import math
6+
from typing import Optional
47

58
import torch
6-
from torch import sigmoid
9+
from torch import Tensor, sigmoid
710
from torch.nn import Module
811

9-
from .. import settings
1012
from ..utils.transforms import _get_inv_param_transform, inv_sigmoid, inv_softplus
1113

1214
# define softplus here instead of using torch.nn.functional.softplus because the functional version can't be pickled
@@ -23,11 +25,21 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
2325
lower_bound (float or torch.Tensor): The lower bound on the parameter.
2426
upper_bound (float or torch.Tensor): The upper bound on the parameter.
2527
"""
26-
lower_bound = torch.as_tensor(lower_bound).float()
27-
upper_bound = torch.as_tensor(upper_bound).float()
28+
dtype = torch.get_default_dtype()
29+
lower_bound = torch.as_tensor(lower_bound).to(dtype)
30+
upper_bound = torch.as_tensor(upper_bound).to(dtype)
2831

2932
if torch.any(torch.ge(lower_bound, upper_bound)):
30-
raise RuntimeError("Got parameter bounds with empty intervals.")
33+
raise ValueError("Got parameter bounds with empty intervals.")
34+
35+
if type(self) == Interval:
36+
max_bound = torch.max(upper_bound)
37+
min_bound = torch.min(lower_bound)
38+
if max_bound == math.inf or min_bound == -math.inf:
39+
raise ValueError(
40+
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
41+
"GreaterThan or LessThan instead."
42+
)
3143

3244
super().__init__()
3345

@@ -41,9 +53,7 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
4153
self._inv_transform = _get_inv_param_transform(transform)
4254

4355
if initial_value is not None:
44-
if not isinstance(initial_value, torch.Tensor):
45-
initial_value = torch.tensor(initial_value)
46-
self._initial_value = self.inverse_transform(initial_value)
56+
self._initial_value = self.inverse_transform(torch.as_tensor(initial_value))
4757
else:
4858
self._initial_value = None
4959

@@ -69,19 +79,19 @@ def _load_from_state_dict(
6979
return result
7080

7181
@property
72-
def enforced(self):
82+
def enforced(self) -> bool:
7383
return self._transform is not None
7484

75-
def check(self, tensor):
85+
def check(self, tensor) -> bool:
7686
return bool(torch.all(tensor <= self.upper_bound) and torch.all(tensor >= self.lower_bound))
7787

78-
def check_raw(self, tensor):
88+
def check_raw(self, tensor) -> bool:
7989
return bool(
8090
torch.all((self.transform(tensor) <= self.upper_bound))
8191
and torch.all(self.transform(tensor) >= self.lower_bound)
8292
)
8393

84-
def intersect(self, other):
94+
def intersect(self, other: Interval) -> Interval:
8595
"""
8696
Returns a new Interval constraint that is the intersection of this one and another specified one.
8797
@@ -98,7 +108,7 @@ def intersect(self, other):
98108
upper_bound = torch.min(self.upper_bound, other.upper_bound)
99109
return Interval(lower_bound, upper_bound)
100110

101-
def transform(self, tensor):
111+
def transform(self, tensor: Tensor) -> Tensor:
102112
"""
103113
Transforms a tensor to satisfy the specified bounds.
104114
@@ -111,49 +121,29 @@ def transform(self, tensor):
111121
if not self.enforced:
112122
return tensor
113123

114-
if settings.debug.on():
115-
max_bound = torch.max(self.upper_bound)
116-
min_bound = torch.min(self.lower_bound)
117-
118-
if max_bound == math.inf or min_bound == -math.inf:
119-
raise RuntimeError(
120-
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
121-
"GreaterThan or LessThan instead."
122-
)
123-
124124
transformed_tensor = (self._transform(tensor) * (self.upper_bound - self.lower_bound)) + self.lower_bound
125125

126126
return transformed_tensor
127127

128-
def inverse_transform(self, transformed_tensor):
128+
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
129129
"""
130130
Applies the inverse transformation.
131131
"""
132132
if not self.enforced:
133133
return transformed_tensor
134134

135-
if settings.debug.on():
136-
max_bound = torch.max(self.upper_bound)
137-
min_bound = torch.min(self.lower_bound)
138-
139-
if max_bound == math.inf or min_bound == -math.inf:
140-
raise RuntimeError(
141-
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
142-
"GreaterThan or LessThan instead."
143-
)
144-
145135
tensor = self._inv_transform((transformed_tensor - self.lower_bound) / (self.upper_bound - self.lower_bound))
146136

147137
return tensor
148138

149139
@property
150-
def initial_value(self):
140+
def initial_value(self) -> Optional[Tensor]:
151141
"""
152142
The initial parameter value (if specified, None otherwise)
153143
"""
154144
return self._initial_value
155145

156-
def __repr__(self):
146+
def __repr__(self) -> str:
157147
if self.lower_bound.numel() == 1 and self.upper_bound.numel() == 1:
158148
return self._get_name() + f"({self.lower_bound:.3E}, {self.upper_bound:.3E})"
159149
else:
@@ -174,17 +164,17 @@ def __init__(self, lower_bound, transform=softplus, inv_transform=inv_softplus,
174164
initial_value=initial_value,
175165
)
176166

177-
def __repr__(self):
167+
def __repr__(self) -> str:
178168
if self.lower_bound.numel() == 1:
179169
return self._get_name() + f"({self.lower_bound:.3E})"
180170
else:
181171
return super().__repr__()
182172

183-
def transform(self, tensor):
173+
def transform(self, tensor: Tensor) -> Tensor:
184174
transformed_tensor = self._transform(tensor) + self.lower_bound if self.enforced else tensor
185175
return transformed_tensor
186176

187-
def inverse_transform(self, transformed_tensor):
177+
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
188178
tensor = self._inv_transform(transformed_tensor - self.lower_bound) if self.enforced else transformed_tensor
189179
return tensor
190180

@@ -193,14 +183,14 @@ class Positive(GreaterThan):
193183
def __init__(self, transform=softplus, inv_transform=inv_softplus, initial_value=None):
194184
super().__init__(lower_bound=0.0, transform=transform, inv_transform=inv_transform, initial_value=initial_value)
195185

196-
def __repr__(self):
186+
def __repr__(self) -> str:
197187
return self._get_name() + "()"
198188

199-
def transform(self, tensor):
189+
def transform(self, tensor: Tensor) -> Tensor:
200190
transformed_tensor = self._transform(tensor) if self.enforced else tensor
201191
return transformed_tensor
202192

203-
def inverse_transform(self, transformed_tensor):
193+
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
204194
tensor = self._inv_transform(transformed_tensor) if self.enforced else transformed_tensor
205195
return tensor
206196

@@ -215,13 +205,13 @@ def __init__(self, upper_bound, transform=softplus, inv_transform=inv_softplus,
215205
initial_value=initial_value,
216206
)
217207

218-
def transform(self, tensor):
208+
def transform(self, tensor: Tensor) -> Tensor:
219209
transformed_tensor = -self._transform(-tensor) + self.upper_bound if self.enforced else tensor
220210
return transformed_tensor
221211

222-
def inverse_transform(self, transformed_tensor):
212+
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
223213
tensor = -self._inv_transform(-(transformed_tensor - self.upper_bound)) if self.enforced else transformed_tensor
224214
return tensor
225215

226-
def __repr__(self):
216+
def __repr__(self) -> str:
227217
return self._get_name() + f"({self.upper_bound:.3E})"

test/constraints/test_constraints.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22

3+
import math
34
import unittest
45

56
import torch
@@ -69,6 +70,13 @@ def test_initial_value(self):
6970
lkhd = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=constraint)
7071
self.assertEqual(lkhd.noise.item(), 3.0)
7172

73+
def test_error_on_infinite(self):
74+
err_msg = "Cannot make an Interval directly with non-finite bounds"
75+
with self.assertRaisesRegex(ValueError, err_msg):
76+
gpytorch.constraints.Interval(0.0, math.inf)
77+
with self.assertRaisesRegex(ValueError, err_msg):
78+
gpytorch.constraints.Interval(-math.inf, 0.0)
79+
7280

7381
class TestGreaterThan(unittest.TestCase, BaseTestCase):
7482
def test_transform_float_greater_than(self):

0 commit comments

Comments
 (0)