Skip to content

Commit

Permalink
Made lt session creation async
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink committed Aug 23, 2024
1 parent 8fd1051 commit b05fd19
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/tribler/core/libtorrent/download_manager/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def set_def(self, tdef: TorrentDef) -> None:
self.tdef = tdef

@check_handle(None)
def add_trackers(self, trackers: list[str]) -> None:
def add_trackers(self, trackers: list[bytes]) -> None:
"""
Add the given trackers to the handle.
"""
Expand Down
59 changes: 30 additions & 29 deletions src/tribler/core/libtorrent/download_manager/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import os
import time
from asyncio import CancelledError, gather, iscoroutine, shield, sleep, wait_for
from asyncio import CancelledError, Future, gather, iscoroutine, shield, sleep, wait_for
from binascii import hexlify, unhexlify
from collections import defaultdict
from copy import deepcopy
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier,

self.state_dir = Path(config.get_version_state_dir())
self.ltsettings: dict[lt.session, dict] = {} # Stores a copy of the settings dict for each libtorrent session
self.ltsessions: dict[int, lt.session] = {}
self.ltsessions: dict[int, Future[lt.session]] = {}
self.dht_health_manager: DHTHealthManager | None = None
self.listen_ports: dict[int, dict[str, int]] = defaultdict(dict)

Expand Down Expand Up @@ -168,7 +168,7 @@ async def _check_dht_ready(self, min_dht_peers: int = 60) -> None:
See https://github.com/Tribler/tribler/issues/5319
"""
while not (self.get_session() and self.get_session().status().dht_nodes > min_dht_peers):
while (await self.get_session()).status().dht_nodes < min_dht_peers:
await asyncio.sleep(1)

def initialize(self) -> None:
Expand All @@ -180,7 +180,7 @@ def initialize(self) -> None:

# Start upnp
if self.config.get("libtorrent/upnp"):
self.get_session().start_upnp()
self.get_session().add_done_callback(lambda s: s.result().start_upnp())

# Register tasks
self.register_task("process_alerts", self._task_process_alerts, interval=1, ignore=(Exception, ))
Expand Down Expand Up @@ -245,13 +245,14 @@ async def shutdown(self, timeout: int = 30) -> None:
if self.has_session():
logger.info("Saving state...")
self.notify_shutdown_state("Writing session state to disk.")
session = await self.get_session()
with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC230
ltstate_file.write(lt.bencode(self.get_session().save_state()))
ltstate_file.write(lt.bencode(session.save_state()))

if self.has_session() and self.config.get("libtorrent/upnp"):
logger.info("Stopping upnp...")
self.notify_shutdown_state("Stopping UPnP.")
self.get_session().stop_upnp()
self.get_session().add_done_callback(lambda s: s.result().stop_upnp())

# Remove metadata temporary directory
if self.metadata_tmpdir:
Expand Down Expand Up @@ -357,12 +358,12 @@ def has_session(self, hops: int = 0) -> bool:
"""
return hops in self.ltsessions

def get_session(self, hops: int = 0) -> lt.session:
def get_session(self, hops: int = 0) -> Future[lt.session]:
"""
Get the session for the given number of anonymization hops.
"""
if hops not in self.ltsessions:
self.ltsessions[hops] = self.create_session(hops)
self.ltsessions[hops] = self.register_executor_task(f"Create session {hops}", self.create_session, hops)

return self.ltsessions[hops]

Expand Down Expand Up @@ -400,15 +401,16 @@ def set_upload_rate_limit(self, rate: int) -> None:
# Pass outgoing_port and num_outgoing_ports to dict due to bug in libtorrent 0.16.18
settings_dict = {"upload_rate_limit": libtorrent_rate, "outgoing_port": 0, "num_outgoing_ports": 1}
for session in self.ltsessions.values():
self.set_session_settings(session, settings_dict)
session.add_done_callback(lambda s: self.set_session_settings(s.result(), settings_dict))

def get_upload_rate_limit(self, hops: int = 0) -> int:
async def get_upload_rate_limit(self, hops: int = 0) -> int:
"""
Get the upload rate limit for the session with the given hop count.
"""
# Rate conversion due to the fact that we had a different system with Swift
# and the old python BitTorrent core: unlimited == 0, stop == -1, else rate in kbytes
libtorrent_rate = self.get_session(hops).upload_rate_limit()
session = await self.get_session(hops)
libtorrent_rate = session.upload_rate_limit()
return self.reverse_convert_rate(rate=libtorrent_rate)

