Skip to content

Commit

Permalink
Fixes from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Feb 1, 2022
1 parent d42afef commit 0a5ea31
Show file tree
Hide file tree
Showing 18 changed files with 147 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from pony import orm
from pony.orm.core import DEFAULT, db_session

from tribler_core.exceptions import InvalidChannelNodeException, InvalidSignatureException
from tribler_core.components.metadata_store.db.orm_bindings.discrete_clock import clock
from tribler_core.components.metadata_store.db.serialization import (
CHANNEL_NODE,
ChannelNodePayload,
DELETED,
DeletedMetadataPayload,
)
from tribler_core.exceptions import InvalidChannelNodeException, InvalidSignatureException
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.unicode import hexlify

Expand Down Expand Up @@ -87,8 +87,8 @@ class ChannelNode(db.Entity):
# This attribute holds the names of the class attributes that are used by the serializer for the
# corresponding payload type. We only initialize it once on class creation as an optimization.
payload_arguments = _payload_class.__init__.__code__.co_varnames[
: _payload_class.__init__.__code__.co_argcount
][1:]
: _payload_class.__init__.__code__.co_argcount
][1:]

# A non - personal attribute of an entry is an attribute that would have the same value regardless of where,
# when and who created the entry.
Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self, *args, **kwargs):
if not private_key_override and not skip_key_check:
# No key/signature given, sign with our own key.
if ("signature" not in kwargs) and (
("public_key" not in kwargs) or (kwargs["public_key"] == self._my_key.pub().key_to_bin()[10:])
("public_key" not in kwargs) or (kwargs["public_key"] == self._my_key.pub().key_to_bin()[10:])
):
private_key_override = self._my_key

Expand Down Expand Up @@ -298,12 +298,15 @@ def make_copy(self, tgt_parent_id, attributes_override=None):
dst_dict.update({"origin_id": tgt_parent_id, "status": NEW})
return self.__class__(**dst_dict)

def get_type(self) -> int:
return self._discriminator_

def to_simple_dict(self):
"""
Return a basic dictionary with information about the node
"""
simple_dict = {
"type": self._discriminator_,
"type": self.get_type(),
"id": self.id_,
"origin_id": self.origin_id,
"public_key": hexlify(self.public_key),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def tdef_to_metadata_dict(tdef):
}


