Skip to content

Commit 9148414

Browse files
Add trio.testing.wait_all_threads_completed
This is the equivalent of trio.testing.wait_all_tasks_blocked but for threads managed by trio. This is useful when writing tests that use to_thread
1 parent 556df86 commit 9148414

File tree

5 files changed

+162
-35
lines changed

5 files changed

+162
-35
lines changed

docs/source/reference-testing.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ Inter-task ordering
7272

7373
.. autofunction:: wait_all_tasks_blocked
7474

75+
.. autofunction:: wait_all_threads_completed
76+
77+
.. autofunction:: active_thread_count
78+
7579

7680
.. _testing-streams:
7781

newsfragments/2937.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `trio.testing.wait_all_threads_completed`, which blocks until no threads are running tasks. This is intended to be used in the same way as `trio.testing.wait_all_tasks_blocked`.

src/trio/_tests/test_threads.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
from .._core._tests.test_ki import ki_self
3939
from .._core._tests.tutil import slow
4040
from .._threads import (
41+
active_thread_count,
4142
current_default_thread_limiter,
4243
from_thread_check_cancelled,
4344
from_thread_run,
4445
from_thread_run_sync,
4546
to_thread_run_sync,
47+
wait_all_threads_completed,
4648
)
4749
from ..testing import wait_all_tasks_blocked
4850

@@ -1106,3 +1108,50 @@ async def test_cancellable_warns() -> None:
11061108

11071109
with pytest.warns(TrioDeprecationWarning):
11081110
await to_thread_run_sync(bool, cancellable=True)
1111+
1112+
1113+
async def test_wait_all_threads_completed() -> None:
1114+
no_threads_left = False
1115+
e1 = Event()
1116+
e2 = Event()
1117+
1118+
e1_exited = Event()
1119+
e2_exited = Event()
1120+
1121+
async def wait_event(e: Event, e_exit: Event) -> None:
1122+
def thread() -> None:
1123+
from_thread_run(e.wait)
1124+
1125+
await to_thread_run_sync(thread)
1126+
e_exit.set()
1127+
1128+
async def wait_no_threads_left() -> None:
1129+
nonlocal no_threads_left
1130+
await wait_all_threads_completed()
1131+
no_threads_left = True
1132+
1133+
async with _core.open_nursery() as nursery:
1134+
nursery.start_soon(wait_event, e1, e1_exited)
1135+
nursery.start_soon(wait_event, e2, e2_exited)
1136+
await wait_all_tasks_blocked()
1137+
nursery.start_soon(wait_no_threads_left)
1138+
await wait_all_tasks_blocked()
1139+
assert not no_threads_left
1140+
assert active_thread_count() == 2
1141+
1142+
e1.set()
1143+
await e1_exited.wait()
1144+
await wait_all_tasks_blocked()
1145+
assert not no_threads_left
1146+
assert active_thread_count() == 1
1147+
1148+
e2.set()
1149+
await e2_exited.wait()
1150+
await wait_all_tasks_blocked()
1151+
assert no_threads_left
1152+
assert active_thread_count() == 0
1153+
1154+
1155+
async def test_wait_all_threads_completed_no_threads() -> None:
1156+
await wait_all_threads_completed()
1157+
assert active_thread_count() == 0

src/trio/_threads.py

Lines changed: 104 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,25 @@
1010

1111
import attr
1212
import outcome
13+
from attrs import define
1314
from sniffio import current_async_library_cvar
1415

1516
import trio
1617

1718
from ._core import (
1819
RunVar,
1920
TrioToken,
21+
checkpoint,
2022
disable_ki_protection,
2123
enable_ki_protection,
2224
start_thread_soon,
2325
)
2426
from ._deprecate import warn_deprecated
25-
from ._sync import CapacityLimiter
27+
from ._sync import CapacityLimiter, Event
2628
from ._util import coroutine_or_error
2729

2830
if TYPE_CHECKING:
29-
from collections.abc import Awaitable, Callable
31+
from collections.abc import Awaitable, Callable, Generator
3032

3133
from trio._core._traps import RaiseCancelT
3234

@@ -52,6 +54,72 @@ class _ParentTaskData(threading.local):
5254
_thread_counter = count()
5355

5456

