Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow optional dependencies between components #6291

Merged
merged 7 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/tribler-core/tribler_core/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,6 @@ async def start(self, failfast=True):
async def shutdown(self):
await gather(*[create_task(component.stop()) for component in self.components.values()])

def get(self, interface: Type[T]) -> T:
imp = self.components.get(interface)
if imp is None:
raise ComponentError(f"{interface.__name__} implementation not found in {self}")
return imp

def __enter__(self):
Session._stack.append(self)

Expand Down Expand Up @@ -140,19 +134,28 @@ def make_implementation(cls: Type[T], config, enable) -> T:
assert False, f"Abstract classmethod make_implementation not implemented in class {cls.__name__}"

@classmethod
def _find_implementation(cls: Type[T]) -> T:
def _find_implementation(cls: Type[T], required=True) -> T:
session = Session.current()
return session.get(cls)
imp = session.components.get(cls)
if imp is None:
if required:
raise ComponentError(f"{cls.__name__} implementation not found in {session}")
imp = cls.make_implementation(session.config, enable=False) # dummy implementation
session.register(cls, imp)
imp.started.set()
return imp

@classmethod
def imp(cls: Type[T]) -> T:
return cls._find_implementation()
def imp(cls: Type[T], required=True) -> T:
return cls._find_implementation(required=required)

async def start(self):
try:
await self.run()
except Exception as e:
print(f'\n*** Exception in {self.__class__.__name__}.start(): {type(e).__name__}:{e}\n')
# Writing to stderr is for the case when logger is not configured properly (as my happen in local tests,
# for example) to avoid silent suppression of the important exceptions
sys.stderr.write(f'\nException in {self.__class__.__name__}.start(): {type(e).__name__}:{e}\n')
self.logger.exception(f'Exception in {self.__class__.__name__}.start(): {type(e).__name__}:{e}')
self.failed = True
self.started.set()
Expand All @@ -175,8 +178,8 @@ async def run(self):
async def shutdown(self):
pass

async def use(self, dependency: Type[T]) -> T:
dep = dependency.imp()
async def use(self, dependency: Type[T], required=True) -> T:
dep = dependency.imp(required=required)
await dep.started.wait()
if dep.failed:
raise ComponentError(f'Component {self.__class__.__name__} has failed dependency {dep.__class__.__name__}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class BandwidthAccountingComponentImp(BandwidthAccountingComponent):
rest_manager: RESTManager

async def run(self):
await self.use(ReporterComponent)
await self.use(UpgradeComponent)
await self.use(ReporterComponent, required=False)
await self.use(UpgradeComponent, required=False)
config = self.session.config

ipv8_component = await self.use(Ipv8Component)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class GigaChannelComponentImp(GigaChannelComponent):
rest_manager: RESTManager

async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)

config = self.session.config
notifier = self.session.notifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class GigachannelManagerComponentImp(GigachannelManagerComponent):
rest_manager: RESTManager

async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)

config = self.session.config
notifier = self.session.notifier
Expand Down
24 changes: 15 additions & 9 deletions src/tribler-core/tribler_core/components/implementation/ipv8.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from ipv8.bootstrapping.dispersy.bootstrapper import DispersyBootstrapper
from ipv8.configuration import ConfigBuilder, DISPERSY_BOOTSTRAPPER
from ipv8.dht.churn import PingChurn
Expand All @@ -24,15 +26,16 @@

class Ipv8ComponentImp(Ipv8Component):
task_manager: TaskManager
rest_manager: RESTManager
rest_manager: Optional[RESTManager]

async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)

config = self.session.config

rest_component = await self.use(RESTComponent)
rest_manager = self.rest_manager = rest_component.rest_manager
rest_component = await self.use(RESTComponent, required=False)
rest_manager = self.rest_manager = rest_component.rest_manager if rest_component.enabled else None

self.task_manager = TaskManager()

port = config.ipv8.port
Expand Down Expand Up @@ -76,7 +79,8 @@ async def run(self):
config.ipv8.walk_interval,
config.ipv8.walk_scaling_upper_limit).start(self.task_manager)

rest_manager.get_endpoint('statistics').ipv8 = ipv8
if rest_manager:
rest_manager.get_endpoint('statistics').ipv8 = ipv8

self.peer_discovery_community = self.dht_discovery_community = None

Expand All @@ -93,9 +97,10 @@ async def run(self):
endpoints_to_init = ['/asyncio', '/attestation', '/dht', '/identity',
'/isolation', '/network', '/noblockdht', '/overlays']

for path, endpoint in rest_manager.get_endpoint('ipv8').endpoints.items():
if path in endpoints_to_init:
endpoint.initialize(ipv8)
if rest_manager:
for path, endpoint in rest_manager.get_endpoint('ipv8').endpoints.items():
if path in endpoints_to_init:
endpoint.initialize(ipv8)

