-
Notifications
You must be signed in to change notification settings - Fork 452
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimized rendezvous storage and hook
- Loading branch information
Showing
20 changed files
with
267 additions
and
435 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
39 changes: 39 additions & 0 deletions
39
src/tribler/core/components/ipv8/rendezvous/db/database.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Union | ||
|
||
from pony.orm import Database, db_session, select | ||
|
||
from ipv8.peer import Peer | ||
from tribler.core.components.ipv8.rendezvous.db.orm_bindings import certificate | ||
from tribler.core.utilities.utilities import MEMORY_DB | ||
|
||
if TYPE_CHECKING: | ||
from tribler.core.components.ipv8.rendezvous.db.orm_bindings.certificate import RendezvousCertificate | ||
|
||
|
||
class RendezvousDatabase: | ||
|
||
def __init__(self, db_path: Union[Path, type(MEMORY_DB)]) -> None: | ||
create_db = db_path is MEMORY_DB or not db_path.is_file() | ||
db_path_string = ":memory:" if db_path is MEMORY_DB else str(db_path) | ||
|
||
self.database = Database() | ||
self.Certificate = certificate.define_binding(self.database) | ||
self.database.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0) | ||
self.database.generate_mapping(create_tables=create_db) | ||
|
||
def add(self, peer: Peer, start_timestamp: float, stop_timestamp: float) -> None: | ||
with db_session(immediate=True): | ||
self.Certificate(public_key=peer.public_key.key_to_bin(), | ||
start=start_timestamp, | ||
stop=stop_timestamp) | ||
|
||
def get(self, peer: Peer) -> list[RendezvousCertificate]: | ||
with db_session(): | ||
return select(certificate for certificate in self.Certificate | ||
if certificate.public_key == peer.public_key.key_to_bin()).fetch() | ||
|
||
def shutdown(self) -> None: | ||
self.database.disconnect() |
File renamed without changes.
20 changes: 20 additions & 0 deletions
20
src/tribler/core/components/ipv8/rendezvous/db/orm_bindings/certificate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import dataclasses | ||
from typing import TYPE_CHECKING | ||
|
||
from pony.orm import Required | ||
|
||
if TYPE_CHECKING: | ||
@dataclasses.dataclass | ||
class RendezvousCertificate: | ||
public_key: bytes | ||
start: float | ||
stop: float | ||
|
||
|
||
def define_binding(db): | ||
class RendezvousCertificate(db.Entity): | ||
public_key = Required(bytes, index=True) | ||
start = Required(float) | ||
stop = Required(float) | ||
|
||
return RendezvousCertificate |
File renamed without changes.
71 changes: 71 additions & 0 deletions
71
src/tribler/core/components/ipv8/rendezvous/db/tests/test_database.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from typing import Generator | ||
|
||
import pytest | ||
|
||
from ipv8.keyvault.crypto import default_eccrypto | ||
from ipv8.peer import Peer | ||
from tribler.core.components.ipv8.rendezvous.db.database import RendezvousDatabase | ||
from tribler.core.utilities.utilities import MEMORY_DB | ||
|
||
|
||
@pytest.fixture(name="memdb", scope="function") | ||
def fixture_memory_database() -> Generator[RendezvousDatabase, None, None]: | ||
db = RendezvousDatabase(MEMORY_DB) | ||
|
||
yield db | ||
|
||
db.shutdown() | ||
|
||
|
||
def generate_peer() -> Peer: | ||
public_key = default_eccrypto.generate_key("curve25519").pub() | ||
return Peer(public_key) | ||
|
||
|
||
@pytest.fixture(name="peer", scope="module") | ||
def fixture_peer() -> Generator[Peer, None, None]: | ||
yield generate_peer() | ||
|
||
|
||
@pytest.fixture(name="peer2", scope="function") | ||
def fixture_peer2() -> Generator[Peer, None, None]: | ||
yield generate_peer() | ||
|
||
|
||
def test_retrieve_no_certificates(peer: Peer, memdb: RendezvousDatabase) -> None: | ||
retrieved = memdb.get(peer) | ||
|
||
assert len(retrieved) == 0 | ||
|
||
|
||
def test_retrieve_single_certificate(peer: Peer, memdb: RendezvousDatabase) -> None: | ||
start_timestamp, stop_timestamp = range(1, 3) | ||
memdb.add(peer, start_timestamp, stop_timestamp) | ||
|
||
retrieved = memdb.get(peer) | ||
|
||
assert len(retrieved) == 1 | ||
assert retrieved[0].start, retrieved[0].stop == (start_timestamp, stop_timestamp) | ||
|
||
|
||
def test_retrieve_multiple_certificates(peer: Peer, memdb: RendezvousDatabase) -> None: | ||
start_timestamp1, stop_timestamp1, start_timestamp2, stop_timestamp2 = range(1, 5) | ||
memdb.add(peer, start_timestamp1, stop_timestamp1) | ||
memdb.add(peer, start_timestamp2, stop_timestamp2) | ||
|
||
retrieved = memdb.get(peer) | ||
|
||
assert len(retrieved) == 2 | ||
assert retrieved[0].start, retrieved[0].stop == (start_timestamp1, stop_timestamp1) | ||
assert retrieved[1].start, retrieved[1].stop == (start_timestamp2, stop_timestamp2) | ||
|
||
|
||
def test_retrieve_filter_certificates(peer: Peer, peer2: Peer, memdb: RendezvousDatabase) -> None: | ||
start_timestamp1, stop_timestamp1, start_timestamp2, stop_timestamp2 = range(1, 5) | ||
memdb.add(peer, start_timestamp1, stop_timestamp1) | ||
memdb.add(peer2, start_timestamp2, stop_timestamp2) | ||
|
||
retrieved = memdb.get(peer) | ||
|
||
assert len(retrieved) == 1 | ||
assert retrieved[0].start, retrieved[0].stop == (start_timestamp1, stop_timestamp1) |
31 changes: 31 additions & 0 deletions
31
src/tribler/core/components/ipv8/rendezvous/rendezvous_hook.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import logging | ||
import time | ||
|
||
from ipv8.peerdiscovery.network import Network, PeerObserver | ||
from ipv8.types import Peer | ||
from tribler.core.components.ipv8.rendezvous.db.database import RendezvousDatabase | ||
|
||
|
||
class RendezvousHook(PeerObserver): | ||
|
||
def __init__(self, rendezvous_db: RendezvousDatabase) -> None: | ||
self.rendezvous_db = rendezvous_db | ||
|
||
def shutdown(self, network: Network) -> None: | ||
for peer in network.verified_peers: | ||
self.on_peer_removed(peer) | ||
if self.rendezvous_db: | ||
self.rendezvous_db.shutdown() | ||
|
||
@property | ||
def current_time(self) -> float: | ||
return time.time() | ||
|
||
def on_peer_added(self, peer: Peer) -> None: | ||
pass | ||
|
||
def on_peer_removed(self, peer: Peer) -> None: | ||
if self.current_time >= peer.creation_time: | ||
self.rendezvous_db.add(peer, peer.creation_time, self.current_time) | ||
else: | ||
logging.exception("%s was first seen in the future! Something is seriously wrong!", peer) |
Empty file.
90 changes: 90 additions & 0 deletions
90
src/tribler/core/components/ipv8/rendezvous/tests/test_rendezvous_hook.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Generator | ||
|
||
import pytest | ||
|
||
from ipv8.keyvault.crypto import default_eccrypto | ||
from ipv8.peer import Peer | ||
from ipv8.peerdiscovery.network import Network | ||
from tribler.core.components.ipv8.rendezvous.db.database import RendezvousDatabase | ||
from tribler.core.components.ipv8.rendezvous.rendezvous_hook import RendezvousHook | ||
from tribler.core.utilities.utilities import MEMORY_DB | ||
|
||
|
||
class MockedRendezvousHook(RendezvousHook): | ||
|
||
def __init__(self, rendezvous_db: RendezvousDatabase, mocked_time: float | None = None) -> None: | ||
super().__init__(rendezvous_db) | ||
self.mocked_time = mocked_time | ||
|
||
@property | ||
def current_time(self) -> float: | ||
if self.mocked_time is None: | ||
return super().current_time | ||
return self.mocked_time | ||
|
||
|
||
@pytest.fixture(name="memdb", scope="function") | ||
def fixture_memory_database() -> Generator[RendezvousDatabase, None, None]: | ||
db = RendezvousDatabase(MEMORY_DB) | ||
|
||
yield db | ||
|
||
db.shutdown() | ||
|
||
|
||
@pytest.fixture(name="hook", scope="function") | ||
def fixture_hook(memdb: RendezvousDatabase) -> Generator[MockedRendezvousHook, None, None]: | ||
hook = MockedRendezvousHook(memdb) | ||
|
||
yield hook | ||
|
||
hook.shutdown(Network()) | ||
|
||
|
||
@pytest.fixture(name="peer", scope="module") | ||
def fixture_peer() -> Generator[Peer, None, None]: | ||
public_key = default_eccrypto.generate_key("curve25519").pub() | ||
yield Peer(public_key) | ||
|
||
|
||
def test_peer_added(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None: | ||
hook.on_peer_added(peer) | ||
|
||
retrieved = memdb.get(peer) | ||
assert len(retrieved) == 0 | ||
|
||
|
||
def test_peer_removed(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None: | ||
hook.on_peer_added(peer) | ||
|
||
hook.mocked_time = peer.creation_time + 1.0 | ||
hook.on_peer_removed(peer) | ||
|
||
retrieved = memdb.get(peer) | ||
assert len(retrieved) == 1 | ||
assert retrieved[0].start, retrieved[0].stop == (peer.creation_time, hook.mocked_time) | ||
|
||
|
||
def test_peer_store_on_shutdown(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None: | ||
network = Network() | ||
network.add_verified_peer(peer) | ||
hook.on_peer_added(peer) | ||
hook.mocked_time = peer.creation_time + 1.0 | ||
|
||
hook.shutdown(network) | ||
|
||
retrieved = memdb.get(peer) | ||
assert len(retrieved) == 1 | ||
assert retrieved[0].start, retrieved[0].stop == (peer.creation_time, hook.mocked_time) | ||
|
||
|
||
def test_peer_ignore_future(peer: Peer, hook: MockedRendezvousHook, memdb: RendezvousDatabase) -> None: | ||
hook.on_peer_added(peer) | ||
|
||
hook.mocked_time = peer.creation_time - 1.0 | ||
hook.on_peer_removed(peer) | ||
|
||
retrieved = memdb.get(peer) | ||
assert len(retrieved) == 0 |
Oops, something went wrong.