Skip to content

Commit 74870c5

Browse files
committed
Move infinite interval bounds check into Internval constructor
Addresses #2258, avoids repeatedly checking the same condition in each forward pass. Avoids forcing GPU synchronization.
1 parent 2e1ccec commit 74870c5

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

gpytorch/constraints/constraints.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
from __future__ import annotations
4+
35
import math
46

57
import torch
@@ -23,11 +25,21 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
2325
lower_bound (float or torch.Tensor): The lower bound on the parameter.
2426
upper_bound (float or torch.Tensor): The upper bound on the parameter.
2527
"""
26-
lower_bound = torch.as_tensor(lower_bound).float()
27-
upper_bound = torch.as_tensor(upper_bound).float()
28+
dtype = torch.get_default_dtype()
29+
lower_bound = torch.as_tensor(lower_bound).to(dtype)
30+
upper_bound = torch.as_tensor(upper_bound).to(dtype)
2831

2932
if torch.any(torch.ge(lower_bound, upper_bound)):
30-
raise RuntimeError("Got parameter bounds with empty intervals.")
33+
raise ValueError("Got parameter bounds with empty intervals.")
34+
35+
if type(self) == Interval:
36+
max_bound = torch.max(upper_bound)
37+
min_bound = torch.min(lower_bound)
38+
if max_bound == math.inf or min_bound == -math.inf:
39+
raise ValueError(
40+
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
41+
"GreaterThan or LessThan instead."
42+
)
3143

3244
super().__init__()
3345

@@ -111,16 +123,6 @@ def transform(self, tensor):
111123
if not self.enforced:
112124
return tensor
113125

114-
if settings.debug.on():
115-
max_bound = torch.max(self.upper_bound)
116-
min_bound = torch.min(self.lower_bound)
117-
118-
if max_bound == math.inf or min_bound == -math.inf:
119-
raise RuntimeError(
120-
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
121-
"GreaterThan or LessThan instead."
122-
)
123-
124126
transformed_tensor = (self._transform(tensor) * (self.upper_bound - self.lower_bound)) + self.lower_bound
125127

126128
return transformed_tensor

test/constraints/test_constraints.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22

3+
import math
34
import unittest
45

56
import torch
@@ -69,6 +70,13 @@ def test_initial_value(self):
6970
lkhd = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=constraint)
7071
self.assertEqual(lkhd.noise.item(), 3.0)
7172

73+
def test_error_on_infinite(self):
74+
err_msg = "Cannot make an Interval directly with non-finite bounds"
75+
with self.assertRaisesRegex(ValueError, err_msg):
76+
gpytorch.constraints.Interval(0.0, math.inf)
77+
with self.assertRaisesRegex(ValueError, err_msg):
78+
gpytorch.constraints.Interval(-math.inf, 0.0)
79+
7280

7381
class TestGreaterThan(unittest.TestCase, BaseTestCase):
7482
def test_transform_float_greater_than(self):

0 commit comments

Comments
 (0)