From 7344c48bc3db871a1a6e9459e6c1d69c922e91ec Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sat, 4 Jan 2025 11:43:55 +0000 Subject: [PATCH 1/3] add more thread safety tests --- Lib/test/test_asyncio/test_free_threading.py | 86 ++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py index 90bddbf3a9dda1..8fb5ea5f5612ea 100644 --- a/Lib/test/test_asyncio/test_free_threading.py +++ b/Lib/test/test_asyncio/test_free_threading.py @@ -7,6 +7,10 @@ threading_helper.requires_working_threading(module=True) +class MyException(Exception): + pass + + def tearDownModule(): asyncio._set_event_loop_policy(None) @@ -53,6 +57,88 @@ def runner(): with threading_helper.start_threads(threads): pass + def test_run_coroutine_threadsafe(self) -> None: + results = [] + + def in_thread(loop: asyncio.AbstractEventLoop): + coro = asyncio.sleep(0.1, result=42) + fut = asyncio.run_coroutine_threadsafe(coro, loop) + result = fut.result() + self.assertEqual(result, 42) + results.append(result) + + async def main(): + loop = asyncio.get_running_loop() + async with asyncio.TaskGroup() as tg: + for _ in range(10): + tg.create_task(asyncio.to_thread(in_thread, loop)) + self.assertEqual(results, [42] * 10) + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + + def test_run_coroutine_threadsafe_exception_caught(self) -> None: + exc = MyException("test") + + async def coro(): + await asyncio.sleep(0.1) + raise exc + + def in_thread(loop: asyncio.AbstractEventLoop): + fut = asyncio.run_coroutine_threadsafe(coro(), loop) + self.assertEqual(fut.exception(), exc) + return exc + + async def main(): + loop = asyncio.get_running_loop() + tasks = [] + async with asyncio.TaskGroup() as tg: + for _ in range(10): + task = tg.create_task(asyncio.to_thread(in_thread, loop)) + tasks.append(task) + for task in tasks: + self.assertEqual(await task, exc) + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + + def test_run_coroutine_threadsafe_exception_uncaught(self) -> None: + async def coro(): + await asyncio.sleep(1) + raise MyException("test") + + def in_thread(loop: asyncio.AbstractEventLoop): + fut = asyncio.run_coroutine_threadsafe(coro(), loop) + return fut.result() + + async def main(): + loop = asyncio.get_running_loop() + tasks = [] + try: + async with asyncio.TaskGroup() as tg: + for _ in range(10): + task = tg.create_task(asyncio.to_thread(in_thread, loop)) + tasks.append(task) + except ExceptionGroup: + for task in tasks: + try: + await task + except (MyException, asyncio.CancelledError): + pass + else: + self.fail("Task should have raised an exception") + else: + self.fail("TaskGroup should have raised an exception") + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + class TestPyFreeThreading(TestFreeThreading, TestCase): all_tasks = staticmethod(asyncio.tasks._py_all_tasks) From 7984147606696f10da83722255fc18778d3043ca Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Thu, 9 Jan 2025 16:31:26 +0000 Subject: [PATCH 2/3] fix tests --- Lib/test/test_asyncio/test_free_threading.py | 53 ++++---------------- 1 file changed, 11 insertions(+), 42 deletions(-) diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py index 8fb5ea5f5612ea..b19a8ffc7953e7 100644 --- a/Lib/test/test_asyncio/test_free_threading.py +++ b/Lib/test/test_asyncio/test_free_threading.py @@ -7,6 +7,7 @@ threading_helper.requires_working_threading(module=True) + class MyException(Exception): pass @@ -79,38 +80,13 @@ async def main(): loop.set_task_factory(self.factory) r.run(main()) - def test_run_coroutine_threadsafe_exception_caught(self) -> None: + def test_run_coroutine_threadsafe_exception(self) -> None: exc = MyException("test") async def coro(): - await asyncio.sleep(0.1) + await asyncio.sleep(0) raise exc - def in_thread(loop: asyncio.AbstractEventLoop): - fut = asyncio.run_coroutine_threadsafe(coro(), loop) - self.assertEqual(fut.exception(), exc) - return exc - - async def main(): - loop = asyncio.get_running_loop() - tasks = [] - async with asyncio.TaskGroup() as tg: - for _ in range(10): - task = tg.create_task(asyncio.to_thread(in_thread, loop)) - tasks.append(task) - for task in tasks: - self.assertEqual(await task, exc) - - with asyncio.Runner() as r: - loop = r.get_loop() - loop.set_task_factory(self.factory) - r.run(main()) - - def test_run_coroutine_threadsafe_exception_uncaught(self) -> None: - async def coro(): - await asyncio.sleep(1) - raise MyException("test") - def in_thread(loop: asyncio.AbstractEventLoop): fut = asyncio.run_coroutine_threadsafe(coro(), loop) return fut.result() @@ -118,21 +94,14 @@ def in_thread(loop: asyncio.AbstractEventLoop): async def main(): loop = asyncio.get_running_loop() tasks = [] - try: - async with asyncio.TaskGroup() as tg: - for _ in range(10): - task = tg.create_task(asyncio.to_thread(in_thread, loop)) - tasks.append(task) - except ExceptionGroup: - for task in tasks: - try: - await task - except (MyException, asyncio.CancelledError): - pass - else: - self.fail("Task should have raised an exception") - else: - self.fail("TaskGroup should have raised an exception") + for _ in range(10): + task = loop.create_task(asyncio.to_thread(in_thread, loop)) + tasks.append(task) + results = await asyncio.gather(*tasks, return_exceptions=True) + + self.assertEqual(len(results), 10) + for result in results: + self.assertIs(result, exc) with asyncio.Runner() as r: loop = r.get_loop() From 025fd626a12501548e3f8494fa36f372a8afcbac Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Mon, 13 Jan 2025 15:12:59 +0000 Subject: [PATCH 3/3] fix test --- Lib/test/test_asyncio/test_free_threading.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py index b19a8ffc7953e7..8f4bba5f3b97d9 100644 --- a/Lib/test/test_asyncio/test_free_threading.py +++ b/Lib/test/test_asyncio/test_free_threading.py @@ -81,11 +81,9 @@ async def main(): r.run(main()) def test_run_coroutine_threadsafe_exception(self) -> None: - exc = MyException("test") - async def coro(): await asyncio.sleep(0) - raise exc + raise MyException("test") def in_thread(loop: asyncio.AbstractEventLoop): fut = asyncio.run_coroutine_threadsafe(coro(), loop) @@ -101,7 +99,8 @@ async def main(): self.assertEqual(len(results), 10) for result in results: - self.assertIs(result, exc) + self.assertIsInstance(result, MyException) + self.assertEqual(str(result), "test") with asyncio.Runner() as r: loop = r.get_loop()