Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Proposal of utility function + command to push models to HF Hub #5370

Merged
merged 22 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ jobs:
strategy:
matrix:
python: ['3.7', '3.8']
env:
HUGGINGFACE_CO_STAGING: yes

steps:
- name: Setup Python
Expand Down
1 change: 1 addition & 0 deletions allennlp/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.commands.test_install import TestInstall
from allennlp.commands.train import Train
from allennlp.commands.push_to_hf import PushToHf
from allennlp.commands.count_instances import CountInstances
from allennlp.commands.tango import Tango
from allennlp.common.plugins import import_plugins
Expand Down
97 changes: 97 additions & 0 deletions allennlp/commands/push_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
The `push_to_hf` subcommand can be used to push a trained model to the
Hugging Face Hub ([hf.co](https://hf.co/)).
"""

import argparse
import logging

from overrides import overrides

from allennlp.commands.subcommand import Subcommand
from allennlp.common.push_to_hf import push_to_hf

logger = logging.getLogger(__name__)


@Subcommand.register("push_to_hf")
class PushToHf(Subcommand):
@overrides
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
description = """Push a model to the Hugging Face Hub

Pushing your models to the Hugging Face Hub ([hf.co](https://hf.co/))
allows you to share your models with others. On top of that, you can try
the models directly in the browser with the available widgets.

Before running this command, login to Hugging Face with `huggingface-cli login`.

You can specify either a `serialization_dir` or an `archive_path`, but using the
first option is recommended since the `serialization_dir` contains more useful
information such as metrics and TensorBoard traces.
"""
subparser = parser.add_parser(self.name, description=description, help=description)
subparser.set_defaults(func=push)

subparser.add_argument(
"-n",
"--repo_name",
required=True,
type=str,
default="Name of the repository",
help="Name of the repository",
)

model_dir_group = subparser.add_mutually_exclusive_group(required=True)
model_dir_group.add_argument(
"-s",
"--serialization_dir",
type=str,
help="directory in which to save the model and its logs are saved",
)

model_dir_group.add_argument(
"-a",
"--archive_path",
type=str,
help="full path to the zipped model, using serialization_dir instead is recommended",
)

subparser.add_argument(
"-o",
"--organization",
required=False,
type=str,
help="name of organization to which the model should be uploaded",
)

subparser.add_argument(
"-c",
"--commit_message",
required=False,
type=str,
default="Update repository",
help="Commit message to use for the push",
)

subparser.add_argument(
"-l",
"--local_repo_path",
required=False,
type=str,
default="hub",
help="local path for creating repo",
)

return subparser


def push(args: argparse.Namespace):
push_to_hf(
args.repo_name,
args.serialization_dir,
args.archive_path,
args.organization,
args.commit_message,
args.local_repo_path,
)
1 change: 1 addition & 0 deletions allennlp/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from allennlp.common.tqdm import Tqdm
from allennlp.common.util import JsonDict
from allennlp.common.meta import Meta
from allennlp.common.push_to_hf import push_to_hf
180 changes: 180 additions & 0 deletions allennlp/common/push_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
Utilities for pushing models to the Hugging Face Hub ([hf.co](https://hf.co/)).
"""

import logging
import shutil
import tarfile
import tempfile
import zipfile
from os import PathLike
from pathlib import Path
from typing import Optional, Union

from huggingface_hub import HfApi, HfFolder, Repository

from allennlp.common.file_utils import cached_path

logger = logging.getLogger(__name__)

README_TEMPLATE = """---
tags:
- allennlp
---

# TODO: Fill this model card
"""


def _create_model_card(repo_dir: Path):
"""Creates a model card for the repository.

TODO: Add metrics to model-index
TODO: Use information from common model cards
"""
readme_path = repo_dir / "README.md"
prev_readme = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf8") as f:
prev_readme = f.read()
with readme_path.open("w", encoding="utf-8") as f:
f.write(README_TEMPLATE)
f.write(prev_readme)


_ALLOWLIST_PATHS = ["vocabulary", "config.json", "weights.th", "best.th", "metrics.json", "log"]


def _copy_allowed_file(filepath: Path, dst_directory: Path):
"""
Copies files from allowlist to a directory, overriding existing
files or directories if any.
"""
if filepath.name not in _ALLOWLIST_PATHS:
return

dst = dst_directory / filepath.name
if dst.is_dir():
shutil.rmtree(str(dst))
elif dst.is_file():
dst.unlink()
if filepath.is_dir():
shutil.copytree(filepath, dst)
elif filepath.is_file():
if filepath.name in ["best.th", "weights.th"]:
dst = dst_directory / "model.th"
shutil.copy(str(filepath), str(dst))


def push_to_hf(
repo_name: str,
serialization_dir: Optional[Union[str, PathLike]] = None,
archive_path: Optional[Union[str, PathLike]] = None,
organization: Optional[str] = None,
commit_message: str = "Update repository",
local_repo_path: Union[str, PathLike] = "hub",
use_auth_token: Union[bool, str] = True,
) -> str:
"""Pushes model and related files to the Hugging Face Hub ([hf.co](https://hf.co/))

# Parameters

repo_name: `str`
Name of the repository in the Hugging Face Hub.

serialization_dir : `Union[str, PathLike]`, optional (default = `None`)
Full path to a directory with the serialized model.

archive_path : `Union[str, PathLike]`, optional (default = `None`)
Full path to the zipped model (e.g. model/model.tar.gz). Use `serialization_dir` if possible.

organization : `Optional[str]`, optional (default = `None`)
Name of organization to which the model should be uploaded.

commit_message: `str` (default=`Update repository`)
Commit message to use for the push.

local_repo_path : `Union[str, Path]`, optional (default=`hub`)
Local directory where the repository will be saved.

use_auth_token (``str`` or ``bool``, `optional`, defaults ``True``):
huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate
against the Hugging Face Hub (useful from Google Colab for instance). It's automatically retrieved
if you've done `huggingface-cli login` before.
"""

if serialization_dir is not None:
working_dir = Path(serialization_dir)
if archive_path is not None:
raise ValueError(
"serialization_dir and archive_path are mutually exclusive, please just use one."
)
if not working_dir.exists() or not working_dir.is_dir():
raise ValueError(
f"Can't find path: {serialization_dir}, please point"
"to a directory with the serialized model."
)
elif archive_path is not None:
working_dir = Path(archive_path)
if (
not working_dir.exists()
or not zipfile.is_zipfile(working_dir)
and not tarfile.is_tarfile(working_dir)
):
raise ValueError(
f"Can't find path: {archive_path}, please point to a .tar.gz archive"
"or to a directory with the serialized model."
)
else:
logging.info(
"Using the archive_path is discouraged. Using the serialization_dir"
"will also upload metrics and TensorBoard traces to the Hugging Face Hub."
)
else:
raise ValueError("please specify either serialization_dir or archive_path")

info_msg = f"Preparing repository '{use_auth_token}'"
if isinstance(use_auth_token, str):
huggingface_token = use_auth_token
elif use_auth_token:
huggingface_token = HfFolder.get_token()

# Create the repo (or clone its content if it's nonempty)
api = HfApi()
repo_url = api.create_repo(
name=repo_name,
token=huggingface_token,
organization=organization,
private=False,
exist_ok=True,
)

repo_local_path = Path(local_repo_path) / repo_name
repo = Repository(repo_local_path, clone_from=repo_url, use_auth_token=use_auth_token)
repo.git_pull(rebase=True)

# Model file should be tracked with Git LFS
repo.lfs_track(["*.th"])
info_msg = f"Preparing repository '{repo_name}'"
if organization is not None:
info_msg += f" ({organization})"
logging.info(info_msg)

# Extract information from either serializable directory or a
# .tar.gz file
if serialization_dir is not None:
for filename in working_dir.iterdir():
_copy_allowed_file(Path(filename), repo_local_path)
else:
with tempfile.TemporaryDirectory() as temp_dir:
extracted_dir = Path(cached_path(working_dir, temp_dir, extract_archive=True))
for filename in extracted_dir.iterdir():
_copy_allowed_file(Path(filename), repo_local_path)

_create_model_card(repo_local_path)

logging.info(f"Pushing repo {repo_name} to the Hugging Face Hub")
repo.push_to_hub(commit_message=commit_message)

logging.info(f"View your model in {repo_url}")
return repo_url
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"termcolor==1.1.0",
"checklist==0.0.11",
"wandb>=0.10.0,<0.13.0",
"huggingface_hub>=0.0.8",
"huggingface_hub>=0.0.16",
"datasets>=1.2.1,<2.0",
"dill",
"base58",
Expand Down
Loading