diff --git a/src/tribler/core/components/bandwidth_accounting/db/database.py b/src/tribler/core/components/bandwidth_accounting/db/database.py index fcc6435eded..c755f95a64d 100644 --- a/src/tribler/core/components/bandwidth_accounting/db/database.py +++ b/src/tribler/core/components/bandwidth_accounting/db/database.py @@ -5,7 +5,7 @@ from tribler.core.components.bandwidth_accounting.db import history, misc, transaction as db_transaction from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData -from tribler.core.utilities.pony_utils import TriblerDatabase +from tribler.core.utilities.pony_utils import TrackedDatabase from tribler.core.utilities.utilities import MEMORY_DB @@ -28,7 +28,7 @@ def __init__(self, db_path: Union[Path, type(MEMORY_DB)], my_pub_key: bytes, self.my_pub_key = my_pub_key self.store_all_transactions = store_all_transactions - self.database = TriblerDatabase() + self.database = TrackedDatabase() # This attribute is internally called by Pony on startup, though pylint cannot detect it # with the static analysis. diff --git a/src/tribler/core/components/conftest.py b/src/tribler/core/components/conftest.py index 35a57027b10..c4b52ad9aa5 100644 --- a/src/tribler/core/components/conftest.py +++ b/src/tribler/core/components/conftest.py @@ -4,7 +4,7 @@ from ipv8.keyvault.private.libnaclkey import LibNaCLSK from ipv8.util import succeed -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.libtorrent.download_manager.download_config import DownloadConfig from tribler.core.components.libtorrent.download_manager.download_manager import DownloadManager from tribler.core.components.libtorrent.settings import LibtorrentSettings @@ -107,8 +107,8 @@ def metadata_store(tmp_path): @pytest.fixture -def knowledge_db(): - db = KnowledgeDatabase() +def tribler_db(): + db = TriblerDatabase() yield db db.shutdown() diff --git a/src/tribler/core/components/database/__init__.py b/src/tribler/core/components/database/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/components/database/database_component.py b/src/tribler/core/components/database/database_component.py new file mode 100644 index 00000000000..124d592a09f --- /dev/null +++ b/src/tribler/core/components/database/database_component.py @@ -0,0 +1,23 @@ +from tribler.core.components.component import Component +from tribler.core.components.database.db.tribler_database import TriblerDatabase +from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR + + +class DatabaseComponent(Component): + tribler_should_stop_on_component_error = True + + db: TriblerDatabase = None + + async def run(self): + await super().run() + + db_path = self.session.config.state_dir / STATEDIR_DB_DIR / "tribler.db" + if self.session.config.gui_test_mode: + db_path = ":memory:" + + self.db = TriblerDatabase(str(db_path), create_tables=True) + + async def shutdown(self): + await super().shutdown() + if self.db: + self.db.shutdown() diff --git a/src/tribler/core/components/database/db/__init__.py b/src/tribler/core/components/database/db/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/components/database/db/tests/__init__.py b/src/tribler/core/components/database/db/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/components/knowledge/db/tests/test_knowledge_db.py b/src/tribler/core/components/database/db/tests/test_tribler_database.py similarity index 97% rename from src/tribler/core/components/knowledge/db/tests/test_knowledge_db.py rename to src/tribler/core/components/database/db/tests/test_tribler_database.py index 66716f22044..2e4edb9c72c 100644 --- a/src/tribler/core/components/knowledge/db/tests/test_knowledge_db.py +++ b/src/tribler/core/components/database/db/tests/test_tribler_database.py @@ -3,22 +3,22 @@ from pony.orm import commit, db_session -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, Operation, \ +from tribler.core.components.database.db.tribler_database import TriblerDatabase, Operation, \ PUBLIC_KEY_FOR_AUTO_GENERATED_OPERATIONS, ResourceType, SHOW_THRESHOLD, SimpleStatement -from tribler.core.components.knowledge.db.tests.test_knowledge_db_base import Resource, TestTagDBBase -from tribler.core.utilities.pony_utils import TriblerDatabase, get_or_create +from tribler.core.components.database.db.tests.test_tribler_database_base import Resource, TestTagDBBase +from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create # pylint: disable=protected-access class TestTagDB(TestTagDBBase): - @patch.object(TriblerDatabase, 'generate_mapping') + @patch.object(TrackedDatabase, 'generate_mapping') def test_constructor_create_tables_true(self, mocked_generate_mapping: Mock): - KnowledgeDatabase(':memory:') + TriblerDatabase(':memory:') mocked_generate_mapping.assert_called_with(create_tables=True) - @patch.object(TriblerDatabase, 'generate_mapping') + @patch.object(TrackedDatabase, 'generate_mapping') def test_constructor_create_tables_false(self, mocked_generate_mapping: Mock): - KnowledgeDatabase(':memory:', create_tables=False) + TriblerDatabase(':memory:', create_tables=False) mocked_generate_mapping.assert_called_with(create_tables=False) @db_session @@ -446,9 +446,9 @@ def _results(objects, predicate=ResourceType.TAG, case_sensitive=True): @db_session def test_show_condition(self): - assert KnowledgeDatabase._show_condition(SimpleNamespace(local_operation=Operation.ADD)) - assert KnowledgeDatabase._show_condition(SimpleNamespace(local_operation=None, score=SHOW_THRESHOLD)) - assert not KnowledgeDatabase._show_condition(SimpleNamespace(local_operation=None, score=0)) + assert TriblerDatabase._show_condition(SimpleNamespace(local_operation=Operation.ADD)) + assert TriblerDatabase._show_condition(SimpleNamespace(local_operation=None, score=SHOW_THRESHOLD)) + assert not TriblerDatabase._show_condition(SimpleNamespace(local_operation=None, score=0)) @db_session def test_get_random_operations_by_condition_less_than_count(self): diff --git a/src/tribler/core/components/knowledge/db/tests/test_knowledge_db_base.py b/src/tribler/core/components/database/db/tests/test_tribler_database_base.py similarity index 91% rename from src/tribler/core/components/knowledge/db/tests/test_knowledge_db_base.py rename to src/tribler/core/components/database/db/tests/test_tribler_database_base.py index f0b5755d4c1..49bc876653c 100644 --- a/src/tribler/core/components/knowledge/db/tests/test_knowledge_db_base.py +++ b/src/tribler/core/components/database/db/tests/test_tribler_database_base.py @@ -4,8 +4,9 @@ from ipv8.test.base import TestBase from pony.orm import commit, db_session +from tribler.core.components.database.db.tribler_database import Operation, ResourceType, SHOW_THRESHOLD, \ + TriblerDatabase from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, Operation, ResourceType, SHOW_THRESHOLD from tribler.core.utilities.pony_utils import get_or_create @@ -22,7 +23,7 @@ class Resource: class TestTagDBBase(TestBase): def setUp(self): super().setUp() - self.db = KnowledgeDatabase() + self.db = TriblerDatabase() async def tearDown(self): if self._outcome.errors: @@ -56,7 +57,7 @@ def create_operation(subject_type: ResourceType = ResourceType.TORRENT, subject= operation=operation, clock=clock, creator_public_key=peer) @staticmethod - def add_operation(tag_db: KnowledgeDatabase, subject_type: ResourceType = ResourceType.TORRENT, + def add_operation(tag_db: TriblerDatabase, subject_type: ResourceType = ResourceType.TORRENT, subject: str = 'infohash', predicate: ResourceType = ResourceType.TAG, obj: str = 'tag', peer=b'', operation: Operation = None, @@ -70,7 +71,7 @@ def add_operation(tag_db: KnowledgeDatabase, subject_type: ResourceType = Resour return result @staticmethod - def add_operation_set(tag_db: KnowledgeDatabase, dictionary): + def add_operation_set(tag_db: TriblerDatabase, dictionary): index = count(0) def generate_n_peer_names(n): diff --git a/src/tribler/core/components/database/db/tribler_database.py b/src/tribler/core/components/database/db/tribler_database.py new file mode 100644 index 00000000000..50504eaab09 --- /dev/null +++ b/src/tribler/core/components/database/db/tribler_database.py @@ -0,0 +1,456 @@ +import datetime +import logging +from dataclasses import dataclass +from enum import IntEnum +from typing import Any, Callable, Iterator, List, Optional, Set + +from pony import orm +from pony.orm import raw_sql +from pony.orm.core import Entity, Query, select +from pony.utils import between + +from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation +from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create + +CLOCK_START_VALUE = 0 + +PUBLIC_KEY_FOR_AUTO_GENERATED_OPERATIONS = b'auto_generated' + +SHOW_THRESHOLD = 1 # how many operation needed for showing a knowledge graph statement in the UI +HIDE_THRESHOLD = -2 # how many operation needed for hiding a knowledge graph statement in the UI + + +class Operation(IntEnum): + """ Available types of statement operations.""" + ADD = 1 # +1 operation + REMOVE = 2 # -1 operation + + +class ResourceType(IntEnum): + """ Description of available resources within the Knowledge Graph. + These types are also using as a predicate for the statements. + + Based on https://en.wikipedia.org/wiki/Dublin_Core + """ + CONTRIBUTOR = 1 + COVERAGE = 2 + CREATOR = 3 + DATE = 4 + DESCRIPTION = 5 + FORMAT = 6 + IDENTIFIER = 7 + LANGUAGE = 8 + PUBLISHER = 9 + RELATION = 10 + RIGHTS = 11 + SOURCE = 12 + SUBJECT = 13 + TITLE = 14 + TYPE = 15 + + # this is a section for extra types + TAG = 101 + TORRENT = 102 + CONTENT_ITEM = 103 + + +@dataclass +class SimpleStatement: + subject_type: ResourceType + object: str + predicate: ResourceType + subject: str + + +class TriblerDatabase: + def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): + self.instance = TrackedDatabase() + self.define_binding(self.instance) + self.instance.bind('sqlite', filename or ':memory:', create_db=True) + generate_mapping_kwargs['create_tables'] = create_tables + self.instance.generate_mapping(**generate_mapping_kwargs) + self.logger = logging.getLogger(self.__class__.__name__) + + @staticmethod + def define_binding(db): + class Peer(db.Entity): + id = orm.PrimaryKey(int, auto=True) + public_key = orm.Required(bytes, unique=True) + added_at = orm.Optional(datetime.datetime, default=datetime.datetime.utcnow) + operations = orm.Set(lambda: StatementOp) + + class Statement(db.Entity): + id = orm.PrimaryKey(int, auto=True) + + subject = orm.Required(lambda: Resource) + object = orm.Required(lambda: Resource, index=True) + + operations = orm.Set(lambda: StatementOp) + + added_count = orm.Required(int, default=0) + removed_count = orm.Required(int, default=0) + + local_operation = orm.Optional(int) # in case user don't (or do) want to see it locally + + orm.composite_key(subject, object) + + @property + def score(self): + return self.added_count - self.removed_count + + def update_counter(self, operation: Operation, increment: int = 1, is_local_peer: bool = False): + """ Update Statement's counter + Args: + operation: Resource operation + increment: + is_local_peer: The flag indicates whether do we performs operations from a local user or from + a remote user. In case of the local user, his operations will be considered as + authoritative for his (only) local Tribler instance. + + Returns: + """ + if is_local_peer: + self.local_operation = operation + if operation == Operation.ADD: + self.added_count += increment + if operation == Operation.REMOVE: + self.removed_count += increment + + class Resource(db.Entity): + id = orm.PrimaryKey(int, auto=True) + name = orm.Required(str) + type = orm.Required(int) # ResourceType enum + + subject_statements = orm.Set(lambda: Statement, reverse="subject") + object_statements = orm.Set(lambda: Statement, reverse="object") + + orm.composite_key(name, type) + + class StatementOp(db.Entity): + id = orm.PrimaryKey(int, auto=True) + + statement = orm.Required(lambda: Statement) + peer = orm.Required(lambda: Peer) + + operation = orm.Required(int) + clock = orm.Required(int) + signature = orm.Required(bytes) + updated_at = orm.Required(datetime.datetime, default=datetime.datetime.utcnow) + auto_generated = orm.Required(bool, default=False) + + orm.composite_key(statement, peer) + + class Misc(db.Entity): # pylint: disable=unused-variable + name = orm.PrimaryKey(str) + value = orm.Optional(str) + + def add_operation(self, operation: StatementOperation, signature: bytes, is_local_peer: bool = False, + is_auto_generated: bool = False, counter_increment: int = 1) -> bool: + """ Add the operation that will be applied to a statement. + Args: + operation: the class describes the adding operation + signature: the signature of the operation + is_local_peer: local operations processes differently than remote operations. They affects + `Statement.local_operation` field which is used in `self.get_tags()` function. + is_auto_generated: the indicator of whether this resource was generated automatically or not + counter_increment: the counter or "numbers" of adding operations + + Returns: True if the operation has been added/updated, False otherwise. + """ + self.logger.debug(f'Add operation. {operation.subject} "{operation.predicate}" {operation.object}') + peer = get_or_create(self.instance.Peer, public_key=operation.creator_public_key) + subject = get_or_create(self.instance.Resource, name=operation.subject, type=operation.subject_type) + obj = get_or_create(self.instance.Resource, name=operation.object, type=operation.predicate) + statement = get_or_create(self.instance.Statement, subject=subject, object=obj) + op = self.instance.StatementOp.get_for_update(statement=statement, peer=peer) + + if not op: # then insert + self.instance.StatementOp(statement=statement, peer=peer, operation=operation.operation, + clock=operation.clock, signature=signature, auto_generated=is_auto_generated) + statement.update_counter(operation.operation, increment=counter_increment, is_local_peer=is_local_peer) + return True + + # if it is a message from the past, then return + if operation.clock <= op.clock: + return False + + # To prevent endless incrementing of the operation, we apply the following logic: + + # 1. Decrement previous operation + statement.update_counter(op.operation, increment=-counter_increment, is_local_peer=is_local_peer) + # 2. Increment new operation + statement.update_counter(operation.operation, increment=counter_increment, is_local_peer=is_local_peer) + + # 3. Update the operation entity + op.set(operation=operation.operation, clock=operation.clock, signature=signature, + updated_at=datetime.datetime.utcnow(), auto_generated=is_auto_generated) + return True + + def add_auto_generated(self, subject_type: ResourceType, subject: str, predicate: ResourceType, obj: str) -> bool: + """ Add an autogenerated operation. + + The difference between "normal" and "autogenerated" operation is that the autogenerated operation will be added + with the flag `is_auto_generated=True` and with the `PUBLIC_KEY_FOR_AUTO_GENERATED_TAGS` public key. + + Args: + subject_type: a type of adding subject. See: ResourceType enum. + subject: a string that represents a subject of adding operation. + predicate: the enum that represents a predicate of adding operation. + obj: a string that represents an object of adding operation. + """ + operation = StatementOperation( + subject_type=subject_type, + subject=subject, + predicate=predicate, + object=obj, + operation=Operation.ADD, + clock=CLOCK_START_VALUE, + creator_public_key=PUBLIC_KEY_FOR_AUTO_GENERATED_OPERATIONS, + ) + + return self.add_operation(operation, signature=b'', is_local_peer=False, is_auto_generated=True, + counter_increment=SHOW_THRESHOLD) + + @staticmethod + def _show_condition(s): + """This function determines show condition for the statement""" + return s.local_operation == Operation.ADD.value or not s.local_operation and s.score >= SHOW_THRESHOLD + + def _get_resources(self, resource_type: Optional[ResourceType], name: Optional[str], case_sensitive: bool) -> Query: + """ Get resources + + Args: + resource_type: type of resources + name: name of resources + case_sensitive: if True, then Resources are selected in a case-sensitive manner. if False, then Resources + are selected in a case-insensitive manner. + + Returns: a Query object for requested resources + """ + + results = self.instance.Resource.select() + if name: + results = results.filter( + (lambda r: r.name == name) if case_sensitive else (lambda r: r.name.lower() == name.lower()) + ) + if resource_type: + results = results.filter(lambda r: r.type == resource_type.value) + return results + + def _get_statements(self, source_type: Optional[ResourceType], source_name: Optional[str], + statements_getter: Callable[[Entity], Entity], + target_condition: Callable[[], bool], condition: Callable[[], bool], + case_sensitive: bool, ) -> Iterator[str]: + """ Get entities that satisfies the given condition. + """ + + for resource in self._get_resources(source_type, source_name, case_sensitive): + results = orm.select(_ for _ in statements_getter(resource) + .select(condition) + .filter(target_condition) + .order_by(lambda s: orm.desc(s.score))) + + yield from list(results) + + def get_objects(self, subject_type: Optional[ResourceType] = None, subject: Optional[str] = '', + predicate: Optional[ResourceType] = None, case_sensitive: bool = True, + condition: Callable[[], bool] = None) -> List[str]: + """ Get objects that satisfy the given subject and predicate. + + To understand the order of parameters, keep in ming the following generic construction: + (, , , ). + + So in the case of retrieving objects this construction becomes + (, , , ?). + + Args: + subject_type: a type of the subject. + subject: a string that represents the subject. + predicate: the enum that represents a predicate of querying operations. + case_sensitive: if True, then Resources are selected in a case-sensitive manner. if False, then Resources + are selected in a case-insensitive manner. + + Returns: a list of the strings representing the objects. + """ + self.logger.debug(f'Get subjects for {subject} with {predicate}') + + statements = self._get_statements( + source_type=subject_type, + source_name=subject, + statements_getter=lambda r: r.subject_statements, + target_condition=(lambda s: s.object.type == predicate.value) if predicate else (lambda _: True), + condition=condition or self._show_condition, + case_sensitive=case_sensitive, + ) + return [s.object.name for s in statements] + + def get_subjects(self, subject_type: Optional[ResourceType] = None, predicate: Optional[ResourceType] = None, + obj: Optional[str] = '', case_sensitive: bool = True) -> List[str]: + """ Get subjects that satisfy the given object and predicate. + To understand the order of parameters, keep in ming the following generic construction: + + (, , , ). + + So in the case of retrieving subjects this construction becomes + (, ?, , ). + + Args: + subject_type: a type of the subject. + obj: a string that represents the object. + predicate: the enum that represents a predicate of querying operations. + case_sensitive: if True, then Resources are selected in a case-sensitive manner. if False, then Resources + are selected in a case-insensitive manner. + + Returns: a list of the strings representing the subjects. + """ + self.logger.debug(f'Get linked back resources for {obj} with {predicate}') + + statements = self._get_statements( + source_type=predicate, + source_name=obj, + statements_getter=lambda r: r.object_statements, + target_condition=(lambda s: s.subject.type == subject_type.value) if subject_type else (lambda _: True), + condition=self._show_condition, + case_sensitive=case_sensitive, + ) + + return [s.subject.name for s in statements] + + def get_statements(self, subject_type: Optional[ResourceType] = None, subject: Optional[str] = '', + case_sensitive: bool = True) -> List[SimpleStatement]: + + statements = self._get_statements( + source_type=subject_type, + source_name=subject, + statements_getter=lambda r: r.subject_statements, + target_condition=lambda _: True, + condition=self._show_condition, + case_sensitive=case_sensitive, + ) + + statements = map(lambda s: SimpleStatement( + subject_type=s.subject.type, + subject=s.subject.name, + predicate=s.object.type, + object=s.object.name + ), statements) + + return list(statements) + + def get_suggestions(self, subject_type: Optional[ResourceType] = None, subject: Optional[str] = '', + predicate: Optional[ResourceType] = None, case_sensitive: bool = True) -> List[str]: + """ Get all suggestions for a particular subject. + + Args: + subject_type: a type of the subject. + subject: a string that represents the subject. + predicate: the enum that represents a predicate of querying operations. + case_sensitive: if True, then Resources are selected in a case-sensitive manner. if False, then Resources + are selected in a case-insensitive manner. + + Returns: a list of the strings representing the objects. + """ + self.logger.debug(f"Getting suggestions for {subject} with {predicate}") + + suggestions = self.get_objects( + subject_type=subject_type, + subject=subject, + predicate=predicate, + case_sensitive=case_sensitive, + condition=lambda s: not s.local_operation and between(s.score, HIDE_THRESHOLD + 1, SHOW_THRESHOLD - 1) + ) + return suggestions + + def get_subjects_intersection(self, subjects_type: Optional[ResourceType], objects: Set[str], + predicate: Optional[ResourceType], + case_sensitive: bool = True) -> Set[str]: + if not objects: + return set() + + if case_sensitive: + name_condition = '"obj"."name" = $obj_name' + else: + name_condition = 'py_lower("obj"."name") = py_lower($obj_name)' + + query = select(r.name for r in self.instance.Resource) + for obj_name in objects: + query = query.filter(raw_sql(f""" + r.id IN ( + SELECT "s"."subject" + FROM "Statement" "s" + WHERE ( + "s"."local_operation" = $(Operation.ADD.value) + OR + ("s"."local_operation" = 0 OR "s"."local_operation" IS NULL) + AND ("s"."added_count" - "s"."removed_count") >= $SHOW_THRESHOLD + ) AND "s"."object" IN ( + SELECT "obj"."id" FROM "Resource" "obj" + WHERE "obj"."type" = $(predicate.value) AND {name_condition} + ) + )""")) + return set(query) + + def get_clock(self, operation: StatementOperation) -> int: + """ Get the clock (int) of operation. + """ + peer = self.instance.Peer.get(public_key=operation.creator_public_key) + subject = self.instance.Resource.get(name=operation.subject, type=operation.subject_type) + obj = self.instance.Resource.get(name=operation.object, type=operation.predicate) + if not subject or not obj or not peer: + return CLOCK_START_VALUE + + statement = self.instance.Statement.get(subject=subject, object=obj) + if not statement: + return CLOCK_START_VALUE + + op = self.instance.StatementOp.get(statement=statement, peer=peer) + return op.clock if op else CLOCK_START_VALUE + + def get_operations_for_gossip(self, count: int = 10) -> Set[Entity]: + """ Get random operations from the DB that older than time_delta. + + Args: + count: a limit for a resulting query + """ + return self._get_random_operations_by_condition( + condition=lambda so: not so.auto_generated, + count=count + ) + + def shutdown(self) -> None: + self.instance.disconnect() + + def _get_random_operations_by_condition(self, condition: Callable[[Entity], bool], count: int = 5, + attempts: int = 100) -> Set[Entity]: + """ Get `count` random operations that satisfy the given condition. + + This method were introduce as an fast alternative for native Pony `random` method. + + + Args: + condition: the condition by which the entities will be queried. + count: the amount of entities to return. + attempts: maximum attempt count for requesting the DB. + + Returns: a set of random operations + """ + operations = set() + for _ in range(attempts): + if len(operations) == count: + return operations + + random_operations_list = self.instance.StatementOp.select_random(1) + if random_operations_list: + operation = random_operations_list[0] + if condition(operation): + operations.add(operation) + + return operations + + def get_misc(self, key: str, default: Optional[str] = None) -> Optional[str]: + data = self.instance.Misc.get(name=key) + return data.value if data else default + + def set_misc(self, key: str, value: Any): + key_value = get_or_create(self.instance.Misc, name=key) + key_value.value = str(value) diff --git a/src/tribler/core/components/gigachannel/gigachannel_component.py b/src/tribler/core/components/gigachannel/gigachannel_component.py index 3a5891c1c35..527f57a75d9 100644 --- a/src/tribler/core/components/gigachannel/gigachannel_component.py +++ b/src/tribler/core/components/gigachannel/gigachannel_component.py @@ -1,6 +1,7 @@ from ipv8.peerdiscovery.network import Network from tribler.core.components.component import Component +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.gigachannel.community.gigachannel_community import ( GigaChannelCommunity, GigaChannelTestnetCommunity, @@ -25,7 +26,7 @@ async def run(self): self._ipv8_component = await self.require_component(Ipv8Component) metadata_store_component = await self.require_component(MetadataStoreComponent) - knowledge_component = await self.get_component(KnowledgeComponent) + db_component = await self.get_component(DatabaseComponent) giga_channel_cls = GigaChannelTestnetCommunity if config.general.testnet else GigaChannelCommunity community = giga_channel_cls( @@ -37,7 +38,7 @@ async def run(self): rqc_settings=config.remote_query_community, metadata_store=metadata_store_component.mds, max_peers=50, - knowledge_db=knowledge_component.knowledge_db if knowledge_component else None + tribler_db=db_component.db if db_component else None ) self.community = community self._ipv8_component.initialise_community_by_default(community, default_random_walk_max_peers=30) diff --git a/src/tribler/core/components/gigachannel/tests/test_gigachannel_component.py b/src/tribler/core/components/gigachannel/tests/test_gigachannel_component.py index 90114ac8740..bd5e75b66bc 100644 --- a/src/tribler/core/components/gigachannel/tests/test_gigachannel_component.py +++ b/src/tribler/core/components/gigachannel/tests/test_gigachannel_component.py @@ -1,3 +1,4 @@ +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.gigachannel.gigachannel_component import GigaChannelComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component from tribler.core.components.key.key_component import KeyComponent @@ -13,7 +14,7 @@ async def test_giga_channel_component(tribler_config): tribler_config.ipv8.enabled = True tribler_config.libtorrent.enabled = True tribler_config.chant.enabled = True - components = [KnowledgeComponent(), MetadataStoreComponent(), KeyComponent(), Ipv8Component(), + components = [DatabaseComponent(), KnowledgeComponent(), MetadataStoreComponent(), KeyComponent(), Ipv8Component(), GigaChannelComponent()] async with Session(tribler_config, components) as session: comp = session.get_instance(GigaChannelComponent) diff --git a/src/tribler/core/components/gigachannel_manager/tests/test_gigachannel_manager_component.py b/src/tribler/core/components/gigachannel_manager/tests/test_gigachannel_manager_component.py index 0885d237312..78435fa18f0 100644 --- a/src/tribler/core/components/gigachannel_manager/tests/test_gigachannel_manager_component.py +++ b/src/tribler/core/components/gigachannel_manager/tests/test_gigachannel_manager_component.py @@ -1,3 +1,4 @@ +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.gigachannel_manager.gigachannel_manager_component import GigachannelManagerComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component from tribler.core.components.key.key_component import KeyComponent @@ -12,9 +13,8 @@ async def test_gigachannel_manager_component(tribler_config): - components = [Ipv8Component(), KnowledgeComponent(), SocksServersComponent(), KeyComponent(), - MetadataStoreComponent(), - LibtorrentComponent(), GigachannelManagerComponent()] + components = [DatabaseComponent(), Ipv8Component(), KnowledgeComponent(), SocksServersComponent(), KeyComponent(), + MetadataStoreComponent(), LibtorrentComponent(), GigachannelManagerComponent()] async with Session(tribler_config, components) as session: comp = session.get_instance(GigachannelManagerComponent) assert comp.started_event.is_set() and not comp.failed diff --git a/src/tribler/core/components/knowledge/community/knowledge_community.py b/src/tribler/core/components/knowledge/community/knowledge_community.py index 2c342ee4823..6f276fc7962 100644 --- a/src/tribler/core/components/knowledge/community/knowledge_community.py +++ b/src/tribler/core/components/knowledge/community/knowledge_community.py @@ -18,7 +18,7 @@ from tribler.core.components.knowledge.community.knowledge_validator import validate_operation, validate_resource, \ validate_resource_type from tribler.core.components.knowledge.community.operations_requests import OperationsRequests, PeerValidationError -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase +from tribler.core.components.database.db.tribler_database import TriblerDatabase REQUESTED_OPERATIONS_COUNT = 10 @@ -34,7 +34,7 @@ class KnowledgeCommunity(TriblerCommunity): community_id = unhexlify('d7f7bdc8bcd3d9ad23f06f25aa8aab6754eb23a0') - def __init__(self, *args, db: KnowledgeDatabase, key: LibNaCLSK, request_interval=REQUEST_INTERVAL, + def __init__(self, *args, db: TriblerDatabase, key: LibNaCLSK, request_interval=REQUEST_INTERVAL, **kwargs): super().__init__(*args, **kwargs) self.db = db diff --git a/src/tribler/core/components/knowledge/community/knowledge_validator.py b/src/tribler/core/components/knowledge/community/knowledge_validator.py index ff70b3bf3e9..df87c6e5e9d 100644 --- a/src/tribler/core/components/knowledge/community/knowledge_validator.py +++ b/src/tribler/core/components/knowledge/community/knowledge_validator.py @@ -1,4 +1,4 @@ -from tribler.core.components.knowledge.db.knowledge_db import Operation, ResourceType +from tribler.core.components.database.db.tribler_database import Operation, ResourceType from tribler.core.components.knowledge.knowledge_constants import MAX_RESOURCE_LENGTH, MIN_RESOURCE_LENGTH diff --git a/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py b/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py index 5833e1a265e..4f8d94ebd66 100644 --- a/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py +++ b/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py @@ -8,7 +8,7 @@ from tribler.core.components.knowledge.community.knowledge_community import KnowledgeCommunity from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, Operation, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase, Operation, ResourceType REQUEST_INTERVAL_FOR_RANDOM_OPERATIONS = 0.1 # in seconds @@ -22,7 +22,7 @@ async def tearDown(self): await super().tearDown() def create_node(self, *args, **kwargs): - return MockIPv8("curve25519", KnowledgeCommunity, db=KnowledgeDatabase(), key=LibNaCLSK(), + return MockIPv8("curve25519", KnowledgeCommunity, db=TriblerDatabase(), key=LibNaCLSK(), request_interval=REQUEST_INTERVAL_FOR_RANDOM_OPERATIONS) def create_operation(self, subject='1' * 20, obj=''): diff --git a/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py b/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py index a4159822b05..b8663d9fee6 100644 --- a/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py +++ b/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py @@ -2,7 +2,7 @@ from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource, validate_operation, \ validate_resource, validate_resource_type -from tribler.core.components.knowledge.db.knowledge_db import Operation, ResourceType +from tribler.core.components.database.db.tribler_database import Operation, ResourceType VALID_TAGS = [ 'nl', diff --git a/src/tribler/core/components/knowledge/knowledge_component.py b/src/tribler/core/components/knowledge/knowledge_component.py index d5c8d9745ec..d1cfc487853 100644 --- a/src/tribler/core/components/knowledge/knowledge_component.py +++ b/src/tribler/core/components/knowledge/knowledge_component.py @@ -1,19 +1,17 @@ import tribler.core.components.metadata_store.metadata_store_component as metadata_store_component from tribler.core.components.component import Component +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component from tribler.core.components.key.key_component import KeyComponent from tribler.core.components.knowledge.community.knowledge_community import KnowledgeCommunity -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase from tribler.core.components.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor from tribler.core.components.metadata_store.utils import generate_test_channels -from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR class KnowledgeComponent(Component): tribler_should_stop_on_component_error = False community: KnowledgeCommunity = None - knowledge_db: KnowledgeDatabase = None rules_processor: KnowledgeRulesProcessor = None _ipv8_component: Ipv8Component = None @@ -23,22 +21,18 @@ async def run(self): self._ipv8_component = await self.require_component(Ipv8Component) key_component = await self.require_component(KeyComponent) mds_component = await self.require_component(metadata_store_component.MetadataStoreComponent) + db_component = await self.require_component(DatabaseComponent) - db_path = self.session.config.state_dir / STATEDIR_DB_DIR / "knowledge.db" - if self.session.config.gui_test_mode: - db_path = ":memory:" - - self.knowledge_db = KnowledgeDatabase(str(db_path), create_tables=True) self.community = KnowledgeCommunity( self._ipv8_component.peer, self._ipv8_component.ipv8.endpoint, self._ipv8_component.ipv8.network, - db=self.knowledge_db, + db=db_component.db, key=key_component.secondary_key ) self.rules_processor = KnowledgeRulesProcessor( notifier=self.session.notifier, - db=self.knowledge_db, + db=db_component.db, mds=mds_component.mds, ) self.rules_processor.start() @@ -46,7 +40,7 @@ async def run(self): self._ipv8_component.initialise_community_by_default(self.community) if self.session.config.gui_test_mode: - generate_test_channels(mds_component.mds, self.knowledge_db) + generate_test_channels(mds_component.mds, db_component.db) async def shutdown(self): await super().shutdown() @@ -54,5 +48,3 @@ async def shutdown(self): await self._ipv8_component.unload_community(self.community) if self.rules_processor: await self.rules_processor.shutdown() - if self.knowledge_db: - self.knowledge_db.shutdown() diff --git a/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py b/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py index af73460657f..ef4f8c6ac99 100644 --- a/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py +++ b/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py @@ -10,7 +10,7 @@ from tribler.core.components.knowledge.community.knowledge_community import KnowledgeCommunity from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, Operation, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase, Operation, ResourceType from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, RESTEndpoint, RESTResponse from tribler.core.components.restapi.rest.schema import HandledErrorSchema from tribler.core.utilities.utilities import froze_it @@ -23,9 +23,9 @@ class KnowledgeEndpoint(RESTEndpoint): """ path = '/knowledge' - def __init__(self, db: KnowledgeDatabase, community: KnowledgeCommunity): + def __init__(self, db: TriblerDatabase, community: KnowledgeCommunity): super().__init__() - self.db: KnowledgeDatabase = db + self.db: TriblerDatabase = db self.community: KnowledgeCommunity = community @staticmethod diff --git a/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py b/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py index ed984168e5a..dd3f090902a 100644 --- a/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py +++ b/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py @@ -7,7 +7,7 @@ from tribler.core.components.conftest import TEST_PERSONAL_KEY from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.knowledge.db.knowledge_db import Operation, ResourceType +from tribler.core.components.database.db.tribler_database import Operation, ResourceType from tribler.core.components.knowledge.restapi.knowledge_endpoint import KnowledgeEndpoint from tribler.core.components.restapi.rest.base_api_test import do_request from tribler.core.utilities.date_utils import freeze_time @@ -17,8 +17,8 @@ # pylint: disable=redefined-outer-name @pytest.fixture -def endpoint(knowledge_db): - return KnowledgeEndpoint(knowledge_db, Mock(key=TEST_PERSONAL_KEY, sign=Mock(return_value=b''))) +def endpoint(tribler_db): + return KnowledgeEndpoint(tribler_db, Mock(key=TEST_PERSONAL_KEY, sign=Mock(return_value=b''))) def tag_to_statement(tag: str) -> Dict: @@ -48,7 +48,7 @@ async def test_add_invalid_tag(rest_api): post_data=post_data) -async def test_modify_tags(rest_api, knowledge_db): +async def test_modify_tags(rest_api, tribler_db): """ Test modifying tags """ @@ -58,7 +58,7 @@ async def test_modify_tags(rest_api, knowledge_db): await do_request(rest_api, f'knowledge/{infohash}', request_type="PATCH", expected_code=200, post_data=post_data) with db_session: - tags = knowledge_db.get_objects(subject=infohash, predicate=ResourceType.TAG) + tags = tribler_db.get_objects(subject=infohash, predicate=ResourceType.TAG) assert len(tags) == 2 # Now remove a tag @@ -67,17 +67,17 @@ async def test_modify_tags(rest_api, knowledge_db): await do_request(rest_api, f'knowledge/{infohash}', request_type="PATCH", expected_code=200, post_data=post_data) with db_session: - tags = knowledge_db.get_objects(subject=infohash, predicate=ResourceType.TAG) + tags = tribler_db.get_objects(subject=infohash, predicate=ResourceType.TAG) assert tags == ["abc"] -async def test_modify_tags_no_community(knowledge_db, endpoint): +async def test_modify_tags_no_community(tribler_db, endpoint): endpoint.community = None infohash = 'a' * 20 endpoint.modify_statements(infohash, [tag_to_statement("abc"), tag_to_statement("def")]) with db_session: - tags = knowledge_db.get_objects(subject=infohash, predicate=ResourceType.TAG) + tags = tribler_db.get_objects(subject=infohash, predicate=ResourceType.TAG) assert len(tags) == 0 @@ -90,7 +90,7 @@ async def test_get_suggestions_invalid_infohash(rest_api): await do_request(rest_api, 'knowledge/3f3f/tag_suggestions', expected_code=400) -async def test_get_suggestions(rest_api, knowledge_db): +async def test_get_suggestions(rest_api, tribler_db): """ Test whether we can successfully fetch suggestions from content """ @@ -107,7 +107,7 @@ def _add_operation(op=Operation.ADD): operation = StatementOperation(subject_type=ResourceType.TORRENT, subject=infohash_str, predicate=ResourceType.TAG, object="test", operation=op, clock=0, creator_public_key=random_key.pub().key_to_bin()) - knowledge_db.add_operation(operation, b"") + tribler_db.add_operation(operation, b"") _add_operation(op=Operation.ADD) _add_operation(op=Operation.REMOVE) diff --git a/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py b/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py index a53683ab010..fde487ee5df 100644 --- a/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py +++ b/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py @@ -10,7 +10,7 @@ from pony.orm import db_session from tribler.core import notifications -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType from tribler.core.components.knowledge.rules.rules_content_items import content_items_rules from tribler.core.components.knowledge.rules.rules_general_tags import general_rules from tribler.core.components.knowledge.rules.tag_rules_base import extract_only_valid_tags @@ -41,7 +41,7 @@ class KnowledgeRulesProcessor(TaskManager): # this value must be incremented in the case of new rules set has been applied version: int = 5 - def __init__(self, notifier: Notifier, db: KnowledgeDatabase, mds: MetadataStore, + def __init__(self, notifier: Notifier, db: TriblerDatabase, mds: MetadataStore, batch_size: int = DEFAULT_BATCH_SIZE, batch_interval: float = DEFAULT_BATCH_INTERVAL, queue_interval: float = DEFAULT_QUEUE_INTERVAL, queue_batch_size: float = DEFAULT_QUEUE_BATCH_SIZE, queue_max_size: int = DEFAULT_QUEUE_MAX_SIZE): diff --git a/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py b/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py index aa24e3d884d..4c12684b11a 100644 --- a/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py +++ b/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py @@ -5,7 +5,7 @@ from ipv8.keyvault.private.libnaclkey import LibNaCLSK from pony.orm import db_session -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType from tribler.core.components.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor from tribler.core.components.metadata_store.db.serialization import REGULAR_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore @@ -20,8 +20,8 @@ @pytest.fixture async def tag_rules_processor(tmp_path: Path): mds = MetadataStore(db_filename=MEMORY_DB, channels_dir=tmp_path, my_key=LibNaCLSK()) - knowledge_db = KnowledgeDatabase(filename=':memory:') - processor = KnowledgeRulesProcessor(notifier=MagicMock(), db=knowledge_db, mds=mds, + db = TriblerDatabase(filename=':memory:') + processor = KnowledgeRulesProcessor(notifier=MagicMock(), db=db, mds=mds, batch_size=TEST_BATCH_SIZE, batch_interval=TEST_INTERVAL) yield processor diff --git a/src/tribler/core/components/knowledge/tests/test_knowledge_component.py b/src/tribler/core/components/knowledge/tests/test_knowledge_component.py index 046af11743e..d0568b100ac 100644 --- a/src/tribler/core/components/knowledge/tests/test_knowledge_component.py +++ b/src/tribler/core/components/knowledge/tests/test_knowledge_component.py @@ -1,3 +1,4 @@ +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component from tribler.core.components.key.key_component import KeyComponent from tribler.core.components.knowledge.knowledge_component import KnowledgeComponent @@ -9,7 +10,7 @@ async def test_tag_component(tribler_config): - components = [MetadataStoreComponent(), KeyComponent(), Ipv8Component(), KnowledgeComponent()] + components = [DatabaseComponent(), MetadataStoreComponent(), KeyComponent(), Ipv8Component(), KnowledgeComponent()] async with Session(tribler_config, components) as session: comp = session.get_instance(KnowledgeComponent) assert comp.started_event.is_set() and not comp.failed diff --git a/src/tribler/core/components/metadata_store/db/store.py b/src/tribler/core/components/metadata_store/db/store.py index 7f8e79e696b..d03672dd3a3 100644 --- a/src/tribler/core/components/metadata_store/db/store.py +++ b/src/tribler/core/components/metadata_store/db/store.py @@ -48,7 +48,7 @@ from tribler.core.exceptions import InvalidSignatureException from tribler.core.utilities.notifier import Notifier from tribler.core.utilities.path_util import Path -from tribler.core.utilities.pony_utils import TriblerDatabase, get_max, get_or_create, run_threaded +from tribler.core.utilities.pony_utils import TrackedDatabase, get_max, get_or_create, run_threaded from tribler.core.utilities.search_utils import torrent_rank from tribler.core.utilities.unicode import hexlify from tribler.core.utilities.utilities import MEMORY_DB @@ -161,7 +161,7 @@ def __init__( # We have to dynamically define/init ORM-managed entities here to be able to support # multiple sessions in Tribler. ORM-managed classes are bound to the database instance # at definition. - self.db = TriblerDatabase() + self.db = TrackedDatabase() # This attribute is internally called by Pony on startup, though pylint cannot detect it # with the static analysis. diff --git a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py index b0179e00a7b..71d2e459875 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py @@ -16,7 +16,7 @@ from tribler.core.components.ipv8.eva.result import TransferResult from tribler.core.components.ipv8.tribler_community import TriblerCommunity from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource -from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import LZ4_EMPTY_ARCHIVE, entries_to_chunk from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore @@ -136,13 +136,13 @@ class RemoteQueryCommunity(TriblerCommunity): def __init__(self, my_peer, endpoint, network, rqc_settings: RemoteQueryCommunitySettings = None, metadata_store=None, - knowledge_db=None, + tribler_db=None, **kwargs): super().__init__(my_peer, endpoint, network=network, **kwargs) self.rqc_settings = rqc_settings self.mds: MetadataStore = metadata_store - self.knowledge_db = knowledge_db + self.tribler_db = tribler_db # This object stores requests for "select" queries that we sent to other hosts. # We keep track of peers we actually requested for data so people can't randomly push spam at us. # Also, this keeps track of hosts we responded to. There is a possibility that @@ -214,11 +214,11 @@ async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List: :raises ValueError: if no JSON could be decoded. :raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed. """ - if self.knowledge_db: + if self.tribler_db: # tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter tags = sanitized_parameters.pop('tags', None) - infohash_set = await run_threaded(self.knowledge_db.instance, self.search_for_tags, tags) + infohash_set = await run_threaded(self.tribler_db.instance, self.search_for_tags, tags) if infohash_set: sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set} @@ -226,10 +226,10 @@ async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List: @db_session def search_for_tags(self, tags: Optional[List[str]]) -> Optional[Set[str]]: - if not tags or not self.knowledge_db: + if not tags or not self.tribler_db: return None valid_tags = {tag for tag in tags if is_valid_resource(tag)} - result = self.knowledge_db.get_subjects_intersection( + result = self.tribler_db.get_subjects_intersection( subjects_type=ResourceType.TORRENT, objects=valid_tags, predicate=ResourceType.TAG, diff --git a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py index a15ff9a35f2..6fcf1134af4 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py @@ -5,8 +5,8 @@ from ipv8.test.base import TestBase from pony.orm import db_session -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, ResourceType, SHOW_THRESHOLD -from tribler.core.components.knowledge.db.tests.test_knowledge_db import Resource, TestTagDB +from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType, SHOW_THRESHOLD +from tribler.core.components.database.db.tests.test_tribler_database import Resource, TestTagDB from tribler.core.components.metadata_store.db.orm_bindings.channel_node import NEW from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity @@ -26,14 +26,14 @@ class TestRemoteSearchByTags(TestBase): def setUp(self): super().setUp() self.metadata_store = None - self.knowledge_db = None + self.tribler_db = None self.initialize(BasicRemoteQueryCommunity, 1) async def tearDown(self): if self.metadata_store: self.metadata_store.shutdown() - if self.knowledge_db: - self.knowledge_db.shutdown() + if self.tribler_db: + self.tribler_db.shutdown() await super().tearDown() @@ -44,10 +44,10 @@ def create_node(self, *args, **kwargs): default_eccrypto.generate_key("curve25519"), disable_sync=True, ) - self.knowledge_db = KnowledgeDatabase(str(Path(self.temporary_directory()) / "tags.db")) + self.tribler_db = TriblerDatabase(str(Path(self.temporary_directory()) / "tags.db")) kwargs['metadata_store'] = self.metadata_store - kwargs['knowledge_db'] = self.knowledge_db + kwargs['tribler_db'] = self.tribler_db kwargs['rqc_settings'] = RemoteQueryCommunitySettings() return super().create_node(*args, **kwargs) @@ -55,12 +55,12 @@ def create_node(self, *args, **kwargs): def rqc(self) -> RemoteQueryCommunity: return self.overlay(0) - @patch.object(RemoteQueryCommunity, 'knowledge_db', new=PropertyMock(return_value=None), create=True) + @patch.object(RemoteQueryCommunity, 'tribler_db', new=PropertyMock(return_value=None), create=True) def test_search_for_tags_no_db(self): - # test that in case of missed `knowledge_db`, function `search_for_tags` returns None + # test that in case of missed `tribler_db`, function `search_for_tags` returns None assert self.rqc.search_for_tags(tags=['tag']) is None - @patch.object(KnowledgeDatabase, 'get_subjects_intersection') + @patch.object(TriblerDatabase, 'get_subjects_intersection') def test_search_for_tags_only_valid_tags(self, mocked_get_subjects_intersection: Mock): # test that function `search_for_tags` uses only valid tags self.rqc.search_for_tags(tags=['invalid_tag' * 50, 'valid_tag']) @@ -92,7 +92,7 @@ async def test_process_rpc_query_with_tags(self): @db_session def fill_tags_database(): TestTagDB.add_operation_set( - self.rqc.knowledge_db, + self.rqc.tribler_db, { hexlify(infohash1): [ Resource(predicate=ResourceType.TAG, name='tag1', count=SHOW_THRESHOLD), diff --git a/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py b/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py index 6a9d2879b41..76db9923c7a 100644 --- a/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py +++ b/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py @@ -3,7 +3,7 @@ from pony.orm import db_session -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType from tribler.core.components.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor from tribler.core.components.metadata_store.category_filter.family_filter import default_xxx_filter from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT @@ -41,11 +41,11 @@ class MetadataEndpointBase(RESTEndpoint): - def __init__(self, metadata_store: MetadataStore, *args, knowledge_db: KnowledgeDatabase = None, + def __init__(self, metadata_store: MetadataStore, *args, tribler_db: TriblerDatabase = None, tag_rules_processor: KnowledgeRulesProcessor = None, **kwargs): super().__init__(*args, **kwargs) self.mds = metadata_store - self.knowledge_db: Optional[KnowledgeDatabase] = knowledge_db + self.tribler_db: Optional[TriblerDatabase] = tribler_db self.tag_rules_processor: Optional[KnowledgeRulesProcessor] = tag_rules_processor @classmethod @@ -76,13 +76,13 @@ def sanitize_parameters(cls, parameters): @db_session def add_statements_to_metadata_list(self, contents_list, hide_xxx=False): - if self.knowledge_db is None: + if self.tribler_db is None: self._logger.error(f'Cannot add statements to metadata list: ' - f'knowledge_db is not set in {self.__class__.__name__}') + f'tribler_db is not set in {self.__class__.__name__}') return for torrent in contents_list: if torrent['type'] == REGULAR_TORRENT: - raw_statements = self.knowledge_db.get_statements( + raw_statements = self.tribler_db.get_statements( subject_type=ResourceType.TORRENT, subject=torrent["infohash"] ) diff --git a/src/tribler/core/components/metadata_store/restapi/search_endpoint.py b/src/tribler/core/components/metadata_store/restapi/search_endpoint.py index 26dc941e265..787e38b51e5 100644 --- a/src/tribler/core/components/metadata_store/restapi/search_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/search_endpoint.py @@ -8,7 +8,7 @@ from marshmallow.fields import Integer, String from pony.orm import db_session -from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.metadata_store.db.serialization import SNIPPET from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.components.metadata_store.restapi.metadata_endpoint import MetadataEndpointBase @@ -49,9 +49,9 @@ def build_snippets(self, search_results: List[Dict]) -> List[Dict]: content_to_torrents: Dict[str, list] = defaultdict(list) for search_result in search_results: with db_session: - content_items: List[str] = self.knowledge_db.get_objects(subject_type=ResourceType.TORRENT, - subject=search_result["infohash"], - predicate=ResourceType.CONTENT_ITEM) + content_items: List[str] = self.tribler_db.get_objects(subject_type=ResourceType.TORRENT, + subject=search_result["infohash"], + predicate=ResourceType.CONTENT_ITEM) if content_items: for content_id in content_items: content_to_torrents[content_id].append(search_result) @@ -150,10 +150,10 @@ def search_db(): try: with db_session: if tags: - infohash_set = self.knowledge_db.get_subjects_intersection(subjects_type=ResourceType.TORRENT, - objects=set(tags), - predicate=ResourceType.TAG, - case_sensitive=False) + infohash_set = self.tribler_db.get_subjects_intersection(subjects_type=ResourceType.TORRENT, + objects=set(tags), + predicate=ResourceType.TAG, + case_sensitive=False) if infohash_set: sanitized['infohash_set'] = {bytes.fromhex(s) for s in infohash_set} diff --git a/src/tribler/core/components/metadata_store/restapi/tests/conftest.py b/src/tribler/core/components/metadata_store/restapi/tests/conftest.py index 56d7d5fbc6c..897ed1c6786 100644 --- a/src/tribler/core/components/metadata_store/restapi/tests/conftest.py +++ b/src/tribler/core/components/metadata_store/restapi/tests/conftest.py @@ -51,7 +51,7 @@ def add_fake_torrents_channels(metadata_store): @pytest.fixture -def my_channel(metadata_store, knowledge_db): +def my_channel(metadata_store, tribler_db): """ Generate a channel with some torrents. Also add a few (random) tags to these torrents. """ @@ -62,11 +62,11 @@ def my_channel(metadata_store, knowledge_db): _ = metadata_store.TorrentMetadata( origin_id=chan.id_, title='torrent%d' % ind, status=NEW, infohash=infohash ) - tag_torrent(infohash, knowledge_db) + tag_torrent(infohash, tribler_db) for ind in range(5, 9): infohash = random_infohash() _ = metadata_store.TorrentMetadata(origin_id=chan.id_, title='torrent%d' % ind, infohash=infohash) - tag_torrent(infohash, knowledge_db) + tag_torrent(infohash, tribler_db) chan2 = metadata_store.ChannelMetadata.create_channel('test2', 'test2') for ind in range(5): @@ -74,11 +74,11 @@ def my_channel(metadata_store, knowledge_db): _ = metadata_store.TorrentMetadata( origin_id=chan2.id_, title='torrentB%d' % ind, status=NEW, infohash=infohash ) - tag_torrent(infohash, knowledge_db) + tag_torrent(infohash, tribler_db) for ind in range(5, 9): infohash = random_infohash() _ = metadata_store.TorrentMetadata( origin_id=chan2.id_, title='torrentB%d' % ind, infohash=random_infohash() ) - tag_torrent(infohash, knowledge_db) + tag_torrent(infohash, tribler_db) return chan diff --git a/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py b/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py index 4b16b29a9b0..a9776b19450 100644 --- a/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py @@ -8,7 +8,7 @@ from pony.orm import db_session from tribler.core.components.gigachannel.community.gigachannel_community import NoChannelSourcesException -from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.libtorrent.torrentdef import TorrentDef from tribler.core.components.metadata_store.category_filter.family_filter import default_xxx_filter from tribler.core.components.metadata_store.db.orm_bindings.channel_node import NEW @@ -32,7 +32,7 @@ # pylint: disable=unused-argument, redefined-outer-name @pytest.fixture -def endpoint(mock_dlmgr, metadata_store, knowledge_db): +def endpoint(mock_dlmgr, metadata_store, tribler_db): def return_exc(*args, **kwargs): raise RequestTimeoutException @@ -43,7 +43,7 @@ def return_exc(*args, **kwargs): Mock(), Mock(remote_select_channel_contents=return_exc), metadata_store, - knowledge_db=knowledge_db + tribler_db=tribler_db ) @@ -636,7 +636,7 @@ async def test_get_my_channel_tags(metadata_store, mock_dlmgr_get_download, my_c assert len(item["statements"]) >= 2 -async def test_get_my_channel_tags_xxx(metadata_store, knowledge_db, mock_dlmgr_get_download, my_channel, +async def test_get_my_channel_tags_xxx(metadata_store, tribler_db, mock_dlmgr_get_download, my_channel, rest_api): # pylint: disable=redefined-outer-name """ Test whether XXX tags are correctly filtered @@ -649,7 +649,7 @@ async def test_get_my_channel_tags_xxx(metadata_store, knowledge_db, mock_dlmgr_ # Add a few tags to our new torrent tags = ["totally safe", "wrongterm", "wRonGtErM", "a wrongterm b"] - tag_torrent(infohash, knowledge_db, tags=tags) + tag_torrent(infohash, tribler_db, tags=tags) json_dict = await do_request( rest_api, diff --git a/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py b/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py index 3c6a91a1fc9..03eaedfdbcb 100644 --- a/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py @@ -6,7 +6,7 @@ import pytest from pony.orm import db_session -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.metadata_store.db.serialization import REGULAR_TORRENT, SNIPPET from tribler.core.components.metadata_store.restapi.search_endpoint import SearchEndpoint from tribler.core.components.restapi.rest.base_api_test import do_request @@ -30,8 +30,8 @@ def needle_in_haystack_mds(metadata_store): @pytest.fixture -def endpoint(needle_in_haystack_mds, knowledge_db): - return SearchEndpoint(needle_in_haystack_mds, knowledge_db=knowledge_db) +def endpoint(needle_in_haystack_mds, tribler_db): + return SearchEndpoint(needle_in_haystack_mds, tribler_db=tribler_db) async def test_search_wrong_mdtype(rest_api): @@ -72,7 +72,7 @@ def mocked_get_subjects_intersection(*_, objects: Set[str], **__): return None return {hexlify(os.urandom(20))} - with patch.object(KnowledgeDatabase, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection): + with patch.object(TriblerDatabase, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection): parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=real_tag', expected_code=200) assert len(parsed["results"]) == 0 @@ -159,7 +159,7 @@ async def test_search_with_space(rest_api, metadata_store): assert results == {'abc.def', 'abc def'} # but not 'abcxyz def' -async def test_single_snippet_in_search(rest_api, metadata_store, knowledge_db): +async def test_single_snippet_in_search(rest_api, metadata_store, tribler_db): """ Test building a simple snippet of a single item. """ @@ -170,7 +170,7 @@ async def test_single_snippet_in_search(rest_api, metadata_store, knowledge_db): def mocked_get_subjects(*_, **__) -> List[str]: return ["Abc"] - with patch.object(KnowledgeDatabase, 'get_objects', wraps=mocked_get_subjects): + with patch.object(TriblerDatabase, 'get_objects', wraps=mocked_get_subjects): s1 = to_fts_query("abc") results = await do_request(rest_api, f'search?txt_filter={s1}', expected_code=200) @@ -182,7 +182,7 @@ def mocked_get_subjects(*_, **__) -> List[str]: assert snippet["torrents_in_snippet"][0]["infohash"] == hexlify(content_ih) -async def test_multiple_snippets_in_search(rest_api, metadata_store, knowledge_db): +async def test_multiple_snippets_in_search(rest_api, metadata_store, tribler_db): """ Test two snippets with two torrents in each snippet. """ @@ -200,7 +200,7 @@ def mocked_get_objects(*__, subject=None, **___) -> List[str]: return ["Content item 2"] return [] - with patch.object(KnowledgeDatabase, 'get_objects', wraps=mocked_get_objects): + with patch.object(TriblerDatabase, 'get_objects', wraps=mocked_get_objects): s1 = to_fts_query("abc") parsed = await do_request(rest_api, f'search?txt_filter={s1}', expected_code=200) results = parsed["results"] diff --git a/src/tribler/core/components/metadata_store/tests/test_metadata_store_component.py b/src/tribler/core/components/metadata_store/tests/test_metadata_store_component.py index fdde51ffe50..3aaae2bc544 100644 --- a/src/tribler/core/components/metadata_store/tests/test_metadata_store_component.py +++ b/src/tribler/core/components/metadata_store/tests/test_metadata_store_component.py @@ -1,3 +1,4 @@ +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component from tribler.core.components.key.key_component import KeyComponent from tribler.core.components.knowledge.knowledge_component import KnowledgeComponent @@ -9,7 +10,7 @@ async def test_metadata_store_component(tribler_config): - components = [KnowledgeComponent(), Ipv8Component(), KeyComponent(), MetadataStoreComponent()] + components = [DatabaseComponent(), KnowledgeComponent(), Ipv8Component(), KeyComponent(), MetadataStoreComponent()] async with Session(tribler_config, components) as session: comp = session.get_instance(MetadataStoreComponent) assert comp.started_event.is_set() and not comp.failed diff --git a/src/tribler/core/components/metadata_store/utils.py b/src/tribler/core/components/metadata_store/utils.py index 53636ae45eb..7cdc6ba8cde 100644 --- a/src/tribler/core/components/metadata_store/utils.py +++ b/src/tribler/core/components/metadata_store/utils.py @@ -6,7 +6,7 @@ from pony.orm import db_session from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, Operation, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase, Operation, ResourceType from tribler.core.components.knowledge.knowledge_constants import MIN_RESOURCE_LENGTH from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.tests.tools.common import PNG_FILE @@ -108,7 +108,7 @@ def generate_collection(metadata_store, tags_db, parent): @db_session -def generate_channel(metadata_store: MetadataStore, tags_db: KnowledgeDatabase, title=None, subscribed=False): +def generate_channel(metadata_store: MetadataStore, tags_db: TriblerDatabase, title=None, subscribed=False): # Remember and restore the original key orig_key = metadata_store.ChannelNode._my_key diff --git a/src/tribler/core/components/popularity/tests/test_popularity_component.py b/src/tribler/core/components/popularity/tests/test_popularity_component.py index b267d3f4eb2..64f7309699d 100644 --- a/src/tribler/core/components/popularity/tests/test_popularity_component.py +++ b/src/tribler/core/components/popularity/tests/test_popularity_component.py @@ -1,3 +1,4 @@ +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component from tribler.core.components.key.key_component import KeyComponent from tribler.core.components.knowledge.knowledge_component import KnowledgeComponent @@ -13,8 +14,9 @@ async def test_popularity_component(tribler_config): - components = [SocksServersComponent(), LibtorrentComponent(), TorrentCheckerComponent(), KnowledgeComponent(), - MetadataStoreComponent(), KeyComponent(), Ipv8Component(), PopularityComponent()] + components = [DatabaseComponent(), SocksServersComponent(), LibtorrentComponent(), TorrentCheckerComponent(), + KnowledgeComponent(), MetadataStoreComponent(), KeyComponent(), Ipv8Component(), + PopularityComponent()] async with Session(tribler_config, components) as session: comp = session.get_instance(PopularityComponent) assert comp.community diff --git a/src/tribler/core/components/restapi/restapi_component.py b/src/tribler/core/components/restapi/restapi_component.py index e1e6073a0a0..603f468a5ba 100644 --- a/src/tribler/core/components/restapi/restapi_component.py +++ b/src/tribler/core/components/restapi/restapi_component.py @@ -6,6 +6,7 @@ from tribler.core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent from tribler.core.components.bandwidth_accounting.restapi.bandwidth_endpoint import BandwidthEndpoint from tribler.core.components.component import Component +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.exceptions import NoneComponent from tribler.core.components.gigachannel.gigachannel_component import GigaChannelComponent from tribler.core.components.gigachannel_manager.gigachannel_manager_component import GigachannelManagerComponent @@ -84,6 +85,7 @@ async def run(self): tunnel_component = await self.maybe_component(TunnelsComponent) torrent_checker_component = await self.maybe_component(TorrentCheckerComponent) gigachannel_manager_component = await self.maybe_component(GigachannelManagerComponent) + db_component = await self.maybe_component(DatabaseComponent) public_key = key_component.primary_key.key.pk if not isinstance(key_component, NoneComponent) else b'' self._events_endpoint = EventsEndpoint(notifier, public_key=hexlify(public_key)) @@ -109,15 +111,15 @@ async def run(self): self.maybe_add(LibTorrentEndpoint, libtorrent_component.download_manager) self.maybe_add(TorrentInfoEndpoint, libtorrent_component.download_manager) self.maybe_add(MetadataEndpoint, torrent_checker, metadata_store_component.mds, - knowledge_db=knowledge_component.knowledge_db, + tribler_db=db_component.db, tag_rules_processor=knowledge_component.rules_processor) self.maybe_add(ChannelsEndpoint, libtorrent_component.download_manager, gigachannel_manager, gigachannel_component.community, metadata_store_component.mds, - knowledge_db=knowledge_component.knowledge_db, + tribler_db=db_component.db, tag_rules_processor=knowledge_component.rules_processor) - self.maybe_add(SearchEndpoint, metadata_store_component.mds, knowledge_db=knowledge_component.knowledge_db) + self.maybe_add(SearchEndpoint, metadata_store_component.mds, tribler_db=db_component.db) self.maybe_add(RemoteQueryEndpoint, gigachannel_component.community, metadata_store_component.mds) - self.maybe_add(KnowledgeEndpoint, db=knowledge_component.knowledge_db, community=knowledge_component.community) + self.maybe_add(KnowledgeEndpoint, db=db_component.db, community=knowledge_component.community) if not isinstance(ipv8_component, NoneComponent): ipv8_root_endpoint = IPV8RootEndpoint() diff --git a/src/tribler/core/components/restapi/tests/test_restapi_component.py b/src/tribler/core/components/restapi/tests/test_restapi_component.py index d9da02ce6c8..fe9f8fc8f8a 100644 --- a/src/tribler/core/components/restapi/tests/test_restapi_component.py +++ b/src/tribler/core/components/restapi/tests/test_restapi_component.py @@ -3,6 +3,7 @@ import pytest from tribler.core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.exceptions import NoneComponent from tribler.core.components.gigachannel.gigachannel_component import GigaChannelComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component @@ -22,7 +23,7 @@ async def test_rest_component(tribler_config): components = [KeyComponent(), RESTComponent(), Ipv8Component(), LibtorrentComponent(), ResourceMonitorComponent(), BandwidthAccountingComponent(), GigaChannelComponent(), KnowledgeComponent(), SocksServersComponent(), - MetadataStoreComponent()] + MetadataStoreComponent(), DatabaseComponent()] async with Session(tribler_config, components) as session: # Test REST component starts normally comp = session.get_instance(RESTComponent) diff --git a/src/tribler/core/components/torrent_checker/tests/test_torrent_checker_component.py b/src/tribler/core/components/torrent_checker/tests/test_torrent_checker_component.py index 28ef90bb1de..b66880a811a 100644 --- a/src/tribler/core/components/torrent_checker/tests/test_torrent_checker_component.py +++ b/src/tribler/core/components/torrent_checker/tests/test_torrent_checker_component.py @@ -1,3 +1,4 @@ +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.ipv8.ipv8_component import Ipv8Component from tribler.core.components.key.key_component import KeyComponent from tribler.core.components.knowledge.knowledge_component import KnowledgeComponent @@ -10,7 +11,7 @@ # pylint: disable=protected-access async def test_torrent_checker_component(tribler_config): - components = [SocksServersComponent(), LibtorrentComponent(), KeyComponent(), + components = [DatabaseComponent(), SocksServersComponent(), LibtorrentComponent(), KeyComponent(), Ipv8Component(), KnowledgeComponent(), MetadataStoreComponent(), TorrentCheckerComponent()] async with Session(tribler_config, components) as session: comp = session.get_instance(TorrentCheckerComponent) diff --git a/src/tribler/core/start_core.py b/src/tribler/core/start_core.py index 746a1cc577e..ee488684366 100644 --- a/src/tribler/core/start_core.py +++ b/src/tribler/core/start_core.py @@ -14,6 +14,7 @@ ) from tribler.core.components.bandwidth_accounting.bandwidth_accounting_component import BandwidthAccountingComponent from tribler.core.components.component import Component +from tribler.core.components.database.database_component import DatabaseComponent from tribler.core.components.gigachannel.gigachannel_component import GigaChannelComponent from tribler.core.components.gigachannel_manager.gigachannel_manager_component import GigachannelManagerComponent from tribler.core.components.gui_process_watcher.gui_process_watcher import GuiProcessWatcher @@ -55,6 +56,7 @@ def components_gen(config: TriblerConfig): """ yield ReporterComponent() yield GuiProcessWatcherComponent() + yield DatabaseComponent() yield RESTComponent() if config.chant.enabled or config.torrent_checking.enabled: yield MetadataStoreComponent() diff --git a/src/tribler/core/upgrade/knowledge_to_triblerdb/migration.py b/src/tribler/core/upgrade/knowledge_to_triblerdb/migration.py new file mode 100644 index 00000000000..282bc5f2d54 --- /dev/null +++ b/src/tribler/core/upgrade/knowledge_to_triblerdb/migration.py @@ -0,0 +1,32 @@ +import logging +import shutil + +from tribler.core.utilities.path_util import Path +from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR + + +class MigrationKnowledgeToTriblerDB: + def __init__(self, state_dir: Path): + self.logger = logging.getLogger(self.__class__.__name__) + self.state_dir = state_dir + + self.knowledge_db_path = self.state_dir / STATEDIR_DB_DIR / 'knowledge.db' + self.tribler_db_path = self.state_dir / STATEDIR_DB_DIR / 'tribler.db' + + self.logger.info(f'Knowledge DB path: {self.knowledge_db_path}') + self.logger.info(f'Tribler DB path: {self.tribler_db_path}') + + def run(self) -> bool: + if not self.knowledge_db_path.exists(): + self.logger.info("Knowledge DB doesn't exist. Stop procedure.") + return False + + try: + # move self.knowledge_db_path to self.tribler_db_path + shutil.move(str(self.knowledge_db_path), str(self.tribler_db_path)) + except OSError as e: + self.logger.error(f"Failed to move the file: {e}") + return False + + self.logger.info("File moved successfully.") + return True diff --git a/src/tribler/core/upgrade/knowledge_to_triblerdb/tests/test_knowledge_to_tribler_db_migration.py b/src/tribler/core/upgrade/knowledge_to_triblerdb/tests/test_knowledge_to_tribler_db_migration.py new file mode 100644 index 00000000000..61a524026d7 --- /dev/null +++ b/src/tribler/core/upgrade/knowledge_to_triblerdb/tests/test_knowledge_to_tribler_db_migration.py @@ -0,0 +1,56 @@ +from unittest.mock import Mock, patch + +import pytest + +from tribler.core.upgrade.knowledge_to_triblerdb.migration import MigrationKnowledgeToTriblerDB +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.knowledge_db import KnowledgeDatabase +from tribler.core.utilities.path_util import Path +from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR + + +# pylint: disable=redefined-outer-name +@pytest.fixture +def migration(tmp_path: Path): + db_dir = tmp_path / STATEDIR_DB_DIR + db_dir.mkdir() + migration = MigrationKnowledgeToTriblerDB(tmp_path) + return migration + + +def test_no_knowledge_db(migration: MigrationKnowledgeToTriblerDB): + # test that in the case of missed `knowledge.db`, migration.run() returns False + assert not migration.run() + assert not migration.knowledge_db_path.exists() + assert not migration.tribler_db_path.exists() + + +def test_move_file(migration: MigrationKnowledgeToTriblerDB): + # Test that the migration moves the `knowledge.db` to `tribler.db` + + # create DB file + KnowledgeDatabase(str(migration.knowledge_db_path)).shutdown() + + assert migration.knowledge_db_path.exists() + assert not migration.tribler_db_path.exists() + + # run migration + assert migration.run() + assert not migration.knowledge_db_path.exists() + assert migration.tribler_db_path.exists() + + +@patch('tribler.core.upgrade.knowledge_to_triblerdb.migration.shutil.move', Mock(side_effect=FileNotFoundError)) +def test_exception(migration: MigrationKnowledgeToTriblerDB): + # Test that the migration doesn't move the `knowledge.db` to `tribler.db` after unsuccessful migration procedure. + + # create DB file + KnowledgeDatabase(str(migration.knowledge_db_path)).shutdown() + + assert migration.knowledge_db_path.exists() + assert not migration.tribler_db_path.exists() + + # run migration + assert not migration.run() + + assert migration.knowledge_db_path.exists() + assert not migration.tribler_db_path.exists() diff --git a/src/tribler/core/upgrade/tags_to_knowledge/migration.py b/src/tribler/core/upgrade/tags_to_knowledge/migration.py index 6712a5efbf1..199f8f6b6be 100644 --- a/src/tribler/core/upgrade/tags_to_knowledge/migration.py +++ b/src/tribler/core/upgrade/tags_to_knowledge/migration.py @@ -7,8 +7,8 @@ from pony.orm import db_session, select from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase, ResourceType -from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.knowledge_db import KnowledgeDatabase, ResourceType +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.tags_db import TagDatabase from tribler.core.utilities.path_util import Path from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR from tribler.core.utilities.unicode import hexlify diff --git a/src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/__init__.py b/src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/components/knowledge/db/knowledge_db.py b/src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/knowledge_db.py similarity index 99% rename from src/tribler/core/components/knowledge/db/knowledge_db.py rename to src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/knowledge_db.py index 4950cd1a1ee..b593aed6c1f 100644 --- a/src/tribler/core/components/knowledge/db/knowledge_db.py +++ b/src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/knowledge_db.py @@ -10,7 +10,7 @@ from pony.utils import between from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.utilities.pony_utils import TriblerDatabase, get_or_create +from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create CLOCK_START_VALUE = 0 @@ -64,7 +64,7 @@ class SimpleStatement: class KnowledgeDatabase: def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): - self.instance = TriblerDatabase() + self.instance = TrackedDatabase() self.define_binding(self.instance) self.instance.bind('sqlite', filename or ':memory:', create_db=True) generate_mapping_kwargs['create_tables'] = create_tables diff --git a/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py b/src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/tags_db.py similarity index 98% rename from src/tribler/core/upgrade/tags_to_knowledge/tags_db.py rename to src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/tags_db.py index f9890118e60..3df9229202e 100644 --- a/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py +++ b/src/tribler/core/upgrade/tags_to_knowledge/previous_dbs/tags_db.py @@ -3,12 +3,12 @@ from pony import orm -from tribler.core.utilities.pony_utils import TriblerDatabase, get_or_create +from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create class TagDatabase: def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): - self.instance = TriblerDatabase() + self.instance = TrackedDatabase() self.define_binding(self.instance) self.instance.bind('sqlite', filename or ':memory:', create_db=True) generate_mapping_kwargs['create_tables'] = create_tables diff --git a/src/tribler/core/upgrade/tags_to_knowledge/tests/test_migration.py b/src/tribler/core/upgrade/tags_to_knowledge/tests/test_tags_to_knowledge_migration.py similarity index 96% rename from src/tribler/core/upgrade/tags_to_knowledge/tests/test_migration.py rename to src/tribler/core/upgrade/tags_to_knowledge/tests/test_tags_to_knowledge_migration.py index 47b79f00fe8..4eba0ab489a 100644 --- a/src/tribler/core/upgrade/tags_to_knowledge/tests/test_migration.py +++ b/src/tribler/core/upgrade/tags_to_knowledge/tests/test_tags_to_knowledge_migration.py @@ -7,9 +7,9 @@ from pony.orm import db_session from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.knowledge.db.knowledge_db import KnowledgeDatabase +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.knowledge_db import KnowledgeDatabase from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge -from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.tags_db import TagDatabase from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR diff --git a/src/tribler/core/upgrade/tests/test_upgrader.py b/src/tribler/core/upgrade/tests/test_upgrader.py index e6b1a369c92..baf2c337f6b 100644 --- a/src/tribler/core/upgrade/tests/test_upgrader.py +++ b/src/tribler/core/upgrade/tests/test_upgrader.py @@ -14,7 +14,7 @@ from tribler.core.components.metadata_store.db.store import CURRENT_DB_VERSION, MetadataStore from tribler.core.tests.tools.common import TESTS_DATA_DIR from tribler.core.upgrade.db8_to_db10 import calc_progress -from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.tags_db import TagDatabase from tribler.core.upgrade.upgrade import TriblerUpgrader, cleanup_noncompliant_channel_torrents from tribler.core.utilities.configparser import CallbackConfigParser from tribler.core.utilities.utilities import random_infohash diff --git a/src/tribler/core/upgrade/upgrade.py b/src/tribler/core/upgrade/upgrade.py index 3d490e1d537..682f9157836 100644 --- a/src/tribler/core/upgrade/upgrade.py +++ b/src/tribler/core/upgrade/upgrade.py @@ -19,8 +19,9 @@ ) from tribler.core.upgrade.config_converter import convert_config_to_tribler76 from tribler.core.upgrade.db8_to_db10 import PonyToPonyMigration, get_db_version +from tribler.core.upgrade.knowledge_to_triblerdb.migration import MigrationKnowledgeToTriblerDB from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge -from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.tags_db import TagDatabase from tribler.core.utilities.configparser import CallbackConfigParser from tribler.core.utilities.path_util import Path from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR @@ -104,6 +105,7 @@ def run(self): self.upgrade_tags_to_knowledge() self.remove_old_logs() self.upgrade_pony_db_14to15() + self.upgrade_knowledge_to_tribler_db() def remove_old_logs(self) -> Tuple[List[Path], List[Path]]: self._logger.info(f'Remove old logs') @@ -417,3 +419,8 @@ def update_status(self, status_text): self._logger.info(status_text) if self._update_status_callback: self._update_status_callback(status_text) + + def upgrade_knowledge_to_tribler_db(self): + self._logger.info('Upgrade knowledge to tribler.db') + migration = MigrationKnowledgeToTriblerDB(self.state_dir) + migration.run() diff --git a/src/tribler/core/utilities/pony_utils.py b/src/tribler/core/utilities/pony_utils.py index 1b1cd1ddcf9..1069635e9a9 100644 --- a/src/tribler/core/utilities/pony_utils.py +++ b/src/tribler/core/utilities/pony_utils.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -databases_to_track: WeakSet[TriblerDatabase] = WeakSet() +databases_to_track: WeakSet[TrackedDatabase] = WeakSet() StatDict = Dict[Optional[str], core.QueryStat] @@ -292,7 +292,7 @@ def release_lock(self): db_session = TriblerDbSession() -class TriblerDatabase(Database): +class TrackedDatabase(Database): # If a developer what to track the slow execution of the database, he should create an instance of TriblerDatabase # instead of the usual pony.orm.Database. diff --git a/src/tribler/core/utilities/tests/test_pony_utils.py b/src/tribler/core/utilities/tests/test_pony_utils.py index 4d93e8c2dfa..7a2b540c94b 100644 --- a/src/tribler/core/utilities/tests/test_pony_utils.py +++ b/src/tribler/core/utilities/tests/test_pony_utils.py @@ -44,7 +44,7 @@ def test_patched_db_session(tmp_path): # The test is added for better coverage of TriblerDbSession methods with patch('pony.orm.dbproviders.sqlite.provider_cls', pony_utils.PatchedSQLiteProvider): - db = pony_utils.TriblerDatabase() + db = pony_utils.TrackedDatabase() db.bind('sqlite', str(tmp_path / 'db.sqlite'), create_db=True) class Entity1(db.Entity): @@ -82,7 +82,7 @@ def test_patched_db_session_default_duration_threshold(tmp_path): # if no duration_threshold was explicitly specified for db_session with patch('pony.orm.dbproviders.sqlite.provider_cls', pony_utils.PatchedSQLiteProvider): - db = pony_utils.TriblerDatabase() + db = pony_utils.TrackedDatabase() db.bind('sqlite', str(tmp_path / 'db.sqlite'), create_db=True) class Entity1(db.Entity): diff --git a/src/tribler/gui/dialogs/editmetadatadialog.py b/src/tribler/gui/dialogs/editmetadatadialog.py index 5eae9cf6837..14d1d68c874 100644 --- a/src/tribler/gui/dialogs/editmetadatadialog.py +++ b/src/tribler/gui/dialogs/editmetadatadialog.py @@ -4,7 +4,7 @@ from PyQt5.QtCore import QModelIndex, QPoint, Qt, pyqtSignal from PyQt5.QtWidgets import QComboBox, QSizePolicy, QWidget -from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.knowledge.knowledge_constants import MAX_RESOURCE_LENGTH, MIN_RESOURCE_LENGTH from tribler.gui.defs import TAG_HORIZONTAL_MARGIN from tribler.gui.dialogs.dialogcontainer import DialogContainer diff --git a/src/tribler/gui/tests/test_gui.py b/src/tribler/gui/tests/test_gui.py index 5fcf30050b8..9efbb346e52 100644 --- a/src/tribler/gui/tests/test_gui.py +++ b/src/tribler/gui/tests/test_gui.py @@ -10,7 +10,7 @@ from PyQt5.QtWidgets import QListWidget, QTableView, QTextEdit, QTreeWidget, QTreeWidgetItem import tribler.gui -from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.knowledge.knowledge_constants import MIN_RESOURCE_LENGTH from tribler.core.components.reporter.reported_error import ReportedError from tribler.core.sentry_reporter.sentry_reporter import SentryReporter diff --git a/src/tribler/gui/utilities.py b/src/tribler/gui/utilities.py index bbaeff9da55..bcf4b4f1198 100644 --- a/src/tribler/gui/utilities.py +++ b/src/tribler/gui/utilities.py @@ -27,7 +27,7 @@ from PyQt5.QtWidgets import QApplication, QMessageBox import tribler.gui -from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.components.database.db.tribler_database import ResourceType from tribler.gui.defs import HEALTH_DEAD, HEALTH_GOOD, HEALTH_MOOT, HEALTH_UNCHECKED # fmt: off diff --git a/src/tribler/gui/widgets/tablecontentdelegate.py b/src/tribler/gui/widgets/tablecontentdelegate.py index 90ea1fab6f1..41d37ffdba8 100644 --- a/src/tribler/gui/widgets/tablecontentdelegate.py +++ b/src/tribler/gui/widgets/tablecontentdelegate.py @@ -6,7 +6,7 @@ from PyQt5.QtWidgets import QApplication, QComboBox, QStyle, QStyleOptionViewItem, QStyledItemDelegate, QToolTip from psutil import LINUX -from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.metadata_store.db.orm_bindings.channel_node import LEGACY_ENTRY from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT, \ SNIPPET