diff --git a/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py b/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py index e87916602f1..09e83b61511 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py +++ b/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py @@ -91,7 +91,7 @@ def fill_tags_database(): Tag(name='tag1', count=SHOW_THRESHOLD), ], b'infohash2': [ - Tag(name='tag2', count=SHOW_THRESHOLD - 1), + Tag(name='tag1', count=SHOW_THRESHOLD - 1), ] }) @@ -110,7 +110,7 @@ def _add(infohash): fill_mds() # Then we try to query search for three tags: 'tag1', 'tag2', 'tag3' - parameters = {'first': 0, 'infohash_set': None, 'last': 100, 'tags': ['tag1', 'tag2', 'tag3']} + parameters = {'first': 0, 'infohash_set': None, 'last': 100, 'tags': ['tag1']} json = dumps(parameters).encode('utf-8') with db_session: diff --git a/src/tribler-core/tribler_core/components/tag/db/tag_db.py b/src/tribler-core/tribler_core/components/tag/db/tag_db.py index 9ee559f0983..df0f38188c5 100644 --- a/src/tribler-core/tribler_core/components/tag/db/tag_db.py +++ b/src/tribler-core/tribler_core/components/tag/db/tag_db.py @@ -191,10 +191,31 @@ def show_suggestions_condition(torrent_tag): return self._get_tags(infohash, show_suggestions_condition) def get_infohashes(self, tags: Set[str]) -> List[bytes]: - """Get list of infohashes that belongs to the tag. Only tags with condition `_show_condition` will be returned + """Get list of infohashes that belongs to the tag. + Only tags with condition `_show_condition` will be returned. + In the case that the tags set contains more than one tag, + only torrents that contain all `tags` will be returned. """ - return select(tt.torrent.infohash for tt in self.instance.TorrentTag - if self._show_condition(tt) and tt.tag.name in tags).fetch() + + # first, get all torrents that contains any tag from `tags` + query_results = select( + tt.torrent for tt in self.instance.TorrentTag + if self._show_condition(tt) and tt.tag.name in tags + ) + + # second, return only torrents that contain all `tags` + result = {} + for torrent in query_results: + if len(tags) == 1: + result[torrent.infohash] = True + + tag_names = {tt.tag.name for tt in torrent.tags} + if all(tag in tag_names for tag in tags): + result[torrent.infohash] = True + + # this workaround with `dict.keys()` is needed to have a deterministic + # order of the result + return list(result.keys()) def get_clock(self, operation: TagOperation) -> int: """ Get the clock (int) of operation. diff --git a/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py b/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py index dba25aba105..439b6324fb0 100644 --- a/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py +++ b/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py @@ -348,27 +348,47 @@ async def test_get_tags_operations_for_gossip(self): # assert that only one torrent returned (the old and the not auto generated one) assert len(self.db.get_tags_operations_for_gossip(time_delta)) == 1 + @db_session + async def test_get_infohashes_threshold(self): + # test that `get_infohashes` function returns only infohashes with tags + # above the threshold + self.add_operation_set( + self.db, + { + b'infohash1': [ + Tag(name='tag1', count=SHOW_THRESHOLD), + ], + b'infohash2': [ + Tag(name='tag1', count=SHOW_THRESHOLD - 1) + ] + } + ) + + assert self.db.get_infohashes({'tag1'}) == [b'infohash1'] + @db_session async def test_get_infohashes(self): + # test that `get_infohashes` function returns an intersection of result + # in case of more than one tag passed to the function self.add_operation_set( self.db, { b'infohash1': [ Tag(name='tag1', count=SHOW_THRESHOLD), - Tag(name='tag2', count=SHOW_THRESHOLD - 1) + Tag(name='tag2', count=SHOW_THRESHOLD) ], b'infohash2': [ Tag(name='tag1', count=SHOW_THRESHOLD) ], b'infohash3': [ - Tag(name='tag1', count=SHOW_THRESHOLD - 1) + Tag(name='tag2', count=SHOW_THRESHOLD) ] } ) - # test that only tags above the threshold are associated with infohases - assert self.db.get_infohashes('tag1') == [b'infohash1', b'infohash2'] - assert not self.db.get_infohashes('tag2') + assert self.db.get_infohashes({'tag1'}) == [b'infohash1', b'infohash2'] + assert self.db.get_infohashes({'tag2'}) == [b'infohash1', b'infohash3'] + assert self.db.get_infohashes({'tag1', 'tag2'}) == [b'infohash1'] @db_session async def test_show_condition(self):