Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Statically typed notifier #6728

Merged
merged 2 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/tribler-core/run_tunnel_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from asyncio import ensure_future, get_event_loop
from ipaddress import AddressValueError, IPv4Address

from ipv8.messaging.anonymization.tunnel import Circuit
from ipv8.taskmanager import TaskManager

from tribler_core import notifications
from tribler_core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent
from tribler_core.components.base import Session
from tribler_core.components.ipv8.ipv8_component import Ipv8Component
Expand All @@ -24,7 +26,6 @@
from tribler_core.config.tribler_config import TriblerConfig
from tribler_core.utilities.osutils import get_root_state_directory
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.simpledefs import NTFY

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,7 +131,7 @@ async def signal_handler(sig):
new_strategies.append((strategy, target_peers))
ipv8.strategies = new_strategies

def circuit_removed(self, circuit, additional_info):
def circuit_removed(self, circuit: Circuit, additional_info: str):
ipv8 = Ipv8Component.instance().ipv8
ipv8.network.remove_by_address(circuit.peer.address)
if self.log_circuits:
Expand All @@ -146,7 +147,7 @@ async def start(self, options):
session.set_as_default()

self.log_circuits = options.log_circuits
session.notifier.add_observer(NTFY.TUNNEL_REMOVE.value, self.circuit_removed)
session.notifier.add_observer(notifications.circuit_removed, self.circuit_removed)

await session.start_components()

Expand Down
2 changes: 1 addition & 1 deletion src/tribler-core/tribler_core/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from typing import Dict, List, Optional, Set, Type, TypeVar, Union

from tribler_core.config.tribler_config import TriblerConfig
from tribler_core.notifier import Notifier
from tribler_core.utilities.crypto_patcher import patch_crypto_be_discovery
from tribler_core.utilities.install_dir import get_lib_path
from tribler_core.utilities.network_utils import default_network_utils
from tribler_core.utilities.notifier import Notifier
from tribler_core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

from pony.orm import db_session

from tribler_core import notifications
from tribler_core.components.ipv8.discovery_booster import DiscoveryBooster
from tribler_core.components.metadata_store.db.serialization import CHANNEL_TORRENT
from tribler_core.components.metadata_store.remote_query_community.payload_checker import ObjState
from tribler_core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity
from tribler_core.components.metadata_store.utils import NoChannelSourcesException
from tribler_core.utilities.simpledefs import CHANNELS_VIEW_UUID, NTFY
from tribler_core.utilities.notifier import Notifier
from tribler_core.utilities.simpledefs import CHANNELS_VIEW_UUID
from tribler_core.utilities.unicode import hexlify

minimal_blob_size = 200
Expand Down Expand Up @@ -92,7 +94,7 @@ def create_introduction_response(
)

def __init__(
self, *args, notifier=None, **kwargs
self, *args, notifier: Notifier = None, **kwargs
): # pylint: disable=unused-argument
# ACHTUNG! We create a separate instance of Network for this community because it
# walks aggressively and wants lots of peers, which can interfere with other communities
Expand Down Expand Up @@ -151,8 +153,7 @@ def on_packet_callback(_, processing_results):
)
]
if self.notifier and results:
self.notifier.notify(NTFY.CHANNEL_DISCOVERED.value,
{"results": results, "uuid": str(CHANNELS_VIEW_UUID)})
self.notifier[notifications.channel_discovered]({"results": results, "uuid": str(CHANNELS_VIEW_UUID)})

request_dict = {
"metadata_type": [CHANNEL_TORRENT],
Expand Down Expand Up @@ -210,10 +211,8 @@ def notify_gui(request, processing_results):
if r.obj_state in (ObjState.NEW_OBJECT, ObjState.UPDATED_LOCAL_VERSION)
]
if self.notifier:
self.notifier.notify(
NTFY.REMOTE_QUERY_RESULTS.value,
{"results": results, "uuid": str(request_uuid), "peer": hexlify(request.peer.mid)},
)
self.notifier[notifications.remote_query_results](
{"results": results, "uuid": str(request_uuid), "peer": hexlify(request.peer.mid)})

# Try sending the request to at least some peers that we know have it
if "channel_pk" in kwargs and "origin_id" in kwargs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time
from datetime import datetime
from unittest import mock
from unittest.mock import AsyncMock, Mock, PropertyMock, patch

from ipv8.keyvault.crypto import default_eccrypto
Expand All @@ -20,7 +19,7 @@
from tribler_core.components.metadata_store.db.store import MetadataStore
from tribler_core.components.metadata_store.remote_query_community.settings import RemoteQueryCommunitySettings
from tribler_core.components.metadata_store.utils import RequestTimeoutException
from tribler_core.notifier import Notifier
from tribler_core.utilities.notifier import Notifier
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.utilities import random_infohash

