Skip to content

Commit 9665857

Browse files
authored
Switch to rich for progress (#94)
* Switch to `rich` for progress * CHANGELOG * fix * more fixes * fix * allow sized progress bars * Clean up * clean up * fix * fix * mypy is awful * Fix up the wrapper file class
1 parent c08c9e7 commit 9665857

20 files changed

+315
-214
lines changed

CHANGELOG.md

+13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
### Added
11+
12+
- Added `quiet` parameter to `cached_path()` for turning off progress displays, and `progress` parameter for customizing displays.
13+
- Added `SchemeClient.get_size()` method.
14+
15+
### Changed
16+
17+
- Switched to `rich` for progress displays, removed dependency on `tqdm`.
18+
19+
### Removed
20+
21+
- Removed `file_friendly_logging()` function.
22+
1023
## [v1.1.2](https://github.com/allenai/cached_path/releases/tag/v1.1.2) - 2022-04-08
1124

1225
## [v1.1.1](https://github.com/allenai/cached_path/releases/tag/v1.1.1) - 2022-03-25

cached_path/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
"""
1212

1313
from ._cached_path import cached_path
14-
from .common import file_friendly_logging, get_cache_dir, set_cache_dir
14+
from .common import get_cache_dir, set_cache_dir
15+
from .progress import get_download_progress
1516
from .schemes import SchemeClient, add_scheme_client
1617
from .util import (
1718
check_tarfile,

cached_path/_cached_path.py

+51-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tarfile
55
import tempfile
66
from pathlib import Path
7-
from typing import Optional, Tuple
7+
from typing import TYPE_CHECKING, Optional, Tuple
88
from urllib.parse import urlparse
99
from zipfile import ZipFile, is_zipfile
1010

@@ -21,6 +21,9 @@
2121
resource_to_filename,
2222
)
2323

24+
if TYPE_CHECKING:
25+
from rich.progress import Progress
26+
2427
logger = logging.getLogger("cached_path")
2528

2629

@@ -29,6 +32,8 @@ def cached_path(
2932
cache_dir: Optional[PathOrStr] = None,
3033
extract_archive: bool = False,
3134
force_extract: bool = False,
35+
quiet: bool = False,
36+
progress: Optional["Progress"] = None,
3237
) -> Path:
3338
"""
3439
Given something that might be a URL or local path, determine which.
@@ -97,6 +102,13 @@ def cached_path(
97102
Use this flag with caution! This can lead to race conditions if used
98103
from multiple processes on the same file.
99104
105+
quiet :
106+
If ``True``, progress displays won't be printed.
107+
108+
progress :
109+
A custom progress display to use. If not set and ``quiet=False``, a default display
110+
from :func:`~cached_path.get_download_progress()` will be used.
111+
100112
Returns
101113
-------
102114
:class:`pathlib.Path`
@@ -133,7 +145,14 @@ def cached_path(
133145
file_name = url_or_filename[exclamation_index + 1 :]
134146

135147
# Call 'cached_path' recursively now to get the local path to the archive itself.
136-
cached_archive_path = cached_path(archive_path, cache_dir, True, force_extract)
148+
cached_archive_path = cached_path(
149+
archive_path,
150+
cache_dir=cache_dir,
151+
extract_archive=True,
152+
force_extract=force_extract,
153+
quiet=quiet,
154+
progress=progress,
155+
)
137156
if not cached_archive_path.is_dir():
138157
raise ValueError(
139158
f"{url_or_filename} uses the ! syntax, but does not specify an archive file."
@@ -151,7 +170,7 @@ def cached_path(
151170

152171
if parsed.scheme in get_supported_schemes():
153172
# URL, so get it from the cache (downloading if necessary)
154-
file_path, etag = get_from_cache(url_or_filename, cache_dir)
173+
file_path, etag = get_from_cache(url_or_filename, cache_dir, quiet=quiet, progress=progress)
155174

156175
if extract_archive and (is_zipfile(file_path) or tarfile.is_tarfile(file_path)):
157176
# This is the path the file should be extracted to.
@@ -243,7 +262,12 @@ def cached_path(
243262
return file_path
244263

245264

246-
def get_from_cache(url: str, cache_dir: Optional[PathOrStr] = None) -> Tuple[Path, Optional[str]]:
265+
def get_from_cache(
266+
url: str,
267+
cache_dir: Optional[PathOrStr] = None,
268+
quiet: bool = False,
269+
progress: Optional["Progress"] = None,
270+
) -> Tuple[Path, Optional[str]]:
247271
"""
248272
Given a URL, look for the corresponding dataset in the local cache.
249273
If it's not there, download it. Then return the path to the cached file and the ETag.
@@ -301,9 +325,31 @@ def get_from_cache(url: str, cache_dir: Optional[PathOrStr] = None) -> Tuple[Pat
301325
if os.path.exists(cache_path):
302326
logger.info("cache of %s is up-to-date", url)
303327
else:
328+
size = client.get_size()
304329
with CacheFile(cache_path) as cache_file:
305330
logger.info("%s not found in cache, downloading to %s", url, cache_path)
306-
client.get_resource(cache_file)
331+
332+
from .progress import BufferedWriterWithProgress, get_download_progress
333+
334+
start_and_cleanup = progress is None
335+
progress = progress or get_download_progress(quiet=quiet)
336+
337+
if start_and_cleanup:
338+
progress.start()
339+
340+
try:
341+
display_url = url if len(url) <= 50 else f"{url[:49]}\N{horizontal ellipsis}"
342+
task_id = progress.add_task(f"Downloading [cyan i]{display_url}[/]", total=size)
343+
writer_with_progress = BufferedWriterWithProgress(cache_file, progress, task_id)
344+
client.get_resource(writer_with_progress)
345+
progress.update(
346+
task_id,
347+
total=writer_with_progress.total_written,
348+
completed=writer_with_progress.total_written,
349+
)
350+
finally:
351+
if start_and_cleanup:
352+
progress.stop()
307353

308354
logger.debug("creating metadata file for %s", cache_path)
309355
meta = Meta.new(

cached_path/common.py

-21
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,6 @@
1414
"""
1515

1616

17-
def _parse_bool(value: Union[bool, str]) -> bool:
18-
if isinstance(value, bool):
19-
return value
20-
if value in {"1", "true", "True", "TRUE"}:
21-
return True
22-
return False
23-
24-
25-
FILE_FRIENDLY_LOGGING: bool = _parse_bool(os.environ.get("FILE_FRIENDLY_LOGGING", False))
26-
27-
2817
def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]:
2918
"""Split a full s3 path into the bucket name and path."""
3019
parsed = urlparse(url)
@@ -51,13 +40,3 @@ def get_cache_dir() -> Path:
5140
Get the global default cache directory.
5241
"""
5342
return Path(CACHE_DIRECTORY)
54-
55-
56-
def file_friendly_logging(on: bool = True) -> None:
57-
"""
58-
Turn on (or off) file-friendly logging globally.
59-
60-
You can also control this through the environment variable `FILE_FRIENDLY_LOGGING`.
61-
"""
62-
global FILE_FRIENDLY_LOGGING
63-
FILE_FRIENDLY_LOGGING = on

cached_path/progress.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import io
2+
from typing import List, Optional
3+
4+
from rich.progress import BarColumn, DownloadColumn, Progress, TaskID, TimeElapsedColumn
5+
6+
7+
class BufferedWriterWithProgress(io.BufferedWriter):
8+
def __init__(self, handle: io.BufferedWriter, progress: Progress, task_id: TaskID):
9+
self.handle = handle
10+
self.progress = progress
11+
self.task_id = task_id
12+
self.total_written = 0
13+
14+
def __enter__(self) -> "BufferedWriterWithProgress":
15+
self.handle.__enter__()
16+
return self
17+
18+
def __exit__(self, exc_type, exc_val, exc_tb):
19+
self.close()
20+
21+
@property
22+
def closed(self) -> bool:
23+
return self.handle.closed
24+
25+
def close(self):
26+
self.handle.close()
27+
28+
def fileno(self):
29+
return self.handle.fileno()
30+
31+
def flush(self):
32+
self.handle.flush()
33+
34+
def isatty(self) -> bool:
35+
return self.handle.isatty()
36+
37+
def readable(self) -> bool:
38+
return self.handle.readable()
39+
40+
def seekable(self) -> bool:
41+
return self.handle.seekable()
42+
43+
def writable(self) -> bool:
44+
return True
45+
46+
def read(self, size: Optional[int] = -1) -> bytes:
47+
return self.handle.read(size)
48+
49+
def read1(self, size: Optional[int] = -1) -> bytes:
50+
return self.handle.read1()
51+
52+
def readinto(self, b):
53+
return self.handle.readinto(b)
54+
55+
def readinto1(self, b):
56+
return self.handle.readinto1(b)
57+
58+
def readline(self, size: Optional[int] = -1) -> bytes:
59+
return self.handle.readline(size)
60+
61+
def readlines(self, hint: int = -1) -> List[bytes]:
62+
return self.handle.readlines(hint)
63+
64+
def write(self, b) -> int:
65+
n = self.handle.write(b)
66+
self.total_written += n
67+
self.progress.advance(self.task_id, n)
68+
return n
69+
70+
def writelines(self, lines):
71+
return self.handle.writelines(lines)
72+
73+
def seek(self, offset: int, whence: int = 0) -> int:
74+
pos = self.handle.seek(offset, whence)
75+
self.progress.update(self.task_id, completed=pos)
76+
return pos
77+
78+
def tell(self) -> int:
79+
return self.handle.tell()
80+
81+
@property
82+
def raw(self):
83+
return self.handle.raw
84+
85+
def detach(self):
86+
return self.handle.detach()
87+
88+
89+
def get_download_progress(quiet: bool = False) -> Progress:
90+
return Progress(
91+
"[progress.description]{task.description}",
92+
BarColumn(),
93+
"[progress.percentage]{task.percentage:>3.0f}%",
94+
TimeElapsedColumn(),
95+
DownloadColumn(),
96+
disable=quiet,
97+
)

cached_path/schemes/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def add_scheme_client(client: Type[SchemeClient]) -> None:
2727

2828

2929
for client in (HttpClient, S3Client, GsClient):
30-
add_scheme_client(client)
30+
add_scheme_client(client) # type: ignore
3131

3232

3333
def get_scheme_client(resource: str) -> SchemeClient:

cached_path/schemes/gs.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Google Cloud Storage.
33
"""
44

5-
from typing import IO, Optional, Tuple
5+
import io
6+
from typing import Optional, Tuple
67

78
from google.api_core.exceptions import NotFound
89
from google.auth.exceptions import DefaultCredentialsError
@@ -11,7 +12,6 @@
1112

1213
from cached_path.common import _split_cloud_path
1314
from cached_path.schemes.scheme_client import SchemeClient
14-
from cached_path.tqdm import Tqdm
1515

1616

1717
class GsClient(SchemeClient):
@@ -20,25 +20,27 @@ class GsClient(SchemeClient):
2020
def __init__(self, resource: str) -> None:
2121
super().__init__(resource)
2222
self.blob = GsClient.get_gcs_blob(resource)
23+
self._loaded = False
24+
25+
def load(self):
26+
if not self._loaded:
27+
try:
28+
self.blob.reload()
29+
self._loaded = True
30+
except NotFound:
31+
raise FileNotFoundError(self.resource)
2332

2433
def get_etag(self) -> Optional[str]:
25-
try:
26-
self.blob.reload()
27-
except NotFound:
28-
raise FileNotFoundError(self.resource)
34+
self.load()
2935
return self.blob.etag or self.blob.md5_hash
3036

31-
def get_resource(self, temp_file: IO) -> None:
32-
with Tqdm.wrapattr(
33-
temp_file,
34-
"write",
35-
unit="iB",
36-
unit_scale=True,
37-
unit_divisor=1024,
38-
total=self.blob.size,
39-
desc="downloading",
40-
) as file_obj:
41-
self.blob.download_to_file(file_obj, checksum="md5", retry=DEFAULT_RETRY)
37+
def get_size(self) -> Optional[int]:
38+
self.load()
39+
return self.blob.size
40+
41+
def get_resource(self, temp_file: io.BufferedWriter) -> None:
42+
self.load()
43+
self.blob.download_to_file(temp_file, checksum="md5", retry=DEFAULT_RETRY)
4244

4345
@staticmethod
4446
def split_gcs_path(resource: str) -> Tuple[str, str]:

0 commit comments

Comments
 (0)