Skip to content

Commit 4a8a574

Browse files
committed
remove extra test fixtures, add tests
1 parent 349e1ac commit 4a8a574

15 files changed

+92
-204
lines changed

cached_path/__init__.py

+1-142
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@
22
Utilities for working with the local dataset cache.
33
"""
44
import string
5-
import weakref
6-
from contextlib import contextmanager
75
import glob
8-
import io
96
import os
107
import logging
118
import tempfile
129
import json
13-
from abc import ABC
1410
from collections import defaultdict
1511
from dataclasses import dataclass, asdict
1612
from datetime import timedelta
@@ -26,15 +22,11 @@
2622
Callable,
2723
Set,
2824
List,
29-
Iterator,
30-
Iterable,
3125
Dict,
3226
NamedTuple,
33-
MutableMapping,
3427
)
3528
from hashlib import sha256
3629
from functools import wraps
37-
from weakref import WeakValueDictionary
3830
from zipfile import ZipFile, is_zipfile
3931
import tarfile
4032
import shutil
@@ -667,11 +659,10 @@ def _hf_hub_download(
667659

668660
if filename is not None:
669661
hub_url = hf_hub.hf_hub_url(repo_id=repo_id, filename=filename, revision=revision)
670-
# TODO: change library name?
671662
cache_path = str(
672663
hf_hub.cached_download(
673664
url=hub_url,
674-
library_name="allennlp",
665+
library_name="cached_path",
675666
library_version=VERSION,
676667
cache_dir=cache_dir,
677668
)
@@ -831,24 +822,6 @@ def get_file_extension(path: str, dot=True, lower: bool = True):
831822
return ext.lower() if lower else ext
832823

833824

834-
def open_compressed(
835-
filename: Union[str, PathLike], mode: str = "rt", encoding: Optional[str] = "UTF-8", **kwargs
836-
):
837-
if not isinstance(filename, str):
838-
filename = str(filename)
839-
open_fn: Callable = open
840-
841-
if filename.endswith(".gz"):
842-
import gzip
843-
844-
open_fn = gzip.open
845-
elif filename.endswith(".bz2"):
846-
import bz2
847-
848-
open_fn = bz2.open
849-
return open_fn(get_cached_path(filename), mode=mode, encoding=encoding, **kwargs)
850-
851-
852825
def _get_resource_size(path: str) -> int:
853826
"""
854827
Get the size of a file or directory.
@@ -867,117 +840,3 @@ def _get_resource_size(path: str) -> int:
867840
inodes.add(inode)
868841
total_size += os.path.getsize(fp)
869842
return total_size
870-
871-
872-
class _CacheEntry(NamedTuple):
873-
regular_files: List[_Meta]
874-
extraction_dirs: List[_Meta]
875-
876-
877-
def _find_entries(
878-
patterns: List[str] = None,
879-
cache_dir: Union[str, Path] = None,
880-
) -> Tuple[int, Dict[str, _CacheEntry]]:
881-
"""
882-
Find all cache entries, filtering ones that don't match any of the glob patterns given.
883-
884-
Returns the total size of the matching entries and mapping or resource name to meta data.
885-
886-
The values in the returned mapping are tuples because we seperate meta entries that
887-
correspond to extraction directories vs regular cache entries.
888-
"""
889-
cache_dir = os.path.expanduser(cache_dir or CACHE_DIRECTORY)
890-
891-
total_size: int = 0
892-
cache_entries: Dict[str, _CacheEntry] = defaultdict(lambda: _CacheEntry([], []))
893-
for meta_path in glob.glob(str(cache_dir) + "/*.json"):
894-
meta = _Meta.from_path(meta_path)
895-
if patterns and not any(fnmatch(meta.resource, p) for p in patterns):
896-
continue
897-
if meta.extraction_dir:
898-
cache_entries[meta.resource].extraction_dirs.append(meta)
899-
else:
900-
cache_entries[meta.resource].regular_files.append(meta)
901-
total_size += meta.size
902-
903-
# Sort entries for each resource by creation time, newest first.
904-
for entry in cache_entries.values():
905-
entry.regular_files.sort(key=lambda meta: meta.creation_time, reverse=True)
906-
entry.extraction_dirs.sort(key=lambda meta: meta.creation_time, reverse=True)
907-
908-
return total_size, cache_entries
909-
910-
911-
def remove_cache_entries(patterns: List[str], cache_dir: Union[str, Path] = None) -> int:
912-
"""
913-
Remove cache entries matching the given patterns.
914-
915-
Returns the total reclaimed space in bytes.
916-
"""
917-
total_size, cache_entries = _find_entries(patterns=patterns, cache_dir=cache_dir)
918-
for resource, entry in cache_entries.items():
919-
for meta in entry.regular_files:
920-
logger.info("Removing cached version of %s at %s", resource, meta.cached_path)
921-
os.remove(meta.cached_path)
922-
if os.path.exists(meta.cached_path + ".lock"):
923-
os.remove(meta.cached_path + ".lock")
924-
os.remove(meta.cached_path + ".json")
925-
for meta in entry.extraction_dirs:
926-
logger.info("Removing extracted version of %s at %s", resource, meta.cached_path)
927-
shutil.rmtree(meta.cached_path)
928-
if os.path.exists(meta.cached_path + ".lock"):
929-
os.remove(meta.cached_path + ".lock")
930-
os.remove(meta.cached_path + ".json")
931-
return total_size
932-
933-
934-
def inspect_cache(patterns: List[str] = None, cache_dir: Union[str, Path] = None):
935-
"""
936-
Print out useful information about the cache directory.
937-
"""
938-
from allennlp.common.util import format_timedelta, format_size
939-
940-
cache_dir = os.path.expanduser(cache_dir or CACHE_DIRECTORY)
941-
942-
# Gather cache entries by resource.
943-
total_size, cache_entries = _find_entries(patterns=patterns, cache_dir=cache_dir)
944-
945-
if patterns:
946-
print(f"Cached resources matching {patterns}:")
947-
else:
948-
print("Cached resources:")
949-
950-
for resource, entry in sorted(
951-
cache_entries.items(),
952-
# Sort by creation time, latest first.
953-
key=lambda x: max(
954-
0 if not x[1][0] else x[1][0][0].creation_time,
955-
0 if not x[1][1] else x[1][1][0].creation_time,
956-
),
957-
reverse=True,
958-
):
959-
print("\n-", resource)
960-
if entry.regular_files:
961-
td = timedelta(seconds=time.time() - entry.regular_files[0].creation_time)
962-
n_versions = len(entry.regular_files)
963-
size = entry.regular_files[0].size
964-
print(
965-
f" {n_versions} {'versions' if n_versions > 1 else 'version'} cached, "
966-
f"latest {format_size(size)} from {format_timedelta(td)} ago"
967-
)
968-
if entry.extraction_dirs:
969-
td = timedelta(seconds=time.time() - entry.extraction_dirs[0].creation_time)
970-
n_versions = len(entry.extraction_dirs)
971-
size = entry.extraction_dirs[0].size
972-
print(
973-
f" {n_versions} {'versions' if n_versions > 1 else 'version'} extracted, "
974-
f"latest {format_size(size)} from {format_timedelta(td)} ago"
975-
)
976-
print(f"\nTotal size: {format_size(total_size)}")
977-
978-
979-
SAFE_FILENAME_CHARS = frozenset("-_.%s%s" % (string.ascii_letters, string.digits))
980-
981-
982-
def filename_is_safe(filename: str) -> bool:
983-
return all(c in SAFE_FILENAME_CHARS for c in filename)

