10
10
11
11
import attrs
12
12
import outcome
13
+ from attrs import define
13
14
from sniffio import current_async_library_cvar
14
15
15
16
import trio
16
17
17
18
from ._core import (
18
19
RunVar ,
19
20
TrioToken ,
21
+ checkpoint ,
20
22
disable_ki_protection ,
21
23
enable_ki_protection ,
22
24
start_thread_soon ,
23
25
)
24
26
from ._deprecate import warn_deprecated
25
- from ._sync import CapacityLimiter
27
+ from ._sync import CapacityLimiter , Event
26
28
from ._util import coroutine_or_error
27
29
28
30
if TYPE_CHECKING :
29
- from collections .abc import Awaitable , Callable
31
+ from collections .abc import Awaitable , Callable , Generator
30
32
31
33
from trio ._core ._traps import RaiseCancelT
32
34
@@ -52,6 +54,72 @@ class _ParentTaskData(threading.local):
52
54
_thread_counter = count ()
53
55
54
56
57
+ @define
58
+ class _ActiveThreadCount :
59
+ count : int
60
+ event : Event
61
+
62
+
63
+ _active_threads_local : RunVar [_ActiveThreadCount ] = RunVar ("active_threads" )
64
+
65
+
66
+ @contextlib .contextmanager
67
+ def _track_active_thread () -> Generator [None , None , None ]:
68
+ try :
69
+ active_threads_local = _active_threads_local .get ()
70
+ except LookupError :
71
+ active_threads_local = _ActiveThreadCount (0 , Event ())
72
+ _active_threads_local .set (active_threads_local )
73
+
74
+ active_threads_local .count += 1
75
+ try :
76
+ yield
77
+ finally :
78
+ active_threads_local .count -= 1
79
+ if active_threads_local .count == 0 :
80
+ active_threads_local .event .set ()
81
+ active_threads_local .event = Event ()
82
+
83
+
84
+ async def wait_all_threads_completed () -> None :
85
+ """Wait until no threads are still running tasks.
86
+
87
+ This is intended to be used when testing code with trio.to_thread to
88
+ make sure no tasks are still making progress in a thread. See the
89
+ following code for a usage example::
90
+
91
+ async def wait_all_settled():
92
+ while True:
93
+ await trio.testing.wait_all_threads_complete()
94
+ await trio.testing.wait_all_tasks_blocked()
95
+ if trio.testing.active_thread_count() == 0:
96
+ break
97
+ """
98
+
99
+ await checkpoint ()
100
+
101
+ try :
102
+ active_threads_local = _active_threads_local .get ()
103
+ except LookupError :
104
+ # If there would have been active threads, the
105
+ # _active_threads_local would have been set
106
+ return
107
+
108
+ while active_threads_local .count != 0 :
109
+ await active_threads_local .event .wait ()
110
+
111
+
112
+ def active_thread_count () -> int :
113
+ """Returns the number of threads that are currently running a task
114
+
115
+ See `trio.testing.wait_all_threads_completed`
116
+ """
117
+ try :
118
+ return _active_threads_local .get ().count
119
+ except LookupError :
120
+ return 0
121
+
122
+
55
123
def current_default_thread_limiter () -> CapacityLimiter :
56
124
"""Get the default `~trio.CapacityLimiter` used by
57
125
`trio.to_thread.run_sync`.
@@ -377,39 +445,40 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None:
377
445
current_trio_token .run_sync_soon (report_back_in_trio_thread_fn , result )
378
446
379
447
await limiter .acquire_on_behalf_of (placeholder )
380
- try :
381
- start_thread_soon (worker_fn , deliver_worker_fn_result , thread_name )
382
- except :
383
- limiter .release_on_behalf_of (placeholder )
384
- raise
385
-
386
- def abort (raise_cancel : RaiseCancelT ) -> trio .lowlevel .Abort :
387
- # fill so from_thread_check_cancelled can raise
388
- cancel_register [0 ] = raise_cancel
389
- if abandon_bool :
390
- # empty so report_back_in_trio_thread_fn cannot reschedule
391
- task_register [0 ] = None
392
- return trio .lowlevel .Abort .SUCCEEDED
393
- else :
394
- return trio .lowlevel .Abort .FAILED
395
-
396
- while True :
397
- # wait_task_rescheduled return value cannot be typed
398
- msg_from_thread : outcome .Outcome [RetT ] | Run [object ] | RunSync [object ] = (
399
- await trio .lowlevel .wait_task_rescheduled (abort )
400
- )
401
- if isinstance (msg_from_thread , outcome .Outcome ):
402
- return msg_from_thread .unwrap ()
403
- elif isinstance (msg_from_thread , Run ):
404
- await msg_from_thread .run ()
405
- elif isinstance (msg_from_thread , RunSync ):
406
- msg_from_thread .run_sync ()
407
- else : # pragma: no cover, internal debugging guard TODO: use assert_never
408
- raise TypeError (
409
- "trio.to_thread.run_sync received unrecognized thread message {!r}."
410
- "" .format (msg_from_thread )
448
+ with _track_active_thread ():
449
+ try :
450
+ start_thread_soon (worker_fn , deliver_worker_fn_result , thread_name )
451
+ except :
452
+ limiter .release_on_behalf_of (placeholder )
453
+ raise
454
+
455
+ def abort (raise_cancel : RaiseCancelT ) -> trio .lowlevel .Abort :
456
+ # fill so from_thread_check_cancelled can raise
457
+ cancel_register [0 ] = raise_cancel
458
+ if abandon_bool :
459
+ # empty so report_back_in_trio_thread_fn cannot reschedule
460
+ task_register [0 ] = None
461
+ return trio .lowlevel .Abort .SUCCEEDED
462
+ else :
463
+ return trio .lowlevel .Abort .FAILED
464
+
465
+ while True :
466
+ # wait_task_rescheduled return value cannot be typed
467
+ msg_from_thread : outcome .Outcome [RetT ] | Run [object ] | RunSync [object ] = (
468
+ await trio .lowlevel .wait_task_rescheduled (abort )
411
469
)
412
- del msg_from_thread
470
+ if isinstance (msg_from_thread , outcome .Outcome ):
471
+ return msg_from_thread .unwrap ()
472
+ elif isinstance (msg_from_thread , Run ):
473
+ await msg_from_thread .run ()
474
+ elif isinstance (msg_from_thread , RunSync ):
475
+ msg_from_thread .run_sync ()
476
+ else : # pragma: no cover, internal debugging guard TODO: use assert_never
477
+ raise TypeError (
478
+ "trio.to_thread.run_sync received unrecognized thread message {!r}."
479
+ "" .format (msg_from_thread )
480
+ )
481
+ del msg_from_thread
413
482
414
483
415
484
def from_thread_check_cancelled () -> None :
0 commit comments