diff --git a/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/channel_node.py b/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/channel_node.py index 1dfffef204a..e4913e56e42 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/channel_node.py +++ b/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/channel_node.py @@ -6,7 +6,6 @@ from pony import orm from pony.orm.core import DEFAULT, db_session -from tribler_core.exceptions import InvalidChannelNodeException, InvalidSignatureException from tribler_core.components.metadata_store.db.orm_bindings.discrete_clock import clock from tribler_core.components.metadata_store.db.serialization import ( CHANNEL_NODE, @@ -14,6 +13,7 @@ DELETED, DeletedMetadataPayload, ) +from tribler_core.exceptions import InvalidChannelNodeException, InvalidSignatureException from tribler_core.utilities.path_util import Path from tribler_core.utilities.unicode import hexlify @@ -87,8 +87,8 @@ class ChannelNode(db.Entity): # This attribute holds the names of the class attributes that are used by the serializer for the # corresponding payload type. We only initialize it once on class creation as an optimization. payload_arguments = _payload_class.__init__.__code__.co_varnames[ - : _payload_class.__init__.__code__.co_argcount - ][1:] + : _payload_class.__init__.__code__.co_argcount + ][1:] # A non - personal attribute of an entry is an attribute that would have the same value regardless of where, # when and who created the entry. @@ -139,7 +139,7 @@ def __init__(self, *args, **kwargs): if not private_key_override and not skip_key_check: # No key/signature given, sign with our own key. if ("signature" not in kwargs) and ( - ("public_key" not in kwargs) or (kwargs["public_key"] == self._my_key.pub().key_to_bin()[10:]) + ("public_key" not in kwargs) or (kwargs["public_key"] == self._my_key.pub().key_to_bin()[10:]) ): private_key_override = self._my_key @@ -298,12 +298,15 @@ def make_copy(self, tgt_parent_id, attributes_override=None): dst_dict.update({"origin_id": tgt_parent_id, "status": NEW}) return self.__class__(**dst_dict) + def get_type(self) -> int: + return self._discriminator_ + def to_simple_dict(self): """ Return a basic dictionary with information about the node """ simple_dict = { - "type": self._discriminator_, + "type": self.get_type(), "id": self.id_, "origin_id": self.origin_id, "public_key": hexlify(self.public_key), diff --git a/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py b/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py index 2ac9dec0be3..e7e15b620d7 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py +++ b/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py @@ -46,7 +46,7 @@ def tdef_to_metadata_dict(tdef): } -def define_binding(db, notifier: Notifier, tag_version: int): +def define_binding(db, notifier: Notifier, tag_processor_version: int): class TorrentMetadata(db.MetadataNode): """ This ORM binding class is intended to store Torrent objects, i.e. infohashes along with some related metadata. @@ -63,7 +63,7 @@ class TorrentMetadata(db.MetadataNode): # Local xxx = orm.Optional(float, default=0) health = orm.Optional('TorrentState', reverse='metadata') - tag_version = orm.Required(int, default=0) + tag_processor_version = orm.Required(int, default=0) # Special class-level properties _payload_class = TorrentMetadataPayload @@ -93,7 +93,7 @@ def __init__(self, *args, **kwargs): notifier.notify(NEW_TORRENT_METADATA_CREATED, infohash=kwargs.get("infohash"), title=self.title) - self.tag_version = tag_version + self.tag_processor_version = tag_processor_version def add_tracker(self, tracker_url): sanitized_url = get_uniformed_tracker_url(tracker_url) @@ -140,6 +140,7 @@ def to_simple_dict(self): "num_leechers": self.health.leechers, "last_tracker_check": self.health.last_check, "updated": int((self.torrent_date - epoch).total_seconds()), + "tag_processor_version": self.tag_processor_version, } ) diff --git a/src/tribler-core/tribler_core/components/metadata_store/db/store.py b/src/tribler-core/tribler_core/components/metadata_store/db/store.py index 36cb2c6d8be..ae015cd7627 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/db/store.py +++ b/src/tribler-core/tribler_core/components/metadata_store/db/store.py @@ -144,7 +144,7 @@ def __init__( notifier=None, check_tables=True, db_version: int = CURRENT_DB_VERSION, - tag_version: int = 0 + tag_processor_version: int = 0 ): self.notifier = notifier # Reference to app-level notification service self.db_path = db_filename @@ -194,7 +194,7 @@ def sqlite_disable_sync(_, connection): self.TorrentMetadata = torrent_metadata.define_binding( self._db, notifier=notifier, - tag_version=tag_version + tag_processor_version=tag_processor_version ) self.ChannelMetadata = channel_metadata.define_binding(self._db) @@ -773,7 +773,7 @@ def get_entries_count(self, **kwargs): return self.get_entries_query(**kwargs).count() @db_session - def get_max_rowid(self): + def get_max_rowid(self) -> int: return select(max(obj.rowid) for obj in self.ChannelNode).get() or 0 fts_keyword_search_re = re.compile(r'\w+', re.UNICODE) diff --git a/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py b/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py index 965d7afc663..1b317b80d14 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py +++ b/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py @@ -41,7 +41,7 @@ async def run(self): key_component.primary_key, notifier=self.session.notifier, disable_sync=config.gui_test_mode, - tag_version=TagRulesProcessor.version + tag_processor_version=TagRulesProcessor.version ) self.mds = metadata_store self.session.notifier.add_observer(NTFY.TORRENT_METADATA_ADDED.value, diff --git a/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py b/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py index ef38626dc05..e87916602f1 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py +++ b/src/tribler-core/tribler_core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py @@ -44,7 +44,7 @@ def create_node(self, *args, **kwargs): default_eccrypto.generate_key("curve25519"), disable_sync=True, ) - self.tags_db = TagDatabase(str(Path(self.temporary_directory()) / "tags.db"), create_tables=True) + self.tags_db = TagDatabase(str(Path(self.temporary_directory()) / "tags.db")) kwargs['metadata_store'] = self.metadata_store kwargs['tags_db'] = self.tags_db diff --git a/src/tribler-core/tribler_core/components/metadata_store/restapi/channels_endpoint.py b/src/tribler-core/tribler_core/components/metadata_store/restapi/channels_endpoint.py index 4592524f135..9bde25fb89c 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/restapi/channels_endpoint.py +++ b/src/tribler-core/tribler_core/components/metadata_store/restapi/channels_endpoint.py @@ -134,7 +134,7 @@ async def get_channels(self, request): ) elif channel_dict["state"] == CHANNEL_STATE.METAINFO_LOOKUP.value: if not self.download_manager.metainfo_requests.get( - bytes(channel.infohash) + bytes(channel.infohash) ) and self.download_manager.download_exists(bytes(channel.infohash)): channel_dict["state"] = CHANNEL_STATE.DOWNLOADING.value @@ -169,6 +169,7 @@ async def get_channels(self, request): }, ) async def get_channel_contents(self, request): + self._logger.info('Get channel content') sanitized = self.sanitize_parameters(request.query) include_total = request.query.get('include_total', '') channel_pk, channel_id = self.get_channel_from_request(request) @@ -180,14 +181,21 @@ async def get_channel_contents(self, request): remote_failed = False if remote: try: + self._logger.info('Receive remote content') contents_list = await self.gigachannel_community.remote_select_channel_contents(**sanitized) except (RequestTimeoutException, NoChannelSourcesException, CancelledError): remote_failed = True + self._logger.info('Remote request failed') + if not remote or remote_failed: + self._logger.info('Receive local content') with db_session: contents = self.mds.get_entries(**sanitized) - contents_list = [c.to_simple_dict() for c in contents] + contents_list = [] + for entry in contents: + self.process_regular_torrent(entry) + contents_list.append(entry.to_simple_dict()) total = self.mds.get_total_count(**sanitized) if include_total else None self.add_download_progress_to_metadata_list(contents_list) self.add_tags_to_metadata_list(contents_list, hide_xxx=sanitized["hide_xxx"]) @@ -390,9 +398,9 @@ async def add_torrent_to_channel(self, request): elif uri.startswith("magnet:"): _, xt, _ = parse_magnetlink(uri) if ( - xt - and is_infohash(codecs.encode(xt, 'hex')) - and (self.mds.torrent_exists_in_personal_channel(xt) or channel.copy_torrent_from_infohash(xt)) + xt + and is_infohash(codecs.encode(xt, 'hex')) + and (self.mds.torrent_exists_in_personal_channel(xt) or channel.copy_torrent_from_infohash(xt)) ): return RESTResponse({"added": 1}) @@ -492,7 +500,11 @@ async def get_popular_torrents_channel(self, request): with db_session: contents = self.mds.get_entries(**sanitized) - contents_list = [c.to_simple_dict() for c in contents] + contents_list = [] + for entry in contents: + self.process_regular_torrent(entry) + contents_list.append(entry.to_simple_dict()) + self.add_download_progress_to_metadata_list(contents_list) self.add_tags_to_metadata_list(contents_list, hide_xxx=sanitized["hide_xxx"]) response_dict = { diff --git a/src/tribler-core/tribler_core/components/metadata_store/restapi/metadata_endpoint_base.py b/src/tribler-core/tribler_core/components/metadata_store/restapi/metadata_endpoint_base.py index f4140293c37..5903743a2d1 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/restapi/metadata_endpoint_base.py +++ b/src/tribler-core/tribler_core/components/metadata_store/restapi/metadata_endpoint_base.py @@ -3,14 +3,15 @@ from pony.orm import db_session +# This dict is used to translate JSON fields into the columns used in Pony for _sorting_. +# id_ is not in the list because there is not index on it, so we never really want to sort on it. from tribler_core.components.metadata_store.category_filter.family_filter import default_xxx_filter from tribler_core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT from tribler_core.components.metadata_store.db.store import MetadataStore from tribler_core.components.restapi.rest.rest_endpoint import RESTEndpoint from tribler_core.components.tag.db.tag_db import TagDatabase +from tribler_core.components.tag.rules.tag_rules_processor import TagRulesProcessor -# This dict is used to translate JSON fields into the columns used in Pony for _sorting_. -# id_ is not in the list because there is not index on it, so we never really want to sort on it. json2pony_columns = { 'category': "tags", 'name': "title", @@ -37,10 +38,12 @@ class MetadataEndpointBase(RESTEndpoint): - def __init__(self, metadata_store: MetadataStore, *args, tags_db: TagDatabase = None, **kwargs): + def __init__(self, metadata_store: MetadataStore, *args, tags_db: TagDatabase = None, + tag_rules_processor: TagRulesProcessor = None, **kwargs): super().__init__(*args, **kwargs) self.mds = metadata_store self.tags_db: Optional[TagDatabase] = tags_db + self.tag_rules_processor: Optional[TagRulesProcessor] = tag_rules_processor @classmethod def sanitize_parameters(cls, parameters): @@ -68,14 +71,34 @@ def sanitize_parameters(cls, parameters): sanitized['metadata_type'] = frozenset(mtypes) return sanitized + def process_regular_torrent(self, entry): + is_torrent = entry.get_type() == REGULAR_TORRENT + if not is_torrent: + return + + if not self.tag_rules_processor: + return + + is_auto_generated_tags_not_created = entry.tag_processor_version < self.tag_rules_processor.version + if is_auto_generated_tags_not_created: + generated = self.tag_rules_processor.process_torrent_title(infohash=entry.infohash, title=entry.title) + entry.tag_processor_version = self.tag_rules_processor.version + self._logger.info(f'Generated {generated} tags for {entry.infohash}') + @db_session def add_tags_to_metadata_list(self, contents_list, hide_xxx=False): if self.tags_db is None: self._logger.error(f'Cannot add tags to metadata list: tags_db is not set in {self.__class__.__name__}') return for torrent in contents_list: - if torrent['type'] == REGULAR_TORRENT: - tags = self.tags_db.get_tags(unhexlify(torrent["infohash"])) - if hide_xxx: - tags = [tag.lower() for tag in tags if not default_xxx_filter.isXXX(tag, isFilename=False)] - torrent["tags"] = tags + is_torrent = torrent['type'] == REGULAR_TORRENT + if not is_torrent: + continue + + infohash_str = torrent['infohash'] + infohash = unhexlify(infohash_str) + + tags = self.tags_db.get_tags(infohash) + if hide_xxx: + tags = [tag.lower() for tag in tags if not default_xxx_filter.isXXX(tag, isFilename=False)] + torrent["tags"] = tags diff --git a/src/tribler-core/tribler_core/components/restapi/restapi_component.py b/src/tribler-core/tribler_core/components/restapi/restapi_component.py index 6e06e8254be..1b488091b0b 100644 --- a/src/tribler-core/tribler_core/components/restapi/restapi_component.py +++ b/src/tribler-core/tribler_core/components/restapi/restapi_component.py @@ -104,12 +104,15 @@ async def run(self): self.maybe_add('/libtorrent', LibTorrentEndpoint, libtorrent_component.download_manager) self.maybe_add('/torrentinfo', TorrentInfoEndpoint, libtorrent_component.download_manager) self.maybe_add('/metadata', MetadataEndpoint, torrent_checker, metadata_store_component.mds, - tags_db=tag_component.tags_db) + tags_db=tag_component.tags_db, tag_rules_processor=tag_component.rules_processor) self.maybe_add('/channels', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager, - gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db) + gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db, + tag_rules_processor=tag_component.rules_processor) self.maybe_add('/collections', ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager, - gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db) - self.maybe_add('/search', SearchEndpoint, metadata_store_component.mds, tags_db=tag_component.tags_db) + gigachannel_component.community, metadata_store_component.mds, tags_db=tag_component.tags_db, + tag_rules_processor=tag_component.rules_processor) + self.maybe_add('/search', SearchEndpoint, metadata_store_component.mds, tags_db=tag_component.tags_db, + tag_rules_processor=tag_component.rules_processor) self.maybe_add('/remote_query', RemoteQueryEndpoint, gigachannel_component.community, metadata_store_component.mds) self.maybe_add('/tags', TagsEndpoint, db=tag_component.tags_db, community=tag_component.community) diff --git a/src/tribler-core/tribler_core/components/tag/community/tests/test_tag_community.py b/src/tribler-core/tribler_core/components/tag/community/tests/test_tag_community.py index 4edd0ca0bbb..ac35da8cfdb 100644 --- a/src/tribler-core/tribler_core/components/tag/community/tests/test_tag_community.py +++ b/src/tribler-core/tribler_core/components/tag/community/tests/test_tag_community.py @@ -26,7 +26,7 @@ async def tearDown(self): await super().tearDown() def create_node(self, *args, **kwargs): - return MockIPv8("curve25519", TagCommunity, db=TagDatabase(create_tables=True), tags_key=LibNaCLSK(), + return MockIPv8("curve25519", TagCommunity, db=TagDatabase(), tags_key=LibNaCLSK(), request_interval=REQUEST_INTERVAL_FOR_RANDOM_TAGS) def create_operation(self, tag=''): diff --git a/src/tribler-core/tribler_core/components/tag/db/tag_db.py b/src/tribler-core/tribler_core/components/tag/db/tag_db.py index b35a265b697..9ee559f0983 100644 --- a/src/tribler-core/tribler_core/components/tag/db/tag_db.py +++ b/src/tribler-core/tribler_core/components/tag/db/tag_db.py @@ -12,8 +12,6 @@ CLOCK_START_VALUE = 0 -# we picked `-1` as a value because it is allows manually created tags get a higher priority -CLOCK_FOR_AUTOGENERATED_TAGS = CLOCK_START_VALUE - 1 PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS = b'auto_generated' SHOW_THRESHOLD = 1 @@ -21,11 +19,12 @@ class TagDatabase: - def __init__(self, filename: Optional[str] = None, **kwargs): + def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): self.instance = orm.Database() self.define_binding(self.instance) self.instance.bind('sqlite', filename or ':memory:', create_db=True) - self.instance.generate_mapping(**kwargs) + generate_mapping_kwargs['create_tables'] = create_tables + self.instance.generate_mapping(**generate_mapping_kwargs) self.logger = logging.getLogger(self.__class__.__name__) @staticmethod @@ -95,7 +94,6 @@ class TorrentTagOp(db.Entity): orm.composite_key(torrent_tag, peer) - def add_tag_operation(self, operation: TagOperation, signature: bytes, is_local_peer: bool = False, is_auto_generated: bool = False, counter_increment: int = 1) -> bool: """ Add the operation that will be applied to the tag. @@ -136,7 +134,15 @@ def add_tag_operation(self, operation: TagOperation, signature: bytes, is_local_ updated_at=datetime.datetime.utcnow(), auto_generated=is_auto_generated) return True - def add_auto_generated_tag_operation(self, operation: TagOperation): + def add_auto_generated_tag(self, infohash: bytes, tag: str): + operation = TagOperation( + infohash=infohash, + operation=TagOperationEnum.ADD, + clock=CLOCK_START_VALUE, + creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, + tag=tag + ) + self.add_tag_operation(operation, signature=b'', is_local_peer=False, is_auto_generated=True, counter_increment=SHOW_THRESHOLD) diff --git a/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py b/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py index 87242a04b27..dba25aba105 100644 --- a/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py +++ b/src/tribler-core/tribler_core/components/tag/db/tests/test_tag_db.py @@ -2,19 +2,16 @@ from dataclasses import dataclass from itertools import count from types import SimpleNamespace +from unittest.mock import Mock, patch from ipv8.test.base import TestBase +from pony import orm from pony.orm import commit, db_session # pylint: disable=protected-access from tribler_core.components.tag.community.tag_payload import TagOperation, TagOperationEnum -from tribler_core.components.tag.db.tag_db import ( - CLOCK_FOR_AUTOGENERATED_TAGS, - PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, - SHOW_THRESHOLD, - TagDatabase, -) +from tribler_core.components.tag.db.tag_db import PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, SHOW_THRESHOLD, TagDatabase from tribler_core.utilities.pony_utils import get_or_create @@ -28,7 +25,7 @@ class Tag: class TestTagDB(TestBase): def setUp(self): super().setUp() - self.db = TagDatabase(create_tables=True) + self.db = TagDatabase() async def tearDown(self): if self._outcome.errors: @@ -83,6 +80,16 @@ def generate_n_peer_names(n): for peer in generate_n_peer_names(tag.count): TestTagDB.add_operation(tag_db, infohash, tag.name, peer, is_auto_generated=tag.auto_generated) + @patch.object(orm.Database, 'generate_mapping') + def test_constructor_create_tables_true(self, mocked_generate_mapping: Mock): + TagDatabase(':memory:') + mocked_generate_mapping.assert_called_with(create_tables=True) + + @patch.object(orm.Database, 'generate_mapping') + def test_constructor_create_tables_false(self, mocked_generate_mapping: Mock): + TagDatabase(':memory:', create_tables=False) + mocked_generate_mapping.assert_called_with(create_tables=False) + @db_session async def test_get_or_create(self): # Test that function get_or_create() works as expected: @@ -188,19 +195,13 @@ async def test_remote_add_multiple_tag_operations(self): assert self.db.instance.TorrentTag.get().removed_count == 1 @db_session - async def test_add_auto_generated_operation(self): - self.db.add_auto_generated_tag_operation( - operation=TagOperation( - infohash=b'infohash', - operation=TagOperationEnum.ADD, - clock=CLOCK_FOR_AUTOGENERATED_TAGS, - creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, - tag='tag' - ) + async def test_add_auto_generated_tag(self): + self.db.add_auto_generated_tag( + infohash=b'infohash', + tag='tag' ) assert self.db.instance.TorrentTagOp.get().auto_generated - assert self.db.instance.TorrentTagOp.get().clock == CLOCK_FOR_AUTOGENERATED_TAGS assert self.db.instance.TorrentTag.get().added_count == SHOW_THRESHOLD assert self.db.instance.Peer.get().public_key == PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS diff --git a/src/tribler-core/tribler_core/components/tag/rules/tag_rules.py b/src/tribler-core/tribler_core/components/tag/rules/tag_rules.py index 1fe007cd463..76b1411e77c 100644 --- a/src/tribler-core/tribler_core/components/tag/rules/tag_rules.py +++ b/src/tribler-core/tribler_core/components/tag/rules/tag_rules.py @@ -59,5 +59,7 @@ def extract_tags(text: str, rules: Optional[RulesList] = None) -> Iterable[str]: def extract_only_valid_tags(text: str, rules: Optional[RulesList] = None) -> Iterable[str]: - extracted_tags_gen = (t.lower() for t in extract_tags(text, rules)) - yield from (t for t in extracted_tags_gen if is_valid_tag(t)) + for tag in extract_tags(text, rules): + tag = tag.lower() + if is_valid_tag(tag): + yield tag diff --git a/src/tribler-core/tribler_core/components/tag/rules/tag_rules_processor.py b/src/tribler-core/tribler_core/components/tag/rules/tag_rules_processor.py index 13085e67877..595dc89aff2 100644 --- a/src/tribler-core/tribler_core/components/tag/rules/tag_rules_processor.py +++ b/src/tribler-core/tribler_core/components/tag/rules/tag_rules_processor.py @@ -8,15 +8,13 @@ import tribler_core.components.metadata_store.db.orm_bindings.torrent_metadata as torrent_metadata import tribler_core.components.metadata_store.db.store as MDS from tribler_core.components.metadata_store.db.serialization import REGULAR_TORRENT -from tribler_core.components.tag.community.tag_payload import TagOperation, TagOperationEnum -from tribler_core.components.tag.db.tag_db import ( - CLOCK_FOR_AUTOGENERATED_TAGS, - PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, - TagDatabase, -) +from tribler_core.components.tag.db.tag_db import TagDatabase from tribler_core.components.tag.rules.tag_rules import extract_only_valid_tags from tribler_core.notifier import Notifier +DEFAULT_INTERVAL = 10 +DEFAULT_BATCH_SIZE = 1000 + LAST_PROCESSED_TORRENT_ID = 'last_processed_torrent_id' @@ -25,7 +23,7 @@ class TagRulesProcessor(TaskManager): version: int = 1 def __init__(self, notifier: Notifier, db: TagDatabase, mds: MDS.MetadataStore, - batch_size: int = 1000, interval: float = 10): + batch_size: int = DEFAULT_BATCH_SIZE, interval: float = DEFAULT_INTERVAL): """ Default values for batch_size and interval are chosen so that tag processing is not too heavy fot CPU and with this values 360k items will be processed within the hour. @@ -40,19 +38,29 @@ def __init__(self, notifier: Notifier, db: TagDatabase, mds: MDS.MetadataStore, self.interval = interval self.notifier.add_observer(torrent_metadata.NEW_TORRENT_METADATA_CREATED, callback=self.process_torrent_title) - self.register_task(name=self.process_batch.__name__, - interval=interval, - task=self.process_batch) + @db_session + def start(self): + self.logger.info('Start') + + max_row_id = self.mds.get_max_rowid() + is_finished = self.get_last_processed_torrent_id() >= max_row_id + + if not is_finished: + self.logger.info(f'Register process_batch task with interval: {self.interval} sec') + self.register_task(name=self.process_batch.__name__, + interval=self.interval, + task=self.process_batch) @db_session def process_batch(self) -> int: def query(_start, _end): return lambda t: _start < t.rowid and t.rowid <= _end and \ t.metadata_type == REGULAR_TORRENT and \ - t.tag_version < self.version + t.tag_processor_version < self.version - start = int(self.mds.get_value(LAST_PROCESSED_TORRENT_ID, default='0')) - end = start + self.batch_size + start = self.get_last_processed_torrent_id() + max_row_id = self.mds.get_max_rowid() + end = min(start + self.batch_size, max_row_id) self.logger.info(f'Processing batch [{start}...{end}]') batch = self.mds.TorrentMetadata.select(query(start, end)) @@ -60,17 +68,16 @@ def query(_start, _end): added = 0 for torrent in batch: added += self.process_torrent_title(torrent.infohash, torrent.title) - torrent.tag_version = self.version + torrent.tag_processor_version = self.version processed += 1 + self.mds.set_value(LAST_PROCESSED_TORRENT_ID, str(end)) self.logger.info(f'Processed: {processed} titles. Added {added} tags.') - max_row_id = self.mds.get_max_rowid() - is_beyond_the_boundary = end > max_row_id - if is_beyond_the_boundary: - self._schedule_new_process_batch_round() - else: - self.mds.set_value(LAST_PROCESSED_TORRENT_ID, str(end)) + is_finished = end >= max_row_id + if is_finished: + self.logger.info('Finish batch processing, cancel process_batch task') + self.cancel_pending_task(name=self.process_batch.__name__) return processed def process_torrent_title(self, infohash: Optional[bytes] = None, title: Optional[str] = None) -> int: @@ -85,26 +92,7 @@ def process_torrent_title(self, infohash: Optional[bytes] = None, title: Optiona def save_tags(self, infohash: bytes, tags: Set[str]): self.logger.debug(f'Save: {len(tags)} tags') for tag in tags: - operation = TagOperation( - infohash=infohash, - operation=TagOperationEnum.ADD, - clock=CLOCK_FOR_AUTOGENERATED_TAGS, - creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, - tag=tag - ) - # we want auto generated operation to act like a normal operation - # therefore we use 2 as a `counter_increment` to immediately pass - # SHOW_THRESHOLD - self.db.add_auto_generated_tag_operation(operation=operation) - - def _schedule_new_process_batch_round(self): - self.logger.info('All items in TorrentMetadata have been processed.') - self.mds.set_value(LAST_PROCESSED_TORRENT_ID, '0') - self.logger.info('Set last_processed_torrent_id to 0') - self.interval *= 2 - self.logger.info(f'Double the interval. New interval: {self.interval}') - self.batch_size *= 2 - self.logger.info(f'Double the batch size. New batch size: {self.batch_size}') - self.replace_task(self.process_batch.__name__, - interval=self.interval, - task=self.process_batch) + self.db.add_auto_generated_tag(infohash=infohash, tag=tag) + + def get_last_processed_torrent_id(self) -> int: + return int(self.mds.get_value(LAST_PROCESSED_TORRENT_ID, default='0')) diff --git a/src/tribler-core/tribler_core/components/tag/rules/tests/test_tag_rules_processor.py b/src/tribler-core/tribler_core/components/tag/rules/tests/test_tag_rules_processor.py index d1580381a43..2e972bdab29 100644 --- a/src/tribler-core/tribler_core/components/tag/rules/tests/test_tag_rules_processor.py +++ b/src/tribler-core/tribler_core/components/tag/rules/tests/test_tag_rules_processor.py @@ -4,13 +4,12 @@ import pytest from tribler_core.components.metadata_store.db.orm_bindings.torrent_metadata import NEW_TORRENT_METADATA_CREATED -from tribler_core.components.tag.community.tag_payload import TagOperation, TagOperationEnum -from tribler_core.components.tag.db.tag_db import CLOCK_FOR_AUTOGENERATED_TAGS, PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS from tribler_core.components.tag.rules.tag_rules_processor import LAST_PROCESSED_TORRENT_ID, TagRulesProcessor TEST_BATCH_SIZE = 100 TEST_INTERVAL = 0.1 + # pylint: disable=redefined-outer-name, protected-access @pytest.fixture def tag_rules_processor(): @@ -43,31 +42,16 @@ def test_process_torrent_file(mocked_save_tags: Mock, tag_rules_processor: TagRu def test_save_tags(tag_rules_processor: TagRulesProcessor): # test that tag_rules_processor calls TagDatabase with correct args - expected_calls = [{'operation': TagOperation(infohash=b'infohash', operation=TagOperationEnum.ADD, - clock=CLOCK_FOR_AUTOGENERATED_TAGS, - creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, - tag='tag1')}, - {'operation': TagOperation(infohash=b'infohash', operation=TagOperationEnum.ADD, - clock=CLOCK_FOR_AUTOGENERATED_TAGS, - creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, - tag='tag2')}] + expected_calls = [{'infohash': b'infohash', 'tag': 'tag1'}, + {'infohash': b'infohash', 'tag': 'tag2'}] tag_rules_processor.save_tags(infohash=b'infohash', tags={'tag1', 'tag2'}) - actual_calls = [c.kwargs for c in tag_rules_processor.db.add_auto_generated_tag_operation.mock_calls] + actual_calls = [c.kwargs for c in tag_rules_processor.db.add_auto_generated_tag.mock_calls] # compare two lists of dict assert [c for c in actual_calls if c not in expected_calls] == [] -@patch.object(TagRulesProcessor, 'replace_task') -def test_schedule_new_process_batch_round(mocked_replace_task: Mock, tag_rules_processor: TagRulesProcessor): - tag_rules_processor._schedule_new_process_batch_round() - assert tag_rules_processor.interval == TEST_INTERVAL * 2 - assert tag_rules_processor.batch_size == TEST_BATCH_SIZE * 2 - tag_rules_processor.mds.set_value.assert_called_with(LAST_PROCESSED_TORRENT_ID, '0') - mocked_replace_task.assert_called_once() - - @patch.object(TagRulesProcessor, 'process_torrent_title', new=Mock(return_value=1)) def test_process_batch_within_the_boundary(tag_rules_processor: TagRulesProcessor): # test inner logic of `process_batch` in case this batch located within the boundary @@ -89,22 +73,22 @@ def select(_): tag_rules_processor.mds.set_value.assert_called_with(LAST_PROCESSED_TORRENT_ID, str(TEST_BATCH_SIZE)) -@patch.object(TagRulesProcessor, '_schedule_new_process_batch_round') @patch.object(TagRulesProcessor, 'process_torrent_title', new=Mock(return_value=1)) -def test_process_batch_beyond_the_boundary(mocked_schedule_new_process_batch_round: Mock, - tag_rules_processor: TagRulesProcessor): - # test inner logic of `process_batch` in case this batch located within the boundary +def test_process_batch_beyond_the_boundary(tag_rules_processor: TagRulesProcessor): + # test inner logic of `process_batch` in case this batch located on a border returned_batch_size = TEST_BATCH_SIZE // 2 # let's return a half of requested items + # let's specify `max_rowid` in such a way that it is less than end of the current batch + max_rowid = returned_batch_size // 2 + def select(_): return [SimpleNamespace(infohash=i, title=i) for i in range(returned_batch_size)] tag_rules_processor.mds.get_value = lambda *_, **__: 0 # let's start from 0 for LAST_PROCESSED_TORRENT_ID tag_rules_processor.mds.TorrentMetadata.select = select - # let's specify `max_rowid` in such a way that it is less than end of the current batch - tag_rules_processor.mds.get_max_rowid = lambda: returned_batch_size // 2 + tag_rules_processor.mds.get_max_rowid = lambda: max_rowid - # assert that actually returned count of processed items is equal to `returned_batch_size` + # assert that actually returned count of processed items is equal to `max_rowid` assert tag_rules_processor.process_batch() == returned_batch_size - mocked_schedule_new_process_batch_round.assert_called_once() + tag_rules_processor.mds.set_value.assert_called_with(LAST_PROCESSED_TORRENT_ID, str(max_rowid)) diff --git a/src/tribler-core/tribler_core/components/tag/tag_component.py b/src/tribler-core/tribler_core/components/tag/tag_component.py index cdd1a2940f8..07898e33e85 100644 --- a/src/tribler-core/tribler_core/components/tag/tag_component.py +++ b/src/tribler-core/tribler_core/components/tag/tag_component.py @@ -42,6 +42,7 @@ async def run(self): db=self.tags_db, mds=mds_component.mds, ) + self.rules_processor.start() self._ipv8_component.initialise_community_by_default(self.community) diff --git a/src/tribler-core/tribler_core/conftest.py b/src/tribler-core/tribler_core/conftest.py index 1907c351705..a4329f76c09 100644 --- a/src/tribler-core/tribler_core/conftest.py +++ b/src/tribler-core/tribler_core/conftest.py @@ -186,7 +186,7 @@ def metadata_store(tmp_path): @pytest.fixture def tags_db(): - db = TagDatabase(create_tables=True) + db = TagDatabase() yield db db.shutdown() diff --git a/src/tribler-core/tribler_core/upgrade/tests/test_upgrader.py b/src/tribler-core/tribler_core/upgrade/tests/test_upgrader.py index d3b0a5a8cd2..ea6bc219455 100644 --- a/src/tribler-core/tribler_core/upgrade/tests/test_upgrader.py +++ b/src/tribler-core/tribler_core/upgrade/tests/test_upgrader.py @@ -172,10 +172,10 @@ def test_upgrade_pony13to14(upgrader: TriblerUpgrader, state_dir, channels_dir, upgrader.upgrade_pony_db_13to14() mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False) - tags = TagDatabase(str(tags_path), check_tables=False) + tags = TagDatabase(str(tags_path), create_tables=False, check_tables=False) with db_session: - assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'tag_version') + assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'tag_processor_version') assert upgrader.column_exists_in_table(tags.instance, 'TorrentTagOp', 'auto_generated') assert mds.get_value('db_version') == '14' diff --git a/src/tribler-core/tribler_core/upgrade/upgrade.py b/src/tribler-core/tribler_core/upgrade/upgrade.py index caede9f28bd..da7074e9adb 100644 --- a/src/tribler-core/tribler_core/upgrade/upgrade.py +++ b/src/tribler-core/tribler_core/upgrade/upgrade.py @@ -107,14 +107,14 @@ def upgrade_pony_db_13to14(self): mds = MetadataStore(mds_path, self.channels_dir, self.trustchain_keypair, disable_sync=True, check_tables=False, db_version=13) if mds_path.exists() else None - tagdb = TagDatabase(str(tagdb_path), check_tables=False) if tagdb_path.exists() else None + tag_db = TagDatabase(str(tagdb_path), create_tables=False, check_tables=False) if tagdb_path.exists() else None - self.do_upgrade_pony_db_13to14(mds, tagdb) + self.do_upgrade_pony_db_13to14(mds, tag_db) if mds: mds.shutdown() - if tagdb: - tagdb.shutdown() + if tag_db: + tag_db.shutdown() def upgrade_pony_db_12to13(self): """ @@ -229,7 +229,7 @@ def do_upgrade_pony_db_12to13(self, mds): db_version.value = str(to_version) def do_upgrade_pony_db_13to14(self, mds: Optional[MetadataStore], tags: Optional[TagDatabase]): - def _alter(db, table_name, column_name, column_type): + def add_column(db, table_name, column_name, column_type): if not self.column_exists_in_table(db, table_name, column_name): db.execute(f'ALTER TABLE "{table_name}" ADD "{column_name}" {column_type} DEFAULT 0') @@ -244,8 +244,8 @@ def _alter(db, table_name, column_name, column_type): self._logger.info(f'{version.current}->{version.next}') - _alter(db=mds._db, table_name='ChannelNode', column_name='tag_version', column_type='INT') - _alter(db=tags.instance, table_name='TorrentTagOp', column_name='auto_generated', column_type='BOOLEAN') + add_column(db=mds._db, table_name='ChannelNode', column_name='tag_processor_version', column_type='INT') + add_column(db=tags.instance, table_name='TorrentTagOp', column_name='auto_generated', column_type='BOOLEAN') tags.instance.commit() mds._db.commit()