From aced84587f9f1e15217bec9390652e76c41f2b29 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 18 Jan 2022 10:18:58 +0100 Subject: [PATCH] Fix run_bandwidth_crawler script --- src/tribler-core/run_bandwidth_crawler.py | 55 +++++++------------ .../bandwidth_accounting_component.py | 19 ++++++- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/tribler-core/run_bandwidth_crawler.py b/src/tribler-core/run_bandwidth_crawler.py index 494348c7951..ba057adbe6d 100644 --- a/src/tribler-core/run_bandwidth_crawler.py +++ b/src/tribler-core/run_bandwidth_crawler.py @@ -2,60 +2,43 @@ This executable script starts a Tribler instance and joins the BandwidthAccountingCommunity. """ import argparse +import logging +import signal import sys from asyncio import ensure_future, get_event_loop from pathlib import Path -from ipv8.loader import IPv8CommunityLoader - -from tribler_common.simpledefs import STATEDIR_DB_DIR - -from tribler_core.components.bandwidth_accounting.db.database import BandwidthDatabase -from tribler_core.components.bandwidth_accounting.settings import BandwidthAccountingSettings +from tribler_core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent +from tribler_core.components.ipv8.ipv8_component import Ipv8Component +from tribler_core.components.key.key_component import KeyComponent from tribler_core.config.tribler_config import TriblerConfig -from tribler_core.modules.bandwidth_accounting.launcher import BandwidthCommunityLauncher from tribler_core.start_core import Session class PortAction(argparse.Action): - def __call__(self, _, namespace, values, option_string=None): - if not 0 < values < 2**16: + def __call__(self, _, namespace, value, option_string=None): + if not 0 < value < 2 ** 16: raise argparse.ArgumentError(self, "Invalid port number") - setattr(namespace, self.dest, values) - - -class BandwidthCommunityCrawlerLauncher(BandwidthCommunityLauncher): - - def get_kwargs(self, session): - settings = BandwidthAccountingSettings() - settings.outgoing_query_interval = 5 - database = BandwidthDatabase(session.config.state_dir / STATEDIR_DB_DIR / "bandwidth.db", - session.trustchain_keypair.pub().key_to_bin(), store_all_transactions=True) + setattr(namespace, self.dest, value) - return { - "database": database, - "settings": settings, - "max_peers": -1 - } - -async def start_crawler(tribler_config): - - # We use our own community loader - loader = IPv8CommunityLoader() - loader.set_launcher(BandwidthCommunityCrawlerLauncher()) - session = Session(tribler_config, community_loader=loader) - - await session.start_components() +async def crawler_session(session_config: TriblerConfig): + session = Session(session_config, + [KeyComponent(), Ipv8Component(), BandwidthAccountingComponent(crawler_mode=True)]) + signal.signal(signal.SIGTERM, lambda signum, stack: session.shutdown_event.set) + async with session.start(): + await session.shutdown_event.wait() if __name__ == "__main__": parser = argparse.ArgumentParser(description=('Start a crawler in the bandwidth accounting community')) parser.add_argument('--statedir', '-s', default='bw_crawler', type=str, help='Use an alternate statedir') - parser.add_argument('--restapi', '-p', default=52194, type=str, help='Use an alternate port for the REST API', + parser.add_argument('--restapi', '-p', default=52194, type=int, help='Use an alternate port for the REST API', action=PortAction, metavar='{0..65535}') args = parser.parse_args(sys.argv[1:]) + logging.basicConfig(level=logging.INFO) + state_dir = Path(args.statedir).absolute() config = TriblerConfig.load(file=state_dir / 'triblerd.conf', state_dir=state_dir) @@ -66,8 +49,8 @@ async def start_crawler(tribler_config): config.torrent_checking.enabled = False config.api.http_enabled = True config.api.http_port = args.restapi + config.bandwidth_accounting.outgoing_query_interval = 5 loop = get_event_loop() - coro = start_crawler(config) - ensure_future(coro) + ensure_future(crawler_session(config)) loop.run_forever() diff --git a/src/tribler-core/tribler_core/components/bandwidth_accounting/bandwidth_accounting_component.py b/src/tribler-core/tribler_core/components/bandwidth_accounting/bandwidth_accounting_component.py index 5327c20113e..88b7afc302a 100644 --- a/src/tribler-core/tribler_core/components/bandwidth_accounting/bandwidth_accounting_component.py +++ b/src/tribler-core/tribler_core/components/bandwidth_accounting/bandwidth_accounting_component.py @@ -14,6 +14,10 @@ class BandwidthAccountingComponent(Component): _ipv8_component: Ipv8Component = None database: BandwidthDatabase = None + def __init__(self, crawler_mode=False): + super().__init__() + self.crawler_mode = crawler_mode + async def run(self): await super().run() self._ipv8_component = await self.require_component(Ipv8Component) @@ -24,14 +28,25 @@ async def run(self): else: bandwidth_cls = BandwidthAccountingCommunity + if self.crawler_mode: + store_all_transactions = True + unlimited_peers = True + else: + store_all_transactions = False + unlimited_peers = False + db_name = "bandwidth_gui_test.db" if config.gui_test_mode else f"{bandwidth_cls.DB_NAME}.db" database_path = config.state_dir / STATEDIR_DB_DIR / db_name - self.database = BandwidthDatabase(database_path, self._ipv8_component.peer.public_key.key_to_bin()) + self.database = BandwidthDatabase(database_path, self._ipv8_component.peer.public_key.key_to_bin(), + store_all_transactions=store_all_transactions) + + kwargs = {"max_peers": -1} if unlimited_peers else {} self.community = bandwidth_cls(self._ipv8_component.peer, self._ipv8_component.ipv8.endpoint, self._ipv8_component.ipv8.network, settings=config.bandwidth_accounting, - database=self.database) + database=self.database, + **kwargs) self._ipv8_component.initialise_community_by_default(self.community)