2
2
3
3
import asyncio
4
4
import collections
5
+ import contextlib
5
6
import functools
6
7
import itertools
7
8
import socket
8
- from typing import List , Optional , Sequence , Union
9
+ from typing import List , Optional , Sequence , Set , Union
9
10
10
11
from . import _staggered
11
12
from .types import AddrInfoType
@@ -75,15 +76,36 @@ async def start_connection(
75
76
except (RuntimeError , OSError ):
76
77
continue
77
78
else : # using happy eyeballs
78
- sock , _ , _ = await _staggered .staggered_race (
79
- (
80
- functools .partial (
81
- _connect_sock , current_loop , exceptions , addrinfo , local_addr_infos
82
- )
83
- for addrinfo in addr_infos
84
- ),
85
- happy_eyeballs_delay ,
86
- )
79
+ open_sockets : Set [socket .socket ] = set ()
80
+ try :
81
+ sock , _ , _ = await _staggered .staggered_race (
82
+ (
83
+ functools .partial (
84
+ _connect_sock ,
85
+ current_loop ,
86
+ exceptions ,
87
+ addrinfo ,
88
+ local_addr_infos ,
89
+ open_sockets ,
90
+ )
91
+ for addrinfo in addr_infos
92
+ ),
93
+ happy_eyeballs_delay ,
94
+ )
95
+ finally :
96
+ # If we have a winner, staggered_race will
97
+ # cancel the other tasks, however there is a
98
+ # small race window where any of the other tasks
99
+ # can be done before they are cancelled which
100
+ # will leave the socket open. To avoid this problem
101
+ # we pass a set to _connect_sock to keep track of
102
+ # the open sockets and close them here if there
103
+ # are any "runner up" sockets.
104
+ for s in open_sockets :
105
+ if s is not sock :
106
+ with contextlib .suppress (OSError ):
107
+ s .close ()
108
+ open_sockets = None # type: ignore[assignment]
87
109
88
110
if sock is None :
89
111
all_exceptions = [exc for sub in exceptions for exc in sub ]
@@ -130,14 +152,26 @@ async def _connect_sock(
130
152
exceptions : List [List [Union [OSError , RuntimeError ]]],
131
153
addr_info : AddrInfoType ,
132
154
local_addr_infos : Optional [Sequence [AddrInfoType ]] = None ,
155
+ open_sockets : Optional [Set [socket .socket ]] = None ,
133
156
) -> socket .socket :
134
- """Create, bind and connect one socket."""
157
+ """
158
+ Create, bind and connect one socket.
159
+
160
+ If open_sockets is passed, add the socket to the set of open sockets.
161
+ Any failure caught here will remove the socket from the set and close it.
162
+
163
+ Callers can use this set to close any sockets that are not the winner
164
+ of all staggered tasks in the result there are runner up sockets aka
165
+ multiple winners.
166
+ """
135
167
my_exceptions : List [Union [OSError , RuntimeError ]] = []
136
168
exceptions .append (my_exceptions )
137
169
family , type_ , proto , _ , address = addr_info
138
170
sock = None
139
171
try :
140
172
sock = socket .socket (family = family , type = type_ , proto = proto )
173
+ if open_sockets is not None :
174
+ open_sockets .add (sock )
141
175
sock .setblocking (False )
142
176
if local_addr_infos is not None :
143
177
for lfamily , _ , _ , _ , laddr in local_addr_infos :
@@ -165,6 +199,8 @@ async def _connect_sock(
165
199
except (RuntimeError , OSError ) as exc :
166
200
my_exceptions .append (exc )
167
201
if sock is not None :
202
+ if open_sockets is not None :
203
+ open_sockets .remove (sock )
168
204
try :
169
205
sock .close ()
170
206
except OSError as e :
@@ -173,6 +209,8 @@ async def _connect_sock(
173
209
raise
174
210
except :
175
211
if sock is not None :
212
+ if open_sockets is not None :
213
+ open_sockets .remove (sock )
176
214
try :
177
215
sock .close ()
178
216
except OSError as e :
0 commit comments