cached_path/testing.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import logging
2+
import os
3+
import pathlib
4+
import shutil
5+
import tempfile
6+
7+
TEST_DIR = tempfile.mkdtemp(prefix="cached_path_tests")
8+
9+
10+
class BaseTestClass:
11+
"""
12+
A custom testing class that disables some of the more verbose
13+
logging and that creates and destroys a temp directory as a test fixture.
14+
"""
15+
16+
PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve()
17+
MODULE_ROOT = PROJECT_ROOT / "cached_path"
18+
TOOLS_ROOT = MODULE_ROOT / "tools"
19+
TESTS_ROOT = PROJECT_ROOT / "tests"
20+
FIXTURES_ROOT = PROJECT_ROOT / "test_fixtures"
21+
22+
def setup_method(self):
23+
logging.basicConfig(
24+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG
25+
)
26+
# Disabling some of the more verbose logging statements that typically aren't very helpful
27+
# in tests.
28+
logging.getLogger("urllib3.connectionpool").disabled = True
29+
30+
self.TEST_DIR = pathlib.Path(TEST_DIR)
31+
32+
os.makedirs(self.TEST_DIR, exist_ok=True)
33+
34+
def teardown_method(self):
35+
shutil.rmtree(self.TEST_DIR)

setup.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def fix_url_dependencies(req: str) -> str:
4242
"Programming Language :: Python :: 3",
4343
"Topic :: Scientific/Engineering :: Artificial Intelligence",
4444
],
45-
keywords="",
46-
url="",
45+
keywords="allennlp cached_path file utils",
46+
url="https://github.com/allenai/cached_path",
4747
author="Allen Institute for Artificial Intelligence",
4848
author_email="[email protected]",
4949
license="Apache",
@@ -55,8 +55,6 @@ def fix_url_dependencies(req: str) -> str:
5555
"tests",
5656
"test_fixtures",
5757
"test_fixtures.*",
58-
"benchmarks",
59-
"benchmarks.*",
6058
]
6159
),
6260
install_requires=install_requirements,

