Skip to content

Commit

Permalink
Polish PR
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Dec 14, 2021
1 parent 9be215b commit 9c17de2
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 45 deletions.
32 changes: 20 additions & 12 deletions src/tribler-common/tribler_common/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
from unittest.mock import MagicMock, patch

from tribler_common.patch_import import patch_import
from tribler_common.utilities import Query, extract_plain_fts_query_text, extract_tags, \
parse_query, show_system_popup, to_fts_query, \
uri_to_path

from tribler_common.utilities import (
Query,
extract_plain_fts_query_text,
extract_tags,
parse_query,
show_system_popup,
to_fts_query,
uri_to_path,
)

# pylint: disable=import-outside-toplevel, import-error

# fmt: off

def test_uri_to_path():
path = Path(__file__).parent / "bla%20foo.bar"
Expand All @@ -26,28 +31,31 @@ def test_to_fts_query():


def test_extract_tags():
assert extract_tags(None) == (set(), '')
assert extract_tags('') == (set(), '')
assert extract_tags('text') == (set(), '')
assert extract_tags('[text') == (set(), '')
assert extract_tags('text]') == (set(), '')
assert extract_tags('[]') == (set(), '')
assert extract_tags('[ta]') == (set(), '')
assert extract_tags('[' + 't' * 51 + ']') == (set(), '')
assert extract_tags('[tag1[tag2]text]') == (set(), '')
assert extract_tags('[not a tag]') == (set(), '')

assert extract_tags('[tag]') == ({'tag'}, '[tag]')
assert extract_tags('[tag1][tag2]') == ({'tag1', 'tag2'}, '[tag1][tag2]')
assert extract_tags('[tag_with_underscore][tag-with-dash]') == ({'tag_with_underscore', 'tag-with-dash'},
'[tag_with_underscore][tag-with-dash]')

assert extract_tags('[tag]text') == ({'tag'}, '[tag]')
assert extract_tags('[tag1][tag2]text') == ({'tag1', 'tag2'}, '[tag1][tag2]')
assert extract_tags(' [tag]text[not_tag]') == ({'tag'}, ' [tag]')
assert extract_tags(' [tag][not tag]for complex query with [not tag at the end]') == ({'tag'}, ' [tag]')


def test_extract_plain_fts_query_text():
assert not extract_plain_fts_query_text(None, '')
assert not extract_plain_fts_query_text('', '')
assert extract_plain_fts_query_text('query', '') == 'query'
assert extract_plain_fts_query_text('[tag] query', '[tag]') == 'query'


def test_parse_query():
assert parse_query(None) == Query(original_query=None)
assert parse_query('') == Query(original_query='')

actual = parse_query('[tag1][tag2]')
Expand All @@ -62,7 +70,7 @@ def test_parse_query():
actual = parse_query('[tag1][tag2] fts query with potential [brackets]')
expected = Query(original_query='[tag1][tag2] fts query with potential [brackets]',
tags={'tag1', 'tag2'},
fts_text='fts query with potential [brackets]',)
fts_text='fts query with potential [brackets]', )
assert actual == expected


Expand Down
17 changes: 7 additions & 10 deletions src/tribler-common/tribler_common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import sys
from dataclasses import dataclass, field
from typing import Optional, Set, Tuple
from typing import Set, Tuple
from urllib.parse import urlparse
from urllib.request import url2pathname

Expand All @@ -29,17 +29,17 @@ def uri_to_path(uri):


fts_query_re = re.compile(r'\w+', re.UNICODE)
tags_re = re.compile(r'^\s*(?:\[\w+\]\s*)+')
tags_re = re.compile(r'^\s*(?:\[[^\s\[\]]{3,50}\]\s*)+')


@dataclass
class Query:
original_query: Optional[str]
original_query: str
tags: Set[str] = field(default_factory=set)
fts_text: str = ''


def parse_query(query: Optional[str]) -> Query:
def parse_query(query: str) -> Query:
"""
The query structure:
query = [tag1][tag2] text
Expand All @@ -55,7 +55,7 @@ def parse_query(query: Optional[str]) -> Query:
return Query(original_query=query, tags=tags, fts_text=fts_text)


def extract_tags(text: Optional[str]) -> Tuple[Set[str], str]:
def extract_tags(text: str) -> Tuple[Set[str], str]:
if not text:
return set(), ''
if (m := tags_re.match(text)) is not None:
Expand All @@ -64,11 +64,8 @@ def extract_tags(text: Optional[str]) -> Tuple[Set[str], str]:
return set(), ''


def extract_plain_fts_query_text(query: Optional[str], tags_string: str) -> str:
if query is None:
return ''

return query[len(tags_string):].strip()
def extract_plain_fts_query_text(query: str, tags_string: str) -> str:
return query[len(tags_string) :].strip()


def to_fts_query(text):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from aiohttp import web

from aiohttp_apispec import docs, querystring_schema

from ipv8.REST.schema import schema

from marshmallow.fields import Integer, String

from pony.orm import db_session

from ipv8.REST.schema import schema
from tribler_core.components.metadata_store.db.store import MetadataStore
from tribler_core.components.metadata_store.restapi.metadata_endpoint import MetadataEndpointBase
from tribler_core.components.metadata_store.restapi.metadata_schema import MetadataParameters, MetadataSchema
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Set
from unittest.mock import patch

from aiohttp.web_app import Application

from pony.orm import db_session
Expand All @@ -8,9 +11,9 @@

from tribler_core.components.metadata_store.restapi.search_endpoint import SearchEndpoint
from tribler_core.components.restapi.rest.base_api_test import do_request
from tribler_core.components.tag.db.tag_db import TagDatabase
from tribler_core.utilities.utilities import random_infohash


# pylint: disable=unused-argument, redefined-outer-name


Expand All @@ -34,13 +37,6 @@ def rest_api(loop, needle_in_haystack_mds, aiohttp_client, tags_db):
return loop.run_until_complete(aiohttp_client(app))


async def test_search_no_query(rest_api):
"""
Testing whether the API returns an error 400 if no query is passed when doing a search
"""
await do_request(rest_api, 'search', expected_code=400)


async def test_search_wrong_mdtype(rest_api):
"""
Testing whether the API returns an error 400 if wrong metadata type is passed in the query
Expand Down Expand Up @@ -73,6 +69,20 @@ async def test_search(rest_api):
assert parsed["results"][0]['name'] == "needle2"


