Skip to content

Commit

Permalink
Fix search
Browse files Browse the repository at this point in the history
Remote search is fixed
Content filter is fixed
  • Loading branch information
drew2a committed Apr 25, 2024
1 parent af1d880 commit 6381b91
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 116 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from binascii import hexlify, unhexlify
from binascii import hexlify

Check notice on line 1 in src/tribler/core/components/content_discovery/restapi/search_endpoint.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/tribler/core/components/content_discovery/restapi/search_endpoint.py#L1

Similar lines in 2 files

from aiohttp import web
from aiohttp_apispec import docs, querystring_schema
from ipv8.REST.schema import schema
from marshmallow.fields import List, String

from ipv8.REST.schema import schema
from tribler.core.components.content_discovery.community.content_discovery_community import ContentDiscoveryCommunity
from tribler.core.components.content_discovery.restapi.schema import RemoteQueryParameters
from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint
from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, MAX_REQUEST_SIZE, RESTEndpoint, \
RESTResponse
from tribler.core.utilities.utilities import froze_it
from tribler.core.utilities.utilities import froze_it, to_fts_query


@froze_it
Expand All @@ -31,14 +32,7 @@ def setup_routes(self):

@classmethod
def sanitize_parameters(cls, parameters):
sanitized = dict(parameters)
if "max_rowid" in parameters:
sanitized["max_rowid"] = int(parameters["max_rowid"])
if "channel_pk" in parameters:
sanitized["channel_pk"] = unhexlify(parameters["channel_pk"])
if "origin_id" in parameters:
sanitized["origin_id"] = int(parameters["origin_id"])
return sanitized
return DatabaseEndpoint.sanitize_parameters(parameters)

@docs(
tags=['Metadata'],
Expand All @@ -58,14 +52,17 @@ def sanitize_parameters(cls, parameters):
)
@querystring_schema(RemoteQueryParameters)
async def remote_search(self, request):
self._logger.info('Create remote search request')
# Query remote results from the GigaChannel Community.
# Results are returned over the Events endpoint.
try:
sanitized = self.sanitize_parameters(request.query)
except (ValueError, KeyError) as e:
return RESTResponse({"error": f"Error processing request parameters: {e}"}, status=HTTP_BAD_REQUEST)
query = request.query.get('fts_text')
if t_filter := request.query.get('filter'):
query += f' {t_filter}'
fts = to_fts_query(query)
sanitized['txt_filter'] = fts
self._logger.info(f'Parameters: {sanitized}')
self._logger.info(f'FTS: {fts}')

request_uuid, peers_list = self.popularity_community.send_search_request(**sanitized)
peers_mid_list = [hexlify(p.mid).decode() for p in peers_list]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ def mock_send(**kwargs):
search_txt = "foo"
await do_request(
rest_api,
f'search/remote?txt_filter={search_txt}&max_rowid=1',
'search/remote',
params={
'fts_text': search_txt,
'filter': 'bar',
'max_rowid': 1
},
request_type="PUT",
expected_code=200,
expected_json={"request_uuid": str(request_uuid), "peers": peers},
)
assert sent['txt_filter'] == search_txt
assert sent['txt_filter'] == f'"{search_txt}" "bar"'
sent.clear()

# Test querying channel data by public key, e.g. for channel preview purposes
Expand Down
13 changes: 10 additions & 3 deletions src/tribler/core/components/database/restapi/database_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tribler.core.components.restapi.rest.rest_endpoint import MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse
from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.utilities import froze_it, parse_bool
from tribler.core.utilities.utilities import froze_it, parse_bool, to_fts_query

TORRENT_CHECK_TIMEOUT = 20
SNIPPETS_TO_SHOW = 3 # The number of snippets we return from the search results
Expand Down Expand Up @@ -86,10 +86,10 @@ def sanitize_parameters(cls, parameters):
"last": int(parameters.get('last', 50)),
"sort_by": json2pony_columns.get(parameters.get('sort_by')),
"sort_desc": parse_bool(parameters.get('sort_desc', True)),
"txt_filter": parameters.get('txt_filter'),
"hide_xxx": parse_bool(parameters.get('hide_xxx', False)),
"category": parameters.get('category'),
}

if 'tags' in parameters:
sanitized['tags'] = parameters.getall('tags')
if "max_rowid" in parameters:
Expand Down Expand Up @@ -192,7 +192,8 @@ async def get_popular_torrents(self, request):
sanitized = self.sanitize_parameters(request.query)
sanitized["metadata_type"] = REGULAR_TORRENT
sanitized["popular"] = True

if t_filter := request.query.get('filter'):
sanitized["txt_filter"] = t_filter
with db_session:
contents = self.mds.get_entries(**sanitized)
contents_list = []
Expand Down Expand Up @@ -236,6 +237,12 @@ async def local_search(self, request):
return RESTResponse({"error": "Error processing request parameters"}, status=HTTP_BAD_REQUEST)