def define_binding(db, notifier: Notifier, tag_version: int):
def define_binding(db, notifier: Notifier, tag_processor_version: int):
class TorrentMetadata(db.MetadataNode):
"""
This ORM binding class is intended to store Torrent objects, i.e. infohashes along with some related metadata.
Expand All @@ -63,7 +63,7 @@ class TorrentMetadata(db.MetadataNode):
# Local
xxx = orm.Optional(float, default=0)
health = orm.Optional('TorrentState', reverse='metadata')
tag_version = orm.Required(int, default=0)
tag_processor_version = orm.Required(int, default=0)

# Special class-level properties
_payload_class = TorrentMetadataPayload
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self, *args, **kwargs):
notifier.notify(NEW_TORRENT_METADATA_CREATED,
infohash=kwargs.get("infohash"),
title=self.title)
self.tag_version = tag_version
self.tag_processor_version = tag_processor_version

def add_tracker(self, tracker_url):
sanitized_url = get_uniformed_tracker_url(tracker_url)
Expand Down Expand Up @@ -140,6 +140,7 @@ def to_simple_dict(self):
"num_leechers": self.health.leechers,
"last_tracker_check": self.health.last_check,
"updated": int((self.torrent_date - epoch).total_seconds()),
"tag_processor_version": self.tag_processor_version,
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(
notifier=None,
check_tables=True,
db_version: int = CURRENT_DB_VERSION,
tag_version: int = 0
tag_processor_version: int = 0
):
self.notifier = notifier # Reference to app-level notification service
self.db_path = db_filename
Expand Down Expand Up @@ -194,7 +194,7 @@ def sqlite_disable_sync(_, connection):
self.TorrentMetadata = torrent_metadata.define_binding(
self._db,
notifier=notifier,
tag_version=tag_version
tag_processor_version=tag_processor_version
)
self.ChannelMetadata = channel_metadata.define_binding(self._db)

Expand Down Expand Up @@ -773,7 +773,7 @@ def get_entries_count(self, **kwargs):
return self.get_entries_query(**kwargs).count()

@db_session
def get_max_rowid(self):
def get_max_rowid(self) -> int:
return select(max(obj.rowid) for obj in self.ChannelNode).get() or 0

fts_keyword_search_re = re.compile(r'\w+', re.UNICODE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def run(self):
key_component.primary_key,
notifier=self.session.notifier,
disable_sync=config.gui_test_mode,
tag_version=TagRulesProcessor.version
tag_processor_version=TagRulesProcessor.version
)
self.mds = metadata_store
self.session.notifier.add_observer(NTFY.TORRENT_METADATA_ADDED.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create_node(self, *args, **kwargs):
default_eccrypto.generate_key("curve25519"),
disable_sync=True,
)
self.tags_db = TagDatabase(str(Path(self.temporary_directory()) / "tags.db"), create_tables=True)
self.tags_db = TagDatabase(str(Path(self.temporary_directory()) / "tags.db"))

kwargs['metadata_store'] = self.metadata_store
kwargs['tags_db'] = self.tags_db
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def get_channels(self, request):
)
elif channel_dict["state"] == CHANNEL_STATE.METAINFO_LOOKUP.value:
if not self.download_manager.metainfo_requests.get(
bytes(channel.infohash)
bytes(channel.infohash)
) and self.download_manager.download_exists(bytes(channel.infohash)):
channel_dict["state"] = CHANNEL_STATE.DOWNLOADING.value

Expand Down Expand Up @@ -169,6 +169,7 @@ async def get_channels(self, request):
},
)
async def get_channel_contents(self, request):
self._logger.info('Get channel content')
sanitized = self.sanitize_parameters(request.query)
include_total = request.query.get('include_total', '')
channel_pk, channel_id = self.get_channel_from_request(request)
Expand All @@ -180,14 +181,21 @@ async def get_channel_contents(self, request):
remote_failed = False
if remote:
try:
self._logger.info('Receive remote content')
contents_list = await self.gigachannel_community.remote_select_channel_contents(**sanitized)
except (RequestTimeoutException, NoChannelSourcesException, CancelledError):
remote_failed = True
self._logger.info('Remote request failed')


if not remote or remote_failed:
self._logger.info('Receive local content')
with db_session:
contents = self.mds.get_entries(**sanitized)
contents_list = [c.to_simple_dict() for c in contents]
contents_list = []
for entry in contents:
self.process_regular_torrent(entry)
contents_list.append(entry.to_simple_dict())
total = self.mds.get_total_count(**sanitized) if include_total else None
self.add_download_progress_to_metadata_list(contents_list)
self.add_tags_to_metadata_list(contents_list, hide_xxx=sanitized["hide_xxx"])
Expand Down Expand Up @@ -390,9 +398,9 @@ async def add_torrent_to_channel(self, request):
elif uri.startswith("magnet:"):
_, xt, _ = parse_magnetlink(uri)
if (
xt
and is_infohash(codecs.encode(xt, 'hex'))
and (self.mds.torrent_exists_in_personal_channel(xt) or channel.copy_torrent_from_infohash(xt))
xt
and is_infohash(codecs.encode(xt, 'hex'))
and (self.mds.torrent_exists_in_personal_channel(xt) or channel.copy_torrent_from_infohash(xt))
):
return RESTResponse({"added": 1})

Expand Down Expand Up @@ -492,7 +500,11 @@ async def get_popular_torrents_channel(self, request):

with db_session:
contents = self.mds.get_entries(**sanitized)
contents_list = [c.to_simple_dict() for c in contents]
contents_list = []
for entry in contents:
self.process_regular_torrent(entry)
contents_list.append(entry.to_simple_dict())

self.add_download_progress_to_metadata_list(contents_list)
self.add_tags_to_metadata_list(contents_list, hide_xxx=sanitized["hide_xxx"])
response_dict = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tribler_core.components.metadata_store.db.store import MetadataStore
from tribler_core.components.restapi.rest.rest_endpoint import RESTEndpoint
from tribler_core.components.tag.db.tag_db import TagDatabase
from tribler_core.components.tag.rules.tag_rules_processor import TagRulesProcessor

# This dict is used to translate JSON fields into the columns used in Pony for _sorting_.
# id_ is not in the list because there is not index on it, so we never really want to sort on it.
Expand Down Expand Up @@ -37,10 +38,12 @@


class MetadataEndpointBase(RESTEndpoint):
def __init__(self, metadata_store: MetadataStore, *args, tags_db: TagDatabase = None, **kwargs):
def __init__(self, metadata_store: MetadataStore, *args, tags_db: TagDatabase = None,
tag_rules_processor: TagRulesProcessor = None, **kwargs):
super().__init__(*args, **kwargs)
self.mds = metadata_store
self.tags_db: Optional[TagDatabase] = tags_db
self.tag_rules_processor: Optional[TagRulesProcessor] = tag_rules_processor

@classmethod
def sanitize_parameters(cls, parameters):
Expand Down Expand Up @@ -68,6 +71,20 @@ def sanitize_parameters(cls, parameters):
sanitized['metadata_type'] = frozenset(mtypes)
return sanitized

def process_regular_torrent(self, entry):
is_torrent = entry.get_type() == REGULAR_TORRENT
if not is_torrent:
return

if not self.tag_rules_processor:
return

is_auto_generated_tags_not_created = entry.tag_processor_version < self.tag_rules_processor.version
if is_auto_generated_tags_not_created:
generated = self.tag_rules_processor.process_torrent_title(infohash=entry.infohash, title=entry.title)
entry.tag_processor_version = self.tag_rules_processor.version
self._logger.info(f'Generated {generated} tags for {entry.infohash}')

@db_session
def add_tags_to_metadata_list(self, contents_list, hide_xxx=False):
if self.tags_db is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ async def run(self):
self.maybe_add('/libtorrent', LibTorrentEndpoint, libtorrent_component.download_manager)
self.maybe_add('/torrentinfo', TorrentInfoEndpoint, libtorrent_component.download_manager)
self.maybe_add('/metadata', MetadataEndpoint, torrent_checker, metadata_store_component.mds,
tags_db=tag_component.tags_db)
tags_db=tag_component.tags_db, tag_rules_processor=tag_component.rules_processor)
self.maybe_add('/channels', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager,
gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db)
gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db,
tag_rules_processor=tag_component.rules_processor)
self.maybe_add('/collections', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager,
gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db)
gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db,
tag_rules_processor=tag_component.rules_processor)
self.maybe_add('/search', SearchEndpoint, metadata_store_component.mds, tags_db=tag_component.tags_db)
self.maybe_add('/remote_query', RemoteQueryEndpoint, gigachannel_component.community,
metadata_store_component.mds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def tearDown(self):
await super().tearDown()

def create_node(self, *args, **kwargs):
return MockIPv8("curve25519", TagCommunity, db=TagDatabase(create_tables=True), tags_key=LibNaCLSK(),
return MockIPv8("curve25519", TagCommunity, db=TagDatabase(), tags_key=LibNaCLSK(),
request_interval=REQUEST_INTERVAL_FOR_RANDOM_TAGS)

def create_operation(self, tag=''):
Expand Down
18 changes: 12 additions & 6 deletions src/tribler-core/tribler_core/components/tag/db/tag_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@

CLOCK_START_VALUE = 0

# we picked `-1` as a value because it is allows manually created tags get a higher priority
CLOCK_FOR_AUTOGENERATED_TAGS = CLOCK_START_VALUE - 1
PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS = b'auto_generated'

SHOW_THRESHOLD = 1
HIDE_THRESHOLD = -2


class TagDatabase:
def __init__(self, filename: Optional[str] = None, **kwargs):
def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs):
self.instance = orm.Database()
self.define_binding(self.instance)
self.instance.bind('sqlite', filename or ':memory:', create_db=True)
self.instance.generate_mapping(**kwargs)
generate_mapping_kwargs['create_tables'] = create_tables
self.instance.generate_mapping(**generate_mapping_kwargs)
self.logger = logging.getLogger(self.__class__.__name__)

@staticmethod
Expand Down Expand Up @@ -95,7 +94,6 @@ class TorrentTagOp(db.Entity):

orm.composite_key(torrent_tag, peer)


def add_tag_operation(self, operation: TagOperation, signature: bytes, is_local_peer: bool = False,
is_auto_generated: bool = False, counter_increment: int = 1) -> bool:
""" Add the operation that will be applied to the tag.
Expand Down Expand Up @@ -136,7 +134,15 @@ def add_tag_operation(self, operation: TagOperation, signature: bytes, is_local_
updated_at=datetime.datetime.utcnow(), auto_generated=is_auto_generated)
return True