def make_bootstrapper(self) -> DispersyBootstrapper:
args = DISPERSY_BOOTSTRAPPER['init']
Expand Down Expand Up @@ -124,7 +129,8 @@ def init_dht_discovery_community(self):
self.dht_discovery_community = community

async def shutdown(self):
self.rest_manager.get_endpoint('statistics').ipv8 = None
if self.rest_manager:
self.rest_manager.get_endpoint('statistics').ipv8 = None
await self.release(RESTComponent)

await self.unused.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class LibtorrentComponentImp(LibtorrentComponent):
rest_manager: RESTManager

async def run(self):
await self.use(ReporterComponent)
await self.use(UpgradeComponent)
await self.use(ReporterComponent, required=False)
await self.use(UpgradeComponent, required=False)
socks_ports = (await self.use(SocksServersComponent)).socks_ports
masterkey = await self.use(MasterKeyComponent)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class MetadataStoreComponentImp(MetadataStoreComponent):
rest_manager: RESTManager

async def run(self):
await self.use(ReporterComponent)
await self.use(UpgradeComponent)
await self.use(ReporterComponent, required=False)
await self.use(UpgradeComponent, required=False)
rest_manager = self.rest_manager = (await self.use(RESTComponent)).rest_manager

config = self.session.config
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from ipv8.dht.routing import RoutingTable
from ipv8.messaging.interfaces.udp.endpoint import UDPv4Address

from tribler_common.simpledefs import NTFY

from tribler_core.components.interfaces.bandwidth_accounting import BandwidthAccountingComponent
Expand All @@ -14,7 +11,7 @@

class PayoutComponentImp(PayoutComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)

config = self.session.config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class PopularityComponentImp(PopularityComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)

config = self.session.config
ipv8_component = await self.use(Ipv8Component)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

class ResourceMonitorComponentImp(ResourceMonitorComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(UpgradeComponent)
await self.use(ReporterComponent, required=False)
await self.use(UpgradeComponent, required=False)
tunnel_community = (await self.use(TunnelsComponent)).community

config = self.session.config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class RESTComponentImp(RESTComponent):

async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)
session = self.session
config = session.config
notifier = session.notifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class SocksServersComponentImp(SocksServersComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)
self.socks_servers = []
self.socks_ports = []
# Start the SOCKS5 servers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TorrentCheckerComponentImp(TorrentCheckerComponent):
rest_manager: RESTManager

async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)

config = self.session.config

Expand Down
23 changes: 16 additions & 7 deletions src/tribler-core/tribler_core/components/implementation/tunnels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@

class TunnelsComponentImp(TunnelsComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)

config = self.session.config
ipv8_component = await self.use(Ipv8Component)
ipv8 = ipv8_component.ipv8
peer = ipv8_component.peer
dht_discovery_community = ipv8_component.dht_discovery_community

bandwidth_community = (await self.use(BandwidthAccountingComponent)).community
download_manager = (await self.use(LibtorrentComponent)).download_manager
rest_manager = (await self.use(RESTComponent)).rest_manager
socks_servers = (await self.use(SocksServersComponent)).socks_servers
bandwidth_component = await self.use(BandwidthAccountingComponent, required=False)
bandwidth_community = bandwidth_component.community if bandwidth_component.enabled else None

download_component = await self.use(LibtorrentComponent, required=False)
download_manager = download_component.download_manager if download_component.enabled else None

rest_component = await self.use(RESTComponent, required=False)
rest_manager = rest_component.rest_manager if rest_component.enabled else None

socks_servers_component = await self.use(SocksServersComponent, required=False)
socks_servers = socks_servers_component.socks_servers if socks_servers_component.enabled else None

settings = TunnelSettings()
settings.min_circuits = config.tunnel_community.min_circuits
Expand Down Expand Up @@ -56,8 +63,10 @@ async def run(self):

self.community = community

rest_manager.get_endpoint('downloads').tunnel_community = community
rest_manager.get_endpoint('ipv8').endpoints['/tunnel'].initialize(ipv8)
if rest_component.enabled:
rest_manager.get_endpoint('ipv8').endpoints['/tunnel'].initialize(ipv8)
if download_component.enabled:
rest_manager.get_endpoint('downloads').tunnel_community = community

async def shutdown(self):
await self.community.unload()
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class UpgradeComponentImp(UpgradeComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)
config = self.session.config
notifier = self.session.notifier
masterkey = await self.use(MasterKeyComponent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

class VersionCheckComponentImp(VersionCheckComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(UpgradeComponent)
await self.use(ReporterComponent, required=False)
await self.use(UpgradeComponent, required=False)

notifier = self.session.notifier

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class WatchFolderComponentImp(WatchFolderComponent):
async def run(self):
await self.use(ReporterComponent)
await self.use(ReporterComponent, required=False)
config = self.session.config
notifier = self.session.notifier
download_manager = (await self.use(LibtorrentComponent)).download_manager
Expand Down
Loading