include_total = request.query.get('include_total', '')
query = request.query.get('fts_text')
if t_filter := request.query.get('filter'):
query += f' {t_filter}'
fts = to_fts_query(query)
sanitized['txt_filter'] = fts
self._logger.info(f'FTS: {fts}')

mds: MetadataStore = self.mds

Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
import os
from typing import List, Set
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from time import time
from typing import Set
from unittest.mock import MagicMock, Mock, patch

import pytest
from pony.orm import db_session

from tribler.core.components.database.category_filter.family_filter import default_xxx_filter
from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer
from tribler.core.components.database.db.serialization import REGULAR_TORRENT, SNIPPET
from tribler.core.components.database.db.serialization import REGULAR_TORRENT
from tribler.core.components.database.restapi.database_endpoint import DatabaseEndpoint, TORRENT_CHECK_TIMEOUT
from tribler.core.components.restapi.rest.base_api_test import do_request
from tribler.core.components.torrent_checker.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.config.tribler_config import TriblerConfig
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import random_infohash, to_fts_query

LOCAL_ENDPOINT = 'metadata/search/local'
POPULAR_ENDPOINT = "metadata/torrents/popular"


@pytest.fixture(name="needle_in_haystack_mds")
def fixture_needle_in_haystack_mds(metadata_store):
num_hay = 100

def _put_torrent_with_seeders(name):
infohash = random_infohash()
state = metadata_store.TorrentState(infohash=infohash, seeders=100, leechers=100, has_data=1,
last_check=int(time()))
metadata_store.TorrentMetadata(title=name, infohash=infohash, public_key=b'', health=state,
metadata_type=REGULAR_TORRENT)

with db_session:
for x in range(0, num_hay):
metadata_store.TorrentMetadata(title='hay ' + str(x), infohash=random_infohash(), public_key=b'')
metadata_store.TorrentMetadata(title='needle', infohash=random_infohash(), public_key=b'')
metadata_store.TorrentMetadata(title='needle2', infohash=random_infohash(), public_key=b'')
_put_torrent_with_seeders('needle 1')
_put_torrent_with_seeders('needle 2')
return metadata_store


Expand Down Expand Up @@ -83,33 +95,18 @@ async def test_check_torrent_query(rest_api):
await do_request(rest_api, f"metadata/torrents/{infohash}/health?timeout=wrong_value&refresh=1", expected_code=400)


@patch.object(DatabaseEndpoint, 'add_download_progress_to_metadata_list', Mock())
async def test_get_popular_torrents(rest_api, endpoint, metadata_store):
"""
Test that the endpoint responds with its known entries.
"""
fake_entry = {
"name": "Torrent Name",
"category": "",
"infohash": "ab" * 20,
"size": 1,
"num_seeders": 1234,
"num_leechers": 123,
"last_tracker_check": 17000000,
"created": 15000000,
"tag_processor_version": 1,
"type": REGULAR_TORRENT,
"id": 0,
"origin_id": 0,
"public_key": "ab" * 64,
"status": 2,
"statements": []
}
fake_state = Mock(return_value=Mock(get_progress=Mock(return_value=0.5)))
metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))])
endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state))
response = await do_request(rest_api, "metadata/torrents/popular")
""" Test that the endpoint responds with its known entries."""
response = await do_request(rest_api, POPULAR_ENDPOINT)
assert len(response['results']) == 2 # as there are two torrents with seeders and leechers

assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50}

@patch.object(DatabaseEndpoint, 'add_download_progress_to_metadata_list', Mock())
async def test_get_popular_torrents_with_filter(rest_api, endpoint, metadata_store):
""" Test that the endpoint responds with its known entries with a filter."""
response = await do_request(rest_api, POPULAR_ENDPOINT, params={'filter': '2'})
assert response['results'][0]['name'] == 'needle 2'


async def test_get_popular_torrents_filter_xxx(rest_api, endpoint, metadata_store):
Expand All @@ -136,7 +133,7 @@ async def test_get_popular_torrents_filter_xxx(rest_api, endpoint, metadata_stor
fake_state = Mock(return_value=Mock(get_progress=Mock(return_value=0.5)))
metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))])
endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state))
response = await do_request(rest_api, "metadata/torrents/popular", params={"hide_xxx": 1})
response = await do_request(rest_api, POPULAR_ENDPOINT, params={"hide_xxx": 1})

fake_entry["statements"] = [] # Should be stripped
assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50}
Expand Down Expand Up @@ -167,7 +164,7 @@ async def test_get_popular_torrents_no_db(rest_api, endpoint, metadata_store):
metadata_store.get_entries = Mock(return_value=[Mock(to_simple_dict=Mock(return_value=fake_entry.copy()))])
endpoint.download_manager.get_download = Mock(return_value=Mock(get_state=fake_state))
endpoint.tribler_db = None
response = await do_request(rest_api, "metadata/torrents/popular")
response = await do_request(rest_api, POPULAR_ENDPOINT)

