Skip to content

Commit

Permalink
Merge pull request #7541 from drew2a/fix/libtorrent_tests
Browse files Browse the repository at this point in the history
Make the `libtorrent` tests more responsive
  • Loading branch information
drew2a authored Jul 17, 2023
2 parents 27623ee + cc1e93b commit 6ae3027
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
"""
import asyncio
import base64
import itertools
import logging
from asyncio import CancelledError, Future, iscoroutine, sleep, wait_for
from collections import defaultdict
from contextlib import suppress
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple

from bitarray import bitarray
Expand All @@ -24,13 +26,16 @@
from tribler.core.components.libtorrent.utils.torrent_utils import check_handle, get_info_from_handle, require_handle
from tribler.core.components.reporter.exception_handler import NoCrashException
from tribler.core.exceptions import SaveResumeDataError
from tribler.core.utilities.async_force_switch import switch
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.osutils import fix_filebasename
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.simpledefs import DOWNLOAD, DownloadStatus
from tribler.core.utilities.unicode import ensure_unicode, hexlify
from tribler.core.utilities.utilities import bdecode_compat

Getter = Callable[[Any], Any]


class Download(TaskManager):
""" Download subclass that represents a libtorrent download."""
Expand Down Expand Up @@ -64,7 +69,7 @@ def __init__(self,
self.checkpoint_after_next_hashcheck = False
self.tracker_status = {} # {url: [num_peers, status_str]}

self.futures = defaultdict(list)
self.futures: Dict[str, list[tuple[Future, Callable, Optional[Getter]]]] = defaultdict(list)
self.alert_handlers = defaultdict(list)

self.future_added = self.wait_for_alert('add_torrent_alert', lambda a: a.handle)
Expand Down Expand Up @@ -100,7 +105,7 @@ def __init__(self,

def __str__(self):
return "Download <name: '%s' hops: %d checkpoint_disabled: %d>" % \
(self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled)
(self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled)

def __repr__(self):
return self.__str__()
Expand All @@ -123,8 +128,8 @@ def get_torrent_data(self) -> Optional[object]:
def register_alert_handler(self, alert_type: str, handler: lt.torrent_handle):
self.alert_handlers[alert_type].append(handler)

def wait_for_alert(self, success_type: str, success_getter: Optional[Callable[[Any], Any]] = None,
fail_type: str = None, fail_getter: Optional[Callable[[Any], Any]] = None) -> Future:
def wait_for_alert(self, success_type: str, success_getter: Optional[Getter] = None,
fail_type: str = None, fail_getter: Optional[Getter] = None) -> Future:
future = Future()
if success_type:
self.futures[success_type].append((future, future.set_result, success_getter))
Expand All @@ -134,6 +139,7 @@ def wait_for_alert(self, success_type: str, success_getter: Optional[Callable[[A

async def wait_for_status(self, *status):
while self.get_state().get_status() not in status:
await switch()
await self.wait_for_alert('state_changed_alert')

def get_def(self) -> TorrentDef:
Expand All @@ -143,10 +149,12 @@ def get_handle(self) -> Awaitable[lt.torrent_handle]:
"""
Returns a deferred that fires with a valid libtorrent download handle.
"""
if self.handle and self.handle.is_valid():
if self.handle:
# This block could be safely omitted because `self.future_added` does the same thing.
# However, it is used in tests, therefore it is better to keep it for now.
return succeed(self.handle)

return self.wait_for_alert('add_torrent_alert', lambda a: a.handle)
return self.future_added

def get_atp(self) -> Dict:
save_path = self.config.get_dest_dir()
Expand Down Expand Up @@ -631,12 +639,16 @@ async def state_callback_loop():
return self.register_anonymous_task("downloads_cb", state_callback_loop)

async def shutdown(self):
self._logger.info('Shutting down...')
self.alert_handlers.clear()
if self.stream is not None:
self.stream.close()
for _, futures in self.futures.items():
for future, _, _ in futures:
future.cancel()

active_futures = [f for f, _, _ in itertools.chain(*self.futures.values()) if not f.done()]
for future in active_futures:
future.cancel()
with suppress(CancelledError):
await asyncio.gather(*active_futures) # wait for futures to be actually cancelled
self.futures.clear()
await self.shutdown_task_manager()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def create_session(self, hops=0, store_listen_port=True):
ltsession.add_dht_router(*router)
ltsession.start_lsd()

self._logger.debug("Started libtorrent session for %d hops on port %d", hops, ltsession.listen_port())
self._logger.info(f"Started libtorrent session for {hops} hops on port {ltsession.listen_port()}")
self.lt_session_shutdown_ready[hops] = False

return ltsession
Expand Down
11 changes: 11 additions & 0 deletions src/tribler/core/components/libtorrent/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,14 @@ def test_get_tracker_status_get_peer_info_error(test_download: Download):
)
status = test_download.get_tracker_status()
assert status


async def test_shutdown(test_download: Download):
""" Test that the `shutdown` method closes the stream and clears the `futures` list."""
test_download.stream = Mock()
assert len(test_download.futures) == 4

await test_download.shutdown()

assert not test_download.futures
assert test_download.stream.close.called
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def fake_dlmgr(tmp_path_factory):
dlmgr.metadata_tmpdir = tmp_path_factory.mktemp('metadata_tmpdir')
dlmgr.get_session = lambda *_, **__: MagicMock()
yield dlmgr
await dlmgr.shutdown(timeout=0)
await dlmgr.shutdown()


async def test_get_metainfo_valid_metadata(fake_dlmgr):
Expand Down
7 changes: 6 additions & 1 deletion src/tribler/core/utilities/async_force_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import functools


async def switch():
""" Coroutine that yields control to the event loop."""
await asyncio.sleep(0)


def force_switch(func):
"""Decorator for forced coroutine switch. The switch will occur before calling the function.
Expand All @@ -11,7 +16,7 @@ def force_switch(func):

@functools.wraps(func)
async def wrapper(*args, **kwargs):
await asyncio.sleep(0)
await switch()
return await func(*args, **kwargs)

return wrapper
Expand Down

0 comments on commit 6ae3027

Please sign in to comment.