@@ -38,19 +38,32 @@ def __init__(self):
38
38
"%s must be defined, see https://cloud.google.com/kms/docs/reference/rest"
39
39
% env_names [0 ],
40
40
)
41
- service_file_inline = os .getenv ("GOOGLE_KMS_SERVICE_ACCOUNT_JSON_INLINE" )
41
+
42
+ self ._evars = evars
43
+
44
+ service_file_inline = os .getenv (
45
+ "GOOGLE_KMS_SERVICE_ACCOUNT_JSON_INLINE" )
42
46
if service_file_inline is not None :
43
- service_file = io .StringIO (service_file_inline )
47
+ self . _service_file = io .StringIO (service_file_inline )
44
48
else :
45
- service_file = os .getenv ("GOOGLE_KMS_SERVICE_ACCOUNT_JSON" )
46
- self ._kms = KMS (** evars , service_file = service_file )
47
- self .log .info ("Using Google KMS %(keyproject)s/%(keyring)s/%(keyname)s" , evars )
49
+ self ._service_file = os .getenv ("GOOGLE_KMS_SERVICE_ACCOUNT_JSON" )
50
+
51
+ self ._kms = None
52
+ self .log .info (
53
+ "Using Google KMS %(keyproject)s/%(keyring)s/%(keyname)s" , evars )
54
+
55
+ async def _get_kms (self ) -> KMS :
56
+ if self ._kms is None :
57
+ self ._kms = KMS (** self ._evars , service_file = self ._service_file )
58
+
59
+ return self ._kms
48
60
49
61
async def encrypt (self , plaintext : Union [bytes , str ]) -> str :
50
62
"""Encrypt text using Google KMS."""
63
+ kms = await self ._get_kms ()
51
64
for attempt in range (self .timeout_retries ):
52
65
try :
53
- return await self . _kms .encrypt (encode (plaintext ))
66
+ return await kms .encrypt (encode (plaintext ))
54
67
except asyncio .TimeoutError as e :
55
68
self .log .warning ("encrypt attempt %d" , attempt + 1 )
56
69
if attempt == self .timeout_retries - 1 :
@@ -59,9 +72,10 @@ async def encrypt(self, plaintext: Union[bytes, str]) -> str:
59
72
async def decrypt (self , ciphertext : str ) -> bytes :
60
73
"""Decrypt text using Google KMS."""
61
74
# we cannot use gcloud.aio.kms.decode because it converts bytes to string with str.decode()
75
+ kms = await self ._get_kms ()
62
76
for attempt in range (self .timeout_retries ):
63
77
try :
64
- payload = await self . _kms .decrypt (ciphertext )
78
+ payload = await kms .decrypt (ciphertext )
65
79
except asyncio .TimeoutError as e :
66
80
self .log .warning ("decrypt attempt %d" , attempt + 1 )
67
81
if attempt == self .timeout_retries - 1 :
@@ -72,4 +86,5 @@ async def decrypt(self, ciphertext: str) -> bytes:
72
86
73
87
async def close (self ):
74
88
"""Close the underlying HTTPS session."""
75
- await self ._kms .close ()
89
+ if self ._kms is not None :
90
+ await self ._kms .close ()
0 commit comments