Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the libtorrent tests more responsive #7541

Merged
merged 4 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
"""
import asyncio
import base64
import itertools
import logging
from asyncio import CancelledError, Future, iscoroutine, sleep, wait_for
from collections import defaultdict
from contextlib import suppress
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple

from bitarray import bitarray
Expand All @@ -24,13 +26,16 @@
from tribler.core.components.libtorrent.utils.torrent_utils import check_handle, get_info_from_handle, require_handle
from tribler.core.components.reporter.exception_handler import NoCrashException
from tribler.core.exceptions import SaveResumeDataError
from tribler.core.utilities.async_force_switch import switch
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.osutils import fix_filebasename
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.simpledefs import DOWNLOAD, DownloadStatus
from tribler.core.utilities.unicode import ensure_unicode, hexlify
from tribler.core.utilities.utilities import bdecode_compat

Getter = Optional[Callable[[Any], Any]]


class Download(TaskManager):
""" Download subclass that represents a libtorrent download."""
Expand Down Expand Up @@ -64,7 +69,7 @@ def __init__(self,
self.checkpoint_after_next_hashcheck = False
self.tracker_status = {} # {url: [num_peers, status_str]}

self.futures = defaultdict(list)
self.futures: Dict[str, list[tuple[Future, Callable, Getter]]] = defaultdict(list)
self.alert_handlers = defaultdict(list)

self.future_added = self.wait_for_alert('add_torrent_alert', lambda a: a.handle)
Expand Down Expand Up @@ -100,7 +105,7 @@ def __init__(self,

def __str__(self):
return "Download <name: '%s' hops: %d checkpoint_disabled: %d>" % \
(self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled)
(self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled)

def __repr__(self):
return self.__str__()
Expand All @@ -123,8 +128,8 @@ def get_torrent_data(self) -> Optional[object]:
def register_alert_handler(self, alert_type: str, handler: lt.torrent_handle):
self.alert_handlers[alert_type].append(handler)

def wait_for_alert(self, success_type: str, success_getter: Optional[Callable[[Any], Any]] = None,
fail_type: str = None, fail_getter: Optional[Callable[[Any], Any]] = None) -> Future:
def wait_for_alert(self, success_type: str, success_getter: Getter = None,
drew2a marked this conversation as resolved.
Show resolved Hide resolved
fail_type: str = None, fail_getter: Getter = None) -> Future:
future = Future()
if success_type:
self.futures[success_type].append((future, future.set_result, success_getter))
Expand All @@ -134,6 +139,7 @@ def wait_for_alert(self, success_type: str, success_getter: Optional[Callable[[A

async def wait_for_status(self, *status):
while self.get_state().get_status() not in status:
await switch()
await self.wait_for_alert('state_changed_alert')

def get_def(self) -> TorrentDef:
Expand All @@ -143,10 +149,12 @@ def get_handle(self) -> Awaitable[lt.torrent_handle]:
"""
Returns a deferred that fires with a valid libtorrent download handle.
"""
if self.handle and self.handle.is_valid():
if self.handle:
# This block could be safely omitted because `self.future_added` does the same thing.
# However, it is used in tests, therefore it is better to keep it for now.
return succeed(self.handle)

return self.wait_for_alert('add_torrent_alert', lambda a: a.handle)
return self.future_added

def get_atp(self) -> Dict:
save_path = self.config.get_dest_dir()
Expand Down Expand Up @@ -631,12 +639,16 @@ async def state_callback_loop():
return self.register_anonymous_task("downloads_cb", state_callback_loop)

async def shutdown(self):
self._logger.info('Shutting down...')
self.alert_handlers.clear()
if self.stream is not None:
self.stream.close()
for _, futures in self.futures.items():
for future, _, _ in futures:
future.cancel()

active_futures = [f for f, _, _ in itertools.chain(*self.futures.values()) if not f.done()]
for future in active_futures:
future.cancel()
with suppress(CancelledError):
await asyncio.gather(*active_futures) # wait for futures to be actually cancelled
self.futures.clear()
await self.shutdown_task_manager()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def create_session(self, hops=0, store_listen_port=True):
ltsession.add_dht_router(*router)
ltsession.start_lsd()

self._logger.debug("Started libtorrent session for %d hops on port %d", hops, ltsession.listen_port())
self._logger.info(f"Started libtorrent session for {hops} hops on port {ltsession.listen_port()}")
self.lt_session_shutdown_ready[hops] = False

return ltsession
Expand Down
11 changes: 11 additions & 0 deletions src/tribler/core/components/libtorrent/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,14 @@ def test_get_tracker_status_get_peer_info_error(test_download: Download):
)
status = test_download.get_tracker_status()
assert status


async def test_shutdown(test_download: Download):
""" Test that the `shutdown` method closes the stream and clears the `futures` list."""
test_download.stream = Mock()
assert len(test_download.futures) == 4

await test_download.shutdown()

assert not test_download.futures
assert test_download.stream.close.called
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def fake_dlmgr(tmp_path_factory):
dlmgr.metadata_tmpdir = tmp_path_factory.mktemp('metadata_tmpdir')
dlmgr.get_session = lambda *_, **__: MagicMock()
yield dlmgr
await dlmgr.shutdown(timeout=0)
await dlmgr.shutdown()


async def test_get_metainfo_valid_metadata(fake_dlmgr):
Expand Down
7 changes: 6 additions & 1 deletion src/tribler/core/utilities/async_force_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import functools


async def switch():
""" Coroutine that yields control to the event loop."""
await asyncio.sleep(0)


def force_switch(func):
"""Decorator for forced coroutine switch. The switch will occur before calling the function.

Expand All @@ -11,7 +16,7 @@ def force_switch(func):

@functools.wraps(func)
async def wrapper(*args, **kwargs):
await asyncio.sleep(0)
await switch()
return await func(*args, **kwargs)

return wrapper
Expand Down