assert response == {'results': [{**fake_entry, **{"progress": 0.5}}], 'first': 1, 'last': 50}

Expand All @@ -176,23 +173,17 @@ async def test_search(rest_api):
"""
Test a search query that should return a few new type channels
"""
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle', expected_code=200)
assert len(parsed["results"]) == 2

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200)
assert len(parsed["results"]) == 1

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=hay', expected_code=200)
assert len(parsed["results"]) == 50

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&type=torrent', expected_code=200)
assert parsed["results"][0]['name'] == 'needle'

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&type=torrent', expected_code=200)
assert parsed["results"][0]['name'] == 'needle 2'

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle%2A&sort_by=name&sort_desc=1',
expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 2
assert parsed["results"][0]['name'] == "needle2"


async def test_search_by_tags(rest_api):
Expand All @@ -202,52 +193,65 @@ def mocked_get_subjects_intersection(*_, objects: Set[str], **__):
return {hexlify(os.urandom(20))}

with patch.object(KnowledgeDataAccessLayer, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection):
parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=real_tag', expected_code=200)
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=real_tag', expected_code=200)

assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&tags=missed_tag',
parsed = await do_request(rest_api, 'metadata/search/local?fts_text=needle&tags=missed_tag',
expected_code=200)
assert len(parsed["results"]) == 1
assert len(parsed["results"]) == 2


async def test_search_with_include_total_and_max_rowid(rest_api):
"""
Test search queries with include_total and max_rowid options
"""

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle', expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle'})
assert len(parsed["results"]) == 2
assert "total" not in parsed
assert "max_rowid" not in parsed

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&include_total=1', expected_code=200)
assert parsed["total"] == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'include_total': 1})
assert parsed["total"] == 2
assert parsed["max_rowid"] == 102

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&include_total=1', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay', 'include_total': 1})
assert parsed["total"] == 100
assert parsed["max_rowid"] == 102

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay'})
assert len(parsed["results"]) == 50

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=0', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'max_rowid': 0})
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=hay&max_rowid=19', expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'hay', 'max_rowid': 19})
assert len(parsed["results"]) == 19

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name', expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT, params={'fts_text': 'needle', 'sort_by': 'name'})
assert len(parsed["results"]) == 2

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=20',
expected_code=200)
parsed = await do_request(rest_api, LOCAL_ENDPOINT,
params={'fts_text': 'needle', 'sort_by': 'name', 'max_rowid': 20})
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'metadata/search/local?txt_filter=needle&sort_by=name&max_rowid=200',
expected_code=200)
assert len(parsed["results"]) == 1
parsed = await do_request(rest_api, LOCAL_ENDPOINT,
params={'fts_text': 'needle', 'sort_by': 'name', 'max_rowid': 200})
assert len(parsed["results"]) == 2


async def test_search_with_filter(rest_api):
""" Test search queries with a filter """
response = await do_request(
rest_api,
'metadata/search/local',
params={
'fts_text': 'needle',
'filter': '1'
},
expected_code=200)
assert response["results"][0]['name'] == 'needle 1'


async def test_completions_no_query(rest_api):
Expand Down Expand Up @@ -282,11 +286,10 @@ async def test_search_with_space(rest_api, metadata_store):
ss2 = to_fts_query(s2)
assert ss2 == s2

parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s1}', expected_code=200)
parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s1}', expected_code=200)
results = {item["name"] for item in parsed["results"]}
assert results == {'abc', 'abc.def', 'abc def', 'abc defxyz'}

parsed = await do_request(rest_api, f'metadata/search/local?txt_filter={s2}', expected_code=200)
parsed = await do_request(rest_api, f'metadata/search/local?fts_text={s2}', expected_code=200)
results = {item["name"] for item in parsed["results"]}
assert results == {'abc.def', 'abc def'} # but not 'abcxyz def'

10 changes: 5 additions & 5 deletions src/tribler/gui/widgets/search_results_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@


class SearchResultsModel(ChannelContentModel):
def __init__(self, original_query, **kwargs):
self.original_query = original_query
def __init__(self, **kwargs):
self.remote_results = {}
title = self.format_title()
title = self.format_title(**kwargs)
super().__init__(channel_info={"name": title}, **kwargs)
self.remote_results_received = False
self.postponed_remote_results = []
self.highlight_remote_results = True
self.sort_by_rank = True
self.original_search_results = []

def format_title(self):
q = self.original_query
def format_title(self,**kwargs):
original_query = kwargs.get('original_query', '')
q = original_query
q = q if len(q) < 50 else q[:50] + '...'
return f'Search results for {q}'

Expand Down
Loading

0 comments on commit 6381b91

Please sign in to comment.