diff options
Diffstat (limited to 'sources/pyside6')
| -rw-r--r-- | sources/pyside6/PySide6/QtAsyncio/tasks.py | 11 | ||||
| -rw-r--r-- | sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py | 57 |
2 files changed, 68 insertions, 0 deletions
diff --git a/sources/pyside6/PySide6/QtAsyncio/tasks.py b/sources/pyside6/PySide6/QtAsyncio/tasks.py index bc3d41a73..6777b8bc3 100644 --- a/sources/pyside6/PySide6/QtAsyncio/tasks.py +++ b/sources/pyside6/PySide6/QtAsyncio/tasks.py @@ -29,6 +29,7 @@ class QAsyncioTask(futures.QAsyncioFuture): self._future_to_await: typing.Optional[asyncio.Future] = None self._cancel_message: typing.Optional[str] = None + self._cancelled = False asyncio._register_task(self) # type: ignore[arg-type] @@ -90,6 +91,15 @@ class QAsyncioTask(futures.QAsyncioFuture): result.add_done_callback( self._step, context=self._context) # type: ignore[arg-type] self._future_to_await = result + if self._cancelled: + # If the task was cancelled, then a new future should be + # cancelled as well. Otherwise, in some scenarios like + # a loop inside the task and with bad timing, if the new + # future is not cancelled, the task would continue running + # in this loop despite having been cancelled. This bad + # timing can occur especially if the first future finishes + # very quickly. + self._future_to_await.cancel(self._cancel_message) elif result is None: self._loop.call_soon(self._step, context=self._context) else: @@ -136,6 +146,7 @@ class QAsyncioTask(futures.QAsyncioFuture): self._handle.cancel() if self._future_to_await is not None: self._future_to_await.cancel(msg) + self._cancelled = True return True def uncancel(self) -> None: diff --git a/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py new file mode 100644 index 000000000..aa8ce4718 --- /dev/null +++ b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py @@ -0,0 +1,57 @@ +# Copyright (C) 2024 The Qt Company Ltd. +# SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0 + +'''Test cases for QtAsyncio''' + +import asyncio +import unittest + +import PySide6.QtAsyncio as QtAsyncio + + +class QAsyncioTestCaseCancelTaskGroup(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + # We only reach the end of the loop if the task is not cancelled. + self.loop_end_reached = False + + async def raise_error(self): + raise RuntimeError + + async def loop_short(self): + self._loop_end_reached = False + for _ in range(1000): + await asyncio.sleep(1e-3) + self._loop_end_reached = True + + async def loop_shorter(self): + self._loop_end_reached = False + for _ in range(1000): + await asyncio.sleep(1e-4) + self._loop_end_reached = True + + async def loop_the_shortest(self): + self._loop_end_reached = False + for _ in range(1000): + await asyncio.to_thread(lambda: None) + self._loop_end_reached = True + + async def main(self, coro): + async with asyncio.TaskGroup() as tg: + tg.create_task(coro()) + tg.create_task(self.raise_error()) + + def test_cancel_taskgroup(self): + coros = [self.loop_short, self.loop_shorter, self.loop_the_shortest] + + for coro in coros: + try: + QtAsyncio.run(self.main(coro), keep_running=False) + except ExceptionGroup as e: + self.assertEqual(len(e.exceptions), 1) + self.assertIsInstance(e.exceptions[0], RuntimeError) + self.assertFalse(self._loop_end_reached) + + +if __name__ == '__main__': + unittest.main() |
