diff --git a/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py b/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py index 725ed5b4943..2ac9dec0be3 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py +++ b/src/tribler-core/tribler_core/components/metadata_store/db/orm_bindings/torrent_metadata.py @@ -8,10 +8,12 @@ from tribler_core.components.metadata_store.category_filter.family_filter import default_xxx_filter from tribler_core.components.metadata_store.db.orm_bindings.channel_node import COMMITTED from tribler_core.components.metadata_store.db.serialization import EPOCH, REGULAR_TORRENT, TorrentMetadataPayload +from tribler_core.notifier import Notifier from tribler_core.utilities.tracker_utils import get_uniformed_tracker_url from tribler_core.utilities.unicode import ensure_unicode, hexlify NULL_KEY_SUBST = b"\00" +NEW_TORRENT_METADATA_CREATED: str = 'TorrentMetadata:new_torrent_metadata_created' # This function is used to devise id_ from infohash in deterministic way. Used in FFA channels. @@ -44,7 +46,7 @@ def tdef_to_metadata_dict(tdef): } -def define_binding(db): +def define_binding(db, notifier: Notifier, tag_version: int): class TorrentMetadata(db.MetadataNode): """ This ORM binding class is intended to store Torrent objects, i.e. infohashes along with some related metadata. @@ -61,12 +63,13 @@ class TorrentMetadata(db.MetadataNode): # Local xxx = orm.Optional(float, default=0) health = orm.Optional('TorrentState', reverse='metadata') + tag_version = orm.Required(int, default=0) # Special class-level properties _payload_class = TorrentMetadataPayload payload_arguments = _payload_class.__init__.__code__.co_varnames[ - : _payload_class.__init__.__code__.co_argcount - ][1:] + : _payload_class.__init__.__code__.co_argcount + ][1:] nonpersonal_attributes = db.MetadataNode.nonpersonal_attributes + ( 'infohash', 'size', @@ -86,6 +89,11 @@ def __init__(self, *args, **kwargs): if 'tracker_info' in kwargs: self.add_tracker(kwargs["tracker_info"]) + if notifier: + notifier.notify(NEW_TORRENT_METADATA_CREATED, + infohash=kwargs.get("infohash"), + title=self.title) + self.tag_version = tag_version def add_tracker(self, tracker_url): sanitized_url = get_uniformed_tracker_url(tracker_url) diff --git a/src/tribler-core/tribler_core/components/metadata_store/db/store.py b/src/tribler-core/tribler_core/components/metadata_store/db/store.py index e646133f75f..36cb2c6d8be 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/db/store.py +++ b/src/tribler-core/tribler_core/components/metadata_store/db/store.py @@ -4,7 +4,7 @@ from asyncio import get_event_loop from datetime import datetime, timedelta from time import sleep, time -from typing import Union +from typing import Optional, Union from lz4.frame import LZ4FrameDecompressor @@ -49,11 +49,12 @@ from tribler_core.components.metadata_store.remote_query_community.payload_checker import process_payload from tribler_core.exceptions import InvalidSignatureException from tribler_core.utilities.path_util import Path +from tribler_core.utilities.pony_utils import get_or_create from tribler_core.utilities.unicode import hexlify from tribler_core.utilities.utilities import MEMORY_DB BETA_DB_VERSIONS = [0, 1, 2, 3, 4, 5] -CURRENT_DB_VERSION = 13 +CURRENT_DB_VERSION = 14 MIN_BATCH_SIZE = 10 MAX_BATCH_SIZE = 1000 @@ -61,7 +62,6 @@ POPULAR_TORRENTS_FRESHNESS_PERIOD = 60 * 60 * 24 # Last day POPULAR_TORRENTS_COUNT = 100 - # This table should never be used from ORM directly. # It is created as a VIRTUAL table by raw SQL and # maintained by SQL triggers. @@ -136,14 +136,15 @@ class MetadataStore: def __init__( - self, - db_filename: Union[Path, type(MEMORY_DB)], - channels_dir, - my_key, - disable_sync=False, - notifier=None, - check_tables=True, - db_version: int = CURRENT_DB_VERSION, + self, + db_filename: Union[Path, type(MEMORY_DB)], + channels_dir, + my_key, + disable_sync=False, + notifier=None, + check_tables=True, + db_version: int = CURRENT_DB_VERSION, + tag_version: int = 0 ): self.notifier = notifier # Reference to app-level notification service self.db_path = db_filename @@ -190,7 +191,11 @@ def sqlite_disable_sync(_, connection): self.MetadataNode = metadata_node.define_binding(self._db) self.CollectionNode = collection_node.define_binding(self._db) - self.TorrentMetadata = torrent_metadata.define_binding(self._db) + self.TorrentMetadata = torrent_metadata.define_binding( + self._db, + notifier=notifier, + tag_version=tag_version + ) self.ChannelMetadata = channel_metadata.define_binding(self._db) self.JsonNode = json_node.define_binding(self._db, db_version) @@ -242,6 +247,14 @@ def wrapper(): return await get_event_loop().run_in_executor(None, wrapper) + def set_value(self, key: str, value: str): + key_value = get_or_create(self.MiscData, name=key) + key_value.value = value + + def get_value(self, key: str, default: Optional[str] = None) -> Optional[str]: + data = self.MiscData.get(name=key) + return data.value if data else default + def drop_indexes(self): cursor = self._db.get_connection().cursor() cursor.execute("select name from sqlite_master where type='index' and name like 'idx_%'") @@ -391,9 +404,9 @@ def process_channel_dir(self, dirname, public_key, id_, **kwargs): if not channel: return if ( - blob_sequence_number <= channel.start_timestamp - or blob_sequence_number <= channel.local_version - or blob_sequence_number > channel.timestamp + blob_sequence_number <= channel.start_timestamp + or blob_sequence_number <= channel.local_version + or blob_sequence_number > channel.timestamp ): continue try: @@ -595,28 +608,28 @@ def search_keyword(self, query, lim=100): @db_session def get_entries_query( - self, - metadata_type=None, - channel_pk=None, - exclude_deleted=False, - hide_xxx=False, - exclude_legacy=False, - origin_id=None, - sort_by=None, - sort_desc=True, - max_rowid=None, - txt_filter=None, - subscribed=None, - category=None, - attribute_ranges=None, - infohash=None, - infohash_set=None, - id_=None, - complete_channel=None, - self_checked_torrent=None, - cls=None, - health_checked_after=None, - popular=None, + self, + metadata_type=None, + channel_pk=None, + exclude_deleted=False, + hide_xxx=False, + exclude_legacy=False, + origin_id=None, + sort_by=None, + sort_desc=True, + max_rowid=None, + txt_filter=None, + subscribed=None, + category=None, + attribute_ranges=None, + infohash=None, + infohash_set=None, + id_=None, + complete_channel=None, + self_checked_torrent=None, + cls=None, + health_checked_after=None, + popular=None, ): """ This method implements REST-friendly way to get entries from the database. @@ -662,8 +675,8 @@ def get_entries_query( if attribute_ranges is not None: for attr, left, right in attribute_ranges: if ( - self.ChannelNode._adict_.get(attr) # pylint: disable=W0212 - or self.ChannelNode._subclass_adict_.get(attr) # pylint: disable=W0212 + self.ChannelNode._adict_.get(attr) # pylint: disable=W0212 + or self.ChannelNode._subclass_adict_.get(attr) # pylint: disable=W0212 ) is None: # Check against code injection raise AttributeError("Tried to query for non-existent attribute") if left is not None: @@ -737,7 +750,7 @@ def get_entries(self, first=1, last=None, **kwargs): :return: A list of class members """ pony_query = self.get_entries_query(**kwargs) - result = pony_query[(first or 1) - 1 : last] + result = pony_query[(first or 1) - 1: last] for entry in result: # ACHTUNG! This is necessary in order to load entry.health inside db_session, # to be able to perform successfully `entry.to_simple_dict()` later diff --git a/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py b/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py index 493cba017df..965d7afc663 100644 --- a/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py +++ b/src/tribler-core/tribler_core/components/metadata_store/metadata_store_component.py @@ -3,8 +3,7 @@ from tribler_core.components.base import Component from tribler_core.components.key.key_component import KeyComponent from tribler_core.components.metadata_store.db.store import MetadataStore -from tribler_core.components.metadata_store.utils import generate_test_channels -from tribler_core.components.tag.tag_component import TagComponent +from tribler_core.components.tag.rules.tag_rules_processor import TagRulesProcessor class MetadataStoreComponent(Component): @@ -42,13 +41,11 @@ async def run(self): key_component.primary_key, notifier=self.session.notifier, disable_sync=config.gui_test_mode, + tag_version=TagRulesProcessor.version ) self.mds = metadata_store - self.session.notifier.add_observer(NTFY.TORRENT_METADATA_ADDED, + self.session.notifier.add_observer(NTFY.TORRENT_METADATA_ADDED.value, metadata_store.TorrentMetadata.add_ffa_from_dict) - if config.gui_test_mode: - tag_component = await self.require_component(TagComponent) - generate_test_channels(metadata_store, tag_component.tags_db) async def shutdown(self): await super().shutdown() diff --git a/src/tribler-core/tribler_core/components/tag/rules/tag_rules.py b/src/tribler-core/tribler_core/components/tag/rules/tag_rules.py new file mode 100644 index 00000000000..1fe007cd463 --- /dev/null +++ b/src/tribler-core/tribler_core/components/tag/rules/tag_rules.py @@ -0,0 +1,63 @@ +import re +from typing import AnyStr, Iterable, Optional, Pattern, Sequence + +from tribler_core.components.tag.community.tag_validator import is_valid_tag + +# Each regex expression should contain just a single capturing group: +square_brackets_re = re.compile(r'\[([^\[\]]+)]') +parentheses_re = re.compile(r'\(([^()]+)\)') +extension_re = re.compile(r'\.(\w{3,4})$') +delimiter_re = re.compile(r'([^\s.,/|]+)') + +tags_in_square_brackets = [ + square_brackets_re, # extract content from square brackets + delimiter_re # divide content by "," or "." or " " or "/" +] + +tags_in_parentheses = [ + parentheses_re, # extract content from brackets + delimiter_re # divide content by "," or "." or " " or "/" +] + +tags_in_extension = [ + extension_re # extract an extension +] + +RulesList = Sequence[Sequence[Pattern[AnyStr]]] +default_rules: RulesList = [ + tags_in_square_brackets, + tags_in_parentheses, + tags_in_extension +] + + +def extract_tags(text: str, rules: Optional[RulesList] = None) -> Iterable[str]: + """ Extract tags by using the giving rules. + + Rules are represented by an array of an array of regexes. + Each rule contains one or more regex expressions. + + During the `text` processing, each rule will be applied to the `text` value. + All extracted tags will be returned. + + During application of the particular rule, `text` will be split into + tokens by application of the first regex expression. Then, second regex + expression will be applied to each tokens that were extracted on the + previous step. + This process will be repeated until regex expression ends. + """ + rules = rules or default_rules + for rule in rules: + text_set = {text} + for regex in rule: + next_text_set = set() + for token in text_set: + for match in regex.finditer(token): + next_text_set |= set(match.groups()) + text_set = next_text_set + yield from text_set + + +def extract_only_valid_tags(text: str, rules: Optional[RulesList] = None) -> Iterable[str]: + extracted_tags_gen = (t.lower() for t in extract_tags(text, rules)) + yield from (t for t in extracted_tags_gen if is_valid_tag(t)) diff --git a/src/tribler-core/tribler_core/components/tag/rules/tag_rules_processor.py b/src/tribler-core/tribler_core/components/tag/rules/tag_rules_processor.py new file mode 100644 index 00000000000..13085e67877 --- /dev/null +++ b/src/tribler-core/tribler_core/components/tag/rules/tag_rules_processor.py @@ -0,0 +1,110 @@ +import logging +from typing import Optional, Set + +from ipv8.taskmanager import TaskManager + +from pony.orm import db_session + +import tribler_core.components.metadata_store.db.orm_bindings.torrent_metadata as torrent_metadata +import tribler_core.components.metadata_store.db.store as MDS +from tribler_core.components.metadata_store.db.serialization import REGULAR_TORRENT +from tribler_core.components.tag.community.tag_payload import TagOperation, TagOperationEnum +from tribler_core.components.tag.db.tag_db import ( + CLOCK_FOR_AUTOGENERATED_TAGS, + PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, + TagDatabase, +) +from tribler_core.components.tag.rules.tag_rules import extract_only_valid_tags +from tribler_core.notifier import Notifier + +LAST_PROCESSED_TORRENT_ID = 'last_processed_torrent_id' + + +class TagRulesProcessor(TaskManager): + # this value must be incremented in the case of new rules set has been applied + version: int = 1 + + def __init__(self, notifier: Notifier, db: TagDatabase, mds: MDS.MetadataStore, + batch_size: int = 1000, interval: float = 10): + """ + Default values for batch_size and interval are chosen so that tag processing is not too heavy + fot CPU and with this values 360k items will be processed within the hour. + """ + super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) + + self.notifier = notifier + self.db = db + self.mds = mds + self.batch_size = batch_size + self.interval = interval + self.notifier.add_observer(torrent_metadata.NEW_TORRENT_METADATA_CREATED, + callback=self.process_torrent_title) + self.register_task(name=self.process_batch.__name__, + interval=interval, + task=self.process_batch) + + @db_session + def process_batch(self) -> int: + def query(_start, _end): + return lambda t: _start < t.rowid and t.rowid <= _end and \ + t.metadata_type == REGULAR_TORRENT and \ + t.tag_version < self.version + + start = int(self.mds.get_value(LAST_PROCESSED_TORRENT_ID, default='0')) + end = start + self.batch_size + self.logger.info(f'Processing batch [{start}...{end}]') + + batch = self.mds.TorrentMetadata.select(query(start, end)) + processed = 0 + added = 0 + for torrent in batch: + added += self.process_torrent_title(torrent.infohash, torrent.title) + torrent.tag_version = self.version + processed += 1 + + self.logger.info(f'Processed: {processed} titles. Added {added} tags.') + max_row_id = self.mds.get_max_rowid() + + is_beyond_the_boundary = end > max_row_id + if is_beyond_the_boundary: + self._schedule_new_process_batch_round() + else: + self.mds.set_value(LAST_PROCESSED_TORRENT_ID, str(end)) + return processed + + def process_torrent_title(self, infohash: Optional[bytes] = None, title: Optional[str] = None) -> int: + if not infohash or not title: + return 0 + tags = set(extract_only_valid_tags(title)) + if tags: + self.save_tags(infohash, tags) + return len(tags) + + @db_session + def save_tags(self, infohash: bytes, tags: Set[str]): + self.logger.debug(f'Save: {len(tags)} tags') + for tag in tags: + operation = TagOperation( + infohash=infohash, + operation=TagOperationEnum.ADD, + clock=CLOCK_FOR_AUTOGENERATED_TAGS, + creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, + tag=tag + ) + # we want auto generated operation to act like a normal operation + # therefore we use 2 as a `counter_increment` to immediately pass + # SHOW_THRESHOLD + self.db.add_auto_generated_tag_operation(operation=operation) + + def _schedule_new_process_batch_round(self): + self.logger.info('All items in TorrentMetadata have been processed.') + self.mds.set_value(LAST_PROCESSED_TORRENT_ID, '0') + self.logger.info('Set last_processed_torrent_id to 0') + self.interval *= 2 + self.logger.info(f'Double the interval. New interval: {self.interval}') + self.batch_size *= 2 + self.logger.info(f'Double the batch size. New batch size: {self.batch_size}') + self.replace_task(self.process_batch.__name__, + interval=self.interval, + task=self.process_batch) diff --git a/src/tribler-core/tribler_core/components/tag/rules/tests/test_general_rules.py b/src/tribler-core/tribler_core/components/tag/rules/tests/test_general_rules.py new file mode 100644 index 00000000000..fa6b93e149c --- /dev/null +++ b/src/tribler-core/tribler_core/components/tag/rules/tests/test_general_rules.py @@ -0,0 +1,91 @@ +import pytest + +from tribler_core.components.tag.rules.tag_rules import ( + delimiter_re, + extension_re, + extract_only_valid_tags, + extract_tags, + parentheses_re, + square_brackets_re, + tags_in_parentheses, + tags_in_square_brackets, +) + +DELIMITERS = [ + ('word1 word2 word3', ['word1', 'word2', 'word3']), + ('word1,word2,word3', ['word1', 'word2', 'word3']), + ('word1/word2/word3', ['word1', 'word2', 'word3']), + ('word1|word2|word3', ['word1', 'word2', 'word3']), + ('word1 /.,word2', ['word1', 'word2']), +] + +SQUARE_BRACKETS = [ + ('[word1] [word2 word3]', ['word1', 'word2 word3']), + ('[word1 [word2] word3]', ['word2']), +] + +PARENTHESES = [ + ('(word1) (word2 word3)', ['word1', 'word2 word3']), + ('(word1 (word2) word3)', ['word2']), +] + +EXTENSIONS = [ + ('some.ext', ['ext']), + ('some.ext4', ['ext4']), + ('some', []), + ('some. ext', []), + ('some.ext ', []), +] + + +@pytest.mark.parametrize('text, words', DELIMITERS) +def test_delimiter(text, words): + assert delimiter_re.findall(text) == words + + +@pytest.mark.parametrize('text, words', SQUARE_BRACKETS) +def test_square_brackets(text, words): + assert square_brackets_re.findall(text) == words + + +@pytest.mark.parametrize('text, words', PARENTHESES) +def test_parentheses(text, words): + assert parentheses_re.findall(text) == words + + +@pytest.mark.parametrize('text, words', EXTENSIONS) +def test_extension(text, words): + # test regex + assert extension_re.findall(text) == words + + +def test_tags_in_square_brackets(): + # test that tags_in_square_brackets rule works correctly with extract_tags function + text = 'text [tag1, tag2] text1 [tag3|tag4] text2, (tag5, tag6)' + expected_tags = {'tag1', 'tag2', 'tag3', 'tag4'} + + actual_tags = set(extract_tags(text, rules=[tags_in_square_brackets])) + assert actual_tags == expected_tags + + +def test_tags_in_parentheses(): + # test that tags_in_parentheses rule works correctly with extract_tags function + text = 'text (tag1, tag2) text1 (tag3|tag4) text2, [tag5, tag6]' + expected_tags = {'tag1', 'tag2', 'tag3', 'tag4'} + + actual_tags = set(extract_tags(text, rules=[tags_in_parentheses])) + assert actual_tags == expected_tags + + +def test_default_rules(): + # test that default_rules works correctly with extract_tags function + text = 'text (tag1, tag2) text1 (tag3|tag4) text2, [tag5, tag6].ext' + expected_tags = {'tag1', 'tag2', 'tag3', 'tag4', 'tag5', 'tag6', 'ext'} + + actual_tags = set(extract_tags(text)) + assert actual_tags == expected_tags + + +def test_extract_only_valid_tags(): + # test that extract_only_valid_tags extracts only valid tags + assert set(extract_only_valid_tags('[valid-tag, in va li d]')) == {'valid-tag'} diff --git a/src/tribler-core/tribler_core/components/tag/rules/tests/test_tag_rules_processor.py b/src/tribler-core/tribler_core/components/tag/rules/tests/test_tag_rules_processor.py new file mode 100644 index 00000000000..d1580381a43 --- /dev/null +++ b/src/tribler-core/tribler_core/components/tag/rules/tests/test_tag_rules_processor.py @@ -0,0 +1,110 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from tribler_core.components.metadata_store.db.orm_bindings.torrent_metadata import NEW_TORRENT_METADATA_CREATED +from tribler_core.components.tag.community.tag_payload import TagOperation, TagOperationEnum +from tribler_core.components.tag.db.tag_db import CLOCK_FOR_AUTOGENERATED_TAGS, PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS +from tribler_core.components.tag.rules.tag_rules_processor import LAST_PROCESSED_TORRENT_ID, TagRulesProcessor + +TEST_BATCH_SIZE = 100 +TEST_INTERVAL = 0.1 + +# pylint: disable=redefined-outer-name, protected-access +@pytest.fixture +def tag_rules_processor(): + return TagRulesProcessor(notifier=Mock(), db=Mock(), mds=Mock(), batch_size=TEST_BATCH_SIZE, interval=TEST_INTERVAL) + + +def test_constructor(tag_rules_processor: TagRulesProcessor): + # test that constructor of TagRulesProcessor works as expected + assert tag_rules_processor.batch_size == TEST_BATCH_SIZE + assert tag_rules_processor.interval == TEST_INTERVAL + + m: Mock = tag_rules_processor.notifier.add_observer + m.assert_called_with(NEW_TORRENT_METADATA_CREATED, callback=tag_rules_processor.process_torrent_title) + + +@patch.object(TagRulesProcessor, 'save_tags') +def test_process_torrent_file(mocked_save_tags: Mock, tag_rules_processor: TagRulesProcessor): + # test on None + assert not tag_rules_processor.process_torrent_title(infohash=None, title='title') + assert not tag_rules_processor.process_torrent_title(infohash=b'infohash', title=None) + + # test that process_torrent_title doesn't find any tags in the title + assert not tag_rules_processor.process_torrent_title(infohash=b'infohash', title='title') + mocked_save_tags.assert_not_called() + + # test that process_torrent_title does find tags in the title + assert tag_rules_processor.process_torrent_title(infohash=b'infohash', title='title [tag]') == 1 + mocked_save_tags.assert_called_with(b'infohash', {'tag'}) + + +def test_save_tags(tag_rules_processor: TagRulesProcessor): + # test that tag_rules_processor calls TagDatabase with correct args + expected_calls = [{'operation': TagOperation(infohash=b'infohash', operation=TagOperationEnum.ADD, + clock=CLOCK_FOR_AUTOGENERATED_TAGS, + creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, + tag='tag1')}, + {'operation': TagOperation(infohash=b'infohash', operation=TagOperationEnum.ADD, + clock=CLOCK_FOR_AUTOGENERATED_TAGS, + creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS, + tag='tag2')}] + + tag_rules_processor.save_tags(infohash=b'infohash', tags={'tag1', 'tag2'}) + actual_calls = [c.kwargs for c in tag_rules_processor.db.add_auto_generated_tag_operation.mock_calls] + + # compare two lists of dict + assert [c for c in actual_calls if c not in expected_calls] == [] + + +@patch.object(TagRulesProcessor, 'replace_task') +def test_schedule_new_process_batch_round(mocked_replace_task: Mock, tag_rules_processor: TagRulesProcessor): + tag_rules_processor._schedule_new_process_batch_round() + assert tag_rules_processor.interval == TEST_INTERVAL * 2 + assert tag_rules_processor.batch_size == TEST_BATCH_SIZE * 2 + tag_rules_processor.mds.set_value.assert_called_with(LAST_PROCESSED_TORRENT_ID, '0') + mocked_replace_task.assert_called_once() + + +@patch.object(TagRulesProcessor, 'process_torrent_title', new=Mock(return_value=1)) +def test_process_batch_within_the_boundary(tag_rules_processor: TagRulesProcessor): + # test inner logic of `process_batch` in case this batch located within the boundary + returned_batch_size = TEST_BATCH_SIZE // 2 # let's return a half of requested items + + def select(_): + return [SimpleNamespace(infohash=i, title=i) for i in range(returned_batch_size)] + + tag_rules_processor.mds.TorrentMetadata.select = select + tag_rules_processor.mds.get_value = lambda *_, **__: 0 # let's start from 0 for LAST_PROCESSED_TORRENT_ID + + # let's specify `max_rowid` in such a way that it is far more than end of the current batch + tag_rules_processor.mds.get_max_rowid = lambda: TEST_BATCH_SIZE * 10 + + # assert that actually returned count of processed items is equal to `returned_batch_size` + assert tag_rules_processor.process_batch() == returned_batch_size + + # assert that actually stored last_processed_torrent_id is equal to `TEST_BATCH_SIZE` + tag_rules_processor.mds.set_value.assert_called_with(LAST_PROCESSED_TORRENT_ID, str(TEST_BATCH_SIZE)) + + +@patch.object(TagRulesProcessor, '_schedule_new_process_batch_round') +@patch.object(TagRulesProcessor, 'process_torrent_title', new=Mock(return_value=1)) +def test_process_batch_beyond_the_boundary(mocked_schedule_new_process_batch_round: Mock, + tag_rules_processor: TagRulesProcessor): + # test inner logic of `process_batch` in case this batch located within the boundary + returned_batch_size = TEST_BATCH_SIZE // 2 # let's return a half of requested items + + def select(_): + return [SimpleNamespace(infohash=i, title=i) for i in range(returned_batch_size)] + + tag_rules_processor.mds.get_value = lambda *_, **__: 0 # let's start from 0 for LAST_PROCESSED_TORRENT_ID + tag_rules_processor.mds.TorrentMetadata.select = select + + # let's specify `max_rowid` in such a way that it is less than end of the current batch + tag_rules_processor.mds.get_max_rowid = lambda: returned_batch_size // 2 + + # assert that actually returned count of processed items is equal to `returned_batch_size` + assert tag_rules_processor.process_batch() == returned_batch_size + mocked_schedule_new_process_batch_round.assert_called_once() diff --git a/src/tribler-core/tribler_core/components/tag/tag_component.py b/src/tribler-core/tribler_core/components/tag/tag_component.py index 44eadb1d949..cdd1a2940f8 100644 --- a/src/tribler-core/tribler_core/components/tag/tag_component.py +++ b/src/tribler-core/tribler_core/components/tag/tag_component.py @@ -1,10 +1,13 @@ from tribler_common.simpledefs import STATEDIR_DB_DIR +import tribler_core.components.metadata_store.metadata_store_component as metadata_store_component from tribler_core.components.base import Component from tribler_core.components.ipv8.ipv8_component import Ipv8Component from tribler_core.components.key.key_component import KeyComponent +from tribler_core.components.metadata_store.utils import generate_test_channels from tribler_core.components.tag.community.tag_community import TagCommunity from tribler_core.components.tag.db.tag_db import TagDatabase +from tribler_core.components.tag.rules.tag_rules_processor import TagRulesProcessor class TagComponent(Component): @@ -12,6 +15,7 @@ class TagComponent(Component): community: TagCommunity = None tags_db: TagDatabase = None + rules_processor: TagRulesProcessor = None _ipv8_component: Ipv8Component = None async def run(self): @@ -19,12 +23,13 @@ async def run(self): self._ipv8_component = await self.require_component(Ipv8Component) key_component = await self.require_component(KeyComponent) + mds_component = await self.require_component(metadata_store_component.MetadataStoreComponent) db_path = self.session.config.state_dir / STATEDIR_DB_DIR / "tags.db" if self.session.config.gui_test_mode: db_path = ":memory:" - self.tags_db = TagDatabase(str(db_path)) + self.tags_db = TagDatabase(str(db_path), create_tables=True) self.community = TagCommunity( self._ipv8_component.peer, self._ipv8_component.ipv8.endpoint, @@ -32,9 +37,17 @@ async def run(self): db=self.tags_db, tags_key=key_component.secondary_key ) + self.rules_processor = TagRulesProcessor( + notifier=self.session.notifier, + db=self.tags_db, + mds=mds_component.mds, + ) self._ipv8_component.initialise_community_by_default(self.community) + if self.session.config.gui_test_mode: + generate_test_channels(mds_component.mds, self.tags_db) + async def shutdown(self): await super().shutdown() if self._ipv8_component and self.community: diff --git a/src/tribler-core/tribler_core/components/tag/tests/test_tag_component.py b/src/tribler-core/tribler_core/components/tag/tests/test_tag_component.py index 16cb000d1aa..7e341714809 100644 --- a/src/tribler-core/tribler_core/components/tag/tests/test_tag_component.py +++ b/src/tribler-core/tribler_core/components/tag/tests/test_tag_component.py @@ -3,6 +3,7 @@ from tribler_core.components.base import Session from tribler_core.components.ipv8.ipv8_component import Ipv8Component from tribler_core.components.key.key_component import KeyComponent +from tribler_core.components.metadata_store.metadata_store_component import MetadataStoreComponent from tribler_core.components.tag.tag_component import TagComponent # pylint: disable=protected-access @@ -10,7 +11,7 @@ @pytest.mark.asyncio async def test_tag_component(tribler_config): - components = [KeyComponent(), Ipv8Component(), TagComponent()] + components = [MetadataStoreComponent(), KeyComponent(), Ipv8Component(), TagComponent()] async with Session(tribler_config, components).start(): comp = TagComponent.instance() assert comp.started_event.is_set() and not comp.failed