def set_download_rate_limit(self, rate: int) -> None:
Expand All @@ -420,13 +422,14 @@ def set_download_rate_limit(self, rate: int) -> None:
# Pass outgoing_port and num_outgoing_ports to dict due to bug in libtorrent 0.16.18
settings_dict = {"download_rate_limit": libtorrent_rate}
for session in self.ltsessions.values():
self.set_session_settings(session, settings_dict)
session.add_done_callback(lambda s: self.set_session_settings(s.result(), settings_dict))

def get_download_rate_limit(self, hops: int = 0) -> int:
async def get_download_rate_limit(self, hops: int = 0) -> int:
"""
Get the download rate limit for the session with the given hop count.
"""
libtorrent_rate = self.get_session(hops=hops).download_rate_limit()
session = await self.get_session(hops)
libtorrent_rate = session.download_rate_limit()
return self.reverse_convert_rate(rate=libtorrent_rate)

def process_alert(self, alert: lt.alert, hops: int = 0) -> None: # noqa: C901, PLR0912
Expand Down Expand Up @@ -597,14 +600,12 @@ def _task_cleanup_metainfo_cache(self) -> None:

def _request_torrent_updates(self) -> None:
for ltsession in self.ltsessions.values():
if ltsession:
ltsession.post_torrent_updates(0xffffffff)
ltsession.add_done_callback(lambda s: s.result().post_torrent_updates(0xffffffff))

def _task_process_alerts(self) -> None:
async def _task_process_alerts(self) -> None:
for hops, ltsession in list(self.ltsessions.items()):
if ltsession:
for alert in ltsession.pop_alerts():
self.process_alert(alert, hops=hops)
for alert in (await ltsession).pop_alerts():
self.process_alert(alert, hops=hops)

def _map_call_on_ltsessions(self, hops: int | None, funcname: str, *args: Any, **kwargs) -> None: # noqa: ANN401
if hops is None:
Expand Down Expand Up @@ -729,7 +730,7 @@ async def start_handle(self, download: Download, atp: dict) -> None:
if resume_data:
logger.debug("Download resume data: %s", str(atp["resume_data"]))

ltsession = self.get_session(download.config.get_hops())
ltsession = await self.get_session(download.config.get_hops())
infohash = download.get_def().get_infohash()

if infohash in self.metainfo_requests and self.metainfo_requests[infohash].download != download:
Expand Down Expand Up @@ -810,20 +811,20 @@ def update_max_rates_from_config(self) -> None:
This is the extra step necessary to apply a new maximum download/upload rate setting.
:return:
"""
rate = DownloadManager.get_libtorrent_max_upload_rate(self.config)
download_rate = DownloadManager.get_libtorrent_max_download_rate(self.config)
settings = {"download_rate_limit": download_rate,
"upload_rate_limit": rate}
for lt_session in self.ltsessions.values():
rate = DownloadManager.get_libtorrent_max_upload_rate(self.config)
download_rate = DownloadManager.get_libtorrent_max_download_rate(self.config)
settings = {"download_rate_limit": download_rate,
"upload_rate_limit": rate}
self.set_session_settings(lt_session, settings)
lt_session.add_done_callback(lambda s: self.set_session_settings(s.result(), settings))

def post_session_stats(self) -> None:
"""
Gather statistics and cause a ``session_stats_alert``.
"""
logger.info("Post session stats")
for session in self.ltsessions.values():
session.post_session_stats()
session.add_done_callback(lambda s: s.result().post_session_stats())

async def remove_download(self, download: Download, remove_content: bool = False,
remove_checkpoint: bool = True) -> None:
Expand All @@ -842,7 +843,7 @@ async def remove_download(self, download: Download, remove_content: bool = False
download.stream.disable()
logger.debug("Removing handle %s", hexlify(infohash))
ltsession = self.get_session(download.config.get_hops())
ltsession.remove_torrent(handle, int(remove_content))
ltsession.add_done_callback(lambda s: s.result().remove_torrent(handle, int(remove_content)))
else:
logger.debug("Cannot remove handle %s because it does not exists", hexlify(infohash))
await download.shutdown()
Expand Down Expand Up @@ -889,7 +890,7 @@ async def update_hops(self, download: Download, new_hops: int) -> None:

await self.start_download(tdef=download.tdef, config=config)

def update_trackers(self, infohash: bytes, trackers: list[str]) -> None:
def update_trackers(self, infohash: bytes, trackers: list[bytes]) -> None:
"""
Update the trackers for a download.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def setUp(self) -> None:
"""
super().setUp()
self.manager = DownloadManager(MockTriblerConfigManager(), Notifier(), Mock())
self.manager.ltsessions = {i: Mock(status=Mock(dht_nodes=0), get_torrents=Mock(return_value=[]))
for i in range(4)}
for i in range(4):
fut = Future()
fut.set_result(Mock(status=Mock(dht_nodes=0), get_torrents=Mock(return_value=[])))
self.manager.ltsessions[i] = fut
self.manager.set_download_states_callback(self.manager.sesscb_states_callback)

async def tearDown(self) -> None:
Expand Down Expand Up @@ -158,7 +160,7 @@ async def test_start_download(self) -> None:
mock_alert = type("add_torrent_alert", (object,), {"handle": mock_handle,
"error": Mock(value=Mock(return_value=None)),
"category": MagicMock(return_value=None)})
self.manager.ltsessions[0].async_add_torrent = lambda _: self.manager.process_alert(mock_alert())
self.manager.ltsessions[0].result().async_add_torrent = lambda _: self.manager.process_alert(mock_alert())

with patch.object(self.manager, "remove_download", AsyncMock()):
download = await self.manager.start_download(tdef=TorrentDefNoMetainfo(b"\x01" * 20, b""),
Expand Down Expand Up @@ -201,7 +203,7 @@ async def test_start_download_existing_handle(self) -> None:
"""
mock_handle = Mock(info_hash=Mock(return_value=Mock(to_bytes=Mock(return_value=b"\x01" * 20))),
is_valid=Mock(return_value=True))
self.manager.ltsessions[0].get_torrents = Mock(return_value=[mock_handle])
self.manager.ltsessions[0].result().get_torrents = Mock(return_value=[mock_handle])
download = await self.manager.start_download(tdef=TorrentDefNoMetainfo(b"\x01" * 20, b"name"),
config=self.create_mock_download_config(),
checkpoint_disabled=True)
Expand Down Expand Up @@ -268,19 +270,21 @@ def test_set_proxy_settings(self) -> None:
"""
Test if the proxy settings can be set.
"""
self.manager.set_proxy_settings(self.manager.get_session(0), 0, ("a", "1234"), ("abc", "def"))
self.manager.set_proxy_settings(self.manager.get_session(0).result(), 0, ("a", "1234"), ("abc", "def"))

