diff --git a/src/tribler/core/libtorrent/download_manager/download.py b/src/tribler/core/libtorrent/download_manager/download.py index cbf638510e..c6051cf6c0 100644 --- a/src/tribler/core/libtorrent/download_manager/download.py +++ b/src/tribler/core/libtorrent/download_manager/download.py @@ -864,7 +864,7 @@ def set_def(self, tdef: TorrentDef) -> None: self.tdef = tdef @check_handle(None) - def add_trackers(self, trackers: list[str]) -> None: + def add_trackers(self, trackers: list[bytes]) -> None: """ Add the given trackers to the handle. """ diff --git a/src/tribler/core/libtorrent/download_manager/download_manager.py b/src/tribler/core/libtorrent/download_manager/download_manager.py index 2811241728..f2a011fbc4 100644 --- a/src/tribler/core/libtorrent/download_manager/download_manager.py +++ b/src/tribler/core/libtorrent/download_manager/download_manager.py @@ -10,7 +10,7 @@ import logging import os import time -from asyncio import CancelledError, gather, iscoroutine, shield, sleep, wait_for +from asyncio import CancelledError, Future, gather, iscoroutine, shield, sleep, wait_for from binascii import hexlify, unhexlify from collections import defaultdict from copy import deepcopy @@ -91,7 +91,7 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier, self.state_dir = Path(config.get_version_state_dir()) self.ltsettings: dict[lt.session, dict] = {} # Stores a copy of the settings dict for each libtorrent session - self.ltsessions: dict[int, lt.session] = {} + self.ltsessions: dict[int, Future[lt.session]] = {} self.dht_health_manager: DHTHealthManager | None = None self.listen_ports: dict[int, dict[str, int]] = defaultdict(dict) @@ -168,7 +168,7 @@ async def _check_dht_ready(self, min_dht_peers: int = 60) -> None: See https://github.com/Tribler/tribler/issues/5319 """ - while not (self.get_session() and self.get_session().status().dht_nodes > min_dht_peers): + while (await self.get_session()).status().dht_nodes < min_dht_peers: await asyncio.sleep(1) def initialize(self) -> None: @@ -180,7 +180,7 @@ def initialize(self) -> None: # Start upnp if self.config.get("libtorrent/upnp"): - self.get_session().start_upnp() + self.get_session().add_done_callback(lambda s: s.result().start_upnp()) # Register tasks self.register_task("process_alerts", self._task_process_alerts, interval=1, ignore=(Exception, )) @@ -245,13 +245,14 @@ async def shutdown(self, timeout: int = 30) -> None: if self.has_session(): logger.info("Saving state...") self.notify_shutdown_state("Writing session state to disk.") + session = await self.get_session() with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC230 - ltstate_file.write(lt.bencode(self.get_session().save_state())) + ltstate_file.write(lt.bencode(session.save_state())) if self.has_session() and self.config.get("libtorrent/upnp"): logger.info("Stopping upnp...") self.notify_shutdown_state("Stopping UPnP.") - self.get_session().stop_upnp() + self.get_session().add_done_callback(lambda s: s.result().stop_upnp()) # Remove metadata temporary directory if self.metadata_tmpdir: @@ -357,12 +358,12 @@ def has_session(self, hops: int = 0) -> bool: """ return hops in self.ltsessions - def get_session(self, hops: int = 0) -> lt.session: + def get_session(self, hops: int = 0) -> Future[lt.session]: """ Get the session for the given number of anonymization hops. """ if hops not in self.ltsessions: - self.ltsessions[hops] = self.create_session(hops) + self.ltsessions[hops] = self.register_executor_task(f"Create session {hops}", self.create_session, hops) return self.ltsessions[hops] @@ -400,15 +401,16 @@ def set_upload_rate_limit(self, rate: int) -> None: # Pass outgoing_port and num_outgoing_ports to dict due to bug in libtorrent 0.16.18 settings_dict = {"upload_rate_limit": libtorrent_rate, "outgoing_port": 0, "num_outgoing_ports": 1} for session in self.ltsessions.values(): - self.set_session_settings(session, settings_dict) + session.add_done_callback(lambda s: self.set_session_settings(s.result(), settings_dict)) - def get_upload_rate_limit(self, hops: int = 0) -> int: + async def get_upload_rate_limit(self, hops: int = 0) -> int: """ Get the upload rate limit for the session with the given hop count. """ # Rate conversion due to the fact that we had a different system with Swift # and the old python BitTorrent core: unlimited == 0, stop == -1, else rate in kbytes - libtorrent_rate = self.get_session(hops).upload_rate_limit() + session = await self.get_session(hops) + libtorrent_rate = session.upload_rate_limit() return self.reverse_convert_rate(rate=libtorrent_rate) def set_download_rate_limit(self, rate: int) -> None: @@ -420,13 +422,14 @@ def set_download_rate_limit(self, rate: int) -> None: # Pass outgoing_port and num_outgoing_ports to dict due to bug in libtorrent 0.16.18 settings_dict = {"download_rate_limit": libtorrent_rate} for session in self.ltsessions.values(): - self.set_session_settings(session, settings_dict) + session.add_done_callback(lambda s: self.set_session_settings(s.result(), settings_dict)) - def get_download_rate_limit(self, hops: int = 0) -> int: + async def get_download_rate_limit(self, hops: int = 0) -> int: """ Get the download rate limit for the session with the given hop count. """ - libtorrent_rate = self.get_session(hops=hops).download_rate_limit() + session = await self.get_session(hops) + libtorrent_rate = session.download_rate_limit() return self.reverse_convert_rate(rate=libtorrent_rate) def process_alert(self, alert: lt.alert, hops: int = 0) -> None: # noqa: C901, PLR0912 @@ -597,14 +600,12 @@ def _task_cleanup_metainfo_cache(self) -> None: def _request_torrent_updates(self) -> None: for ltsession in self.ltsessions.values(): - if ltsession: - ltsession.post_torrent_updates(0xffffffff) + ltsession.add_done_callback(lambda s: s.result().post_torrent_updates(0xffffffff)) - def _task_process_alerts(self) -> None: + async def _task_process_alerts(self) -> None: for hops, ltsession in list(self.ltsessions.items()): - if ltsession: - for alert in ltsession.pop_alerts(): - self.process_alert(alert, hops=hops) + for alert in (await ltsession).pop_alerts(): + self.process_alert(alert, hops=hops) def _map_call_on_ltsessions(self, hops: int | None, funcname: str, *args: Any, **kwargs) -> None: # noqa: ANN401 if hops is None: @@ -729,7 +730,7 @@ async def start_handle(self, download: Download, atp: dict) -> None: if resume_data: logger.debug("Download resume data: %s", str(atp["resume_data"])) - ltsession = self.get_session(download.config.get_hops()) + ltsession = await self.get_session(download.config.get_hops()) infohash = download.get_def().get_infohash() if infohash in self.metainfo_requests and self.metainfo_requests[infohash].download != download: @@ -815,7 +816,7 @@ def update_max_rates_from_config(self) -> None: download_rate = DownloadManager.get_libtorrent_max_download_rate(self.config) settings = {"download_rate_limit": download_rate, "upload_rate_limit": rate} - self.set_session_settings(lt_session, settings) + lt_session.add_done_callback(lambda s: self.set_session_settings(s.result(), settings)) def post_session_stats(self) -> None: """ @@ -823,7 +824,7 @@ def post_session_stats(self) -> None: """ logger.info("Post session stats") for session in self.ltsessions.values(): - session.post_session_stats() + session.add_done_callback(lambda s: s.result().post_session_stats()) async def remove_download(self, download: Download, remove_content: bool = False, remove_checkpoint: bool = True) -> None: @@ -842,7 +843,7 @@ async def remove_download(self, download: Download, remove_content: bool = False download.stream.disable() logger.debug("Removing handle %s", hexlify(infohash)) ltsession = self.get_session(download.config.get_hops()) - ltsession.remove_torrent(handle, int(remove_content)) + ltsession.add_done_callback(lambda s: s.result().remove_torrent(handle, int(remove_content))) else: logger.debug("Cannot remove handle %s because it does not exists", hexlify(infohash)) await download.shutdown() @@ -889,7 +890,7 @@ async def update_hops(self, download: Download, new_hops: int) -> None: await self.start_download(tdef=download.tdef, config=config) - def update_trackers(self, infohash: bytes, trackers: list[str]) -> None: + def update_trackers(self, infohash: bytes, trackers: list[bytes]) -> None: """ Update the trackers for a download. diff --git a/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py b/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py index 0a1ec846e2..5438dbadeb 100644 --- a/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py +++ b/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py @@ -44,8 +44,10 @@ def setUp(self) -> None: """ super().setUp() self.manager = DownloadManager(MockTriblerConfigManager(), Notifier(), Mock()) - self.manager.ltsessions = {i: Mock(status=Mock(dht_nodes=0), get_torrents=Mock(return_value=[])) - for i in range(4)} + for i in range(4): + fut = Future() + fut.set_result(Mock(status=Mock(dht_nodes=0), get_torrents=Mock(return_value=[]))) + self.manager.ltsessions[i] = fut self.manager.set_download_states_callback(self.manager.sesscb_states_callback) async def tearDown(self) -> None: @@ -158,7 +160,7 @@ async def test_start_download(self) -> None: mock_alert = type("add_torrent_alert", (object,), {"handle": mock_handle, "error": Mock(value=Mock(return_value=None)), "category": MagicMock(return_value=None)}) - self.manager.ltsessions[0].async_add_torrent = lambda _: self.manager.process_alert(mock_alert()) + self.manager.ltsessions[0].result().async_add_torrent = lambda _: self.manager.process_alert(mock_alert()) with patch.object(self.manager, "remove_download", AsyncMock()): download = await self.manager.start_download(tdef=TorrentDefNoMetainfo(b"\x01" * 20, b""), @@ -201,7 +203,7 @@ async def test_start_download_existing_handle(self) -> None: """ mock_handle = Mock(info_hash=Mock(return_value=Mock(to_bytes=Mock(return_value=b"\x01" * 20))), is_valid=Mock(return_value=True)) - self.manager.ltsessions[0].get_torrents = Mock(return_value=[mock_handle]) + self.manager.ltsessions[0].result().get_torrents = Mock(return_value=[mock_handle]) download = await self.manager.start_download(tdef=TorrentDefNoMetainfo(b"\x01" * 20, b"name"), config=self.create_mock_download_config(), checkpoint_disabled=True) @@ -268,19 +270,21 @@ def test_set_proxy_settings(self) -> None: """ Test if the proxy settings can be set. """ - self.manager.set_proxy_settings(self.manager.get_session(0), 0, ("a", "1234"), ("abc", "def")) + self.manager.set_proxy_settings(self.manager.get_session(0).result(), 0, ("a", "1234"), ("abc", "def")) self.assertEqual(call({"proxy_type": 0, "proxy_hostnames": True, "proxy_peer_connections": True, "proxy_hostname": "a", "proxy_port": 1234, "proxy_username": "abc", - "proxy_password": "def"}), self.manager.ltsessions[0].apply_settings.call_args) + "proxy_password": "def"}), self.manager.ltsessions[0].result().apply_settings.call_args) - def test_post_session_stats(self) -> None: + async def test_post_session_stats(self) -> None: """ Test if post_session_stats actually updates the state of libtorrent readiness for clean shutdown. """ self.manager.post_session_stats() - self.manager.ltsessions[0].post_session_stats.assert_called_once() + await sleep(0) + + self.manager.ltsessions[0].result().post_session_stats.assert_called_once() async def test_load_checkpoint_no_metainfo(self) -> None: """ @@ -475,26 +479,30 @@ def test_update_trackers_list_append(self) -> None: self.assertSetEqual({f"127.0.0.1/test-announce{i}".encode() for i in range(2)}, {announce_url[0] for announce_url in download.tdef.metainfo[b"announce-list"]}) - def test_get_download_rate_limit(self) -> None: + async def test_get_download_rate_limit(self) -> None: """ Test if the download rate limit can be set. """ settings = {} - self.manager.ltsessions[0].get_settings = Mock(return_value=settings) - self.manager.ltsessions[0].download_rate_limit = functools.partial(settings.get, "download_rate_limit") + self.manager.ltsessions[0].result().get_settings = Mock(return_value=settings) + self.manager.ltsessions[0].result().download_rate_limit = functools.partial(settings.get, "download_rate_limit") self.manager.set_download_rate_limit(42) - self.assertEqual(42, self.manager.get_download_rate_limit()) + await sleep(0) + + self.assertEqual(42, await self.manager.get_download_rate_limit()) - def test_get_upload_rate_limit(self) -> None: + async def test_get_upload_rate_limit(self) -> None: """ Test if the upload rate limit can be set. """ settings = {} - self.manager.ltsessions[0].get_settings = Mock(return_value=settings) - self.manager.ltsessions[0].upload_rate_limit = functools.partial(settings.get, "upload_rate_limit") + self.manager.ltsessions[0].result().get_settings = Mock(return_value=settings) + self.manager.ltsessions[0].result().upload_rate_limit = functools.partial(settings.get, "upload_rate_limit") self.manager.set_upload_rate_limit(42) - self.assertEqual(42, self.manager.get_upload_rate_limit()) + await sleep(0) + + self.assertEqual(42, await self.manager.get_upload_rate_limit())