Skip to content

Commit

Permalink
Optimized rendezvous storage and hook
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink committed Oct 13, 2023
1 parent e99dba8 commit c9c5ff1
Show file tree
Hide file tree
Showing 21 changed files with 268 additions and 436 deletions.
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ sphinxcontrib-openapi==0.7.0
configobj==5.0.6
mistune==0.8.4 # sphinxcontrib-openapi==0.7.0 cannot work with the latest mistune version (2.0.0)
MarkupSafe==2.0.1 # used by jinja2; 2.1.0 version removes soft_unicode and breaks jinja2-2.11.3
pyipv8==2.8.0
pyipv8==2.11.0
setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
13 changes: 13 additions & 0 deletions src/tribler/core/components/ipv8/ipv8_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from ipv8_service import IPv8

from tribler.core.components.component import Component
from tribler.core.components.ipv8.rendezvous.db.database import RendezvousDatabase
from tribler.core.components.ipv8.rendezvous.rendezvous_hook import RendezvousHook
from tribler.core.components.key.key_component import KeyComponent
from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR

INFINITE = -1

Expand All @@ -30,13 +33,21 @@ class Ipv8Component(Component):
_task_manager: TaskManager
_peer_discovery_community: Optional[DiscoveryCommunity] = None

RENDEZVOUS_DB_NAME = 'rendezvous.db'
rendezvous_db: RendezvousDatabase
rendevous_hook: RendezvousHook

async def run(self):
await super().run()

config = self.session.config

self._task_manager = TaskManager()

self.rendezvous_db = RendezvousDatabase(
db_path=self.session.config.state_dir / STATEDIR_DB_DIR / self.RENDEZVOUS_DB_NAME)
self.rendevous_hook = RendezvousHook(self.rendezvous_db)

port = config.ipv8.port
address = config.ipv8.address
self.logger.info('Starting ipv8')
Expand All @@ -60,6 +71,7 @@ async def run(self):
ipv8 = IPv8(ipv8_config_builder.finalize(),
enable_statistics=config.ipv8.statistics and not config.gui_test_mode,
endpoint_override=endpoint)
ipv8.network.add_peer_observer(self.rendevous_hook)
await ipv8.start()
self.ipv8 = ipv8

Expand Down Expand Up @@ -135,5 +147,6 @@ async def shutdown(self):
if overlay:
await self.ipv8.unload_overlay(overlay)

self.rendevous_hook.shutdown(self.ipv8.network)
await self._task_manager.shutdown_task_manager()
await self.ipv8.stop(stop_loop=False)
39 changes: 39 additions & 0 deletions src/tribler/core/components/ipv8/rendezvous/db/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Union

from pony.orm import Database, db_session, select

from ipv8.peer import Peer
from tribler.core.components.ipv8.rendezvous.db.orm_bindings import certificate
from tribler.core.utilities.utilities import MEMORY_DB

if TYPE_CHECKING:
from tribler.core.components.ipv8.rendezvous.db.orm_bindings.certificate import RendezvousCertificate


class RendezvousDatabase:

def __init__(self, db_path: Union[Path, type(MEMORY_DB)]) -> None:
create_db = db_path is MEMORY_DB or not db_path.is_file()
db_path_string = ":memory:" if db_path is MEMORY_DB else str(db_path)

self.database = Database()
self.Certificate = certificate.define_binding(self.database)
self.database.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0)
self.database.generate_mapping(create_tables=create_db)

def add(self, peer: Peer, start_timestamp: float, stop_timestamp: float) -> None:
with db_session(immediate=True):
self.Certificate(public_key=peer.public_key.key_to_bin(),
start=start_timestamp,
stop=stop_timestamp)

def get(self, peer: Peer) -> list[RendezvousCertificate]:
with db_session():
return select(certificate for certificate in self.Certificate
if certificate.public_key == peer.public_key.key_to_bin()).fetch()

def shutdown(self) -> None:
self.database.disconnect()
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import dataclasses
from typing import TYPE_CHECKING

from pony.orm import Required

if TYPE_CHECKING:
@dataclasses.dataclass
class RendezvousCertificate:
public_key: bytes
start: float
stop: float


def define_binding(db):
class RendezvousCertificate(db.Entity):
public_key = Required(bytes, index=True)
start = Required(float)
stop = Required(float)

return RendezvousCertificate
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Generator

import pytest

from ipv8.keyvault.crypto import default_eccrypto
from ipv8.peer import Peer
from tribler.core.components.ipv8.rendezvous.db.database import RendezvousDatabase
from tribler.core.utilities.utilities import MEMORY_DB


@pytest.fixture(name="memdb", scope="function")
def fixture_memory_database() -> Generator[RendezvousDatabase, None, None]:
db = RendezvousDatabase(MEMORY_DB)

yield db

db.shutdown()


def generate_peer() -> Peer:
public_key = default_eccrypto.generate_key("curve25519").pub()
return Peer(public_key)


@pytest.fixture(name="peer", scope="module")
def fixture_peer() -> Generator[Peer, None, None]:
yield generate_peer()


@pytest.fixture(name="peer2", scope="function")
def fixture_peer2() -> Generator[Peer, None, None]:
yield generate_peer()


