Skip to content

Commit

Permalink
Merge pull request #7025 from kozlovsky/search_improvements
Browse files Browse the repository at this point in the history
Improved ranking for search results and updated search UI without the artifical delay at the loading screen
  • Loading branch information
kozlovsky authored Nov 4, 2022
2 parents c31e980 + c3df3a5 commit 4755bee
Show file tree
Hide file tree
Showing 17 changed files with 957 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
from ipv8.peerdiscovery.discovery import DiscoveryStrategy


TARGET_PEERS_NUMBER = 20


class RemovePeers(DiscoveryStrategy):
"""
Synchronization strategy for remote query community.
Remove a random peer, if we have enough peers to walk to.
"""

def __init__(self, overlay, target_peers_number=TARGET_PEERS_NUMBER):
super().__init__(overlay)
self.target_peers_number = target_peers_number

def take_step(self):
with self.walk_lock:
peers = self.overlay.get_peers()
if peers and len(peers) > 20:
if peers and len(peers) > self.target_peers_number:
self.overlay.network.remove_peer(choice(peers))
63 changes: 55 additions & 8 deletions src/tribler/core/components/metadata_store/db/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pony import orm
from pony.orm import db_session, desc, left_join, raw_sql, select
from pony.orm.dbproviders.sqlite import keep_exception

