10
10
11
11
import attr
12
12
import outcome
13
+ from attrs import define
13
14
from sniffio import current_async_library_cvar
14
15
15
16
import trio
16
17
17
18
from ._core import (
18
19
RunVar ,
19
20
TrioToken ,
21
+ checkpoint ,
20
22
disable_ki_protection ,
21
23
enable_ki_protection ,
22
24
start_thread_soon ,
23
25
)
24
26
from ._deprecate import warn_deprecated
25
- from ._sync import CapacityLimiter
27
+ from ._sync import CapacityLimiter , Event
26
28
from ._util import coroutine_or_error
27
29
28
30
if TYPE_CHECKING :
29
- from collections .abc import Awaitable , Callable
31
+ from collections .abc import Awaitable , Callable , Generator
30
32
31
33
from trio ._core ._traps import RaiseCancelT
32
34
@@ -52,6 +54,72 @@ class _ParentTaskData(threading.local):
52
54
_thread_counter = count ()
53
55
54
56
57
+ @define
58
+ class _ActiveThreadCount :
59
+ count : int
60
+ event : Event
61
+
62
+
63
+ _active_threads_local : RunVar [_ActiveThreadCount ] = RunVar ("active_threads" )
64
+
65
+
66
+ @contextlib .contextmanager
67
+ def _track_active_thread () -> Generator [None , None , None ]:
68
+ try :
69
+ active_threads_local = _active_threads_local .get ()
70
+ except LookupError :
71
+ active_threads_local = _ActiveThreadCount (0 , Event ())
72
+ _active_threads_local .set (active_threads_local )
73
+
74
+ active_threads_local .count += 1
75
+ try :
76
+ yield
77
+ finally :
78
+ active_threads_local .count -= 1
79
+ if active_threads_local .count == 0 :
80
+ active_threads_local .event .set ()
81
+ active_threads_local .event = Event ()
82
+
83
+
84
+ async def wait_all_threads_completed () -> None :
85
+ """Wait until no threads are still running tasks.
86
+
87
+ This is intended to be used when testing code with trio.to_thread to
88
+ make sure no tasks are still making progress in a thread. See the
89
+ following code for a usage example::
90
+
91
+ async def wait_all_settled():
92
+ while True:
93
+ await trio.testing.wait_all_threads_complete()
94
+ await trio.testing.wait_all_tasks_blocked()
95
+ if trio.testing.active_thread_count() == 0:
96
+ break
97
+ """
98
+
99
+ await checkpoint ()
100
+
101
+ try :
102
+ active_threads_local = _active_threads_local .get ()
103
+ except LookupError :
104
+ # If there would have been active threads, the
105
+ # _active_threads_local would have been set
106
+ return
107
+
108
+ while active_threads_local .count != 0 :
109
+ await active_threads_local .event .wait ()
110
+
111
+
112
+ def active_thread_count () -> int :
113
+ """Returns the number of threads that are currently running a task
114
+
115
+ See `trio.testing.wait_all_threads_completed`
116
+ """
117
+ try :
118
+ return _active_threads_local .get ().count
119
+ except LookupError :
120
+ return 0
121
+
122
+
55
123
def current_default_thread_limiter () -> CapacityLimiter :
56
124
"""Get the default `~trio.CapacityLimiter` used by
57
125
`trio.to_thread.run_sync`.
@@ -375,39 +443,40 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None:
375
443
current_trio_token .run_sync_soon (report_back_in_trio_thread_fn , result )
376
444
377
445
await limiter .acquire_on_behalf_of (placeholder )
378
- try :
379
- start_thread_soon (worker_fn , deliver_worker_fn_result , thread_name )
380
- except :
381
- limiter .release_on_behalf_of (placeholder )
382
- raise
383
-
384
- def abort (raise_cancel : RaiseCancelT ) -> trio .lowlevel .Abort :
385
- # fill so from_thread_check_cancelled can raise
386
- cancel_register [0 ] = raise_cancel
387
- if abandon_bool :
388
- # empty so report_back_in_trio_thread_fn cannot reschedule
389
- task_register [0 ] = None
390
- return trio .lowlevel .Abort .SUCCEEDED
391
- else :
392
- return trio .lowlevel .Abort .FAILED
393
-
394
- while True :
395
- # wait_task_rescheduled return value cannot be typed
396
- msg_from_thread : outcome .Outcome [RetT ] | Run [object ] | RunSync [
397
- object
398
- ] = await trio .lowlevel .wait_task_rescheduled (abort )
399
- if isinstance (msg_from_thread , outcome .Outcome ):
400
- return msg_from_thread .unwrap ()
401
- elif isinstance (msg_from_thread , Run ):
402
- await msg_from_thread .run ()
403
- elif isinstance (msg_from_thread , RunSync ):
404
- msg_from_thread .run_sync ()
405
- else : # pragma: no cover, internal debugging guard TODO: use assert_never
406
- raise TypeError (
407
- "trio.to_thread.run_sync received unrecognized thread message {!r}."
408
- "" .format (msg_from_thread )
409
- )
410
- del msg_from_thread
446
+ with _track_active_thread ():
447
+ try :
448
+ start_thread_soon (worker_fn , deliver_worker_fn_result , thread_name )
449
+ except :
450
+ limiter .release_on_behalf_of (placeholder )
451
+ raise
452
+
453
+ def abort (raise_cancel : RaiseCancelT ) -> trio .lowlevel .Abort :
454
+ # fill so from_thread_check_cancelled can raise
455
+ cancel_register [0 ] = raise_cancel
456
+ if abandon_bool :
457
+ # empty so report_back_in_trio_thread_fn cannot reschedule
458
+ task_register [0 ] = None
459
+ return trio .lowlevel .Abort .SUCCEEDED
460
+ else :
461
+ return trio .lowlevel .Abort .FAILED
462
+
463
+ while True :
464
+ # wait_task_rescheduled return value cannot be typed
465
+ msg_from_thread : outcome .Outcome [RetT ] | Run [object ] | RunSync [
466
+ object
467
+ ] = await trio .lowlevel .wait_task_rescheduled (abort )
468
+ if isinstance (msg_from_thread , outcome .Outcome ):
469
+ return msg_from_thread .unwrap ()
470
+ elif isinstance (msg_from_thread , Run ):
471
+ await msg_from_thread .run ()
472
+ elif isinstance (msg_from_thread , RunSync ):
473
+ msg_from_thread .run_sync ()
474
+ else : # pragma: no cover, internal debugging guard TODO: use assert_never
475
+ raise TypeError (
476
+ "trio.to_thread.run_sync received unrecognized thread message {!r}."
477
+ "" .format (msg_from_thread )
478
+ )
479
+ del msg_from_thread
411
480
412
481
413
482
def from_thread_check_cancelled () -> None :
0 commit comments