1
1
#!/usr/bin/env python3
2
2
3
+ from __future__ import annotations
4
+
3
5
import math
4
6
5
7
import torch
@@ -23,11 +25,21 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
23
25
lower_bound (float or torch.Tensor): The lower bound on the parameter.
24
26
upper_bound (float or torch.Tensor): The upper bound on the parameter.
25
27
"""
26
- lower_bound = torch .as_tensor (lower_bound ).float ()
27
- upper_bound = torch .as_tensor (upper_bound ).float ()
28
+ dtype = torch .get_default_dtype ()
29
+ lower_bound = torch .as_tensor (lower_bound ).to (dtype )
30
+ upper_bound = torch .as_tensor (upper_bound ).to (dtype )
28
31
29
32
if torch .any (torch .ge (lower_bound , upper_bound )):
30
- raise RuntimeError ("Got parameter bounds with empty intervals." )
33
+ raise ValueError ("Got parameter bounds with empty intervals." )
34
+
35
+ if type (self ) == Interval :
36
+ max_bound = torch .max (upper_bound )
37
+ min_bound = torch .min (lower_bound )
38
+ if max_bound == math .inf or min_bound == - math .inf :
39
+ raise ValueError (
40
+ "Cannot make an Interval directly with non-finite bounds. Use a derived class like "
41
+ "GreaterThan or LessThan instead."
42
+ )
31
43
32
44
super ().__init__ ()
33
45
@@ -111,16 +123,6 @@ def transform(self, tensor):
111
123
if not self .enforced :
112
124
return tensor
113
125
114
- if settings .debug .on ():
115
- max_bound = torch .max (self .upper_bound )
116
- min_bound = torch .min (self .lower_bound )
117
-
118
- if max_bound == math .inf or min_bound == - math .inf :
119
- raise RuntimeError (
120
- "Cannot make an Interval directly with non-finite bounds. Use a derived class like "
121
- "GreaterThan or LessThan instead."
122
- )
123
-
124
126
transformed_tensor = (self ._transform (tensor ) * (self .upper_bound - self .lower_bound )) + self .lower_bound
125
127
126
128
return transformed_tensor
0 commit comments