def test_retrieve_no_certificates(peer: Peer, memdb: RendezvousDatabase) -> None:
retrieved = memdb.get(peer)

assert len(retrieved) == 0


def test_retrieve_single_certificate(peer: Peer, memdb: RendezvousDatabase) -> None:
start_timestamp, stop_timestamp = range(1, 3)
memdb.add(peer, start_timestamp, stop_timestamp)

retrieved = memdb.get(peer)

assert len(retrieved) == 1
assert retrieved[0].start, retrieved[0].stop == (start_timestamp, stop_timestamp)


def test_retrieve_multiple_certificates(peer: Peer, memdb: RendezvousDatabase) -> None:
start_timestamp1, stop_timestamp1, start_timestamp2, stop_timestamp2 = range(1, 5)
memdb.add(peer, start_timestamp1, stop_timestamp1)
memdb.add(peer, start_timestamp2, stop_timestamp2)

retrieved = memdb.get(peer)

assert len(retrieved) == 2
assert retrieved[0].start, retrieved[0].stop == (start_timestamp1, stop_timestamp1)
assert retrieved[1].start, retrieved[1].stop == (start_timestamp2, stop_timestamp2)


def test_retrieve_filter_certificates(peer: Peer, peer2: Peer, memdb: RendezvousDatabase) -> None:
start_timestamp1, stop_timestamp1, start_timestamp2, stop_timestamp2 = range(1, 5)
memdb.add(peer, start_timestamp1, stop_timestamp1)
memdb.add(peer2, start_timestamp2, stop_timestamp2)

retrieved = memdb.get(peer)

assert len(retrieved) == 1
assert retrieved[0].start, retrieved[0].stop == (start_timestamp1, stop_timestamp1)
31 changes: 31 additions & 0 deletions src/tribler/core/components/ipv8/rendezvous/rendezvous_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
import time

from ipv8.peerdiscovery.network import Network, PeerObserver
from ipv8.types import Peer
from tribler.core.components.ipv8.rendezvous.db.database import RendezvousDatabase


class RendezvousHook(PeerObserver):

def __init__(self, rendezvous_db: RendezvousDatabase) -> None:
self.rendezvous_db = rendezvous_db

def shutdown(self, network: Network) -> None:
for peer in network.verified_peers:
self.on_peer_removed(peer)
if self.rendezvous_db:
self.rendezvous_db.shutdown()

@property
def current_time(self) -> float:
return time.time()

def on_peer_added(self, peer: Peer) -> None:
pass

def on_peer_removed(self, peer: Peer) -> None:
if self.current_time >= peer.creation_time:
self.rendezvous_db.add(peer, peer.creation_time, self.current_time)
else:
logging.exception("%s was first seen in the future! Something is seriously wrong!", peer)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from typing import Generator

import pytest

from ipv8.keyvault.crypto import default_eccrypto
from ipv8.peer import Peer
from ipv8.peerdiscovery.network import Network
from tribler.core.components.ipv8.rendezvous.db.database import RendezvousDatabase
from tribler.core.components.ipv8.rendezvous.rendezvous_hook import RendezvousHook
from tribler.core.utilities.utilities import MEMORY_DB


class MockedRendezvousHook(RendezvousHook):

def __init__(self, rendezvous_db: RendezvousDatabase, mocked_time: float | None = None) -> None:
super().__init__(rendezvous_db)
self.mocked_time = mocked_time

@property
def current_time(self) -> float:
if self.mocked_time is None:
return super().current_time
return self.mocked_time


@pytest.fixture(name="memdb", scope="function")
def fixture_memory_database() -> Generator[RendezvousDatabase, None, None]:
db = RendezvousDatabase(MEMORY_DB)

yield db

db.shutdown()


@pytest.fixture(name="hook", scope="function")
def fixture_hook(memdb: RendezvousDatabase) -> Generator[MockedRendezvousHook, None, None]:
hook = MockedRendezvousHook(memdb)

yield hook

hook.shutdown(Network())


@pytest.fixture(name="peer", scope="module")
def fixture_peer() -> Generator[Peer, None, None]:
public_key = default_eccrypto.generate_key("curve25519").pub()
yield Peer(public_key)


def test_peer_added(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None:
hook.on_peer_added(peer)

retrieved = memdb.get(peer)
assert len(retrieved) == 0


def test_peer_removed(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None:
hook.on_peer_added(peer)

hook.mocked_time = peer.creation_time + 1.0
hook.on_peer_removed(peer)

retrieved = memdb.get(peer)
assert len(retrieved) == 1
assert retrieved[0].start, retrieved[0].stop == (peer.creation_time, hook.mocked_time)


def test_peer_store_on_shutdown(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None:
network = Network()
network.add_verified_peer(peer)
hook.on_peer_added(peer)
hook.mocked_time = peer.creation_time + 1.0

hook.shutdown(network)

retrieved = memdb.get(peer)
assert len(retrieved) == 1
assert retrieved[0].start, retrieved[0].stop == (peer.creation_time, hook.mocked_time)


def test_peer_ignore_future(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None:
hook.on_peer_added(peer)

hook.mocked_time = peer.creation_time - 1.0
hook.on_peer_removed(peer)

retrieved = memdb.get(peer)
assert len(retrieved) == 0
Loading

0 comments on commit c9c5ff1

Please sign in to comment.