Skip to content

Commit dee6d41

Browse files
committed
ISSUE-922: chore(tests/containers): implement retry if port-forwarding fails
1 parent 540495b commit dee6d41

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

tests/containers/kubernetes_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def deploy(self, container_name: str) -> None:
225225
assert len(pod_name.items) == 1
226226
pod: kubernetes.client.models.v1_pod.V1Pod = pod_name.items[0]
227227

228-
p = socket_proxy.SocketProxy(exposing_contextmanager(core_v1_api, pod), "localhost", 0)
228+
p = socket_proxy.SocketProxy(lambda: exposing_contextmanager(core_v1_api, pod), "localhost", 0)
229229
t = threading.Thread(target=p.listen_and_serve_until_canceled)
230230
t.start()
231231
self.tf.defer(t, lambda thread: thread.join())

tests/containers/socket_proxy.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import logging
55
import socket
66
import select
7+
import struct
78
import threading
89
import subprocess
9-
import typing
10+
from typing import Callable, ContextManager
1011

1112
from tests.containers.cancellation_token import CancellationToken
1213

@@ -50,7 +51,7 @@ def stop(self):
5051
class SocketProxy:
5152
def __init__(
5253
self,
53-
remote_socket_factory: typing.ContextManager[socket.socket],
54+
remote_socket_factory: Callable[..., ContextManager[socket.socket]],
5455
local_host: str = "localhost",
5556
local_port: int = 0,
5657
buffer_size: int = 4096
@@ -81,9 +82,14 @@ def listen_and_serve_until_canceled(self):
8182
Handles at most one client at a time. """
8283
try:
8384
while not self.cancellation_token.cancelled:
84-
client_socket, addr = self.server_socket.accept()
85-
logging.info(f"Accepted connection from {addr[0]}:{addr[1]}")
86-
self._handle_client(client_socket)
85+
readable, _, _ = select.select([self.server_socket, self.cancellation_token], [], [])
86+
87+
# ISSUE-922: socket.accept() blocks, so if cancel() did not come very fast, we'd loop over and block
88+
if self.server_socket in readable:
89+
client_socket, addr = self.server_socket.accept()
90+
logging.info(f"Accepted connection from {addr[0]}:{addr[1]}")
91+
# handle client synchronously, which means that there can be at most one at a time
92+
self._handle_client(client_socket)
8793
except Exception as e:
8894
logging.exception(f"Proxying failed to listen", exc_info=e)
8995
raise
@@ -96,27 +102,38 @@ def get_actual_port(self) -> int:
96102
return self.server_socket.getsockname()[1]
97103

98104
def _handle_client(self, client_socket):
99-
with client_socket as _, self.remote_socket_factory as remote_socket:
100-
while True:
105+
with client_socket as _, self.remote_socket_factory() as remote_socket:
106+
while not self.cancellation_token.cancelled:
101107
readable, _, _ = select.select([client_socket, remote_socket, self.cancellation_token], [], [])
102108

103-
if self.cancellation_token.cancelled:
104-
break
105-
106109
if client_socket in readable:
107110
data = client_socket.recv(self.buffer_size)
108111
if not data:
109112
break
110113
remote_socket.send(data)
111114

112115
if remote_socket in readable:
113-
data = remote_socket.recv(self.buffer_size)
116+
try:
117+
data = remote_socket.recv(self.buffer_size)
118+
except ConnectionResetError:
119+
# ISSUE-922: it seems best to propagate the error and let the client retry
120+
# alternatively it would be necessary to resend anything already received from client_socket
121+
_rst_socket(client_socket)
122+
break
114123
if not data:
115124
break
116125
client_socket.send(data)
117126

118127

119-
if __name__ == "__main__":
128+
def _rst_socket(s: socket):
129+
"""Closing a SO_LINGER socket will RST it
130+
https://stackoverflow.com/questions/46264404/how-can-i-reset-a-tcp-socket-in-python
131+
"""
132+
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
133+
s.close()
134+
135+
136+
def main() -> None:
120137
"""Sample application to show how this can work."""
121138

122139

@@ -161,13 +178,21 @@ def get_actual_port(self):
161178
server.join()
162179

163180

164-
proxy = SocketProxy(remote_socket_factory(), "localhost", 0)
181+
proxy = SocketProxy(remote_socket_factory, "localhost", 0)
165182
thread = threading.Thread(target=proxy.listen_and_serve_until_canceled)
166183
thread.start()
167184

168-
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
169-
client_socket.connect(("localhost", proxy.get_actual_port()))
185+
for _ in range(2):
186+
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
187+
client_socket.connect(("localhost", proxy.get_actual_port()))
170188

171-
print(client_socket.recv(1024)) # prints Hello World
189+
print(client_socket.recv(1024)) # prints Hello World
190+
print(client_socket.recv(1024)) # prints nothing
191+
client_socket.close()
192+
proxy.cancellation_token.cancel()
172193

173194
thread.join()
195+
196+
197+
if __name__ == "__main__":
198+
main()

0 commit comments

Comments
 (0)