6
6
import select
7
7
import threading
8
8
import subprocess
9
- import typing
9
+ import time
10
+ from typing import Callable , ContextManager
10
11
11
12
from tests .containers .cancellation_token import CancellationToken
12
13
@@ -50,7 +51,7 @@ def stop(self):
50
51
class SocketProxy :
51
52
def __init__ (
52
53
self ,
53
- remote_socket_factory : typing . ContextManager [socket .socket ],
54
+ remote_socket_factory : Callable [..., ContextManager [socket .socket ] ],
54
55
local_host : str = "localhost" ,
55
56
local_port : int = 0 ,
56
57
buffer_size : int = 4096
@@ -96,27 +97,32 @@ def get_actual_port(self) -> int:
96
97
return self .server_socket .getsockname ()[1 ]
97
98
98
99
def _handle_client (self , client_socket ):
99
- with client_socket as _ , self .remote_socket_factory as remote_socket :
100
- while True :
101
- readable , _ , _ = select .select ([client_socket , remote_socket , self .cancellation_token ], [], [])
102
-
103
- if self .cancellation_token .cancelled :
104
- break
105
-
106
- if client_socket in readable :
107
- data = client_socket .recv (self .buffer_size )
108
- if not data :
109
- break
110
- remote_socket .send (data )
111
-
112
- if remote_socket in readable :
113
- data = remote_socket .recv (self .buffer_size )
114
- if not data :
115
- break
116
- client_socket .send (data )
117
-
118
-
119
- if __name__ == "__main__" :
100
+ with client_socket as _ :
101
+ while not self .cancellation_token .cancelled :
102
+ try :
103
+ with self .remote_socket_factory () as remote_socket :
104
+ while not self .cancellation_token .cancelled :
105
+ readable , _ , _ = select .select ([client_socket , remote_socket , self .cancellation_token ], [], [])
106
+
107
+ if client_socket in readable :
108
+ data = client_socket .recv (self .buffer_size )
109
+ if not data :
110
+ return
111
+ remote_socket .send (data )
112
+
113
+ if remote_socket in readable :
114
+ data = remote_socket .recv (self .buffer_size )
115
+ if not data :
116
+ return
117
+ client_socket .send (data )
118
+ except ConnectionResetError as e :
119
+ # data = remote_socket.recv(self.buffer_size) may fail like this
120
+ # usually it happens if the pod has not been fully up, so retry is in order
121
+ logging .info ("failed to read, will try again" , exc_info = e )
122
+ time .sleep (2 )
123
+
124
+
125
+ def main () -> None :
120
126
"""Sample application to show how this can work."""
121
127
122
128
@@ -161,7 +167,7 @@ def get_actual_port(self):
161
167
server .join ()
162
168
163
169
164
- proxy = SocketProxy (remote_socket_factory () , "localhost" , 0 )
170
+ proxy = SocketProxy (remote_socket_factory , "localhost" , 0 )
165
171
thread = threading .Thread (target = proxy .listen_and_serve_until_canceled )
166
172
thread .start ()
167
173
@@ -171,3 +177,7 @@ def get_actual_port(self):
171
177
print (client_socket .recv (1024 )) # prints Hello World
172
178
173
179
thread .join ()
180
+
181
+
182
+ if __name__ == "__main__" :
183
+ main ()
0 commit comments