self.assertEqual(call({"proxy_type": 0, "proxy_hostnames": True, "proxy_peer_connections": True,
"proxy_hostname": "a", "proxy_port": 1234, "proxy_username": "abc",
"proxy_password": "def"}), self.manager.ltsessions[0].apply_settings.call_args)
"proxy_password": "def"}), self.manager.ltsessions[0].result().apply_settings.call_args)

def test_post_session_stats(self) -> None:
async def test_post_session_stats(self) -> None:
"""
Test if post_session_stats actually updates the state of libtorrent readiness for clean shutdown.
"""
self.manager.post_session_stats()

self.manager.ltsessions[0].post_session_stats.assert_called_once()
await sleep(0)

self.manager.ltsessions[0].result().post_session_stats.assert_called_once()

async def test_load_checkpoint_no_metainfo(self) -> None:
"""
Expand Down Expand Up @@ -475,26 +479,30 @@ def test_update_trackers_list_append(self) -> None:
self.assertSetEqual({f"127.0.0.1/test-announce{i}".encode() for i in range(2)},
{announce_url[0] for announce_url in download.tdef.metainfo[b"announce-list"]})

def test_get_download_rate_limit(self) -> None:
async def test_get_download_rate_limit(self) -> None:
"""
Test if the download rate limit can be set.
"""
settings = {}
self.manager.ltsessions[0].get_settings = Mock(return_value=settings)
self.manager.ltsessions[0].download_rate_limit = functools.partial(settings.get, "download_rate_limit")
self.manager.ltsessions[0].result().get_settings = Mock(return_value=settings)
self.manager.ltsessions[0].result().download_rate_limit = functools.partial(settings.get, "download_rate_limit")

self.manager.set_download_rate_limit(42)

self.assertEqual(42, self.manager.get_download_rate_limit())
await sleep(0)

self.assertEqual(42, await self.manager.get_download_rate_limit())

def test_get_upload_rate_limit(self) -> None:
async def test_get_upload_rate_limit(self) -> None:
"""
Test if the upload rate limit can be set.
"""
settings = {}
self.manager.ltsessions[0].get_settings = Mock(return_value=settings)
self.manager.ltsessions[0].upload_rate_limit = functools.partial(settings.get, "upload_rate_limit")
self.manager.ltsessions[0].result().get_settings = Mock(return_value=settings)
self.manager.ltsessions[0].result().upload_rate_limit = functools.partial(settings.get, "upload_rate_limit")

self.manager.set_upload_rate_limit(42)

self.assertEqual(42, self.manager.get_upload_rate_limit())
await sleep(0)

self.assertEqual(42, await self.manager.get_upload_rate_limit())

0 comments on commit b05fd19

Please sign in to comment.