Skip to content

Move infinite interval bounds check into Interval constructor #2259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 35 additions & 45 deletions gpytorch/constraints/constraints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#!/usr/bin/env python3

from __future__ import annotations

import math
from typing import Optional

import torch
from torch import sigmoid
from torch import Tensor, sigmoid
from torch.nn import Module

from .. import settings
from ..utils.transforms import _get_inv_param_transform, inv_sigmoid, inv_softplus

# define softplus here instead of using torch.nn.functional.softplus because the functional version can't be pickled
Expand All @@ -23,11 +25,21 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
lower_bound (float or torch.Tensor): The lower bound on the parameter.
upper_bound (float or torch.Tensor): The upper bound on the parameter.
"""
lower_bound = torch.as_tensor(lower_bound).float()
upper_bound = torch.as_tensor(upper_bound).float()
dtype = torch.get_default_dtype()
lower_bound = torch.as_tensor(lower_bound).to(dtype)
upper_bound = torch.as_tensor(upper_bound).to(dtype)

if torch.any(torch.ge(lower_bound, upper_bound)):
raise RuntimeError("Got parameter bounds with empty intervals.")
raise ValueError("Got parameter bounds with empty intervals.")

if type(self) == Interval:
max_bound = torch.max(upper_bound)
min_bound = torch.min(lower_bound)
if max_bound == math.inf or min_bound == -math.inf:
raise ValueError(
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
"GreaterThan or LessThan instead."
)

super().__init__()

Expand All @@ -41,9 +53,7 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
self._inv_transform = _get_inv_param_transform(transform)

if initial_value is not None:
if not isinstance(initial_value, torch.Tensor):
initial_value = torch.tensor(initial_value)
self._initial_value = self.inverse_transform(initial_value)
self._initial_value = self.inverse_transform(torch.as_tensor(initial_value))
else:
self._initial_value = None

Expand All @@ -69,19 +79,19 @@ def _load_from_state_dict(
return result

@property
def enforced(self):
def enforced(self) -> bool:
return self._transform is not None

def check(self, tensor):
def check(self, tensor) -> bool:
return bool(torch.all(tensor <= self.upper_bound) and torch.all(tensor >= self.lower_bound))

def check_raw(self, tensor):
def check_raw(self, tensor) -> bool:
return bool(
torch.all((self.transform(tensor) <= self.upper_bound))
and torch.all(self.transform(tensor) >= self.lower_bound)
)

def intersect(self, other):
def intersect(self, other: Interval) -> Interval:
"""
Returns a new Interval constraint that is the intersection of this one and another specified one.

Expand All @@ -98,7 +108,7 @@ def intersect(self, other):
upper_bound = torch.min(self.upper_bound, other.upper_bound)
return Interval(lower_bound, upper_bound)

def transform(self, tensor):
def transform(self, tensor: Tensor) -> Tensor:
"""
Transforms a tensor to satisfy the specified bounds.

Expand All @@ -111,49 +121,29 @@ def transform(self, tensor):
if not self.enforced:
return tensor

if settings.debug.on():
max_bound = torch.max(self.upper_bound)
min_bound = torch.min(self.lower_bound)

if max_bound == math.inf or min_bound == -math.inf:
raise RuntimeError(
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
"GreaterThan or LessThan instead."
)

transformed_tensor = (self._transform(tensor) * (self.upper_bound - self.lower_bound)) + self.lower_bound

return transformed_tensor

def inverse_transform(self, transformed_tensor):
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
"""
Applies the inverse transformation.
"""
if not self.enforced:
return transformed_tensor

if settings.debug.on():
max_bound = torch.max(self.upper_bound)
min_bound = torch.min(self.lower_bound)

if max_bound == math.inf or min_bound == -math.inf:
raise RuntimeError(
"Cannot make an Interval directly with non-finite bounds. Use a derived class like "
"GreaterThan or LessThan instead."
)

tensor = self._inv_transform((transformed_tensor - self.lower_bound) / (self.upper_bound - self.lower_bound))

return tensor

@property
def initial_value(self):
def initial_value(self) -> Optional[Tensor]:
"""
The initial parameter value (if specified, None otherwise)
"""
return self._initial_value

def __repr__(self):
def __repr__(self) -> str:
if self.lower_bound.numel() == 1 and self.upper_bound.numel() == 1:
return self._get_name() + f"({self.lower_bound:.3E}, {self.upper_bound:.3E})"
else:
Expand All @@ -174,17 +164,17 @@ def __init__(self, lower_bound, transform=softplus, inv_transform=inv_softplus,
initial_value=initial_value,
)

def __repr__(self):
def __repr__(self) -> str:
if self.lower_bound.numel() == 1:
return self._get_name() + f"({self.lower_bound:.3E})"
else:
return super().__repr__()

def transform(self, tensor):
def transform(self, tensor: Tensor) -> Tensor:
transformed_tensor = self._transform(tensor) + self.lower_bound if self.enforced else tensor
return transformed_tensor

def inverse_transform(self, transformed_tensor):
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
tensor = self._inv_transform(transformed_tensor - self.lower_bound) if self.enforced else transformed_tensor
return tensor

Expand All @@ -193,14 +183,14 @@ class Positive(GreaterThan):
def __init__(self, transform=softplus, inv_transform=inv_softplus, initial_value=None):
super().__init__(lower_bound=0.0, transform=transform, inv_transform=inv_transform, initial_value=initial_value)

def __repr__(self):
def __repr__(self) -> str:
return self._get_name() + "()"

def transform(self, tensor):
def transform(self, tensor: Tensor) -> Tensor:
transformed_tensor = self._transform(tensor) if self.enforced else tensor
return transformed_tensor

def inverse_transform(self, transformed_tensor):
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
tensor = self._inv_transform(transformed_tensor) if self.enforced else transformed_tensor
return tensor

Expand All @@ -215,13 +205,13 @@ def __init__(self, upper_bound, transform=softplus, inv_transform=inv_softplus,
initial_value=initial_value,
)

def transform(self, tensor):
def transform(self, tensor: Tensor) -> Tensor:
transformed_tensor = -self._transform(-tensor) + self.upper_bound if self.enforced else tensor
return transformed_tensor

def inverse_transform(self, transformed_tensor):
def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
tensor = -self._inv_transform(-(transformed_tensor - self.upper_bound)) if self.enforced else transformed_tensor
return tensor

def __repr__(self):
def __repr__(self) -> str:
return self._get_name() + f"({self.upper_bound:.3E})"
8 changes: 8 additions & 0 deletions test/constraints/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import math
import unittest

import torch
Expand Down Expand Up @@ -69,6 +70,13 @@ def test_initial_value(self):
lkhd = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=constraint)
self.assertEqual(lkhd.noise.item(), 3.0)

def test_error_on_infinite(self):
err_msg = "Cannot make an Interval directly with non-finite bounds"
with self.assertRaisesRegex(ValueError, err_msg):
gpytorch.constraints.Interval(0.0, math.inf)
with self.assertRaisesRegex(ValueError, err_msg):
gpytorch.constraints.Interval(-math.inf, 0.0)


class TestGreaterThan(unittest.TestCase, BaseTestCase):
def test_transform_float_greater_than(self):
Expand Down