Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 6d480aa

Browse files
authored
Improve handling of **kwargs in FromParams (#4629)
* Improve handling of **kwargs in FromParams * remove duplicated line * improved handling of *args and *kwargs in FromParams * changelog * **kwargs is needed * Revert "**kwargs is needed" This reverts commit c74a94c. * revert FromParams changes * ignore *args * change optimizers' inheritance order * add bare constructor to ForParams * Revert "add bare constructor to ForParams" This reverts commit 4afe072. * Revert "change optimizers' inheritance order" This reverts commit d9fe036. * remove comment
1 parent bf3206a commit 6d480aa

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
### Fixed
11+
12+
- Ignore *args when constructing classes with `FromParams`.
13+
1014
## [v1.1.0](https://github.com/allenai/allennlp/releases/tag/v1.1.0) - 2020-09-08
1115

1216
### Fixed

allennlp/common/from_params.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,22 @@ def remove_optional(annotation: type):
113113

114114

115115
def infer_params(cls: Type[T], constructor: Callable[..., T] = None) -> Dict[str, Any]:
116-
if cls == FromParams:
117-
return {}
118116
if constructor is None:
119117
constructor = cls.__init__
120118

121119
signature = inspect.signature(constructor)
122120
parameters = dict(signature.parameters)
123121

124122
has_kwargs = False
123+
var_positional_key = None
125124
for param in parameters.values():
126125
if param.kind == param.VAR_KEYWORD:
127126
has_kwargs = True
127+
elif param.kind == param.VAR_POSITIONAL:
128+
var_positional_key = param.name
129+
130+
if var_positional_key:
131+
del parameters[var_positional_key]
128132

129133
if not has_kwargs:
130134
return parameters

tests/common/from_params_test.py

+19
Original file line numberDiff line numberDiff line change
@@ -961,3 +961,22 @@ def __init__(self, a: int, b: str = None, **kwargs) -> None:
961961
assert foo.a == 2
962962
assert foo.b == "hi"
963963
assert foo.c == {"2": "3"}
964+
965+
def test_from_params_child_has_kwargs_base_implicit_constructor(self):
966+
class Foo(FromParams):
967+
pass
968+
969+
class Bar(Foo):
970+
def __init__(self, a: int, **kwargs) -> None:
971+
self.a = a
972+
973+
bar = Bar.from_params(Params({"a": 2}))
974+
assert bar.a == 2
975+
976+
def test_from_params_has_args(self):
977+
class Foo(FromParams):
978+
def __init__(self, a: int, *args) -> None:
979+
self.a = a
980+
981+
foo = Foo.from_params(Params({"a": 2}))
982+
assert foo.a == 2

0 commit comments

Comments
 (0)