diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py index 90bddbf3a9dda1..8f4bba5f3b97d9 100644 --- a/Lib/test/test_asyncio/test_free_threading.py +++ b/Lib/test/test_asyncio/test_free_threading.py @@ -7,6 +7,11 @@ threading_helper.requires_working_threading(module=True) + +class MyException(Exception): + pass + + def tearDownModule(): asyncio._set_event_loop_policy(None) @@ -53,6 +58,55 @@ def runner(): with threading_helper.start_threads(threads): pass + def test_run_coroutine_threadsafe(self) -> None: + results = [] + + def in_thread(loop: asyncio.AbstractEventLoop): + coro = asyncio.sleep(0.1, result=42) + fut = asyncio.run_coroutine_threadsafe(coro, loop) + result = fut.result() + self.assertEqual(result, 42) + results.append(result) + + async def main(): + loop = asyncio.get_running_loop() + async with asyncio.TaskGroup() as tg: + for _ in range(10): + tg.create_task(asyncio.to_thread(in_thread, loop)) + self.assertEqual(results, [42] * 10) + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + + def test_run_coroutine_threadsafe_exception(self) -> None: + async def coro(): + await asyncio.sleep(0) + raise MyException("test") + + def in_thread(loop: asyncio.AbstractEventLoop): + fut = asyncio.run_coroutine_threadsafe(coro(), loop) + return fut.result() + + async def main(): + loop = asyncio.get_running_loop() + tasks = [] + for _ in range(10): + task = loop.create_task(asyncio.to_thread(in_thread, loop)) + tasks.append(task) + results = await asyncio.gather(*tasks, return_exceptions=True) + + self.assertEqual(len(results), 10) + for result in results: + self.assertIsInstance(result, MyException) + self.assertEqual(str(result), "test") + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + class TestPyFreeThreading(TestFreeThreading, TestCase): all_tasks = staticmethod(asyncio.tasks._py_all_tasks)