Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 71ebcd8

Browse files
authored
add infer_and_cast (#2324)
* add infer_and_cast * remove print statement + add comment * address PR feedback * pylint
1 parent 059b057 commit 71ebcd8

File tree

4 files changed

+109
-6
lines changed

4 files changed

+109
-6
lines changed

allennlp/common/params.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,46 @@ def evaluate_snippet(_filename: str, expr: str, **_kwargs) -> str:
3131

3232
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
3333

34+
# pylint: disable=inconsistent-return-statements
35+
def infer_and_cast(value: Any):
36+
"""
37+
In some cases we'll be feeding params dicts to functions we don't own;
38+
for example, PyTorch optimizers. In that case we can't use ``pop_int``
39+
or similar to force casts (which means you can't specify ``int`` parameters
40+
using environment variables). This function takes something that looks JSON-like
41+
and recursively casts things that look like (bool, int, float) to (bool, int, float).
42+
"""
43+
# pylint: disable=too-many-return-statements
44+
if isinstance(value, (int, float, bool)):
45+
# Already one of our desired types, so leave as is.
46+
return value
47+
elif isinstance(value, list):
48+
# Recursively call on each list element.
49+
return [infer_and_cast(item) for item in value]
50+
elif isinstance(value, dict):
51+
# Recursively call on each dict value.
52+
return {key: infer_and_cast(item) for key, item in value.items()}
53+
elif isinstance(value, str):
54+
# If it looks like a bool, make it a bool.
55+
if value.lower() == "true":
56+
return True
57+
elif value.lower() == "false":
58+
return False
59+
else:
60+
# See if it could be an int.
61+
try:
62+
return int(value)
63+
except ValueError:
64+
pass
65+
# See if it could be a float.
66+
try:
67+
return float(value)
68+
except ValueError:
69+
# Just return it as a string.
70+
return value
71+
else:
72+
raise ValueError(f"cannot infer type of {value}")
73+
# pylint: enable=inconsistent-return-statements
3474

3575
def unflatten(flat_dict: Dict[str, Any]) -> Dict[str, Any]:
3676
"""
@@ -259,18 +299,23 @@ def pop_choice(self, key: str, choices: List[Any], default_to_first_choice: bool
259299
raise ConfigurationError(message)
260300
return value
261301

262-
def as_dict(self, quiet=False):
302+
def as_dict(self, quiet: bool = False, infer_type_and_cast: bool = False):
263303
"""
264304
Sometimes we need to just represent the parameters as a dict, for instance when we pass
265-
them to a Keras layer(so that they can be serialised).
305+
them to PyTorch code.
266306
267307
Parameters
268308
----------
269309
quiet: bool, optional (default = False)
270310
Whether to log the parameters before returning them as a dict.
271311
"""
312+
if infer_type_and_cast:
313+
params_as_dict = infer_and_cast(self.params)
314+
else:
315+
params_as_dict = self.params
316+
272317
if quiet:
273-
return self.params
318+
return params_as_dict
274319

275320
def log_recursively(parameters, history):
276321
for key, value in parameters.items():
@@ -285,7 +330,7 @@ def log_recursively(parameters, history):
285330
"used subsequently.")
286331
logger.info("CURRENTLY DEFINED PARAMETERS: ")
287332
log_recursively(self.params, self.history)
288-
return self.params
333+
return params_as_dict
289334

290335
def as_flat_dict(self):
291336
"""

allennlp/tests/common/params_test.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytest
99

10-
from allennlp.common.params import Params, unflatten, with_fallback, parse_overrides
10+
from allennlp.common.params import Params, unflatten, with_fallback, parse_overrides, infer_and_cast
1111
from allennlp.common.testing import AllenNlpTestCase
1212

1313

@@ -314,3 +314,27 @@ def test_to_file(self):
314314
assert json.dumps(expected_ordered_params_dict) == json.dumps(ordered_params_dict)
315315
# check without preference orders doesn't give error
316316
params.to_file(file_path)
317+
318+
def test_infer_and_cast(self):
319+
lots_of_strings = {
320+
"a": ["10", "1.3", "true"],
321+
"b": {"x": 10, "y": "20.1", "z": "other things"},
322+
"c": "just a string"
323+
}
324+
325+
casted = {
326+
"a": [10, 1.3, True],
327+
"b": {"x": 10, "y": 20.1, "z": "other things"},
328+
"c": "just a string"
329+
}
330+
331+
assert infer_and_cast(lots_of_strings) == casted
332+
333+
contains_bad_data = {"x": 10, "y": int}
334+
with pytest.raises(ValueError, match="cannot infer type"):
335+
infer_and_cast(contains_bad_data)
336+
337+
params = Params(lots_of_strings)
338+
339+
assert params.as_dict() == lots_of_strings
340+
assert params.as_dict(infer_type_and_cast=True) == casted

allennlp/tests/training/optimizer_test.py

+27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# pylint: disable=invalid-name
2+
import pytest
3+
24
from allennlp.common.testing import AllenNlpTestCase
35
from allennlp.data import Vocabulary
46
from allennlp.common.params import Params
@@ -75,6 +77,31 @@ def test_optimizer_parameter_groups(self):
7577
assert len(param_groups[2]['params']) == 3
7678

7779

80+
def test_parameter_type_inference(self):
81+
# Should work ok even with lr as a string
82+
optimizer_params = Params({
83+
"type": "sgd",
84+
"lr": "0.1"
85+
})
86+
87+
parameters = [[n, p] for n, p in self.model.named_parameters() if p.requires_grad]
88+
optimizer = Optimizer.from_params(parameters, optimizer_params)
89+
90+
assert optimizer.defaults["lr"] == 0.1
91+
92+
# But should crash (in the Pytorch code) if we don't do the type inference
93+
optimizer_params = Params({
94+
"type": "sgd",
95+
"lr": "0.1",
96+
"infer_type_and_cast": False
97+
})
98+
99+
parameters = [[n, p] for n, p in self.model.named_parameters() if p.requires_grad]
100+
101+
with pytest.raises(TypeError):
102+
optimizer = Optimizer.from_params(parameters, optimizer_params)
103+
104+
78105
class TestDenseSparseAdam(AllenNlpTestCase):
79106

80107
def setUp(self):

allennlp/training/optimizers.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,14 @@ def from_params(cls, model_parameters: List, params: Params): # type: ignore
121121
else:
122122
num_parameters += parameter_group.numel()
123123
logger.info("Number of trainable parameters: %s", num_parameters)
124-
return Optimizer.by_name(optimizer)(parameter_groups, **params.as_dict()) # type: ignore
124+
125+
# By default we cast things that e.g. look like floats to floats before handing them
126+
# to the Optimizer constructor, but if you want to disable that behavior you could add a
127+
# "infer_type_and_cast": false
128+
# key to your "trainer.optimizer" config.
129+
infer_type_and_cast = params.pop_bool("infer_type_and_cast", True)
130+
params_as_dict = params.as_dict(infer_type_and_cast=infer_type_and_cast)
131+
return Optimizer.by_name(optimizer)(parameter_groups, **params_as_dict) # type: ignore
125132

126133
# We just use the Pytorch optimizers, so here we force them into
127134
# Registry._registry so we can build them from params.

0 commit comments

Comments
 (0)