Skip to content

Commit

Permalink
Add remote search for tags
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Jan 6, 2022
1 parent d8cf392 commit cf2bc24
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 40 deletions.
1 change: 1 addition & 0 deletions src/tribler-common/tribler_common/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_extract_tags():
assert extract_tags('####') == (set(), '####')

assert extract_tags('#tag') == ({'tag'}, '')
assert extract_tags('#Tag') == ({'tag'}, '')
assert extract_tags('a #tag in the middle') == ({'tag'}, 'a in the middle')
assert extract_tags('at the end of the query #tag') == ({'tag'}, 'at the end of the query ')
assert extract_tags('multiple tags: #tag1 #tag2#tag3') == ({'tag1', 'tag2', 'tag3'}, 'multiple tags: ')
Expand Down
3 changes: 2 additions & 1 deletion src/tribler-common/tribler_common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def extract_tags(text: str) -> Tuple[Set[str], str]:
positions = [0]

for m in tags_re.finditer(text):
tags.add(m.group(0)[1:])
tag = m.group(0)[1:]
tags.add(tag.lower())
positions.extend(itertools.chain.from_iterable(m.regs))
positions.append(len(text))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tribler_core.components.ipv8.ipv8_component import INFINITE, Ipv8Component
from tribler_core.components.metadata_store.metadata_store_component import MetadataStoreComponent
from tribler_core.components.reporter.reporter_component import ReporterComponent
from tribler_core.components.tag.tag_component import TagComponent


class GigaChannelComponent(Component):
Expand All @@ -24,6 +25,7 @@ async def run(self):

self._ipv8_component = await self.require_component(Ipv8Component)
metadata_store_component = await self.require_component(MetadataStoreComponent)
tag_component = await self.get_component(TagComponent)