57+
@define
58+
class _ActiveThreadCount:
59+
count: int
60+
event: Event
61+
62+
63+
_active_threads_local: RunVar[_ActiveThreadCount] = RunVar("active_threads")
64+
65+
66+
@contextlib.contextmanager
67+
def _track_active_thread() -> Generator[None, None, None]:
68+
try:
69+
active_threads_local = _active_threads_local.get()
70+
except LookupError:
71+
active_threads_local = _ActiveThreadCount(0, Event())
72+
_active_threads_local.set(active_threads_local)
73+
74+
active_threads_local.count += 1
75+
try:
76+
yield
77+
finally:
78+
active_threads_local.count -= 1
79+
if active_threads_local.count == 0:
80+
active_threads_local.event.set()
81+
active_threads_local.event = Event()
82+
83+
84+
async def wait_all_threads_completed() -> None:
85+
"""Wait until no threads are still running tasks.
86+
87+
This is intended to be used when testing code with trio.to_thread to
88+
make sure no tasks are still making progress in a thread. See the
89+
following code for a usage example::
90+
91+
async def wait_all_settled():
92+
while True:
93+
await trio.testing.wait_all_threads_complete()
94+
await trio.testing.wait_all_tasks_blocked()
95+
if trio.testing.active_thread_count() == 0:
96+
break
97+
"""
98+
99+
await checkpoint()
100+
101+
try:
102+
active_threads_local = _active_threads_local.get()
103+
except LookupError:
104+
# If there would have been active threads, the
105+
# _active_threads_local would have been set
106+
return
107+
108+
while active_threads_local.count != 0:
109+
await active_threads_local.event.wait()
110+
111+
112+
def active_thread_count() -> int:
113+
"""Returns the number of threads that are currently running a task
114+
115+
See `trio.testing.wait_all_threads_completed`
116+
"""
117+
try:
118+
return _active_threads_local.get().count
119+
except LookupError:
120+
return 0
121+
122+
55123
def current_default_thread_limiter() -> CapacityLimiter:
56124
"""Get the default `~trio.CapacityLimiter` used by
57125
`trio.to_thread.run_sync`.
@@ -375,39 +443,40 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None:
375443
current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result)
376444

377445
await limiter.acquire_on_behalf_of(placeholder)
378-
try:
379-
start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name)
380-
except:
381-
limiter.release_on_behalf_of(placeholder)
382-
raise
383-
384-
def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:
385-
# fill so from_thread_check_cancelled can raise
386-
cancel_register[0] = raise_cancel
387-
if abandon_bool:
388-
# empty so report_back_in_trio_thread_fn cannot reschedule
389-
task_register[0] = None
390-
return trio.lowlevel.Abort.SUCCEEDED
391-
else:
392-
return trio.lowlevel.Abort.FAILED
393-
394-
while True:
395-
# wait_task_rescheduled return value cannot be typed
396-
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[
397-
object
398-
] = await trio.lowlevel.wait_task_rescheduled(abort)
399-
if isinstance(msg_from_thread, outcome.Outcome):
400-
return msg_from_thread.unwrap()
401-
elif isinstance(msg_from_thread, Run):
402-
await msg_from_thread.run()
403-
elif isinstance(msg_from_thread, RunSync):
404-
msg_from_thread.run_sync()
405-
else: # pragma: no cover, internal debugging guard TODO: use assert_never
406-
raise TypeError(
407-
"trio.to_thread.run_sync received unrecognized thread message {!r}."
408-
"".format(msg_from_thread)
409-
)
410-
del msg_from_thread
446+
with _track_active_thread():
447+
try:
448+
start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name)
449+
except:
450+
limiter.release_on_behalf_of(placeholder)
451+
raise
452+
453+
def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:
454+
# fill so from_thread_check_cancelled can raise
455+
cancel_register[0] = raise_cancel
456+
if abandon_bool:
457+
# empty so report_back_in_trio_thread_fn cannot reschedule
458+
task_register[0] = None
459+
return trio.lowlevel.Abort.SUCCEEDED
460+
else:
461+
return trio.lowlevel.Abort.FAILED
462+
463+
while True:
464+
# wait_task_rescheduled return value cannot be typed
465+
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[
466+
object
467+
] = await trio.lowlevel.wait_task_rescheduled(abort)
468+
if isinstance(msg_from_thread, outcome.Outcome):
469+
return msg_from_thread.unwrap()
470+
elif isinstance(msg_from_thread, Run):
471+
await msg_from_thread.run()
472+
elif isinstance(msg_from_thread, RunSync):
473+
msg_from_thread.run_sync()
474+
else: # pragma: no cover, internal debugging guard TODO: use assert_never
475+
raise TypeError(
476+
"trio.to_thread.run_sync received unrecognized thread message {!r}."
477+
"".format(msg_from_thread)
478+
)
479+
del msg_from_thread
411480

412481

413482
def from_thread_check_cancelled() -> None:

src/trio/testing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
MockClock as MockClock,
55
wait_all_tasks_blocked as wait_all_tasks_blocked,
66
)
7+
from .._threads import (
8+
active_thread_count as active_thread_count,
9+
wait_all_threads_completed as wait_all_threads_completed,
10+
)
711
from .._util import fixup_module_metadata
812
from ._check_streams import (
913
check_half_closeable_stream as check_half_closeable_stream,

0 commit comments

Comments
 (0)