Skip to content

Commit 9bcdc1d

Browse files
authored
refactor: switch to pathlib.Path in cvedb.py (#1751)
* refactor: switch to pathlib.Path in cvedb.py * fix: windows tests
1 parent 6c13a9d commit 9bcdc1d

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

cve_bin_tool/cvedb.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import asyncio
1010
import datetime
1111
import logging
12-
import os
1312
import shutil
1413
import sqlite3
14+
from pathlib import Path
1515
from typing import Any
1616

1717
import requests
@@ -26,12 +26,10 @@
2626
logging.basicConfig(level=logging.DEBUG)
2727

2828
# database defaults
29-
DISK_LOCATION_DEFAULT = os.path.join(os.path.expanduser("~"), ".cache", "cve-bin-tool")
30-
DISK_LOCATION_BACKUP = os.path.join(
31-
os.path.expanduser("~"), ".cache", "cve-bin-tool-backup"
32-
)
29+
DISK_LOCATION_DEFAULT = Path("~").expanduser() / ".cache" / "cve-bin-tool"
30+
DISK_LOCATION_BACKUP = Path("~").expanduser() / ".cache" / "cve-bin-tool-backup"
3331
DBNAME = "cve.db"
34-
OLD_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "cvedb")
32+
OLD_CACHE_DIR = Path("~") / ".cache" / "cvedb"
3533

3634

3735
class CVEDB:
@@ -58,9 +56,11 @@ def __init__(
5856
if sources is not None
5957
else [x(error_mode=error_mode) for x in self.SOURCES]
6058
)
61-
self.cachedir = cachedir if cachedir is not None else self.CACHEDIR
59+
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
6260
self.backup_cachedir = (
63-
backup_cachedir if backup_cachedir is not None else self.BACKUPCACHEDIR
61+
Path(backup_cachedir)
62+
if backup_cachedir is not None
63+
else self.BACKUPCACHEDIR
6464
)
6565
self.error_mode = error_mode
6666

@@ -71,7 +71,7 @@ def __init__(
7171
self.version_check = version_check
7272

7373
# set up the db if needed
74-
self.dbpath = os.path.join(self.cachedir, DBNAME)
74+
self.dbpath = self.cachedir / DBNAME
7575
self.connection: sqlite3.Connection | None = None
7676

7777
self.data = []
@@ -81,7 +81,7 @@ def __init__(
8181
self.exploits_list = []
8282
self.exploit_count = 0
8383

84-
if not os.path.exists(self.dbpath):
84+
if not self.dbpath.exists():
8585
self.rollback_cache_backup()
8686

8787
def get_cve_count(self) -> int:
@@ -91,20 +91,20 @@ def get_cve_count(self) -> int:
9191
return self.cve_count
9292

9393
def check_db_exists(self) -> bool:
94-
return os.path.isfile(self.dbpath)
94+
return self.dbpath.is_file()
9595

9696
def get_db_update_date(self) -> float:
9797
# last time when CVE data was updated
9898
self.time_of_last_update = datetime.datetime.fromtimestamp(
99-
os.path.getmtime(self.dbpath)
99+
self.dbpath.stat().st_mtime
100100
)
101-
return os.path.getmtime(self.dbpath)
101+
return self.dbpath.stat().st_mtime
102102

103103
async def refresh(self) -> None:
104104
"""Refresh the cve database and check for new version."""
105105
# refresh the database
106-
if not os.path.isdir(self.cachedir):
107-
os.makedirs(self.cachedir)
106+
if not self.cachedir.is_dir():
107+
self.cachedir.mkdir(parents=True)
108108

109109
# check for the latest version
110110
if self.version_check:
@@ -125,15 +125,15 @@ def get_cvelist_if_stale(self) -> None:
125125
"""Update if the local db is more than one day old.
126126
This avoids the full slow update with every execution.
127127
"""
128-
if not os.path.isfile(self.dbpath) or (
128+
if not self.dbpath.is_file() or (
129129
datetime.datetime.today()
130-
- datetime.datetime.fromtimestamp(os.path.getmtime(self.dbpath))
130+
- datetime.datetime.fromtimestamp(self.dbpath.stat().st_mtime)
131131
) > datetime.timedelta(hours=24):
132132
self.refresh_cache_and_update_db()
133133
self.time_of_last_update = datetime.datetime.today()
134134
else:
135135
self.time_of_last_update = datetime.datetime.fromtimestamp(
136-
os.path.getmtime(self.dbpath)
136+
self.dbpath.stat().st_mtime
137137
)
138138
self.LOGGER.info(
139139
"Using cached CVE data (<24h old). Use -u now to update immediately."
@@ -315,11 +315,11 @@ def populate_affected(self, affected_data, cursor):
315315

316316
def clear_cached_data(self) -> None:
317317
self.create_cache_backup()
318-
if os.path.exists(self.cachedir):
318+
if self.cachedir.exists():
319319
self.LOGGER.warning(f"Updating cachedir {self.cachedir}")
320320
shutil.rmtree(self.cachedir)
321321
# Remove files associated with pre-1.0 development tree
322-
if os.path.exists(OLD_CACHE_DIR):
322+
if OLD_CACHE_DIR.exists():
323323
self.LOGGER.warning(f"Deleting old cachedir {OLD_CACHE_DIR}")
324324
shutil.rmtree(OLD_CACHE_DIR)
325325

@@ -381,7 +381,7 @@ def db_close(self) -> None:
381381

382382
def create_cache_backup(self) -> None:
383383
"""Creates a backup of the cachedir in case anything fails"""
384-
if os.path.exists(self.cachedir):
384+
if self.cachedir.exists():
385385
self.LOGGER.debug(
386386
f"Creating backup of cachedir {self.cachedir} at {self.backup_cachedir}"
387387
)
@@ -397,15 +397,15 @@ def copy_db(self, filename, export=True):
397397

398398
def remove_cache_backup(self) -> None:
399399
"""Removes the backup if database was successfully loaded"""
400-
if os.path.exists(self.backup_cachedir):
400+
if self.backup_cachedir.exists():
401401
self.LOGGER.debug(f"Removing backup cache from {self.backup_cachedir}")
402402
shutil.rmtree(self.backup_cachedir)
403403

404404
def rollback_cache_backup(self) -> None:
405405
"""Rollback the cachedir backup in case anything fails"""
406-
if os.path.exists(os.path.join(self.backup_cachedir, DBNAME)):
406+
if (self.backup_cachedir / DBNAME).exists():
407407
self.LOGGER.info("Rolling back the cache to its previous state")
408-
if os.path.exists(self.cachedir):
408+
if self.cachedir.exists():
409409
shutil.rmtree(self.cachedir)
410410
shutil.move(self.backup_cachedir, self.cachedir)
411411

0 commit comments

Comments
 (0)