Skip to content

Commit 476a05b

Browse files
authored
fix: close runner up sockets in the event there are multiple winners (#143)
The first attempt to fix this was to use the cpython staggered race updates in #142 but there is still a race there where there can be multiple winners. Instead we now accept that we will not be able to cancel all coros in time and there will always be a risk of multiple winners. We store all sockets in a set that were not already cleaned up and we close all but the first winner after the staggered race finishes.
1 parent 9c55c91 commit 476a05b

File tree

2 files changed

+120
-13
lines changed

2 files changed

+120
-13
lines changed

src/aiohappyeyeballs/impl.py

+49-11
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import asyncio
44
import collections
5+
import contextlib
56
import functools
67
import itertools
78
import socket
8-
from typing import List, Optional, Sequence, Union
9+
from typing import List, Optional, Sequence, Set, Union
910

1011
from . import _staggered
1112
from .types import AddrInfoType
@@ -75,15 +76,36 @@ async def start_connection(
7576
except (RuntimeError, OSError):
7677
continue
7778
else: # using happy eyeballs
78-
sock, _, _ = await _staggered.staggered_race(
79-
(
80-
functools.partial(
81-
_connect_sock, current_loop, exceptions, addrinfo, local_addr_infos
82-
)
83-
for addrinfo in addr_infos
84-
),
85-
happy_eyeballs_delay,
86-
)
79+
open_sockets: Set[socket.socket] = set()
80+
try:
81+
sock, _, _ = await _staggered.staggered_race(
82+
(
83+
functools.partial(
84+
_connect_sock,
85+
current_loop,
86+
exceptions,
87+
addrinfo,
88+
local_addr_infos,
89+
open_sockets,
90+
)
91+
for addrinfo in addr_infos
92+
),
93+
happy_eyeballs_delay,
94+
)
95+
finally:
96+
# If we have a winner, staggered_race will
97+
# cancel the other tasks, however there is a
98+
# small race window where any of the other tasks
99+
# can be done before they are cancelled which
100+
# will leave the socket open. To avoid this problem
101+
# we pass a set to _connect_sock to keep track of
102+
# the open sockets and close them here if there
103+
# are any "runner up" sockets.
104+
for s in open_sockets:
105+
if s is not sock:
106+
with contextlib.suppress(OSError):
107+
s.close()
108+
open_sockets = None # type: ignore[assignment]
87109

88110
if sock is None:
89111
all_exceptions = [exc for sub in exceptions for exc in sub]
@@ -130,14 +152,26 @@ async def _connect_sock(
130152
exceptions: List[List[Union[OSError, RuntimeError]]],
131153
addr_info: AddrInfoType,
132154
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
155+
open_sockets: Optional[Set[socket.socket]] = None,
133156
) -> socket.socket:
134-
"""Create, bind and connect one socket."""
157+
"""
158+
Create, bind and connect one socket.
159+
160+
If open_sockets is passed, add the socket to the set of open sockets.
161+
Any failure caught here will remove the socket from the set and close it.
162+
163+
Callers can use this set to close any sockets that are not the winner
164+
of all staggered tasks in the result there are runner up sockets aka
165+
multiple winners.
166+
"""
135167
my_exceptions: List[Union[OSError, RuntimeError]] = []
136168
exceptions.append(my_exceptions)
137169
family, type_, proto, _, address = addr_info
138170
sock = None
139171
try:
140172
sock = socket.socket(family=family, type=type_, proto=proto)
173+
if open_sockets is not None:
174+
open_sockets.add(sock)
141175
sock.setblocking(False)
142176
if local_addr_infos is not None:
143177
for lfamily, _, _, _, laddr in local_addr_infos:
@@ -165,6 +199,8 @@ async def _connect_sock(
165199
except (RuntimeError, OSError) as exc:
166200
my_exceptions.append(exc)
167201
if sock is not None:
202+
if open_sockets is not None:
203+
open_sockets.remove(sock)
168204
try:
169205
sock.close()
170206
except OSError as e:
@@ -173,6 +209,8 @@ async def _connect_sock(
173209
raise
174210
except:
175211
if sock is not None:
212+
if open_sockets is not None:
213+
open_sockets.remove(sock)
176214
try:
177215
sock.close()
178216
except OSError as e:

tests/test_impl.py

+71-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
22
import socket
33
from types import ModuleType
4-
from typing import Tuple
4+
from typing import List, Optional, Sequence, Set, Tuple, Union
55
from unittest import mock
66

77
import pytest
88

9-
from aiohappyeyeballs import _staggered, start_connection
9+
from aiohappyeyeballs import AddrInfoType, _staggered, impl, start_connection
1010

1111

1212
def mock_socket_module():
@@ -179,6 +179,75 @@ def _socket(*args, **kw):
179179
assert await start_connection(addr_info) == mock_socket
180180

181181

182+
@pytest.mark.asyncio
183+
@patch_socket
184+
async def test_multiple_winners_cleaned_up(
185+
m_socket: ModuleType,
186+
) -> None:
187+
loop = asyncio.get_running_loop()
188+
finish = loop.create_future()
189+
190+
def _socket(*args, **kw):
191+
return mock.MagicMock(
192+
family=socket.AF_INET,
193+
type=socket.SOCK_STREAM,
194+
proto=socket.IPPROTO_TCP,
195+
fileno=mock.MagicMock(return_value=1),
196+
)
197+
198+
async def _connect_sock(
199+
loop: asyncio.AbstractEventLoop,
200+
exceptions: List[List[Union[OSError, RuntimeError]]],
201+
addr_info: AddrInfoType,
202+
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
203+
sockets: Optional[Set[socket.socket]] = None,
204+
) -> socket.socket:
205+
await finish
206+
sock = _socket()
207+
assert sockets is not None
208+
sockets.add(sock)
209+
return sock
210+
211+
m_socket.socket = _socket # type: ignore
212+
addr_info = [
213+
(
214+
socket.AF_INET,
215+
socket.SOCK_STREAM,
216+
socket.IPPROTO_TCP,
217+
"",
218+
("107.6.106.82", 80),
219+
),
220+
(
221+
socket.AF_INET,
222+
socket.SOCK_STREAM,
223+
socket.IPPROTO_TCP,
224+
"",
225+
("107.6.106.83", 80),
226+
),
227+
(
228+
socket.AF_INET,
229+
socket.SOCK_STREAM,
230+
socket.IPPROTO_TCP,
231+
"",
232+
("107.6.106.84", 80),
233+
),
234+
(
235+
socket.AF_INET,
236+
socket.SOCK_STREAM,
237+
socket.IPPROTO_TCP,
238+
"",
239+
("107.6.106.85", 80),
240+
),
241+
]
242+
with mock.patch.object(impl, "_connect_sock", _connect_sock):
243+
task = loop.create_task(
244+
start_connection(addr_info, happy_eyeballs_delay=0.0001, interleave=0)
245+
)
246+
await asyncio.sleep(0.1)
247+
loop.call_soon(finish.set_result, None)
248+
await task
249+
250+
182251
@pytest.mark.asyncio
183252
@patch_socket
184253
async def test_multiple_addr_success_second_one_happy_eyeballs(

0 commit comments

Comments
 (0)