Expand Down Expand Up @@ -55,7 +54,7 @@ def create_node(self, *args, **kwargs):
kwargs['metadata_store'] = metadata_store
kwargs['settings'] = ChantSettings()
kwargs['rqc_settings'] = RemoteQueryCommunitySettings()
with mock.patch('tribler_core.components.gigachannel.community.gigachannel_community.DiscoveryBooster'):
with patch('tribler_core.components.gigachannel.community.gigachannel_community.DiscoveryBooster'):
node = super().create_node(*args, **kwargs)
self.count += 1
return node
Expand Down Expand Up @@ -117,27 +116,24 @@ async def test_gigachannel_search(self):
self.nodes[1].overlay.mds.TorrentMetadata(title=U_TORRENT, infohash=random_infohash())
self.nodes[1].overlay.mds.TorrentMetadata(title="debian torrent", infohash=random_infohash())

notification_calls = []

def mock_notify(_, args):
notification_calls.append(args)

self.nodes[2].overlay.notifier = Notifier()
self.nodes[2].overlay.notifier.notify = lambda sub, args: mock_notify(self.nodes[2].overlay, args)
notifier = Notifier(loop=self.loop)
notifier.notify = Mock()
self.nodes[2].overlay.notifier = notifier

self.nodes[2].overlay.send_search_request(**{"txt_filter": "ubuntu*"})

await self.deliver_messages(timeout=0.5)

# Check that the notifier callback was called on both entries
titles = sorted(call.args[1]["results"][0]["name"] for call in notifier.notify.call_args_list)
assert titles == [U_CHANNEL, U_TORRENT]

with db_session:
assert self.nodes[2].overlay.mds.ChannelNode.select().count() == 2
assert (
self.nodes[2].overlay.mds.ChannelNode.select(lambda g: g.title in (U_CHANNEL, U_TORRENT)).count() == 2
)

# Check that the notifier callback was called on both entries
assert [U_CHANNEL, U_TORRENT] == sorted([c["results"][0]["name"] for c in notification_calls])

def test_query_on_introduction(self):
"""
Test querying a peer that was just introduced to us.
Expand Down Expand Up @@ -198,18 +194,18 @@ async def test_remote_select_subscribed_channels(self):
channel_uns = self.nodes[0].overlay.mds.ChannelMetadata.create_channel("channel unsub", "")
channel_uns.subscribed = False

def mock_notify(overlay, args):
overlay.notified_results = True
self.assertTrue("results" in args)

self.nodes[1].overlay.notifier = Notifier()
self.nodes[1].overlay.notifier.notify = lambda sub, args: mock_notify(self.nodes[1].overlay, args)
notifier = Notifier(loop=self.loop)
notifier.notify = Mock()
self.nodes[1].overlay.notifier = notifier

peer = self.nodes[0].my_peer
await self.introduce_nodes()

await self.deliver_messages(timeout=0.5)

# Check that the notifier callback is called on new channel entries
notifier.notify.assert_called()
assert "results" in notifier.notify.call_args.args[1]

with db_session:
received_channels = self.nodes[1].overlay.mds.ChannelMetadata.select(lambda g: g.title == "channel sub")
self.assertEqual(num_channels, received_channels.count())
Expand All @@ -225,9 +221,6 @@ def mock_notify(overlay, args):
for chan in self.nodes[1].overlay.mds.ChannelMetadata.select():
self.assertTrue(chan.votes > 0.0)

# Check that the notifier callback is called on new channel entries
self.assertTrue(self.nodes[1].overlay.notified_results)

def test_channels_peers_mapping_drop_excess_peers(self):
"""
Test dropping old excess peers from a channel to peers mapping
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from pony.orm import db_session

from tribler_core import notifications
from tribler_core.components.libtorrent.download_manager.download_config import DownloadConfig
from tribler_core.components.libtorrent.download_manager.download_manager import DownloadManager
from tribler_core.components.libtorrent.torrentdef import TorrentDef
from tribler_core.components.metadata_store.db.orm_bindings.channel_node import COMMITTED
from tribler_core.components.metadata_store.db.serialization import CHANNEL_TORRENT
from tribler_core.components.metadata_store.db.store import MetadataStore
from tribler_core.notifier import Notifier
from tribler_core.utilities.notifier import Notifier
from tribler_core.utilities.simpledefs import DLSTATUS_SEEDING, NTFY
from tribler_core.utilities.unicode import hexlify

