Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit e53d185

Browse files
authored
Bug fix for case when param type is Optional[Union...] (#4510)
* Bug fix for case when param type is Optional[Union...] * Update CHANGELOG * Adding test case
1 parent 14f63b7 commit e53d185

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Removed unnecessary warning about deadlocks in `DataLoader`.
1313
- Use slower tqdm intervals when output is being piped or redirected.
14-
- Fixed testing models that only return a loss when they are in training mode
14+
- Fixed testing models that only return a loss when they are in training mode.
15+
- Fixed a bug in `FromParams` that causes silent failure in case of the parameter type being Optional[Union[...]].
1516

1617
### Added
1718

allennlp/common/from_params.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ def remove_optional(annotation: type):
105105
"""
106106
origin = getattr(annotation, "__origin__", None)
107107
args = getattr(annotation, "__args__", ())
108-
if origin == Union and len(args) == 2 and args[1] == type(None): # noqa
109-
return args[0]
108+
109+
if origin == Union:
110+
return Union[tuple([arg for arg in args if arg != type(None)])] # noqa: E721
110111
else:
111112
return annotation
112113

tests/common/from_params_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -845,3 +845,19 @@ def __init__(self):
845845

846846
with pytest.raises(ConfigurationError, match="no registered concrete types"):
847847
B.from_params(Params({}))
848+
849+
def test_from_params_raises_error_on_wrong_parameter_name_in_optional_union(self):
850+
class NestedClass(FromParams):
851+
def __init__(self, varname: Optional[str] = None):
852+
self.varname = varname
853+
854+
class WrapperClass(FromParams):
855+
def __init__(self, nested_class: Optional[Union[str, NestedClass]] = None):
856+
if isinstance(nested_class, str):
857+
nested_class = NestedClass(varname=nested_class)
858+
self.nested_class = nested_class
859+
860+
with pytest.raises(ConfigurationError):
861+
WrapperClass.from_params(
862+
params=Params({"nested_class": {"wrong_varname": "varstring"}})
863+
)

0 commit comments

Comments
 (0)