Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asyncio.create_task() calls #7300

Merged
merged 1 commit into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions scripts/experiments/tunnel_community/hidden_peer_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __init__(self, *args, **kwargs):
self.register_task('_graceful_shutdown', self._graceful_shutdown, delay=EXPERIMENT_RUN_TIME)

def _graceful_shutdown(self):
task = asyncio.create_task(self.on_tribler_shutdown())
task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))
tasks = self.async_group.add(self.on_tribler_shutdown())
shutdown_task = tasks[0]
shutdown_task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))

async def on_tribler_shutdown(self):
await self.shutdown_task_manager()
Expand Down
5 changes: 3 additions & 2 deletions scripts/experiments/tunnel_community/speed_test_exit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __init__(self, *args, **kwargs):
self.output_file = 'speed_test_exit.txt'

def _graceful_shutdown(self):
task = asyncio.create_task(self.on_tribler_shutdown())
task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))
tasks = self.async_group.add(self.on_tribler_shutdown())
shutdown_task = tasks[0]
shutdown_task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))

async def on_tribler_shutdown(self):
await self.shutdown_task_manager()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from asyncio import create_task
from binascii import unhexlify

from aiohttp import ContentTypeError, web
Expand Down Expand Up @@ -225,5 +224,6 @@ async def get_torrent_health(self, request):
return RESTResponse({"error": f"Error processing timeout parameter: {e}"}, status=HTTP_BAD_REQUEST)

infohash = unhexlify(request.match_info['infohash'])
create_task(self.torrent_checker.check_torrent_health(infohash, timeout=timeout, scrape_now=True))
check_coro = self.torrent_checker.check_torrent_health(infohash, timeout=timeout, scrape_now=True)
self.async_group.add(check_coro)
return RESTResponse({'checking': '1'})
20 changes: 8 additions & 12 deletions src/tribler/core/components/restapi/rest/events_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import time
from asyncio import CancelledError
Expand All @@ -8,7 +9,6 @@
from aiohttp_apispec import docs
from ipv8.REST.schema import schema
from ipv8.messaging.anonymization.tunnel import Circuit
from ipv8.taskmanager import TaskManager, task
from marshmallow.fields import Dict, String

from tribler.core import notifications
Expand Down Expand Up @@ -39,16 +39,15 @@ def passthrough(x):


@froze_it
class EventsEndpoint(RESTEndpoint, TaskManager):
class EventsEndpoint(RESTEndpoint):
"""
Important events in Tribler are returned over the events endpoint. This connection is held open. Each event is
pushed over this endpoint in the form of a JSON dictionary. Each JSON dictionary contains a type field that
indicates the type of the event. Individual events are separated by a newline character.
"""

def __init__(self, notifier: Notifier, public_key: str = None):
RESTEndpoint.__init__(self)
TaskManager.__init__(self)
super().__init__()
self.events_responses: List[RESTStreamResponse] = []
self.app.on_shutdown.append(self.on_shutdown)
self.undelivered_error: Optional[dict] = None
Expand All @@ -59,7 +58,8 @@ def __init__(self, notifier: Notifier, public_key: str = None):

def on_notification(self, topic, *args, **kwargs):
if topic in topics_to_send_to_gui:
self.write_data({"topic": topic.__name__, "args": args, "kwargs": kwargs})
data = {"topic": topic.__name__, "args": args, "kwargs": kwargs}
self.async_group.add(self.write_data(data))

def on_circuit_removed(self, circuit: Circuit, additional_info: str):
# The original notification contains non-JSON-serializable argument, so we send another one to GUI
Expand All @@ -69,10 +69,7 @@ def on_circuit_removed(self, circuit: Circuit, additional_info: str):
additional_info=additional_info)

async def on_shutdown(self, _):
await self.shutdown_task_manager()

async def shutdown(self):
await self.shutdown_task_manager()
await self.shutdown()
drew2a marked this conversation as resolved.
Show resolved Hide resolved

def setup_routes(self):
self.app.add_routes([web.get('', self.get_events)])
Expand Down Expand Up @@ -101,7 +98,6 @@ def encode_message(self, message: dict) -> bytes:
def has_connection_to_gui(self):
return bool(self.events_responses)