test_fixtures/embeddings/fake_embeddings.5d.txt

-15
This file was deleted.
-295 Bytes
Binary file not shown.
-332 Bytes
Binary file not shown.
-316 Bytes
Binary file not shown.
Binary file not shown.
-485 Bytes
Binary file not shown.
Binary file not shown.
-472 Bytes
Binary file not shown.
-1.08 KB
Binary file not shown.
-5.52 KB
Binary file not shown.
Binary file not shown.

tests/cached_path_test.py

+54-43
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from collections import Counter
22
import json
3+
import os
4+
import shutil
35
import time
6+
import pathlib
47

8+
from filelock import Timeout
59
import pytest
610
import responses
711
from requests.exceptions import ConnectionError, HTTPError
@@ -15,50 +19,11 @@
1519
get_cached_path,
1620
_split_s3_path,
1721
_split_gcs_path,
18-
open_compressed,
1922
CacheFile,
2023
_Meta,
21-
_find_entries,
22-
inspect_cache,
23-
remove_cache_entries,
24-
# LocalCacheResource,
2524
)
2625

27-
import logging
28-
import os
29-
import pathlib
30-
import shutil
31-
import tempfile
32-
33-
TEST_DIR = tempfile.mkdtemp(prefix="cached_path_tests")
34-
35-
36-
class BaseTestCase:
37-
"""
38-
A custom testing class that disables some of the more verbose AllenNLP
39-
logging and that creates and destroys a temp directory as a test fixture.
40-
"""
41-
42-
PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve()
43-
MODULE_ROOT = PROJECT_ROOT / "cached_path"
44-
TOOLS_ROOT = MODULE_ROOT / "tools"
45-
TESTS_ROOT = PROJECT_ROOT / "tests"
46-
FIXTURES_ROOT = PROJECT_ROOT / "test_fixtures"
47-
48-
def setup_method(self):
49-
logging.basicConfig(
50-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG
51-
)
52-
# Disabling some of the more verbose logging statements that typically aren't very helpful
53-
# in tests.
54-
logging.getLogger("urllib3.connectionpool").disabled = True
55-
56-
self.TEST_DIR = pathlib.Path(TEST_DIR)
57-
58-
os.makedirs(self.TEST_DIR, exist_ok=True)
59-
60-
def teardown_method(self):
61-
shutil.rmtree(self.TEST_DIR)
26+
from cached_path.testing import BaseTestClass
6227

