Skip to content

refactor: switch to pathlib.Path in cvedb.py #1751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions cve_bin_tool/cvedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import asyncio
import datetime
import logging
import os
import shutil
import sqlite3
from pathlib import Path
from typing import Any

import requests
Expand All @@ -26,12 +26,10 @@
logging.basicConfig(level=logging.DEBUG)

# database defaults
DISK_LOCATION_DEFAULT = os.path.join(os.path.expanduser("~"), ".cache", "cve-bin-tool")
DISK_LOCATION_BACKUP = os.path.join(
os.path.expanduser("~"), ".cache", "cve-bin-tool-backup"
)
DISK_LOCATION_DEFAULT = Path("~").expanduser() / ".cache" / "cve-bin-tool"
DISK_LOCATION_BACKUP = Path("~").expanduser() / ".cache" / "cve-bin-tool-backup"
DBNAME = "cve.db"
OLD_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "cvedb")
OLD_CACHE_DIR = Path("~") / ".cache" / "cvedb"


class CVEDB:
Expand All @@ -58,9 +56,11 @@ def __init__(
if sources is not None
else [x(error_mode=error_mode) for x in self.SOURCES]
)
self.cachedir = cachedir if cachedir is not None else self.CACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.backup_cachedir = (
backup_cachedir if backup_cachedir is not None else self.BACKUPCACHEDIR
Path(backup_cachedir)
if backup_cachedir is not None
else self.BACKUPCACHEDIR
)
self.error_mode = error_mode

Expand All @@ -71,7 +71,7 @@ def __init__(
self.version_check = version_check

# set up the db if needed
self.dbpath = os.path.join(self.cachedir, DBNAME)
self.dbpath = self.cachedir / DBNAME
self.connection: sqlite3.Connection | None = None

self.data = []
Expand All @@ -81,7 +81,7 @@ def __init__(
self.exploits_list = []
self.exploit_count = 0

if not os.path.exists(self.dbpath):
if not self.dbpath.exists():
self.rollback_cache_backup()

def get_cve_count(self) -> int:
Expand All @@ -91,20 +91,20 @@ def get_cve_count(self) -> int:
return self.cve_count

def check_db_exists(self) -> bool:
return os.path.isfile(self.dbpath)
return self.dbpath.is_file()

def get_db_update_date(self) -> float:
# last time when CVE data was updated
self.time_of_last_update = datetime.datetime.fromtimestamp(
os.path.getmtime(self.dbpath)
self.dbpath.stat().st_mtime
)
return os.path.getmtime(self.dbpath)
return self.dbpath.stat().st_mtime

async def refresh(self) -> None:
"""Refresh the cve database and check for new version."""
# refresh the database
if not os.path.isdir(self.cachedir):
os.makedirs(self.cachedir)
if not self.cachedir.is_dir():
self.cachedir.mkdir(parents=True)

# check for the latest version
if self.version_check:
Expand All @@ -125,15 +125,15 @@ def get_cvelist_if_stale(self) -> None:
"""Update if the local db is more than one day old.
This avoids the full slow update with every execution.
"""
if not os.path.isfile(self.dbpath) or (
if not self.dbpath.is_file() or (
datetime.datetime.today()
- datetime.datetime.fromtimestamp(os.path.getmtime(self.dbpath))
- datetime.datetime.fromtimestamp(self.dbpath.stat().st_mtime)
) > datetime.timedelta(hours=24):
self.refresh_cache_and_update_db()
self.time_of_last_update = datetime.datetime.today()
else:
self.time_of_last_update = datetime.datetime.fromtimestamp(
os.path.getmtime(self.dbpath)
self.dbpath.stat().st_mtime
)
self.LOGGER.info(
"Using cached CVE data (<24h old). Use -u now to update immediately."
Expand Down Expand Up @@ -315,11 +315,11 @@ def populate_affected(self, affected_data, cursor):

def clear_cached_data(self) -> None:
self.create_cache_backup()
if os.path.exists(self.cachedir):
if self.cachedir.exists():
self.LOGGER.warning(f"Updating cachedir {self.cachedir}")
shutil.rmtree(self.cachedir)
# Remove files associated with pre-1.0 development tree
if os.path.exists(OLD_CACHE_DIR):
if OLD_CACHE_DIR.exists():
self.LOGGER.warning(f"Deleting old cachedir {OLD_CACHE_DIR}")
shutil.rmtree(OLD_CACHE_DIR)

Expand Down Expand Up @@ -381,7 +381,7 @@ def db_close(self) -> None:

def create_cache_backup(self) -> None:
"""Creates a backup of the cachedir in case anything fails"""
if os.path.exists(self.cachedir):
if self.cachedir.exists():
self.LOGGER.debug(
f"Creating backup of cachedir {self.cachedir} at {self.backup_cachedir}"
)
Expand All @@ -397,15 +397,15 @@ def copy_db(self, filename, export=True):

def remove_cache_backup(self) -> None:
"""Removes the backup if database was successfully loaded"""
if os.path.exists(self.backup_cachedir):
if self.backup_cachedir.exists():
self.LOGGER.debug(f"Removing backup cache from {self.backup_cachedir}")
shutil.rmtree(self.backup_cachedir)

def rollback_cache_backup(self) -> None:
"""Rollback the cachedir backup in case anything fails"""
if os.path.exists(os.path.join(self.backup_cachedir, DBNAME)):
if (self.backup_cachedir / DBNAME).exists():
self.LOGGER.info("Rolling back the cache to its previous state")
if os.path.exists(self.cachedir):
if self.cachedir.exists():
shutil.rmtree(self.cachedir)
shutil.move(self.backup_cachedir, self.cachedir)

Expand Down