Skip to content

Commit d248314

Browse files
committed
fix: update staggered race implementation to use latest cpython version
related aiohttp issue aio-libs/aiohttp#10506 In #101 we replaced the staggered race implementation since the cpython version had races that were not fixed at the time. cpython has since updated the implementation to fix additional races. Our current implementation still has problems with cancellation and cpython has fixed that in python/cpython#128475 and python/cpython#124847 This PR ports the latest cpython implementation
1 parent 035d976 commit d248314

File tree

3 files changed

+170
-124
lines changed

3 files changed

+170
-124
lines changed

src/aiohappyeyeballs/_staggered.py

+100-118
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import asyncio
22
import contextlib
33
from typing import (
4-
TYPE_CHECKING,
54
Any,
65
Awaitable,
76
Callable,
87
Iterable,
98
List,
109
Optional,
11-
Set,
1210
Tuple,
1311
TypeVar,
1412
Union,
@@ -17,33 +15,6 @@
1715
_T = TypeVar("_T")
1816

1917

20-
def _set_result(wait_next: "asyncio.Future[None]") -> None:
21-
"""Set the result of a future if it is not already done."""
22-
if not wait_next.done():
23-
wait_next.set_result(None)
24-
25-
26-
async def _wait_one(
27-
futures: "Iterable[asyncio.Future[Any]]",
28-
loop: asyncio.AbstractEventLoop,
29-
) -> _T:
30-
"""Wait for the first future to complete."""
31-
wait_next = loop.create_future()
32-
33-
def _on_completion(fut: "asyncio.Future[Any]") -> None:
34-
if not wait_next.done():
35-
wait_next.set_result(fut)
36-
37-
for f in futures:
38-
f.add_done_callback(_on_completion)
39-
40-
try:
41-
return await wait_next
42-
finally:
43-
for f in futures:
44-
f.remove_done_callback(_on_completion)
45-
46-
4718
async def staggered_race(
4819
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
4920
delay: Optional[float],
@@ -75,18 +46,16 @@ async def staggered_race(
7546
raise
7647
7748
Args:
78-
----
7949
coro_fns: an iterable of coroutine functions, i.e. callables that
8050
return a coroutine object when called. Use ``functools.partial`` or
8151
lambdas to pass arguments.
8252
8353
delay: amount of time, in seconds, between starting coroutines. If
8454
``None``, the coroutines will run sequentially.
8555
86-
loop: the event loop to use. If ``None``, the running loop is used.
56+
loop: the event loop to use.
8757
8858
Returns:
89-
-------
9059
tuple *(winner_result, winner_index, exceptions)* where
9160
9261
- *winner_result*: the result of the winning coroutine, or ``None``
@@ -103,100 +72,113 @@ async def staggered_race(
10372
coroutine's entry is ``None``.
10473
10574
"""
75+
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
10676
loop = loop or asyncio.get_running_loop()
107-
exceptions: List[Optional[BaseException]] = []
108-
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
77+
enum_coro_fns = enumerate(coro_fns)
78+
winner_result: Optional[Any] = None
79+
winner_index: Union[int, None] = None
80+
unhandled_exceptions: list[BaseException] = []
81+
exceptions: list[Union[BaseException, None]] = []
82+
running_tasks: set[asyncio.Task[Any]] = set()
83+
on_completed_fut: Union[asyncio.Future[None], None] = None
84+
85+
def task_done(task: asyncio.Task[Any]) -> None:
86+
running_tasks.discard(task)
87+
if (
88+
on_completed_fut is not None
89+
and not on_completed_fut.done()
90+
and not running_tasks
91+
):
92+
on_completed_fut.set_result(None)
93+
94+
if task.cancelled():
95+
return
96+
97+
exc = task.exception()
98+
if exc is None:
99+
return
100+
unhandled_exceptions.append(exc) # noqa: F821 - defined in the outer scope
109101

110102
async def run_one_coro(
111-
coro_fn: Callable[[], Awaitable[_T]],
112-
this_index: int,
113-
start_next: "asyncio.Future[None]",
114-
) -> Optional[Tuple[_T, int]]:
115-
"""
116-
Run a single coroutine.
117-
118-
If the coroutine fails, set the exception in the exceptions list and
119-
start the next coroutine by setting the result of the start_next.
120-
121-
If the coroutine succeeds, return the result and the index of the
122-
coroutine in the coro_fns list.
103+
ok_to_start: asyncio.Event,
104+
previous_failed: Union[None, asyncio.Event],
105+
) -> None:
106+
# in eager tasks this waits for the calling task to append this task
107+
# to running_tasks, in regular tasks this wait is a no-op that does
108+
# not yield a future. See gh-124309.
109+
await ok_to_start.wait()
110+
# Wait for the previous task to finish, or for delay seconds
111+
if previous_failed is not None:
112+
with contextlib.suppress(asyncio.TimeoutError):
113+
# Use asyncio.wait_for() instead of asyncio.wait() here, so
114+
# that if we get cancelled at this point, Event.wait() is also
115+
# cancelled, otherwise there will be a "Task destroyed but it is
116+
# pending" later.
117+
await asyncio.wait_for(previous_failed.wait(), delay)
118+
# Get the next coroutine to run
119+
try:
120+
this_index, coro_fn = next(enum_coro_fns)
121+
except StopIteration:
122+
return
123+
# Start task that will run the next coroutine
124+
this_failed = asyncio.Event()
125+
next_ok_to_start = asyncio.Event()
126+
next_task = loop.create_task(run_one_coro(next_ok_to_start, this_failed))
127+
running_tasks.add(next_task)
128+
next_task.add_done_callback(task_done)
129+
# next_task has been appended to running_tasks so next_task is ok to
130+
# start.
131+
next_ok_to_start.set()
132+
# Prepare place to put this coroutine's exceptions if not won
133+
exceptions.append(None) # noqa: F821 - defined in the outer scope
134+
assert len(exceptions) == this_index + 1 # noqa: F821, S101 - defined in the outer scope
123135

124-
If SystemExit or KeyboardInterrupt is raised, re-raise it.
125-
"""
126136
try:
127137
result = await coro_fn()
128138
except (SystemExit, KeyboardInterrupt):
129139
raise
130140
except BaseException as e:
131-
exceptions[this_index] = e
132-
_set_result(start_next) # Kickstart the next coroutine
133-
return None
134-
135-
return result, this_index
136-
137-
start_next_timer: Optional[asyncio.TimerHandle] = None
138-
start_next: Optional[asyncio.Future[None]]
139-
task: asyncio.Task[Optional[Tuple[_T, int]]]
140-
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
141-
coro_iter = iter(coro_fns)
142-
this_index = -1
141+
exceptions[this_index] = e # noqa: F821 - defined in the outer scope
142+
this_failed.set() # Kickstart the next coroutine
143+
else:
144+
# Store winner's results
145+
nonlocal winner_index, winner_result
146+
assert winner_index is None # noqa: S101
147+
winner_index = this_index
148+
winner_result = result
149+
# Cancel all other tasks. We take care to not cancel the current
150+
# task as well. If we do so, then since there is no `await` after
151+
# here and CancelledError are usually thrown at one, we will
152+
# encounter a curious corner case where the current task will end
153+
# up as done() == True, cancelled() == False, exception() ==
154+
# asyncio.CancelledError. This behavior is specified in
155+
# https://bugs.python.org/issue30048
156+
current_task = asyncio.current_task(loop)
157+
for t in running_tasks:
158+
if t is not current_task:
159+
t.cancel()
160+
161+
propagate_cancellation_error = None
143162
try:
144-
while True:
145-
if coro_fn := next(coro_iter, None):
146-
this_index += 1
147-
exceptions.append(None)
148-
start_next = loop.create_future()
149-
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
150-
tasks.add(task)
151-
start_next_timer = (
152-
loop.call_later(delay, _set_result, start_next) if delay else None
153-
)
154-
elif not tasks:
155-
# We exhausted the coro_fns list and no tasks are running
156-
# so we have no winner and all coroutines failed.
157-
break
158-
159-
while tasks or start_next:
160-
done = await _wait_one(
161-
(*tasks, start_next) if start_next else tasks, loop
162-
)
163-
if done is start_next:
164-
# The current task has failed or the timer has expired
165-
# so we need to start the next task.
166-
start_next = None
167-
if start_next_timer:
168-
start_next_timer.cancel()
169-
start_next_timer = None
170-
171-
# Break out of the task waiting loop to start the next
172-
# task.
173-
break
174-
175-
if TYPE_CHECKING:
176-
assert isinstance(done, asyncio.Task)
177-
178-
tasks.remove(done)
179-
if winner := done.result():
180-
return *winner, exceptions
163+
ok_to_start = asyncio.Event()
164+
first_task = loop.create_task(run_one_coro(ok_to_start, None))
165+
running_tasks.add(first_task)
166+
first_task.add_done_callback(task_done)
167+
# first_task has been appended to running_tasks so first_task is ok to start.
168+
ok_to_start.set()
169+
propagate_cancellation_error = None
170+
# Make sure no tasks are left running if we leave this function
171+
while running_tasks:
172+
on_completed_fut = loop.create_future()
173+
try:
174+
await on_completed_fut
175+
except asyncio.CancelledError as ex:
176+
propagate_cancellation_error = ex
177+
for task in running_tasks:
178+
task.cancel(*ex.args)
179+
on_completed_fut = None
180+
if propagate_cancellation_error is not None:
181+
raise propagate_cancellation_error
182+
return winner_result, winner_index, exceptions
181183
finally:
182-
# We either have:
183-
# - a winner
184-
# - all tasks failed
185-
# - a KeyboardInterrupt or SystemExit.
186-
187-
#
188-
# If the timer is still running, cancel it.
189-
#
190-
if start_next_timer:
191-
start_next_timer.cancel()
192-
193-
#
194-
# If there are any tasks left, cancel them and than
195-
# wait them so they fill the exceptions list.
196-
#
197-
for task in tasks:
198-
task.cancel()
199-
with contextlib.suppress(asyncio.CancelledError):
200-
await task
201-
202-
return None, None, exceptions
184+
del exceptions, propagate_cancellation_error, unhandled_exceptions

tests/test_staggered.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,15 @@ async def coro(idx):
4747
await asyncio.sleep(0.1)
4848
loop.call_soon(finish.set_result, None)
4949
winner, index, excs = await task
50-
assert len(winners) == 4
51-
assert winners == [0, 1, 2, 3]
50+
assert len(winners) == 1
51+
assert winners == [0]
5252
assert winner == 0
5353
assert index == 0
54-
assert excs == [None, None, None, None]
54+
assert len(excs) == 4
55+
assert excs[0] is None
56+
assert isinstance(excs[1], asyncio.CancelledError)
57+
assert isinstance(excs[2], asyncio.CancelledError)
58+
assert isinstance(excs[3], asyncio.CancelledError)
5559

5660

5761
@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher")
@@ -77,10 +81,14 @@ async def coro(idx):
7781
await asyncio.sleep(0.1)
7882
loop.call_soon(finish.set_result, None)
7983
winner, index, excs = await task
80-
assert len(winners) == 4
81-
assert winners == [0, 1, 2, 3]
84+
assert len(winners) == 1
85+
assert winners == [0]
8286
assert winner == 0
8387
assert index == 0
84-
assert excs == [None, None, None, None]
88+
assert len(excs) == 4
89+
assert excs[0] is None
90+
assert isinstance(excs[1], asyncio.CancelledError)
91+
assert isinstance(excs[2], asyncio.CancelledError)
92+
assert isinstance(excs[3], asyncio.CancelledError)
8593

8694
loop.run_until_complete(run())

tests/test_staggered_cpython.py

+56
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,62 @@ async def main():
141141
loop.run_until_complete(main())
142142
loop.close()
143143

144+
async def test_multiple_winners(self):
145+
event = asyncio.Event()
146+
147+
async def coro(index):
148+
await event.wait()
149+
return index
150+
151+
async def do_set():
152+
event.set()
153+
await asyncio.Event().wait()
154+
155+
winner, index, excs = await staggered_race(
156+
[
157+
lambda: coro(0),
158+
lambda: coro(1),
159+
do_set,
160+
],
161+
delay=0.1,
162+
)
163+
self.assertIs(winner, 0)
164+
self.assertIs(index, 0)
165+
self.assertEqual(len(excs), 3)
166+
self.assertIsNone(excs[0], None)
167+
self.assertIsInstance(excs[1], asyncio.CancelledError)
168+
self.assertIsInstance(excs[2], asyncio.CancelledError)
169+
170+
async def test_cancelled(self):
171+
log = []
172+
with self.assertRaises(TimeoutError):
173+
async with asyncio.timeout(None) as cs_outer, asyncio.timeout(
174+
None
175+
) as cs_inner:
176+
177+
async def coro_fn():
178+
cs_inner.reschedule(-1)
179+
await asyncio.sleep(0)
180+
try:
181+
await asyncio.sleep(0)
182+
except asyncio.CancelledError:
183+
log.append("cancelled 1")
184+
185+
cs_outer.reschedule(-1)
186+
await asyncio.sleep(0)
187+
try:
188+
await asyncio.sleep(0)
189+
except asyncio.CancelledError:
190+
log.append("cancelled 2")
191+
192+
try:
193+
await staggered_race([coro_fn], delay=None)
194+
except asyncio.CancelledError:
195+
log.append("cancelled 3")
196+
raise
197+
198+
self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"])
199+
144200

145201
if __name__ == "__main__":
146202
unittest.main()

0 commit comments

Comments
 (0)