Skip to content

Commit 8e1bc6a

Browse files
authored
feat: add callback for users to customize socket creation (#147)
Co-authored-by: Kieren <kjander0>
1 parent c4ab1e5 commit 8e1bc6a

File tree

4 files changed

+102
-5
lines changed

4 files changed

+102
-5
lines changed

src/aiohappyeyeballs/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
__version__ = "2.4.8"
22

33
from .impl import start_connection
4-
from .types import AddrInfoType
4+
from .types import AddrInfoType, SocketFactoryType
55
from .utils import addr_to_addr_infos, pop_addr_infos_interleave, remove_addr_infos
66

77
__all__ = (
88
"AddrInfoType",
9+
"SocketFactoryType",
910
"addr_to_addr_infos",
1011
"pop_addr_infos_interleave",
1112
"remove_addr_infos",

src/aiohappyeyeballs/impl.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import List, Optional, Sequence, Set, Union
1010

1111
from . import _staggered
12-
from .types import AddrInfoType
12+
from .types import AddrInfoType, SocketFactoryType
1313

1414

1515
async def start_connection(
@@ -19,6 +19,7 @@ async def start_connection(
1919
happy_eyeballs_delay: Optional[float] = None,
2020
interleave: Optional[int] = None,
2121
loop: Optional[asyncio.AbstractEventLoop] = None,
22+
socket_factory: Optional[SocketFactoryType] = None,
2223
) -> socket.socket:
2324
"""
2425
Connect to a TCP server.
@@ -70,7 +71,12 @@ async def start_connection(
7071
for addrinfo in addr_infos:
7172
try:
7273
sock = await _connect_sock(
73-
current_loop, exceptions, addrinfo, local_addr_infos
74+
current_loop,
75+
exceptions,
76+
addrinfo,
77+
local_addr_infos,
78+
None,
79+
socket_factory,
7480
)
7581
break
7682
except (RuntimeError, OSError):
@@ -87,6 +93,7 @@ async def start_connection(
8793
addrinfo,
8894
local_addr_infos,
8995
open_sockets,
96+
socket_factory,
9097
)
9198
for addrinfo in addr_infos
9299
),
@@ -153,6 +160,7 @@ async def _connect_sock(
153160
addr_info: AddrInfoType,
154161
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
155162
open_sockets: Optional[Set[socket.socket]] = None,
163+
socket_factory: Optional[SocketFactoryType] = None,
156164
) -> socket.socket:
157165
"""
158166
Create, bind and connect one socket.
@@ -169,7 +177,10 @@ async def _connect_sock(
169177
family, type_, proto, _, address = addr_info
170178
sock = None
171179
try:
172-
sock = socket.socket(family=family, type=type_, proto=proto)
180+
if socket_factory is not None:
181+
sock = socket_factory(addr_info)
182+
else:
183+
sock = socket.socket(family=family, type=type_, proto=proto)
173184
if open_sockets is not None:
174185
open_sockets.add(sock)
175186
sock.setblocking(False)

src/aiohappyeyeballs/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Types for aiohappyeyeballs."""
22

33
import socket
4+
from collections.abc import Callable
45
from typing import Tuple, Union
56

67
AddrInfoType = Tuple[
@@ -10,3 +11,5 @@
1011
str,
1112
Tuple, # type: ignore[type-arg]
1213
]
14+
15+
SocketFactoryType = Callable[[AddrInfoType], socket.socket]

tests/test_impl.py

+83-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
import pytest
88

9-
from aiohappyeyeballs import AddrInfoType, _staggered, impl, start_connection
9+
from aiohappyeyeballs import (
10+
AddrInfoType,
11+
SocketFactoryType,
12+
_staggered,
13+
impl,
14+
start_connection,
15+
)
1016

1117

1218
def mock_socket_module():
@@ -136,6 +142,33 @@ def _socket(*args, **kw):
136142
)
137143

138144

145+
@pytest.mark.asyncio
146+
@patch_socket
147+
async def test_single_addr_socket_factory(m_socket: ModuleType) -> None:
148+
mock_socket = mock.MagicMock(
149+
family=socket.AF_INET,
150+
type=socket.SOCK_STREAM,
151+
proto=socket.IPPROTO_TCP,
152+
fileno=mock.MagicMock(return_value=1),
153+
)
154+
155+
def factory(addr_info: AddrInfoType) -> socket.socket:
156+
return mock_socket
157+
158+
addr_info = [
159+
(
160+
socket.AF_INET,
161+
socket.SOCK_STREAM,
162+
socket.IPPROTO_TCP,
163+
"",
164+
("107.6.106.82", 80),
165+
)
166+
]
167+
loop = asyncio.get_running_loop()
168+
with mock.patch.object(loop, "sock_connect", return_value=None):
169+
assert await start_connection(addr_info, socket_factory=factory) == mock_socket
170+
171+
139172
@pytest.mark.asyncio
140173
@patch_socket
141174
async def test_multiple_addr_success_second_one(
@@ -201,6 +234,7 @@ async def _connect_sock(
201234
addr_info: AddrInfoType,
202235
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
203236
sockets: Optional[Set[socket.socket]] = None,
237+
socket_factory: Optional[SocketFactoryType] = None,
204238
) -> socket.socket:
205239
await finish
206240
sock = _socket()
@@ -293,6 +327,54 @@ def _socket(*args, **kw):
293327
)
294328

295329

330+
@pytest.mark.asyncio
331+
@patch_socket
332+
async def test_happy_eyeballs_socket_factory(
333+
m_socket: ModuleType,
334+
) -> None:
335+
mock_socket = mock.MagicMock(
336+
family=socket.AF_INET,
337+
type=socket.SOCK_STREAM,
338+
proto=socket.IPPROTO_TCP,
339+
fileno=mock.MagicMock(return_value=1),
340+
)
341+
342+
idx = -1
343+
errors = ["err1", "err2"]
344+
345+
def factory(addr_info: AddrInfoType) -> socket.socket:
346+
nonlocal idx, errors
347+
idx += 1
348+
if idx == 1:
349+
raise OSError(5, errors[idx])
350+
return mock_socket
351+
352+
addr_info = [
353+
(
354+
socket.AF_INET,
355+
socket.SOCK_STREAM,
356+
socket.IPPROTO_TCP,
357+
"",
358+
("107.6.106.82", 80),
359+
),
360+
(
361+
socket.AF_INET,
362+
socket.SOCK_STREAM,
363+
socket.IPPROTO_TCP,
364+
"",
365+
("107.6.106.83", 80),
366+
),
367+
]
368+
loop = asyncio.get_running_loop()
369+
with mock.patch.object(loop, "sock_connect", return_value=None):
370+
assert (
371+
await start_connection(
372+
addr_info, happy_eyeballs_delay=0.3, socket_factory=factory
373+
)
374+
== mock_socket
375+
)
376+
377+
296378
@pytest.mark.asyncio
297379
@patch_socket
298380
async def test_multiple_addr_all_fail_happy_eyeballs(

0 commit comments

Comments
 (0)