diff --git a/scripts/experiments/tunnel_community/hidden_peer_discovery.py b/scripts/experiments/tunnel_community/hidden_peer_discovery.py index 744f212bebb..d0e30ec48ff 100644 --- a/scripts/experiments/tunnel_community/hidden_peer_discovery.py +++ b/scripts/experiments/tunnel_community/hidden_peer_discovery.py @@ -29,8 +29,9 @@ def __init__(self, *args, **kwargs): self.register_task('_graceful_shutdown', self._graceful_shutdown, delay=EXPERIMENT_RUN_TIME) def _graceful_shutdown(self): - task = asyncio.create_task(self.on_tribler_shutdown()) - task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self)) + tasks = self.async_group.add(self.on_tribler_shutdown()) + shutdown_task = tasks[0] + shutdown_task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self)) async def on_tribler_shutdown(self): await self.shutdown_task_manager() diff --git a/scripts/experiments/tunnel_community/speed_test_exit.py b/scripts/experiments/tunnel_community/speed_test_exit.py index 743d16113c8..39039e251e8 100644 --- a/scripts/experiments/tunnel_community/speed_test_exit.py +++ b/scripts/experiments/tunnel_community/speed_test_exit.py @@ -29,8 +29,9 @@ def __init__(self, *args, **kwargs): self.output_file = 'speed_test_exit.txt' def _graceful_shutdown(self): - task = asyncio.create_task(self.on_tribler_shutdown()) - task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self)) + tasks = self.async_group.add(self.on_tribler_shutdown()) + shutdown_task = tasks[0] + shutdown_task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self)) async def on_tribler_shutdown(self): await self.shutdown_task_manager() diff --git a/src/tribler/core/components/metadata_store/restapi/metadata_endpoint.py b/src/tribler/core/components/metadata_store/restapi/metadata_endpoint.py index 2b8426e05c2..9a192ec741a 100644 --- a/src/tribler/core/components/metadata_store/restapi/metadata_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/metadata_endpoint.py @@ -1,4 +1,3 @@ -from asyncio import create_task from binascii import unhexlify from aiohttp import ContentTypeError, web @@ -225,5 +224,6 @@ async def get_torrent_health(self, request): return RESTResponse({"error": f"Error processing timeout parameter: {e}"}, status=HTTP_BAD_REQUEST) infohash = unhexlify(request.match_info['infohash']) - create_task(self.torrent_checker.check_torrent_health(infohash, timeout=timeout, scrape_now=True)) + check_coro = self.torrent_checker.check_torrent_health(infohash, timeout=timeout, scrape_now=True) + self.async_group.add(check_coro) return RESTResponse({'checking': '1'}) diff --git a/src/tribler/core/components/restapi/rest/events_endpoint.py b/src/tribler/core/components/restapi/rest/events_endpoint.py index 30f4e9364e6..d77c70df12b 100644 --- a/src/tribler/core/components/restapi/rest/events_endpoint.py +++ b/src/tribler/core/components/restapi/rest/events_endpoint.py @@ -1,3 +1,4 @@ +import asyncio import json import time from asyncio import CancelledError @@ -8,7 +9,6 @@ from aiohttp_apispec import docs from ipv8.REST.schema import schema from ipv8.messaging.anonymization.tunnel import Circuit -from ipv8.taskmanager import TaskManager, task from marshmallow.fields import Dict, String from tribler.core import notifications @@ -39,7 +39,7 @@ def passthrough(x): @froze_it -class EventsEndpoint(RESTEndpoint, TaskManager): +class EventsEndpoint(RESTEndpoint): """ Important events in Tribler are returned over the events endpoint. This connection is held open. Each event is pushed over this endpoint in the form of a JSON dictionary. Each JSON dictionary contains a type field that @@ -47,8 +47,7 @@ class EventsEndpoint(RESTEndpoint, TaskManager): """ def __init__(self, notifier: Notifier, public_key: str = None): - RESTEndpoint.__init__(self) - TaskManager.__init__(self) + super().__init__() self.events_responses: List[RESTStreamResponse] = [] self.app.on_shutdown.append(self.on_shutdown) self.undelivered_error: Optional[dict] = None @@ -59,7 +58,8 @@ def __init__(self, notifier: Notifier, public_key: str = None): def on_notification(self, topic, *args, **kwargs): if topic in topics_to_send_to_gui: - self.write_data({"topic": topic.__name__, "args": args, "kwargs": kwargs}) + data = {"topic": topic.__name__, "args": args, "kwargs": kwargs} + self.async_group.add(self.write_data(data)) def on_circuit_removed(self, circuit: Circuit, additional_info: str): # The original notification contains non-JSON-serializable argument, so we send another one to GUI @@ -69,10 +69,7 @@ def on_circuit_removed(self, circuit: Circuit, additional_info: str): additional_info=additional_info) async def on_shutdown(self, _): - await self.shutdown_task_manager() - - async def shutdown(self): - await self.shutdown_task_manager() + await self.shutdown() def setup_routes(self): self.app.add_routes([web.get('', self.get_events)]) @@ -101,7 +98,6 @@ def encode_message(self, message: dict) -> bytes: def has_connection_to_gui(self): return bool(self.events_responses) - @task async def write_data(self, message): """ Write data over the event socket if it's open. @@ -124,7 +120,7 @@ async def write_data(self, message): def on_tribler_exception(self, reported_error: ReportedError): message = self.error_message(reported_error) if self.has_connection_to_gui(): - self.write_data(message) + self.async_group.add(self.write_data(message)) elif not self.undelivered_error: # If there are several undelivered errors, we store the first error as more important and skip other self.undelivered_error = message @@ -170,7 +166,7 @@ async def get_events(self, request): try: while True: - await self.register_anonymous_task('event_sleep', lambda: None, delay=3600) + await asyncio.sleep(3600) except CancelledError: self.events_responses.remove(response) return response diff --git a/src/tribler/core/components/restapi/rest/rest_endpoint.py b/src/tribler/core/components/restapi/rest/rest_endpoint.py index f3522a87b2b..97bc6f57326 100644 --- a/src/tribler/core/components/restapi/rest/rest_endpoint.py +++ b/src/tribler/core/components/restapi/rest/rest_endpoint.py @@ -1,8 +1,17 @@ +from __future__ import annotations + import json import logging +from typing import Dict, TYPE_CHECKING from aiohttp import web +from tribler.core.utilities.async_group import AsyncGroup + +if TYPE_CHECKING: + from tribler.core.components.restapi.rest.events_endpoint import EventsEndpoint + from ipv8.REST.root_endpoint import RootEndpoint as IPV8RootEndpoint + HTTP_BAD_REQUEST = 400 HTTP_UNAUTHORIZED = 401 HTTP_NOT_FOUND = 404 @@ -14,16 +23,32 @@ class RESTEndpoint: def __init__(self, middlewares=()): self._logger = logging.getLogger(self.__class__.__name__) self.app = web.Application(middlewares=middlewares, client_max_size=2 * 1024 ** 2) - self.endpoints = {} + self.endpoints: Dict[str, RESTEndpoint] = {} + self.async_group = AsyncGroup() self.setup_routes() + self._shutdown = False + def setup_routes(self): pass - def add_endpoint(self, prefix, endpoint): + def add_endpoint(self, prefix: str, endpoint: RESTEndpoint | EventsEndpoint | IPV8RootEndpoint): self.endpoints[prefix] = endpoint self.app.add_subapp(prefix, endpoint.app) + async def shutdown(self): + if self._shutdown: + return + self._shutdown = True + + shutdown_group = AsyncGroup() + for endpoint in self.endpoints.values(): + if isinstance(endpoint, RESTEndpoint): + shutdown_group.add(endpoint.shutdown()) # IPV8RootEndpoint doesn't have a shutdown method + + await shutdown_group.wait() + await self.async_group.cancel() + class RESTResponse(web.Response): diff --git a/src/tribler/core/components/restapi/rest/shutdown_endpoint.py b/src/tribler/core/components/restapi/rest/shutdown_endpoint.py index 1046fae7b29..c4b7d613e48 100644 --- a/src/tribler/core/components/restapi/rest/shutdown_endpoint.py +++ b/src/tribler/core/components/restapi/rest/shutdown_endpoint.py @@ -18,7 +18,7 @@ def __init__(self, shutdown_callback): self.shutdown_callback = shutdown_callback def setup_routes(self): - self.app.add_routes([web.put('', self.shutdown)]) + self.app.add_routes([web.put('', self.shutdown_request)]) @docs( tags=["General"], @@ -31,7 +31,7 @@ def setup_routes(self): } } ) - async def shutdown(self, request): + async def shutdown_request(self, _): self._logger.info('Received a shutdown request from GUI') self.shutdown_callback() return RESTResponse({"shutdown": True}) diff --git a/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py b/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py index 182c1c6f78f..8086b93b220 100644 --- a/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py +++ b/src/tribler/core/components/restapi/rest/tests/test_events_endpoint.py @@ -150,7 +150,7 @@ async def test_on_tribler_exception_stores_only_first_error(endpoint, reported_e assert endpoint.undelivered_error == endpoint.error_message(first_reported_error) -@patch.object(EventsEndpoint, 'register_anonymous_task', new=AsyncMock(side_effect=CancelledError)) +@patch('asyncio.sleep', new=AsyncMock(side_effect=CancelledError)) @patch.object(RESTStreamResponse, 'prepare', new=AsyncMock()) @patch.object(RESTStreamResponse, 'write', new_callable=AsyncMock) @patch.object(EventsEndpoint, 'encode_message') diff --git a/src/tribler/core/components/restapi/rest/tests/test_rest_endpoint.py b/src/tribler/core/components/restapi/rest/tests/test_rest_endpoint.py new file mode 100644 index 00000000000..be7e6cbb365 --- /dev/null +++ b/src/tribler/core/components/restapi/rest/tests/test_rest_endpoint.py @@ -0,0 +1,46 @@ +from unittest.mock import AsyncMock, patch + +from tribler.core.components.restapi.rest.rest_endpoint import RESTEndpoint +from tribler.core.utilities.async_group import AsyncGroup + + +# pylint: disable=protected-access + +async def test_shutdown(): + # In this test we check that all coros related to the Root Endpoint are cancelled + # during shutdown + + async def coro(): + ... + + root_endpoint = RESTEndpoint() + root_endpoint.async_group.add(coro()) + + # add 2 child endpoints with a single coro in each: + child_endpoints = [RESTEndpoint(), RESTEndpoint()] + for index, child_endpoint in enumerate(child_endpoints): + root_endpoint.add_endpoint(prefix=f'/{index}', endpoint=child_endpoint) + child_endpoint.async_group.add(coro()) + + def total_coro_count(): + count = 0 + for endpoint in child_endpoints + [root_endpoint]: + count += len(endpoint.async_group._futures) + return count + + assert total_coro_count() == 3 + + await root_endpoint.shutdown() + + assert total_coro_count() == 0 + + +@patch.object(AsyncGroup, 'cancel', new_callable=AsyncMock) +async def test_multiple_shutdown_calls(async_group_cancel: AsyncMock): + # Test that if shutdown calls twice, only one call is processed + endpoint = RESTEndpoint() + + await endpoint.shutdown() + await endpoint.shutdown() + + async_group_cancel.assert_called_once() diff --git a/src/tribler/core/components/restapi/restapi_component.py b/src/tribler/core/components/restapi/restapi_component.py index 341718d7d6c..3aadb622e9a 100644 --- a/src/tribler/core/components/restapi/restapi_component.py +++ b/src/tribler/core/components/restapi/restapi_component.py @@ -150,8 +150,8 @@ def report_callback(reported_error: ReportedError): async def shutdown(self): await super().shutdown() - if self._events_endpoint: - await self._events_endpoint.shutdown() + if self.root_endpoint: + await self.root_endpoint.shutdown() if self._core_exception_handler: self._core_exception_handler.report_callback = None diff --git a/src/tribler/core/components/session.py b/src/tribler/core/components/session.py index 9e8621676b5..41edd38df2c 100644 --- a/src/tribler/core/components/session.py +++ b/src/tribler/core/components/session.py @@ -10,6 +10,7 @@ from tribler.core.components.component import Component, ComponentError, ComponentStartupException, \ MultipleComponentsFound from tribler.core.config.tribler_config import TriblerConfig +from tribler.core.utilities.async_group import AsyncGroup from tribler.core.utilities.crypto_patcher import patch_crypto_be_discovery from tribler.core.utilities.install_dir import get_lib_path from tribler.core.utilities.network_utils import default_network_utils @@ -33,6 +34,7 @@ def __init__(self, config: TriblerConfig = None, components: List[Component] = ( self.config: TriblerConfig = config or TriblerConfig() self.shutdown_event: Event = shutdown_event or Event() self.notifier: Notifier = notifier or Notifier(loop=get_event_loop()) + self.async_group = AsyncGroup() self.components: Dict[Type[Component], Component] = {} for component in components: self.register(component.__class__, component) @@ -104,7 +106,7 @@ async def exception_reraiser(): self.logger.info(f'Reraise startup exception: {self._startup_exception}') raise self._startup_exception - get_event_loop().create_task(exception_reraiser()) + self.async_group.add(exception_reraiser()) def set_startup_exception(self, exc: Exception): if not self._startup_exception: @@ -113,6 +115,7 @@ def set_startup_exception(self, exc: Exception): async def shutdown(self): self.logger.info("Stopping components") await gather(*[create_task(component.stop()) for component in self.components.values()]) + await self.async_group.cancel() self.logger.info("All components are stopped") diff --git a/src/tribler/core/utilities/async_group.py b/src/tribler/core/utilities/async_group.py index 349510e8bf1..a5cff104f32 100644 --- a/src/tribler/core/utilities/async_group.py +++ b/src/tribler/core/utilities/async_group.py @@ -1,5 +1,5 @@ import asyncio -from asyncio import CancelledError, Future +from asyncio import CancelledError, Future, Task from contextlib import suppress from typing import Iterable, List, Set @@ -24,13 +24,17 @@ class AsyncGroup: def __init__(self): self._futures: Set[Future] = set() - def add(self, *coroutines): + def add(self, *coroutines) -> List[Task]: """Add a coroutine to the group. """ + result = [] for coroutine in coroutines: task = asyncio.create_task(coroutine) self._futures.add(task) task.add_done_callback(self._done_callback) + result.append(task) + + return result async def wait(self): """ Wait for completion of all futures diff --git a/src/tribler/core/utilities/tests/test_async_group.py b/src/tribler/core/utilities/tests/test_async_group.py index 951ea72eb57..788f63cfbf1 100644 --- a/src/tribler/core/utilities/tests/test_async_group.py +++ b/src/tribler/core/utilities/tests/test_async_group.py @@ -30,21 +30,23 @@ async def raise_exception(): async def test_add_single_coro(group: AsyncGroup): - group.add( + tasks = group.add( void() ) assert len(group._futures) == 1 + assert len(tasks) == 1 async def test_add_iterable(group: AsyncGroup): - group.add( + tasks = group.add( void(), void(), void() ) assert len(group._futures) == 3 + assert len(tasks) == 3 async def test_cancel(group: AsyncGroup): diff --git a/src/tribler/core/utilities/tiny_tribler_service.py b/src/tribler/core/utilities/tiny_tribler_service.py index 9eef49fcfd4..c061505a83d 100644 --- a/src/tribler/core/utilities/tiny_tribler_service.py +++ b/src/tribler/core/utilities/tiny_tribler_service.py @@ -7,6 +7,7 @@ from tribler.core.components.component import Component from tribler.core.components.session import Session from tribler.core.config.tribler_config import TriblerConfig +from tribler.core.utilities.async_group import AsyncGroup from tribler.core.utilities.osutils import get_root_state_directory from tribler.core.utilities.process_manager import ProcessKind, ProcessManager, TriblerProcess, \ set_global_process_manager @@ -27,6 +28,8 @@ def __init__(self, components: List[Component], timeout_in_sec=None, state_dir=P self.config = TriblerConfig(state_dir=state_dir.absolute()) self.timeout_in_sec = timeout_in_sec self.components = components + self.async_group = AsyncGroup() + self._main_task = None async def on_tribler_started(self): """Function will calls after the Tribler session is started @@ -42,7 +45,7 @@ async def start_tribler(): await self._start_session() if self.timeout_in_sec: - asyncio.create_task(self._terminate_by_timeout()) + self.async_group.add(self._terminate_by_timeout()) self._enable_graceful_shutdown() await self.on_tribler_started() @@ -51,7 +54,9 @@ async def start_tribler(): if fragile: make_async_loop_fragile(loop) - loop.create_task(start_tribler()) + # the variable `self._main_task` is used here to prevent a naked `loop.create_task()` call + # more details: https://github.com/Tribler/tribler/issues/7299 + self._main_task = loop.create_task(start_tribler()) try: loop.run_forever() finally: @@ -97,8 +102,9 @@ async def _terminate_by_timeout(self): def _graceful_shutdown(self): self.logger.info("Shutdown gracefully") - task = asyncio.create_task(self.session.shutdown()) - task.add_done_callback(lambda result: self._stop_event_loop()) + tasks = self.async_group.add(self.session.shutdown()) + shutdown_task = tasks[0] + shutdown_task.add_done_callback(lambda result: self._stop_event_loop()) def _stop_event_loop(self): asyncio.get_running_loop().stop()