Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 71a98c2

Browse files
authored
stricter typing for Optional[T] types, improve handling of Lazy params (#4743)
* stricter typing for Optional[T] types * fix linting error * fix checkpointer test * fix add_field method * fix '_extract_token_and_type_ids' method * fix typing on Lazy * improve Lazy API * add notes about Lazy to CHANGELOG * fix CHANGELOG
1 parent 27edfbf commit 71a98c2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+318
-242
lines changed

CHANGELOG.md

+10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
### Changed
11+
12+
- Enforced stricter typing requirements around the use of `Optional[T]` types.
13+
- Changed the behavior of `Lazy` types in `from_params` methods. Previously, if you defined a `Lazy` parameter like
14+
`foo: Lazy[Foo] = None` in a custom `from_params` classmethod, then `foo` would actually never be `None`.
15+
This behavior is now different. If no params were given for `foo`, it will be `None`.
16+
You can also now set default values for foo like `foo: Lazy[Foo] = Lazy(Foo)`.
17+
Or, if you want you want a default value but also want to allow for `None` values, you can
18+
write it like this: `foo: Optional[Lazy[Foo]] = Lazy(Foo)`.
19+
1020
### Fixed
1121

1222
- Made it possible to instantiate `TrainerCallback` from config files.

Makefile

+1-5
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,7 @@ format :
5050

5151
.PHONY : typecheck
5252
typecheck :
53-
mypy . \
54-
--ignore-missing-imports \
55-
--no-strict-optional \
56-
--no-site-packages \
57-
--cache-dir=/dev/null
53+
mypy . --cache-dir=/dev/null
5854

5955
.PHONY : test
6056
test :

allennlp/__main__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if os.environ.get("ALLENNLP_DEBUG"):
77
LEVEL = logging.DEBUG
88
else:
9-
level_name = os.environ.get("ALLENNLP_LOG_LEVEL")
9+
level_name = os.environ.get("ALLENNLP_LOG_LEVEL", "INFO")
1010
LEVEL = logging._nameToLevel.get(level_name, logging.INFO)
1111

1212
sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))

