Skip to content

Commit

Permalink
Add tracking of peers subscribed to channels (#6034)
Browse files Browse the repository at this point in the history
  • Loading branch information
ichorid authored Mar 31, 2021
1 parent 52b29cb commit e1e2fab
Show file tree
Hide file tree
Showing 9 changed files with 483 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import uuid
from binascii import unhexlify
from collections import defaultdict
from dataclasses import dataclass
from random import sample

from ipv8.peerdiscovery.network import Network
from ipv8.types import Peer

from pony.orm import db_session

Expand All @@ -22,12 +24,55 @@
max_entries = maximum_payload_size // minimal_blob_size
max_search_peers = 5


MAGIC_GIGACHAN_VERSION_MARK = b'\x01'


@dataclass
class ChannelEntry:
timestamp: float
channel_version: int


class ChannelsPeersMapping:
def __init__(self, max_peers_per_channel=10):
self.max_peers_per_channel = max_peers_per_channel
self._channels_dict = defaultdict(set)
# Reverse mapping from peers to channels
self._peers_channels = defaultdict(set)

def add(self, peer: Peer, channel_pk: bytes, channel_id: int):
id_tuple = (channel_pk, channel_id)
channel_peers = self._channels_dict[id_tuple]

channel_peers.add(peer)
self._peers_channels[peer].add(id_tuple)

if len(channel_peers) > self.max_peers_per_channel:
removed_peer = min(channel_peers, key=lambda x: x.last_response)
channel_peers.remove(removed_peer)
# Maintain the reverse mapping
self._peers_channels[removed_peer].remove(id_tuple)
if not self._peers_channels[removed_peer]:
self._peers_channels.pop(removed_peer)

def remove_peer(self, peer):
for id_tuple in self._peers_channels[peer]:
self._channels_dict.pop(id_tuple, None)
self._peers_channels.pop(peer)

def get_last_seen_peers_for_channel(self, channel_pk: bytes, channel_id: int, limit=None):
id_tuple = (channel_pk, channel_id)
channel_peers = self._channels_dict.get(id_tuple, [])
return sorted(channel_peers, key=lambda x: x.last_response, reverse=True)[0:limit]


@dataclass
class GigaChannelCommunitySettings(RemoteQueryCommunitySettings):
queried_peers_limit: int = 1000
# The maximum number of peers that we got from channels to peers mapping,
# that must be queried in addition to randomly queried peers
max_mapped_query_peers = 3


class GigaChannelCommunity(RemoteQueryCommunity):
Expand Down Expand Up @@ -55,12 +100,13 @@ def __init__(self, my_peer, endpoint, network, metadata_store, **kwargs):
# This set contains all the peers that we queried for subscribed channels over time.
# It is emptied regularly. The purpose of this set is to work as a filter so we never query the same
# peer twice. If we do, this should happen really rarely
# TODO: use Bloom filter here instead. We actually *want* it to be all-false-positives eventually.
self.queried_peers = set()

self.discovery_booster = DiscoveryBooster(timeout_in_sec=30)
self.discovery_booster = DiscoveryBooster(timeout_in_sec=60)
self.discovery_booster.apply(self)

self.channels_peers = ChannelsPeersMapping()

def get_random_peers(self, sample_size=None):
# Randomly sample sample_size peers from the complete list of our peers
all_peers = self.get_peers()
Expand All @@ -82,6 +128,7 @@ def on_packet_callback(_, processing_results):
with db_session:
for c in (r.md_obj for r in processing_results if r.md_obj.metadata_type == CHANNEL_TORRENT):
self.mds.vote_bump(c.public_key, c.id_, peer.public_key.key_to_bin()[10:])
self.channels_peers.add(peer, c.public_key, c.id_)

# Notify GUI about the new channels
results = [
Expand Down Expand Up @@ -117,11 +164,32 @@ def notify_gui(_, processing_results):
if self.notifier and results:
self.notifier.notify(NTFY.REMOTE_QUERY_RESULTS, {"results": results, "uuid": str(request_uuid)})

for p in self.get_random_peers(self.settings.max_query_peers):
# Try sending the request to at least some peers that we know have it
if "channel_pk" in kwargs and "origin_id" in kwargs:
peers_to_query = self.get_known_subscribed_peers_for_node(
unhexlify(kwargs["channel_pk"]), kwargs["origin_id"], self.settings.max_mapped_query_peers
)
else:
peers_to_query = self.get_random_peers(self.settings.max_query_peers)

for p in peers_to_query:
self.send_remote_select(p, **kwargs, processing_callback=notify_gui)

return request_uuid

def get_known_subscribed_peers_for_node(self, node_pk, node_id, limit=None):
# Determine the toplevel parent channel
with db_session:
node = self.mds.ChannelNode.get(public_key=node_pk, id_=node_id)
root_id = next((value for value in node.get_parents_ids() if value != 0), node_id) if node else node_id

return self.channels_peers.get_last_seen_peers_for_channel(node_pk, root_id, limit)

def _on_query_timeout(self, request_cache):
if not request_cache.peer_responded:
self.channels_peers.remove_peer(request_cache.peer)
super()._on_query_timeout(request_cache)


class GigaChannelTestnetCommunity(GigaChannelCommunity):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pony.orm.dbapiprovider import OperationalError

from tribler_core.modules.metadata_store.community.eva_protocol import EVAProtocolMixin
from tribler_core.modules.metadata_store.orm_bindings.channel_metadata import entries_to_chunk
from tribler_core.modules.metadata_store.orm_bindings.channel_metadata import LZ4_EMPTY_ARCHIVE, entries_to_chunk
from tribler_core.modules.metadata_store.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT
from tribler_core.modules.metadata_store.store import ObjState
from tribler_core.utilities.unicode import hexlify
Expand Down Expand Up @@ -59,7 +59,7 @@ class SelectResponsePayload(VariablePayload):


class SelectRequest(RandomNumberCache):
def __init__(self, request_cache, prefix, request_kwargs, processing_callback=None):
def __init__(self, request_cache, prefix, request_kwargs, peer, processing_callback=None, timeout_callback=None):
super().__init__(request_cache, prefix)
self.request_kwargs = request_kwargs
# The callback to call on results of processing of the response payload
Expand All @@ -68,8 +68,15 @@ def __init__(self, request_cache, prefix, request_kwargs, processing_callback=No
# This limit is imposed as a safety precaution to prevent spam/flooding
self.packets_limit = 10

self.peer = peer
# Indicate if at least a single packet was returned by the queried peer.
self.peer_responded = False

self.timeout_callback = timeout_callback

def on_timeout(self):
pass
if self.timeout_callback is not None:
self.timeout_callback(self)


class PushbackWindow(NumberCache):
Expand Down Expand Up @@ -137,7 +144,14 @@ def on_error(self, peer, exception):
self.logger.warning(f"EVA transfer error: peer {hexlify(peer.mid)}, exception: {exception}")

def send_remote_select(self, peer, processing_callback=None, force_eva_response=False, **kwargs):
request = SelectRequest(self.request_cache, hexlify(peer.mid), kwargs, processing_callback)
request = SelectRequest(
self.request_cache,
hexlify(peer.mid),
kwargs,
peer,
processing_callback=processing_callback,
timeout_callback=self._on_query_timeout,
)
self.request_cache.add(request)

self.logger.info(f"Select to {hexlify(peer.mid)} with ({kwargs})")
Expand All @@ -158,6 +172,12 @@ async def process_rpc_query(self, json_bytes: bytes):
return await self.mds.get_entries_threaded(**request_sanitized)

def send_db_results(self, peer, request_payload_id, db_results, force_eva_response=False):

# Special case of empty results list - sending empty lz4 archive
if len(db_results) == 0:
self.ez_send(peer, SelectResponsePayload(request_payload_id, LZ4_EMPTY_ARCHIVE))
return

index = 0
while index < len(db_results):
transfer_size = (
Expand Down Expand Up @@ -207,6 +227,10 @@ async def on_remote_select_response(self, peer, response_payload):
if request is None:
return

# Remember that at least a single packet was received was received from the queried peer.
if isinstance(request, SelectRequest):
request.peer_responded = True

# Check for limit on the number of packets per request
if request.packets_limit > 1:
request.packets_limit -= 1
Expand Down Expand Up @@ -249,6 +273,16 @@ async def on_remote_select_response(self, peer, response_payload):
if isinstance(request, SelectRequest) and request.processing_callback:
request.processing_callback(request, processing_results)

def _on_query_timeout(self, request_cache):
if not request_cache.peer_responded:
self.logger.info(
"Remote query timeout, deleting peer: %s %s %s",
str(request_cache.peer.address),
hexlify(request_cache.peer.mid),
str(request_cache.request_kwargs),
)
self.network.remove_peer(request_cache.peer)

async def unload(self):
await self.request_cache.shutdown()
await super().unload()
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from datetime import datetime
from unittest import mock
from unittest.mock import Mock
from unittest.mock import Mock, PropertyMock, patch

from ipv8.database import database_blob
from ipv8.keyvault.crypto import default_eccrypto
Expand All @@ -9,17 +10,23 @@

from pony.orm import db_session

import pytest

from tribler_core.modules.metadata_store.community.gigachannel_community import (
ChannelsPeersMapping,
GigaChannelCommunity,
MAGIC_GIGACHAN_VERSION_MARK,
)
from tribler_core.modules.metadata_store.store import MetadataStore
from tribler_core.notifier import Notifier
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.random_utils import random_infohash
from tribler_core.utilities.unicode import hexlify

EMPTY_BLOB = database_blob(b"")

# pylint:disable=protected-access


class TestGigaChannelUnits(TestBase):
def setUp(self):
Expand Down Expand Up @@ -178,3 +185,103 @@ def mock_notify(overlay, args):

# Check that the notifier callback is called on new channel entries
self.assertTrue(self.nodes[1].overlay.notified_results)

def test_channels_peers_mapping_drop_excess_peers(self):
"""
Test dropping old excess peers from a channel to peers mapping
"""
mapping = ChannelsPeersMapping()
chan_pk = Mock()
chan_id = 123

num_excess_peers = 20
first_peer_timestamp = None
for k in range(0, mapping.max_peers_per_channel + num_excess_peers):
peer = Peer(default_eccrypto.generate_key("very-low"), ("1.2.3.4", 5))
peer.last_response = time.time()
mapping.add(peer, chan_pk, chan_id)
if k == 0:
first_peer_timestamp = peer.last_response

chan_peers_3 = mapping.get_last_seen_peers_for_channel(chan_pk, chan_id, 3)
assert len(chan_peers_3) == 3

chan_peers = mapping.get_last_seen_peers_for_channel(chan_pk, chan_id)
assert len(chan_peers) == mapping.max_peers_per_channel

assert chan_peers_3 == chan_peers[0:3]
assert chan_peers == sorted(chan_peers, key=lambda x: x.last_response, reverse=True)

# Make sure only the older peers are dropped as excess
for p in chan_peers:
assert p.last_response > first_peer_timestamp

# Test removing a peer directly, e.g. as a result of a query timeout
peer = Peer(default_eccrypto.generate_key("very-low"), ("1.2.3.4", 5))
mapping.add(peer, chan_pk, chan_id)
mapping.remove_peer(peer)
for p in chan_peers:
mapping.remove_peer(p)

assert mapping.get_last_seen_peers_for_channel(chan_pk, chan_id) == []

# Make sure the stuff is cleaned up
assert len(mapping._peers_channels) == 0
assert len(mapping._channels_dict) == 0

@pytest.mark.timeout(0)
async def test_remote_search_mapped_peers(self):
"""
Test using mapped peers for channel queries.
"""
key = default_eccrypto.generate_key("curve25519")
channel_pk = key.pub().key_to_bin()[10:]
channel_id = 123
kwargs = {"channel_pk": f"{hexlify(channel_pk)}", "origin_id": channel_id}

await self.introduce_nodes()

source_peer = self.nodes[2].overlay.get_peers()[0]
self.nodes[2].overlay.channels_peers.add(source_peer, channel_pk, channel_id)

self.nodes[2].overlay.notifier = None

# We disable getting random peers, so the only source for peers is channels peers map
self.nodes[2].overlay.get_random_peers = lambda _: []

self.nodes[2].overlay.send_remote_select = Mock()
self.nodes[2].overlay.send_search_request(**kwargs)

# The peer must have queried at least one peer
self.nodes[2].overlay.send_remote_select.assert_called()

@pytest.mark.timeout(5)
async def test_drop_silent_peer_from_channels_map(self):

# We do not want the query back mechanism to interfere with this test
self.nodes[1].overlay.settings.max_channel_query_back = 0

kwargs_dict = {"txt_filter": "ubuntu*"}

basic_path = 'tribler_core.modules.metadata_store.community'

with patch(
basic_path + '.remote_query_community.SelectRequest.timeout_delay', new_callable=PropertyMock
) as delay_mock:
# Change query timeout to a really low value
delay_mock.return_value = 0.3

# Stop peer 0 from responding
with patch(basic_path + '.remote_query_community.RemoteQueryCommunity._on_remote_select_basic'):
self.nodes[1].overlay.channels_peers.remove_peer = Mock()
self.nodes[1].overlay.send_remote_select(self.nodes[0].my_peer, **kwargs_dict)

await self.deliver_messages(timeout=1)
# node 0 must have called remove_peer because of the timeout
self.nodes[1].overlay.channels_peers.remove_peer.assert_called()

# Now test that even in the case of an empty response packet, remove_peer is not called on timeout
self.nodes[1].overlay.channels_peers.remove_peer = Mock()
self.nodes[1].overlay.send_remote_select(self.nodes[0].my_peer, **kwargs_dict)
await self.deliver_messages(timeout=1)
self.nodes[1].overlay.channels_peers.remove_peer.assert_not_called()
Loading

0 comments on commit e1e2fab

Please sign in to comment.