aboutsummaryrefslogtreecommitdiffstats
path: root/sources/pyside6
diff options
context:
space:
mode:
Diffstat (limited to 'sources/pyside6')
-rw-r--r--sources/pyside6/PySide6/QtAsyncio/tasks.py11
-rw-r--r--sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py57
2 files changed, 68 insertions, 0 deletions
diff --git a/sources/pyside6/PySide6/QtAsyncio/tasks.py b/sources/pyside6/PySide6/QtAsyncio/tasks.py
index bc3d41a73..6777b8bc3 100644
--- a/sources/pyside6/PySide6/QtAsyncio/tasks.py
+++ b/sources/pyside6/PySide6/QtAsyncio/tasks.py
@@ -29,6 +29,7 @@ class QAsyncioTask(futures.QAsyncioFuture):
self._future_to_await: typing.Optional[asyncio.Future] = None
self._cancel_message: typing.Optional[str] = None
+ self._cancelled = False
asyncio._register_task(self) # type: ignore[arg-type]
@@ -90,6 +91,15 @@ class QAsyncioTask(futures.QAsyncioFuture):
result.add_done_callback(
self._step, context=self._context) # type: ignore[arg-type]
self._future_to_await = result
+ if self._cancelled:
+ # If the task was cancelled, then a new future should be
+ # cancelled as well. Otherwise, in some scenarios like
+ # a loop inside the task and with bad timing, if the new
+ # future is not cancelled, the task would continue running
+ # in this loop despite having been cancelled. This bad
+ # timing can occur especially if the first future finishes
+ # very quickly.
+ self._future_to_await.cancel(self._cancel_message)
elif result is None:
self._loop.call_soon(self._step, context=self._context)
else:
@@ -136,6 +146,7 @@ class QAsyncioTask(futures.QAsyncioFuture):
self._handle.cancel()
if self._future_to_await is not None:
self._future_to_await.cancel(msg)
+ self._cancelled = True
return True
def uncancel(self) -> None:
diff --git a/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py
new file mode 100644
index 000000000..aa8ce4718
--- /dev/null
+++ b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py
@@ -0,0 +1,57 @@
+# Copyright (C) 2024 The Qt Company Ltd.
+# SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0
+
+'''Test cases for QtAsyncio'''
+
+import asyncio
+import unittest
+
+import PySide6.QtAsyncio as QtAsyncio
+
+
+class QAsyncioTestCaseCancelTaskGroup(unittest.TestCase):
+ def setUp(self) -> None:
+ super().setUp()
+ # We only reach the end of the loop if the task is not cancelled.
+ self.loop_end_reached = False
+
+ async def raise_error(self):
+ raise RuntimeError
+
+ async def loop_short(self):
+ self._loop_end_reached = False
+ for _ in range(1000):
+ await asyncio.sleep(1e-3)
+ self._loop_end_reached = True
+
+ async def loop_shorter(self):
+ self._loop_end_reached = False
+ for _ in range(1000):
+ await asyncio.sleep(1e-4)
+ self._loop_end_reached = True
+
+ async def loop_the_shortest(self):
+ self._loop_end_reached = False
+ for _ in range(1000):
+ await asyncio.to_thread(lambda: None)
+ self._loop_end_reached = True
+
+ async def main(self, coro):
+ async with asyncio.TaskGroup() as tg:
+ tg.create_task(coro())
+ tg.create_task(self.raise_error())
+
+ def test_cancel_taskgroup(self):
+ coros = [self.loop_short, self.loop_shorter, self.loop_the_shortest]
+
+ for coro in coros:
+ try:
+ QtAsyncio.run(self.main(coro), keep_running=False)
+ except ExceptionGroup as e:
+ self.assertEqual(len(e.exceptions), 1)
+ self.assertIsInstance(e.exceptions[0], RuntimeError)
+ self.assertFalse(self._loop_end_reached)
+
+
+if __name__ == '__main__':
+ unittest.main()