23
23
import struct
24
24
from typing import Any , Callable , Literal , Optional
25
25
26
- from cryptography .x509 import (
27
- CertificateBuilder as x509_CertificateBuilder ,
28
- NameAttribute as x509_NameAttribute ,
29
- random_serial_number as x509_random_serial_number ,
30
- SubjectAlternativeName as x509_SubjectAlternativeName ,
31
- IPAddress as x509_IPAddress ,
32
- DNSName as x509_DNSName ,
33
- load_pem_x509_certificate as x509_load_pem_x509_certificate ,
34
- Name as x509_Name ,
35
- )
26
+ from cryptography import x509
36
27
from cryptography .hazmat .primitives import hashes
37
28
from cryptography .hazmat .primitives import serialization
38
29
from cryptography .hazmat .primitives .asymmetric import rsa
@@ -98,7 +89,7 @@ def token_state(
98
89
99
90
def generate_cert (
100
91
common_name : str , expires_in : int = 60 , server_cert : bool = False
101
- ) -> tuple [x509_CertificateBuilder , rsa .RSAPrivateKey ]:
92
+ ) -> tuple [x509 . CertificateBuilder , rsa .RSAPrivateKey ]:
102
93
"""
103
94
Generate a private key and cert object to be used in testing.
104
95
@@ -108,40 +99,40 @@ def generate_cert(
108
99
server_cert (bool): Whether it is a server certificate.
109
100
110
101
Returns:
111
- tuple[x509_CertificateBuilder , rsa.RSAPrivateKey]
102
+ tuple[x509.CertificateBuilder , rsa.RSAPrivateKey]
112
103
"""
113
104
# generate private key
114
105
key = rsa .generate_private_key (public_exponent = 65537 , key_size = 2048 )
115
106
# calculate expiry time
116
107
now = datetime .now (timezone .utc )
117
108
expiration = now + timedelta (minutes = expires_in )
118
109
# configure cert subject
119
- subject = issuer = x509_Name (
110
+ subject = issuer = x509 . Name (
120
111
[
121
- x509_NameAttribute (NameOID .COUNTRY_NAME , "US" ),
122
- x509_NameAttribute (NameOID .STATE_OR_PROVINCE_NAME , "California" ),
123
- x509_NameAttribute (NameOID .LOCALITY_NAME , "Mountain View" ),
124
- x509_NameAttribute (NameOID .ORGANIZATION_NAME , "Google Inc" ),
125
- x509_NameAttribute (NameOID .COMMON_NAME , common_name ),
112
+ x509 . NameAttribute (NameOID .COUNTRY_NAME , "US" ),
113
+ x509 . NameAttribute (NameOID .STATE_OR_PROVINCE_NAME , "California" ),
114
+ x509 . NameAttribute (NameOID .LOCALITY_NAME , "Mountain View" ),
115
+ x509 . NameAttribute (NameOID .ORGANIZATION_NAME , "Google Inc" ),
116
+ x509 . NameAttribute (NameOID .COMMON_NAME , common_name ),
126
117
]
127
118
)
128
119
# build cert
129
120
cert = (
130
- x509_CertificateBuilder ()
121
+ x509 . CertificateBuilder ()
131
122
.subject_name (subject )
132
123
.issuer_name (issuer )
133
124
.public_key (key .public_key ())
134
- .serial_number (x509_random_serial_number ())
125
+ .serial_number (x509 . random_serial_number ())
135
126
.not_valid_before (now )
136
127
.not_valid_after (expiration )
137
128
)
138
129
if server_cert :
139
130
cert = cert .add_extension (
140
- x509_SubjectAlternativeName (
131
+ x509 . SubjectAlternativeName (
141
132
general_names = [
142
- x509_IPAddress (ipaddress .ip_address ("127.0.0.1" )),
143
- x509_IPAddress (ipaddress .ip_address ("10.0.0.1" )),
144
- x509_DNSName ("x.y.alloydb.goog." ),
133
+ x509 . IPAddress (ipaddress .ip_address ("127.0.0.1" )),
134
+ x509 . IPAddress (ipaddress .ip_address ("10.0.0.1" )),
135
+ x509 . DNSName ("x.y.alloydb.goog." ),
145
136
]
146
137
),
147
138
critical = False ,
@@ -215,11 +206,11 @@ def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]:
215
206
)
216
207
# build client cert
217
208
client_cert = (
218
- x509_CertificateBuilder ()
209
+ x509 . CertificateBuilder ()
219
210
.subject_name (self .intermediate_cert .subject )
220
211
.issuer_name (self .intermediate_cert .issuer )
221
212
.public_key (pub_key_bytes )
222
- .serial_number (x509_random_serial_number ())
213
+ .serial_number (x509 . random_serial_number ())
223
214
.not_valid_before (self .cert_before )
224
215
.not_valid_after (self .cert_expiry )
225
216
)
@@ -262,11 +253,11 @@ async def _get_client_certificate(
262
253
)
263
254
# build client cert
264
255
client_cert = (
265
- x509_CertificateBuilder ()
256
+ x509 . CertificateBuilder ()
266
257
.subject_name (self .instance .intermediate_cert .subject )
267
258
.issuer_name (self .instance .intermediate_cert .issuer )
268
259
.public_key (pub_key_bytes )
269
- .serial_number (x509_random_serial_number ())
260
+ .serial_number (x509 . random_serial_number ())
270
261
.not_valid_before (self .instance .cert_before )
271
262
.not_valid_after (self .instance .cert_expiry )
272
263
)
@@ -315,7 +306,7 @@ async def get_connection_info(
315
306
# unpack certs
316
307
ca_cert , cert_chain = certs
317
308
# get expiration from client certificate
318
- cert_obj = x509_load_pem_x509_certificate (cert_chain [0 ].encode ("UTF-8" ))
309
+ cert_obj = x509 . load_pem_x509_certificate (cert_chain [0 ].encode ("UTF-8" ))
319
310
expiration = cert_obj .not_valid_after_utc
320
311
321
312
return ConnectionInfo (
0 commit comments