Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 9ac6c76

Browse files
authored
Allow overrides to be JSON string or dict (#4680)
1 parent 55cfb47 commit 9ac6c76

File tree

6 files changed

+26
-16
lines changed

6 files changed

+26
-16
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828

2929
### Changed
3030

31+
- Allow overrides to be JSON string or `dict`.
3132
- `transformers` dependency updated to version 3.1.0.
3233
- When `cached_path` is called on a local archive with `extract_archive=True`, the archive is now extracted into a unique subdirectory of the cache root instead of a subdirectory of the archive's directory. The extraction directory is also unique to the modification time of the archive, so if the file changes, subsequent calls to `cached_path` will know to re-extract the archive.
3334
- Removed the `truncation_strategy` parameter to `PretrainedTransformerTokenizer`. The way we're calling the tokenizer, the truncation strategy takes no effect anyways.

allennlp/commands/train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def train_model_from_args(args: argparse.Namespace):
122122
def train_model_from_file(
123123
parameter_filename: Union[str, PathLike],
124124
serialization_dir: Union[str, PathLike],
125-
overrides: str = "",
125+
overrides: Union[str, Dict[str, Any]] = "",
126126
recover: bool = False,
127127
force: bool = False,
128128
node_rank: int = 0,
@@ -140,8 +140,8 @@ def train_model_from_file(
140140
serialization_dir : `str`
141141
The directory in which to save results and logs. We just pass this along to
142142
[`train_model`](#train_model).
143-
overrides : `str`
144-
A JSON string that we will use to override values in the input parameter file.
143+
overrides : `Union[str, Dict[str, Any]]`, optional (default = `""`)
144+
A JSON string or a dict that we will use to override values in the input parameter file.
145145
recover : `bool`, optional (default=`False`)
146146
If `True`, we will try to recover a training run from an existing serialization
147147
directory. This is only intended for use when something actually crashed during the middle

allennlp/common/params.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,10 @@ def _check_is_dict(self, new_history, value):
457457

458458
@classmethod
459459
def from_file(
460-
cls, params_file: Union[str, PathLike], params_overrides: str = "", ext_vars: dict = None
460+
cls,
461+
params_file: Union[str, PathLike],
462+
params_overrides: Union[str, Dict[str, Any]] = "",
463+
ext_vars: dict = None,
461464
) -> "Params":
462465
"""
463466
Load a `Params` object from a configuration file.
@@ -468,7 +471,7 @@ def from_file(
468471
469472
The path to the configuration file to load.
470473
471-
params_overrides: `str`, optional
474+
params_overrides: `Union[str, Dict[str, Any]]`, optional (default = `""`)
472475
473476
A dict of overrides that can be applied to final object.
474477
e.g. {"model.embedding_dim": 10}
@@ -490,6 +493,8 @@ def from_file(
490493

491494
file_dict = json.loads(evaluate_file(params_file, ext_vars=ext_vars))
492495

496+
if isinstance(params_overrides, dict):
497+
params_overrides = json.dumps(params_overrides)
493498
overrides_dict = parse_overrides(params_overrides)
494499
param_dict = with_fallback(preferred=overrides_dict, fallback=file_dict)
495500

allennlp/models/archival.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Helper functions for archiving models and restoring archived models.
33
"""
44
from os import PathLike
5-
from typing import NamedTuple, Union
5+
from typing import NamedTuple, Union, Dict, Any
66
import logging
77
import os
88
import tempfile
@@ -132,7 +132,7 @@ def archive_model(
132132
def load_archive(
133133
archive_file: Union[str, Path],
134134
cuda_device: int = -1,
135-
overrides: str = "",
135+
overrides: Union[str, Dict[str, Any]] = "",
136136
weights_file: str = None,
137137
) -> Archive:
138138
"""
@@ -145,7 +145,7 @@ def load_archive(
145145
cuda_device : `int`, optional (default = `-1`)
146146
If `cuda_device` is >= 0, the model will be loaded onto the
147147
corresponding GPU. Otherwise it will be loaded onto the CPU.
148-
overrides : `str`, optional (default = `""`)
148+
overrides : `Union[str, Dict[str, Any]]`, optional (default = `""`)
149149
JSON overrides to apply to the unarchived `Params` object.
150150
weights_file : `str`, optional (default = `None`)
151151
The weights file to use. If unspecified, weights.th in the archive_file will be used.

allennlp/predictors/predictor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def from_path(
239239
dataset_reader_to_load: str = "validation",
240240
frozen: bool = True,
241241
import_plugins: bool = True,
242-
overrides: str = "",
242+
overrides: Union[str, Dict[str, Any]] = "",
243243
) -> "Predictor":
244244
"""
245245
Instantiate a `Predictor` from an archive path.
@@ -267,7 +267,7 @@ def from_path(
267267
This comes with additional overhead, but means you don't need to explicitly
268268
import the modules that your predictor depends on as long as those modules
269269
can be found by `allennlp.common.plugins.import_plugins()`.
270-
overrides : `str`, optional (default = `""`)
270+
overrides : `Union[str, Dict[str, Any]]`, optional (default = `""`)
271271
JSON overrides to apply to the unarchived `Params` object.
272272
273273
# Returns

tests/common/params_test.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,18 @@ def test_bad_unicode_environment_variables(self):
3333
Params.from_file(filename)
3434
del os.environ["BAD_ENVIRONMENT_VARIABLE"]
3535

36-
def test_overrides(self):
36+
@pytest.mark.parametrize("input_type", [dict, str])
37+
def test_overrides(self, input_type):
3738
filename = self.FIXTURES_ROOT / "simple_tagger" / "experiment.json"
38-
overrides = (
39-
'{ "train_data_path": "FOO", "model": { "type": "BAR" },'
40-
'"model.text_field_embedder.tokens.type": "BAZ",'
41-
'"data_loader.batch_sampler.sorting_keys.0": "question"}'
39+
overrides = {
40+
"train_data_path": "FOO",
41+
"model": {"type": "BAR"},
42+
"model.text_field_embedder.tokens.type": "BAZ",
43+
"data_loader.batch_sampler.sorting_keys.0": "question",
44+
}
45+
params = Params.from_file(
46+
filename, overrides if input_type == dict else json.dumps(overrides)
4247
)
43-
params = Params.from_file(filename, overrides)
4448

4549
assert "dataset_reader" in params
4650
assert "trainer" in params

0 commit comments

Comments
 (0)