Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Add dependency-injector to Tribler #6197

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experiment/popularity_community/crawl_torrents.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def create_config(working_dir, config_path):
async def on_tribler_started(self):
await super().on_tribler_started()
session = self.session
peer = Peer(session.trustchain_keypair)
peer = Peer(session.trustchain_keys.keypair)

crawler_settings = SimpleNamespace(output_file_path=self._output_file_path,
peers_count_csv_file_path=self._peers_count_csv_file_path)
Expand Down
13 changes: 10 additions & 3 deletions experiment/popularity_community/initial_filling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

from ipv8.peer import Peer
from ipv8.peerdiscovery.discovery import RandomWalk
from tribler_core import containers
from tribler_core.modules.popularity.community import PopularityCommunity
from tribler_core.modules.remote_query_community.settings import RemoteQueryCommunitySettings
from tribler_core.modules.popularity.popularity_community import PopularityCommunity
from tribler_core.session import Session
from tribler_core.utilities.tiny_tribler_service import TinyTriblerService

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,12 +69,17 @@ def check(self):

class Service(TinyTriblerService):
def __init__(self, interval_in_sec, output_file_path, timeout_in_sec, working_dir, config_path):
super().__init__(Service.create_config(working_dir, config_path), timeout_in_sec,
config = Service.create_config(working_dir, config_path)
super().__init__(config, timeout_in_sec,
working_dir, config_path)

self._interval_in_sec = interval_in_sec
self._output_file_path = output_file_path

self.application = containers.ApplicationContainer(state_dir=config.state_dir)
self.application.config.from_pydantic(config)
self.application.wire(modules=[Session])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good start point to discover PR.


@staticmethod
def create_config(working_dir, config_path):
config = TinyTriblerService.create_default_config(working_dir, config_path)
Expand All @@ -85,7 +92,7 @@ async def on_tribler_started(self):
await super().on_tribler_started()

session = self.session
peer = Peer(session.trustchain_keypair)
peer = Peer(session.trustchain_keys.keypair)

session.popularity_community = ObservablePopularityCommunity(self._interval_in_sec,
self._output_file_path,
Expand Down
1 change: 1 addition & 0 deletions src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ netifaces
pyqtgraph
yappi
pydantic
dependency-injector
11 changes: 11 additions & 0 deletions src/run_tribler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tribler_common.sentry_reporter.sentry_reporter import SentryReporter, SentryStrategy
from tribler_common.sentry_reporter.sentry_scrubber import SentryScrubber
from tribler_common.version_manager import VersionHistory
from tribler_core import containers
from tribler_core.dependencies import check_for_missing_dependencies
from tribler_core.modules.community_loader import create_default_loader
from tribler_core.utilities.osutils import get_root_state_directory
Expand Down Expand Up @@ -92,6 +93,16 @@ async def start_tribler():
trace_logger = check_and_enable_code_tracing('core', log_dir)

community_loader = create_default_loader(config)

application = containers.ApplicationContainer(state_dir=config.state_dir)
application.config.from_pydantic(config)

if core_test_mode:
from ipv8.messaging.interfaces.dispatcher.endpoint import DispatcherEndpoint
application.ipv8_container.endpoint = DispatcherEndpoint([])

application.wire(modules=[Session])

session = Session(config, core_test_mode=core_test_mode, community_loader=community_loader)

signal.signal(signal.SIGTERM, lambda signum, stack: shutdown(session, signum, stack))
Expand Down
2 changes: 1 addition & 1 deletion src/tribler-core/run_bandwidth_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_kwargs(self, session):
settings = BandwidthAccountingSettings()
settings.outgoing_query_interval = 5
database = BandwidthDatabase(session.config.state_dir / "sqlite" / "bandwidth.db",
session.trustchain_keypair.pub().key_to_bin(), store_all_transactions=True)
session.trustchain_keys.keypair.pub().key_to_bin(), store_all_transactions=True)

return {
"database": database,
Expand Down
35 changes: 35 additions & 0 deletions src/tribler-core/tribler_core/containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from dependency_injector import containers, providers

from ipv8.messaging.interfaces.dispatcher.endpoint import DispatcherEndpoint
from ipv8.peer import Peer
from ipv8_service import IPv8
from tribler_core.ipv8_config import Ipv8Config
from tribler_core.trustchain_keys import TrustChainKeys


class Ipv8Container(containers.DeclarativeContainer):
state_dir = providers.Dependency()
config = providers.Configuration()

ipv8_config = providers.Singleton(Ipv8Config, state_dir=state_dir, config=config.provider)

endpoint = providers.Singleton(
DispatcherEndpoint,
["UDPIPv4"],
UDPIPv4=providers.Dict(port=config.port, ip=config.address),

)

ipv8 = providers.Singleton(
IPv8, ipv8_config.provided.value, enable_statistics=config.statistics, endpoint_override=endpoint
)


class ApplicationContainer(containers.DeclarativeContainer):
state_dir = providers.Dependency()
config = providers.Configuration()

trustchain_keys = providers.Singleton(TrustChainKeys, state_dir=state_dir, config=config.provider)
peer = providers.Singleton(Peer, trustchain_keys.provided.trustchain_keypair)

ipv8_container = providers.Container(Ipv8Container, state_dir=state_dir, config=config.ipv8)
15 changes: 15 additions & 0 deletions src/tribler-core/tribler_core/ipv8_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ipv8.configuration import ConfigBuilder


class Ipv8Config:
"""Extracted from session.py"""

def __init__(self, state_dir=None, config=None):
self.value = (ConfigBuilder()
.set_port(config.port())
.set_address(config.address())
.clear_overlays()
.clear_keys() # We load the keys ourselves
.set_working_directory(str(state_dir))
.set_walker_interval(config.walk_interval())
.finalize())
8 changes: 7 additions & 1 deletion src/tribler-core/tribler_core/modules/community_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# pylint: disable=import-outside-toplevel
from dependency_injector.wiring import Provide

from ipv8.loader import (
CommunityLauncher,
IPv8CommunityLoader,
Expand All @@ -9,6 +11,7 @@
walk_strategy,
)
from ipv8.peer import Peer
from tribler_core import containers

from tribler_core.config.tribler_config import TriblerConfig

Expand All @@ -22,7 +25,7 @@ class TriblerCommunityLauncher(CommunityLauncher):
def get_my_peer(self, ipv8, session):
trustchain_testnet = session.config.general.testnet or session.config.trustchain.testnet
return (Peer(session.trustchain_testnet_keypair) if trustchain_testnet
else Peer(session.trustchain_keypair))
else Peer(session.trustchain_keys.keypair))

def get_bootstrappers(self, session):
from ipv8.bootstrapping.dispersy.bootstrapper import DispersyBootstrapper
Expand Down Expand Up @@ -89,6 +92,9 @@ class IPv8DiscoveryCommunityLauncher(TriblerCommunityLauncher):
class DHTCommunityLauncher(TriblerCommunityLauncher):
pass

# def load_communities(ipv8 = Provide[containers.Ipv8Container.ipv8], peer = Provide[containers.Ipv8Container.peer]):
#
# pass

def create_default_loader(config: TriblerConfig):
loader = IPv8CommunityLoader()
Expand Down
5 changes: 5 additions & 0 deletions src/tribler-core/tribler_core/modules/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from dependency_injector import containers


class CommunityContainer(containers.DeclarativeContainer):
...
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def create_session(self, hops=0, store_listen_port=True):
pe_settings.prefer_rc4 = True
ltsession.set_pe_settings(pe_settings)

mid = self.tribler_session.trustchain_keypair.key_to_hash()
mid = self.tribler_session.trustchain_keys.keypair.key_to_hash()
settings['peer_fingerprint'] = mid
settings['handshake_client_version'] = 'Tribler/' + version_id + '/' + hexlify(mid)
else:
Expand Down
27 changes: 27 additions & 0 deletions src/tribler-core/tribler_core/modules/popularity/containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dependency_injector import providers

from ipv8.peerdiscovery.discovery import RandomWalk
from tribler_core.modules import container
from tribler_core.modules.metadata_store.community.sync_strategy import RemovePeers
from tribler_core.modules.popularity.community import PopularityCommunity


class PopularityCommunityContainer(container.CommunityContainer):
config = providers.Configuration()
rqc_config = providers.Configuration()

metadata_store = providers.Dependency()
torrent_checker = providers.Dependency()

peer = providers.Dependency()
endpoint = providers.Dependency()
network = providers.Dependency()

community = providers.Factory(PopularityCommunity, peer, endpoint, network,
settings=config, rqc_settings=rqc_config,
metadata_store=metadata_store, torrent_checker=torrent_checker)

strategies = providers.List(
providers.Factory(RandomWalk, community, target_peers=30),
providers.Factory(RemovePeers, community, target_peers=-1),
)
75 changes: 13 additions & 62 deletions src/tribler-core/tribler_core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from io import StringIO
from traceback import print_exception

import tribler_core.utilities.permid as permid_module
from dependency_injector.wiring import Provide, inject

from ipv8.loader import IPv8CommunityLoader
from ipv8.taskmanager import TaskManager
from ipv8.types import Peer
from ipv8_service import IPv8
from tribler_common.network_utils import NetworkUtils
from tribler_common.sentry_reporter.sentry_reporter import SentryReporter
Expand All @@ -32,7 +34,9 @@
STATE_START_WATCH_FOLDER,
STATE_UPGRADING_READABLE,
)
from tribler_core import containers
from tribler_core.config.tribler_config import TriblerConfig
from tribler_core.containers import TrustChainKeys
from tribler_core.modules.community_loader import create_default_loader
from tribler_core.modules.metadata_store.utils import generate_test_channels
from tribler_core.modules.tracker_manager import TrackerManager
Expand Down Expand Up @@ -70,6 +74,10 @@ class Session(TaskManager):
"""
__single = None

trustchain_keys: TrustChainKeys = Provide[containers.ApplicationContainer.trustchain_keys]
# peer: Peer = Provide[containers.ApplicationContainer.peer]
ipv8: IPv8 = Provide[containers.ApplicationContainer.ipv8_container.ipv8]

def __init__(self, config: TriblerConfig, core_test_mode: bool = False,
community_loader: IPv8CommunityLoader = None):
"""
Expand All @@ -93,7 +101,6 @@ def __init__(self, config: TriblerConfig, core_test_mode: bool = False,
self.upgrader = None
self.readable_status = '' # Human-readable string to indicate the status during startup/shutdown of Tribler

self.ipv8 = None
self.ipv8_start_time = 0

self._logger = logging.getLogger(self.__class__.__name__)
Expand Down Expand Up @@ -121,8 +128,6 @@ def __init__(self, config: TriblerConfig, core_test_mode: bool = False,
self.wallets = {}
self.popularity_community = None
self.gigachannel_community = None
self.trustchain_keypair = None
self.trustchain_testnet_keypair = None

self.dht_community = None
self.payout_manager = None
Expand Down Expand Up @@ -169,34 +174,6 @@ def create_in_state_dir(path):
create_in_state_dir(STATEDIR_CHECKPOINT_DIR)
create_in_state_dir(STATEDIR_CHANNELS_DIR)

def init_keypair(self):
"""
Set parameters that depend on state_dir.
"""
keypair_filename = self.config.trustchain.get_path_as_absolute('ec_keypair_filename', self.config.state_dir)
state_dir = self.config.state_dir
if keypair_filename.exists():
self.trustchain_keypair = permid_module.read_keypair_trustchain(keypair_filename)
else:
self.trustchain_keypair = permid_module.generate_keypair_trustchain()

# Save keypair
trustchain_pubfilename = state_dir / 'ecpub_multichain.pem'
permid_module.save_keypair_trustchain(self.trustchain_keypair, keypair_filename)
permid_module.save_pub_key_trustchain(self.trustchain_keypair, trustchain_pubfilename)

testnet_keypair_filename = self.config.trustchain.get_path_as_absolute('testnet_keypair_filename',
self.config.state_dir)
if testnet_keypair_filename.exists():
self.trustchain_testnet_keypair = permid_module.read_keypair_trustchain(testnet_keypair_filename)
else:
self.trustchain_testnet_keypair = permid_module.generate_keypair_trustchain()

# Save keypair
trustchain_testnet_pubfilename = state_dir / 'ecpub_trustchain_testnet.pem'
permid_module.save_keypair_trustchain(self.trustchain_testnet_keypair, testnet_keypair_filename)
permid_module.save_pub_key_trustchain(self.trustchain_testnet_keypair, trustchain_testnet_pubfilename)

def unhandled_error_observer(self, loop, context):
"""
This method is called when an unhandled error in Tribler is observed.
Expand Down Expand Up @@ -261,14 +238,14 @@ async def start(self):

:param config: a TriblerConfig object
"""

state_dir = self.config.state_dir
self._logger.info("Session is using state directory: %s", state_dir)
self.create_state_directory_structure()
self.init_keypair()

# we have to represent `user_id` as a string to make it equal to the
# `user_id` on the GUI side
user_id_str = hexlify(self.trustchain_keypair.key.pk).encode('utf-8')
user_id_str = hexlify(self.trustchain_keys.keypair.key.pk).encode('utf-8')
SentryReporter.set_user(user_id_str)

# Start the REST API before the upgrader since we want to send interesting upgrader events over the socket
Expand Down Expand Up @@ -305,39 +282,14 @@ async def start(self):
chant_testnet = self.config.general.testnet or self.config.chant.testnet
metadata_db_name = 'metadata.db' if not chant_testnet else 'metadata_testnet.db'
database_path = state_dir / 'sqlite' / metadata_db_name
self.mds = MetadataStore(database_path, channels_dir, self.trustchain_keypair,
self.mds = MetadataStore(database_path, channels_dir, self.trustchain_keys.keypair,
notifier=self.notifier,
disable_sync=self.core_test_mode)
if self.core_test_mode:
generate_test_channels(self.mds)

# IPv8
if self.config.ipv8.enabled:
from ipv8.configuration import ConfigBuilder
from ipv8.messaging.interfaces.dispatcher.endpoint import DispatcherEndpoint
port = self.config.ipv8.port
address = self.config.ipv8.address
self._logger.info('Starting ipv8')
self._logger.info(f'Port: {port}. Address: {address}')
ipv8_config_builder = (ConfigBuilder()
.set_port(port)
.set_address(address)
.clear_overlays()
.clear_keys() # We load the keys ourselves
.set_working_directory(str(state_dir))
.set_walker_interval(self.config.ipv8.walk_interval))

if self.core_test_mode:
endpoint = DispatcherEndpoint([])
else:
# IPv8 includes IPv6 support by default.
# We only load IPv4 to not kill all Tribler overlays (currently, it would instantly crash all users).
# If you want to test IPv6 in Tribler you can set ``endpoint = None`` here.
endpoint = DispatcherEndpoint(["UDPIPv4"], UDPIPv4={'port': port,
'ip': address})
self.ipv8 = IPv8(ipv8_config_builder.finalize(),
enable_statistics=self.config.ipv8.statistics and not self.core_test_mode,
endpoint_override=endpoint)
await self.ipv8.start()

anon_proxy_ports = self.config.tunnel_community.socks5_listen_ports
Expand Down Expand Up @@ -412,7 +364,7 @@ async def start(self):
if self.config.bootstrap.enabled and not self.core_test_mode:
self.register_task('bootstrap_download', self.start_bootstrap_download)

self.notifier.notify(NTFY.TRIBLER_STARTED, self.trustchain_keypair.key.pk)
self.notifier.notify(NTFY.TRIBLER_STARTED, self.trustchain_keys.keypair.key.pk)

# If there is a config error, report to the user via GUI notifier
if self.config.error:
Expand Down Expand Up @@ -475,7 +427,6 @@ async def shutdown(self):
if self.ipv8:
self.notify_shutdown_state("Shutting down IPv8...")
await self.ipv8.stop(stop_loop=False)
self.ipv8 = None

if self.payout_manager:
await self.payout_manager.shutdown()
Expand Down
Loading