From 2c94374c211e7e1c4368e3094d73ac502d8b9ae7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 3 Nov 2022 16:49:11 +0100 Subject: [PATCH] Do not perform more than one full-text remote search in parallel --- .../remote_query_community.py | 39 +++++++++-- .../tests/test_remote_query_community.py | 70 +++++++++++++++---- 2 files changed, 89 insertions(+), 20 deletions(-) diff --git a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py index cf443e8d47a..206547453d7 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py @@ -1,8 +1,10 @@ import json import struct +import time from asyncio import Future from binascii import unhexlify -from typing import List, Optional, Set +from itertools import count +from typing import Any, Dict, List, Optional, Set from ipv8.lazy_community import lazy_wrapper from ipv8.messaging.lazy_payload import VariablePayload, vp_compile @@ -26,7 +28,7 @@ BINARY_FIELDS = ("infohash", "channel_pk") -def sanitize_query(query_dict, cap=100): +def sanitize_query(query_dict: Dict[str, Any], cap=100) -> Dict[str, Any]: sanitized_dict = dict(query_dict) # We impose a cap on max numbers of returned entries to prevent DDOS-like attacks @@ -151,6 +153,8 @@ def __init__(self, my_peer, endpoint, network, self.add_message_handler(SelectResponsePayload, self.on_remote_select_response) self.eva = EVAProtocol(self, self.on_receive, self.on_send_complete, self.on_error) + self.remote_queries_in_progress = 0 + self.next_remote_query_num = count().__next__ # generator of sequential numbers, for logging & debug purposes async def on_receive(self, result: TransferResult): self.logger.debug(f"EVA data received: peer {hexlify(result.peer.mid)}, info {result.info}") @@ -183,16 +187,32 @@ def send_remote_select(self, peer, processing_callback=None, force_eva_response= self.ez_send(peer, RemoteSelectPayload(*args)) return request - async def process_rpc_query(self, json_bytes: bytes): + def should_limit_rate(self, sanitized_parameters: Dict[str, Any]) -> bool: + return 'txt_filter' in sanitized_parameters + + async def process_rpc_query_rate_limited(self, sanitized_parameters: Dict[str, Any]) -> List: + query_num = self.next_remote_query_num() + if self.should_limit_rate(sanitized_parameters) and self.remote_queries_in_progress: + self.logger.warning(f'Ignore remote query {query_num} as another one is already processing. ' + f'The ignored query: {sanitized_parameters}') + return [] + + self.logger.info(f'Process remote query {query_num}: {sanitized_parameters}') + self.remote_queries_in_progress += 1 + t = time.time() + try: + return await self.process_rpc_query(sanitized_parameters) + finally: + self.remote_queries_in_progress -= 1 + self.logger.info(f'Remote query {query_num} processed in {time.time()-t} seconds: {sanitized_parameters}') + + async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List: """ Retrieve the result of a database query from a third party, encoded as raw JSON bytes (through `dumps`). :raises TypeError: if the JSON contains invalid keys. :raises ValueError: if no JSON could be decoded. :raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed. """ - parameters = json.loads(json_bytes) - sanitized_parameters = sanitize_query(parameters, self.rqc_settings.max_response_size) - # tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter tags = sanitized_parameters.pop('tags', None) @@ -237,9 +257,14 @@ async def on_remote_select_eva(self, peer, request_payload): async def on_remote_select(self, peer, request_payload): await self._on_remote_select_basic(peer, request_payload) + def parse_parameters(self, json_bytes: bytes) -> Dict[str, Any]: + parameters = json.loads(json_bytes) + return sanitize_query(parameters, self.rqc_settings.max_response_size) + async def _on_remote_select_basic(self, peer, request_payload, force_eva_response=False): try: - db_results = await self.process_rpc_query(request_payload.json) + sanitized_parameters = self.parse_parameters(request_payload.json) + db_results = await self.process_rpc_query_rate_limited(sanitized_parameters) # When we send our response to a host, we open a window of opportunity # for it to push back updates diff --git a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py index 2a82b8c666d..fbd6d25d5bb 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_query_community.py @@ -1,11 +1,10 @@ import random import string +import time from asyncio import sleep from binascii import unhexlify -from json import dumps from operator import attrgetter from os import urandom -from time import time from unittest.mock import Mock, patch from ipv8.keyvault.crypto import default_eccrypto @@ -112,7 +111,7 @@ async def test_remote_select(self): channel=channel, seeders=2 * i, leechers=i, - last_check=int(time()) + i, + last_check=int(time.time()) + i, ) kwargs_dict = {"txt_filter": "ubuntu*", "metadata_type": [REGULAR_TORRENT]} @@ -345,7 +344,7 @@ async def test_process_rpc_query_match_many(self): channel = self.channel_metadata(0).create_channel("a channel", "") add_random_torrent(self.torrent_metadata(0), name="a torrent", channel=channel) - results = await self.overlay(0).process_rpc_query(dumps({})) + results = await self.overlay(0).process_rpc_query({}) self.assertEqual(2, len(results)) channel_md, torrent_md = results if isinstance(results[0], self.channel_metadata(0)) else results[::-1] @@ -359,7 +358,7 @@ async def test_process_rpc_query_match_one(self): with db_session: self.channel_metadata(0).create_channel("a channel", "") - results = await self.overlay(0).process_rpc_query(dumps({})) + results = await self.overlay(0).process_rpc_query({}) self.assertEqual(1, len(results)) (channel_md,) = results @@ -369,22 +368,22 @@ async def test_process_rpc_query_match_none(self): """ Check if a correct query with no match in our database returns no result. """ - results = await self.overlay(0).process_rpc_query(dumps({})) + results = await self.overlay(0).process_rpc_query({}) self.assertEqual(0, len(results)) - async def test_process_rpc_query_match_empty_json(self): + def test_parse_parameters_match_empty_json(self): """ Check if processing an empty request causes a ValueError (JSONDecodeError) to be raised. """ with self.assertRaises(ValueError): - await self.overlay(0).process_rpc_query(b'') + self.overlay(0).parse_parameters(b'') - async def test_process_rpc_query_match_illegal_json(self): + def test_parse_parameters_match_illegal_json(self): """ Check if processing a request with illegal JSON causes a UnicodeDecodeError to be raised. """ with self.assertRaises(UnicodeDecodeError): - await self.overlay(0).process_rpc_query(b'{"akey":\x80}') + self.overlay(0).parse_parameters(b'{"akey":\x80}') async def test_process_rpc_query_match_invalid_json(self): """ @@ -394,21 +393,24 @@ async def test_process_rpc_query_match_invalid_json(self): self.channel_metadata(0).create_channel("a channel", "") query = b'{"id_":' + b'\x31' * 200 + b'}' with self.assertRaises(ValueError): - await self.overlay(0).process_rpc_query(query) + parameters = self.overlay(0).parse_parameters(query) + await self.overlay(0).process_rpc_query(parameters) async def test_process_rpc_query_match_invalid_key(self): """ Check if processing a request with invalid flags causes a UnicodeDecodeError to be raised. """ with self.assertRaises(TypeError): - await self.overlay(0).process_rpc_query(b'{"bla":":("}') + parameters = self.overlay(0).parse_parameters(b'{"bla":":("}') + await self.overlay(0).process_rpc_query(parameters) async def test_process_rpc_query_no_column(self): """ Check if processing a request with no database columns causes an OperationalError. """ with self.assertRaises(OperationalError): - await self.overlay(0).process_rpc_query(b'{"txt_filter":{"key":"bla"}}') + parameters = self.overlay(0).parse_parameters(b'{"txt_filter":{"key":"bla"}}') + await self.overlay(0).process_rpc_query(parameters) async def test_remote_query_big_response(self): @@ -574,3 +576,45 @@ async def test_remote_select_force_eva(self): await self.deliver_messages(timeout=0.5) self.nodes[1].overlay.eva.send_binary.assert_called_once() + + async def test_multiple_parallel_request(self): + peer_a = self.nodes[0].my_peer + a = self.nodes[0].overlay + b = self.nodes[1].overlay + + # Peer A has two torrents "foo" and "bar" + with db_session: + add_random_torrent(a.mds.TorrentMetadata, name="foo") + add_random_torrent(a.mds.TorrentMetadata, name="bar") + + # Peer B sends two parallel full-text search queries, only one of them should be processed + callback1 = Mock() + kwargs1 = {"txt_filter": "foo", "metadata_type": [REGULAR_TORRENT]} + b.send_remote_select(peer_a, **kwargs1, processing_callback=callback1) + + callback2 = Mock() + kwargs2 = {"txt_filter": "bar", "metadata_type": [REGULAR_TORRENT]} + b.send_remote_select(peer_a, **kwargs2, processing_callback=callback2) + + original_get_entries = MetadataStore.get_entries + # Add a delay to ensure that the first query is still being processed when the second one arrives + # (the mds.get_entries() method is a synchronous one and is called from a worker thread) + + def slow_get_entries(self, *args, **kwargs): + time.sleep(0.1) + return original_get_entries(self, *args, **kwargs) + + with patch.object(a, 'logger') as logger, patch.object(MetadataStore, 'get_entries', slow_get_entries): + await self.deliver_messages(timeout=0.5) + + torrents1 = list(b.mds.get_entries(**kwargs1)) + torrents2 = list(b.mds.get_entries(**kwargs2)) + + # Both remote queries should return results to the peer B... + assert callback1.called and callback2.called + # ...but one of them should return an empty list, as the database query was not actually executed + assert bool(torrents1) != bool(torrents2) + + # Check that on peer A there is exactly one warning about an ignored remote query + warnings = [call.args[0] for call in logger.warning.call_args_list] + assert len([msg for msg in warnings if msg.startswith('Ignore remote query')]) == 1