Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add local search for tags #6617

Merged
merged 5 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/tribler-common/tribler_common/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from unittest.mock import MagicMock, patch

from tribler_common.patch_import import patch_import
from tribler_common.utilities import show_system_popup, to_fts_query, uri_to_path
from tribler_common.utilities import Query, extract_tags, parse_query, show_system_popup, to_fts_query, uri_to_path

# pylint: disable=import-outside-toplevel, import-error

# fmt: off

def test_uri_to_path():
path = Path(__file__).parent / "bla%20foo.bar"
Expand All @@ -22,6 +22,42 @@ def test_to_fts_query():
assert to_fts_query('[abc, def]: xyz?!') == '"abc" "def" "xyz"*'


def test_extract_tags():
assert extract_tags('') == (set(), '')
assert extract_tags('text') == (set(), 'text')
assert extract_tags('#') == (set(), '#')
assert extract_tags('# ') == (set(), '# ')
assert extract_tags('#t ') == (set(), '#t ')
assert extract_tags('#' + 't' * 51) == (set(), '#' + 't' * 51)
assert extract_tags('####') == (set(), '####')

assert extract_tags('#tag') == ({'tag'}, '')
assert extract_tags('a #tag in the middle') == ({'tag'}, 'a in the middle')
assert extract_tags('at the end of the query #tag') == ({'tag'}, 'at the end of the query ')
assert extract_tags('multiple tags: #tag1 #tag2#tag3') == ({'tag1', 'tag2', 'tag3'}, 'multiple tags: ')
assert extract_tags('#tag_with_underscores #tag-with-dashes') == ({'tag_with_underscores', 'tag-with-dashes'}, ' ')


def test_parse_query():
assert parse_query('') == Query(original_query='')

actual = parse_query('#tag1 #tag2')
expected = Query(original_query='#tag1 #tag2', tags={'tag1', 'tag2'}, fts_text='')
assert actual == expected

actual = parse_query('query without tags')
expected = Query(original_query='query without tags',
tags=set(),
fts_text='query without tags')
assert actual == expected

actual = parse_query('query with #tag1 and #tag2')
expected = Query(original_query='query with #tag1 and #tag2',
tags={'tag1', 'tag2'},
fts_text='query with and')
assert actual == expected


@patch_import(modules=['win32api'], MessageBox=MagicMock())
@patch('platform.system', new=MagicMock(return_value='Windows'))
@patch('tribler_common.utilities.print', new=MagicMock)
Expand Down
41 changes: 41 additions & 0 deletions src/tribler-common/tribler_common/utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import itertools
import os
import platform
import re
import sys
from dataclasses import dataclass, field
from typing import Set, Tuple
from urllib.parse import urlparse
from urllib.request import url2pathname

Expand All @@ -27,6 +30,44 @@ def uri_to_path(uri):


fts_query_re = re.compile(r'\w+', re.UNICODE)
tags_re = re.compile(r'#[^\s^#]{3,50}(?=[#\s]|$)')


@dataclass
class Query:
original_query: str
tags: Set[str] = field(default_factory=set)
fts_text: str = ''


def parse_query(query: str) -> Query:
"""
The query structure:
query = [tag1][tag2] text
^ ^
tags fts query
"""
drew2a marked this conversation as resolved.
Show resolved Hide resolved
if not query:
return Query(original_query=query)

tags, remaining_text = extract_tags(query)
return Query(original_query=query, tags=tags, fts_text=remaining_text.strip())


def extract_tags(text: str) -> Tuple[Set[str], str]:
if not text:
return set(), ''

tags = set()
positions = [0]

for m in tags_re.finditer(text):
tags.add(m.group(0)[1:])
positions.extend(itertools.chain.from_iterable(m.regs))
positions.append(len(text))

remaining_text = ''.join(text[positions[i] : positions[i + 1]] for i in range(0, len(positions) - 1, 2))
return tags, remaining_text


