From c17e726a72298f931ea0cb6f4f9d8b57444e9ebd Mon Sep 17 00:00:00 2001 From: qstokkink Date: Wed, 4 Oct 2023 16:22:33 +0200 Subject: [PATCH] Ported communities to CommunitySettings --- .../bandwidth_accounting_community.py | 13 +++-- .../tests/test_community.py | 18 +++++-- .../community/gigachannel_community.py | 18 ++++--- .../tests/test_gigachannel_community.py | 18 ++++--- .../core/components/ipv8/ipv8_component.py | 17 +++++-- .../core/components/ipv8/tribler_community.py | 16 +++++-- .../remote_query_community.py | 25 ++++++---- .../tests/test_remote_query_community.py | 14 ++++-- .../community/popularity_community.py | 14 ++++-- .../tests/test_popularity_community.py | 19 +++++--- .../tunnel/community/tunnel_community.py | 47 ++++++++++++------- .../tests/test_triblertunnel_community.py | 31 ++++++++---- 12 files changed, 170 insertions(+), 80 deletions(-) diff --git a/src/tribler/core/components/bandwidth_accounting/community/bandwidth_accounting_community.py b/src/tribler/core/components/bandwidth_accounting/community/bandwidth_accounting_community.py index 4c3dd4df1a7..7d55565ad62 100644 --- a/src/tribler/core/components/bandwidth_accounting/community/bandwidth_accounting_community.py +++ b/src/tribler/core/components/bandwidth_accounting/community/bandwidth_accounting_community.py @@ -16,10 +16,14 @@ ) from tribler.core.components.bandwidth_accounting.db.database import BandwidthDatabase from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData, EMPTY_SIGNATURE -from tribler.core.components.ipv8.tribler_community import TriblerCommunity +from tribler.core.components.ipv8.tribler_community import TriblerCommunity, TriblerSettings from tribler.core.utilities.unicode import hexlify +class BandwidthCommunitySettings(TriblerSettings): + database: BandwidthDatabase | None = None + + class BandwidthAccountingCommunity(TriblerCommunity): """ Community around bandwidth accounting and payouts. @@ -27,17 +31,18 @@ class BandwidthAccountingCommunity(TriblerCommunity): community_id = unhexlify('79b25f2867739261780faefede8f25038de9975d') DB_NAME = 'bandwidth' version = b'\x02' + settings_class = BandwidthCommunitySettings - def __init__(self, *args, **kwargs) -> None: + def __init__(self, settings: BandwidthCommunitySettings) -> None: """ Initialize the community. :param persistence: The database that stores transactions, will be created if not provided. :param database_path: The path at which the database will be created. Defaults to the current working directory. """ - self.database: BandwidthDatabase = kwargs.pop('database', None) + self.database = settings.database self.random = Random() - super().__init__(*args, **kwargs) + super().__init__(settings) self.request_cache = RequestCache() self.my_pk = self.my_peer.public_key.key_to_bin() diff --git a/src/tribler/core/components/bandwidth_accounting/tests/test_community.py b/src/tribler/core/components/bandwidth_accounting/tests/test_community.py index 3b92360fa49..e47601ead61 100644 --- a/src/tribler/core/components/bandwidth_accounting/tests/test_community.py +++ b/src/tribler/core/components/bandwidth_accounting/tests/test_community.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from ipv8.keyvault.crypto import default_eccrypto from ipv8.peer import Peer from ipv8.test.base import TestBase from ipv8.test.mocking.ipv8 import MockIPv8 from tribler.core.components.bandwidth_accounting.community.bandwidth_accounting_community import ( - BandwidthAccountingCommunity, + BandwidthAccountingCommunity, BandwidthCommunitySettings, ) from tribler.core.components.bandwidth_accounting.community.cache import BandwidthTransactionSignCache from tribler.core.components.bandwidth_accounting.db.database import BandwidthDatabase @@ -21,11 +23,17 @@ def setUp(self): super().setUp() self.initialize(BandwidthAccountingCommunity, 2) - def create_node(self): + def create_node(self, settings: BandwidthCommunitySettings | None = None, + create_dht: bool = False, enable_statistics: bool = False): peer = Peer(default_eccrypto.generate_key("curve25519"), address=("1.2.3.4", 5)) - db = BandwidthDatabase(db_path=MEMORY_DB, my_pub_key=peer.public_key.key_to_bin()) - ipv8 = MockIPv8(peer, BandwidthAccountingCommunity, database=db, - settings=BandwidthAccountingSettings()) + db = BandwidthDatabase(db_path=MEMORY_DB, + my_pub_key=peer.public_key.key_to_bin()) + ipv8 = MockIPv8(peer, + BandwidthAccountingCommunity, + BandwidthCommunitySettings( + database=db, + settings=BandwidthAccountingSettings() + )) return ipv8 def database(self, i): diff --git a/src/tribler/core/components/gigachannel/community/gigachannel_community.py b/src/tribler/core/components/gigachannel/community/gigachannel_community.py index f00d939535c..6dd962c240e 100644 --- a/src/tribler/core/components/gigachannel/community/gigachannel_community.py +++ b/src/tribler/core/components/gigachannel/community/gigachannel_community.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time import uuid from binascii import unhexlify @@ -13,7 +15,8 @@ from tribler.core.components.ipv8.discovery_booster import DiscoveryBooster from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT from tribler.core.components.metadata_store.remote_query_community.payload_checker import ObjState -from tribler.core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity +from tribler.core.components.metadata_store.remote_query_community.remote_query_community import \ + RemoteCommunitySettings, RemoteQueryCommunity from tribler.core.components.metadata_store.utils import NoChannelSourcesException from tribler.core.utilities.notifier import Notifier from tribler.core.utilities.simpledefs import CHANNELS_VIEW_UUID @@ -70,8 +73,13 @@ def get_last_seen_peers_for_channel(self, channel_pk: bytes, channel_id: int, li return sorted(channel_peers, key=lambda x: x.last_response, reverse=True)[0:limit] +class GigaCommunitySettings(RemoteCommunitySettings): + notifier: Notifier | None = None + + class GigaChannelCommunity(RemoteQueryCommunity): community_id = unhexlify('d3512d0ff816d8ac672eab29a9c1a3a32e17cb13') + settings_class = GigaCommunitySettings def create_introduction_response( self, @@ -94,14 +102,12 @@ def create_introduction_response( new_style=new_style, ) - def __init__( - self, *args, notifier: Notifier = None, **kwargs - ): # pylint: disable=unused-argument + def __init__(self, settings: GigaCommunitySettings): # pylint: disable=unused-argument # ACHTUNG! We create a separate instance of Network for this community because it # walks aggressively and wants lots of peers, which can interfere with other communities - super().__init__(*args, **kwargs) + super().__init__(settings) - self.notifier = notifier + self.notifier = settings.notifier # This set contains all the peers that we queried for subscribed channels over time. # It is emptied regularly. The purpose of this set is to work as a filter so we never query the same diff --git a/src/tribler/core/components/gigachannel/community/tests/test_gigachannel_community.py b/src/tribler/core/components/gigachannel/community/tests/test_gigachannel_community.py index a4938f05fb9..5299d7054bf 100644 --- a/src/tribler/core/components/gigachannel/community/tests/test_gigachannel_community.py +++ b/src/tribler/core/components/gigachannel/community/tests/test_gigachannel_community.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from collections.abc import Mapping from dataclasses import asdict, dataclass, fields @@ -5,6 +7,8 @@ from unittest.mock import AsyncMock, Mock import pytest + +from ipv8.community import CommunitySettings from ipv8.keyvault.crypto import default_eccrypto from ipv8.peer import Peer from ipv8.test.base import TestBase @@ -13,7 +17,7 @@ from tribler.core.components.gigachannel.community.gigachannel_community import ( ChannelsPeersMapping, GigaChannelCommunity, - NoChannelSourcesException, + GigaCommunitySettings, NoChannelSourcesException, happy_eyeballs_delay ) from tribler.core.components.gigachannel.community.settings import ChantSettings @@ -66,7 +70,8 @@ async def tearDown(self): metadata_store.shutdown() await super().tearDown() - def create_node(self, *args, **kwargs): + def create_node(self, settings: CommunitySettings | None = None, create_dht: bool = False, + enable_statistics: bool = False): metadata_store = MetadataStore( Path(self.temporary_directory()) / f"{self.count}.db", Path(self.temporary_directory()), @@ -74,10 +79,11 @@ def create_node(self, *args, **kwargs): disable_sync=True, ) self.metadata_store_set.add(metadata_store) - kwargs['metadata_store'] = metadata_store - kwargs['settings'] = ChantSettings() - kwargs['rqc_settings'] = RemoteQueryCommunitySettings() - node = super().create_node(*args, **kwargs) + node = super().create_node(GigaCommunitySettings( + metadata_store=metadata_store, + settings=ChantSettings(), + rqc_settings=RemoteQueryCommunitySettings() + )) node.overlay.discovery_booster.finish() notifier = Notifier(loop=self.loop) diff --git a/src/tribler/core/components/ipv8/ipv8_component.py b/src/tribler/core/components/ipv8/ipv8_component.py index 7b07d4a332e..d123f2c980e 100644 --- a/src/tribler/core/components/ipv8/ipv8_component.py +++ b/src/tribler/core/components/ipv8/ipv8_component.py @@ -1,6 +1,7 @@ from typing import Optional from ipv8.bootstrapping.dispersy.bootstrapper import DispersyBootstrapper +from ipv8.community import CommunitySettings from ipv8.configuration import ConfigBuilder, DISPERSY_BOOTSTRAPPER from ipv8.dht.churn import PingChurn from ipv8.dht.discovery import DHTDiscoveryCommunity @@ -112,7 +113,12 @@ def make_bootstrapper(self) -> DispersyBootstrapper: def _init_peer_discovery_community(self): ipv8 = self.ipv8 - community = DiscoveryCommunity(self.peer, ipv8.endpoint, ipv8.network, max_peers=100) + community = DiscoveryCommunity(CommunitySettings( + my_peer=self.peer, + endpoint=ipv8.endpoint, + network=ipv8.network, + max_peers=100 + )) self.initialise_community_by_default(community) ipv8.add_strategy(community, RandomChurn(community), INFINITE) ipv8.add_strategy(community, PeriodicSimilarity(community), INFINITE) @@ -120,7 +126,12 @@ def _init_peer_discovery_community(self): def _init_dht_discovery_community(self): ipv8 = self.ipv8 - community = DHTDiscoveryCommunity(self.peer, ipv8.endpoint, ipv8.network, max_peers=60) + community = DHTDiscoveryCommunity(CommunitySettings( + my_peer=self.peer, + endpoint=ipv8.endpoint, + network=ipv8.network, + max_peers=60 + )) self.initialise_community_by_default(community) ipv8.add_strategy(community, PingChurn(community), INFINITE) self.dht_discovery_community = community @@ -136,4 +147,4 @@ async def shutdown(self): await self.ipv8.unload_overlay(overlay) await self._task_manager.shutdown_task_manager() - await self.ipv8.stop(stop_loop=False) + await self.ipv8.stop() diff --git a/src/tribler/core/components/ipv8/tribler_community.py b/src/tribler/core/components/ipv8/tribler_community.py index 018d4ecdd56..a8a9779dc79 100644 --- a/src/tribler/core/components/ipv8/tribler_community.py +++ b/src/tribler/core/components/ipv8/tribler_community.py @@ -1,13 +1,21 @@ -from ipv8.community import Community +from __future__ import annotations + +from ipv8.community import Community, CommunitySettings from tribler.core.config.tribler_config_section import TriblerConfigSection +class TriblerSettings(CommunitySettings): + settings: TriblerConfigSection | None = None + + class TriblerCommunity(Community): """Base class for Tribler communities. """ - def __init__(self, *args, settings: TriblerConfigSection = None, **kwargs): - super().__init__(*args, **kwargs) - self.settings = settings + settings_class = TriblerSettings + + def __init__(self, settings: TriblerSettings): + super().__init__(settings) + self.settings = settings.settings self.logger.info(f'Init. Settings: {settings}.') diff --git a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py index 71d2e459875..367f5d021b8 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import struct import time @@ -14,7 +16,7 @@ from tribler.core.components.ipv8.eva.protocol import EVAProtocol from tribler.core.components.ipv8.eva.result import TransferResult -from tribler.core.components.ipv8.tribler_community import TriblerCommunity +from tribler.core.components.ipv8.tribler_community import TriblerCommunity, TriblerSettings from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import LZ4_EMPTY_ARCHIVE, entries_to_chunk @@ -128,21 +130,24 @@ def on_timeout(self): pass +class RemoteCommunitySettings(TriblerSettings): + rqc_settings: RemoteQueryCommunitySettings | None = None + metadata_store: MetadataStore | None = None + tribler_db = None + + class RemoteQueryCommunity(TriblerCommunity): """ Community for general purpose SELECT-like queries into remote Channels database """ + settings_class = RemoteCommunitySettings - def __init__(self, my_peer, endpoint, network, - rqc_settings: RemoteQueryCommunitySettings = None, - metadata_store=None, - tribler_db=None, - **kwargs): - super().__init__(my_peer, endpoint, network=network, **kwargs) + def __init__(self, settings: RemoteCommunitySettings): + super().__init__(settings) - self.rqc_settings = rqc_settings - self.mds: MetadataStore = metadata_store - self.tribler_db = tribler_db + self.rqc_settings = settings.rqc_settings + self.mds = settings.metadata_store + self.tribler_db = settings.tribler_db # This object stores requests for "select" queries that we sent to other hosts. # We keep track of peers we actually requested for data so people can't randomly push spam at us. # Also, this keeps track of hosts we responded to. There is a possibility that diff --git a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py index a279b6ffaa3..6bca64f3a59 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import random import string @@ -9,6 +11,8 @@ from unittest.mock import Mock, patch import pytest + +from ipv8.community import CommunitySettings from ipv8.keyvault.crypto import default_eccrypto from ipv8.test.base import TestBase from pony.orm import db_session @@ -18,7 +22,7 @@ from tribler.core.components.metadata_store.db.serialization import CHANNEL_THUMBNAIL, CHANNEL_TORRENT, REGULAR_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.components.metadata_store.remote_query_community.remote_query_community import ( - RemoteQueryCommunity, + RemoteCommunitySettings, RemoteQueryCommunity, sanitize_query, ) from tribler.core.components.metadata_store.remote_query_community.settings import RemoteQueryCommunitySettings @@ -73,7 +77,8 @@ async def tearDown(self): metadata_store.shutdown() await super().tearDown() - def create_node(self, *args, **kwargs): + def create_node(self, settings: CommunitySettings | None = None, create_dht: bool = False, + enable_statistics: bool = False): metadata_store = MetadataStore( Path(self.temporary_directory()) / f"{self.count}.db", Path(self.temporary_directory()), @@ -81,9 +86,8 @@ def create_node(self, *args, **kwargs): disable_sync=True, ) self.metadata_store_set.add(metadata_store) - kwargs['metadata_store'] = metadata_store - kwargs['rqc_settings'] = RemoteQueryCommunitySettings() - node = super().create_node(*args, **kwargs) + node = super().create_node(RemoteCommunitySettings(metadata_store=metadata_store, + rqc_settings=RemoteQueryCommunitySettings())) self.count += 1 return node diff --git a/src/tribler/core/components/popularity/community/popularity_community.py b/src/tribler/core/components/popularity/community/popularity_community.py index 4e7e31c90c4..c070756b2a5 100644 --- a/src/tribler/core/components/popularity/community/popularity_community.py +++ b/src/tribler/core/components/popularity/community/popularity_community.py @@ -7,7 +7,8 @@ from ipv8.lazy_community import lazy_wrapper from pony.orm import db_session -from tribler.core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity +from tribler.core.components.metadata_store.remote_query_community.remote_query_community import \ + RemoteCommunitySettings, RemoteQueryCommunity from tribler.core.components.popularity.community.payload import PopularTorrentsRequest, TorrentsHealthPayload from tribler.core.components.popularity.community.version_community_mixin import VersionCommunityMixin from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo @@ -19,6 +20,10 @@ from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker +class PopularityCommunitySettings(RemoteCommunitySettings): + torrent_checker: TorrentChecker | None = None + + class PopularityCommunity(RemoteQueryCommunity, VersionCommunityMixin): """ Community for disseminating the content across the network. @@ -36,11 +41,12 @@ class PopularityCommunity(RemoteQueryCommunity, VersionCommunityMixin): GOSSIP_RANDOM_TORRENT_COUNT = 10 community_id = unhexlify('9aca62f878969c437da9844cba29a134917e1648') + settings_class = PopularityCommunitySettings - def __init__(self, *args, torrent_checker=None, **kwargs): + def __init__(self, settings: PopularityCommunitySettings): # Creating a separate instance of Network for this community to find more peers - super().__init__(*args, **kwargs) - self.torrent_checker: TorrentChecker = torrent_checker + super().__init__(settings) + self.torrent_checker = settings.torrent_checker self.add_message_handler(TorrentsHealthPayload, self.on_torrents_health) self.add_message_handler(PopularTorrentsRequest, self.on_popular_torrents_request) diff --git a/src/tribler/core/components/popularity/community/tests/test_popularity_community.py b/src/tribler/core/components/popularity/community/tests/test_popularity_community.py index 78415f61ed8..8d5046f8bbe 100644 --- a/src/tribler/core/components/popularity/community/tests/test_popularity_community.py +++ b/src/tribler/core/components/popularity/community/tests/test_popularity_community.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import time from random import randint from typing import List from unittest.mock import Mock +from ipv8.community import CommunitySettings from ipv8.keyvault.crypto import default_eccrypto from ipv8.test.base import TestBase from ipv8.test.mocking.ipv8 import MockIPv8 @@ -10,7 +13,8 @@ from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.components.metadata_store.remote_query_community.settings import RemoteQueryCommunitySettings -from tribler.core.components.popularity.community.popularity_community import PopularityCommunity +from tribler.core.components.popularity.community.popularity_community import PopularityCommunity, \ + PopularityCommunitySettings from tribler.core.components.torrent_checker.torrent_checker.torrentchecker_session import HealthInfo from tribler.core.tests.tools.base_test import MockObject from tribler.core.utilities.path_util import Path @@ -53,7 +57,8 @@ async def tearDown(self): metadata_store.shutdown() await super().tearDown() - def create_node(self, *args, **kwargs): + def create_node(self, settings: CommunitySettings | None = None, create_dht: bool = False, + enable_statistics: bool = False): mds = MetadataStore(Path(self.temporary_directory()) / f"{self.count}", Path(self.temporary_directory()), default_eccrypto.generate_key("curve25519")) @@ -64,10 +69,12 @@ def create_node(self, *args, **kwargs): self.count += 1 rqc_settings = RemoteQueryCommunitySettings() - return MockIPv8("curve25519", PopularityCommunity, metadata_store=mds, - torrent_checker=torrent_checker, - rqc_settings=rqc_settings - ) + return MockIPv8("curve25519", PopularityCommunity, + PopularityCommunitySettings( + metadata_store=mds, + torrent_checker=torrent_checker, + rqc_settings=rqc_settings + )) @db_session def fill_database(self, metadata_store, last_check_now=False): diff --git a/src/tribler/core/components/tunnel/community/tunnel_community.py b/src/tribler/core/components/tunnel/community/tunnel_community.py index c9896ec31fa..32422cc873e 100644 --- a/src/tribler/core/components/tunnel/community/tunnel_community.py +++ b/src/tribler/core/components/tunnel/community/tunnel_community.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import math import sys @@ -11,7 +13,7 @@ import async_timeout from ipv8.messaging.anonymization.caches import CreateRequestCache from ipv8.messaging.anonymization.community import unpack_cell -from ipv8.messaging.anonymization.hidden_services import HiddenTunnelCommunity +from ipv8.messaging.anonymization.hidden_services import HiddenTunnelCommunity, HiddenTunnelSettings from ipv8.messaging.anonymization.payload import EstablishIntroPayload, NO_CRYPTO_PACKETS from ipv8.messaging.anonymization.tunnel import ( CIRCUIT_STATE_CLOSING, @@ -34,6 +36,7 @@ from tribler.core import notifications from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData +from tribler.core.components.ipv8.tribler_community import TriblerSettings from tribler.core.components.socks_servers.socks5.server import Socks5Server from tribler.core.components.tunnel.community.caches import BalanceRequestCache, HTTPRequestCache from tribler.core.components.tunnel.community.discovery import GoldenRatioStrategy @@ -61,24 +64,34 @@ MAX_HTTP_PACKET_SIZE = 1400 +class TriblerTunnelCommunitySettings(HiddenTunnelSettings, TriblerSettings): + bandwidth_community = None + exitnode_cache: Optional[Path] = None + config = None + notifier = None + dlmgr = None + socks_servers: List[Socks5Server] = [] + + class TriblerTunnelCommunity(HiddenTunnelCommunity): """ This community is built upon the anonymous messaging layer in IPv8. It adds support for libtorrent anonymous downloads and bandwidth token payout when closing circuits. """ community_id = unhexlify('a3591a6bd89bbaca0974062a1287afcfbc6fd6bb') - - def __init__(self, *args, **kwargs): - self.bandwidth_community = kwargs.pop('bandwidth_community', None) - self.exitnode_cache: Optional[Path] = kwargs.pop('exitnode_cache', None) - self.config = kwargs.pop('config', None) - self.notifier = kwargs.pop('notifier', None) - self.download_manager = kwargs.pop('dlmgr', None) - self.socks_servers: List[Socks5Server] = kwargs.pop('socks_servers', []) + settings_class = TriblerTunnelCommunitySettings + + def __init__(self, settings: TriblerTunnelCommunitySettings): + self.bandwidth_community = settings.bandwidth_community + self.exitnode_cache = settings.exitnode_cache + self.config = settings.config + self.notifier = settings.notifier + self.download_manager = settings.dlmgr + self.socks_servers = settings.socks_servers num_competing_slots = self.config.competing_slots num_random_slots = self.config.random_slots - super().__init__(*args, **kwargs) + super().__init__(settings) self._use_main_thread = True if self.config.exitnode_enabled: @@ -382,7 +395,8 @@ def clean_from_slots(self, circuit_id): if tup[1] == circuit_id: self.competing_slots[ind] = (0, None) - def remove_circuit(self, circuit_id, additional_info='', remove_now=False, destroy=False): + def remove_circuit(self, circuit_id: int, additional_info: str = '', remove_now: bool = False, + destroy: bool | int = False): if circuit_id not in self.circuits: self.logger.warning("Circuit %d not found when trying to remove it", circuit_id) return succeed(None) @@ -423,14 +437,12 @@ def remove_circuit(self, circuit_id, additional_info='', remove_now=False, destr remove_now=remove_now, destroy=destroy) @task - async def remove_relay(self, circuit_id, additional_info='', remove_now=False, destroy=False, - got_destroy_from=None, both_sides=True): + async def remove_relay(self, circuit_id: int, additional_info: str = '', remove_now: bool = False, + destroy: bool = False): removed_relays = await super().remove_relay(circuit_id, additional_info=additional_info, remove_now=remove_now, - destroy=destroy, - got_destroy_from=got_destroy_from, - both_sides=both_sides) + destroy=destroy) self.clean_from_slots(circuit_id) @@ -438,7 +450,8 @@ async def remove_relay(self, circuit_id, additional_info='', remove_now=False, d for removed_relay in removed_relays: self.notifier[notifications.circuit_removed](removed_relay, additional_info) - def remove_exit_socket(self, circuit_id, additional_info='', remove_now=False, destroy=False): + def remove_exit_socket(self, circuit_id: int, additional_info: str = '', remove_now: bool = False, + destroy: bool = False): if circuit_id in self.exit_sockets and self.notifier: exit_socket = self.exit_sockets[circuit_id] self.notifier[notifications.circuit_removed](exit_socket, additional_info) diff --git a/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py b/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py index d8ccdde5d79..9ce12b4934a 100644 --- a/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py +++ b/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from asyncio import Future, TimeoutError as AsyncTimeoutError, sleep, wait_for from collections import defaultdict @@ -5,6 +7,8 @@ from unittest.mock import MagicMock, Mock, patch import pytest + +from ipv8.community import CommunitySettings from ipv8.messaging.anonymization.payload import EstablishIntroPayload from ipv8.messaging.anonymization.tunnel import ( CIRCUIT_STATE_READY, @@ -22,12 +26,13 @@ from ipv8.util import succeed from tribler.core.components.bandwidth_accounting.community.bandwidth_accounting_community import ( - BandwidthAccountingCommunity, + BandwidthAccountingCommunity, BandwidthCommunitySettings, ) from tribler.core.components.bandwidth_accounting.db.database import BandwidthDatabase from tribler.core.components.bandwidth_accounting.settings import BandwidthAccountingSettings from tribler.core.components.tunnel.community.payload import BandwidthTransactionPayload -from tribler.core.components.tunnel.community.tunnel_community import PEER_FLAG_EXIT_HTTP, TriblerTunnelCommunity +from tribler.core.components.tunnel.community.tunnel_community import PEER_FLAG_EXIT_HTTP, TriblerTunnelCommunity, \ + TriblerTunnelCommunitySettings from tribler.core.components.tunnel.settings import TunnelCommunitySettings from tribler.core.tests.tools.base_test import MockObject from tribler.core.tests.tools.tracker.http_tracker import HTTPTracker @@ -67,21 +72,27 @@ async def tearDown(self): await node.overlay.bandwidth_community.unload() await super().tearDown() - def create_node(self): + def create_node(self, settings: CommunitySettings | None = None, create_dht: bool = False, + enable_statistics: bool = False): config = TunnelCommunitySettings() mock_ipv8 = MockIPv8("curve25519", TriblerTunnelCommunity, - settings={'remove_tunnel_delay': 0}, - config=config, - exitnode_cache=Path(self.temporary_directory()) / "exitnode_cache.dat" - ) + TriblerTunnelCommunitySettings( + settings={'remove_tunnel_delay': 0}, + config=config, + exitnode_cache=Path(self.temporary_directory()) / "exitnode_cache.dat" + )) mock_ipv8.overlay.settings.max_circuits = 1 db = BandwidthDatabase(db_path=MEMORY_DB, my_pub_key=mock_ipv8.my_peer.public_key.key_to_bin()) # Load the bandwidth accounting community - mock_ipv8.overlay.bandwidth_community = BandwidthAccountingCommunity( - mock_ipv8.my_peer, mock_ipv8.endpoint, mock_ipv8.network, - settings=BandwidthAccountingSettings(), database=db) + mock_ipv8.overlay.bandwidth_community = BandwidthAccountingCommunity(BandwidthCommunitySettings( + my_peer=mock_ipv8.my_peer, + endpoint=mock_ipv8.endpoint, + network=mock_ipv8.network, + settings=BandwidthAccountingSettings(), + database=db + )) mock_ipv8.overlay.dht_provider = MockDHTProvider(Peer(mock_ipv8.overlay.my_peer.key, mock_ipv8.overlay.my_estimated_wan))