diff --git a/cve_bin_tool/cvedb.py b/cve_bin_tool/cvedb.py index e4bb1837c2..fe492ca752 100644 --- a/cve_bin_tool/cvedb.py +++ b/cve_bin_tool/cvedb.py @@ -9,9 +9,9 @@ import asyncio import datetime import logging -import os import shutil import sqlite3 +from pathlib import Path from typing import Any import requests @@ -26,12 +26,10 @@ logging.basicConfig(level=logging.DEBUG) # database defaults -DISK_LOCATION_DEFAULT = os.path.join(os.path.expanduser("~"), ".cache", "cve-bin-tool") -DISK_LOCATION_BACKUP = os.path.join( - os.path.expanduser("~"), ".cache", "cve-bin-tool-backup" -) +DISK_LOCATION_DEFAULT = Path("~").expanduser() / ".cache" / "cve-bin-tool" +DISK_LOCATION_BACKUP = Path("~").expanduser() / ".cache" / "cve-bin-tool-backup" DBNAME = "cve.db" -OLD_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "cvedb") +OLD_CACHE_DIR = Path("~") / ".cache" / "cvedb" class CVEDB: @@ -58,9 +56,11 @@ def __init__( if sources is not None else [x(error_mode=error_mode) for x in self.SOURCES] ) - self.cachedir = cachedir if cachedir is not None else self.CACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR self.backup_cachedir = ( - backup_cachedir if backup_cachedir is not None else self.BACKUPCACHEDIR + Path(backup_cachedir) + if backup_cachedir is not None + else self.BACKUPCACHEDIR ) self.error_mode = error_mode @@ -71,7 +71,7 @@ def __init__( self.version_check = version_check # set up the db if needed - self.dbpath = os.path.join(self.cachedir, DBNAME) + self.dbpath = self.cachedir / DBNAME self.connection: sqlite3.Connection | None = None self.data = [] @@ -81,7 +81,7 @@ def __init__( self.exploits_list = [] self.exploit_count = 0 - if not os.path.exists(self.dbpath): + if not self.dbpath.exists(): self.rollback_cache_backup() def get_cve_count(self) -> int: @@ -91,20 +91,20 @@ def get_cve_count(self) -> int: return self.cve_count def check_db_exists(self) -> bool: - return os.path.isfile(self.dbpath) + return self.dbpath.is_file() def get_db_update_date(self) -> float: # last time when CVE data was updated self.time_of_last_update = datetime.datetime.fromtimestamp( - os.path.getmtime(self.dbpath) + self.dbpath.stat().st_mtime ) - return os.path.getmtime(self.dbpath) + return self.dbpath.stat().st_mtime async def refresh(self) -> None: """Refresh the cve database and check for new version.""" # refresh the database - if not os.path.isdir(self.cachedir): - os.makedirs(self.cachedir) + if not self.cachedir.is_dir(): + self.cachedir.mkdir(parents=True) # check for the latest version if self.version_check: @@ -125,15 +125,15 @@ def get_cvelist_if_stale(self) -> None: """Update if the local db is more than one day old. This avoids the full slow update with every execution. """ - if not os.path.isfile(self.dbpath) or ( + if not self.dbpath.is_file() or ( datetime.datetime.today() - - datetime.datetime.fromtimestamp(os.path.getmtime(self.dbpath)) + - datetime.datetime.fromtimestamp(self.dbpath.stat().st_mtime) ) > datetime.timedelta(hours=24): self.refresh_cache_and_update_db() self.time_of_last_update = datetime.datetime.today() else: self.time_of_last_update = datetime.datetime.fromtimestamp( - os.path.getmtime(self.dbpath) + self.dbpath.stat().st_mtime ) self.LOGGER.info( "Using cached CVE data (<24h old). Use -u now to update immediately." @@ -315,11 +315,11 @@ def populate_affected(self, affected_data, cursor): def clear_cached_data(self) -> None: self.create_cache_backup() - if os.path.exists(self.cachedir): + if self.cachedir.exists(): self.LOGGER.warning(f"Updating cachedir {self.cachedir}") shutil.rmtree(self.cachedir) # Remove files associated with pre-1.0 development tree - if os.path.exists(OLD_CACHE_DIR): + if OLD_CACHE_DIR.exists(): self.LOGGER.warning(f"Deleting old cachedir {OLD_CACHE_DIR}") shutil.rmtree(OLD_CACHE_DIR) @@ -381,7 +381,7 @@ def db_close(self) -> None: def create_cache_backup(self) -> None: """Creates a backup of the cachedir in case anything fails""" - if os.path.exists(self.cachedir): + if self.cachedir.exists(): self.LOGGER.debug( f"Creating backup of cachedir {self.cachedir} at {self.backup_cachedir}" ) @@ -397,15 +397,15 @@ def copy_db(self, filename, export=True): def remove_cache_backup(self) -> None: """Removes the backup if database was successfully loaded""" - if os.path.exists(self.backup_cachedir): + if self.backup_cachedir.exists(): self.LOGGER.debug(f"Removing backup cache from {self.backup_cachedir}") shutil.rmtree(self.backup_cachedir) def rollback_cache_backup(self) -> None: """Rollback the cachedir backup in case anything fails""" - if os.path.exists(os.path.join(self.backup_cachedir, DBNAME)): + if (self.backup_cachedir / DBNAME).exists(): self.LOGGER.info("Rolling back the cache to its previous state") - if os.path.exists(self.cachedir): + if self.cachedir.exists(): shutil.rmtree(self.cachedir) shutil.move(self.backup_cachedir, self.cachedir)