Skip to content

Commit 5b95eb9

Browse files
committed
Use context managers in test_ssl to simplify test writing.
1 parent 17c0713 commit 5b95eb9

File tree

1 file changed

+38
-64
lines changed

1 file changed

+38
-64
lines changed

Lib/test/test_ssl.py

+38-64
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,14 @@ def __init__(self, certificate, ssl_version=None,
532532
threading.Thread.__init__(self)
533533
self.daemon = True
534534

535+
def __enter__(self):
536+
self.start(threading.Event())
537+
self.flag.wait()
538+
539+
def __exit__(self, *args):
540+
self.stop()
541+
self.join()
542+
535543
def start(self, flag=None):
536544
self.flag = flag
537545
threading.Thread.start(self)
@@ -638,6 +646,20 @@ def __init__(self, certfile):
638646
def __str__(self):
639647
return "<%s %s>" % (self.__class__.__name__, self.server)
640648

649+
def __enter__(self):
650+
self.start(threading.Event())
651+
self.flag.wait()
652+
653+
def __exit__(self, *args):
654+
if test_support.verbose:
655+
sys.stdout.write(" cleanup: stopping server.\n")
656+
self.stop()
657+
if test_support.verbose:
658+
sys.stdout.write(" cleanup: joining server thread.\n")
659+
self.join()
660+
if test_support.verbose:
661+
sys.stdout.write(" cleanup: successfully joined.\n")
662+
641663
def start(self, flag=None):
642664
self.flag = flag
643665
threading.Thread.start(self)
@@ -752,12 +774,7 @@ def bad_cert_test(certfile):
752774
server = ThreadedEchoServer(CERTFILE,
753775
certreqs=ssl.CERT_REQUIRED,
754776
cacerts=CERTFILE, chatty=False)
755-
flag = threading.Event()
756-
server.start(flag)
757-
# wait for it to start
758-
flag.wait()
759-
# try to connect
760-
try:
777+
with server:
761778
try:
762779
s = ssl.wrap_socket(socket.socket(),
763780
certfile=certfile,
@@ -771,9 +788,6 @@ def bad_cert_test(certfile):
771788
sys.stdout.write("\nsocket.error is %s\n" % x[1])
772789
else:
773790
raise AssertionError("Use of invalid cert should have failed!")
774-
finally:
775-
server.stop()
776-
server.join()
777791

778792
def server_params_test(certfile, protocol, certreqs, cacertsfile,
779793
client_certfile, client_protocol=None, indata="FOO\n",
@@ -791,14 +805,10 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
791805
chatty=chatty,
792806
connectionchatty=connectionchatty,
793807
wrap_accepting_socket=wrap_accepting_socket)
794-
flag = threading.Event()
795-
server.start(flag)
796-
# wait for it to start
797-
flag.wait()
798-
# try to connect
799-
if client_protocol is None:
800-
client_protocol = protocol
801-
try:
808+
with server:
809+
# try to connect
810+
if client_protocol is None:
811+
client_protocol = protocol
802812
s = ssl.wrap_socket(socket.socket(),
803813
certfile=client_certfile,
804814
ca_certs=cacertsfile,
@@ -826,9 +836,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
826836
if test_support.verbose:
827837
sys.stdout.write(" client: closing connection.\n")
828838
s.close()
829-
finally:
830-
server.stop()
831-
server.join()
832839

833840
def try_protocol_combo(server_protocol,
834841
client_protocol,
@@ -930,12 +937,7 @@ def test_getpeercert(self):
930937
ssl_version=ssl.PROTOCOL_SSLv23,
931938
cacerts=CERTFILE,
932939
chatty=False)
933-
flag = threading.Event()
934-
server.start(flag)
935-
# wait for it to start
936-
flag.wait()
937-
# try to connect
938-
try:
940+
with server:
939941
s = ssl.wrap_socket(socket.socket(),
940942
certfile=CERTFILE,
941943
ca_certs=CERTFILE,
@@ -957,9 +959,6 @@ def test_getpeercert(self):
957959
"Missing or invalid 'organizationName' field in certificate subject; "
958960
"should be 'Python Software Foundation'.")
959961
s.close()
960-
finally:
961-
server.stop()
962-
server.join()
963962

964963
def test_empty_cert(self):
965964
"""Connecting with an empty cert file"""
@@ -1042,13 +1041,8 @@ def test_starttls(self):
10421041
starttls_server=True,
10431042
chatty=True,
10441043
connectionchatty=True)
1045-
flag = threading.Event()
1046-
server.start(flag)
1047-
# wait for it to start
1048-
flag.wait()
1049-
# try to connect
10501044
wrapped = False
1051-
try:
1045+
with server:
10521046
s = socket.socket()
10531047
s.setblocking(1)
10541048
s.connect((HOST, server.port))
@@ -1093,9 +1087,6 @@ def test_starttls(self):
10931087
else:
10941088
s.send("over\n")
10951089
s.close()
1096-
finally:
1097-
server.stop()
1098-
server.join()
10991090

11001091
def test_socketserver(self):
11011092
"""Using a SocketServer to create and manage SSL connections."""
@@ -1145,12 +1136,7 @@ def test_asyncore_server(self):
11451136
if test_support.verbose:
11461137
sys.stdout.write("\n")
11471138
server = AsyncoreEchoServer(CERTFILE)
1148-
flag = threading.Event()
1149-
server.start(flag)
1150-
# wait for it to start
1151-
flag.wait()
1152-
# try to connect
1153-
try:
1139+
with server:
11541140
s = ssl.wrap_socket(socket.socket())
11551141
s.connect(('127.0.0.1', server.port))
11561142
if test_support.verbose:
@@ -1169,10 +1155,6 @@ def test_asyncore_server(self):
11691155
if test_support.verbose:
11701156
sys.stdout.write(" client: closing connection.\n")
11711157
s.close()
1172-
finally:
1173-
server.stop()
1174-
# wait for server thread to end
1175-
server.join()
11761158

11771159
def test_recv_send(self):
11781160
"""Test recv(), send() and friends."""
@@ -1185,19 +1167,14 @@ def test_recv_send(self):
11851167
cacerts=CERTFILE,
11861168
chatty=True,
11871169
connectionchatty=False)
1188-
flag = threading.Event()
1189-
server.start(flag)
1190-
# wait for it to start
1191-
flag.wait()
1192-
# try to connect
1193-
s = ssl.wrap_socket(socket.socket(),
1194-
server_side=False,
1195-
certfile=CERTFILE,
1196-
ca_certs=CERTFILE,
1197-
cert_reqs=ssl.CERT_NONE,
1198-
ssl_version=ssl.PROTOCOL_TLSv1)
1199-
s.connect((HOST, server.port))
1200-
try:
1170+
with server:
1171+
s = ssl.wrap_socket(socket.socket(),
1172+
server_side=False,
1173+
certfile=CERTFILE,
1174+
ca_certs=CERTFILE,
1175+
cert_reqs=ssl.CERT_NONE,
1176+
ssl_version=ssl.PROTOCOL_TLSv1)
1177+
s.connect((HOST, server.port))
12011178
# helper methods for standardising recv* method signatures
12021179
def _recv_into():
12031180
b = bytearray("\0"*100)
@@ -1285,9 +1262,6 @@ def _recvfrom_into():
12851262

12861263
s.write("over\n".encode("ASCII", "strict"))
12871264
s.close()
1288-
finally:
1289-
server.stop()
1290-
server.join()
12911265

12921266
def test_handshake_timeout(self):
12931267
# Issue #5103: SSL handshake must respect the socket timeout

0 commit comments

Comments
 (0)