@task
async def write_data(self, message):
"""
Write data over the event socket if it's open.
Expand All @@ -124,7 +120,7 @@ async def write_data(self, message):
def on_tribler_exception(self, reported_error: ReportedError):
message = self.error_message(reported_error)
if self.has_connection_to_gui():
self.write_data(message)
self.async_group.add(self.write_data(message))
elif not self.undelivered_error:
# If there are several undelivered errors, we store the first error as more important and skip other
self.undelivered_error = message
Expand Down Expand Up @@ -170,7 +166,7 @@ async def get_events(self, request):

try:
while True:
await self.register_anonymous_task('event_sleep', lambda: None, delay=3600)
await asyncio.sleep(3600)
except CancelledError:
self.events_responses.remove(response)
return response
29 changes: 27 additions & 2 deletions src/tribler/core/components/restapi/rest/rest_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

import json
import logging
from typing import Dict, TYPE_CHECKING

from aiohttp import web

from tribler.core.utilities.async_group import AsyncGroup

if TYPE_CHECKING:
from tribler.core.components.restapi.rest.events_endpoint import EventsEndpoint
from ipv8.REST.root_endpoint import RootEndpoint as IPV8RootEndpoint

HTTP_BAD_REQUEST = 400
HTTP_UNAUTHORIZED = 401
HTTP_NOT_FOUND = 404
Expand All @@ -14,16 +23,32 @@ class RESTEndpoint:
def __init__(self, middlewares=()):
self._logger = logging.getLogger(self.__class__.__name__)
self.app = web.Application(middlewares=middlewares, client_max_size=2 * 1024 ** 2)
self.endpoints = {}
self.endpoints: Dict[str, RESTEndpoint] = {}
self.async_group = AsyncGroup()
self.setup_routes()

self._shutdown = False

def setup_routes(self):
pass

def add_endpoint(self, prefix, endpoint):
def add_endpoint(self, prefix: str, endpoint: RESTEndpoint | EventsEndpoint | IPV8RootEndpoint):
self.endpoints[prefix] = endpoint
self.app.add_subapp(prefix, endpoint.app)

async def shutdown(self):
drew2a marked this conversation as resolved.
Show resolved Hide resolved
if self._shutdown:
return
self._shutdown = True

shutdown_group = AsyncGroup()
for endpoint in self.endpoints.values():
if isinstance(endpoint, RESTEndpoint):
shutdown_group.add(endpoint.shutdown()) # IPV8RootEndpoint doesn't have a shutdown method

await shutdown_group.wait()
await self.async_group.cancel()


class RESTResponse(web.Response):

Expand Down
4 changes: 2 additions & 2 deletions src/tribler/core/components/restapi/rest/shutdown_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback

def setup_routes(self):
self.app.add_routes([web.put('', self.shutdown)])
self.app.add_routes([web.put('', self.shutdown_request)])

@docs(
tags=["General"],
Expand All @@ -31,7 +31,7 @@ def setup_routes(self):
}
}
)
async def shutdown(self, request):
async def shutdown_request(self, _):
self._logger.info('Received a shutdown request from GUI')
self.shutdown_callback()
return RESTResponse({"shutdown": True})
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def test_on_tribler_exception_stores_only_first_error(endpoint, reported_e
assert endpoint.undelivered_error == endpoint.error_message(first_reported_error)


@patch.object(EventsEndpoint, 'register_anonymous_task', new=AsyncMock(side_effect=CancelledError))
@patch('asyncio.sleep', new=AsyncMock(side_effect=CancelledError))
@patch.object(RESTStreamResponse, 'prepare', new=AsyncMock())
@patch.object(RESTStreamResponse, 'write', new_callable=AsyncMock)
@patch.object(EventsEndpoint, 'encode_message')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from unittest.mock import AsyncMock, patch

from tribler.core.components.restapi.rest.rest_endpoint import RESTEndpoint
from tribler.core.utilities.async_group import AsyncGroup


# pylint: disable=protected-access

async def test_shutdown():
# In this test we check that all coros related to the Root Endpoint are cancelled
# during shutdown

async def coro():
...

root_endpoint = RESTEndpoint()
root_endpoint.async_group.add(coro())

# add 2 child endpoints with a single coro in each:
child_endpoints = [RESTEndpoint(), RESTEndpoint()]
for index, child_endpoint in enumerate(child_endpoints):
root_endpoint.add_endpoint(prefix=f'/{index}', endpoint=child_endpoint)
child_endpoint.async_group.add(coro())

def total_coro_count():
count = 0
for endpoint in child_endpoints + [root_endpoint]:
count += len(endpoint.async_group._futures)
return count

assert total_coro_count() == 3

await root_endpoint.shutdown()

assert total_coro_count() == 0


@patch.object(AsyncGroup, 'cancel', new_callable=AsyncMock)
async def test_multiple_shutdown_calls(async_group_cancel: AsyncMock):
# Test that if shutdown calls twice, only one call is processed
endpoint = RESTEndpoint()

await endpoint.shutdown()
await endpoint.shutdown()

async_group_cancel.assert_called_once()
4 changes: 2 additions & 2 deletions src/tribler/core/components/restapi/restapi_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def report_callback(reported_error: ReportedError):
async def shutdown(self):
await super().shutdown()

if self._events_endpoint:
await self._events_endpoint.shutdown()
if self.root_endpoint:
await self.root_endpoint.shutdown()

if self._core_exception_handler:
self._core_exception_handler.report_callback = None
Expand Down
5 changes: 4 additions & 1 deletion src/tribler/core/components/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tribler.core.components.component import Component, ComponentError, ComponentStartupException, \
MultipleComponentsFound
from tribler.core.config.tribler_config import TriblerConfig
from tribler.core.utilities.async_group import AsyncGroup
from tribler.core.utilities.crypto_patcher import patch_crypto_be_discovery
from tribler.core.utilities.install_dir import get_lib_path
from tribler.core.utilities.network_utils import default_network_utils
Expand All @@ -33,6 +34,7 @@ def __init__(self, config: TriblerConfig = None, components: List[Component] = (
self.config: TriblerConfig = config or TriblerConfig()
self.shutdown_event: Event = shutdown_event or Event()
self.notifier: Notifier = notifier or Notifier(loop=get_event_loop())
self.async_group = AsyncGroup()
self.components: Dict[Type[Component], Component] = {}
for component in components:
self.register(component.__class__, component)
Expand Down Expand Up @@ -104,7 +106,7 @@ async def exception_reraiser():
self.logger.info(f'Reraise startup exception: {self._startup_exception}')
raise self._startup_exception

get_event_loop().create_task(exception_reraiser())
self.async_group.add(exception_reraiser())

def set_startup_exception(self, exc: Exception):
if not self._startup_exception:
Expand All @@ -113,6 +115,7 @@ def set_startup_exception(self, exc: Exception):
async def shutdown(self):
self.logger.info("Stopping components")
await gather(*[create_task(component.stop()) for component in self.components.values()])
await self.async_group.cancel()
self.logger.info("All components are stopped")


Expand Down
8 changes: 6 additions & 2 deletions src/tribler/core/utilities/async_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from asyncio import CancelledError, Future
from asyncio import CancelledError, Future, Task
from contextlib import suppress
from typing import Iterable, List, Set

Expand All @@ -24,13 +24,17 @@ class AsyncGroup:
def __init__(self):
self._futures: Set[Future] = set()

def add(self, *coroutines):
def add(self, *coroutines) -> List[Task]:
"""Add a coroutine to the group.
"""
result = []
for coroutine in coroutines:
task = asyncio.create_task(coroutine)
self._futures.add(task)
task.add_done_callback(self._done_callback)
result.append(task)

return result

async def wait(self):
""" Wait for completion of all futures
Expand Down
6 changes: 4 additions & 2 deletions src/tribler/core/utilities/tests/test_async_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,23 @@ async def raise_exception():