def add_auto_generated_tag_operation(self, operation: TagOperation):
def add_auto_generated_tag(self, infohash: bytes, tag: str):
operation = TagOperation(
infohash=infohash,
operation=TagOperationEnum.ADD,
clock=CLOCK_START_VALUE,
creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS,
tag=tag
)

self.add_tag_operation(operation, signature=b'', is_local_peer=False, is_auto_generated=True,
counter_increment=SHOW_THRESHOLD)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
from dataclasses import dataclass
from itertools import count
from types import SimpleNamespace
from unittest.mock import Mock, patch

from ipv8.test.base import TestBase

from pony import orm
from pony.orm import commit, db_session

# pylint: disable=protected-access
from tribler_core.components.tag.community.tag_payload import TagOperation, TagOperationEnum
from tribler_core.components.tag.db.tag_db import (
CLOCK_FOR_AUTOGENERATED_TAGS,
PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS,
SHOW_THRESHOLD,
TagDatabase,
)
from tribler_core.components.tag.db.tag_db import PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, SHOW_THRESHOLD, TagDatabase
from tribler_core.utilities.pony_utils import get_or_create


Expand All @@ -28,7 +25,7 @@ class Tag:
class TestTagDB(TestBase):
def setUp(self):
super().setUp()
self.db = TagDatabase(create_tables=True)
self.db = TagDatabase()

