This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Proposal of utility function + command to push models to HF Hub #5370
Merged
Merged
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
0d7ea93
Implement initial push_to_hf utility + command
osanseviero 907c423
Improve command description
osanseviero 31d8238
Merge branch 'main' into push_to_hf
epwalsh 9f35a05
Merge branch 'main' into push_to_hf
epwalsh a899c81
Split serialization_dir and archive_path use cases
osanseviero 83d322d
Change to PathLike
osanseviero c503ffd
Fix
osanseviero 7cec4a4
Merge branch 'main' into push_to_hf
dirkgr ea0734e
dummy fix
osanseviero 5b9e839
Add tests for pushing to Hub
osanseviero 0131481
Merge branch 'main' into push_to_hf
dirkgr 88842e2
Update allennlp/commands/push_to_hf.py
osanseviero eb945a0
Changelog
osanseviero ad07622
Merge branch 'main' into push_to_hf
epwalsh b5a0be0
Fix GA workflow
osanseviero 842aaf4
invalidate cache
epwalsh 50fa431
Patch staging in tests
osanseviero 4c303ed
Merge branch 'push_to_hf' of https://github.com/osanseviero/allennlp …
osanseviero 559a22d
Style and remove diffs
osanseviero 84273ca
Merge branch 'main' into push_to_hf
epwalsh cc36028
fix changelog
epwalsh 0d4e1f1
replace '_' with '-' for consistency
epwalsh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
The `push_to_hf` subcommand can be used to push a trained model to the | ||
Hugging Face Hub ([hf.co](https://hf.co/)). | ||
""" | ||
|
||
import argparse | ||
import logging | ||
|
||
from overrides import overrides | ||
|
||
from allennlp.commands.subcommand import Subcommand | ||
from allennlp.common.push_to_hf import push_to_hf | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@Subcommand.register("push_to_hf") | ||
class PushToHf(Subcommand): | ||
@overrides | ||
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: | ||
description = """Push a model to the Hugging Face Hub | ||
|
||
Pushing your models to the Hugging Face Hub ([hf.co](https://hf.co/)) | ||
allows you to share your models with others. On top of that, you can try | ||
the models directly in the browser with the available widgets. | ||
|
||
Before running this command, login to Hugging Face with huggingface-cli login | ||
""" | ||
subparser = parser.add_parser(self.name, description=description, help=description) | ||
subparser.set_defaults(func=push) | ||
|
||
subparser.add_argument( | ||
"-a", | ||
"--archive_path", | ||
required=True, | ||
type=str, | ||
help="full path to the zipped model or to a directory with the serialized model.", | ||
) | ||
|
||
subparser.add_argument( | ||
"-n", | ||
"--repo_name", | ||
required=True, | ||
type=str, | ||
default="Name of the repository", | ||
help="Name of the repository", | ||
) | ||
|
||
subparser.add_argument( | ||
"-o", | ||
"--organization", | ||
required=False, | ||
type=str, | ||
help="name of organization to which the model should be uploaded", | ||
) | ||
|
||
subparser.add_argument( | ||
"-c", | ||
"--commit_message", | ||
required=False, | ||
type=str, | ||
default="Update repository", | ||
help="Commit message to use for the push", | ||
) | ||
|
||
subparser.add_argument( | ||
"-l", | ||
"--local_repo_path", | ||
required=False, | ||
type=str, | ||
default="hub", | ||
help="local path for creating repo", | ||
) | ||
|
||
return subparser | ||
|
||
|
||
def push(args: argparse.Namespace): | ||
push_to_hf( | ||
args.archive_path, | ||
args.repo_name, | ||
args.organization, | ||
args.commit_message, | ||
args.local_repo_path, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
Utilities for pushing models to the Hugging Face Hub ([hf.co](https://hf.co/)). | ||
""" | ||
|
||
import logging | ||
import sys | ||
from typing import Optional, Union | ||
from pathlib import Path | ||
|
||
from allennlp.common.file_utils import cached_path | ||
import shutil | ||
|
||
import zipfile | ||
import tarfile | ||
import tempfile | ||
|
||
from huggingface_hub import Repository, HfApi, HfFolder | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
README_TEMPLATE = """--- | ||
tags: | ||
- allennlp | ||
--- | ||
|
||
# TODO: Fill this model card | ||
""" | ||
|
||
|
||
def _create_model_card(repo_dir: Path): | ||
"""Creates a model card for the repository. | ||
|
||
TODO: Add metrics to model-index | ||
TODO: Use information from common model cards | ||
""" | ||
readme_path = repo_dir / "README.md" | ||
prev_readme = "" | ||
if readme_path.exists(): | ||
with readme_path.open("r", encoding="utf8") as f: | ||
prev_readme = f.read() | ||
with readme_path.open("w", encoding="utf-8") as f: | ||
f.write(README_TEMPLATE) | ||
f.write(prev_readme) | ||
|
||
|
||
_ALLOWLIST_PATHS = ["vocabulary", "config.json", "weights.th", "best.th", "metrics.json", "log"] | ||
|
||
|
||
def _copy_allowed_file(filepath: Path, dst_directory: Path): | ||
""" | ||
Copies files from allowlist to a directory, overriding existing | ||
files or directories if any. | ||
""" | ||
if filepath.name not in _ALLOWLIST_PATHS: | ||
return | ||
|
||
dst = dst_directory / filepath.name | ||
if dst.is_dir(): | ||
shutil.rmtree(str(dst)) | ||
elif dst.is_file(): | ||
dst.unlink() | ||
if filepath.is_dir(): | ||
shutil.copytree(filepath, dst) | ||
elif filepath.is_file(): | ||
if filepath.name == "best.th": | ||
dst = dst_directory / "model.th" | ||
shutil.copy(str(filepath), str(dst)) | ||
|
||
|
||
def push_to_hf( | ||
archive_path: Union[str, Path], | ||
osanseviero marked this conversation as resolved.
Show resolved
Hide resolved
|
||
repo_name: str, | ||
organization: Optional[str] = None, | ||
commit_message: str = "Update repository", | ||
local_repo_path: Union[str, Path] = "hub", | ||
osanseviero marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
"""Pushes model and related files to the Hugging Face Hub ([hf.co](https://hf.co/)) | ||
|
||
# Parameters | ||
|
||
archive_path : `Union[str, Path]` | ||
Full path to the zipped model (e.g. model/model.tar.gz) or to a directory with the serialized model. | ||
|
||
repo_name: `str` | ||
Name of the repository in the Hugging Face Hub. | ||
|
||
organization : `Optional[str]`, optional (default = `None`) | ||
Name of organization to which the model should be uploaded. | ||
|
||
commit_message: `str` (default=`Update repository`) | ||
Commit message to use for the push. | ||
|
||
local_repo_path : `Union[str, Path]`, optional (default=`hub`) | ||
Local directory where the repository will be saved. | ||
|
||
""" | ||
archive_path = Path(archive_path) | ||
|
||
if not archive_path.exists(): | ||
logging.error( | ||
f"Can't find archive path: {archive_path}, please" | ||
"point to either a .tar.gz archive or to a directory" | ||
"with the serialized model." | ||
) | ||
sys.exit(1) | ||
osanseviero marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Create the repo (or clone its content if it's nonempty) | ||
api = HfApi() | ||
repo_url = api.create_repo( | ||
name=repo_name, | ||
token=HfFolder.get_token(), | ||
organization=organization, | ||
private=False, | ||
exist_ok=True, | ||
) | ||
|
||
repo_local_path = Path(local_repo_path) / repo_name | ||
repo = Repository(repo_local_path, clone_from=repo_url) | ||
repo.git_pull(rebase=True) | ||
|
||
# Model file should be tracked with Git LFS | ||
repo.lfs_track(["*.th"]) | ||
info_msg = f"Preparing repository '{repo_name}'" | ||
if organization is not None: | ||
info_msg += f" ({organization})" | ||
logging.info(info_msg) | ||
|
||
# Extract information from either serializable directory or a | ||
# .tar.gz file | ||
if archive_path.is_dir(): | ||
for filename in archive_path.iterdir(): | ||
_copy_allowed_file(Path(filename), repo_local_path) | ||
elif zipfile.is_zipfile(archive_path) or tarfile.is_tarfile(archive_path): | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
extracted_dir = Path(cached_path(archive_path, temp_dir, extract_archive=True)) | ||
for filename in extracted_dir.iterdir(): | ||
_copy_allowed_file(Path(filename), repo_local_path) | ||
|
||
_create_model_card(repo_local_path) | ||
|
||
logging.info(f"Pushing repo {repo_name} to the Hugging Face Hub") | ||
url = repo.push_to_hub(commit_message=commit_message) | ||
|
||
url, _ = url.split("/commit/") | ||
logging.info(f"View your model in {url}") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.