From 69c5dfff09f4a001955949840dd9887a7cc2438f Mon Sep 17 00:00:00 2001 From: drew2a Date: Tue, 4 Jan 2022 14:10:22 +0100 Subject: [PATCH] Make notifier more generic --- .../community/gigachannel_community.py | 9 +- .../gigachannel_manager.py | 2 +- .../libtorrent/download_manager/download.py | 10 +- .../download_manager/download_manager.py | 29 +++-- .../tests/test_torrentinfo_endpoint.py | 2 +- .../restapi/torrentinfo_endpoint.py | 2 +- .../resource_monitor/implementation/core.py | 6 +- .../tests/test_resource_monitor.py | 4 +- .../torrent_checker/torrent_checker.py | 4 +- .../version_check/versioncheck_manager.py | 9 +- .../components/watch_folder/watch_folder.py | 4 +- src/tribler-core/tribler_core/notifier.py | 73 +++++------- src/tribler-core/tribler_core/start_core.py | 4 +- .../tribler_core/tests/test_notifier.py | 111 ++++++++++++++---- 14 files changed, 166 insertions(+), 103 deletions(-) diff --git a/src/tribler-core/tribler_core/components/gigachannel/community/gigachannel_community.py b/src/tribler-core/tribler_core/components/gigachannel/community/gigachannel_community.py index efe319b6544..371cefe04c8 100644 --- a/src/tribler-core/tribler_core/components/gigachannel/community/gigachannel_community.py +++ b/src/tribler-core/tribler_core/components/gigachannel/community/gigachannel_community.py @@ -13,10 +13,10 @@ from tribler_common.simpledefs import CHANNELS_VIEW_UUID, NTFY from tribler_core.components.gigachannel.community.discovery_booster import DiscoveryBooster -from tribler_core.components.metadata_store.remote_query_community.payload_checker import ObjState from tribler_core.components.metadata_store.db.serialization import CHANNEL_TORRENT -from tribler_core.components.metadata_store.utils import NoChannelSourcesException +from tribler_core.components.metadata_store.remote_query_community.payload_checker import ObjState from tribler_core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity +from tribler_core.components.metadata_store.utils import NoChannelSourcesException from tribler_core.utilities.unicode import hexlify minimal_blob_size = 200 @@ -152,7 +152,8 @@ def on_packet_callback(_, processing_results): ) ] if self.notifier and results: - self.notifier.notify(NTFY.CHANNEL_DISCOVERED, {"results": results, "uuid": str(CHANNELS_VIEW_UUID)}) + self.notifier.notify(NTFY.CHANNEL_DISCOVERED.value, + {"results": results, "uuid": str(CHANNELS_VIEW_UUID)}) request_dict = { "metadata_type": [CHANNEL_TORRENT], @@ -211,7 +212,7 @@ def notify_gui(request, processing_results): ] if self.notifier: self.notifier.notify( - NTFY.REMOTE_QUERY_RESULTS, + NTFY.REMOTE_QUERY_RESULTS.value, {"results": results, "uuid": str(request_uuid), "peer": hexlify(request.peer.mid)}, ) diff --git a/src/tribler-core/tribler_core/components/gigachannel_manager/gigachannel_manager.py b/src/tribler-core/tribler_core/components/gigachannel_manager/gigachannel_manager.py index 7f44ec22af6..a69ccdb4dd3 100644 --- a/src/tribler-core/tribler_core/components/gigachannel_manager/gigachannel_manager.py +++ b/src/tribler-core/tribler_core/components/gigachannel_manager/gigachannel_manager.py @@ -291,7 +291,7 @@ def _process_download(): updated_channel = self.mds.ChannelMetadata.get(public_key=channel.public_key, id_=channel.id_) channel_dict = updated_channel.to_simple_dict() if updated_channel else None if updated_channel: - self.notifier.notify(NTFY.CHANNEL_ENTITY_UPDATED, channel_dict) + self.notifier.notify(NTFY.CHANNEL_ENTITY_UPDATED.value, channel_dict) def updated_my_channel(self, tdef): """ diff --git a/src/tribler-core/tribler_core/components/libtorrent/download_manager/download.py b/src/tribler-core/tribler_core/components/libtorrent/download_manager/download.py index 2a278e99667..1b2b9e11d5c 100644 --- a/src/tribler-core/tribler_core/components/libtorrent/download_manager/download.py +++ b/src/tribler-core/tribler_core/components/libtorrent/download_manager/download.py @@ -15,16 +15,16 @@ from tribler_common.osutils import fix_filebasename from tribler_common.simpledefs import DLSTATUS_SEEDING, DLSTATUS_STOPPED, DOWNLOAD, NTFY -from tribler_core.exceptions import SaveResumeDataError from tribler_core.components.libtorrent.download_manager.download_config import DownloadConfig from tribler_core.components.libtorrent.download_manager.download_state import DownloadState -from tribler_core.components.libtorrent.settings import DownloadDefaultsSettings from tribler_core.components.libtorrent.download_manager.stream import Stream +from tribler_core.components.libtorrent.settings import DownloadDefaultsSettings from tribler_core.components.libtorrent.torrentdef import TorrentDef, TorrentDefNoMetainfo -from tribler_core.notifier import Notifier from tribler_core.components.libtorrent.utils.libtorrent_helper import libtorrent as lt -from tribler_core.utilities.path_util import Path from tribler_core.components.libtorrent.utils.torrent_utils import check_handle, get_info_from_handle, require_handle +from tribler_core.exceptions import SaveResumeDataError +from tribler_core.notifier import Notifier +from tribler_core.utilities.path_util import Path from tribler_core.utilities.unicode import ensure_unicode, hexlify from tribler_core.utilities.utilities import bdecode_compat @@ -391,7 +391,7 @@ def on_torrent_finished_alert(self, _): self.checkpoint() if self.get_state().get_total_transferred(DOWNLOAD) > 0 and self.stream is not None: if self.notifier is not None: - self.notifier.notify(NTFY.TORRENT_FINISHED, self.tdef.get_infohash(), + self.notifier.notify(NTFY.TORRENT_FINISHED.value, self.tdef.get_infohash(), self.tdef.get_name_as_unicode(), self.hidden or self.config.get_channel_download()) diff --git a/src/tribler-core/tribler_core/components/libtorrent/download_manager/download_manager.py b/src/tribler-core/tribler_core/components/libtorrent/download_manager/download_manager.py index 77e7d48afda..7b8a15b4dbe 100644 --- a/src/tribler-core/tribler_core/components/libtorrent/download_manager/download_manager.py +++ b/src/tribler-core/tribler_core/components/libtorrent/download_manager/download_manager.py @@ -101,8 +101,8 @@ def __init__(self, self.metainfo_cache = {} # Dictionary that maps infohashes to cached metainfo items self.default_alert_mask = lt.alert.category_t.error_notification | lt.alert.category_t.status_notification | \ - lt.alert.category_t.storage_notification | lt.alert.category_t.performance_warning | \ - lt.alert.category_t.tracker_notification | lt.alert.category_t.debug_notification + lt.alert.category_t.storage_notification | lt.alert.category_t.performance_warning | \ + lt.alert.category_t.tracker_notification | lt.alert.category_t.debug_notification self.session_stats_callback = None self.state_cb_count = 0 @@ -151,19 +151,22 @@ def initialize(self): self.set_download_states_callback(self.sesscb_states_callback) + def notify_shutdown_state(self, state): + self.notifier.notify(NTFY.TRIBLER_SHUTDOWN_STATE.value, state) + async def shutdown(self, timeout=30): if self.downloads: - self.notifier.notify_shutdown_state("Checkpointing Downloads...") + self.notify_shutdown_state("Checkpointing Downloads...") await gather(*[download.stop() for download in self.downloads.values()], return_exceptions=True) - self.notifier.notify_shutdown_state("Shutting down Downloads...") + self.notify_shutdown_state("Shutting down Downloads...") await gather(*[download.shutdown() for download in self.downloads.values()], return_exceptions=True) - self.notifier.notify_shutdown_state("Shutting down Libtorrent Manager...") + self.notify_shutdown_state("Shutting down Libtorrent Manager...") # If libtorrent session has pending disk io, wait until timeout (default: 30 seconds) to let it finish. # In between ask for session stats to check if state is clean for shutdown. # In dummy mode, we immediately shut down the download manager. while not self.dummy_mode and not self.is_shutdown_ready() and timeout >= 1: - self.notifier.notify_shutdown_state("Waiting for Libtorrent to finish...") + self.notify_shutdown_state("Waiting for Libtorrent to finish...") self.post_session_stats() timeout -= 1 await asyncio.sleep(1) @@ -244,7 +247,7 @@ def create_session(self, hops=0, store_listen_port=True): settings['force_proxy'] = True # Anon listen port is never used anywhere, so we let Libtorrent set it - #settings["listen_interfaces"] = "0.0.0.0:%d" % anon_port + # settings["listen_interfaces"] = "0.0.0.0:%d" % anon_port # By default block all IPs except 1.1.1.1 (which is used to ensure libtorrent makes a connection to us) self.update_ip_filter(ltsession, ['1.1.1.1']) @@ -255,7 +258,7 @@ def create_session(self, hops=0, store_listen_port=True): if hops == 0: proxy_settings = DownloadManager.get_libtorrent_proxy_settings(self.config) else: - proxy_settings = [SOCKS5_PROXY_DEF, ("127.0.0.1", self.socks_listen_ports[hops-1]), None] + proxy_settings = [SOCKS5_PROXY_DEF, ("127.0.0.1", self.socks_listen_ports[hops - 1]), None] self.set_proxy_settings(ltsession, *proxy_settings) for extension in extensions: @@ -276,7 +279,7 @@ def create_session(self, hops=0, store_listen_port=True): except Exception as exc: self._logger.info(f"could not load libtorrent state, got exception: {exc!r}. starting from scratch") else: - #ltsession.listen_on(anon_port, anon_port + 20) + # ltsession.listen_on(anon_port, anon_port + 20) rate = DownloadManager.get_libtorrent_max_upload_rate(self.config) download_rate = DownloadManager.get_libtorrent_max_download_rate(self.config) @@ -369,8 +372,8 @@ def process_alert(self, alert, hops=0): download = self.downloads.get(infohash) if download: is_process_alert = (download.handle and download.handle.is_valid()) \ - or (not download.handle and alert_type == 'add_torrent_alert') \ - or (download.handle and alert_type == 'torrent_removed_alert') + or (not download.handle and alert_type == 'add_torrent_alert') \ + or (download.handle and alert_type == 'torrent_removed_alert') if is_process_alert: download.process_alert(alert, alert_type) else: @@ -385,7 +388,7 @@ def process_alert(self, alert, hops=0): self.listen_ports[hops] = getattr(alert, "port", alert.endpoint[1]) elif alert_type == 'peer_disconnected_alert' and self.notifier: - self.notifier.notify(NTFY.PEER_DISCONNECTED_EVENT, alert.pid.to_bytes()) + self.notifier.notify(NTFY.PEER_DISCONNECTED_EVENT.value, alert.pid.to_bytes()) elif alert_type == 'session_stats_alert': queued_disk_jobs = alert.values['disk.queued_disk_jobs'] @@ -803,7 +806,7 @@ async def sesscb_states_callback(self, states_list): if self.state_cb_count % 5 == 0 and download.config.get_hops() == 0 and self.notifier: for peer in download.get_peerlist(): if str(peer["extended_version"]).startswith('Tribler'): - self.notifier.notify(NTFY.TRIBLER_TORRENT_PEER_UPDATE, + self.notifier.notify(NTFY.TRIBLER_TORRENT_PEER_UPDATE.value, unhexlify(peer["id"]), infohash, peer["dtotal"]) if self.state_cb_count % 4 == 0: diff --git a/src/tribler-core/tribler_core/components/libtorrent/restapi/tests/test_torrentinfo_endpoint.py b/src/tribler-core/tribler_core/components/libtorrent/restapi/tests/test_torrentinfo_endpoint.py index 2d4613d540b..99dce69fc51 100644 --- a/src/tribler-core/tribler_core/components/libtorrent/restapi/tests/test_torrentinfo_endpoint.py +++ b/src/tribler-core/tribler_core/components/libtorrent/restapi/tests/test_torrentinfo_endpoint.py @@ -109,7 +109,7 @@ def get_metainfo(infohash, timeout=20, hops=None, url=None): await do_request(rest_api, f'torrentinfo?uri={path}', expected_code=500) # Ensure that correct torrent metadata was sent through notifier (to MetadataStore) - mock_dlmgr.notifier.notify.assert_called_with(NTFY.TORRENT_METADATA_ADDED, metainfo_dict) + mock_dlmgr.notifier.notify.assert_called_with(NTFY.TORRENT_METADATA_ADDED.value, metainfo_dict) mock_dlmgr.get_metainfo = get_metainfo verify_valid_dict(await do_request(rest_api, f'torrentinfo?uri={path}', expected_code=200)) diff --git a/src/tribler-core/tribler_core/components/libtorrent/restapi/torrentinfo_endpoint.py b/src/tribler-core/tribler_core/components/libtorrent/restapi/torrentinfo_endpoint.py index 5f1435aaa1a..fa298b3ca26 100644 --- a/src/tribler-core/tribler_core/components/libtorrent/restapi/torrentinfo_endpoint.py +++ b/src/tribler-core/tribler_core/components/libtorrent/restapi/torrentinfo_endpoint.py @@ -119,7 +119,7 @@ async def get_torrent_info(self, request): # Add the torrent to GigaChannel as a free-for-all entry, so others can search it self.download_manager.notifier.notify( - NTFY.TORRENT_METADATA_ADDED, + NTFY.TORRENT_METADATA_ADDED.value, tdef_to_metadata_dict(TorrentDef.load_from_dict(metainfo))) # TODO(Martijn): store the stuff in a database!!! diff --git a/src/tribler-core/tribler_core/components/resource_monitor/implementation/core.py b/src/tribler-core/tribler_core/components/resource_monitor/implementation/core.py index 988887f4867..c4e22746d5c 100644 --- a/src/tribler-core/tribler_core/components/resource_monitor/implementation/core.py +++ b/src/tribler-core/tribler_core/components/resource_monitor/implementation/core.py @@ -2,10 +2,12 @@ import time from collections import deque +from ipv8.taskmanager import TaskManager + import psutil -from ipv8.taskmanager import TaskManager from tribler_common.simpledefs import NTFY + from tribler_core.components.resource_monitor.implementation.base import ResourceMonitor from tribler_core.components.resource_monitor.implementation.profiler import YappiProfiler from tribler_core.components.resource_monitor.settings import ResourceMonitorSettings @@ -101,7 +103,7 @@ def record_disk_usage(self, recorded_at=None): if disk_usage.free < FREE_DISK_THRESHOLD: self._logger.warning("Warning! Less than 100MB of disk space available") if self.notifier: - self.notifier.notify(NTFY.LOW_SPACE, self.disk_usage_data[-1]) + self.notifier.notify(NTFY.LOW_SPACE.value, self.disk_usage_data[-1]) def get_free_disk_space(self): return psutil.disk_usage(str(self.state_dir)) diff --git a/src/tribler-core/tribler_core/components/resource_monitor/implementation/tests/test_resource_monitor.py b/src/tribler-core/tribler_core/components/resource_monitor/implementation/tests/test_resource_monitor.py index 9f30d433c7f..658ce7ec68a 100644 --- a/src/tribler-core/tribler_core/components/resource_monitor/implementation/tests/test_resource_monitor.py +++ b/src/tribler-core/tribler_core/components/resource_monitor/implementation/tests/test_resource_monitor.py @@ -7,8 +7,8 @@ import pytest from tribler_common.simpledefs import NTFY -from tribler_core.components.resource_monitor.implementation.core import CoreResourceMonitor +from tribler_core.components.resource_monitor.implementation.core import CoreResourceMonitor from tribler_core.components.resource_monitor.settings import ResourceMonitorSettings @@ -86,7 +86,7 @@ def fake_get_free_disk_space(): return namedtuple('sdiskusage', disk.keys())(*disk.values()) def on_notify(subject, *args): - assert subject in [NTFY.LOW_SPACE, NTFY.TRIBLER_SHUTDOWN_STATE] + assert subject in [NTFY.LOW_SPACE.value, NTFY.TRIBLER_SHUTDOWN_STATE.value] resource_monitor.get_free_disk_space = fake_get_free_disk_space resource_monitor.notifier.notify = on_notify diff --git a/src/tribler-core/tribler_core/components/torrent_checker/torrent_checker/torrent_checker.py b/src/tribler-core/tribler_core/components/torrent_checker/torrent_checker/torrent_checker.py index a8ddcc3c908..a5545099eb1 100644 --- a/src/tribler-core/tribler_core/components/torrent_checker/torrent_checker/torrent_checker.py +++ b/src/tribler-core/tribler_core/components/torrent_checker/torrent_checker/torrent_checker.py @@ -297,7 +297,7 @@ def on_torrent_health_check_completed(self, infohash, result): final_response = {} if not result or not isinstance(result, list): self._logger.info("Received invalid torrent checker result") - self.notifier.notify(NTFY.CHANNEL_ENTITY_UPDATED, + self.notifier.notify(NTFY.CHANNEL_ENTITY_UPDATED.value, {"infohash": hexlify(infohash), "num_seeders": 0, "num_leechers": 0, @@ -329,7 +329,7 @@ def on_torrent_health_check_completed(self, infohash, result): self.update_torrents_checked(torrent_update_dict) # TODO: DRY! Stop doing lots of formats, just make REST endpoint automatically encode binary data to hex! - self.notifier.notify(NTFY.CHANNEL_ENTITY_UPDATED, + self.notifier.notify(NTFY.CHANNEL_ENTITY_UPDATED.value, {"infohash": hexlify(infohash), "num_seeders": torrent_update_dict["seeders"], "num_leechers": torrent_update_dict["leechers"], diff --git a/src/tribler-core/tribler_core/components/version_check/versioncheck_manager.py b/src/tribler-core/tribler_core/components/version_check/versioncheck_manager.py index 42b91f7f74a..7a8e350a98a 100644 --- a/src/tribler-core/tribler_core/components/version_check/versioncheck_manager.py +++ b/src/tribler-core/tribler_core/components/version_check/versioncheck_manager.py @@ -2,13 +2,12 @@ import platform from distutils.version import LooseVersion -from aiohttp import ( - ClientSession, - ClientTimeout, -) +from aiohttp import ClientSession, ClientTimeout from ipv8.taskmanager import TaskManager + from tribler_common.simpledefs import NTFY + from tribler_core.notifier import Notifier from tribler_core.version import version_id @@ -64,7 +63,7 @@ async def check_new_version_api(self, version_check_url): response_dict = await response.json(content_type=None) version = response_dict['name'][1:] if LooseVersion(version) > LooseVersion(version_id): - self.notifier.notify(NTFY.TRIBLER_NEW_VERSION, version) + self.notifier.notify(NTFY.TRIBLER_NEW_VERSION.value, version) return True return False diff --git a/src/tribler-core/tribler_core/components/watch_folder/watch_folder.py b/src/tribler-core/tribler_core/components/watch_folder/watch_folder.py index c7f9f0ff8fe..ed69c4902f0 100644 --- a/src/tribler-core/tribler_core/components/watch_folder/watch_folder.py +++ b/src/tribler-core/tribler_core/components/watch_folder/watch_folder.py @@ -3,7 +3,9 @@ from pathlib import Path from ipv8.taskmanager import TaskManager + from tribler_common.simpledefs import NTFY + from tribler_core.components.libtorrent.download_manager.download_manager import DownloadManager from tribler_core.components.libtorrent.torrentdef import TorrentDef from tribler_core.notifier import Notifier @@ -41,7 +43,7 @@ def cleanup_torrent_file(self, root, name): self._logger.warning(f'Cant rename the file to {path}. Exception: {e}') self._logger.warning("Watch folder - corrupt torrent file %s", name) - self.notifier.notify(NTFY.WATCH_FOLDER_CORRUPT_FILE, name) + self.notifier.notify(NTFY.WATCH_FOLDER_CORRUPT_FILE.value, name) def check_watch_folder(self): if not self.watch_folder.is_dir(): diff --git a/src/tribler-core/tribler_core/notifier.py b/src/tribler-core/tribler_core/notifier.py index 9bafc0fb3a3..cdefd7fa902 100644 --- a/src/tribler-core/tribler_core/notifier.py +++ b/src/tribler-core/tribler_core/notifier.py @@ -1,53 +1,44 @@ -""" -Notifier. - -Author(s): Vadim Bulavintsev -""" import logging from asyncio import get_event_loop - -from tribler_common.simpledefs import NTFY +from collections import defaultdict +from typing import Callable, Dict class Notifier: - def __init__(self): - self._logger = logging.getLogger(self.__class__.__name__) - self.observers = {} + self.logger = logging.getLogger(self.__class__.__name__) + self.observers: Dict[str, set] = defaultdict(set) + # @ichorid: # We have to note the event loop reference, because when we call "notify" from an external thread, # we don't know anything about the existence of the event loop, and get_event_loop() can't find # the original event loop from an external thread. - self._loop = get_event_loop() # We remember the event loop from the thread that runs the Notifier # to be able to schedule notifications from external threads + self._loop = get_event_loop() - def add_observer(self, subject, callback): - assert isinstance(subject, NTFY) - self.observers[subject] = self.observers.get(subject, []) - self.observers[subject].append(callback) - self._logger.debug(f"Add observer topic {subject} callback {callback}") - - def remove_observer(self, subject, callback): - if subject not in self.observers: - return - if callback not in self.observers[subject]: - return - - self.observers[subject].remove(callback) - self._logger.debug(f"Remove observer topic {subject} callback {callback}") - - def notify(self, subject, *args): - # We have to call the notifier callbacks through call_soon_threadsafe - # because the notify method could have been called from a non-reactor thread - self._loop.call_soon_threadsafe(self._notify, subject, *args) - - def _notify(self, subject, *args): - if subject not in self.observers: - self._logger.warning(f"Called notification on a non-existing subject {subject}") - return - for callback in self.observers[subject]: - callback(*args) - - def notify_shutdown_state(self, state): - self._logger.info("Tribler shutdown state notification:%s", state) - self.notify(NTFY.TRIBLER_SHUTDOWN_STATE, state) + def add_observer(self, topic: str, callback: Callable): + self.logger.debug(f"Add observer topic {topic}") + self.observers[topic].add(callback) + + def remove_observer(self, topic: str, callback: Callable): + self.logger.debug(f"Remove observer topic {topic}") + self.observers[topic].discard(callback) + + def notify(self, topic: str, *args, **kwargs): + def _notify(_topic, _kwargs, *_args): + for callback in self.observers[_topic]: + try: + callback(*_args, **_kwargs) + except Exception as _e: # pylint: disable=broad-except + self.logger.exception(_e) + + try: + # @ichorid: + # We have to call the notifier callbacks through call_soon_threadsafe + # because the notify method could have been called from a non-reactor thread + self._loop.call_soon_threadsafe(_notify, topic, kwargs, *args) + except RuntimeError as e: + # Raises RuntimeError if called on a loop that’s been closed. + # This can happen on a secondary thread when the main application is shutting down. + # https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.call_soon_threadsafe + self.logger.warning(e) diff --git a/src/tribler-core/tribler_core/start_core.py b/src/tribler-core/tribler_core/start_core.py index cd7b81accd2..39f78eba87f 100644 --- a/src/tribler-core/tribler_core/start_core.py +++ b/src/tribler-core/tribler_core/start_core.py @@ -94,14 +94,14 @@ async def core_session(config: TriblerConfig, components: List[Component]): # If there is a config error, report to the user via GUI notifier if config.error: - session.notifier.notify(NTFY.REPORT_CONFIG_ERROR, config.error) + session.notifier.notify(NTFY.REPORT_CONFIG_ERROR.value, config.error) # SHUTDOWN await session.shutdown_event.wait() await session.shutdown() if not config.gui_test_mode: - session.notifier.notify_shutdown_state("Saving configuration...") + session.notifier.notify(NTFY.TRIBLER_SHUTDOWN_STATE.value, "Saving configuration...") config.write() diff --git a/src/tribler-core/tribler_core/tests/test_notifier.py b/src/tribler-core/tribler_core/tests/test_notifier.py index 32a9e96d3a0..859e3c785bd 100644 --- a/src/tribler-core/tribler_core/tests/test_notifier.py +++ b/src/tribler-core/tribler_core/tests/test_notifier.py @@ -3,40 +3,105 @@ import pytest -from tribler_common.simpledefs import NTFY - from tribler_core.notifier import Notifier -@pytest.fixture(name="notifier") -def fixture_notifier(): +# pylint: disable=redefined-outer-name, protected-access + +@pytest.fixture +def notifier(): return Notifier() +class TestCallback: + def __init__(self, side_effect=None): + self.callback_has_been_called = False + self.callback_has_been_called_with_args = None + self.callback_has_been_called_with_kwargs = None + self.side_effect = side_effect + self.event = asyncio.Event() + + def callback(self, *args, **kwargs): + self.callback_has_been_called_with_args = args + self.callback_has_been_called_with_kwargs = kwargs + self.callback_has_been_called = True + if self.side_effect: + raise self.side_effect() + + self.event.set() + + +@pytest.mark.asyncio +async def test_notifier_add_observer(notifier: Notifier): + def callback(): + ... + + # test that add observer stores topics and callbacks as a set to prevent duplicates + notifier.add_observer('topic', callback) + notifier.add_observer('topic', callback) + + assert len(notifier.observers['topic']) == 1 + + @pytest.mark.asyncio -async def test_notifier(notifier): +async def test_notifier_remove_nonexistent_observer(notifier: Notifier): + # test that `remove_observer` don't crash in case of calling to remove non existed topic/callback + notifier.remove_observer('nonexistent', lambda: None) + assert not notifier.observers['nonexistent'] - mock_foo = Mock() - notifier.add_observer(NTFY.TORRENT_FINISHED, mock_foo.bar) - notifier.notify(NTFY.TORRENT_FINISHED) - # Notifier uses asyncio loop internally, so we must wait at least a single loop cycle - await asyncio.sleep(0) - mock_foo.bar.assert_called_once() +@pytest.mark.asyncio +async def test_notifier_remove_observer(notifier: Notifier): + def callback(): + ... + + notifier.add_observer('topic', lambda: None) + notifier.add_observer('topic', callback) -def test_remove_observer(notifier): - def _f(): - pass + notifier.remove_observer('topic', callback) + assert len(notifier.observers['topic']) == 1 - notifier.add_observer(NTFY.TORRENT_FINISHED, _f) - assert len(notifier.observers) == 1 - assert len(notifier.observers[NTFY.TORRENT_FINISHED]) == 1 - notifier.remove_observer(NTFY.TORRENT_FINISHED, _f) - assert not notifier.observers[NTFY.TORRENT_FINISHED] +@pytest.mark.timeout(1) +@pytest.mark.asyncio +async def test_notify(notifier: Notifier): + # test that notify works as expected + normal_callback = TestCallback() + + notifier.add_observer('topic', normal_callback.callback) + notifier.notify('topic', 'arg', kwarg='value') + + # wait for the callback + await normal_callback.event.wait() + assert normal_callback.callback_has_been_called + assert normal_callback.callback_has_been_called_with_args == ('arg',) + assert normal_callback.callback_has_been_called_with_kwargs == {'kwarg': 'value'} + + +@pytest.mark.asyncio +async def test_notify_with_exception(notifier: Notifier): + # test that notify works as expected even if one of callbacks will raise an exception + + normal_callback = TestCallback() + side_effect_callback = TestCallback(ValueError) + + notifier.add_observer('topic', side_effect_callback.callback) + notifier.add_observer('topic', normal_callback.callback) + + notifier.notify('topic') + + # wait + await asyncio.sleep(1) + + assert normal_callback.callback_has_been_called + assert side_effect_callback.callback_has_been_called + + +@pytest.mark.asyncio +async def test_notify_call_soon_threadsafe_with_exception(notifier: Notifier): + notifier.logger = Mock() + notifier._loop = Mock(call_soon_threadsafe=Mock(side_effect=RuntimeError)) - # raise no error when _f not presents in callbacks - notifier.remove_observer(NTFY.TORRENT_FINISHED, _f) + notifier.notify('topic') - # raise no error when subject not presents in observers - notifier.remove_observer(NTFY.POPULARITY_COMMUNITY_ADD_UNKNOWN_TORRENT, _f) + notifier.logger.warning.assert_called_once()