giga_channel_cls = GigaChannelTestnetCommunity if config.general.testnet else GigaChannelCommunity
community = giga_channel_cls(
Expand All @@ -35,6 +37,7 @@ async def run(self):
rqc_settings=config.remote_query_community,
metadata_store=metadata_store_component.mds,
max_peers=50,
tags_db=tag_component.tags_db if tag_component else None
)
self.community = community
self._ipv8_component.initialise_community_by_default(community, default_random_walk_max_peers=30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import struct
from asyncio import Future
from binascii import unhexlify
from typing import List, Optional, Set

from ipv8.lazy_community import lazy_wrapper
from ipv8.messaging.lazy_payload import VariablePayload, vp_compile
from ipv8.requestcache import NumberCache, RandomNumberCache, RequestCache

from pony.orm import db_session
from pony.orm.dbapiprovider import OperationalError

from tribler_core.components.ipv8.tribler_community import TriblerCommunity
Expand All @@ -17,6 +19,7 @@
from tribler_core.components.metadata_store.remote_query_community.payload_checker import ObjState
from tribler_core.components.metadata_store.remote_query_community.settings import RemoteQueryCommunitySettings
from tribler_core.components.metadata_store.utils import RequestTimeoutException
from tribler_core.components.tag.community.tag_validator import is_valid_tag
from tribler_core.utilities.unicode import hexlify

BINARY_FIELDS = ("infohash", "channel_pk")
Expand Down Expand Up @@ -129,12 +132,13 @@ class RemoteQueryCommunity(TriblerCommunity, EVAProtocolMixin):
def __init__(self, my_peer, endpoint, network,
rqc_settings: RemoteQueryCommunitySettings = None,
metadata_store=None,
tags_db=None,
**kwargs):
super().__init__(my_peer, endpoint, network=network, **kwargs)

self.rqc_settings = rqc_settings
self.mds: MetadataStore = metadata_store

self.tags_db = tags_db
# This object stores requests for "select" queries that we sent to other hosts.
# We keep track of peers we actually requested for data so people can't randomly push spam at us.
# Also, this keeps track of hosts we responded to. There is a possibility that
Expand Down Expand Up @@ -188,8 +192,23 @@ async def process_rpc_query(self, json_bytes: bytes):
:raises ValueError: if no JSON could be decoded.
:raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed.
"""
request_sanitized = sanitize_query(json.loads(json_bytes), self.rqc_settings.max_response_size)
return await self.mds.get_entries_threaded(**request_sanitized)
parameters = json.loads(json_bytes)
sanitized_parameters = sanitize_query(parameters, self.rqc_settings.max_response_size)

# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)

infohash_set = await self.mds.run_threaded(self.search_for_tags, tags)
sanitized_parameters['infohash_set'] = infohash_set # it could be None, it is expected

return await self.mds.get_entries_threaded(**sanitized_parameters)

@db_session
def search_for_tags(self, tags: Optional[List[str]]) -> Optional[Set[bytes]]:
if not tags or not self.tags_db:
return None
valid_tags = {tag for tag in tags if is_valid_tag(tag)}
return self.tags_db.get_infohashes(valid_tags)

def send_db_results(self, peer, request_payload_id, db_results, force_eva_response=False):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def setUp(self):
self.count = 0
self.metadata_store_set = set()
self.initialize(BasicRemoteQueryCommunity, 2)
self.torrent_template = {"title": "", "infohash": b"", "torrent_date": datetime(1970, 1, 1), "tags": "video"}

async def tearDown(self):
for metadata_store in self.metadata_store_set:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from json import dumps
from unittest.mock import AsyncMock, Mock, PropertyMock, patch

from ipv8.keyvault.crypto import default_eccrypto
from ipv8.test.base import TestBase

from pony.orm import db_session

from tribler_core.components.metadata_store.db.orm_bindings.channel_node import NEW
from tribler_core.components.metadata_store.db.store import MetadataStore
from tribler_core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity
from tribler_core.components.metadata_store.remote_query_community.settings import RemoteQueryCommunitySettings
from tribler_core.components.metadata_store.remote_query_community.tests.test_remote_query_community import (
BasicRemoteQueryCommunity,
)
from tribler_core.components.tag.db.tag_db import TagDatabase
from tribler_core.components.tag.db.tests.test_tag_db import Tag, TestTagDB
from tribler_core.utilities.path_util import Path


class TestRemoteSearchByTags(TestBase):
""" In this test set we will use only one node's instance as it is sufficient
for testing remote search by tags
"""

def setUp(self):
super().setUp()
self.metadata_store = None
self.tags_db = None
self.initialize(BasicRemoteQueryCommunity, 1)

async def tearDown(self):
if self.metadata_store:
self.metadata_store.shutdown()
if self.tags_db:
self.tags_db.shutdown()

await super().tearDown()

def create_node(self, *args, **kwargs):
self.metadata_store = MetadataStore(
Path(self.temporary_directory()) / "mds.db",
Path(self.temporary_directory()),
default_eccrypto.generate_key("curve25519"),
disable_sync=True,
)
self.tags_db = TagDatabase(str(Path(self.temporary_directory()) / "tags.db"))

kwargs['metadata_store'] = self.metadata_store
kwargs['tags_db'] = self.tags_db
kwargs['rqc_settings'] = RemoteQueryCommunitySettings()
return super().create_node(*args, **kwargs)

@property
def rqc(self) -> RemoteQueryCommunity:
return self.overlay(0)

@patch.object(RemoteQueryCommunity, 'tags_db', new=PropertyMock(return_value=None), create=True)
async def test_search_for_tags_no_db(self):
# test that in case of missed `tags_db`, function `search_for_tags` returns None
assert self.rqc.search_for_tags(tags=['tag']) is None

@patch.object(TagDatabase, 'get_infohashes')
async def test_search_for_tags_only_valid_tags(self, mocked_get_infohashes: Mock):
# test that function `search_for_tags` uses only valid tags
self.rqc.search_for_tags(tags=['invalid tag', 'valid_tag'])
mocked_get_infohashes.assert_called_with({'valid_tag'})

@patch.object(MetadataStore, 'get_entries_threaded', new_callable=AsyncMock)
async def test_process_rpc_query_no_tags(self, mocked_get_entries_threaded: AsyncMock):
# test that in case of missed tags, the remote search works like normal remote search
parameters = {'first': 0, 'infohash_set': None, 'last': 100}
json = dumps(parameters).encode('utf-8')

await self.rqc.process_rpc_query(json)

expected_parameters = {'infohash_set': None}
expected_parameters.update(parameters)
mocked_get_entries_threaded.assert_called_with(**expected_parameters)

async def test_process_rpc_query_with_tags(self):
# This is full test that checked whether search by tags works or not
#
# Test assumes that two databases were filled by the following data (TagsDatabase and MDS):
@db_session
def fill_tags_database():
TestTagDB.add_operation_set(
self.rqc.tags_db,
{
b'infohash1': [
Tag(name='tag1', count=2),
],
b'infohash2': [
Tag(name='tag2', count=1),
]
})

@db_session
def fill_mds():
with db_session:
def _add(infohash):
torrent = {"infohash": infohash, "title": 'title', "tags": "", "size": 1, "status": NEW}
self.rqc.mds.TorrentMetadata.from_dict(torrent)

_add(b'infohash1')
_add(b'infohash2')
_add(b'infohash3')

fill_tags_database()
fill_mds()

# Then we try to query search for three tags: 'tag1', 'tag2', 'tag3'
parameters = {'first': 0, 'infohash_set': None, 'last': 100, 'tags': ['tag1', 'tag2', 'tag3']}
json = dumps(parameters).encode('utf-8')

with db_session:
query_results = [r.to_dict() for r in await self.rqc.process_rpc_query(json)]

# Expected results: only one infohash (b'infohash1') should be returned.
result_infohash_list = [r['infohash'] for r in query_results]
assert result_infohash_list == [b'infohash1']
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def sanitize_parameters(self, parameters):
)
@querystring_schema(RemoteQueryParameters)
async def create_remote_search_request(self, request):
self._logger.info('Create remote search request')
# Query remote results from the GigaChannel Community.
# Results are returned over the Events endpoint.
try:
sanitized = self.sanitize_parameters(request.query)
except (ValueError, KeyError) as e:
return RESTResponse({"error": f"Error processing request parameters: {e}"}, status=HTTP_BAD_REQUEST)
self._logger.info(f'Parameters: {sanitized}')

request_uuid, peers_list = self.gigachannel_community.send_search_request(**sanitized)
peers_mid_list = [hexlify(p.mid) for p in peers_list]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def search_db():
try:
with db_session:
if tags:
lower_tags = {tag.lower() for tag in tags}
infohash_set = self.tags_db.get_infohashes(lower_tags)
infohash_set = self.tags_db.get_infohashes(set(tags))
sanitized['infohash_set'] = infohash_set

search_results, total, max_rowid = await mds.run_threaded(search_db)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,13 @@ def validate_tag(tag: str):
raise ValueError('Tag should not contain any spaces')


def is_valid_tag(tag: str) -> bool:
try:
validate_tag(tag)
except ValueError:
return False
return True


def validate_operation(operation: int):
TagOperationEnum(operation)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from tribler_core.components.tag.community.tag_payload import TagOperationEnum
from tribler_core.components.tag.community.tag_validator import validate_operation, validate_tag
from tribler_core.components.tag.community.tag_validator import is_valid_tag, validate_operation, validate_tag

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -49,3 +49,10 @@ async def test_contains_upper_case_not_latin():
async def test_contain_any_space():
with pytest.raises(ValueError):
validate_tag('tag with space')


async def test_is_valid_tag():
# test that is_valid_tag works similar to validate_tag but it returns `bool`
# instead of raise the ValueError exception
assert is_valid_tag('valid-tag')
assert not is_valid_tag('invalid tag')
Loading

0 comments on commit cf2bc24

Please sign in to comment.