Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tags auto generation #6718

Merged
merged 5 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/tribler-core/run_tribler_upgrader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import signal
import sys

Expand All @@ -10,17 +11,21 @@
from tribler_core.upgrade.upgrade import TriblerUpgrader
from tribler_core.utilities.path_util import Path

logger = logging.getLogger(__name__)


def upgrade_state_dir(root_state_dir: Path,
update_status_callback=None,
interrupt_upgrade_event=None):
logger.info('Upgrade state dir')
# Before any upgrade, prepare a separate state directory for the update version so it does not
# affect the older version state directory. This allows for safe rollback.
version_history = VersionHistory(root_state_dir)
version_history.fork_state_directory_if_necessary()
version_history.save_if_necessary()
state_dir = version_history.code_version.directory
if not state_dir.exists():
logger.info('State dir does not exist. Exit upgrade procedure.')
return

config = TriblerConfig.load(file=state_dir / CONFIG_FILE_NAME, state_dir=state_dir, reset_config_on_error=True)
Expand All @@ -37,15 +42,19 @@ def upgrade_state_dir(root_state_dir: Path,


if __name__ == "__main__":

logger.info('Run')
_upgrade_interrupted_event = []


def interrupt_upgrade(*_):
logger.info('Interrupt upgrade')
_upgrade_interrupted_event.append(True)


def upgrade_interrupted():
return bool(_upgrade_interrupted_event)


signal.signal(signal.SIGINT, interrupt_upgrade)
signal.signal(signal.SIGTERM, interrupt_upgrade)
_root_state_dir = Path(sys.argv[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from pony import orm
from pony.orm.core import DEFAULT, db_session

from tribler_core.exceptions import InvalidChannelNodeException, InvalidSignatureException
from tribler_core.components.metadata_store.db.orm_bindings.discrete_clock import clock
from tribler_core.components.metadata_store.db.serialization import (
CHANNEL_NODE,
ChannelNodePayload,
DELETED,
DeletedMetadataPayload,
)
from tribler_core.exceptions import InvalidChannelNodeException, InvalidSignatureException
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.unicode import hexlify

Expand Down Expand Up @@ -87,8 +87,8 @@ class ChannelNode(db.Entity):
# This attribute holds the names of the class attributes that are used by the serializer for the
# corresponding payload type. We only initialize it once on class creation as an optimization.
payload_arguments = _payload_class.__init__.__code__.co_varnames[
: _payload_class.__init__.__code__.co_argcount
][1:]
: _payload_class.__init__.__code__.co_argcount
][1:]

# A non - personal attribute of an entry is an attribute that would have the same value regardless of where,
# when and who created the entry.
Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self, *args, **kwargs):
if not private_key_override and not skip_key_check:
# No key/signature given, sign with our own key.
if ("signature" not in kwargs) and (
("public_key" not in kwargs) or (kwargs["public_key"] == self._my_key.pub().key_to_bin()[10:])
("public_key" not in kwargs) or (kwargs["public_key"] == self._my_key.pub().key_to_bin()[10:])
):
private_key_override = self._my_key

Expand Down Expand Up @@ -298,12 +298,15 @@ def make_copy(self, tgt_parent_id, attributes_override=None):
dst_dict.update({"origin_id": tgt_parent_id, "status": NEW})
return self.__class__(**dst_dict)

def get_type(self) -> int:
return self._discriminator_

def to_simple_dict(self):
"""
Return a basic dictionary with information about the node
"""
simple_dict = {
"type": self._discriminator_,
"type": self.get_type(),
"id": self.id_,
"origin_id": self.origin_id,
"public_key": hexlify(self.public_key),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from tribler_core.components.metadata_store.category_filter.family_filter import default_xxx_filter
from tribler_core.components.metadata_store.db.orm_bindings.channel_node import COMMITTED
from tribler_core.components.metadata_store.db.serialization import EPOCH, REGULAR_TORRENT, TorrentMetadataPayload
from tribler_core.notifier import Notifier
from tribler_core.utilities.tracker_utils import get_uniformed_tracker_url
from tribler_core.utilities.unicode import ensure_unicode, hexlify

NULL_KEY_SUBST = b"\00"
NEW_TORRENT_METADATA_CREATED: str = 'TorrentMetadata:new_torrent_metadata_created'


# This function is used to devise id_ from infohash in deterministic way. Used in FFA channels.
Expand Down Expand Up @@ -44,7 +46,7 @@ def tdef_to_metadata_dict(tdef):
}


def define_binding(db):
def define_binding(db, notifier: Notifier, tag_processor_version: int):
class TorrentMetadata(db.MetadataNode):
"""
This ORM binding class is intended to store Torrent objects, i.e. infohashes along with some related metadata.
Expand All @@ -61,12 +63,13 @@ class TorrentMetadata(db.MetadataNode):
# Local
xxx = orm.Optional(float, default=0)
health = orm.Optional('TorrentState', reverse='metadata')
tag_processor_version = orm.Required(int, default=0)

# Special class-level properties
_payload_class = TorrentMetadataPayload
payload_arguments = _payload_class.__init__.__code__.co_varnames[
: _payload_class.__init__.__code__.co_argcount
][1:]
: _payload_class.__init__.__code__.co_argcount
][1:]
nonpersonal_attributes = db.MetadataNode.nonpersonal_attributes + (
'infohash',
'size',
Expand All @@ -86,6 +89,11 @@ def __init__(self, *args, **kwargs):

if 'tracker_info' in kwargs:
self.add_tracker(kwargs["tracker_info"])
if notifier:
notifier.notify(NEW_TORRENT_METADATA_CREATED,
infohash=kwargs.get("infohash"),
title=self.title)
self.tag_processor_version = tag_processor_version

def add_tracker(self, tracker_url):
sanitized_url = get_uniformed_tracker_url(tracker_url)
Expand Down Expand Up @@ -132,6 +140,7 @@ def to_simple_dict(self):
"num_leechers": self.health.leechers,
"last_tracker_check": self.health.last_check,
"updated": int((self.torrent_date - epoch).total_seconds()),
"tag_processor_version": self.tag_processor_version,
}
)

Expand Down
95 changes: 54 additions & 41 deletions src/tribler-core/tribler_core/components/metadata_store/db/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from asyncio import get_event_loop
from datetime import datetime, timedelta
from time import sleep, time
from typing import Union
from typing import Optional, Union

from lz4.frame import LZ4FrameDecompressor

Expand Down Expand Up @@ -49,19 +49,19 @@
from tribler_core.components.metadata_store.remote_query_community.payload_checker import process_payload
from tribler_core.exceptions import InvalidSignatureException
from tribler_core.utilities.path_util import Path
from tribler_core.utilities.pony_utils import get_or_create
from tribler_core.utilities.unicode import hexlify
from tribler_core.utilities.utilities import MEMORY_DB

BETA_DB_VERSIONS = [0, 1, 2, 3, 4, 5]
CURRENT_DB_VERSION = 13
CURRENT_DB_VERSION = 14

MIN_BATCH_SIZE = 10
MAX_BATCH_SIZE = 1000

POPULAR_TORRENTS_FRESHNESS_PERIOD = 60 * 60 * 24 # Last day
POPULAR_TORRENTS_COUNT = 100


# This table should never be used from ORM directly.
# It is created as a VIRTUAL table by raw SQL and
# maintained by SQL triggers.
Expand Down Expand Up @@ -136,14 +136,15 @@

class MetadataStore:
def __init__(
self,
db_filename: Union[Path, type(MEMORY_DB)],
channels_dir,
my_key,
disable_sync=False,
notifier=None,
check_tables=True,
db_version: int = CURRENT_DB_VERSION,
self,
devos50 marked this conversation as resolved.
Show resolved Hide resolved
db_filename: Union[Path, type(MEMORY_DB)],
channels_dir,
my_key,
disable_sync=False,
notifier=None,
check_tables=True,
db_version: int = CURRENT_DB_VERSION,
tag_processor_version: int = 0
):
self.notifier = notifier # Reference to app-level notification service
self.db_path = db_filename
Expand Down Expand Up @@ -190,7 +191,11 @@ def sqlite_disable_sync(_, connection):

self.MetadataNode = metadata_node.define_binding(self._db)
self.CollectionNode = collection_node.define_binding(self._db)
self.TorrentMetadata = torrent_metadata.define_binding(self._db)
self.TorrentMetadata = torrent_metadata.define_binding(
self._db,
notifier=notifier,
tag_processor_version=tag_processor_version
)
self.ChannelMetadata = channel_metadata.define_binding(self._db)

self.JsonNode = json_node.define_binding(self._db, db_version)
Expand Down Expand Up @@ -242,6 +247,14 @@ def wrapper():

return await get_event_loop().run_in_executor(None, wrapper)

def set_value(self, key: str, value: str):
key_value = get_or_create(self.MiscData, name=key)
key_value.value = value

def get_value(self, key: str, default: Optional[str] = None) -> Optional[str]:
data = self.MiscData.get(name=key)
return data.value if data else default

def drop_indexes(self):
cursor = self._db.get_connection().cursor()
cursor.execute("select name from sqlite_master where type='index' and name like 'idx_%'")
Expand Down Expand Up @@ -391,9 +404,9 @@ def process_channel_dir(self, dirname, public_key, id_, **kwargs):
if not channel:
return
if (
blob_sequence_number <= channel.start_timestamp
or blob_sequence_number <= channel.local_version
or blob_sequence_number > channel.timestamp
blob_sequence_number <= channel.start_timestamp
or blob_sequence_number <= channel.local_version
or blob_sequence_number > channel.timestamp
):
continue
try:
Expand Down Expand Up @@ -595,28 +608,28 @@ def search_keyword(self, query, lim=100):

@db_session
def get_entries_query(
self,
metadata_type=None,
channel_pk=None,
exclude_deleted=False,
hide_xxx=False,
exclude_legacy=False,
origin_id=None,
sort_by=None,
sort_desc=True,
max_rowid=None,
txt_filter=None,
subscribed=None,
category=None,
attribute_ranges=None,
infohash=None,
infohash_set=None,
id_=None,
complete_channel=None,
self_checked_torrent=None,
cls=None,
health_checked_after=None,
popular=None,
self,
metadata_type=None,
channel_pk=None,
exclude_deleted=False,
hide_xxx=False,
exclude_legacy=False,
origin_id=None,
sort_by=None,
sort_desc=True,
max_rowid=None,
txt_filter=None,
subscribed=None,
category=None,
attribute_ranges=None,
infohash=None,
infohash_set=None,
id_=None,
complete_channel=None,
self_checked_torrent=None,
cls=None,
health_checked_after=None,
popular=None,
):
"""
This method implements REST-friendly way to get entries from the database.
Expand Down Expand Up @@ -662,8 +675,8 @@ def get_entries_query(
if attribute_ranges is not None:
for attr, left, right in attribute_ranges:
if (
self.ChannelNode._adict_.get(attr) # pylint: disable=W0212
or self.ChannelNode._subclass_adict_.get(attr) # pylint: disable=W0212
self.ChannelNode._adict_.get(attr) # pylint: disable=W0212
or self.ChannelNode._subclass_adict_.get(attr) # pylint: disable=W0212
) is None: # Check against code injection
raise AttributeError("Tried to query for non-existent attribute")
if left is not None:
Expand Down Expand Up @@ -737,7 +750,7 @@ def get_entries(self, first=1, last=None, **kwargs):
:return: A list of class members
"""
pony_query = self.get_entries_query(**kwargs)
result = pony_query[(first or 1) - 1 : last]
result = pony_query[(first or 1) - 1: last]
for entry in result:
# ACHTUNG! This is necessary in order to load entry.health inside db_session,
# to be able to perform successfully `entry.to_simple_dict()` later
Expand All @@ -760,7 +773,7 @@ def get_entries_count(self, **kwargs):
return self.get_entries_query(**kwargs).count()

@db_session
def get_max_rowid(self):
def get_max_rowid(self) -> int:
return select(max(obj.rowid) for obj in self.ChannelNode).get() or 0

fts_keyword_search_re = re.compile(r'\w+', re.UNICODE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from tribler_core.components.base import Component
from tribler_core.components.key.key_component import KeyComponent
from tribler_core.components.metadata_store.db.store import MetadataStore
from tribler_core.components.metadata_store.utils import generate_test_channels
from tribler_core.components.tag.tag_component import TagComponent
from tribler_core.components.tag.rules.tag_rules_processor import TagRulesProcessor


class MetadataStoreComponent(Component):
Expand Down Expand Up @@ -42,13 +41,11 @@ async def run(self):
key_component.primary_key,
notifier=self.session.notifier,
disable_sync=config.gui_test_mode,
tag_processor_version=TagRulesProcessor.version
)
self.mds = metadata_store
self.session.notifier.add_observer(NTFY.TORRENT_METADATA_ADDED,
self.session.notifier.add_observer(NTFY.TORRENT_METADATA_ADDED.value,
metadata_store.TorrentMetadata.add_ffa_from_dict)
if config.gui_test_mode:
tag_component = await self.require_component(TagComponent)
generate_test_channels(metadata_store, tag_component.tags_db)

async def shutdown(self):
await super().shutdown()
Expand Down
Loading