@@ -31,6 +31,46 @@ def evaluate_snippet(_filename: str, expr: str, **_kwargs) -> str:
31
31
32
32
logger = logging .getLogger (__name__ ) # pylint: disable=invalid-name
33
33
34
+ # pylint: disable=inconsistent-return-statements
35
+ def infer_and_cast (value : Any ):
36
+ """
37
+ In some cases we'll be feeding params dicts to functions we don't own;
38
+ for example, PyTorch optimizers. In that case we can't use ``pop_int``
39
+ or similar to force casts (which means you can't specify ``int`` parameters
40
+ using environment variables). This function takes something that looks JSON-like
41
+ and recursively casts things that look like (bool, int, float) to (bool, int, float).
42
+ """
43
+ # pylint: disable=too-many-return-statements
44
+ if isinstance (value , (int , float , bool )):
45
+ # Already one of our desired types, so leave as is.
46
+ return value
47
+ elif isinstance (value , list ):
48
+ # Recursively call on each list element.
49
+ return [infer_and_cast (item ) for item in value ]
50
+ elif isinstance (value , dict ):
51
+ # Recursively call on each dict value.
52
+ return {key : infer_and_cast (item ) for key , item in value .items ()}
53
+ elif isinstance (value , str ):
54
+ # If it looks like a bool, make it a bool.
55
+ if value .lower () == "true" :
56
+ return True
57
+ elif value .lower () == "false" :
58
+ return False
59
+ else :
60
+ # See if it could be an int.
61
+ try :
62
+ return int (value )
63
+ except ValueError :
64
+ pass
65
+ # See if it could be a float.
66
+ try :
67
+ return float (value )
68
+ except ValueError :
69
+ # Just return it as a string.
70
+ return value
71
+ else :
72
+ raise ValueError (f"cannot infer type of { value } " )
73
+ # pylint: enable=inconsistent-return-statements
34
74
35
75
def unflatten (flat_dict : Dict [str , Any ]) -> Dict [str , Any ]:
36
76
"""
@@ -259,18 +299,23 @@ def pop_choice(self, key: str, choices: List[Any], default_to_first_choice: bool
259
299
raise ConfigurationError (message )
260
300
return value
261
301
262
- def as_dict (self , quiet = False ):
302
+ def as_dict (self , quiet : bool = False , infer_type_and_cast : bool = False ):
263
303
"""
264
304
Sometimes we need to just represent the parameters as a dict, for instance when we pass
265
- them to a Keras layer(so that they can be serialised) .
305
+ them to PyTorch code .
266
306
267
307
Parameters
268
308
----------
269
309
quiet: bool, optional (default = False)
270
310
Whether to log the parameters before returning them as a dict.
271
311
"""
312
+ if infer_type_and_cast :
313
+ params_as_dict = infer_and_cast (self .params )
314
+ else :
315
+ params_as_dict = self .params
316
+
272
317
if quiet :
273
- return self . params
318
+ return params_as_dict
274
319
275
320
def log_recursively (parameters , history ):
276
321
for key , value in parameters .items ():
@@ -285,7 +330,7 @@ def log_recursively(parameters, history):
285
330
"used subsequently." )
286
331
logger .info ("CURRENTLY DEFINED PARAMETERS: " )
287
332
log_recursively (self .params , self .history )
288
- return self . params
333
+ return params_as_dict
289
334
290
335
def as_flat_dict (self ):
291
336
"""
0 commit comments