async def test_add_single_coro(group: AsyncGroup):
group.add(
tasks = group.add(
void()
)

assert len(group._futures) == 1
assert len(tasks) == 1


async def test_add_iterable(group: AsyncGroup):
group.add(
tasks = group.add(
void(),
void(),
void()
)

assert len(group._futures) == 3
assert len(tasks) == 3


async def test_cancel(group: AsyncGroup):
Expand Down
14 changes: 10 additions & 4 deletions src/tribler/core/utilities/tiny_tribler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tribler.core.components.component import Component
from tribler.core.components.session import Session
from tribler.core.config.tribler_config import TriblerConfig
from tribler.core.utilities.async_group import AsyncGroup
from tribler.core.utilities.osutils import get_root_state_directory
from tribler.core.utilities.process_manager import ProcessKind, ProcessManager, TriblerProcess, \
set_global_process_manager
Expand All @@ -27,6 +28,8 @@ def __init__(self, components: List[Component], timeout_in_sec=None, state_dir=P
self.config = TriblerConfig(state_dir=state_dir.absolute())
self.timeout_in_sec = timeout_in_sec
self.components = components
self.async_group = AsyncGroup()
self._main_task = None

async def on_tribler_started(self):
"""Function will calls after the Tribler session is started
Expand All @@ -42,7 +45,7 @@ async def start_tribler():
await self._start_session()

if self.timeout_in_sec:
asyncio.create_task(self._terminate_by_timeout())
self.async_group.add(self._terminate_by_timeout())

self._enable_graceful_shutdown()
await self.on_tribler_started()
Expand All @@ -51,7 +54,9 @@ async def start_tribler():
if fragile:
make_async_loop_fragile(loop)

loop.create_task(start_tribler())
# the variable `self._main_task` is used here to prevent a naked `loop.create_task()` call
# more details: https://github.com/Tribler/tribler/issues/7299
self._main_task = loop.create_task(start_tribler())
try:
loop.run_forever()
finally:
Expand Down Expand Up @@ -97,8 +102,9 @@ async def _terminate_by_timeout(self):

def _graceful_shutdown(self):
self.logger.info("Shutdown gracefully")
task = asyncio.create_task(self.session.shutdown())
task.add_done_callback(lambda result: self._stop_event_loop())
tasks = self.async_group.add(self.session.shutdown())
drew2a marked this conversation as resolved.
Show resolved Hide resolved
shutdown_task = tasks[0]
shutdown_task.add_done_callback(lambda result: self._stop_event_loop())

def _stop_event_loop(self):
asyncio.get_running_loop().stop()
Expand Down