async def tearDown(self):
if self._outcome.errors:
Expand Down Expand Up @@ -83,6 +80,16 @@ def generate_n_peer_names(n):
for peer in generate_n_peer_names(tag.count):
TestTagDB.add_operation(tag_db, infohash, tag.name, peer, is_auto_generated=tag.auto_generated)

@patch.object(orm.Database, 'generate_mapping')
def test_constructor_create_tables_true(self, mocked_generate_mapping: Mock):
TagDatabase(':memory:')
mocked_generate_mapping.assert_called_with(create_tables=True)

@patch.object(orm.Database, 'generate_mapping')
def test_constructor_create_tables_false(self, mocked_generate_mapping: Mock):
TagDatabase(':memory:', create_tables=False)
mocked_generate_mapping.assert_called_with(create_tables=False)

@db_session
async def test_get_or_create(self):
# Test that function get_or_create() works as expected:
Expand Down Expand Up @@ -188,19 +195,13 @@ async def test_remote_add_multiple_tag_operations(self):
assert self.db.instance.TorrentTag.get().removed_count == 1

@db_session
async def test_add_auto_generated_operation(self):
self.db.add_auto_generated_tag_operation(
operation=TagOperation(
infohash=b'infohash',
operation=TagOperationEnum.ADD,
clock=CLOCK_FOR_AUTOGENERATED_TAGS,
creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS,
tag='tag'
)
async def test_add_auto_generated_tag(self):
self.db.add_auto_generated_tag(
infohash=b'infohash',
tag='tag'
)

assert self.db.instance.TorrentTagOp.get().auto_generated
assert self.db.instance.TorrentTagOp.get().clock == CLOCK_FOR_AUTOGENERATED_TAGS
assert self.db.instance.TorrentTag.get().added_count == SHOW_THRESHOLD
assert self.db.instance.Peer.get().public_key == PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@ def extract_tags(text: str, rules: Optional[RulesList] = None) -> Iterable[str]:


def extract_only_valid_tags(text: str, rules: Optional[RulesList] = None) -> Iterable[str]:
extracted_tags_gen = (t.lower() for t in extract_tags(text, rules))
yield from (t for t in extracted_tags_gen if is_valid_tag(t))
for tag in extract_tags(text, rules):
tag = tag.lower()
if is_valid_tag(tag):
yield tag
Loading

0 comments on commit 0a5ea31

Please sign in to comment.