From d96854adfac5ea66d729a480bb15a724656cd07a Mon Sep 17 00:00:00 2001 From: drew2a Date: Fri, 7 Jul 2023 15:03:27 +0200 Subject: [PATCH] Fix `shutdown` --- .../libtorrent/download_manager/download.py | 22 +++++++++++++------ .../libtorrent/tests/test_download.py | 11 ++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/tribler/core/components/libtorrent/download_manager/download.py b/src/tribler/core/components/libtorrent/download_manager/download.py index e325d1fea38..8d6c09a8af3 100644 --- a/src/tribler/core/components/libtorrent/download_manager/download.py +++ b/src/tribler/core/components/libtorrent/download_manager/download.py @@ -5,9 +5,11 @@ """ import asyncio import base64 +import itertools import logging from asyncio import CancelledError, Future, iscoroutine, sleep, wait_for from collections import defaultdict +from contextlib import suppress from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from bitarray import bitarray @@ -31,6 +33,8 @@ from tribler.core.utilities.unicode import ensure_unicode, hexlify from tribler.core.utilities.utilities import bdecode_compat +Getter = Optional[Callable[[Any], Any]] + class Download(TaskManager): """ Download subclass that represents a libtorrent download.""" @@ -68,7 +72,7 @@ def __init__(self, self.tracker_status = {} # {url: [num_peers, status_str]} self.checkpoint_disabled = self.dummy - self.futures = defaultdict(list) + self.futures: Dict[str, list[tuple[Future, Callable, Getter]]] = defaultdict(list) self.alert_handlers = defaultdict(list) self.future_added = self.wait_for_alert('add_torrent_alert', lambda a: a.handle) @@ -103,7 +107,7 @@ def __init__(self, def __str__(self): return "Download " % \ - (self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled) + (self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled) def __repr__(self): return self.__str__() @@ -126,8 +130,8 @@ def get_torrent_data(self) -> Optional[object]: def register_alert_handler(self, alert_type: str, handler: lt.torrent_handle): self.alert_handlers[alert_type].append(handler) - def wait_for_alert(self, success_type: str, success_getter: Optional[Callable[[Any], Any]] = None, - fail_type: str = None, fail_getter: Optional[Callable[[Any], Any]] = None) -> Future: + def wait_for_alert(self, success_type: str, success_getter: Getter = None, + fail_type: str = None, fail_getter: Getter = None) -> Future: future = Future() if success_type: self.futures[success_type].append((future, future.set_result, success_getter)) @@ -634,12 +638,16 @@ async def state_callback_loop(): return self.register_anonymous_task("downloads_cb", state_callback_loop) async def shutdown(self): + self._logger.info('Shutting down...') self.alert_handlers.clear() if self.stream is not None: self.stream.close() - for _, futures in self.futures.items(): - for future, _, _ in futures: - future.cancel() + + active_futures = [f for f, _, _ in itertools.chain(*self.futures.values()) if not f.done()] + for future in active_futures: + future.cancel() + with suppress(CancelledError): + await asyncio.gather(*active_futures) # wait for futures to be actually cancelled self.futures.clear() await self.shutdown_task_manager() diff --git a/src/tribler/core/components/libtorrent/tests/test_download.py b/src/tribler/core/components/libtorrent/tests/test_download.py index 5f782b5039c..caf53284dc7 100644 --- a/src/tribler/core/components/libtorrent/tests/test_download.py +++ b/src/tribler/core/components/libtorrent/tests/test_download.py @@ -490,3 +490,14 @@ def test_get_tracker_status_get_peer_info_error(test_download: Download): ) status = test_download.get_tracker_status() assert status + + +async def test_shutdown(test_download: Download): + """ Test that the `shutdown` method closes the stream and clears the `futures` list.""" + test_download.stream = Mock() + assert len(test_download.futures) == 4 + + await test_download.shutdown() + + assert not test_download.futures + assert test_download.stream.close.called