This repository was archived by the owner on Dec 16, 2022. It is now read-only.
File tree 3 files changed +21
-3
lines changed
3 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -11,7 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
11
11
12
12
- Removed unnecessary warning about deadlocks in ` DataLoader ` .
13
13
- Use slower tqdm intervals when output is being piped or redirected.
14
- - Fixed testing models that only return a loss when they are in training mode
14
+ - Fixed testing models that only return a loss when they are in training mode.
15
+ - Fixed a bug in ` FromParams ` that causes silent failure in case of the parameter type being Optional[ Union[ ...]] .
15
16
16
17
### Added
17
18
Original file line number Diff line number Diff line change @@ -105,8 +105,9 @@ def remove_optional(annotation: type):
105
105
"""
106
106
origin = getattr (annotation , "__origin__" , None )
107
107
args = getattr (annotation , "__args__" , ())
108
- if origin == Union and len (args ) == 2 and args [1 ] == type (None ): # noqa
109
- return args [0 ]
108
+
109
+ if origin == Union :
110
+ return Union [tuple ([arg for arg in args if arg != type (None )])] # noqa: E721
110
111
else :
111
112
return annotation
112
113
Original file line number Diff line number Diff line change @@ -845,3 +845,19 @@ def __init__(self):
845
845
846
846
with pytest .raises (ConfigurationError , match = "no registered concrete types" ):
847
847
B .from_params (Params ({}))
848
+
849
+ def test_from_params_raises_error_on_wrong_parameter_name_in_optional_union (self ):
850
+ class NestedClass (FromParams ):
851
+ def __init__ (self , varname : Optional [str ] = None ):
852
+ self .varname = varname
853
+
854
+ class WrapperClass (FromParams ):
855
+ def __init__ (self , nested_class : Optional [Union [str , NestedClass ]] = None ):
856
+ if isinstance (nested_class , str ):
857
+ nested_class = NestedClass (varname = nested_class )
858
+ self .nested_class = nested_class
859
+
860
+ with pytest .raises (ConfigurationError ):
861
+ WrapperClass .from_params (
862
+ params = Params ({"nested_class" : {"wrong_varname" : "varstring" }})
863
+ )
You can’t perform that action at this time.
0 commit comments