Skip to content

Commit bdfe2ca

Browse files
kumaraditya303WolframAlph
authored andcommitted
pythongh-128002: fix many thread safety issues in asyncio (python#128147)
* Makes `_asyncio.Task` and `_asyncio.Future` thread-safe by adding critical sections * Add assertions to check for thread safety checking locking of object by critical sections in internal functions * Make `_asyncio.all_tasks` thread safe when eager tasks are used * Add a thread safety test
1 parent e7eb021 commit bdfe2ca

File tree

3 files changed

+951
-175
lines changed

3 files changed

+951
-175
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import asyncio
2+
import unittest
3+
from threading import Thread
4+
from unittest import TestCase
5+
6+
from test.support import threading_helper
7+
8+
threading_helper.requires_working_threading(module=True)
9+
10+
def tearDownModule():
11+
asyncio._set_event_loop_policy(None)
12+
13+
14+
class TestFreeThreading:
15+
def test_all_tasks_race(self) -> None:
16+
async def main():
17+
loop = asyncio.get_running_loop()
18+
future = loop.create_future()
19+
20+
async def coro():
21+
await future
22+
23+
tasks = set()
24+
25+
async with asyncio.TaskGroup() as tg:
26+
for _ in range(100):
27+
tasks.add(tg.create_task(coro()))
28+
29+
all_tasks = self.all_tasks(loop)
30+
self.assertEqual(len(all_tasks), 101)
31+
32+
for task in all_tasks:
33+
self.assertEqual(task.get_loop(), loop)
34+
self.assertFalse(task.done())
35+
36+
current = self.current_task()
37+
self.assertEqual(current.get_loop(), loop)
38+
self.assertSetEqual(all_tasks, tasks | {current})
39+
future.set_result(None)
40+
41+
def runner():
42+
with asyncio.Runner() as runner:
43+
loop = runner.get_loop()
44+
loop.set_task_factory(self.factory)
45+
runner.run(main())
46+
47+
threads = []
48+
49+
for _ in range(10):
50+
thread = Thread(target=runner)
51+
threads.append(thread)
52+
53+
with threading_helper.start_threads(threads):
54+
pass
55+
56+
57+
class TestPyFreeThreading(TestFreeThreading, TestCase):
58+
all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
59+
current_task = staticmethod(asyncio.tasks._py_current_task)
60+
61+
def factory(self, loop, coro, context=None):
62+
return asyncio.tasks._PyTask(coro, loop=loop, context=context)
63+
64+
65+
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
66+
class TestCFreeThreading(TestFreeThreading, TestCase):
67+
all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
68+
current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
69+
70+
def factory(self, loop, coro, context=None):
71+
return asyncio.tasks._CTask(coro, loop=loop, context=context)
72+
73+
74+
class TestEagerPyFreeThreading(TestPyFreeThreading):
75+
def factory(self, loop, coro, context=None):
76+
return asyncio.tasks._PyTask(coro, loop=loop, context=context, eager_start=True)
77+
78+
79+
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
80+
class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
81+
def factory(self, loop, coro, context=None):
82+
return asyncio.tasks._CTask(coro, loop=loop, context=context, eager_start=True)

0 commit comments

Comments
 (0)