Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 6d8da97

Browse files
authored
make archival take an optional output path (#2510)
1 parent fefc439 commit 6d8da97

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

allennlp/models/archival.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def extract_module(self, path: str, freeze: bool = True) -> Module:
8888

8989
def archive_model(serialization_dir: str,
9090
weights: str = _DEFAULT_WEIGHTS,
91-
files_to_archive: Dict[str, str] = None) -> None:
91+
files_to_archive: Dict[str, str] = None,
92+
archive_path: str = None) -> None:
9293
"""
9394
Archive the model weights, its training configuration, and its
9495
vocabulary to `model.tar.gz`. Include the additional ``files_to_archive``
@@ -104,6 +105,10 @@ def archive_model(serialization_dir: str,
104105
A mapping {flattened_key -> filename} of supplementary files to include
105106
in the archive. That is, if you wanted to include ``params['model']['weights']``
106107
then you would specify the key as `"model.weights"`.
108+
archive_path : ``str``, optional, (default = None)
109+
A full path to serialize the model to. The default is "model.tar.gz" inside the
110+
serialization_dir. If you pass a directory here, we'll serialize the model
111+
to "model.tar.gz" inside the directory.
107112
"""
108113
weights_file = os.path.join(serialization_dir, weights)
109114
if not os.path.exists(weights_file):
@@ -121,8 +126,12 @@ def archive_model(serialization_dir: str,
121126
with open(fta_filename, 'w') as fta_file:
122127
fta_file.write(json.dumps(files_to_archive))
123128

124-
125-
archive_file = os.path.join(serialization_dir, "model.tar.gz")
129+
if archive_path is not None:
130+
archive_file = archive_path
131+
if os.path.isdir(archive_file):
132+
archive_file = os.path.join(archive_file, "model.tar.gz")
133+
else:
134+
archive_file = os.path.join(serialization_dir, "model.tar.gz")
126135
logger.info("archiving weights and vocabulary to %s", archive_file)
127136
with tarfile.open(archive_file, 'w:gz') as archive:
128137
archive.add(config_file, arcname=CONFIG_NAME)

allennlp/tests/models/archival_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ def test_archiving(self):
7676
params2 = archive.config
7777
assert params2.as_dict() == params_copy
7878

79+
def test_archive_model_uses_archive_path(self):
80+
81+
serialization_dir = self.TEST_DIR / 'serialization'
82+
# Train a model
83+
train_model(self.params, serialization_dir=serialization_dir)
84+
# Use a new path.
85+
archive_model(serialization_dir=serialization_dir,
86+
archive_path=serialization_dir / "new_path.tar.gz")
87+
archive = load_archive(serialization_dir / 'new_path.tar.gz')
88+
assert archive
89+
7990
def test_extra_files(self):
8091

8192
serialization_dir = self.TEST_DIR / 'serialization'

0 commit comments

Comments
 (0)