1
1
import asyncio
2
2
import contextlib
3
3
from typing import (
4
- TYPE_CHECKING ,
5
4
Any ,
6
5
Awaitable ,
7
6
Callable ,
8
7
Iterable ,
9
8
List ,
10
9
Optional ,
11
- Set ,
12
10
Tuple ,
13
11
TypeVar ,
14
12
Union ,
17
15
_T = TypeVar ("_T" )
18
16
19
17
20
- def _set_result (wait_next : "asyncio.Future[None]" ) -> None :
21
- """Set the result of a future if it is not already done."""
22
- if not wait_next .done ():
23
- wait_next .set_result (None )
24
-
25
-
26
- async def _wait_one (
27
- futures : "Iterable[asyncio.Future[Any]]" ,
28
- loop : asyncio .AbstractEventLoop ,
29
- ) -> _T :
30
- """Wait for the first future to complete."""
31
- wait_next = loop .create_future ()
32
-
33
- def _on_completion (fut : "asyncio.Future[Any]" ) -> None :
34
- if not wait_next .done ():
35
- wait_next .set_result (fut )
36
-
37
- for f in futures :
38
- f .add_done_callback (_on_completion )
39
-
40
- try :
41
- return await wait_next
42
- finally :
43
- for f in futures :
44
- f .remove_done_callback (_on_completion )
45
-
46
-
47
18
async def staggered_race (
48
19
coro_fns : Iterable [Callable [[], Awaitable [_T ]]],
49
20
delay : Optional [float ],
@@ -75,18 +46,16 @@ async def staggered_race(
75
46
raise
76
47
77
48
Args:
78
- ----
79
49
coro_fns: an iterable of coroutine functions, i.e. callables that
80
50
return a coroutine object when called. Use ``functools.partial`` or
81
51
lambdas to pass arguments.
82
52
83
53
delay: amount of time, in seconds, between starting coroutines. If
84
54
``None``, the coroutines will run sequentially.
85
55
86
- loop: the event loop to use. If ``None``, the running loop is used.
56
+ loop: the event loop to use.
87
57
88
58
Returns:
89
- -------
90
59
tuple *(winner_result, winner_index, exceptions)* where
91
60
92
61
- *winner_result*: the result of the winning coroutine, or ``None``
@@ -103,100 +72,113 @@ async def staggered_race(
103
72
coroutine's entry is ``None``.
104
73
105
74
"""
75
+ # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
106
76
loop = loop or asyncio .get_running_loop ()
107
- exceptions : List [Optional [BaseException ]] = []
108
- tasks : Set [asyncio .Task [Optional [Tuple [_T , int ]]]] = set ()
77
+ enum_coro_fns = enumerate (coro_fns )
78
+ winner_result : Optional [Any ] = None
79
+ winner_index : Union [int , None ] = None
80
+ unhandled_exceptions : list [BaseException ] = []
81
+ exceptions : list [Union [BaseException , None ]] = []
82
+ running_tasks : set [asyncio .Task [Any ]] = set ()
83
+ on_completed_fut : Union [asyncio .Future [None ], None ] = None
84
+
85
+ def task_done (task : asyncio .Task [Any ]) -> None :
86
+ running_tasks .discard (task )
87
+ if (
88
+ on_completed_fut is not None
89
+ and not on_completed_fut .done ()
90
+ and not running_tasks
91
+ ):
92
+ on_completed_fut .set_result (None )
93
+
94
+ if task .cancelled ():
95
+ return
96
+
97
+ exc = task .exception ()
98
+ if exc is None :
99
+ return
100
+ unhandled_exceptions .append (exc ) # noqa: F821 - defined in the outer scope
109
101
110
102
async def run_one_coro (
111
- coro_fn : Callable [[], Awaitable [_T ]],
112
- this_index : int ,
113
- start_next : "asyncio.Future[None]" ,
114
- ) -> Optional [Tuple [_T , int ]]:
115
- """
116
- Run a single coroutine.
117
-
118
- If the coroutine fails, set the exception in the exceptions list and
119
- start the next coroutine by setting the result of the start_next.
120
-
121
- If the coroutine succeeds, return the result and the index of the
122
- coroutine in the coro_fns list.
103
+ ok_to_start : asyncio .Event ,
104
+ previous_failed : Union [None , asyncio .Event ],
105
+ ) -> None :
106
+ # in eager tasks this waits for the calling task to append this task
107
+ # to running_tasks, in regular tasks this wait is a no-op that does
108
+ # not yield a future. See gh-124309.
109
+ await ok_to_start .wait ()
110
+ # Wait for the previous task to finish, or for delay seconds
111
+ if previous_failed is not None :
112
+ with contextlib .suppress (asyncio .TimeoutError ):
113
+ # Use asyncio.wait_for() instead of asyncio.wait() here, so
114
+ # that if we get cancelled at this point, Event.wait() is also
115
+ # cancelled, otherwise there will be a "Task destroyed but it is
116
+ # pending" later.
117
+ await asyncio .wait_for (previous_failed .wait (), delay )
118
+ # Get the next coroutine to run
119
+ try :
120
+ this_index , coro_fn = next (enum_coro_fns )
121
+ except StopIteration :
122
+ return
123
+ # Start task that will run the next coroutine
124
+ this_failed = asyncio .Event ()
125
+ next_ok_to_start = asyncio .Event ()
126
+ next_task = loop .create_task (run_one_coro (next_ok_to_start , this_failed ))
127
+ running_tasks .add (next_task )
128
+ next_task .add_done_callback (task_done )
129
+ # next_task has been appended to running_tasks so next_task is ok to
130
+ # start.
131
+ next_ok_to_start .set ()
132
+ # Prepare place to put this coroutine's exceptions if not won
133
+ exceptions .append (None ) # noqa: F821 - defined in the outer scope
134
+ assert len (exceptions ) == this_index + 1 # noqa: F821, S101 - defined in the outer scope
123
135
124
- If SystemExit or KeyboardInterrupt is raised, re-raise it.
125
- """
126
136
try :
127
137
result = await coro_fn ()
128
138
except (SystemExit , KeyboardInterrupt ):
129
139
raise
130
140
except BaseException as e :
131
- exceptions [this_index ] = e
132
- _set_result (start_next ) # Kickstart the next coroutine
133
- return None
134
-
135
- return result , this_index
136
-
137
- start_next_timer : Optional [asyncio .TimerHandle ] = None
138
- start_next : Optional [asyncio .Future [None ]]
139
- task : asyncio .Task [Optional [Tuple [_T , int ]]]
140
- done : Union [asyncio .Future [None ], asyncio .Task [Optional [Tuple [_T , int ]]]]
141
- coro_iter = iter (coro_fns )
142
- this_index = - 1
141
+ exceptions [this_index ] = e # noqa: F821 - defined in the outer scope
142
+ this_failed .set () # Kickstart the next coroutine
143
+ else :
144
+ # Store winner's results
145
+ nonlocal winner_index , winner_result
146
+ assert winner_index is None # noqa: S101
147
+ winner_index = this_index
148
+ winner_result = result
149
+ # Cancel all other tasks. We take care to not cancel the current
150
+ # task as well. If we do so, then since there is no `await` after
151
+ # here and CancelledError are usually thrown at one, we will
152
+ # encounter a curious corner case where the current task will end
153
+ # up as done() == True, cancelled() == False, exception() ==
154
+ # asyncio.CancelledError. This behavior is specified in
155
+ # https://bugs.python.org/issue30048
156
+ current_task = asyncio .current_task (loop )
157
+ for t in running_tasks :
158
+ if t is not current_task :
159
+ t .cancel ()
160
+
161
+ propagate_cancellation_error = None
143
162
try :
144
- while True :
145
- if coro_fn := next (coro_iter , None ):
146
- this_index += 1
147
- exceptions .append (None )
148
- start_next = loop .create_future ()
149
- task = loop .create_task (run_one_coro (coro_fn , this_index , start_next ))
150
- tasks .add (task )
151
- start_next_timer = (
152
- loop .call_later (delay , _set_result , start_next ) if delay else None
153
- )
154
- elif not tasks :
155
- # We exhausted the coro_fns list and no tasks are running
156
- # so we have no winner and all coroutines failed.
157
- break
158
-
159
- while tasks or start_next :
160
- done = await _wait_one (
161
- (* tasks , start_next ) if start_next else tasks , loop
162
- )
163
- if done is start_next :
164
- # The current task has failed or the timer has expired
165
- # so we need to start the next task.
166
- start_next = None
167
- if start_next_timer :
168
- start_next_timer .cancel ()
169
- start_next_timer = None
170
-
171
- # Break out of the task waiting loop to start the next
172
- # task.
173
- break
174
-
175
- if TYPE_CHECKING :
176
- assert isinstance (done , asyncio .Task )
177
-
178
- tasks .remove (done )
179
- if winner := done .result ():
180
- return * winner , exceptions
163
+ ok_to_start = asyncio .Event ()
164
+ first_task = loop .create_task (run_one_coro (ok_to_start , None ))
165
+ running_tasks .add (first_task )
166
+ first_task .add_done_callback (task_done )
167
+ # first_task has been appended to running_tasks so first_task is ok to start.
168
+ ok_to_start .set ()
169
+ propagate_cancellation_error = None
170
+ # Make sure no tasks are left running if we leave this function
171
+ while running_tasks :
172
+ on_completed_fut = loop .create_future ()
173
+ try :
174
+ await on_completed_fut
175
+ except asyncio .CancelledError as ex :
176
+ propagate_cancellation_error = ex
177
+ for task in running_tasks :
178
+ task .cancel (* ex .args )
179
+ on_completed_fut = None
180
+ if propagate_cancellation_error is not None :
181
+ raise propagate_cancellation_error
182
+ return winner_result , winner_index , exceptions
181
183
finally :
182
- # We either have:
183
- # - a winner
184
- # - all tasks failed
185
- # - a KeyboardInterrupt or SystemExit.
186
-
187
- #
188
- # If the timer is still running, cancel it.
189
- #
190
- if start_next_timer :
191
- start_next_timer .cancel ()
192
-
193
- #
194
- # If there are any tasks left, cancel them and than
195
- # wait them so they fill the exceptions list.
196
- #
197
- for task in tasks :
198
- task .cancel ()
199
- with contextlib .suppress (asyncio .CancelledError ):
200
- await task
201
-
202
- return None , None , exceptions
184
+ del exceptions , propagate_cancellation_error , unhandled_exceptions
0 commit comments