diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cb3e3c4ac5..01cf80ba0a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for the HuggingFace Hub as an alternative way to handle loading files. Hub downloads should be made through the `hf://` URL scheme. - Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`. - Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars. - The test for distributed metrics now takes a parameter specifying how often you want to run it. @@ -26,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Ported the following Huggingface `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`. +- Ported the following HuggingFace `LambdaLR`-based schedulers: `ConstantLearningRateScheduler`, `ConstantWithWarmupLearningRateScheduler`, `CosineWithWarmupLearningRateScheduler`, `CosineHardRestartsWithWarmupLearningRateScheduler`. - Added new `sub_token_mode` parameter to `pretrained_transformer_mismatched_embedder` class to support first sub-token embedding - Added a way to run a multi task model with a dataset reader as part of `allennlp predict`. - Added new `eval_mode` in `PretrainedTransformerEmbedder`. If it is set to `True`, the transformer is _always_ run in evaluation mode, which, e.g., disables dropout and does not update batch normalization statistics. diff --git a/allennlp/common/file_utils.py b/allennlp/common/file_utils.py index 97fa3bb8b97..4eba73f32c9 100644 --- a/allennlp/common/file_utils.py +++ b/allennlp/common/file_utils.py @@ -54,6 +54,8 @@ from requests.packages.urllib3.util.retry import Retry import lmdb from torch import Tensor +from huggingface_hub import hf_hub_url, cached_download, snapshot_download +from allennlp.version import VERSION from allennlp.common.tqdm import Tqdm @@ -233,9 +235,46 @@ def cached_path( cache_dir = os.path.expanduser(cache_dir) os.makedirs(cache_dir, exist_ok=True) + extraction_path: Optional[str] = None + if not isinstance(url_or_filename, str): url_or_filename = str(url_or_filename) + if url_or_filename.startswith("hf://"): + # Remove the hf:// prefix + identifier = url_or_filename[5:] + + filename: Optional[str] + if len(identifier.split("/")) > 2: + filename = "/".join(identifier.split("/")[2:]) + model_identifier = "/".join(identifier.split("/")[:2]) + else: + filename = None + model_identifier = identifier + + revision: Optional[str] + if "@" in model_identifier: + repo_id = model_identifier.split("@")[0] + revision = model_identifier.split("@")[1] + else: + repo_id = model_identifier + revision = None + + if filename is not None: + url = hf_hub_url(repo_id=repo_id, filename=filename, revision=revision) + url_or_filename = str( + cached_download( + url=url, + library_name="allennlp", + library_version=VERSION, + cache_dir=CACHE_DIRECTORY, + ) + ) + else: + extraction_path = snapshot_download( + repo_id, revision=revision, cache_dir=CACHE_DIRECTORY + ) + file_path: str # If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here. @@ -261,9 +300,7 @@ def cached_path( parsed = urlparse(url_or_filename) - extraction_path: Optional[str] = None - - if parsed.scheme in ("http", "https", "s3"): + if parsed.scheme in ("http", "https", "s3") and extraction_path is None: # URL, so get it from the cache (downloading if necessary) file_path = get_from_cache(url_or_filename, cache_dir) @@ -272,7 +309,7 @@ def cached_path( # For example ~/.allennlp/cache/234234.21341 -> ~/.allennlp/cache/234234.21341-extracted extraction_path = file_path + "-extracted" - else: + elif extraction_path is None: url_or_filename = os.path.expanduser(url_or_filename) if os.path.exists(url_or_filename): diff --git a/setup.py b/setup.py index 7bfb949189a..22d600c6806 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ "lmdb", "more-itertools", "wandb>=0.10.0,<0.11.0", + "huggingface_hub>=0.0.8", ], entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]}, include_package_data=True, diff --git a/tests/common/file_utils_test.py b/tests/common/file_utils_test.py index 9e56781a196..c8391fa72aa 100644 --- a/tests/common/file_utils_test.py +++ b/tests/common/file_utils_test.py @@ -28,7 +28,10 @@ LocalCacheResource, TensorCache, ) +from allennlp.common import Params +from allennlp.modules.token_embedders import ElmoTokenEmbedder from allennlp.common.testing import AllenNlpTestCase +from allennlp.predictors import Predictor def set_up_glove(url: str, byt: bytes, change_etag_every: int = 1000): @@ -563,3 +566,25 @@ def test_tensor_cache(self): with pytest.warns(UserWarning, match="cache will be read-only"): cache = TensorCache(self.TEST_DIR / "cache") assert cache.read_only + + +class TestHFHubDownload(AllenNlpTestCase): + def test_cached_download(self): + params = Params( + { + "options_file": "hf://lysandre/test-elmo-tiny/options.json", + "weight_file": "hf://lysandre/test-elmo-tiny/lm_weights.hdf5", + } + ) + embedding_layer = ElmoTokenEmbedder.from_params(vocab=None, params=params) + + assert isinstance( + embedding_layer, ElmoTokenEmbedder + ), "Embedding layer badly instantiated from HF Hub." + assert ( + embedding_layer.get_output_dim() == 32 + ), "Embedding layer badly instantiated from HF Hub." + + def test_snapshot_download(self): + predictor = Predictor.from_path("hf://lysandre/test-simple-tagger-tiny") + assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens"