4
4
import logging
5
5
import socket
6
6
import select
7
+ import struct
7
8
import threading
8
9
import subprocess
9
- import typing
10
+ from typing import Callable , ContextManager
10
11
11
12
from tests .containers .cancellation_token import CancellationToken
12
13
@@ -50,7 +51,7 @@ def stop(self):
50
51
class SocketProxy :
51
52
def __init__ (
52
53
self ,
53
- remote_socket_factory : typing . ContextManager [socket .socket ],
54
+ remote_socket_factory : Callable [..., ContextManager [socket .socket ] ],
54
55
local_host : str = "localhost" ,
55
56
local_port : int = 0 ,
56
57
buffer_size : int = 4096
@@ -81,9 +82,14 @@ def listen_and_serve_until_canceled(self):
81
82
Handles at most one client at a time. """
82
83
try :
83
84
while not self .cancellation_token .cancelled :
84
- client_socket , addr = self .server_socket .accept ()
85
- logging .info (f"Accepted connection from { addr [0 ]} :{ addr [1 ]} " )
86
- self ._handle_client (client_socket )
85
+ readable , _ , _ = select .select ([self .server_socket , self .cancellation_token ], [], [])
86
+
87
+ # ISSUE-922: socket.accept() blocks, so if cancel() did not come very fast, we'd loop over and block
88
+ if self .server_socket in readable :
89
+ client_socket , addr = self .server_socket .accept ()
90
+ logging .info (f"Accepted connection from { addr [0 ]} :{ addr [1 ]} " )
91
+ # handle client synchronously, which means that there can be at most one at a time
92
+ self ._handle_client (client_socket )
87
93
except Exception as e :
88
94
logging .exception (f"Proxying failed to listen" , exc_info = e )
89
95
raise
@@ -96,27 +102,38 @@ def get_actual_port(self) -> int:
96
102
return self .server_socket .getsockname ()[1 ]
97
103
98
104
def _handle_client (self , client_socket ):
99
- with client_socket as _ , self .remote_socket_factory as remote_socket :
100
- while True :
105
+ with client_socket as _ , self .remote_socket_factory () as remote_socket :
106
+ while not self . cancellation_token . cancelled :
101
107
readable , _ , _ = select .select ([client_socket , remote_socket , self .cancellation_token ], [], [])
102
108
103
- if self .cancellation_token .cancelled :
104
- break
105
-
106
109
if client_socket in readable :
107
110
data = client_socket .recv (self .buffer_size )
108
111
if not data :
109
112
break
110
113
remote_socket .send (data )
111
114
112
115
if remote_socket in readable :
113
- data = remote_socket .recv (self .buffer_size )
116
+ try :
117
+ data = remote_socket .recv (self .buffer_size )
118
+ except ConnectionResetError :
119
+ # ISSUE-922: it seems best to propagate the error and let the client retry
120
+ # alternatively it would be necessary to resend anything already received from client_socket
121
+ _rst_socket (client_socket )
122
+ break
114
123
if not data :
115
124
break
116
125
client_socket .send (data )
117
126
118
127
119
- if __name__ == "__main__" :
128
+ def _rst_socket (s : socket ):
129
+ """Closing a SO_LINGER socket will RST it
130
+ https://stackoverflow.com/questions/46264404/how-can-i-reset-a-tcp-socket-in-python
131
+ """
132
+ s .setsockopt (socket .SOL_SOCKET , socket .SO_LINGER , struct .pack ('ii' , 1 , 0 ))
133
+ s .close ()
134
+
135
+
136
+ def main () -> None :
120
137
"""Sample application to show how this can work."""
121
138
122
139
@@ -161,13 +178,21 @@ def get_actual_port(self):
161
178
server .join ()
162
179
163
180
164
- proxy = SocketProxy (remote_socket_factory () , "localhost" , 0 )
181
+ proxy = SocketProxy (remote_socket_factory , "localhost" , 0 )
165
182
thread = threading .Thread (target = proxy .listen_and_serve_until_canceled )
166
183
thread .start ()
167
184
168
- client_socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
169
- client_socket .connect (("localhost" , proxy .get_actual_port ()))
185
+ for _ in range (2 ):
186
+ client_socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
187
+ client_socket .connect (("localhost" , proxy .get_actual_port ()))
170
188
171
- print (client_socket .recv (1024 )) # prints Hello World
189
+ print (client_socket .recv (1024 )) # prints Hello World
190
+ print (client_socket .recv (1024 )) # prints nothing
191
+ client_socket .close ()
192
+ proxy .cancellation_token .cancel ()
172
193
173
194
thread .join ()
195
+
196
+
197
+ if __name__ == "__main__" :
198
+ main ()
0 commit comments