1
1
#!/usr/bin/env python3
2
2
3
+ from __future__ import annotations
4
+
3
5
import math
6
+ from typing import Optional
4
7
5
8
import torch
6
- from torch import sigmoid
9
+ from torch import Tensor , sigmoid
7
10
from torch .nn import Module
8
11
9
- from .. import settings
10
12
from ..utils .transforms import _get_inv_param_transform , inv_sigmoid , inv_softplus
11
13
12
14
# define softplus here instead of using torch.nn.functional.softplus because the functional version can't be pickled
@@ -23,11 +25,21 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
23
25
lower_bound (float or torch.Tensor): The lower bound on the parameter.
24
26
upper_bound (float or torch.Tensor): The upper bound on the parameter.
25
27
"""
26
- lower_bound = torch .as_tensor (lower_bound ).float ()
27
- upper_bound = torch .as_tensor (upper_bound ).float ()
28
+ dtype = torch .get_default_dtype ()
29
+ lower_bound = torch .as_tensor (lower_bound ).to (dtype )
30
+ upper_bound = torch .as_tensor (upper_bound ).to (dtype )
28
31
29
32
if torch .any (torch .ge (lower_bound , upper_bound )):
30
- raise RuntimeError ("Got parameter bounds with empty intervals." )
33
+ raise ValueError ("Got parameter bounds with empty intervals." )
34
+
35
+ if type (self ) == Interval :
36
+ max_bound = torch .max (upper_bound )
37
+ min_bound = torch .min (lower_bound )
38
+ if max_bound == math .inf or min_bound == - math .inf :
39
+ raise ValueError (
40
+ "Cannot make an Interval directly with non-finite bounds. Use a derived class like "
41
+ "GreaterThan or LessThan instead."
42
+ )
31
43
32
44
super ().__init__ ()
33
45
@@ -41,9 +53,7 @@ def __init__(self, lower_bound, upper_bound, transform=sigmoid, inv_transform=in
41
53
self ._inv_transform = _get_inv_param_transform (transform )
42
54
43
55
if initial_value is not None :
44
- if not isinstance (initial_value , torch .Tensor ):
45
- initial_value = torch .tensor (initial_value )
46
- self ._initial_value = self .inverse_transform (initial_value )
56
+ self ._initial_value = self .inverse_transform (torch .as_tensor (initial_value ))
47
57
else :
48
58
self ._initial_value = None
49
59
@@ -69,19 +79,19 @@ def _load_from_state_dict(
69
79
return result
70
80
71
81
@property
72
- def enforced (self ):
82
+ def enforced (self ) -> bool :
73
83
return self ._transform is not None
74
84
75
- def check (self , tensor ):
85
+ def check (self , tensor ) -> bool :
76
86
return bool (torch .all (tensor <= self .upper_bound ) and torch .all (tensor >= self .lower_bound ))
77
87
78
- def check_raw (self , tensor ):
88
+ def check_raw (self , tensor ) -> bool :
79
89
return bool (
80
90
torch .all ((self .transform (tensor ) <= self .upper_bound ))
81
91
and torch .all (self .transform (tensor ) >= self .lower_bound )
82
92
)
83
93
84
- def intersect (self , other ) :
94
+ def intersect (self , other : Interval ) -> Interval :
85
95
"""
86
96
Returns a new Interval constraint that is the intersection of this one and another specified one.
87
97
@@ -98,7 +108,7 @@ def intersect(self, other):
98
108
upper_bound = torch .min (self .upper_bound , other .upper_bound )
99
109
return Interval (lower_bound , upper_bound )
100
110
101
- def transform (self , tensor ) :
111
+ def transform (self , tensor : Tensor ) -> Tensor :
102
112
"""
103
113
Transforms a tensor to satisfy the specified bounds.
104
114
@@ -111,49 +121,29 @@ def transform(self, tensor):
111
121
if not self .enforced :
112
122
return tensor
113
123
114
- if settings .debug .on ():
115
- max_bound = torch .max (self .upper_bound )
116
- min_bound = torch .min (self .lower_bound )
117
-
118
- if max_bound == math .inf or min_bound == - math .inf :
119
- raise RuntimeError (
120
- "Cannot make an Interval directly with non-finite bounds. Use a derived class like "
121
- "GreaterThan or LessThan instead."
122
- )
123
-
124
124
transformed_tensor = (self ._transform (tensor ) * (self .upper_bound - self .lower_bound )) + self .lower_bound
125
125
126
126
return transformed_tensor
127
127
128
- def inverse_transform (self , transformed_tensor ) :
128
+ def inverse_transform (self , transformed_tensor : Tensor ) -> Tensor :
129
129
"""
130
130
Applies the inverse transformation.
131
131
"""
132
132
if not self .enforced :
133
133
return transformed_tensor
134
134
135
- if settings .debug .on ():
136
- max_bound = torch .max (self .upper_bound )
137
- min_bound = torch .min (self .lower_bound )
138
-
139
- if max_bound == math .inf or min_bound == - math .inf :
140
- raise RuntimeError (
141
- "Cannot make an Interval directly with non-finite bounds. Use a derived class like "
142
- "GreaterThan or LessThan instead."
143
- )
144
-
145
135
tensor = self ._inv_transform ((transformed_tensor - self .lower_bound ) / (self .upper_bound - self .lower_bound ))
146
136
147
137
return tensor
148
138
149
139
@property
150
- def initial_value (self ):
140
+ def initial_value (self ) -> Optional [ Tensor ] :
151
141
"""
152
142
The initial parameter value (if specified, None otherwise)
153
143
"""
154
144
return self ._initial_value
155
145
156
- def __repr__ (self ):
146
+ def __repr__ (self ) -> str :
157
147
if self .lower_bound .numel () == 1 and self .upper_bound .numel () == 1 :
158
148
return self ._get_name () + f"({ self .lower_bound :.3E} , { self .upper_bound :.3E} )"
159
149
else :
@@ -174,17 +164,17 @@ def __init__(self, lower_bound, transform=softplus, inv_transform=inv_softplus,
174
164
initial_value = initial_value ,
175
165
)
176
166
177
- def __repr__ (self ):
167
+ def __repr__ (self ) -> str :
178
168
if self .lower_bound .numel () == 1 :
179
169
return self ._get_name () + f"({ self .lower_bound :.3E} )"
180
170
else :
181
171
return super ().__repr__ ()
182
172
183
- def transform (self , tensor ) :
173
+ def transform (self , tensor : Tensor ) -> Tensor :
184
174
transformed_tensor = self ._transform (tensor ) + self .lower_bound if self .enforced else tensor
185
175
return transformed_tensor
186
176
187
- def inverse_transform (self , transformed_tensor ) :
177
+ def inverse_transform (self , transformed_tensor : Tensor ) -> Tensor :
188
178
tensor = self ._inv_transform (transformed_tensor - self .lower_bound ) if self .enforced else transformed_tensor
189
179
return tensor
190
180
@@ -193,14 +183,14 @@ class Positive(GreaterThan):
193
183
def __init__ (self , transform = softplus , inv_transform = inv_softplus , initial_value = None ):
194
184
super ().__init__ (lower_bound = 0.0 , transform = transform , inv_transform = inv_transform , initial_value = initial_value )
195
185
196
- def __repr__ (self ):
186
+ def __repr__ (self ) -> str :
197
187
return self ._get_name () + "()"
198
188
199
- def transform (self , tensor ) :
189
+ def transform (self , tensor : Tensor ) -> Tensor :
200
190
transformed_tensor = self ._transform (tensor ) if self .enforced else tensor
201
191
return transformed_tensor
202
192
203
- def inverse_transform (self , transformed_tensor ) :
193
+ def inverse_transform (self , transformed_tensor : Tensor ) -> Tensor :
204
194
tensor = self ._inv_transform (transformed_tensor ) if self .enforced else transformed_tensor
205
195
return tensor
206
196
@@ -215,13 +205,13 @@ def __init__(self, upper_bound, transform=softplus, inv_transform=inv_softplus,
215
205
initial_value = initial_value ,
216
206
)
217
207
218
- def transform (self , tensor ) :
208
+ def transform (self , tensor : Tensor ) -> Tensor :
219
209
transformed_tensor = - self ._transform (- tensor ) + self .upper_bound if self .enforced else tensor
220
210
return transformed_tensor
221
211
222
- def inverse_transform (self , transformed_tensor ) :
212
+ def inverse_transform (self , transformed_tensor : Tensor ) -> Tensor :
223
213
tensor = - self ._inv_transform (- (transformed_tensor - self .upper_bound )) if self .enforced else transformed_tensor
224
214
return tensor
225
215
226
- def __repr__ (self ):
216
+ def __repr__ (self ) -> str :
227
217
return self ._get_name () + f"({ self .upper_bound :.3E} )"
0 commit comments