6328

6429
def set_up_glove(url: str, byt: bytes, change_etag_every: int = 1000):
@@ -94,7 +59,53 @@ def head_callback(_):
9459
responses.add_callback(responses.HEAD, url, callback=head_callback)
9560

9661

97-
class TestFileUtils(BaseTestCase):
62+
class TestFileLock(BaseTestClass):
63+
def setup_method(self):
64+
super().setup_method()
65+
66+
# Set up a regular lock and a read-only lock.
67+
open(self.TEST_DIR / "lock", "a").close()
68+
open(self.TEST_DIR / "read_only_lock", "a").close()
69+
os.chmod(self.TEST_DIR / "read_only_lock", 0o555)
70+
71+
# Also set up a read-only directory.
72+
os.mkdir(self.TEST_DIR / "read_only_dir", 0o555)
73+
74+
def test_locking(self):
75+
with FileLock(self.TEST_DIR / "lock"):
76+
# Trying to acquire the lock again should fail.
77+
with pytest.raises(Timeout):
78+
with FileLock(self.TEST_DIR / "lock", timeout=0.1):
79+
pass
80+
81+
# Trying to acquire a lock when lacking write permissions on the file should fail.
82+
with pytest.raises(PermissionError):
83+
with FileLock(self.TEST_DIR / "read_only_lock"):
84+
pass
85+
86+
# But this should only issue a warning if we set the `read_only_ok` flag to `True`.
87+
with pytest.warns(UserWarning, match="Lacking permissions"):
88+
with FileLock(self.TEST_DIR / "read_only_lock", read_only_ok=True):
89+
pass
90+
91+
# However this should always fail when we lack write permissions and the file lock
92+
# doesn't exist yet.
93+
with pytest.raises(PermissionError):
94+
with FileLock(self.TEST_DIR / "read_only_dir" / "lock", read_only_ok=True):
95+
pass
96+
97+
98+
class TestCacheFile(BaseTestClass):
99+
def test_temp_file_removed_on_error(self):
100+
cache_filename = self.TEST_DIR / "cache_file"
101+
with pytest.raises(IOError, match="I made this up"):
102+
with CacheFile(cache_filename) as handle:
103+
raise IOError("I made this up")
104+
assert not os.path.exists(handle.name)
105+
assert not os.path.exists(cache_filename)
106+
107+
108+
class TestFileUtils(BaseTestClass):
98109
def setup_method(self):
99110
super().setup_method()
100111
self.glove_file = self.FIXTURES_ROOT / "embeddings/glove.6B.100d.sample.txt.gz"
@@ -333,7 +344,7 @@ def test_extract_with_external_symlink(self):
333344
get_cached_path(dangerous_file, extract_archive=True)
334345

335346

336-
class TestCachedPathWithArchive(BaseTestCase):
347+
class TestCachedPathWithArchive(BaseTestClass):
337348
def setup_method(self):
338349
super().setup_method()
339350
self.tar_file = self.TEST_DIR / "utf-8.tar.gz"
@@ -411,7 +422,7 @@ def test_cached_path_extract_remote_zip(self):
411422
self.check_extracted(extracted)
412423

413424

414-
class TestHFHubDownload(BaseTestCase):
425+
class TestHFHubDownload(BaseTestClass):
415426
def test_cached_download_no_user_or_org(self):
416427
path = get_cached_path("hf://t5-small/config.json", cache_dir=self.TEST_DIR)
417428
assert os.path.isfile(path)

0 commit comments

Comments
 (0)