Skip to content

Commit

Permalink
Fix run_bandwidth_crawler script
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jan 18, 2022
1 parent cd50ad8 commit aced845
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
55 changes: 19 additions & 36 deletions src/tribler-core/run_bandwidth_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,43 @@
This executable script starts a Tribler instance and joins the BandwidthAccountingCommunity.
"""
import argparse
import logging
import signal
import sys
from asyncio import ensure_future, get_event_loop
from pathlib import Path

from ipv8.loader import IPv8CommunityLoader

from tribler_common.simpledefs import STATEDIR_DB_DIR

from tribler_core.components.bandwidth_accounting.db.database import BandwidthDatabase
from tribler_core.components.bandwidth_accounting.settings import BandwidthAccountingSettings
from tribler_core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent
from tribler_core.components.ipv8.ipv8_component import Ipv8Component
from tribler_core.components.key.key_component import KeyComponent
from tribler_core.config.tribler_config import TriblerConfig
from tribler_core.modules.bandwidth_accounting.launcher import BandwidthCommunityLauncher
from tribler_core.start_core import Session


class PortAction(argparse.Action):
def __call__(self, _, namespace, values, option_string=None):
if not 0 < values < 2**16:
def __call__(self, _, namespace, value, option_string=None):
if not 0 < value < 2 ** 16:
raise argparse.ArgumentError(self, "Invalid port number")
setattr(namespace, self.dest, values)


class BandwidthCommunityCrawlerLauncher(BandwidthCommunityLauncher):

def get_kwargs(self, session):
settings = BandwidthAccountingSettings()
settings.outgoing_query_interval = 5
database = BandwidthDatabase(session.config.state_dir / STATEDIR_DB_DIR / "bandwidth.db",
session.trustchain_keypair.pub().key_to_bin(), store_all_transactions=True)
setattr(namespace, self.dest, value)

return {
"database": database,
"settings": settings,
"max_peers": -1
}


async def start_crawler(tribler_config):

# We use our own community loader
loader = IPv8CommunityLoader()
loader.set_launcher(BandwidthCommunityCrawlerLauncher())
session = Session(tribler_config, community_loader=loader)

await session.start_components()
async def crawler_session(session_config: TriblerConfig):
session = Session(session_config,
[KeyComponent(), Ipv8Component(), BandwidthAccountingComponent(crawler_mode=True)])
signal.signal(signal.SIGTERM, lambda signum, stack: session.shutdown_event.set)
async with session.start():
await session.shutdown_event.wait()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=('Start a crawler in the bandwidth accounting community'))
parser.add_argument('--statedir', '-s', default='bw_crawler', type=str, help='Use an alternate statedir')
parser.add_argument('--restapi', '-p', default=52194, type=str, help='Use an alternate port for the REST API',
parser.add_argument('--restapi', '-p', default=52194, type=int, help='Use an alternate port for the REST API',
action=PortAction, metavar='{0..65535}')
args = parser.parse_args(sys.argv[1:])

logging.basicConfig(level=logging.INFO)

state_dir = Path(args.statedir).absolute()
config = TriblerConfig.load(file=state_dir / 'triblerd.conf', state_dir=state_dir)

Expand All @@ -66,8 +49,8 @@ async def start_crawler(tribler_config):
config.torrent_checking.enabled = False
config.api.http_enabled = True
config.api.http_port = args.restapi
config.bandwidth_accounting.outgoing_query_interval = 5

loop = get_event_loop()
coro = start_crawler(config)
ensure_future(coro)
ensure_future(crawler_session(config))
loop.run_forever()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class BandwidthAccountingComponent(Component):
_ipv8_component: Ipv8Component = None
database: BandwidthDatabase = None

def __init__(self, crawler_mode=False):
super().__init__()
self.crawler_mode = crawler_mode

async def run(self):
await super().run()
self._ipv8_component = await self.require_component(Ipv8Component)
Expand All @@ -24,14 +28,25 @@ async def run(self):
else:
bandwidth_cls = BandwidthAccountingCommunity

if self.crawler_mode:
store_all_transactions = True
unlimited_peers = True
else:
store_all_transactions = False
unlimited_peers = False

db_name = "bandwidth_gui_test.db" if config.gui_test_mode else f"{bandwidth_cls.DB_NAME}.db"
database_path = config.state_dir / STATEDIR_DB_DIR / db_name
self.database = BandwidthDatabase(database_path, self._ipv8_component.peer.public_key.key_to_bin())
self.database = BandwidthDatabase(database_path, self._ipv8_component.peer.public_key.key_to_bin(),
store_all_transactions=store_all_transactions)

kwargs = {"max_peers": -1} if unlimited_peers else {}
self.community = bandwidth_cls(self._ipv8_component.peer,
self._ipv8_component.ipv8.endpoint,
self._ipv8_component.ipv8.network,
settings=config.bandwidth_accounting,
database=self.database)
database=self.database,
**kwargs)

self._ipv8_component.initialise_community_by_default(self.community)

Expand Down

0 comments on commit aced845

Please sign in to comment.