Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

save meta data with model archives #5209

Merged
merged 3 commits into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Meta data defined by the class `allennlp.common.meta.Meta` is now saved in the serialization directory and archive file
when training models from the command line. This is also now part of the `Archive` named tuple that's returned from `load_archive()`.
- Added `nn.util.distributed_device()` helper function.
- Added `allennlp.nn.util.load_state_dict` helper function.
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
Expand Down
4 changes: 4 additions & 0 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common import Params, Registrable, Lazy
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.common.meta import Meta, META_NAME
from allennlp.common import logging as common_logging
from allennlp.common import util as common_util
from allennlp.common.plugins import import_plugins
Expand Down Expand Up @@ -226,6 +227,9 @@ def train_model(
training_util.create_serialization_dir(params, serialization_dir, recover, force)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

meta = Meta.new()
meta.to_file(os.path.join(serialization_dir, META_NAME))

include_in_archive = params.pop("include_in_archive", None)
verify_include_in_archive(include_in_archive)

Expand Down
1 change: 1 addition & 0 deletions allennlp/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from allennlp.common.registrable import Registrable
from allennlp.common.tqdm import Tqdm
from allennlp.common.util import JsonDict
from allennlp.common.meta import Meta
37 changes: 37 additions & 0 deletions allennlp/common/meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from os import PathLike
from dataclasses import dataclass, asdict
import json
import logging
from typing import Union

from allennlp.version import VERSION


logger = logging.getLogger(__name__)


META_NAME = "meta.json"


@dataclass
class Meta:
"""
Defines the meta data that's saved in a serialization directory and archive
when training an AllenNLP model.
"""

version: str

@classmethod
def new(cls) -> "Meta":
return cls(version=VERSION)

def to_file(self, path: Union[PathLike, str]) -> None:
with open(path, "w") as meta_file:
json.dump(asdict(self), meta_file)

@classmethod
def from_path(cls, path: Union[PathLike, str]) -> "Meta":
with open(path) as meta_file:
data = json.load(meta_file)
return cls(**data)
59 changes: 57 additions & 2 deletions allennlp/models/archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
Helper functions for archiving models and restoring archived models.
"""
from os import PathLike
from typing import NamedTuple, Union, Dict, Any, List, Optional
from typing import Tuple, NamedTuple, Union, Dict, Any, List, Optional
import logging
import os
import tempfile
import tarfile
import shutil
from contextlib import contextmanager
import glob
import warnings

from torch.nn import Module

from allennlp.version import VERSION, _MAJOR, _MINOR, _PATCH
from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.meta import Meta, META_NAME
from allennlp.common.params import Params
from allennlp.data.dataset_readers import DatasetReader
from allennlp.models.model import Model, _DEFAULT_WEIGHTS
Expand All @@ -29,6 +32,7 @@ class Archive(NamedTuple):
config: Params
dataset_reader: DatasetReader
validation_dataset_reader: DatasetReader
meta: Optional[Meta]

def extract_module(self, path: str, freeze: bool = True) -> Module:
"""
Expand Down Expand Up @@ -90,12 +94,13 @@ def extract_module(self, path: str, freeze: bool = True) -> Module:
# These constants are the *known names* under which we archive them.
CONFIG_NAME = "config.json"
_WEIGHTS_NAME = "weights.th"
_VERSION_TUPLE = (_MAJOR, _MINOR, _PATCH)


def verify_include_in_archive(include_in_archive: Optional[List[str]] = None):
if include_in_archive is None:
return
saved_names = [CONFIG_NAME, _WEIGHTS_NAME, _DEFAULT_WEIGHTS, "vocabulary"]
saved_names = [CONFIG_NAME, _WEIGHTS_NAME, _DEFAULT_WEIGHTS, META_NAME, "vocabulary"]
for archival_target in include_in_archive:
if archival_target in saved_names:
raise ConfigurationError(
Expand Down Expand Up @@ -133,18 +138,26 @@ def archive_model(
config_file = os.path.join(serialization_dir, CONFIG_NAME)
if not os.path.exists(config_file):
logger.error("config file %s does not exist, unable to archive model", config_file)
return

meta_file = os.path.join(serialization_dir, META_NAME)

if archive_path is not None:
archive_file = archive_path
if os.path.isdir(archive_file):
archive_file = os.path.join(archive_file, "model.tar.gz")
else:
archive_file = os.path.join(serialization_dir, "model.tar.gz")

logger.info("archiving weights and vocabulary to %s", archive_file)
with tarfile.open(archive_file, "w:gz") as archive:
archive.add(config_file, arcname=CONFIG_NAME)
archive.add(weights_file, arcname=_WEIGHTS_NAME)
archive.add(os.path.join(serialization_dir, "vocabulary"), arcname="vocabulary")
if os.path.exists(meta_file):
archive.add(meta_file, arcname=META_NAME)
else:
logger.warning("meta file %s does not exist", meta_file)

if include_in_archive is not None:
for archival_target in include_in_archive:
Expand Down Expand Up @@ -184,6 +197,8 @@ def load_archive(
else:
logger.info(f"loading archive file {archive_file} from cache at {resolved_archive_file}")

meta: Optional[Meta] = None

tempdir = None
try:
if os.path.isdir(resolved_archive_file):
Expand All @@ -205,16 +220,26 @@ def load_archive(
config.duplicate(), serialization_dir
)
model = _load_model(config.duplicate(), weights_path, serialization_dir, cuda_device)

# Load meta.
meta_path = os.path.join(serialization_dir, META_NAME)
if os.path.exists(meta_path):
meta = Meta.from_path(meta_path)
finally:
if tempdir is not None:
logger.info(f"removing temporary unarchived model dir at {tempdir}")
shutil.rmtree(tempdir, ignore_errors=True)

# Check version compatibility.
if meta is not None:
_check_version_compatibility(archive_file, meta)

return Archive(
model=model,
config=config,
dataset_reader=dataset_reader,
validation_dataset_reader=validation_dataset_reader,
meta=meta,
)


Expand Down Expand Up @@ -267,3 +292,33 @@ def extracted_archive(resolved_archive_file, cleanup=True):
if tempdir is not None and cleanup:
logger.info(f"removing temporary unarchived model dir at {tempdir}")
shutil.rmtree(tempdir, ignore_errors=True)


def _parse_version(version: str) -> Tuple[str, str, str]:
"""
Parse a version string into a (major, minor, patch).
"""
try:
major, minor, patch = version.split(".")[:3]
except ValueError:
raise ValueError(f"Invalid version '{version}', unable to parse")
return (major, minor, patch)


def _check_version_compatibility(archive_file: Union[PathLike, str], meta: Meta):
meta_version_tuple = _parse_version(meta.version)
# Warn if current version is behind the version the model was trained on.
if _VERSION_TUPLE < meta_version_tuple:
warnings.warn(
f"The model {archive_file} was trained on a newer version of AllenNLP (v{meta.version}), "
f"but you're using version {VERSION}.",
UserWarning,
)
# Warn if major versions differ since there is no guarantee of backwards
# compatibility across major releases.
elif _VERSION_TUPLE[0] != meta_version_tuple[0]:
warnings.warn(
f"The model {archive_file} was trained on version {meta.version} of AllenNLP, "
f"but you're using {VERSION} which may not be compatible.",
UserWarning,
)
7 changes: 6 additions & 1 deletion tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
import torch

from allennlp.version import VERSION
from allennlp.commands.train import Train, train_model, train_model_from_args, TrainModel
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
Expand Down Expand Up @@ -109,7 +110,11 @@ class TestTrain(AllenNlpTestCase):
def test_train_model(self):
params = lambda: copy.deepcopy(self.DEFAULT_PARAMS)

train_model(params(), serialization_dir=os.path.join(self.TEST_DIR, "test_train_model"))
serialization_dir = os.path.join(self.TEST_DIR, "test_train_model")
train_model(params(), serialization_dir=serialization_dir)
archive = load_archive(os.path.join(serialization_dir, "model.tar.gz"))
assert archive.meta is not None
assert archive.meta.version == VERSION

# It's OK if serialization dir exists but is empty:
serialization_dir2 = os.path.join(self.TEST_DIR, "empty_directory")
Expand Down
19 changes: 18 additions & 1 deletion tests/models/archival_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@
import pytest
import torch

from allennlp.version import _MAJOR, _MINOR
from allennlp.commands.train import train_model
from allennlp.common import Params
from allennlp.common.meta import Meta
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.dataset_readers import DatasetReader
from allennlp.models.archival import archive_model, load_archive, CONFIG_NAME
from allennlp.models.archival import (
archive_model,
load_archive,
CONFIG_NAME,
_check_version_compatibility,
)


def assert_models_equal(model, model2):
Expand All @@ -32,6 +39,16 @@ def assert_models_equal(model, model2):
assert vocab._index_to_token == vocab2._index_to_token


def _test_check_version_compatibility():
meta = Meta(version=f"{_MAJOR}.{int(_MINOR) + 1}.0")
with pytest.warns(UserWarning, match="trained on a newer version"):
_check_version_compatibility("model.tar.gz", meta)

meta = Meta(version="1.2.0")
with pytest.warns(UserWarning, match="trained on version"):
_check_version_compatibility("model.tar.gz", meta)


class ArchivalTest(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
Expand Down