9
9
import asyncio
10
10
import datetime
11
11
import logging
12
- import os
13
12
import shutil
14
13
import sqlite3
14
+ from pathlib import Path
15
15
from typing import Any
16
16
17
17
import requests
26
26
logging .basicConfig (level = logging .DEBUG )
27
27
28
28
# database defaults
29
- DISK_LOCATION_DEFAULT = os .path .join (os .path .expanduser ("~" ), ".cache" , "cve-bin-tool" )
30
- DISK_LOCATION_BACKUP = os .path .join (
31
- os .path .expanduser ("~" ), ".cache" , "cve-bin-tool-backup"
32
- )
29
+ DISK_LOCATION_DEFAULT = Path ("~" ).expanduser () / ".cache" / "cve-bin-tool"
30
+ DISK_LOCATION_BACKUP = Path ("~" ).expanduser () / ".cache" / "cve-bin-tool-backup"
33
31
DBNAME = "cve.db"
34
- OLD_CACHE_DIR = os . path . join ( os . path . expanduser ( "~" ), ".cache" , "cvedb" )
32
+ OLD_CACHE_DIR = Path ( "~" ) / ".cache" / "cvedb"
35
33
36
34
37
35
class CVEDB :
@@ -58,9 +56,11 @@ def __init__(
58
56
if sources is not None
59
57
else [x (error_mode = error_mode ) for x in self .SOURCES ]
60
58
)
61
- self .cachedir = cachedir if cachedir is not None else self .CACHEDIR
59
+ self .cachedir = Path ( cachedir ) if cachedir is not None else self .CACHEDIR
62
60
self .backup_cachedir = (
63
- backup_cachedir if backup_cachedir is not None else self .BACKUPCACHEDIR
61
+ Path (backup_cachedir )
62
+ if backup_cachedir is not None
63
+ else self .BACKUPCACHEDIR
64
64
)
65
65
self .error_mode = error_mode
66
66
@@ -71,7 +71,7 @@ def __init__(
71
71
self .version_check = version_check
72
72
73
73
# set up the db if needed
74
- self .dbpath = os . path . join ( self .cachedir , DBNAME )
74
+ self .dbpath = self .cachedir / DBNAME
75
75
self .connection : sqlite3 .Connection | None = None
76
76
77
77
self .data = []
@@ -81,7 +81,7 @@ def __init__(
81
81
self .exploits_list = []
82
82
self .exploit_count = 0
83
83
84
- if not os . path .exists (self . dbpath ):
84
+ if not self . dbpath .exists ():
85
85
self .rollback_cache_backup ()
86
86
87
87
def get_cve_count (self ) -> int :
@@ -91,20 +91,20 @@ def get_cve_count(self) -> int:
91
91
return self .cve_count
92
92
93
93
def check_db_exists (self ) -> bool :
94
- return os . path . isfile ( self .dbpath )
94
+ return self .dbpath . is_file ( )
95
95
96
96
def get_db_update_date (self ) -> float :
97
97
# last time when CVE data was updated
98
98
self .time_of_last_update = datetime .datetime .fromtimestamp (
99
- os . path . getmtime ( self .dbpath )
99
+ self .dbpath . stat (). st_mtime
100
100
)
101
- return os . path . getmtime ( self .dbpath )
101
+ return self .dbpath . stat (). st_mtime
102
102
103
103
async def refresh (self ) -> None :
104
104
"""Refresh the cve database and check for new version."""
105
105
# refresh the database
106
- if not os . path . isdir ( self .cachedir ):
107
- os . makedirs ( self .cachedir )
106
+ if not self .cachedir . is_dir ( ):
107
+ self .cachedir . mkdir ( parents = True )
108
108
109
109
# check for the latest version
110
110
if self .version_check :
@@ -125,15 +125,15 @@ def get_cvelist_if_stale(self) -> None:
125
125
"""Update if the local db is more than one day old.
126
126
This avoids the full slow update with every execution.
127
127
"""
128
- if not os . path . isfile ( self .dbpath ) or (
128
+ if not self .dbpath . is_file ( ) or (
129
129
datetime .datetime .today ()
130
- - datetime .datetime .fromtimestamp (os . path . getmtime ( self .dbpath ) )
130
+ - datetime .datetime .fromtimestamp (self .dbpath . stat (). st_mtime )
131
131
) > datetime .timedelta (hours = 24 ):
132
132
self .refresh_cache_and_update_db ()
133
133
self .time_of_last_update = datetime .datetime .today ()
134
134
else :
135
135
self .time_of_last_update = datetime .datetime .fromtimestamp (
136
- os . path . getmtime ( self .dbpath )
136
+ self .dbpath . stat (). st_mtime
137
137
)
138
138
self .LOGGER .info (
139
139
"Using cached CVE data (<24h old). Use -u now to update immediately."
@@ -315,11 +315,11 @@ def populate_affected(self, affected_data, cursor):
315
315
316
316
def clear_cached_data (self ) -> None :
317
317
self .create_cache_backup ()
318
- if os . path .exists (self . cachedir ):
318
+ if self . cachedir .exists ():
319
319
self .LOGGER .warning (f"Updating cachedir { self .cachedir } " )
320
320
shutil .rmtree (self .cachedir )
321
321
# Remove files associated with pre-1.0 development tree
322
- if os . path . exists (OLD_CACHE_DIR ):
322
+ if OLD_CACHE_DIR . exists ():
323
323
self .LOGGER .warning (f"Deleting old cachedir { OLD_CACHE_DIR } " )
324
324
shutil .rmtree (OLD_CACHE_DIR )
325
325
@@ -381,7 +381,7 @@ def db_close(self) -> None:
381
381
382
382
def create_cache_backup (self ) -> None :
383
383
"""Creates a backup of the cachedir in case anything fails"""
384
- if os . path .exists (self . cachedir ):
384
+ if self . cachedir .exists ():
385
385
self .LOGGER .debug (
386
386
f"Creating backup of cachedir { self .cachedir } at { self .backup_cachedir } "
387
387
)
@@ -397,15 +397,15 @@ def copy_db(self, filename, export=True):
397
397
398
398
def remove_cache_backup (self ) -> None :
399
399
"""Removes the backup if database was successfully loaded"""
400
- if os . path .exists (self . backup_cachedir ):
400
+ if self . backup_cachedir .exists ():
401
401
self .LOGGER .debug (f"Removing backup cache from { self .backup_cachedir } " )
402
402
shutil .rmtree (self .backup_cachedir )
403
403
404
404
def rollback_cache_backup (self ) -> None :
405
405
"""Rollback the cachedir backup in case anything fails"""
406
- if os . path . exists ( os . path . join ( self .backup_cachedir , DBNAME )):
406
+ if ( self .backup_cachedir / DBNAME ). exists ( ):
407
407
self .LOGGER .info ("Rolling back the cache to its previous state" )
408
- if os . path .exists (self . cachedir ):
408
+ if self . cachedir .exists ():
409
409
shutil .rmtree (self .cachedir )
410
410
shutil .move (self .backup_cachedir , self .cachedir )
411
411
0 commit comments