Skip to content

Commit

Permalink
Do not perform more than one full-text remote search in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Nov 4, 2022
1 parent dabb40d commit c3df3a5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import struct
import time
from asyncio import Future
from binascii import unhexlify
from typing import List, Optional, Set
from itertools import count
from typing import Any, Dict, List, Optional, Set

from ipv8.lazy_community import lazy_wrapper
from ipv8.messaging.lazy_payload import VariablePayload, vp_compile
Expand All @@ -26,7 +28,7 @@
BINARY_FIELDS = ("infohash", "channel_pk")


def sanitize_query(query_dict, cap=100):
def sanitize_query(query_dict: Dict[str, Any], cap=100) -> Dict[str, Any]:
sanitized_dict = dict(query_dict)

# We impose a cap on max numbers of returned entries to prevent DDOS-like attacks
Expand Down Expand Up @@ -151,6 +153,8 @@ def __init__(self, my_peer, endpoint, network,
self.add_message_handler(SelectResponsePayload, self.on_remote_select_response)

self.eva = EVAProtocol(self, self.on_receive, self.on_send_complete, self.on_error)
self.remote_queries_in_progress = 0
self.next_remote_query_num = count().__next__ # generator of sequential numbers, for logging & debug purposes

async def on_receive(self, result: TransferResult):
self.logger.debug(f"EVA data received: peer {hexlify(result.peer.mid)}, info {result.info}")
Expand Down Expand Up @@ -183,16 +187,32 @@ def send_remote_select(self, peer, processing_callback=None, force_eva_response=
self.ez_send(peer, RemoteSelectPayload(*args))
return request

async def process_rpc_query(self, json_bytes: bytes):
def should_limit_rate_for_query(self, sanitized_parameters: Dict[str, Any]) -> bool:
return 'txt_filter' in sanitized_parameters

async def process_rpc_query_rate_limited(self, sanitized_parameters: Dict[str, Any]) -> List:
query_num = self.next_remote_query_num()
if self.remote_queries_in_progress and self.should_limit_rate_for_query(sanitized_parameters):
self.logger.warning(f'Ignore remote query {query_num} as another one is already processing. '
f'The ignored query: {sanitized_parameters}')
return []

self.logger.info(f'Process remote query {query_num}: {sanitized_parameters}')
self.remote_queries_in_progress += 1
t = time.time()
try:
return await self.process_rpc_query(sanitized_parameters)
finally:
self.remote_queries_in_progress -= 1
self.logger.info(f'Remote query {query_num} processed in {time.time()-t} seconds: {sanitized_parameters}')

async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List:
"""
Retrieve the result of a database query from a third party, encoded as raw JSON bytes (through `dumps`).
:raises TypeError: if the JSON contains invalid keys.
:raises ValueError: if no JSON could be decoded.
:raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed.
"""
parameters = json.loads(json_bytes)
sanitized_parameters = sanitize_query(parameters, self.rqc_settings.max_response_size)

# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)

Expand Down Expand Up @@ -237,9 +257,14 @@ async def on_remote_select_eva(self, peer, request_payload):
async def on_remote_select(self, peer, request_payload):
await self._on_remote_select_basic(peer, request_payload)

def parse_parameters(self, json_bytes: bytes) -> Dict[str, Any]:
parameters = json.loads(json_bytes)
return sanitize_query(parameters, self.rqc_settings.max_response_size)

async def _on_remote_select_basic(self, peer, request_payload, force_eva_response=False):
try:
db_results = await self.process_rpc_query(request_payload.json)
sanitized_parameters = self.parse_parameters(request_payload.json)
db_results = await self.process_rpc_query_rate_limited(sanitized_parameters)

# When we send our response to a host, we open a window of opportunity
# for it to push back updates
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import random
import string
import time
from asyncio import sleep
from binascii import unhexlify
from json import dumps
from operator import attrgetter
from os import urandom
from time import time
from unittest.mock import Mock, patch

from ipv8.keyvault.crypto import default_eccrypto
Expand Down Expand Up @@ -112,7 +111,7 @@ async def test_remote_select(self):
channel=channel,
seeders=2 * i,
leechers=i,
last_check=int(time()) + i,
last_check=int(time.time()) + i,
)

kwargs_dict = {"txt_filter": "ubuntu*", "metadata_type": [REGULAR_TORRENT]}
Expand Down Expand Up @@ -345,7 +344,7 @@ async def test_process_rpc_query_match_many(self):
channel = self.channel_metadata(0).create_channel("a channel", "")
add_random_torrent(self.torrent_metadata(0), name="a torrent", channel=channel)

results = await self.overlay(0).process_rpc_query(dumps({}))
results = await self.overlay(0).process_rpc_query({})
self.assertEqual(2, len(results))

channel_md, torrent_md = results if isinstance(results[0], self.channel_metadata(0)) else results[::-1]
Expand All @@ -359,7 +358,7 @@ async def test_process_rpc_query_match_one(self):
with db_session:
self.channel_metadata(0).create_channel("a channel", "")

results = await self.overlay(0).process_rpc_query(dumps({}))
results = await self.overlay(0).process_rpc_query({})
self.assertEqual(1, len(results))

(channel_md,) = results
Expand All @@ -369,22 +368,22 @@ async def test_process_rpc_query_match_none(self):
"""
Check if a correct query with no match in our database returns no result.
"""
results = await self.overlay(0).process_rpc_query(dumps({}))
results = await self.overlay(0).process_rpc_query({})
self.assertEqual(0, len(results))

async def test_process_rpc_query_match_empty_json(self):
def test_parse_parameters_match_empty_json(self):
"""
Check if processing an empty request causes a ValueError (JSONDecodeError) to be raised.
"""
with self.assertRaises(ValueError):
await self.overlay(0).process_rpc_query(b'')
self.overlay(0).parse_parameters(b'')

async def test_process_rpc_query_match_illegal_json(self):
def test_parse_parameters_match_illegal_json(self):
"""
Check if processing a request with illegal JSON causes a UnicodeDecodeError to be raised.
"""
with self.assertRaises(UnicodeDecodeError):
await self.overlay(0).process_rpc_query(b'{"akey":\x80}')
self.overlay(0).parse_parameters(b'{"akey":\x80}')

async def test_process_rpc_query_match_invalid_json(self):
"""
Expand All @@ -394,21 +393,24 @@ async def test_process_rpc_query_match_invalid_json(self):
self.channel_metadata(0).create_channel("a channel", "")
query = b'{"id_":' + b'\x31' * 200 + b'}'
with self.assertRaises(ValueError):
await self.overlay(0).process_rpc_query(query)
parameters = self.overlay(0).parse_parameters(query)
await self.overlay(0).process_rpc_query(parameters)

async def test_process_rpc_query_match_invalid_key(self):
"""
Check if processing a request with invalid flags causes a UnicodeDecodeError to be raised.
"""
with self.assertRaises(TypeError):
await self.overlay(0).process_rpc_query(b'{"bla":":("}')
parameters = self.overlay(0).parse_parameters(b'{"bla":":("}')
await self.overlay(0).process_rpc_query(parameters)

async def test_process_rpc_query_no_column(self):
"""
Check if processing a request with no database columns causes an OperationalError.
"""
with self.assertRaises(OperationalError):
await self.overlay(0).process_rpc_query(b'{"txt_filter":{"key":"bla"}}')
parameters = self.overlay(0).parse_parameters(b'{"txt_filter":{"key":"bla"}}')
await self.overlay(0).process_rpc_query(parameters)

async def test_remote_query_big_response(self):

Expand Down Expand Up @@ -574,3 +576,45 @@ async def test_remote_select_force_eva(self):
await self.deliver_messages(timeout=0.5)

self.nodes[1].overlay.eva.send_binary.assert_called_once()

async def test_multiple_parallel_request(self):
peer_a = self.nodes[0].my_peer
a = self.nodes[0].overlay
b = self.nodes[1].overlay

# Peer A has two torrents "foo" and "bar"
with db_session:
add_random_torrent(a.mds.TorrentMetadata, name="foo")
add_random_torrent(a.mds.TorrentMetadata, name="bar")

# Peer B sends two parallel full-text search queries, only one of them should be processed
callback1 = Mock()
kwargs1 = {"txt_filter": "foo", "metadata_type": [REGULAR_TORRENT]}
b.send_remote_select(peer_a, **kwargs1, processing_callback=callback1)

callback2 = Mock()
kwargs2 = {"txt_filter": "bar", "metadata_type": [REGULAR_TORRENT]}
b.send_remote_select(peer_a, **kwargs2, processing_callback=callback2)

original_get_entries = MetadataStore.get_entries
# Add a delay to ensure that the first query is still being processed when the second one arrives
# (the mds.get_entries() method is a synchronous one and is called from a worker thread)

def slow_get_entries(self, *args, **kwargs):
time.sleep(0.1)
return original_get_entries(self, *args, **kwargs)

with patch.object(a, 'logger') as logger, patch.object(MetadataStore, 'get_entries', slow_get_entries):
await self.deliver_messages(timeout=0.5)

torrents1 = list(b.mds.get_entries(**kwargs1))
torrents2 = list(b.mds.get_entries(**kwargs2))

# Both remote queries should return results to the peer B...
assert callback1.called and callback2.called
# ...but one of them should return an empty list, as the database query was not actually executed
assert bool(torrents1) != bool(torrents2)

# Check that on peer A there is exactly one warning about an ignored remote query
warnings = [call.args[0] for call in logger.warning.call_args_list]
assert len([msg for msg in warnings if msg.startswith('Ignore remote query')]) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def test_search_for_tags_only_valid_tags(self, mocked_get_subjects_intersection:
async def test_process_rpc_query_no_tags(self, mocked_get_entries_threaded: AsyncMock):
# test that in case of missed tags, the remote search works like normal remote search
parameters = {'first': 0, 'infohash_set': None, 'last': 100}
json = dumps(parameters).encode('utf-8')

await self.rqc.process_rpc_query(json)
await self.rqc.process_rpc_query(parameters)

expected_parameters = {'infohash_set': None}
expected_parameters.update(parameters)
Expand Down Expand Up @@ -117,10 +115,8 @@ def _add(infohash):

# Then we try to query search for three tags: 'tag1', 'tag2', 'tag3'
parameters = {'first': 0, 'infohash_set': None, 'last': 100, 'tags': ['tag1']}
json = dumps(parameters).encode('utf-8')

with db_session:
query_results = [r.to_dict() for r in await self.rqc.process_rpc_query(json)]
query_results = [r.to_dict() for r in await self.rqc.process_rpc_query(parameters)]

# Expected results: only one infohash (b'infohash1') should be returned.
result_infohash_list = [r['infohash'] for r in query_results]
Expand Down

0 comments on commit c3df3a5

Please sign in to comment.