From 9c17de263dba9342b112fe70b468a0d3ecf8aee6 Mon Sep 17 00:00:00 2001 From: drew2a Date: Tue, 14 Dec 2021 13:40:57 +0100 Subject: [PATCH] Polish PR --- .../tribler_common/tests/test_utils.py | 32 ++++++++++++------- .../tribler_common/utilities.py | 17 ++++------ .../metadata_store/restapi/search_endpoint.py | 6 +++- .../restapi/tests/test_search_endpoint.py | 26 ++++++++++----- .../components/restapi/restapi_component.py | 4 +-- src/tribler-gui/tribler_gui/tribler_window.py | 5 +-- .../widgets/searchresultswidget.py | 25 +++++++++------ 7 files changed, 70 insertions(+), 45 deletions(-) diff --git a/src/tribler-common/tribler_common/tests/test_utils.py b/src/tribler-common/tribler_common/tests/test_utils.py index 1f00d68fdf5..8350af725c6 100644 --- a/src/tribler-common/tribler_common/tests/test_utils.py +++ b/src/tribler-common/tribler_common/tests/test_utils.py @@ -2,13 +2,18 @@ from unittest.mock import MagicMock, patch from tribler_common.patch_import import patch_import -from tribler_common.utilities import Query, extract_plain_fts_query_text, extract_tags, \ - parse_query, show_system_popup, to_fts_query, \ - uri_to_path - +from tribler_common.utilities import ( + Query, + extract_plain_fts_query_text, + extract_tags, + parse_query, + show_system_popup, + to_fts_query, + uri_to_path, +) # pylint: disable=import-outside-toplevel, import-error - +# fmt: off def test_uri_to_path(): path = Path(__file__).parent / "bla%20foo.bar" @@ -26,28 +31,31 @@ def test_to_fts_query(): def test_extract_tags(): - assert extract_tags(None) == (set(), '') assert extract_tags('') == (set(), '') assert extract_tags('text') == (set(), '') assert extract_tags('[text') == (set(), '') assert extract_tags('text]') == (set(), '') assert extract_tags('[]') == (set(), '') + assert extract_tags('[ta]') == (set(), '') + assert extract_tags('[' + 't' * 51 + ']') == (set(), '') assert extract_tags('[tag1[tag2]text]') == (set(), '') + assert extract_tags('[not a tag]') == (set(), '') + + assert extract_tags('[tag]') == ({'tag'}, '[tag]') + assert extract_tags('[tag1][tag2]') == ({'tag1', 'tag2'}, '[tag1][tag2]') + assert extract_tags('[tag_with_underscore][tag-with-dash]') == ({'tag_with_underscore', 'tag-with-dash'}, + '[tag_with_underscore][tag-with-dash]') - assert extract_tags('[tag]text') == ({'tag'}, '[tag]') - assert extract_tags('[tag1][tag2]text') == ({'tag1', 'tag2'}, '[tag1][tag2]') - assert extract_tags(' [tag]text[not_tag]') == ({'tag'}, ' [tag]') + assert extract_tags(' [tag][not tag]for complex query with [not tag at the end]') == ({'tag'}, ' [tag]') def test_extract_plain_fts_query_text(): - assert not extract_plain_fts_query_text(None, '') assert not extract_plain_fts_query_text('', '') assert extract_plain_fts_query_text('query', '') == 'query' assert extract_plain_fts_query_text('[tag] query', '[tag]') == 'query' def test_parse_query(): - assert parse_query(None) == Query(original_query=None) assert parse_query('') == Query(original_query='') actual = parse_query('[tag1][tag2]') @@ -62,7 +70,7 @@ def test_parse_query(): actual = parse_query('[tag1][tag2] fts query with potential [brackets]') expected = Query(original_query='[tag1][tag2] fts query with potential [brackets]', tags={'tag1', 'tag2'}, - fts_text='fts query with potential [brackets]',) + fts_text='fts query with potential [brackets]', ) assert actual == expected diff --git a/src/tribler-common/tribler_common/utilities.py b/src/tribler-common/tribler_common/utilities.py index 03ada628cd3..ecfd645d706 100644 --- a/src/tribler-common/tribler_common/utilities.py +++ b/src/tribler-common/tribler_common/utilities.py @@ -3,7 +3,7 @@ import re import sys from dataclasses import dataclass, field -from typing import Optional, Set, Tuple +from typing import Set, Tuple from urllib.parse import urlparse from urllib.request import url2pathname @@ -29,17 +29,17 @@ def uri_to_path(uri): fts_query_re = re.compile(r'\w+', re.UNICODE) -tags_re = re.compile(r'^\s*(?:\[\w+\]\s*)+') +tags_re = re.compile(r'^\s*(?:\[[^\s\[\]]{3,50}\]\s*)+') @dataclass class Query: - original_query: Optional[str] + original_query: str tags: Set[str] = field(default_factory=set) fts_text: str = '' -def parse_query(query: Optional[str]) -> Query: +def parse_query(query: str) -> Query: """ The query structure: query = [tag1][tag2] text @@ -55,7 +55,7 @@ def parse_query(query: Optional[str]) -> Query: return Query(original_query=query, tags=tags, fts_text=fts_text) -def extract_tags(text: Optional[str]) -> Tuple[Set[str], str]: +def extract_tags(text: str) -> Tuple[Set[str], str]: if not text: return set(), '' if (m := tags_re.match(text)) is not None: @@ -64,11 +64,8 @@ def extract_tags(text: Optional[str]) -> Tuple[Set[str], str]: return set(), '' -def extract_plain_fts_query_text(query: Optional[str], tags_string: str) -> str: - if query is None: - return '' - - return query[len(tags_string):].strip() +def extract_plain_fts_query_text(query: str, tags_string: str) -> str: + return query[len(tags_string) :].strip() def to_fts_query(text): diff --git a/src/tribler-core/tribler_core/components/metadata_store/restapi/search_endpoint.py b/src/tribler-core/tribler_core/components/metadata_store/restapi/search_endpoint.py index fba6905cf09..2adc4274fd4 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/restapi/search_endpoint.py +++ b/src/tribler-core/tribler_core/components/metadata_store/restapi/search_endpoint.py @@ -1,9 +1,13 @@ from aiohttp import web + from aiohttp_apispec import docs, querystring_schema + +from ipv8.REST.schema import schema + from marshmallow.fields import Integer, String + from pony.orm import db_session -from ipv8.REST.schema import schema from tribler_core.components.metadata_store.db.store import MetadataStore from tribler_core.components.metadata_store.restapi.metadata_endpoint import MetadataEndpointBase from tribler_core.components.metadata_store.restapi.metadata_schema import MetadataParameters, MetadataSchema diff --git a/src/tribler-core/tribler_core/components/metadata_store/restapi/tests/test_search_endpoint.py b/src/tribler-core/tribler_core/components/metadata_store/restapi/tests/test_search_endpoint.py index a11ef177045..db237832b20 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/restapi/tests/test_search_endpoint.py +++ b/src/tribler-core/tribler_core/components/metadata_store/restapi/tests/test_search_endpoint.py @@ -1,3 +1,6 @@ +from typing import Set +from unittest.mock import patch + from aiohttp.web_app import Application from pony.orm import db_session @@ -8,9 +11,9 @@ from tribler_core.components.metadata_store.restapi.search_endpoint import SearchEndpoint from tribler_core.components.restapi.rest.base_api_test import do_request +from tribler_core.components.tag.db.tag_db import TagDatabase from tribler_core.utilities.utilities import random_infohash - # pylint: disable=unused-argument, redefined-outer-name @@ -34,13 +37,6 @@ def rest_api(loop, needle_in_haystack_mds, aiohttp_client, tags_db): return loop.run_until_complete(aiohttp_client(app)) -async def test_search_no_query(rest_api): - """ - Testing whether the API returns an error 400 if no query is passed when doing a search - """ - await do_request(rest_api, 'search', expected_code=400) - - async def test_search_wrong_mdtype(rest_api): """ Testing whether the API returns an error 400 if wrong metadata type is passed in the query @@ -73,6 +69,20 @@ async def test_search(rest_api): assert parsed["results"][0]['name'] == "needle2" +async def test_search_by_tags(rest_api): + def mocked_get_infohashes(tags: Set[str]): + if tags.pop() == 'missed_tag': + return None + return {b'infohash'} + + with patch.object(TagDatabase, 'get_infohashes', wraps=mocked_get_infohashes): + parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=real_tag', expected_code=200) + assert len(parsed["results"]) == 0 + + parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=missed_tag', expected_code=200) + assert len(parsed["results"]) == 1 + + async def test_search_with_include_total_and_max_rowid(rest_api): """ Test search queries with include_total and max_rowid options diff --git a/src/tribler-core/tribler_core/components/restapi/restapi_component.py b/src/tribler-core/tribler_core/components/restapi/restapi_component.py index 1f7643b25de..8ae253e3804 100644 --- a/src/tribler-core/tribler_core/components/restapi/restapi_component.py +++ b/src/tribler-core/tribler_core/components/restapi/restapi_component.py @@ -2,7 +2,9 @@ from typing import Type from ipv8.REST.root_endpoint import RootEndpoint as IPV8RootEndpoint + from tribler_common.reported_error import ReportedError + from tribler_core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent from tribler_core.components.bandwidth_accounting.restapi.bandwidth_endpoint import BandwidthEndpoint from tribler_core.components.base import Component, NoneComponent @@ -68,7 +70,6 @@ async def run(self): log_dir = config.general.get_path_as_absolute('log_dir', config.state_dir) metadata_store_component = await self.get_component(MetadataStoreComponent) - # fmt: off # pylint: disable=C0301 key_component = await self.require_component(KeyComponent) ipv8_component = await self.maybe_component(Ipv8Component) @@ -119,7 +120,6 @@ async def run(self): for _, endpoint in ipv8_root_endpoint.endpoints.items(): endpoint.initialize(ipv8_component.ipv8) self.root_endpoint.add_endpoint('/ipv8', ipv8_root_endpoint) - # fmt: on # Note: AIOHTTP endpoints cannot be added after the app has been started! rest_manager = RESTManager(config=config.api, root_endpoint=self.root_endpoint, state_dir=config.state_dir) diff --git a/src/tribler-gui/tribler_gui/tribler_window.py b/src/tribler-gui/tribler_gui/tribler_window.py index 94b463e0ff2..9f52410d453 100644 --- a/src/tribler-gui/tribler_gui/tribler_window.py +++ b/src/tribler-gui/tribler_gui/tribler_window.py @@ -1026,10 +1026,11 @@ def clicked_search_bar(self, checked=False): self.stackedWidget.setCurrentIndex(PAGE_SEARCH_RESULTS) def on_top_search_bar_return_pressed(self): - query = parse_query(self.top_search_bar.text()) - if not query.original_query: + query_text = self.top_search_bar.text() + if not query_text: return + query = parse_query(query_text) if self.search_results_page.search(query): self._logger.info(f'Do search for query: {query}') self.deselect_all_menu_buttons() diff --git a/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py b/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py index e6f7e4b0577..90f6757cc2d 100644 --- a/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py +++ b/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py @@ -7,7 +7,9 @@ from tribler_common.sentry_reporter.sentry_mixin import AddBreadcrumbOnShowMixin from tribler_common.utilities import Query, to_fts_query + from tribler_core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT + from tribler_gui.tribler_request_manager import TriblerNetworkRequest from tribler_gui.utilities import connect, get_ui_file_path, tr from tribler_gui.widgets.tablecontentmodel import SearchResultsModel @@ -23,11 +25,11 @@ def format_search_loading_label(search_request): } return ( - tr( - "Remote responses: %(num_complete_peers)i / %(total_peers)i" - "\nNew remote results received: %(num_remote_results)i" - ) - % data + tr( + "Remote responses: %(num_complete_peers)i / %(total_peers)i" + "\nNew remote results received: %(num_remote_results)i" + ) + % data ) @@ -79,8 +81,11 @@ def show_results(self, *_): query = self.search_request.query self.results_page.initialize_root_model( SearchResultsModel( - channel_info={"name": (tr("Search results for %s") % query.original_query) if len( - query.original_query) < 50 else f"{query.original_query[:50]}..."}, + channel_info={ + "name": (tr("Search results for %s") % query.original_query) + if len(query.original_query) < 50 + else f"{query.original_query[:50]}..." + }, endpoint_url="search", hide_xxx=self.results_page.hide_xxx, text_filter=to_fts_query(query.fts_text), @@ -92,9 +97,9 @@ def show_results(self, *_): def check_can_show(self, query): if ( - self.last_search_query == query - and self.last_search_time is not None - and time.time() - self.last_search_time < 1 + self.last_search_query == query + and self.last_search_time is not None + and time.time() - self.last_search_time < 1 ): self._logger.info("Same search query already sent within 500ms so dropping this one") return False