|
5 | 5 | import glob
|
6 | 6 | import os
|
7 | 7 | import logging
|
8 |
| -import shutil |
9 | 8 | import tempfile
|
10 | 9 | import json
|
11 | 10 | from urllib.parse import urlparse
|
@@ -243,6 +242,46 @@ def _find_latest_cached(url: str, cache_dir: str) -> Optional[str]:
|
243 | 242 | return None
|
244 | 243 |
|
245 | 244 |
|
| 245 | +class CacheFile: |
| 246 | + """ |
| 247 | + This is a context manager that makes robust caching easier. |
| 248 | +
|
| 249 | + On `__enter__`, an IO handle to a temporarily file is returned, which can |
| 250 | + be treated as if it's the actual cache file. |
| 251 | +
|
| 252 | + On `__exit__`, the temporarily file is renamed to the cache file. If anything |
| 253 | + goes wrong while writing to the temporary file, it will be removed. |
| 254 | + """ |
| 255 | + |
| 256 | + def __init__(self, cache_filename: Union[Path, str], mode="w+b") -> None: |
| 257 | + self.cache_filename = ( |
| 258 | + cache_filename if isinstance(cache_filename, Path) else Path(cache_filename) |
| 259 | + ) |
| 260 | + self.cache_directory = os.path.dirname(self.cache_filename) |
| 261 | + self.mode = mode |
| 262 | + self.temp_file = tempfile.NamedTemporaryFile( |
| 263 | + self.mode, dir=self.cache_directory, delete=False, suffix=".tmp" |
| 264 | + ) |
| 265 | + |
| 266 | + def __enter__(self): |
| 267 | + return self.temp_file |
| 268 | + |
| 269 | + def __exit__(self, exc_type, exc_value, traceback): |
| 270 | + self.temp_file.close() |
| 271 | + if exc_value is None: |
| 272 | + # Success. |
| 273 | + logger.info( |
| 274 | + "Renaming temp file %s to cache at %s", self.temp_file.name, self.cache_filename |
| 275 | + ) |
| 276 | + # Rename the temp file to the actual cache filename. |
| 277 | + os.replace(self.temp_file.name, self.cache_filename) |
| 278 | + return True |
| 279 | + # Something went wrong, remove the temp file. |
| 280 | + logger.info("removing temp file %s", self.temp_file.name) |
| 281 | + os.remove(self.temp_file.name) |
| 282 | + return False |
| 283 | + |
| 284 | + |
246 | 285 | # TODO(joelgrus): do we want to do checksums or anything like that?
|
247 | 286 | def get_from_cache(url: str, cache_dir: str = None) -> str:
|
248 | 287 | """
|
@@ -303,33 +342,20 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
|
303 | 342 | if os.path.exists(cache_path):
|
304 | 343 | logger.info("cache of %s is up-to-date", url)
|
305 | 344 | else:
|
306 |
| - # Download to temporary file, then copy to cache dir once finished. |
307 |
| - # Otherwise you get corrupt cache entries if the download gets interrupted. |
308 |
| - with tempfile.NamedTemporaryFile() as temp_file: |
309 |
| - logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
| 345 | + with CacheFile(cache_path) as cache_file: |
| 346 | + logger.info("%s not found in cache, downloading to %s", url, cache_file.name) |
310 | 347 |
|
311 | 348 | # GET file object
|
312 | 349 | if url.startswith("s3://"):
|
313 |
| - _s3_get(url, temp_file) |
| 350 | + _s3_get(url, cache_file) |
314 | 351 | else:
|
315 |
| - _http_get(url, temp_file) |
316 |
| - |
317 |
| - # we are copying the file before closing it, so flush to avoid truncation |
318 |
| - temp_file.flush() |
319 |
| - # shutil.copyfileobj() starts at the current position, so go to the start |
320 |
| - temp_file.seek(0) |
321 |
| - |
322 |
| - logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
323 |
| - with open(cache_path, "wb") as cache_file: |
324 |
| - shutil.copyfileobj(temp_file, cache_file) # type: ignore |
325 |
| - |
326 |
| - logger.info("creating metadata file for %s", cache_path) |
327 |
| - meta = {"url": url, "etag": etag} |
328 |
| - meta_path = cache_path + ".json" |
329 |
| - with open(meta_path, "w") as meta_file: |
330 |
| - json.dump(meta, meta_file) |
| 352 | + _http_get(url, cache_file) |
331 | 353 |
|
332 |
| - logger.info("removing temp file %s", temp_file.name) |
| 354 | + logger.info("creating metadata file for %s", cache_path) |
| 355 | + meta = {"url": url, "etag": etag} |
| 356 | + meta_path = cache_path + ".json" |
| 357 | + with open(meta_path, "w") as meta_file: |
| 358 | + json.dump(meta, meta_file) |
333 | 359 |
|
334 | 360 | return cache_path
|
335 | 361 |
|
|
0 commit comments