allennlp/commands/predict.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,10 @@ def __init__(
128128

129129
self._predictor = predictor
130130
self._input_file = input_file
131-
if output_file is not None:
132-
self._output_file = open(output_file, "w")
133-
else:
134-
self._output_file = None
131+
self._output_file = None if output_file is None else open(output_file, "w")
135132
self._batch_size = batch_size
136133
self._print_to_console = print_to_console
137-
if has_dataset_reader:
138-
self._dataset_reader = predictor._dataset_reader
139-
else:
140-
self._dataset_reader = None
134+
self._dataset_reader = None if not has_dataset_reader else predictor._dataset_reader
141135

142136
def _predict_json(self, batch_data: List[JsonDict]) -> Iterator[str]:
143137
if len(batch_data) == 1:

allennlp/commands/train.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ def _train_worker(
401401
include_package = include_package or []
402402

403403
if distributed:
404+
assert distributed_device_ids is not None
405+
404406
# Since the worker is spawned and not forked, the extra imports need to be done again.
405407
# Both the ones from the plugins and the ones from `include_package`.
406408
import_plugins()
@@ -556,7 +558,7 @@ def from_partial_objects(
556558
model: Lazy[Model],
557559
data_loader: Lazy[DataLoader],
558560
trainer: Lazy[Trainer],
559-
vocabulary: Lazy[Vocabulary] = None,
561+
vocabulary: Lazy[Vocabulary] = Lazy(Vocabulary),
560562
datasets_for_vocab_creation: List[str] = None,
561563
validation_dataset_reader: DatasetReader = None,
562564
validation_data_path: str = None,
@@ -610,7 +612,7 @@ def from_partial_objects(
610612
trainer: `Lazy[Trainer]`
611613
The `Trainer` that actually implements the training loop. This is a lazy object because
612614
it depends on the model that's going to be trained.
613-
vocabulary: `Lazy[Vocabulary]`, optional (default=`None`)
615+
vocabulary: `Lazy[Vocabulary]`, optional (default=`Lazy(Vocabulary)`)
614616
The `Vocabulary` that we will use to convert strings in the data to integer ids (and
615617
possibly set sizes of embedding matrices in the `Model`). By default we construct the
616618
vocabulary from the instances that we read.
@@ -664,8 +666,7 @@ def from_partial_objects(
664666
)
665667

666668
vocabulary_ = vocabulary.construct(instances=instance_generator)
667-
if not vocabulary_:
668-
vocabulary_ = Vocabulary.from_instances(instance_generator)
669+
669670
model_ = model.construct(vocab=vocabulary_, serialization_dir=serialization_dir)
670671

671672
# Initializing the model can have side effect of expanding the vocabulary.
@@ -682,13 +683,9 @@ def from_partial_objects(
682683

683684
data_loader_ = data_loader.construct(dataset=datasets["train"])
684685
validation_data = datasets.get("validation")
686+
validation_data_loader_: Optional[DataLoader] = None
685687
if validation_data is not None:
686-
# Because of the way Lazy[T] works, we can't check it's existence
687-
# _before_ we've tried to construct it. It returns None if it is not
688-
# present, so we try to construct it first, and then afterward back off
689-
# to the data_loader configuration used for training if it returns None.
690-
validation_data_loader_ = validation_data_loader.construct(dataset=validation_data)
691-
if validation_data_loader_ is None:
688+
if validation_data_loader is None:
692689
validation_data_loader_ = data_loader.construct(dataset=validation_data)
693690
if getattr(validation_data_loader_, "_batches_per_epoch", None) is not None:
694691
warnings.warn(
@@ -698,16 +695,16 @@ def from_partial_objects(
698695
"validation datasets for each epoch.",
699696
UserWarning,
700697
)
701-
else:
702-
validation_data_loader_ = None
698+
else:
699+
validation_data_loader_ = validation_data_loader.construct(dataset=validation_data)
703700

704701
test_data = datasets.get("test")
702+
test_data_loader: Optional[DataLoader] = None
705703
if test_data is not None:
706-
test_data_loader = validation_data_loader.construct(dataset=test_data)
707-
if test_data_loader is None:
704+
if validation_data_loader is None:
708705
test_data_loader = data_loader.construct(dataset=test_data)
709-
else:
710-
test_data_loader = None
706+
else:
707+
test_data_loader = validation_data_loader.construct(dataset=test_data)
711708

712709
# We don't need to pass serialization_dir and local_rank here, because they will have been
713710
# passed through the trainer by from_params already, because they were keyword arguments to
@@ -717,6 +714,7 @@ def from_partial_objects(
717714
data_loader=data_loader_,
718715
validation_data_loader=validation_data_loader_,
719716
)
717+
assert trainer_ is not None
720718

721719
return cls(
722720
serialization_dir=serialization_dir,

allennlp/common/file_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ class method.
445445
The unix timestamp of when the corresponding resource was cached or extracted.
446446
"""
447447

448-
size: int = None
448+
size: int = 0
449449
"""
450450
The size of the corresponding resource, in bytes.
451451
"""

allennlp/common/from_params.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def remove_optional(annotation: type):
112112
return annotation
113113

114114

115-
def infer_params(cls: Type[T], constructor: Callable[..., T] = None) -> Dict[str, Any]:
115+
def infer_params(
116+
cls: Type[T], constructor: Union[Callable[..., T], Callable[[T], None]] = None
117+
) -> Dict[str, Any]:
116118
if constructor is None:
117119
constructor = cls.__init__
118120

@@ -298,9 +300,6 @@ def pop_and_construct_arg(
298300

299301
popped_params = params.pop(name, default) if default != _NO_DEFAULT else params.pop(name)
300302
if popped_params is None:
301-
origin = getattr(annotation, "__origin__", None)
302-
if origin == Lazy:
303-
return Lazy(lambda **kwargs: None)
304303
return None
305304

306305
return construct_arg(class_name, name, popped_params, annotation, default, **extras)
@@ -450,7 +449,8 @@ def construct_arg(
450449
)
451450
elif origin == Lazy:
452451
if popped_params is default:
453-
return Lazy(lambda **kwargs: default)
452+
return default
453+
454454
value_cls = args[0]
455455
subextras = create_extras(value_cls, extras)
456456

@@ -509,7 +509,7 @@ def from_params(
509509
cls: Type[T],
510510
params: Params,
511511
constructor_to_call: Callable[..., T] = None,
512-
constructor_to_inspect: Callable[..., T] = None,
512+
constructor_to_inspect: Union[Callable[..., T], Callable[[T], None]] = None,
513513
**extras,
514514
) -> T:
515515
"""
@@ -584,7 +584,7 @@ def from_params(
584584
constructor_to_inspect = subclass.__init__
585585
constructor_to_call = subclass # type: ignore
586586
else:
587-
constructor_to_inspect = getattr(subclass, constructor_name)
587+
constructor_to_inspect = cast(Callable[..., T], getattr(subclass, constructor_name))
588588
constructor_to_call = constructor_to_inspect
589589

590590
if hasattr(subclass, "from_params"):
@@ -623,6 +623,7 @@ def from_params(
623623
params.assert_empty(cls.__name__)
624624
else:
625625
# This class has a constructor, so create kwargs for it.
626+
constructor_to_inspect = cast(Callable[..., T], constructor_to_inspect)
626627
kwargs = create_kwargs(constructor_to_inspect, cls, params, **extras)
627628

628629
return constructor_to_call(**kwargs) # type: ignore

allennlp/common/lazy.py

+32-20
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
from typing import Callable, Generic, TypeVar, Optional
1+
import inspect
2+
from typing import Callable, Generic, TypeVar, Type, Union
3+
4+
from allennlp.common.params import Params
5+
26

37
T = TypeVar("T")
48

59

610
class Lazy(Generic[T]):
711
"""
812
This class is for use when constructing objects using `FromParams`, when an argument to a
9-
constructor has a _sequential dependency_ with another argument to the same constructor. For
10-
example, in a `Trainer` class you might want to take a `Model` and an `Optimizer` as arguments,
11-
but the `Optimizer` needs to be constructed using the parameters from the `Model`. You can give
13+
constructor has a _sequential dependency_ with another argument to the same constructor.
14+
15+
For example, in a `Trainer` class you might want to take a `Model` and an `Optimizer` as arguments,
16+
but the `Optimizer` needs to be constructed using the parameters from the `Model`. You can give
1217
the type annotation `Lazy[Optimizer]` to the optimizer argument, then inside the constructor
1318
call `optimizer.construct(parameters=model.parameters)`.
1419
@@ -21,26 +26,33 @@ class Lazy(Generic[T]):
2126
construction is actually found in `FromParams`, where we have a special case for a `Lazy` type
2227
annotation.
2328
24-
!!! Warning
25-
The way this class is used in from_params means that optional constructor arguments CANNOT
26-
be compared to `None` _before_ it is constructed. See the example below for correct usage.
27-
28-
```
29+
```python
2930
@classmethod
30-
def my_constructor(cls, some_object: Lazy[MyObject] = None) -> MyClass:
31-
...
32-
# WRONG! some_object will never be None at this point, it will be
33-
# a Lazy[] that returns None
34-
obj = some_object or MyObjectDefault()
35-
# CORRECT:
36-
obj = some_object.construct(kwarg=kwarg) or MyObjectDefault()
37-
...
31+
def my_constructor(
32+
cls,
33+
some_object: Lazy[MyObject],
34+
optional_object: Lazy[MyObject] = None,
35+
required_object_with_default: Lazy[MyObject] = Lazy(MyObjectDefault),
36+
) -> MyClass:
37+
obj1 = some_object.construct()
38+
obj2 = None if optional_object is None else optional_object.construct()
39+
obj3 = required_object_with_default.construct()
3840
```
3941
4042
"""
4143

42-
def __init__(self, constructor: Callable[..., T]):
43-
self._constructor = constructor
44+
def __init__(self, constructor: Union[Type[T], Callable[..., T]]):
45+
constructor_to_use: Callable[..., T]
46+
47+
if inspect.isclass(constructor):
48+
49+
def constructor_to_use(**kwargs):
50+
return constructor.from_params(Params({}), **kwargs) # type: ignore[union-attr]
51+
52+
else:
53+
constructor_to_use = constructor
54+
55+
self._constructor = constructor_to_use
4456

45-
def construct(self, **kwargs) -> Optional[T]:
57+
def construct(self, **kwargs) -> T:
4658
return self._constructor(**kwargs)

allennlp/common/logging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def prepare_global_logging(
9999
if os.environ.get("ALLENNLP_DEBUG"):
100100
LEVEL = logging.DEBUG
101101
else:
102-
level_name = os.environ.get("ALLENNLP_LOG_LEVEL")
102+
level_name = os.environ.get("ALLENNLP_LOG_LEVEL", "INFO")
103103
LEVEL = logging._nameToLevel.get(level_name, logging.INFO)
104104

105105
file_handler.setLevel(LEVEL)

allennlp/common/params.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import OrderedDict
77
from collections.abc import MutableMapping
88
from os import PathLike
9-
from typing import Any, Dict, List, Union
9+
from typing import Any, Dict, List, Union, Optional
1010

1111
from overrides import overrides
1212

@@ -250,7 +250,7 @@ def pop(self, key: str, default: Any = DEFAULT, keep_as_dict: bool = False) -> A
250250
else:
251251
return self._check_is_dict(key, value)
252252

253-
def pop_int(self, key: str, default: Any = DEFAULT) -> int:
253+
def pop_int(self, key: str, default: Any = DEFAULT) -> Optional[int]:
254254
"""
255255
Performs a pop and coerces to an int.
256256
"""
@@ -260,7 +260,7 @@ def pop_int(self, key: str, default: Any = DEFAULT) -> int:
260260
else:
261261
return int(value)
262262

263-
def pop_float(self, key: str, default: Any = DEFAULT) -> float:
263+
def pop_float(self, key: str, default: Any = DEFAULT) -> Optional[float]:
264264
"""
265265
Performs a pop and coerces to a float.
266266
"""
@@ -270,7 +270,7 @@ def pop_float(self, key: str, default: Any = DEFAULT) -> float:
270270
else:
271271
return float(value)
272272

273-
def pop_bool(self, key: str, default: Any = DEFAULT) -> bool:
273+
def pop_bool(self, key: str, default: Any = DEFAULT) -> Optional[bool]:
274274
"""
275275
Performs a pop and coerces to a bool.
276276
"""

allennlp/common/registrable.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class Registrable(FromParams):
3838
a subclass to load all other subclasses and the abstract class).
3939
"""
4040

41-
_registry: Dict[Type, Dict[str, Tuple[Type, str]]] = defaultdict(dict)
42-
default_implementation: str = None
41+
_registry: Dict[Type, Dict[str, Tuple[Type, Optional[str]]]] = defaultdict(dict)
42+
default_implementation: Optional[str] = None
4343

4444
@classmethod
4545
def register(cls: Type[T], name: str, constructor: str = None, exist_ok: bool = False):
@@ -152,7 +152,7 @@ def resolve_class_name(cls: Type[T], name: str) -> Tuple[Type[T], Optional[str]]
152152
function to use).
153153
"""
154154
if name in Registrable._registry[cls]:
155-
subclass, constructor = Registrable._registry[cls].get(name)
155+
subclass, constructor = Registrable._registry[cls][name]
156156
return subclass, constructor
157157
elif "." in name:
158158
# This might be a fully qualified class name, so we'll try importing its "module"

allennlp/common/testing/distributed_test.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
def init_process(
1111
process_rank: int,
12-
distributed_device_ids: List[int] = None,
13-
world_size: int = 1,
14-
func: Callable = None,
12+
world_size: int,
13+
distributed_device_ids: List[int],
14+
func: Callable,
1515
func_args: Tuple = None,
1616
func_kwargs: Dict[str, Any] = None,
1717
master_addr: str = "127.0.0.1",
@@ -40,13 +40,13 @@ def init_process(
4040
timeout=datetime.timedelta(seconds=120),
4141
)
4242

43-
func(global_rank, world_size, gpu_id, *func_args, **func_kwargs)
43+
func(global_rank, world_size, gpu_id, *(func_args or []), **(func_kwargs or {}))
4444

4545
dist.barrier()
4646

4747

4848
def run_distributed_test(
49-
device_ids: List[int] = [-1, -1],
49+
device_ids: List[int] = None,
5050
func: Callable = None,
5151
*args,
5252
**kwargs,
@@ -62,14 +62,15 @@ def run_distributed_test(
6262
func: `Callable`
6363
`func` needs to be global for spawning the processes, so that it can be pickled.
6464
"""
65+
device_ids = device_ids or [-1, -1]
6566
check_for_gpu(device_ids)
6667
# "fork" start method is the default and should be preferred, except when we're
6768
# running the tests on GPU, in which case we need to use "spawn".
6869
start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork"
6970
nprocs = world_size = len(device_ids)
7071
mp.start_processes(
7172
init_process,
72-
args=(device_ids, world_size, func, args, kwargs),
73+
args=(world_size, device_ids, func, args, kwargs),
7374
nprocs=nprocs,
7475
start_method=start_method,
7576
)

0 commit comments

Comments
 (0)