from tribler.core import notifications
from tribler.core.components.metadata_store.db.orm_bindings import (
Expand Down Expand Up @@ -50,9 +51,11 @@
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import get_max, get_or_create
from tribler.core.utilities.search_utils import torrent_rank
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import MEMORY_DB


BETA_DB_VERSIONS = [0, 1, 2, 3, 4, 5]
CURRENT_DB_VERSION = 14

Expand Down Expand Up @@ -167,7 +170,7 @@ def __init__(
# with the static analysis.
# pylint: disable=unused-variable
@self._db.on_connect(provider='sqlite')
def sqlite_disable_sync(_, connection):
def on_connect(_, connection):
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = WAL")
cursor.execute("PRAGMA synchronous = NORMAL")
Expand All @@ -180,6 +183,10 @@ def sqlite_disable_sync(_, connection):
# losing power during a write will corrupt the database.
cursor.execute("PRAGMA journal_mode = 0")
cursor.execute("PRAGMA synchronous = 0")

sqlite_rank = keep_exception(torrent_rank)
connection.create_function('search_rank', 5, sqlite_rank)

# pylint: enable=unused-variable

self.MiscData = misc.define_binding(self._db)
Expand Down Expand Up @@ -591,7 +598,7 @@ def torrent_exists_in_personal_channel(self, infohash):
)

# pylint: disable=unused-argument
def search_keyword(self, query, lim=100):
def search_keyword(self, query):
# Requires FTS5 table "FtsIndex" to be generated and populated.
# FTS table is maintained automatically by SQL triggers.
# BM25 ranking is embedded in FTS5.
Expand All @@ -600,10 +607,11 @@ def search_keyword(self, query, lim=100):
if not query or query == "*":
return []

fts_ids = raw_sql(
"""SELECT rowid FROM ChannelNode WHERE rowid IN (SELECT rowid FROM FtsIndex WHERE FtsIndex MATCH $query
ORDER BY bm25(FtsIndex) LIMIT $lim) GROUP BY coalesce(infohash, rowid)"""
)
fts_ids = raw_sql("""
SELECT rowid FROM ChannelNode
WHERE rowid IN (SELECT rowid FROM FtsIndex WHERE FtsIndex MATCH $query)
GROUP BY coalesce(infohash, rowid)
""")
return left_join(g for g in self.MetadataNode if g.rowid in fts_ids) # pylint: disable=E1135

@db_session
Expand Down Expand Up @@ -639,7 +647,7 @@ def get_entries_query(

if cls is None:
cls = self.ChannelNode
pony_query = self.search_keyword(txt_filter, lim=1000) if txt_filter else left_join(g for g in cls)
pony_query = self.search_keyword(txt_filter) if txt_filter else left_join(g for g in cls)
infohash_set = infohash_set or ({infohash} if infohash else None)
if popular:
if metadata_type != REGULAR_TORRENT:
Expand Down Expand Up @@ -728,10 +736,49 @@ def get_entries_query(

if sort_by is None:
if txt_filter:
# pylint: disable=W0105
"""
The following call of `sort_by` produces an ORDER BY expression that looks like this:
ORDER BY
case when "g"."metadata_type" = $CHANNEL_TORRENT then 1
when "g"."metadata_type" = $COLLECTION_NODE then 2
else 3 end,
search_rank(
$QUERY_STRING,
g.title,
torrentstate.seeders,
torrentstate.leechers,
$CURRENT_TIME - strftime('%s', g.torrent_date)
) DESC,
"torrentstate"."last_check" DESC,
So, the channel torrents and channel folders are always on top if they are not filtered out.
Then regular torrents are selected in order of their relevance according to a search_rank() result.
If two torrents have the same search rank, they are ordered by the last time they were checked.
The search_rank() function is called directly from the SQLite query, but is implemented in Python,
it is actually the torrent_rank() function from core/utilities/search_utils.py, wrapped with
keep_exception() to return possible exception from SQLite to Python.
The search_rank() function receives the following arguments:
- the current query string (like "Big Buck Bunny");
- the title of the current torrent;
- the number of seeders;
- the number of leechers;
- the number of seconds since the torrent's creation time.
"""

pony_query = pony_query.sort_by(
f"""
(1 if g.metadata_type == {CHANNEL_TORRENT} else 2 if g.metadata_type == {COLLECTION_NODE} else 3),
desc(g.health.seeders), desc(g.health.leechers)
raw_sql('''search_rank(
$txt_filter, g.title, torrentstate.seeders, torrentstate.leechers,
$int(time()) - strftime('%s', g.torrent_date)
) DESC'''),
desc(g.health.last_check) # just to trigger the TorrentState table inclusion into the left join
"""
)
elif popular:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import struct
import time
from asyncio import Future
from binascii import unhexlify
from typing import List, Optional, Set
from itertools import count
from typing import Any, Dict, List, Optional, Set

from ipv8.lazy_community import lazy_wrapper
from ipv8.messaging.lazy_payload import VariablePayload, vp_compile
Expand All @@ -26,7 +28,7 @@
BINARY_FIELDS = ("infohash", "channel_pk")


def sanitize_query(query_dict, cap=100):
def sanitize_query(query_dict: Dict[str, Any], cap=100) -> Dict[str, Any]:
sanitized_dict = dict(query_dict)

# We impose a cap on max numbers of returned entries to prevent DDOS-like attacks
Expand Down Expand Up @@ -151,6 +153,8 @@ def __init__(self, my_peer, endpoint, network,
self.add_message_handler(SelectResponsePayload, self.on_remote_select_response)

self.eva = EVAProtocol(self, self.on_receive, self.on_send_complete, self.on_error)
self.remote_queries_in_progress = 0
self.next_remote_query_num = count().__next__ # generator of sequential numbers, for logging & debug purposes

async def on_receive(self, result: TransferResult):
self.logger.debug(f"EVA data received: peer {hexlify(result.peer.mid)}, info {result.info}")
Expand Down Expand Up @@ -183,16 +187,32 @@ def send_remote_select(self, peer, processing_callback=None, force_eva_response=
self.ez_send(peer, RemoteSelectPayload(*args))
return request

async def process_rpc_query(self, json_bytes: bytes):
def should_limit_rate_for_query(self, sanitized_parameters: Dict[str, Any]) -> bool:
return 'txt_filter' in sanitized_parameters

async def process_rpc_query_rate_limited(self, sanitized_parameters: Dict[str, Any]) -> List:
query_num = self.next_remote_query_num()
if self.remote_queries_in_progress and self.should_limit_rate_for_query(sanitized_parameters):
self.logger.warning(f'Ignore remote query {query_num} as another one is already processing. '
f'The ignored query: {sanitized_parameters}')
return []

self.logger.info(f'Process remote query {query_num}: {sanitized_parameters}')
self.remote_queries_in_progress += 1
t = time.time()
try:
return await self.process_rpc_query(sanitized_parameters)
finally:
self.remote_queries_in_progress -= 1
self.logger.info(f'Remote query {query_num} processed in {time.time()-t} seconds: {sanitized_parameters}')

async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List:
"""
Retrieve the result of a database query from a third party, encoded as raw JSON bytes (through `dumps`).
:raises TypeError: if the JSON contains invalid keys.
:raises ValueError: if no JSON could be decoded.
:raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed.
"""
parameters = json.loads(json_bytes)
sanitized_parameters = sanitize_query(parameters, self.rqc_settings.max_response_size)

# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)

Expand Down Expand Up @@ -237,9 +257,14 @@ async def on_remote_select_eva(self, peer, request_payload):
async def on_remote_select(self, peer, request_payload):
await self._on_remote_select_basic(peer, request_payload)

def parse_parameters(self, json_bytes: bytes) -> Dict[str, Any]:
parameters = json.loads(json_bytes)
return sanitize_query(parameters, self.rqc_settings.max_response_size)

async def _on_remote_select_basic(self, peer, request_payload, force_eva_response=False):
try:
db_results = await self.process_rpc_query(request_payload.json)
sanitized_parameters = self.parse_parameters(request_payload.json)
db_results = await self.process_rpc_query_rate_limited(sanitized_parameters)

# When we send our response to a host, we open a window of opportunity
# for it to push back updates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@ class RemoteQueryCommunitySettings(TriblerConfigSection):
minimal_blob_size: int = 200
maximum_payload_size: int = 1300
max_entries: int = maximum_payload_size // minimal_blob_size
max_query_peers: int = 5

# The next option is currently used by GigaChannelCommunity only. We probably should move it to the
# GigaChannelCommunity settings or to a dedicated search-related section. The value of the option is corresponding
# with the TARGET_PEERS_NUMBER of src/tribler/core/components/gigachannel/community/sync_strategy.py, that is, to
# the number of peers that GigaChannelCommunity will have after a long run (initially, the number of peers in
# GigaChannelCommunity can rise up to several hundred due to DiscoveryBooster). The number of parallel remote
# requests should be not too small (to have various results from remote peers) and not too big (to avoid flooding
# the network with exceedingly high number of queries). TARGET_PEERS_NUMBER looks like a good middle ground here.
max_query_peers: int = 20

max_response_size: int = 100 # Max number of entries returned by SQL query
max_channel_query_back: int = 4 # Max number of entries to query back on receiving an unknown channel
push_updates_back_enabled = True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import random
import string
import time
from asyncio import sleep
from binascii import unhexlify
from json import dumps
from operator import attrgetter
from os import urandom
from time import time
from unittest.mock import Mock, patch

from ipv8.keyvault.crypto import default_eccrypto
Expand Down Expand Up @@ -112,7 +111,7 @@ async def test_remote_select(self):
channel=channel,
seeders=2 * i,
leechers=i,
last_check=int(time()) + i,
last_check=int(time.time()) + i,
)

kwargs_dict = {"txt_filter": "ubuntu*", "metadata_type": [REGULAR_TORRENT]}
Expand Down Expand Up @@ -345,7 +344,7 @@ async def test_process_rpc_query_match_many(self):
channel = self.channel_metadata(0).create_channel("a channel", "")
add_random_torrent(self.torrent_metadata(0), name="a torrent", channel=channel)

results = await self.overlay(0).process_rpc_query(dumps({}))
results = await self.overlay(0).process_rpc_query({})
self.assertEqual(2, len(results))

channel_md, torrent_md = results if isinstance(results[0], self.channel_metadata(0)) else results[::-1]
Expand All @@ -359,7 +358,7 @@ async def test_process_rpc_query_match_one(self):
with db_session:
self.channel_metadata(0).create_channel("a channel", "")

results = await self.overlay(0).process_rpc_query(dumps({}))
results = await self.overlay(0).process_rpc_query({})
self.assertEqual(1, len(results))

(channel_md,) = results
Expand All @@ -369,22 +368,22 @@ async def test_process_rpc_query_match_none(self):
"""
Check if a correct query with no match in our database returns no result.
"""
results = await self.overlay(0).process_rpc_query(dumps({}))
results = await self.overlay(0).process_rpc_query({})
self.assertEqual(0, len(results))

async def test_process_rpc_query_match_empty_json(self):
def test_parse_parameters_match_empty_json(self):
"""
Check if processing an empty request causes a ValueError (JSONDecodeError) to be raised.
"""
with self.assertRaises(ValueError):
await self.overlay(0).process_rpc_query(b'')
self.overlay(0).parse_parameters(b'')

async def test_process_rpc_query_match_illegal_json(self):
def test_parse_parameters_match_illegal_json(self):
"""
Check if processing a request with illegal JSON causes a UnicodeDecodeError to be raised.
"""
with self.assertRaises(UnicodeDecodeError):
await self.overlay(0).process_rpc_query(b'{"akey":\x80}')
self.overlay(0).parse_parameters(b'{"akey":\x80}')

async def test_process_rpc_query_match_invalid_json(self):
"""
Expand All @@ -394,21 +393,24 @@ async def test_process_rpc_query_match_invalid_json(self):
self.channel_metadata(0).create_channel("a channel", "")
query = b'{"id_":' + b'\x31' * 200 + b'}'
with self.assertRaises(ValueError):
await self.overlay(0).process_rpc_query(query)
parameters = self.overlay(0).parse_parameters(query)
await self.overlay(0).process_rpc_query(parameters)

async def test_process_rpc_query_match_invalid_key(self):
"""
Check if processing a request with invalid flags causes a UnicodeDecodeError to be raised.
"""
with self.assertRaises(TypeError):
await self.overlay(0).process_rpc_query(b'{"bla":":("}')
parameters = self.overlay(0).parse_parameters(b'{"bla":":("}')
await self.overlay(0).process_rpc_query(parameters)

async def test_process_rpc_query_no_column(self):
"""
Check if processing a request with no database columns causes an OperationalError.
"""
with self.assertRaises(OperationalError):
await self.overlay(0).process_rpc_query(b'{"txt_filter":{"key":"bla"}}')
parameters = self.overlay(0).parse_parameters(b'{"txt_filter":{"key":"bla"}}')
await self.overlay(0).process_rpc_query(parameters)

async def test_remote_query_big_response(self):

Expand Down Expand Up @@ -574,3 +576,45 @@ async def test_remote_select_force_eva(self):
await self.deliver_messages(timeout=0.5)

self.nodes[1].overlay.eva.send_binary.assert_called_once()

async def test_multiple_parallel_request(self):
peer_a = self.nodes[0].my_peer
a = self.nodes[0].overlay
b = self.nodes[1].overlay

# Peer A has two torrents "foo" and "bar"
with db_session:
add_random_torrent(a.mds.TorrentMetadata, name="foo")
add_random_torrent(a.mds.TorrentMetadata, name="bar")

# Peer B sends two parallel full-text search queries, only one of them should be processed
callback1 = Mock()
kwargs1 = {"txt_filter": "foo", "metadata_type": [REGULAR_TORRENT]}
b.send_remote_select(peer_a, **kwargs1, processing_callback=callback1)

callback2 = Mock()
kwargs2 = {"txt_filter": "bar", "metadata_type": [REGULAR_TORRENT]}
b.send_remote_select(peer_a, **kwargs2, processing_callback=callback2)

original_get_entries = MetadataStore.get_entries
# Add a delay to ensure that the first query is still being processed when the second one arrives
# (the mds.get_entries() method is a synchronous one and is called from a worker thread)

def slow_get_entries(self, *args, **kwargs):
time.sleep(0.1)
return original_get_entries(self, *args, **kwargs)

with patch.object(a, 'logger') as logger, patch.object(MetadataStore, 'get_entries', slow_get_entries):
await self.deliver_messages(timeout=0.5)

torrents1 = list(b.mds.get_entries(**kwargs1))
torrents2 = list(b.mds.get_entries(**kwargs2))

# Both remote queries should return results to the peer B...
assert callback1.called and callback2.called
# ...but one of them should return an empty list, as the database query was not actually executed
assert bool(torrents1) != bool(torrents2)

# Check that on peer A there is exactly one warning about an ignored remote query
warnings = [call.args[0] for call in logger.warning.call_args_list]
assert len([msg for msg in warnings if msg.startswith('Ignore remote query')]) == 1
Loading

0 comments on commit 4755bee

Please sign in to comment.