Expand Down Expand Up @@ -290,7 +291,7 @@ def _process_download():
updated_channel = self.mds.ChannelMetadata.get(public_key=channel.public_key, id_=channel.id_)
channel_dict = updated_channel.to_simple_dict() if updated_channel else None
if updated_channel:
self.notifier.notify(NTFY.CHANNEL_ENTITY_UPDATED.value, channel_dict)
self.notifier[notifications.channel_entity_updated](channel_dict)

def updated_my_channel(self, tdef):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from asyncio import Future
from datetime import datetime
from pathlib import Path
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, patch

from ipv8.util import succeed

Expand Down Expand Up @@ -42,8 +42,8 @@ async def gigachannel_manager(metadata_store):
chanman = GigaChannelManager(
state_dir=metadata_store.channels_dir.parent,
metadata_store=metadata_store,
download_manager=Mock(),
notifier=Mock(),
download_manager=MagicMock(),
notifier=MagicMock()
)
yield chanman
await chanman.shutdown()
Expand All @@ -55,7 +55,7 @@ async def test_regen_personal_channel_no_torrent(personal_channel, gigachannel_m
Test regenerating a non-existing personal channel torrent at startup
"""
gigachannel_manager.download_manager.get_download = lambda _: None
gigachannel_manager.regenerate_channel_torrent = Mock()
gigachannel_manager.regenerate_channel_torrent = MagicMock()
await gigachannel_manager.check_and_regen_personal_channels()
gigachannel_manager.regenerate_channel_torrent.assert_called_once()

Expand Down Expand Up @@ -86,8 +86,8 @@ async def test_regenerate_channel_torrent(personal_channel, metadata_store, giga
# Test trying to regenerate a non-existing channel
assert await gigachannel_manager.regenerate_channel_torrent(chan_pk, chan_id + 1) is None

# Mock existing downloads removal-related functions
gigachannel_manager.download_manager.get_downloads_by_name = lambda *_: [Mock()]
# MagicMock existing downloads removal-related functions
gigachannel_manager.download_manager.get_downloads_by_name = lambda *_: [MagicMock()]
downloads_to_remove = []

async def mock_remove_download(download_obj, **_):
Expand All @@ -101,16 +101,16 @@ async def mock_remove_download(download_obj, **_):
assert len(downloads_to_remove) == 1

# Test regenerating a non-empty channel
gigachannel_manager.updated_my_channel = Mock()
metadata_store.ChannelMetadata.consolidate_channel_torrent = lambda *_: Mock()
gigachannel_manager.updated_my_channel = MagicMock()
metadata_store.ChannelMetadata.consolidate_channel_torrent = lambda *_: MagicMock()
with patch("tribler_core.components.libtorrent.torrentdef.TorrentDef.load_from_dict"):
await gigachannel_manager.regenerate_channel_torrent(chan_pk, chan_id)
gigachannel_manager.updated_my_channel.assert_called_once()


def test_updated_my_channel(personal_channel, gigachannel_manager, tmpdir):
tdef = TorrentDef.load_from_dict(update_metainfo)
gigachannel_manager.download_manager.start_download = Mock()
gigachannel_manager.download_manager.start_download = MagicMock()
gigachannel_manager.download_manager.download_exists = lambda *_: False
gigachannel_manager.updated_my_channel(tdef)
gigachannel_manager.download_manager.start_download.assert_called_once()
Expand All @@ -120,7 +120,7 @@ def test_updated_my_channel(personal_channel, gigachannel_manager, tmpdir):
async def test_check_and_regen_personal_channel_torrent(personal_channel, gigachannel_manager):
with db_session:
chan_pk, chan_id = personal_channel.public_key, personal_channel.id_
chan_download = Mock()
chan_download = MagicMock()

async def mock_wait(*_):
pass
Expand All @@ -135,7 +135,7 @@ async def mock_wait_2(*_):
chan_download.wait_for_status = mock_wait_2
# Test timeout waiting for seeding state and then regen

f = Mock()
f = MagicMock()

async def mock_regen(*_):
f()
Expand Down Expand Up @@ -212,7 +212,7 @@ def fake_get_metainfo(infohash, **_):
gigachannel_manager.download_manager = MockObject()
gigachannel_manager.download_manager.download_exists = lambda _: True

mock_download = Mock()
mock_download = MagicMock()
mock_download.get_state.get_status = DLSTATUS_SEEDING

gigachannel_manager.download_manager.get_download = lambda _: mock_download
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ipv8.taskmanager import TaskManager, task
from ipv8.util import int2byte, succeed

from tribler_core import notifications
from tribler_core.components.libtorrent.download_manager.download_config import DownloadConfig
from tribler_core.components.libtorrent.download_manager.download_state import DownloadState
from tribler_core.components.libtorrent.download_manager.stream import Stream
Expand All @@ -20,7 +21,7 @@
from tribler_core.components.libtorrent.utils.libtorrent_helper import libtorrent as lt
from tribler_core.components.libtorrent.utils.torrent_utils import check_handle, get_info_from_handle, require_handle
from tribler_core.exceptions import SaveResumeDataError
from tribler_core.notifier import Notifier
from tribler_core.utilities.notifier import Notifier
from tribler_core.utilities.osutils import fix_filebasename
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.simpledefs import DLSTATUS_SEEDING, DLSTATUS_STOPPED, DOWNLOAD, NTFY
Expand Down Expand Up @@ -388,11 +389,11 @@ def on_torrent_checked_alert(self, _):
def on_torrent_finished_alert(self, _):
self.update_lt_status(self.handle.status())
self.checkpoint()
if self.get_state().get_total_transferred(DOWNLOAD) > 0 and self.stream is not None:
if self.notifier is not None:
self.notifier.notify(NTFY.TORRENT_FINISHED.value, self.tdef.get_infohash(),
self.tdef.get_name_as_unicode(), self.hidden or
self.config.get_channel_download())
downloaded = self.get_state().get_total_transferred(DOWNLOAD)
if downloaded > 0 and self.stream is not None and self.notifier is not None:
self.notifier[notifications.torrent_finished](infohash=self.tdef.get_infohash().hex(),
name=self.tdef.get_name_as_unicode(),
hidden=self.hidden or self.config.get_channel_download())

def update_lt_status(self, lt_status):
""" Update libtorrent stats and check if the download should be stopped."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@

from ipv8.taskmanager import TaskManager, task

from tribler_core import notifications
from tribler_core.components.libtorrent.download_manager.dht_health_manager import DHTHealthManager
from tribler_core.components.libtorrent.download_manager.download import Download
from tribler_core.components.libtorrent.download_manager.download_config import DownloadConfig
from tribler_core.components.libtorrent.settings import DownloadDefaultsSettings, LibtorrentSettings
from tribler_core.components.libtorrent.torrentdef import TorrentDef, TorrentDefNoMetainfo
from tribler_core.components.libtorrent.utils import torrent_utils
from tribler_core.components.libtorrent.utils.libtorrent_helper import libtorrent as lt
from tribler_core.notifier import Notifier
from tribler_core.utilities import path_util
from tribler_core.utilities.network_utils import default_network_utils
from tribler_core.utilities.notifier import Notifier
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.rest_utils import (
FILE_SCHEME,
Expand Down Expand Up @@ -158,7 +159,7 @@ def initialize(self):
self.set_download_states_callback(self.sesscb_states_callback)

def notify_shutdown_state(self, state):
self.notifier.notify(NTFY.TRIBLER_SHUTDOWN_STATE.value, state)
self.notifier[notifications.tribler_shutdown_state](state)

async def shutdown(self, timeout=30):
if self.downloads:
Expand Down Expand Up @@ -393,8 +394,8 @@ def process_alert(self, alert, hops=0):
# We use the now-deprecated ``endpoint`` attribute for these older versions.
self.listen_ports[hops] = getattr(alert, "port", alert.endpoint[1])

elif alert_type == 'peer_disconnected_alert' and self.notifier:
self.notifier.notify(NTFY.PEER_DISCONNECTED_EVENT.value, alert.pid.to_bytes())
elif alert_type == 'peer_disconnected_alert':
self.notifier[notifications.peer_disconnected](alert.pid.to_bytes())

elif alert_type == 'session_stats_alert':
queued_disk_jobs = alert.values['disk.queued_disk_jobs']
Expand Down Expand Up @@ -814,8 +815,8 @@ async def sesscb_states_callback(self, states_list):
if self.state_cb_count % 5 == 0 and download.config.get_hops() == 0 and self.notifier:
for peer in download.get_peerlist():
if str(peer["extended_version"]).startswith('Tribler'):
self.notifier.notify(NTFY.TRIBLER_TORRENT_PEER_UPDATE.value,
unhexlify(peer["id"]), infohash, peer["dtotal"])
self.notifier[notifications.tribler_torrent_peer_update](unhexlify(peer["id"]), infohash,
peer["dtotal"])

if self.state_cb_count % 4 == 0:
self._last_states_list = states_list
Expand Down
Loading