async def test_search_by_tags(rest_api):
def mocked_get_infohashes(tags: Set[str]):
if tags.pop() == 'missed_tag':
return None
return {b'infohash'}

with patch.object(TagDatabase, 'get_infohashes', wraps=mocked_get_infohashes):
parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=real_tag', expected_code=200)
assert len(parsed["results"]) == 0

parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=missed_tag', expected_code=200)
assert len(parsed["results"]) == 1


async def test_search_with_include_total_and_max_rowid(rest_api):
"""
Test search queries with include_total and max_rowid options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Type

from ipv8.REST.root_endpoint import RootEndpoint as IPV8RootEndpoint

from tribler_common.reported_error import ReportedError

from tribler_core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent
from tribler_core.components.bandwidth_accounting.restapi.bandwidth_endpoint import BandwidthEndpoint
from tribler_core.components.base import Component, NoneComponent
Expand Down Expand Up @@ -68,7 +70,6 @@ async def run(self):
log_dir = config.general.get_path_as_absolute('log_dir', config.state_dir)
metadata_store_component = await self.get_component(MetadataStoreComponent)

# fmt: off
# pylint: disable=C0301
key_component = await self.require_component(KeyComponent)
ipv8_component = await self.maybe_component(Ipv8Component)
Expand Down Expand Up @@ -119,7 +120,6 @@ async def run(self):
for _, endpoint in ipv8_root_endpoint.endpoints.items():
endpoint.initialize(ipv8_component.ipv8)
self.root_endpoint.add_endpoint('/ipv8', ipv8_root_endpoint)
# fmt: on

# Note: AIOHTTP endpoints cannot be added after the app has been started!
rest_manager = RESTManager(config=config.api, root_endpoint=self.root_endpoint, state_dir=config.state_dir)
Expand Down
5 changes: 3 additions & 2 deletions src/tribler-gui/tribler_gui/tribler_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,10 +1026,11 @@ def clicked_search_bar(self, checked=False):
self.stackedWidget.setCurrentIndex(PAGE_SEARCH_RESULTS)

def on_top_search_bar_return_pressed(self):
query = parse_query(self.top_search_bar.text())
if not query.original_query:
query_text = self.top_search_bar.text()
if not query_text:
return

query = parse_query(query_text)
if self.search_results_page.search(query):
self._logger.info(f'Do search for query: {query}')
self.deselect_all_menu_buttons()
Expand Down
25 changes: 15 additions & 10 deletions src/tribler-gui/tribler_gui/widgets/searchresultswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from tribler_common.sentry_reporter.sentry_mixin import AddBreadcrumbOnShowMixin
from tribler_common.utilities import Query, to_fts_query

from tribler_core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT

from tribler_gui.tribler_request_manager import TriblerNetworkRequest
from tribler_gui.utilities import connect, get_ui_file_path, tr
from tribler_gui.widgets.tablecontentmodel import SearchResultsModel
Expand All @@ -23,11 +25,11 @@ def format_search_loading_label(search_request):
}

return (
tr(
"Remote responses: %(num_complete_peers)i / %(total_peers)i"
"\nNew remote results received: %(num_remote_results)i"
)
% data
tr(
"Remote responses: %(num_complete_peers)i / %(total_peers)i"
"\nNew remote results received: %(num_remote_results)i"
)
% data
)


Expand Down Expand Up @@ -79,8 +81,11 @@ def show_results(self, *_):
query = self.search_request.query
self.results_page.initialize_root_model(
SearchResultsModel(
channel_info={"name": (tr("Search results for %s") % query.original_query) if len(
query.original_query) < 50 else f"{query.original_query[:50]}..."},
channel_info={
"name": (tr("Search results for %s") % query.original_query)
if len(query.original_query) < 50
else f"{query.original_query[:50]}..."
},
endpoint_url="search",
hide_xxx=self.results_page.hide_xxx,
text_filter=to_fts_query(query.fts_text),
Expand All @@ -92,9 +97,9 @@ def show_results(self, *_):

def check_can_show(self, query):
if (
self.last_search_query == query
and self.last_search_time is not None
and time.time() - self.last_search_time < 1
self.last_search_query == query
and self.last_search_time is not None
and time.time() - self.last_search_time < 1
):
self._logger.info("Same search query already sent within 500ms so dropping this one")
return False
Expand Down

0 comments on commit 9c17de2

Please sign in to comment.