def to_fts_query(text):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def get_entries_query(
category=None,
attribute_ranges=None,
infohash=None,
infohash_set=None,
id_=None,
complete_channel=None,
self_checked_torrent=None,
Expand All @@ -626,7 +627,7 @@ def get_entries_query(
if cls is None:
cls = self.ChannelNode
pony_query = self.search_keyword(txt_filter, lim=1000) if txt_filter else left_join(g for g in cls)

infohash_set = infohash_set or ({infohash} if infohash else None)
if popular:
if metadata_type != REGULAR_TORRENT:
raise TypeError('With `popular=True`, only `metadata_type=REGULAR_TORRENT` is allowed')
Expand Down Expand Up @@ -678,7 +679,7 @@ def get_entries_query(
pony_query = pony_query.where(lambda g: g.status != TODELETE) if exclude_deleted else pony_query
pony_query = pony_query.where(lambda g: g.xxx == 0) if hide_xxx else pony_query
pony_query = pony_query.where(lambda g: g.status != LEGACY_ENTRY) if exclude_legacy else pony_query
pony_query = pony_query.where(lambda g: g.infohash == infohash) if infohash else pony_query
pony_query = pony_query.where(lambda g: g.infohash in infohash_set) if infohash_set else pony_query
pony_query = (
pony_query.where(lambda g: g.health.self_checked == self_checked_torrent)
if self_checked_torrent is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tribler_core.components.metadata_store.db.orm_bindings.discrete_clock import clock
from tribler_core.components.metadata_store.db.orm_bindings.torrent_metadata import tdef_to_metadata_dict
from tribler_core.components.metadata_store.db.serialization import CHANNEL_TORRENT, REGULAR_TORRENT
from tribler_core.conftest import TEST_PERSONAL_KEY
from tribler_core.tests.tools.common import TORRENT_UBUNTU_FILE
from tribler_core.utilities.utilities import random_infohash

Expand Down Expand Up @@ -224,6 +225,30 @@ def test_get_autocomplete_terms_max(metadata_store):
autocomplete_terms = metadata_store.get_auto_complete_terms(".", 2)


@db_session
def test_get_entries_for_infohashes(metadata_store):
infohash1 = random_infohash()
infohash2 = random_infohash()
infohash3 = random_infohash()

metadata_store.TorrentMetadata(title='title', infohash=infohash1, size=0, sign_with=TEST_PERSONAL_KEY)
metadata_store.TorrentMetadata(title='title', infohash=infohash2, size=0, sign_with=TEST_PERSONAL_KEY)

def count(*args, **kwargs):
return len(metadata_store.get_entries_query(*args, **kwargs))

# infohash can be passed as a single object
assert count(infohash=infohash3) == 0
assert count(infohash=infohash1) == 1

# infohashes can be passed as a set
assert count(infohash_set={infohash1, infohash2}) == 2

# in the case both arguments are used, the function will take to consideration
# only `infohash_set`
assert count(infohash=infohash1, infohash_set={infohash1, infohash2}) == 2


@db_session
def test_get_entries(metadata_store):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def sanitize_parameters(cls, parameters):
"category": parameters.get('category'),
"exclude_deleted": bool(int(parameters.get('exclude_deleted', 0)) > 0),
}
if 'tags' in parameters:
sanitized['tags'] = parameters.getall('tags')
if "remote" in parameters:
sanitized["remote"] = (bool(int(parameters.get('remote', 0)) > 0),)
if 'metadata_type' in parameters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@ def sanitize_parameters(cls, parameters):
async def search(self, request):
try:
sanitized = self.sanitize_parameters(request.query)
tags = sanitized.pop('tags', None)
except (ValueError, KeyError):
return RESTResponse({"error": "Error processing request parameters"}, status=HTTP_BAD_REQUEST)

if not sanitized["txt_filter"]:
return RESTResponse({"error": "Filter parameter missing"}, status=HTTP_BAD_REQUEST)

include_total = request.query.get('include_total', '')

mds: MetadataStore = self.mds
Expand All @@ -75,9 +73,15 @@ def search_db():
return search_results, total, max_rowid

try:
with db_session:
if tags:
lower_tags = {tag.lower() for tag in tags}
infohash_set = self.tags_db.get_infohashes(lower_tags)
sanitized['infohash_set'] = infohash_set

search_results, total, max_rowid = await mds.run_threaded(search_db)
except Exception as e: # pylint: disable=broad-except; # pragma: no cover
self._logger.error("Error while performing DB search: %s: %s", type(e).__name__, e)
self._logger.exception("Error while performing DB search: %s: %s", type(e).__name__, e)
return RESTResponse(status=HTTP_BAD_REQUEST)

self.add_tags_to_metadata_list(search_results, hide_xxx=sanitized["hide_xxx"])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Set
from unittest.mock import patch

from aiohttp.web_app import Application

from pony.orm import db_session
Expand All @@ -8,9 +11,9 @@

from tribler_core.components.metadata_store.restapi.search_endpoint import SearchEndpoint
from tribler_core.components.restapi.rest.base_api_test import do_request
from tribler_core.components.tag.db.tag_db import TagDatabase
from tribler_core.utilities.utilities import random_infohash


# pylint: disable=unused-argument, redefined-outer-name


Expand All @@ -34,13 +37,6 @@ def rest_api(loop, needle_in_haystack_mds, aiohttp_client, tags_db):
return loop.run_until_complete(aiohttp_client(app))


async def test_search_no_query(rest_api):
"""
Testing whether the API returns an error 400 if no query is passed when doing a search
"""
await do_request(rest_api, 'search', expected_code=400)


async def test_search_wrong_mdtype(rest_api):
"""
Testing whether the API returns an error 400 if wrong metadata type is passed in the query
Expand Down Expand Up @@ -73,6 +69,20 @@ async def test_search(rest_api):
assert parsed["results"][0]['name'] == "needle2"


async def test_search_by_tags(rest_api):
def mocked_get_infohashes(tags: Set[str]):
if tags.pop() == 'missed_tag':
return None
return {b'infohash'}

with patch.object(TagDatabase, 'get_infohashes', wraps=mocked_get_infohashes):
parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=real_tag', expected_code=200)
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=missed_tag', expected_code=200)
assert len(parsed["results"]) == 1


async def test_search_with_include_total_and_max_rowid(rest_api):
"""
Test search queries with include_total and max_rowid options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,17 @@ async def run(self):
log_dir = config.general.get_path_as_absolute('log_dir', config.state_dir)
metadata_store_component = await self.get_component(MetadataStoreComponent)

# fmt: off
# pylint: disable=C0301
key_component = await self.require_component(KeyComponent)
ipv8_component = await self.maybe_component(Ipv8Component)
libtorrent_component = await self.maybe_component(LibtorrentComponent)
resource_monitor_component = await self.maybe_component(ResourceMonitorComponent)
key_component = await self.require_component(KeyComponent)
ipv8_component = await self.maybe_component(Ipv8Component)
libtorrent_component = await self.maybe_component(LibtorrentComponent)
resource_monitor_component = await self.maybe_component(ResourceMonitorComponent)
bandwidth_accounting_component = await self.maybe_component(BandwidthAccountingComponent)
gigachannel_component = await self.maybe_component(GigaChannelComponent)
tag_component = await self.maybe_component(TagComponent)
tunnel_component = await self.maybe_component(TunnelsComponent)
torrent_checker_component = await self.maybe_component(TorrentCheckerComponent)
gigachannel_manager_component = await self.maybe_component(GigachannelManagerComponent)
gigachannel_component = await self.maybe_component(GigaChannelComponent)
tag_component = await self.maybe_component(TagComponent)
tunnel_component = await self.maybe_component(TunnelsComponent)
torrent_checker_component = await self.maybe_component(TorrentCheckerComponent)
gigachannel_manager_component = await self.maybe_component(GigachannelManagerComponent)

self._events_endpoint = EventsEndpoint(notifier, public_key=hexlify(key_component.primary_key.key.pk))
self.root_endpoint = RootEndpoint(middlewares=[ApiKeyMiddleware(config.api.key), error_middleware])
Expand All @@ -92,32 +91,37 @@ async def run(self):

# add endpoints
self.root_endpoint.add_endpoint('/events', self._events_endpoint)
self.maybe_add('/settings', SettingsEndpoint, config, download_manager=libtorrent_component.download_manager)
self.maybe_add('/shutdown', ShutdownEndpoint, shutdown_event.set)
self.maybe_add('/debug', DebugEndpoint, config.state_dir, log_dir, tunnel_community=tunnel_community, resource_monitor=resource_monitor_component.resource_monitor)
self.maybe_add('/bandwidth', BandwidthEndpoint, bandwidth_accounting_component.community)
self.maybe_add('/trustview', TrustViewEndpoint, bandwidth_accounting_component.database)
self.maybe_add('/downloads', DownloadsEndpoint, libtorrent_component.download_manager, metadata_store=metadata_store_component.mds, tunnel_community=tunnel_community)
self.maybe_add('/settings', SettingsEndpoint, config, download_manager=libtorrent_component.download_manager)
self.maybe_add('/shutdown', ShutdownEndpoint, shutdown_event.set)
self.maybe_add('/debug', DebugEndpoint, config.state_dir, log_dir, tunnel_community=tunnel_community,
resource_monitor=resource_monitor_component.resource_monitor)
self.maybe_add('/bandwidth', BandwidthEndpoint, bandwidth_accounting_component.community)
self.maybe_add('/trustview', TrustViewEndpoint, bandwidth_accounting_component.database)
self.maybe_add('/downloads', DownloadsEndpoint, libtorrent_component.download_manager,
metadata_store=metadata_store_component.mds, tunnel_community=tunnel_community)
self.maybe_add('/createtorrent', CreateTorrentEndpoint, libtorrent_component.download_manager)
self.maybe_add('/statistics', StatisticsEndpoint, ipv8=ipv8_component.ipv8, metadata_store=metadata_store_component.mds)
self.maybe_add('/libtorrent', LibTorrentEndpoint, libtorrent_component.download_manager)
self.maybe_add('/torrentinfo', TorrentInfoEndpoint, libtorrent_component.download_manager)
self.maybe_add('/metadata', MetadataEndpoint, torrent_checker, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/channels', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager, gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/collections', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager, gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/search', SearchEndpoint, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/remote_query', RemoteQueryEndpoint, gigachannel_component.community, metadata_store_component.mds)
self.maybe_add('/tags', TagsEndpoint, tag_component.tags_db, tag_component.community)
self.maybe_add('/statistics', StatisticsEndpoint, ipv8=ipv8_component.ipv8,
metadata_store=metadata_store_component.mds)
self.maybe_add('/libtorrent', LibTorrentEndpoint, libtorrent_component.download_manager)
self.maybe_add('/torrentinfo', TorrentInfoEndpoint, libtorrent_component.download_manager)
self.maybe_add('/metadata', MetadataEndpoint, torrent_checker, metadata_store_component.mds,
tags_db=tag_component.tags_db)
self.maybe_add('/channels', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager,
gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/collections', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager,
gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/search', SearchEndpoint, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/remote_query', RemoteQueryEndpoint, gigachannel_component.community,
metadata_store_component.mds)
self.maybe_add('/tags', TagsEndpoint, db=tag_component.tags_db, community=tag_component.community)

# pylint: enable=C0301
ipv8_root_endpoint = IPV8RootEndpoint()
for _, endpoint in ipv8_root_endpoint.endpoints.items():
endpoint.initialize(ipv8_component.ipv8)
self.root_endpoint.add_endpoint('/ipv8', ipv8_root_endpoint)
# fmt: on

# ACHTUNG!
# AIOHTTP endpoints cannot be added after the app has been started!
# Note: AIOHTTP endpoints cannot be added after the app has been started!
rest_manager = RESTManager(config=config.api, root_endpoint=self.root_endpoint, state_dir=config.state_dir)
await rest_manager.start()
self.rest_manager = rest_manager
Expand Down
Loading