@@ -532,6 +532,14 @@ def __init__(self, certificate, ssl_version=None,
532
532
threading .Thread .__init__ (self )
533
533
self .daemon = True
534
534
535
+ def __enter__ (self ):
536
+ self .start (threading .Event ())
537
+ self .flag .wait ()
538
+
539
+ def __exit__ (self , * args ):
540
+ self .stop ()
541
+ self .join ()
542
+
535
543
def start (self , flag = None ):
536
544
self .flag = flag
537
545
threading .Thread .start (self )
@@ -638,6 +646,20 @@ def __init__(self, certfile):
638
646
def __str__ (self ):
639
647
return "<%s %s>" % (self .__class__ .__name__ , self .server )
640
648
649
+ def __enter__ (self ):
650
+ self .start (threading .Event ())
651
+ self .flag .wait ()
652
+
653
+ def __exit__ (self , * args ):
654
+ if test_support .verbose :
655
+ sys .stdout .write (" cleanup: stopping server.\n " )
656
+ self .stop ()
657
+ if test_support .verbose :
658
+ sys .stdout .write (" cleanup: joining server thread.\n " )
659
+ self .join ()
660
+ if test_support .verbose :
661
+ sys .stdout .write (" cleanup: successfully joined.\n " )
662
+
641
663
def start (self , flag = None ):
642
664
self .flag = flag
643
665
threading .Thread .start (self )
@@ -752,12 +774,7 @@ def bad_cert_test(certfile):
752
774
server = ThreadedEchoServer (CERTFILE ,
753
775
certreqs = ssl .CERT_REQUIRED ,
754
776
cacerts = CERTFILE , chatty = False )
755
- flag = threading .Event ()
756
- server .start (flag )
757
- # wait for it to start
758
- flag .wait ()
759
- # try to connect
760
- try :
777
+ with server :
761
778
try :
762
779
s = ssl .wrap_socket (socket .socket (),
763
780
certfile = certfile ,
@@ -771,9 +788,6 @@ def bad_cert_test(certfile):
771
788
sys .stdout .write ("\n socket.error is %s\n " % x [1 ])
772
789
else :
773
790
raise AssertionError ("Use of invalid cert should have failed!" )
774
- finally :
775
- server .stop ()
776
- server .join ()
777
791
778
792
def server_params_test (certfile , protocol , certreqs , cacertsfile ,
779
793
client_certfile , client_protocol = None , indata = "FOO\n " ,
@@ -791,14 +805,10 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
791
805
chatty = chatty ,
792
806
connectionchatty = connectionchatty ,
793
807
wrap_accepting_socket = wrap_accepting_socket )
794
- flag = threading .Event ()
795
- server .start (flag )
796
- # wait for it to start
797
- flag .wait ()
798
- # try to connect
799
- if client_protocol is None :
800
- client_protocol = protocol
801
- try :
808
+ with server :
809
+ # try to connect
810
+ if client_protocol is None :
811
+ client_protocol = protocol
802
812
s = ssl .wrap_socket (socket .socket (),
803
813
certfile = client_certfile ,
804
814
ca_certs = cacertsfile ,
@@ -826,9 +836,6 @@ def server_params_test(certfile, protocol, certreqs, cacertsfile,
826
836
if test_support .verbose :
827
837
sys .stdout .write (" client: closing connection.\n " )
828
838
s .close ()
829
- finally :
830
- server .stop ()
831
- server .join ()
832
839
833
840
def try_protocol_combo (server_protocol ,
834
841
client_protocol ,
@@ -930,12 +937,7 @@ def test_getpeercert(self):
930
937
ssl_version = ssl .PROTOCOL_SSLv23 ,
931
938
cacerts = CERTFILE ,
932
939
chatty = False )
933
- flag = threading .Event ()
934
- server .start (flag )
935
- # wait for it to start
936
- flag .wait ()
937
- # try to connect
938
- try :
940
+ with server :
939
941
s = ssl .wrap_socket (socket .socket (),
940
942
certfile = CERTFILE ,
941
943
ca_certs = CERTFILE ,
@@ -957,9 +959,6 @@ def test_getpeercert(self):
957
959
"Missing or invalid 'organizationName' field in certificate subject; "
958
960
"should be 'Python Software Foundation'." )
959
961
s .close ()
960
- finally :
961
- server .stop ()
962
- server .join ()
963
962
964
963
def test_empty_cert (self ):
965
964
"""Connecting with an empty cert file"""
@@ -1042,13 +1041,8 @@ def test_starttls(self):
1042
1041
starttls_server = True ,
1043
1042
chatty = True ,
1044
1043
connectionchatty = True )
1045
- flag = threading .Event ()
1046
- server .start (flag )
1047
- # wait for it to start
1048
- flag .wait ()
1049
- # try to connect
1050
1044
wrapped = False
1051
- try :
1045
+ with server :
1052
1046
s = socket .socket ()
1053
1047
s .setblocking (1 )
1054
1048
s .connect ((HOST , server .port ))
@@ -1093,9 +1087,6 @@ def test_starttls(self):
1093
1087
else :
1094
1088
s .send ("over\n " )
1095
1089
s .close ()
1096
- finally :
1097
- server .stop ()
1098
- server .join ()
1099
1090
1100
1091
def test_socketserver (self ):
1101
1092
"""Using a SocketServer to create and manage SSL connections."""
@@ -1145,12 +1136,7 @@ def test_asyncore_server(self):
1145
1136
if test_support .verbose :
1146
1137
sys .stdout .write ("\n " )
1147
1138
server = AsyncoreEchoServer (CERTFILE )
1148
- flag = threading .Event ()
1149
- server .start (flag )
1150
- # wait for it to start
1151
- flag .wait ()
1152
- # try to connect
1153
- try :
1139
+ with server :
1154
1140
s = ssl .wrap_socket (socket .socket ())
1155
1141
s .connect (('127.0.0.1' , server .port ))
1156
1142
if test_support .verbose :
@@ -1169,10 +1155,6 @@ def test_asyncore_server(self):
1169
1155
if test_support .verbose :
1170
1156
sys .stdout .write (" client: closing connection.\n " )
1171
1157
s .close ()
1172
- finally :
1173
- server .stop ()
1174
- # wait for server thread to end
1175
- server .join ()
1176
1158
1177
1159
def test_recv_send (self ):
1178
1160
"""Test recv(), send() and friends."""
@@ -1185,19 +1167,14 @@ def test_recv_send(self):
1185
1167
cacerts = CERTFILE ,
1186
1168
chatty = True ,
1187
1169
connectionchatty = False )
1188
- flag = threading .Event ()
1189
- server .start (flag )
1190
- # wait for it to start
1191
- flag .wait ()
1192
- # try to connect
1193
- s = ssl .wrap_socket (socket .socket (),
1194
- server_side = False ,
1195
- certfile = CERTFILE ,
1196
- ca_certs = CERTFILE ,
1197
- cert_reqs = ssl .CERT_NONE ,
1198
- ssl_version = ssl .PROTOCOL_TLSv1 )
1199
- s .connect ((HOST , server .port ))
1200
- try :
1170
+ with server :
1171
+ s = ssl .wrap_socket (socket .socket (),
1172
+ server_side = False ,
1173
+ certfile = CERTFILE ,
1174
+ ca_certs = CERTFILE ,
1175
+ cert_reqs = ssl .CERT_NONE ,
1176
+ ssl_version = ssl .PROTOCOL_TLSv1 )
1177
+ s .connect ((HOST , server .port ))
1201
1178
# helper methods for standardising recv* method signatures
1202
1179
def _recv_into ():
1203
1180
b = bytearray ("\0 " * 100 )
@@ -1285,9 +1262,6 @@ def _recvfrom_into():
1285
1262
1286
1263
s .write ("over\n " .encode ("ASCII" , "strict" ))
1287
1264
s .close ()
1288
- finally :
1289
- server .stop ()
1290
- server .join ()
1291
1265
1292
1266
def test_handshake_timeout (self ):
1293
1267
# Issue #5103: SSL handshake must respect the socket timeout
0 commit comments