From 3c998103907e2599822b92ed946d1f4ec1fe1702 Mon Sep 17 00:00:00 2001 From: drew2a Date: Thu, 9 Dec 2021 19:00:04 +0100 Subject: [PATCH] Fix logic --- .../components/metadata_store/db/store.py | 2 +- .../components/tag/restapi/tags_endpoint.py | 24 +++++---- .../tag/restapi/tests/test_tags_endpoint.py | 49 ++++++++++++------- .../widgets/searchresultswidget.py | 3 +- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/src/tribler-core/tribler_core/components/metadata_store/db/store.py b/src/tribler-core/tribler_core/components/metadata_store/db/store.py index 50f9e87c13e..3d71d6d66cc 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/db/store.py +++ b/src/tribler-core/tribler_core/components/metadata_store/db/store.py @@ -627,7 +627,7 @@ def get_entries_query( if cls is None: cls = self.ChannelNode pony_query = self.search_keyword(txt_filter, lim=1000) if txt_filter else left_join(g for g in cls) - infohash_set = infohash_set or {infohash} + infohash_set = infohash_set or ({infohash} if infohash else None) if popular: if metadata_type != REGULAR_TORRENT: raise TypeError('With `popular=True`, only `metadata_type=REGULAR_TORRENT` is allowed') diff --git a/src/tribler-core/tribler_core/components/tag/restapi/tags_endpoint.py b/src/tribler-core/tribler_core/components/tag/restapi/tags_endpoint.py index 28dd1064281..644d69dd303 100644 --- a/src/tribler-core/tribler_core/components/tag/restapi/tags_endpoint.py +++ b/src/tribler-core/tribler_core/components/tag/restapi/tags_endpoint.py @@ -1,3 +1,5 @@ +from itertools import islice + import binascii from binascii import unhexlify from typing import Optional, Sequence, Set, Tuple @@ -12,6 +14,7 @@ from pony.orm import db_session +from tribler_core.components.metadata_store.db.serialization import REGULAR_TORRENT from tribler_core.components.metadata_store.db.store import MetadataStore from tribler_core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, RESTEndpoint, RESTResponse from tribler_core.components.restapi.rest.schema import HandledErrorSchema @@ -156,14 +159,14 @@ async def search(self, request): most_relevant, less_relevant = self.search_by_tags(tags) items = [] if most_relevant or less_relevant: - infohash_list = (sorted(most_relevant) + sorted(less_relevant))[:limit] with db_session: - items = list(self.replace_infohash_list_by_metadata_gen(infohash_list)) + gen = self.get_metadata_gen([most_relevant, less_relevant]) + items = list(islice(gen, limit)) return RESTResponse({"items": items}) @db_session - def search_by_tags(self, tags: Set[str]) -> Tuple[Set[str], Set[str]]: + def search_by_tags(self, tags: Set[str]) -> Tuple[Set[bytes], Set[bytes]]: """ Search infohashes for given tags Function combines the results in the following way: @@ -182,10 +185,13 @@ def search_by_tags(self, tags: Set[str]) -> Tuple[Set[str], Set[str]]: return intersection, union - intersection - def replace_infohash_list_by_metadata_gen(self, list_of_infohash_list: Sequence[Sequence[bytes]]): - for infohash_list in list_of_infohash_list: - entries = self.mds.get_entries_query(infohash_set=infohash_list) - items_without_tags = ((entry.to_simple_dict(), entry.infohash) for entry in entries) - for item, infohash in items_without_tags: - item['tags'] = self.db.get_tags(infohash) + def get_metadata_gen(self, list_of_infohash_set: Sequence[Set[bytes]]): + for infohash_set in list_of_infohash_set: + if not infohash_set: + continue + + entries = self.mds.get_entries_query(infohash_set=infohash_set, metadata_type=REGULAR_TORRENT) + for entry in entries: + item = entry.to_simple_dict() + item['tags'] = self.db.get_tags(entry.infohash) yield item diff --git a/src/tribler-core/tribler_core/components/tag/restapi/tests/test_tags_endpoint.py b/src/tribler-core/tribler_core/components/tag/restapi/tests/test_tags_endpoint.py index 95082f24658..a671adea917 100644 --- a/src/tribler-core/tribler_core/components/tag/restapi/tests/test_tags_endpoint.py +++ b/src/tribler-core/tribler_core/components/tag/restapi/tests/test_tags_endpoint.py @@ -1,21 +1,18 @@ from unittest.mock import Mock, patch +import pytest from aiohttp.web_app import Application - from freezegun import freeze_time - -from ipv8.keyvault.crypto import default_eccrypto - from pony.orm import db_session -import pytest - +from ipv8.keyvault.crypto import default_eccrypto from tribler_core.components.restapi.rest.base_api_test import do_request from tribler_core.components.tag.community.tag_payload import TagOperation, TagOperationEnum from tribler_core.components.tag.restapi.tags_endpoint import TagsEndpoint from tribler_core.conftest import TEST_PERSONAL_KEY from tribler_core.utilities.unicode import hexlify + # pylint: disable=redefined-outer-name @pytest.fixture @@ -110,25 +107,43 @@ async def test_get_suggestions(rest_api, tags_db): async def test_search(rest_api, tags_db): - infohash = b'a' * 20 + infohash1 = b'1' * 20 + infohash2 = b'2' * 20 # add tag operations to the db with db_session: for _ in range(2): random_key = default_eccrypto.generate_key('low') - operation = TagOperation(infohash=infohash, tag="test", operation=TagOperationEnum.ADD, clock=0, - creator_public_key=random_key.pub().key_to_bin()) - tags_db.add_tag_operation(operation, b"") + tags_db.add_tag_operation(TagOperation(infohash=infohash1, tag="tag1", operation=TagOperationEnum.ADD, + clock=0, creator_public_key=random_key.pub().key_to_bin()), b'') + tags_db.add_tag_operation(TagOperation(infohash=infohash2, tag="tag2", operation=TagOperationEnum.ADD, + clock=0, creator_public_key=random_key.pub().key_to_bin()), b'') + tags_db.add_tag_operation(TagOperation(infohash=infohash2, tag="tag1", operation=TagOperationEnum.ADD, + clock=0, creator_public_key=random_key.pub().key_to_bin()), b'') def mocked_replace(infohash_list): - return (hexlify(infohash) for infohash in infohash_list) + result = [] + for s in infohash_list: + for infohash in s: + result.append(hexlify(infohash)) + return result # patch `replace_infohash_list_by_metadata_gen` in such a way that it will # return `infohash` instead of `metadata` - with patch.object(TagsEndpoint, 'replace_infohash_list_by_metadata_gen', wraps=mocked_replace): - response = await do_request(rest_api, 'tags/search', post_data={'tags': ['test'], 'limit': 1}) + with patch.object(TagsEndpoint, 'get_metadata_gen', wraps=mocked_replace): + response = await do_request(rest_api, 'tags/search', post_data={'tags': ['tag1'], 'limit': 2}) + + # order inside items in response is not stable, so check that results in items + assert hexlify(infohash1) in response['items'] + assert hexlify(infohash2) in response['items'] + + # in this case order is stable as infohash2 is are in priority list, and infohash1 are in less priority list + response = await do_request(rest_api, 'tags/search', post_data={'tags': ['tag1', 'tag2'], 'limit': 2}) + assert response['items'] == [hexlify(infohash2), hexlify(infohash1)] - assert response['items'] == [hexlify(infohash)] + # check that limit is work in proper way + response = await do_request(rest_api, 'tags/search', post_data={'tags': ['tag1', 'tag2'], 'limit': 1}) + assert response['items'] == [hexlify(infohash2)] async def test_search_by_tags(): @@ -156,7 +171,7 @@ def get_infohashes(self, tag): assert all(a in less_relevant for a in other_values) -async def test_replace_infohash_list_by_metadata_gen(): +async def test_get_metadata_gen(): class MetadataEntry: def __init__(self, infohash): self.infohash = infohash @@ -174,7 +189,7 @@ def get_tags(self, infohash): endpoint = TagsEndpoint(db=TagsDB(), community=Mock(), mds=MDS()) - list_of_infohash_list = [[b'infohash1'], [b'infohash2', b'infohash2']] - result = list(endpoint.replace_infohash_list_by_metadata_gen(list_of_infohash_list)) + list_of_infohash_set = [{b'infohash1'}, {b'infohash2', b'infohash3'}] + result = list(endpoint.get_metadata_gen(list_of_infohash_set)) assert len(result) == 3 assert all('tags' in item for item in result) diff --git a/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py b/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py index 16cbfb1c9bd..b0cc39975a5 100644 --- a/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py +++ b/src/tribler-gui/tribler_gui/widgets/searchresultswidget.py @@ -103,7 +103,8 @@ def search_by_tags(self, tags): def response(r): logging.info(f'Response from "tags/search": {r}') self.setCurrentWidget(self.results_page) - model = TagsSearchResultsModel(r['items'], channel_info={'name': f'Tags: {tags}'}) + name = f'Tags: [{", ".join(tags)}]' + model = TagsSearchResultsModel(r['items'], channel_info={'name': name}) self.results_page.initialize_root_model(model) self.results_page.controller.brain_dead_refresh() model.update()