Skip to content

Commit

Permalink
Fix shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Jul 14, 2023
1 parent 27623ee commit 75b3ca8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
"""
import asyncio
import base64
import itertools
import logging
from asyncio import CancelledError, Future, iscoroutine, sleep, wait_for
from collections import defaultdict
from contextlib import suppress
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple

from bitarray import bitarray
Expand All @@ -31,6 +33,8 @@
from tribler.core.utilities.unicode import ensure_unicode, hexlify
from tribler.core.utilities.utilities import bdecode_compat

Getter = Optional[Callable[[Any], Any]]


class Download(TaskManager):
""" Download subclass that represents a libtorrent download."""
Expand Down Expand Up @@ -64,7 +68,7 @@ def __init__(self,
self.checkpoint_after_next_hashcheck = False
self.tracker_status = {} # {url: [num_peers, status_str]}

self.futures = defaultdict(list)
self.futures: Dict[str, list[tuple[Future, Callable, Getter]]] = defaultdict(list)
self.alert_handlers = defaultdict(list)

self.future_added = self.wait_for_alert('add_torrent_alert', lambda a: a.handle)
Expand Down Expand Up @@ -100,7 +104,7 @@ def __init__(self,

def __str__(self):
return "Download <name: '%s' hops: %d checkpoint_disabled: %d>" % \
(self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled)
(self.tdef.get_name(), self.config.get_hops(), self.checkpoint_disabled)

def __repr__(self):
return self.__str__()
Expand All @@ -123,8 +127,8 @@ def get_torrent_data(self) -> Optional[object]:
def register_alert_handler(self, alert_type: str, handler: lt.torrent_handle):
self.alert_handlers[alert_type].append(handler)

def wait_for_alert(self, success_type: str, success_getter: Optional[Callable[[Any], Any]] = None,
fail_type: str = None, fail_getter: Optional[Callable[[Any], Any]] = None) -> Future:
def wait_for_alert(self, success_type: str, success_getter: Getter = None,
fail_type: str = None, fail_getter: Getter = None) -> Future:
future = Future()
if success_type:
self.futures[success_type].append((future, future.set_result, success_getter))
Expand Down Expand Up @@ -631,12 +635,16 @@ async def state_callback_loop():
return self.register_anonymous_task("downloads_cb", state_callback_loop)

async def shutdown(self):
self._logger.info('Shutting down...')
self.alert_handlers.clear()
if self.stream is not None:
self.stream.close()
for _, futures in self.futures.items():
for future, _, _ in futures:
future.cancel()

active_futures = [f for f, _, _ in itertools.chain(*self.futures.values()) if not f.done()]
for future in active_futures:
future.cancel()
with suppress(CancelledError):
await asyncio.gather(*active_futures) # wait for futures to be actually cancelled
self.futures.clear()
await self.shutdown_task_manager()

Expand Down
11 changes: 11 additions & 0 deletions src/tribler/core/components/libtorrent/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,14 @@ def test_get_tracker_status_get_peer_info_error(test_download: Download):
)
status = test_download.get_tracker_status()
assert status


async def test_shutdown(test_download: Download):
""" Test that the `shutdown` method closes the stream and clears the `futures` list."""
test_download.stream = Mock()
assert len(test_download.futures) == 4

await test_download.shutdown()

assert not test_download.futures
assert test_download.stream.close.called

0 comments on commit 75b3ca8

Please sign in to comment.