@@ -88,7 +88,8 @@ def extract_module(self, path: str, freeze: bool = True) -> Module:
88
88
89
89
def archive_model (serialization_dir : str ,
90
90
weights : str = _DEFAULT_WEIGHTS ,
91
- files_to_archive : Dict [str , str ] = None ) -> None :
91
+ files_to_archive : Dict [str , str ] = None ,
92
+ archive_path : str = None ) -> None :
92
93
"""
93
94
Archive the model weights, its training configuration, and its
94
95
vocabulary to `model.tar.gz`. Include the additional ``files_to_archive``
@@ -104,6 +105,10 @@ def archive_model(serialization_dir: str,
104
105
A mapping {flattened_key -> filename} of supplementary files to include
105
106
in the archive. That is, if you wanted to include ``params['model']['weights']``
106
107
then you would specify the key as `"model.weights"`.
108
+ archive_path : ``str``, optional, (default = None)
109
+ A full path to serialize the model to. The default is "model.tar.gz" inside the
110
+ serialization_dir. If you pass a directory here, we'll serialize the model
111
+ to "model.tar.gz" inside the directory.
107
112
"""
108
113
weights_file = os .path .join (serialization_dir , weights )
109
114
if not os .path .exists (weights_file ):
@@ -121,8 +126,12 @@ def archive_model(serialization_dir: str,
121
126
with open (fta_filename , 'w' ) as fta_file :
122
127
fta_file .write (json .dumps (files_to_archive ))
123
128
124
-
125
- archive_file = os .path .join (serialization_dir , "model.tar.gz" )
129
+ if archive_path is not None :
130
+ archive_file = archive_path
131
+ if os .path .isdir (archive_file ):
132
+ archive_file = os .path .join (archive_file , "model.tar.gz" )
133
+ else :
134
+ archive_file = os .path .join (serialization_dir , "model.tar.gz" )
126
135
logger .info ("archiving weights and vocabulary to %s" , archive_file )
127
136
with tarfile .open (archive_file , 'w:gz' ) as archive :
128
137
archive .add (config_file , arcname = CONFIG_NAME )
0 commit comments