diff --git a/src/tribler/core/components/database/db/layers/health_data_access_level.py b/src/tribler/core/components/database/db/layers/health_data_access_level.py new file mode 100644 index 00000000000..a1590732538 --- /dev/null +++ b/src/tribler/core/components/database/db/layers/health_data_access_level.py @@ -0,0 +1,57 @@ +import datetime +import logging + +from pony import orm + +from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo +from tribler.core.upgrade.tags_to_knowledge.previous_dbs.knowledge_db import ResourceType +from tribler.core.utilities.pony_utils import get_or_create + + +class HealthDataAccessLayer: + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + self.db = None + self.instance = None + + def apply(self, db): + self.db = db.instance + self.instance = db.instance + + return self.define_binding(self.instance) + + @staticmethod + def define_binding(db): + class HealthInfo(db.Entity): + id = orm.PrimaryKey(int, auto=True) + + torrent = orm.Required(lambda: db.Resource, index=True) + + seeders = orm.Required(int, default=0) + leechers = orm.Required(int, default=0) + source = orm.Required(int, default=0) # Source enum + last_check = orm.Required(datetime.datetime, default=datetime.datetime.utcnow) + + return HealthInfo, + + def add_torrent_health(self, torrent_health: HealthInfo): + torrent = get_or_create( + self.instance.Resource, + name=torrent_health.infohash_hex, + type=ResourceType.TORRENT + ) + + health_info_entity = get_or_create( + self.instance.HealthInfo, + torrent=torrent + ) + + health_info_entity.seeders = torrent_health.seeders + health_info_entity.leechers = torrent_health.leechers + health_info_entity.source = torrent_health.source + health_info_entity.last_check = datetime.datetime.utcfromtimestamp(torrent_health.last_check) + + def get_torrent_health(self, infohash: str): + if torrent := self.instance.Resource.get(name=infohash, type=ResourceType.TORRENT): + return self.instance.HealthInfo.get(torrent=torrent) + return None diff --git a/src/tribler/core/components/database/db/layers/knowledge_data_access_layer.py b/src/tribler/core/components/database/db/layers/knowledge_data_access_layer.py index f5812c719a3..32d7a37c4fd 100644 --- a/src/tribler/core/components/database/db/layers/knowledge_data_access_layer.py +++ b/src/tribler/core/components/database/db/layers/knowledge_data_access_layer.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass from enum import IntEnum -from typing import Any, Callable, Iterator, List, Optional, Set +from typing import Callable, Iterator, List, Optional, Set from pony import orm from pony.orm import raw_sql @@ -11,7 +11,7 @@ from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo -from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create +from tribler.core.utilities.pony_utils import get_or_create CLOCK_START_VALUE = 0 @@ -64,13 +64,16 @@ class SimpleStatement: class KnowledgeDataAccessLayer: - def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): - self.instance = TrackedDatabase() - self.define_binding(self.instance) - self.instance.bind('sqlite', filename or ':memory:', create_db=True) - generate_mapping_kwargs['create_tables'] = create_tables - self.instance.generate_mapping(**generate_mapping_kwargs) + def __init__(self): self.logger = logging.getLogger(self.__class__.__name__) + self.db = None + self.instance = None + + def apply(self, db): + self.db = db + self.instance = db.instance + + return self.define_binding(self.instance) @staticmethod def define_binding(db): @@ -124,7 +127,7 @@ class Resource(db.Entity): subject_statements = orm.Set(lambda: Statement, reverse="subject") object_statements = orm.Set(lambda: Statement, reverse="object") - health_info = orm.Set(lambda: HealthInfo, reverse="torrent") + health_info = orm.Set(lambda: db.HealthInfo, reverse="torrent") orm.composite_key(name, type) @@ -142,19 +145,7 @@ class StatementOp(db.Entity): orm.composite_key(statement, peer) - class Misc(db.Entity): # pylint: disable=unused-variable - name = orm.PrimaryKey(str) - value = orm.Optional(str) - - class HealthInfo(db.Entity): - id = orm.PrimaryKey(int, auto=True) - - torrent = orm.Required(lambda: Resource, index=True) - - seeders = orm.Required(int, default=0) - leechers = orm.Required(int, default=0) - source = orm.Required(int, default=0) # Source enum - last_check = orm.Required(datetime.datetime, default=datetime.datetime.utcnow) + return Peer, Statement, Resource, StatementOp def add_operation(self, operation: StatementOperation, signature: bytes, is_local_peer: bool = False, is_auto_generated: bool = False, counter_increment: int = 1) -> bool: @@ -224,28 +215,6 @@ def add_auto_generated_operation(self, subject_type: ResourceType, subject: str, return self.add_operation(operation, signature=b'', is_local_peer=False, is_auto_generated=True, counter_increment=SHOW_THRESHOLD) - def add_torrent_health(self, torrent_health: HealthInfo): - torrent = get_or_create( - self.instance.Resource, - name=torrent_health.infohash_hex, - type=ResourceType.TORRENT - ) - - health_info_entity = get_or_create( - self.instance.HealthInfo, - torrent=torrent - ) - - health_info_entity.seeders = torrent_health.seeders - health_info_entity.leechers = torrent_health.leechers - health_info_entity.source = torrent_health.source - health_info_entity.last_check = datetime.datetime.utcfromtimestamp(torrent_health.last_check) - - def get_torrent_health(self, infohash: str): - if torrent := self.instance.Resource.get(name=infohash, type=ResourceType.TORRENT): - return self.instance.HealthInfo.get(torrent=torrent) - return None - @staticmethod def _show_condition(s): """This function determines show condition for the statement""" @@ -451,8 +420,7 @@ def get_operations_for_gossip(self, count: int = 10) -> Set[Entity]: count=count ) - def shutdown(self) -> None: - self.instance.disconnect() + def _get_random_operations_by_condition(self, condition: Callable[[Entity], bool], count: int = 5, attempts: int = 100) -> Set[Entity]: @@ -480,11 +448,3 @@ def _get_random_operations_by_condition(self, condition: Callable[[Entity], bool operations.add(operation) return operations - - def get_misc(self, key: str, default: Optional[str] = None) -> Optional[str]: - data = self.instance.Misc.get(name=key) - return data.value if data else default - - def set_misc(self, key: str, value: Any): - key_value = get_or_create(self.instance.Misc, name=key) - key_value.value = str(value) diff --git a/src/tribler/core/components/database/db/layers/tests/test_health_data_access_level.py b/src/tribler/core/components/database/db/layers/tests/test_health_data_access_level.py new file mode 100644 index 00000000000..1844c34db57 --- /dev/null +++ b/src/tribler/core/components/database/db/layers/tests/test_health_data_access_level.py @@ -0,0 +1,26 @@ +import time + +from tribler.core.components.database.db.tests.test_tribler_database import TestTriblerDatabase +from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo, Source +from tribler.core.utilities.pony_utils import db_session + + +# pylint: disable=protected-access +class TestHealthAccessLayer(TestTriblerDatabase): + + @db_session + def test_add_torrent_health(self): + """ Test that add_torrent_health works as expected""" + health_info = HealthInfo( + infohash=b'0' * 20, + seeders=10, + leechers=20, + last_check=int(time.time()), + self_checked=True, + source=Source.POPULARITY_COMMUNITY + ) + + self.db.health.add_torrent_health(health_info) + + assert self.db.health.get_torrent_health(health_info.infohash_hex) # add fields validation + assert not self.db.health.get_torrent_health('missed hash') diff --git a/src/tribler/core/components/database/db/layers/tests/knowledge_data_access_layer.py b/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer.py similarity index 81% rename from src/tribler/core/components/database/db/layers/tests/knowledge_data_access_layer.py rename to src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer.py index 043f7339bfb..4f05bef153a 100644 --- a/src/tribler/core/components/database/db/layers/tests/knowledge_data_access_layer.py +++ b/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer.py @@ -1,18 +1,18 @@ -import time from types import SimpleNamespace from unittest.mock import Mock, patch from pony.orm import commit, db_session -from tribler.core.components.database.db.tests.test_tribler_database_base import Resource, TestTagDBBase -from tribler.core.components.database.db.tribler_database import Operation, PUBLIC_KEY_FOR_AUTO_GENERATED_OPERATIONS, \ - ResourceType, SHOW_THRESHOLD, SimpleStatement, TriblerDatabase -from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo, Source +from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer, \ + Operation, PUBLIC_KEY_FOR_AUTO_GENERATED_OPERATIONS, ResourceType, SHOW_THRESHOLD, SimpleStatement +from tribler.core.components.database.db.layers.tests.test_knowledge_data_access_layer_base import Resource, \ + TestKnowledgeAccessLayerBase +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create # pylint: disable=protected-access -class TestTagDB(TestTagDBBase): +class TestKnowledgeAccessLayer(TestKnowledgeAccessLayerBase): @patch.object(TrackedDatabase, 'generate_mapping') def test_constructor_create_tables_true(self, mocked_generate_mapping: Mock): TriblerDatabase(':memory:') @@ -158,7 +158,7 @@ def _get_statement(t: ResourceType): @db_session def test_add_auto_generated_tag(self): - self.db.add_auto_generated_operation( + self.db.knowledge.add_auto_generated_operation( subject_type=ResourceType.TORRENT, subject='infohash', predicate=ResourceType.TAG, @@ -178,11 +178,11 @@ def test_double_add_auto_generated_tag(self): 'predicate': ResourceType.TAG, 'obj': 'tag' } - self.db.add_auto_generated_operation(**kwargs) - self.db.add_auto_generated_operation(**kwargs) + self.db.knowledge.add_auto_generated_operation(**kwargs) + self.db.knowledge.add_auto_generated_operation(**kwargs) - assert len(self.db.instance.Statement.select()) == 1 - assert self.db.instance.Statement.get().added_count == SHOW_THRESHOLD + assert len(self.db.Statement.select()) == 1 + assert self.db.Statement.get().added_count == SHOW_THRESHOLD @db_session def test_multiple_tags(self): @@ -228,9 +228,9 @@ def test_get_objects_added(self): } ) - assert not self.db.get_objects(subject='missed infohash', predicate=ResourceType.TAG) - assert self.db.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag3', 'tag2'] - assert self.db.get_objects(subject='infohash1', predicate=ResourceType.CONTRIBUTOR) == ['Contributor'] + assert not self.db.knowledge.get_objects(subject='missed infohash', predicate=ResourceType.TAG) + assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag3', 'tag2'] + assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.CONTRIBUTOR) == ['Contributor'] @db_session def test_get_objects_removed(self): @@ -245,9 +245,9 @@ def test_get_objects_removed(self): ) self.add_operation(self.db, subject='infohash1', predicate=ResourceType.TAG, obj='tag2', peer=b'4', - operation=Operation.REMOVE) + operation=Operation.REMOVE) - assert self.db.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1'] + assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1'] @db_session def test_get_objects_case_insensitive(self): @@ -271,18 +271,18 @@ def test_get_objects_case_insensitive(self): ) all_torrents = ['torrent', 'Torrent', 'TORRENT'] - assert self.db.get_objects(subject='ubuntu', predicate=torrent, case_sensitive=False) == all_torrents - assert self.db.get_objects(subject='Ubuntu', predicate=torrent, case_sensitive=False) == all_torrents + assert self.db.knowledge.get_objects(subject='ubuntu', predicate=torrent, case_sensitive=False) == all_torrents + assert self.db.knowledge.get_objects(subject='Ubuntu', predicate=torrent, case_sensitive=False) == all_torrents - assert self.db.get_objects(subject='ubuntu', predicate=torrent, case_sensitive=True) == ['torrent'] - assert self.db.get_objects(subject='Ubuntu', predicate=torrent, case_sensitive=True) == ['Torrent'] + assert self.db.knowledge.get_objects(subject='ubuntu', predicate=torrent, case_sensitive=True) == ['torrent'] + assert self.db.knowledge.get_objects(subject='Ubuntu', predicate=torrent, case_sensitive=True) == ['Torrent'] all_ubuntu = ['ubuntu', 'Ubuntu', 'UBUNTU'] - assert self.db.get_subjects(obj='torrent', predicate=torrent, case_sensitive=False) == all_ubuntu - assert self.db.get_subjects(obj='Torrent', predicate=torrent, case_sensitive=False) == all_ubuntu + assert self.db.knowledge.get_subjects(obj='torrent', predicate=torrent, case_sensitive=False) == all_ubuntu + assert self.db.knowledge.get_subjects(obj='Torrent', predicate=torrent, case_sensitive=False) == all_ubuntu - assert self.db.get_subjects(obj='torrent', predicate=torrent, case_sensitive=True) == ['ubuntu'] - assert self.db.get_subjects(obj='Torrent', predicate=torrent, case_sensitive=True) == ['Ubuntu'] + assert self.db.knowledge.get_subjects(obj='torrent', predicate=torrent, case_sensitive=True) == ['ubuntu'] + assert self.db.knowledge.get_subjects(obj='Torrent', predicate=torrent, case_sensitive=True) == ['Ubuntu'] @db_session def test_show_local_resources(self): @@ -292,7 +292,7 @@ def test_show_local_resources(self): operation=Operation.REMOVE) self.add_operation(self.db, ResourceType.TORRENT, 'infohash1', ResourceType.TAG, 'tag1', b'peer2', operation=Operation.REMOVE) - assert not self.db.get_objects(subject='infohash1', predicate=ResourceType.TAG) + assert not self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) # test local add self.add_operation(self.db, ResourceType.TORRENT, 'infohash1', ResourceType.TAG, 'tag1', b'peer3', @@ -301,8 +301,8 @@ def test_show_local_resources(self): self.add_operation(self.db, ResourceType.TORRENT, 'infohash1', ResourceType.CONTRIBUTOR, 'contributor', b'peer3', operation=Operation.ADD, is_local_peer=True) - assert self.db.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1'] - assert self.db.get_objects(subject='infohash1', predicate=ResourceType.CONTRIBUTOR) == ['contributor'] + assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1'] + assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.CONTRIBUTOR) == ['contributor'] @db_session def test_hide_local_tags(self): @@ -310,13 +310,13 @@ def test_hide_local_tags(self): # No matter of other peers opinions, locally removed tag should be not visible. self.add_operation(self.db, ResourceType.TORRENT, 'infohash1', ResourceType.TAG, 'tag1', b'peer1') self.add_operation(self.db, ResourceType.TORRENT, 'infohash1', ResourceType.TAG, 'tag1', b'peer2') - assert self.db.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1'] + assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1'] # test local remove self.add_operation(self.db, ResourceType.TORRENT, 'infohash1', ResourceType.TAG, 'tag1', b'peer3', operation=Operation.REMOVE, is_local_peer=True) - assert self.db.get_objects(subject='infohash1', predicate=ResourceType.TAG) == [] + assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == [] @db_session def test_suggestions(self): @@ -327,30 +327,31 @@ def test_suggestions(self): self.add_operation(self.db, subject='subject', predicate=ResourceType.CONTRIBUTOR, obj='contributor', peer=b'2') - assert self.db.get_suggestions(subject='subject', - predicate=ResourceType.TAG) == [] # This tag now has enough support + assert self.db.knowledge.get_suggestions(subject='subject', + predicate=ResourceType.TAG) == [] # This tag now has enough support self.add_operation(self.db, subject='subject', predicate=ResourceType.TAG, obj='tag1', peer=b'3', operation=Operation.REMOVE) # score:1 self.add_operation(self.db, subject='subject', predicate=ResourceType.TAG, obj='tag1', peer=b'4', operation=Operation.REMOVE) # score:0 - assert self.db.get_suggestions(subject='subject', predicate=ResourceType.TAG) == ["tag1"] + assert self.db.knowledge.get_suggestions(subject='subject', predicate=ResourceType.TAG) == ["tag1"] self.add_operation(self.db, subject='subject', predicate=ResourceType.TAG, obj='tag1', peer=b'5', operation=Operation.REMOVE) # score:-1 self.add_operation(self.db, subject='subject', predicate=ResourceType.TAG, obj='tag1', peer=b'6', operation=Operation.REMOVE) # score:-2 - assert not self.db.get_suggestions(subject='infohash', predicate=ResourceType.TAG) # below the threshold + assert not self.db.knowledge.get_suggestions(subject='infohash', + predicate=ResourceType.TAG) # below the threshold @db_session def test_get_clock_of_operation(self): operation = self.create_operation() - assert self.db.get_clock(operation) == 0 + assert self.db.knowledge.get_clock(operation) == 0 self.add_operation(self.db, subject=operation.subject, predicate=operation.predicate, obj=operation.object, peer=operation.creator_public_key, clock=1) - assert self.db.get_clock(operation) == 1 + assert self.db.knowledge.get_clock(operation) == 1 @db_session def test_get_tags_operations_for_gossip(self): @@ -367,7 +368,7 @@ def test_get_tags_operations_for_gossip(self): } ) - operations = self.db.get_operations_for_gossip(count=2) + operations = self.db.knowledge.get_operations_for_gossip(count=2) assert len(operations) == 2 assert all(not o.auto_generated for o in operations) @@ -390,7 +391,7 @@ def test_get_subjects_intersection_threshold(self): } ) - actual = self.db.get_subjects_intersection( + actual = self.db.knowledge.get_subjects_intersection( subjects_type=ResourceType.TORRENT, objects={'tag1'}, predicate=ResourceType.TAG @@ -431,7 +432,7 @@ def test_get_subjects_intersection(self): # no results def _results(objects, predicate=ResourceType.TAG, case_sensitive=True): - results = self.db.get_subjects_intersection( + results = self.db.knowledge.get_subjects_intersection( subjects_type=ResourceType.TORRENT, objects=objects, predicate=predicate, @@ -454,9 +455,9 @@ def _results(objects, predicate=ResourceType.TAG, case_sensitive=True): @db_session def test_show_condition(self): - assert TriblerDatabase._show_condition(SimpleNamespace(local_operation=Operation.ADD)) - assert TriblerDatabase._show_condition(SimpleNamespace(local_operation=None, score=SHOW_THRESHOLD)) - assert not TriblerDatabase._show_condition(SimpleNamespace(local_operation=None, score=0)) + assert KnowledgeDataAccessLayer._show_condition(SimpleNamespace(local_operation=Operation.ADD)) + assert KnowledgeDataAccessLayer._show_condition(SimpleNamespace(local_operation=None, score=SHOW_THRESHOLD)) + assert not KnowledgeDataAccessLayer._show_condition(SimpleNamespace(local_operation=None, score=0)) @db_session def test_get_random_operations_by_condition_less_than_count(self): @@ -473,7 +474,7 @@ def test_get_random_operations_by_condition_less_than_count(self): ) # request 5 random operations - random_operations = self.db._get_random_operations_by_condition( + random_operations = self.db.knowledge._get_random_operations_by_condition( condition=lambda _: True, count=5, attempts=100 @@ -495,7 +496,7 @@ def test_get_random_operations_by_condition_greater_than_count(self): ) # request 5 random operations - random_operations = self.db._get_random_operations_by_condition( + random_operations = self.db.knowledge._get_random_operations_by_condition( condition=lambda _: True, count=5, attempts=100 @@ -520,7 +521,7 @@ def test_get_random_tag_operations_by_condition(self): ) # request 5 normal tags - random_operations = self.db._get_random_operations_by_condition( + random_operations = self.db.knowledge._get_random_operations_by_condition( condition=lambda so: not so.auto_generated, count=5, attempts=100 @@ -546,7 +547,7 @@ def test_get_random_tag_operations_by_condition_no_results(self): ) # request 5 normal tags - random_operations = self.db._get_random_operations_by_condition( + random_operations = self.db.knowledge._get_random_operations_by_condition( condition=lambda so: not so.auto_generated, count=5, attempts=100 @@ -575,15 +576,16 @@ def test_get_subjects(self): } ) - actual = self.db.get_subjects(subject_type=ResourceType.TORRENT, predicate=ResourceType.CONTENT_ITEM, - obj='missed') + actual = self.db.knowledge.get_subjects(subject_type=ResourceType.TORRENT, predicate=ResourceType.CONTENT_ITEM, + obj='missed') assert actual == [] - actual = self.db.get_subjects(subject_type=ResourceType.TORRENT, predicate=ResourceType.CONTENT_ITEM, - obj='ubuntu') + actual = self.db.knowledge.get_subjects(subject_type=ResourceType.TORRENT, predicate=ResourceType.CONTENT_ITEM, + obj='ubuntu') assert actual == ['infohash1', 'infohash2'] - actual = self.db.get_subjects(subject_type=ResourceType.TORRENT, predicate=ResourceType.TAG, obj='linux') + actual = self.db.knowledge.get_subjects(subject_type=ResourceType.TORRENT, predicate=ResourceType.TAG, + obj='linux') assert actual == ['infohash1', 'infohash2', 'infohash3'] @db_session @@ -611,12 +613,12 @@ def test_get_statements(self): SimpleStatement(subject_type=ResourceType.TORRENT, subject='infohash1', predicate=ResourceType.TYPE, object='linux') ] - assert self.db.get_statements(subject='infohash1') == expected + assert self.db.knowledge.get_statements(subject='infohash1') == expected expected.append( SimpleStatement(subject_type=ResourceType.TORRENT, subject='INFOHASH1', predicate=ResourceType.TYPE, object='case_insensitive')) - assert self.db.get_statements(subject='infohash1', case_sensitive=False) == expected + assert self.db.knowledge.get_statements(subject='infohash1', case_sensitive=False) == expected @db_session def test_various_queries(self): @@ -642,7 +644,7 @@ def test_various_queries(self): # queries def _objects(subject_type=None, subject='', predicate=None): - return set(self.db.get_objects(subject_type=subject_type, subject=subject, predicate=predicate)) + return set(self.db.knowledge.get_objects(subject_type=subject_type, subject=subject, predicate=predicate)) assert _objects(subject='infohash1') == {'ubuntu', 'linux', 'creator'} assert _objects(subject_type=ResourceType.TORRENT) == {'ubuntu', 'linux', 'debian'} @@ -652,40 +654,8 @@ def _objects(subject_type=None, subject='', predicate=None): assert actual == {'linux'} def _subjects(subject_type=None, obj='', predicate=None): - return set(self.db.get_subjects(subject_type=subject_type, predicate=predicate, obj=obj)) + return set(self.db.knowledge.get_subjects(subject_type=subject_type, predicate=predicate, obj=obj)) assert _subjects(obj='linux') == {'infohash1', 'infohash2', 'infohash3'} assert _subjects(predicate=ResourceType.TAG, obj='linux') == {'infohash3'} assert _subjects(predicate=ResourceType.CONTENT_ITEM) == {'infohash1', 'infohash2'} - - @db_session - def test_non_existent_misc(self): - """Test that get_misc returns proper values""" - # None if the key does not exist - assert not self.db.get_misc(key='non existent') - - # A value if the key does exist - assert self.db.get_misc(key='non existent', default=42) == 42 - - @db_session - def test_set_misc(self): - """Test that set_misc works as expected""" - self.db.set_misc(key='key', value='value') - assert self.db.get_misc(key='key') == 'value' - - @db_session - def test_add_torrent_health(self): - """ Test that add_torrent_health works as expected""" - health_info = HealthInfo( - infohash=b'0' * 20, - seeders=10, - leechers=20, - last_check=int(time.time()), - self_checked=True, - source=Source.POPULARITY_COMMUNITY - ) - - self.db.add_torrent_health(health_info) - - assert self.db.get_torrent_health(health_info.infohash_hex) # add fields validation - assert not self.db.get_torrent_health('missed hash') diff --git a/src/tribler/core/components/database/db/tests/test_tribler_database_base.py b/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer_base.py similarity index 52% rename from src/tribler/core/components/database/db/tests/test_tribler_database_base.py rename to src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer_base.py index c0fde6087b1..87e9be99ac8 100644 --- a/src/tribler/core/components/database/db/tests/test_tribler_database_base.py +++ b/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer_base.py @@ -1,11 +1,12 @@ from dataclasses import dataclass from itertools import count -from ipv8.test.base import TestBase -from pony.orm import commit, db_session +from pony.orm import commit -from tribler.core.components.database.db.tribler_database import Operation, ResourceType, SHOW_THRESHOLD, \ - TriblerDatabase +from tribler.core.components.database.db.layers.knowledge_data_access_layer import Operation, ResourceType, \ + SHOW_THRESHOLD +from tribler.core.components.database.db.tests.test_tribler_database import TestTriblerDatabase +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation from tribler.core.utilities.pony_utils import get_or_create @@ -20,37 +21,12 @@ class Resource: auto_generated: bool = False -class TestTagDBBase(TestBase): - def setUp(self): - super().setUp() - self.db = TriblerDatabase() - - async def tearDown(self): - if self._outcome.errors: - self.dump_db() - - await super().tearDown() - - @db_session - def dump_db(self): - print('\nPeer:') - self.db.instance.Peer.select().show() - print('\nResource:') - self.db.instance.Resource.select().show() - print('\nStatement') - self.db.instance.Statement.select().show() - print('\nStatementOp') - self.db.instance.StatementOp.select().show() - print('\nMisc') - self.db.instance.Misc.select().show() - print('\nHealthInfo') - self.db.instance.HealthInfo.select().show() - +class TestKnowledgeAccessLayerBase(TestTriblerDatabase): def create_statement(self, subject='subject', subject_type: ResourceType = ResourceType.TORRENT, predicate: ResourceType = ResourceType.TAG, obj='object'): - subj = get_or_create(self.db.instance.Resource, name=subject, type=subject_type) - obj = get_or_create(self.db.instance.Resource, name=obj, type=predicate) - statement = get_or_create(self.db.instance.Statement, subject=subj, object=obj) + subj = get_or_create(self.db.Resource, name=subject, type=subject_type) + obj = get_or_create(self.db.Resource, name=obj, type=predicate) + statement = get_or_create(self.db.Statement, subject=subj, object=obj) return statement @@ -61,21 +37,22 @@ def create_operation(subject_type: ResourceType = ResourceType.TORRENT, subject= operation=operation, clock=clock, creator_public_key=peer) @staticmethod - def add_operation(tag_db: TriblerDatabase, subject_type: ResourceType = ResourceType.TORRENT, + def add_operation(db: TriblerDatabase, subject_type: ResourceType = ResourceType.TORRENT, subject: str = 'infohash', predicate: ResourceType = ResourceType.TAG, obj: str = 'tag', peer=b'', operation: Operation = None, is_local_peer=False, clock=None, is_auto_generated=False, counter_increment: int = 1): operation = operation or Operation.ADD - operation = TestTagDBBase.create_operation(subject_type, subject, obj, peer, operation, predicate, clock) - operation.clock = clock or tag_db.get_clock(operation) + 1 - result = tag_db.add_operation(operation, signature=b'', is_local_peer=is_local_peer, - is_auto_generated=is_auto_generated, counter_increment=counter_increment) + operation = TestKnowledgeAccessLayerBase.create_operation(subject_type, subject, obj, peer, operation, + predicate, clock) + operation.clock = clock or db.knowledge.get_clock(operation) + 1 + result = db.knowledge.add_operation(operation, signature=b'', is_local_peer=is_local_peer, + is_auto_generated=is_auto_generated, counter_increment=counter_increment) commit() return result @staticmethod - def add_operation_set(tag_db: TriblerDatabase, dictionary): + def add_operation_set(db: TriblerDatabase, dictionary): index = count(0) def generate_n_peer_names(n): @@ -90,5 +67,5 @@ def generate_n_peer_names(n): for obj in objects: for peer in generate_n_peer_names(obj.count): # assume that for test purposes all subject by default could be `Predicate.TORRENT` - TestTagDBBase.add_operation(tag_db, subject_type, subject, obj.predicate, obj.name, peer, - is_auto_generated=obj.auto_generated) + TestKnowledgeAccessLayerBase.add_operation(db, subject_type, subject, obj.predicate, obj.name, peer, + is_auto_generated=obj.auto_generated) diff --git a/src/tribler/core/components/database/db/tests/test_tribler_database.py b/src/tribler/core/components/database/db/tests/test_tribler_database.py new file mode 100644 index 00000000000..c5d561adb75 --- /dev/null +++ b/src/tribler/core/components/database/db/tests/test_tribler_database.py @@ -0,0 +1,48 @@ +from ipv8.test.base import TestBase +from pony.orm import db_session + +from tribler.core.components.database.db.tribler_database import TriblerDatabase + + +# pylint: disable=protected-access + +class TestTriblerDatabase(TestBase): + def setUp(self): + super().setUp() + self.db = TriblerDatabase() + + async def tearDown(self): + if self._outcome.errors: + self.dump_db() + + await super().tearDown() + + @db_session + def dump_db(self): + print('\nPeer:') + self.db.instance.Peer.select().show() + print('\nResource:') + self.db.instance.Resource.select().show() + print('\nStatement') + self.db.instance.Statement.select().show() + print('\nStatementOp') + self.db.instance.StatementOp.select().show() + print('\nMisc') + self.db.instance.Misc.select().show() + print('\nHealthInfo') + self.db.instance.HealthInfo.select().show() + + @db_session + def test_set_misc(self): + """Test that set_misc works as expected""" + self.db.set_misc(key='key', value='value') + assert self.db.get_misc(key='key') == 'value' + + @db_session + def test_non_existent_misc(self): + """Test that get_misc returns proper values""" + # None if the key does not exist + assert not self.db.get_misc(key='non existent') + + # A value if the key does exist + assert self.db.get_misc(key='non existent', default=42) == 42 diff --git a/src/tribler/core/components/database/db/tribler_database.py b/src/tribler/core/components/database/db/tribler_database.py new file mode 100644 index 00000000000..6c90a8d23f5 --- /dev/null +++ b/src/tribler/core/components/database/db/tribler_database.py @@ -0,0 +1,46 @@ +import logging +from typing import Any, Optional + +from pony import orm + +from tribler.core.components.database.db.layers.health_data_access_level import HealthDataAccessLayer +from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer +from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create + + +class TriblerDatabase: + def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): + self.instance = TrackedDatabase() + + self.knowledge = KnowledgeDataAccessLayer() + self.health = HealthDataAccessLayer() + + self.Misc, = self.define_binding(self.instance) + self.Peer, self.Statement, self.Resource, self.StatementOp, = self.knowledge.apply(self) + self.HealthInfo, = self.health.apply(self) + + self.instance.bind('sqlite', filename or ':memory:', create_db=True) + generate_mapping_kwargs['create_tables'] = create_tables + self.instance.generate_mapping(**generate_mapping_kwargs) + self.logger = logging.getLogger(self.__class__.__name__) + + @staticmethod + def define_binding(db): + """ Define common bindings""" + + class Misc(db.Entity): # pylint: disable=unused-variable + name = orm.PrimaryKey(str) + value = orm.Optional(str) + + return Misc, + + def get_misc(self, key: str, default: Optional[str] = None) -> Optional[str]: + data = self.Misc.get(name=key) + return data.value if data else default + + def set_misc(self, key: str, value: Any): + key_value = get_or_create(self.Misc, name=key) + key_value.value = str(value) + + def shutdown(self) -> None: + self.instance.disconnect() diff --git a/src/tribler/core/components/knowledge/community/knowledge_community.py b/src/tribler/core/components/knowledge/community/knowledge_community.py index 6f276fc7962..c966b7f665d 100644 --- a/src/tribler/core/components/knowledge/community/knowledge_community.py +++ b/src/tribler/core/components/knowledge/community/knowledge_community.py @@ -71,7 +71,7 @@ def on_message(self, peer, raw: RawStatementOperationMessage): self.validate_operation(operation) with db_session(): - is_added = self.db.add_operation(operation, signature.signature) + is_added = self.db.knowledge.add_operation(operation, signature.signature) if is_added: s = f'+ operation added ({operation.object!r} "{operation.predicate}" {operation.subject!r})' self.logger.info(s) @@ -89,7 +89,7 @@ def on_request(self, peer, operation): self.logger.info(f'<- peer {peer.mid.hex()} requested {operations_count} operations') with db_session: - random_operations = self.db.get_operations_for_gossip(count=operations_count) + random_operations = self.db.knowledge.get_operations_for_gossip(count=operations_count) self.logger.debug(f'Response {len(random_operations)} operations') sent_operations = [] diff --git a/src/tribler/core/components/knowledge/community/knowledge_validator.py b/src/tribler/core/components/knowledge/community/knowledge_validator.py index df87c6e5e9d..584ad2a983f 100644 --- a/src/tribler/core/components/knowledge/community/knowledge_validator.py +++ b/src/tribler/core/components/knowledge/community/knowledge_validator.py @@ -1,4 +1,4 @@ -from tribler.core.components.database.db.tribler_database import Operation, ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import Operation, ResourceType from tribler.core.components.knowledge.knowledge_constants import MAX_RESOURCE_LENGTH, MIN_RESOURCE_LENGTH diff --git a/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py b/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py index 4f8d94ebd66..9f4f88f6803 100644 --- a/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py +++ b/src/tribler/core/components/knowledge/community/tests/test_knowledge_community.py @@ -6,9 +6,10 @@ from ipv8.test.mocking.ipv8 import MockIPv8 from pony.orm import db_session +from tribler.core.components.database.db.layers.knowledge_data_access_layer import Operation, ResourceType from tribler.core.components.knowledge.community.knowledge_community import KnowledgeCommunity from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.database.db.tribler_database import TriblerDatabase, Operation, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase REQUEST_INTERVAL_FOR_RANDOM_OPERATIONS = 0.1 # in seconds @@ -30,7 +31,7 @@ def create_operation(self, subject='1' * 20, obj=''): operation = StatementOperation(subject_type=ResourceType.TORRENT, subject=subject, predicate=ResourceType.TAG, object=obj, operation=Operation.ADD, clock=0, creator_public_key=community.key.pub().key_to_bin()) - operation.clock = community.db.get_clock(operation) + 1 + operation.clock = community.db.knowledge.get_clock(operation) + 1 return operation @db_session @@ -47,11 +48,11 @@ def fill_db(self): if i >= 5: signature = b'1' * 64 - community.db.add_operation(message, signature) + community.db.knowledge.add_operation(message, signature) # a single entity should be cyrillic cyrillic_message = self.create_operation(subject='Контент', obj='Тэг') - community.db.add_operation(cyrillic_message, community.sign(cyrillic_message)) + community.db.knowledge.add_operation(cyrillic_message, community.sign(cyrillic_message)) # put them into the past for op in community.db.instance.StatementOp.select(): @@ -72,11 +73,11 @@ async def test_on_request_eat_exceptions(self): # ValueError should be eaten silently self.fill_db() # let's "break" the function that will be called on on_request() - self.overlay(0).db.get_operations_for_gossip = Mock(return_value=[MagicMock()]) + self.overlay(0).db.knowledge.get_operations_for_gossip = Mock(return_value=[MagicMock()]) # occurred exception should be ate by community silently await self.introduce_nodes() await self.deliver_messages(timeout=REQUEST_INTERVAL_FOR_RANDOM_OPERATIONS * 2) - self.overlay(0).db.get_operations_for_gossip.assert_called() + self.overlay(0).db.knowledge.get_operations_for_gossip.assert_called() async def test_no_peers(self): # Test that no error occurs in the community, in case there is no peers diff --git a/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py b/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py index b8663d9fee6..fade63a7e2d 100644 --- a/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py +++ b/src/tribler/core/components/knowledge/community/tests/test_knowledge_validator.py @@ -1,8 +1,8 @@ import pytest +from tribler.core.components.database.db.layers.knowledge_data_access_layer import Operation, ResourceType from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource, validate_operation, \ validate_resource, validate_resource_type -from tribler.core.components.database.db.tribler_database import Operation, ResourceType VALID_TAGS = [ 'nl', diff --git a/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py b/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py index ef4f8c6ac99..d0f4438fdd3 100644 --- a/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py +++ b/src/tribler/core/components/knowledge/restapi/knowledge_endpoint.py @@ -7,10 +7,11 @@ from marshmallow.fields import Boolean, List, String from pony.orm import db_session +from tribler.core.components.database.db.layers.knowledge_data_access_layer import Operation, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.knowledge.community.knowledge_community import KnowledgeCommunity from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource -from tribler.core.components.database.db.tribler_database import TriblerDatabase, Operation, ResourceType from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, RESTEndpoint, RESTResponse from tribler.core.components.restapi.rest.schema import HandledErrorSchema from tribler.core.utilities.utilities import froze_it @@ -88,7 +89,7 @@ def modify_statements(self, infohash: str, statements: list): return # First, get the current statements and compute the diff between the old and new statements - old_statements = self.db.get_statements(subject_type=ResourceType.TORRENT, subject=infohash) + old_statements = self.db.knowledge.get_statements(subject_type=ResourceType.TORRENT, subject=infohash) old_statements = {(stmt.predicate, stmt.object) for stmt in old_statements} self._logger.info(f'Old statements: {old_statements}') new_statements = {(stmt["predicate"], stmt["object"]) for stmt in statements} @@ -105,9 +106,9 @@ def modify_statements(self, infohash: str, statements: list): predicate=predicate, object=obj, operation=type_of_operation, clock=0, creator_public_key=public_key) - operation.clock = self.db.get_clock(operation) + 1 + operation.clock = self.db.knowledge.get_clock(operation) + 1 signature = self.community.sign(operation) - self.db.add_operation(operation, signature, is_local_peer=True) + self.db.knowledge.add_operation(operation, signature, is_local_peer=True) self._logger.info(f'Added statements: {added_statements}') self._logger.info(f'Removed statements: {removed_statements}') @@ -134,5 +135,5 @@ async def get_tag_suggestions(self, request): return error_response with db_session: - suggestions = self.db.get_suggestions(subject=infohash, predicate=ResourceType.TAG) + suggestions = self.db.knowledge.get_suggestions(subject=infohash, predicate=ResourceType.TAG) return RESTResponse({"suggestions": suggestions}) diff --git a/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py b/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py index dd3f090902a..744beaa3138 100644 --- a/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py +++ b/src/tribler/core/components/knowledge/restapi/tests/test_knowledge_endpoint.py @@ -6,8 +6,8 @@ from pony.orm import db_session from tribler.core.components.conftest import TEST_PERSONAL_KEY +from tribler.core.components.database.db.layers.knowledge_data_access_layer import Operation, ResourceType from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.database.db.tribler_database import Operation, ResourceType from tribler.core.components.knowledge.restapi.knowledge_endpoint import KnowledgeEndpoint from tribler.core.components.restapi.rest.base_api_test import do_request from tribler.core.utilities.date_utils import freeze_time @@ -58,7 +58,7 @@ async def test_modify_tags(rest_api, tribler_db): await do_request(rest_api, f'knowledge/{infohash}', request_type="PATCH", expected_code=200, post_data=post_data) with db_session: - tags = tribler_db.get_objects(subject=infohash, predicate=ResourceType.TAG) + tags = tribler_db.knowledge.get_objects(subject=infohash, predicate=ResourceType.TAG) assert len(tags) == 2 # Now remove a tag @@ -67,7 +67,7 @@ async def test_modify_tags(rest_api, tribler_db): await do_request(rest_api, f'knowledge/{infohash}', request_type="PATCH", expected_code=200, post_data=post_data) with db_session: - tags = tribler_db.get_objects(subject=infohash, predicate=ResourceType.TAG) + tags = tribler_db.knowledge.get_objects(subject=infohash, predicate=ResourceType.TAG) assert tags == ["abc"] @@ -77,7 +77,7 @@ async def test_modify_tags_no_community(tribler_db, endpoint): endpoint.modify_statements(infohash, [tag_to_statement("abc"), tag_to_statement("def")]) with db_session: - tags = tribler_db.get_objects(subject=infohash, predicate=ResourceType.TAG) + tags = tribler_db.knowledge.get_objects(subject=infohash, predicate=ResourceType.TAG) assert len(tags) == 0 @@ -107,7 +107,7 @@ def _add_operation(op=Operation.ADD): operation = StatementOperation(subject_type=ResourceType.TORRENT, subject=infohash_str, predicate=ResourceType.TAG, object="test", operation=op, clock=0, creator_public_key=random_key.pub().key_to_bin()) - tribler_db.add_operation(operation, b"") + tribler_db.knowledge.add_operation(operation, b"") _add_operation(op=Operation.ADD) _add_operation(op=Operation.REMOVE) diff --git a/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py b/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py index 9ae08db885b..128076ee962 100644 --- a/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py +++ b/src/tribler/core/components/knowledge/rules/knowledge_rules_processor.py @@ -10,7 +10,8 @@ from pony.orm import db_session from tribler.core import notifications -from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.knowledge.rules.rules_content_items import content_items_rules from tribler.core.components.knowledge.rules.rules_general_tags import general_rules from tribler.core.components.knowledge.rules.tag_rules_base import extract_only_valid_tags @@ -221,7 +222,7 @@ async def process_torrent_title(self, infohash: Optional[bytes] = None, title: O def save_statements(self, subject_type: ResourceType, subject: str, predicate: ResourceType, objects: Set[str]): self.logger.debug(f'Save: {len(objects)} objects for "{subject}" with predicate={predicate}') for obj in objects: - self.db.add_auto_generated_operation(subject_type=subject_type, subject=subject, predicate=predicate, obj=obj) + self.db.knowledge.add_auto_generated_operation(subject_type=subject_type, subject=subject, predicate=predicate, obj=obj) @db_session def get_last_processed_torrent_id(self) -> int: diff --git a/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py b/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py index 17a948bec12..0c237544ed9 100644 --- a/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py +++ b/src/tribler/core/components/knowledge/rules/tests/test_knowledge_rules_processor.py @@ -5,7 +5,8 @@ from ipv8.keyvault.private.libnaclkey import LibNaCLSK from pony.orm import db_session -from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor from tribler.core.components.metadata_store.db.serialization import REGULAR_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore diff --git a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py index 71d2e459875..a51b953375d 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py @@ -12,11 +12,11 @@ from pony.orm import db_session from pony.orm.dbapiprovider import OperationalError +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType from tribler.core.components.ipv8.eva.protocol import EVAProtocol from tribler.core.components.ipv8.eva.result import TransferResult from tribler.core.components.ipv8.tribler_community import TriblerCommunity from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource -from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import LZ4_EMPTY_ARCHIVE, entries_to_chunk from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore @@ -229,7 +229,7 @@ def search_for_tags(self, tags: Optional[List[str]]) -> Optional[Set[str]]: if not tags or not self.tribler_db: return None valid_tags = {tag for tag in tags if is_valid_resource(tag)} - result = self.tribler_db.get_subjects_intersection( + result = self.tribler_db.knowledge.get_subjects_intersection( subjects_type=ResourceType.TORRENT, objects=valid_tags, predicate=ResourceType.TAG, diff --git a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py index 6fcf1134af4..def45d438d1 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/tests/test_remote_search_by_tags.py @@ -5,8 +5,11 @@ from ipv8.test.base import TestBase from pony.orm import db_session -from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType, SHOW_THRESHOLD -from tribler.core.components.database.db.tests.test_tribler_database import Resource, TestTagDB +from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer, \ + ResourceType, SHOW_THRESHOLD +from tribler.core.components.database.db.layers.tests.test_knowledge_data_access_layer_base import Resource, \ + TestKnowledgeAccessLayerBase +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.metadata_store.db.orm_bindings.channel_node import NEW from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity @@ -60,7 +63,7 @@ def test_search_for_tags_no_db(self): # test that in case of missed `tribler_db`, function `search_for_tags` returns None assert self.rqc.search_for_tags(tags=['tag']) is None - @patch.object(TriblerDatabase, 'get_subjects_intersection') + @patch.object(KnowledgeDataAccessLayer, 'get_subjects_intersection') def test_search_for_tags_only_valid_tags(self, mocked_get_subjects_intersection: Mock): # test that function `search_for_tags` uses only valid tags self.rqc.search_for_tags(tags=['invalid_tag' * 50, 'valid_tag']) @@ -91,7 +94,7 @@ async def test_process_rpc_query_with_tags(self): @db_session def fill_tags_database(): - TestTagDB.add_operation_set( + TestKnowledgeAccessLayerBase.add_operation_set( self.rqc.tribler_db, { hexlify(infohash1): [ @@ -100,7 +103,8 @@ def fill_tags_database(): hexlify(infohash2): [ Resource(predicate=ResourceType.TAG, name='tag1', count=SHOW_THRESHOLD - 1), ] - }) + } + ) @db_session def fill_mds(): diff --git a/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py b/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py index 76db9923c7a..9d3ff8d03fe 100644 --- a/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py +++ b/src/tribler/core/components/metadata_store/restapi/metadata_endpoint_base.py @@ -3,7 +3,8 @@ from pony.orm import db_session -from tribler.core.components.database.db.tribler_database import TriblerDatabase, ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor from tribler.core.components.metadata_store.category_filter.family_filter import default_xxx_filter from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT @@ -82,7 +83,7 @@ def add_statements_to_metadata_list(self, contents_list, hide_xxx=False): return for torrent in contents_list: if torrent['type'] == REGULAR_TORRENT: - raw_statements = self.tribler_db.get_statements( + raw_statements = self.tribler_db.knowledge.get_statements( subject_type=ResourceType.TORRENT, subject=torrent["infohash"] ) diff --git a/src/tribler/core/components/metadata_store/restapi/search_endpoint.py b/src/tribler/core/components/metadata_store/restapi/search_endpoint.py index 787e38b51e5..330b966406a 100644 --- a/src/tribler/core/components/metadata_store/restapi/search_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/search_endpoint.py @@ -8,11 +8,11 @@ from marshmallow.fields import Integer, String from pony.orm import db_session -from tribler.core.components.database.db.tribler_database import ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType from tribler.core.components.metadata_store.db.serialization import SNIPPET from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.components.metadata_store.restapi.metadata_endpoint import MetadataEndpointBase -from tribler.core.components.metadata_store.restapi.metadata_schema import SearchMetadataParameters, MetadataSchema +from tribler.core.components.metadata_store.restapi.metadata_schema import MetadataSchema, SearchMetadataParameters from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, RESTResponse from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.utilities import froze_it @@ -49,9 +49,9 @@ def build_snippets(self, search_results: List[Dict]) -> List[Dict]: content_to_torrents: Dict[str, list] = defaultdict(list) for search_result in search_results: with db_session: - content_items: List[str] = self.tribler_db.get_objects(subject_type=ResourceType.TORRENT, - subject=search_result["infohash"], - predicate=ResourceType.CONTENT_ITEM) + content_items: List[str] = self.tribler_db.knowledge.get_objects(subject_type=ResourceType.TORRENT, + subject=search_result["infohash"], + predicate=ResourceType.CONTENT_ITEM) if content_items: for content_id in content_items: content_to_torrents[content_id].append(search_result) @@ -150,10 +150,11 @@ def search_db(): try: with db_session: if tags: - infohash_set = self.tribler_db.get_subjects_intersection(subjects_type=ResourceType.TORRENT, - objects=set(tags), - predicate=ResourceType.TAG, - case_sensitive=False) + infohash_set = self.tribler_db.knowledge.get_subjects_intersection( + subjects_type=ResourceType.TORRENT, + objects=set(tags), + predicate=ResourceType.TAG, + case_sensitive=False) if infohash_set: sanitized['infohash_set'] = {bytes.fromhex(s) for s in infohash_set} diff --git a/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py b/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py index a9776b19450..9d342ca991c 100644 --- a/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/tests/test_channels_endpoint.py @@ -7,8 +7,8 @@ from ipv8.util import succeed from pony.orm import db_session +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType from tribler.core.components.gigachannel.community.gigachannel_community import NoChannelSourcesException -from tribler.core.components.database.db.tribler_database import ResourceType from tribler.core.components.libtorrent.torrentdef import TorrentDef from tribler.core.components.metadata_store.category_filter.family_filter import default_xxx_filter from tribler.core.components.metadata_store.db.orm_bindings.channel_node import NEW diff --git a/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py b/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py index 03eaedfdbcb..1ad8f213209 100644 --- a/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/tests/test_search_endpoint.py @@ -6,6 +6,7 @@ import pytest from pony.orm import db_session +from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.metadata_store.db.serialization import REGULAR_TORRENT, SNIPPET from tribler.core.components.metadata_store.restapi.search_endpoint import SearchEndpoint @@ -72,7 +73,7 @@ def mocked_get_subjects_intersection(*_, objects: Set[str], **__): return None return {hexlify(os.urandom(20))} - with patch.object(TriblerDatabase, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection): + with patch.object(KnowledgeDataAccessLayer, 'get_subjects_intersection', wraps=mocked_get_subjects_intersection): parsed = await do_request(rest_api, 'search?txt_filter=needle&tags=real_tag', expected_code=200) assert len(parsed["results"]) == 0 @@ -170,7 +171,7 @@ async def test_single_snippet_in_search(rest_api, metadata_store, tribler_db): def mocked_get_subjects(*_, **__) -> List[str]: return ["Abc"] - with patch.object(TriblerDatabase, 'get_objects', wraps=mocked_get_subjects): + with patch.object(KnowledgeDataAccessLayer, 'get_objects', wraps=mocked_get_subjects): s1 = to_fts_query("abc") results = await do_request(rest_api, f'search?txt_filter={s1}', expected_code=200) @@ -200,7 +201,7 @@ def mocked_get_objects(*__, subject=None, **___) -> List[str]: return ["Content item 2"] return [] - with patch.object(TriblerDatabase, 'get_objects', wraps=mocked_get_objects): + with patch.object(KnowledgeDataAccessLayer, 'get_objects', wraps=mocked_get_objects): s1 = to_fts_query("abc") parsed = await do_request(rest_api, f'search?txt_filter={s1}', expected_code=200) results = parsed["results"] diff --git a/src/tribler/core/components/metadata_store/utils.py b/src/tribler/core/components/metadata_store/utils.py index 7cdc6ba8cde..eecdd4f1c3a 100644 --- a/src/tribler/core/components/metadata_store/utils.py +++ b/src/tribler/core/components/metadata_store/utils.py @@ -5,8 +5,9 @@ from ipv8.keyvault.crypto import default_eccrypto from pony.orm import db_session +from tribler.core.components.database.db.layers.knowledge_data_access_layer import Operation, ResourceType +from tribler.core.components.database.db.tribler_database import TriblerDatabase from tribler.core.components.knowledge.community.knowledge_payload import StatementOperation -from tribler.core.components.database.db.tribler_database import TriblerDatabase, Operation, ResourceType from tribler.core.components.knowledge.knowledge_constants import MIN_RESOURCE_LENGTH from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.tests.tools.common import PNG_FILE @@ -39,7 +40,7 @@ def get_random_word(min_length=0): return word -def tag_torrent(infohash, tags_db, tags=None, suggested_tags=None): +def tag_torrent(infohash, db, tags=None, suggested_tags=None): infohash = hexlify(infohash) if tags is None: tags_count = random.randint(2, 6) @@ -60,8 +61,8 @@ def tag_torrent(infohash, tags_db, tags=None, suggested_tags=None): def _add_operation(_obj, _op, _key, _predicate=ResourceType.TAG): operation = StatementOperation(subject_type=ResourceType.TORRENT, subject=infohash, predicate=_predicate, object=_obj, operation=_op, clock=0, creator_public_key=_key.pub().key_to_bin()) - operation.clock = tags_db.get_clock(operation) + 1 - tags_db.add_operation(operation, b"") + operation.clock = db.knowledge.get_clock(operation) + 1 + db.knowledge.add_operation(operation, b"") # Give the torrent some tags for tag in tags: @@ -86,7 +87,7 @@ def _add_operation(_obj, _op, _key, _predicate=ResourceType.TAG): @db_session -def generate_torrent(metadata_store, tags_db, parent, title=None): +def generate_torrent(metadata_store, db, parent, title=None): infohash = random_infohash() # Give each torrent some health information. For now, we assume all torrents are healthy. @@ -97,7 +98,7 @@ def generate_torrent(metadata_store, tags_db, parent, title=None): metadata_store.TorrentMetadata(title=title or generate_title(words_count=4), infohash=infohash, origin_id=parent.id_, health=torrent_state, tags=category) - tag_torrent(infohash, tags_db) + tag_torrent(infohash, db) @db_session @@ -108,7 +109,7 @@ def generate_collection(metadata_store, tags_db, parent): @db_session -def generate_channel(metadata_store: MetadataStore, tags_db: TriblerDatabase, title=None, subscribed=False): +def generate_channel(metadata_store: MetadataStore, db: TriblerDatabase, title=None, subscribed=False): # Remember and restore the original key orig_key = metadata_store.ChannelNode._my_key @@ -119,7 +120,7 @@ def generate_channel(metadata_store: MetadataStore, tags_db: TriblerDatabase, ti # add some collections to the channel for _ in range(0, 3): - generate_collection(metadata_store, tags_db, chan) + generate_collection(metadata_store, db, chan) metadata_store.ChannelNode._my_key = orig_key diff --git a/src/tribler/gui/dialogs/editmetadatadialog.py b/src/tribler/gui/dialogs/editmetadatadialog.py index 14d1d68c874..14a0beb5548 100644 --- a/src/tribler/gui/dialogs/editmetadatadialog.py +++ b/src/tribler/gui/dialogs/editmetadatadialog.py @@ -4,7 +4,7 @@ from PyQt5.QtCore import QModelIndex, QPoint, Qt, pyqtSignal from PyQt5.QtWidgets import QComboBox, QSizePolicy, QWidget -from tribler.core.components.database.db.tribler_database import ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType from tribler.core.components.knowledge.knowledge_constants import MAX_RESOURCE_LENGTH, MIN_RESOURCE_LENGTH from tribler.gui.defs import TAG_HORIZONTAL_MARGIN from tribler.gui.dialogs.dialogcontainer import DialogContainer diff --git a/src/tribler/gui/utilities.py b/src/tribler/gui/utilities.py index bcf4b4f1198..2d121c64c91 100644 --- a/src/tribler/gui/utilities.py +++ b/src/tribler/gui/utilities.py @@ -27,7 +27,7 @@ from PyQt5.QtWidgets import QApplication, QMessageBox import tribler.gui -from tribler.core.components.database.db.tribler_database import ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType from tribler.gui.defs import HEALTH_DEAD, HEALTH_GOOD, HEALTH_MOOT, HEALTH_UNCHECKED # fmt: off diff --git a/src/tribler/gui/widgets/tablecontentdelegate.py b/src/tribler/gui/widgets/tablecontentdelegate.py index 41d37ffdba8..e0d5df3e4b5 100644 --- a/src/tribler/gui/widgets/tablecontentdelegate.py +++ b/src/tribler/gui/widgets/tablecontentdelegate.py @@ -6,7 +6,7 @@ from PyQt5.QtWidgets import QApplication, QComboBox, QStyle, QStyleOptionViewItem, QStyledItemDelegate, QToolTip from psutil import LINUX -from tribler.core.components.database.db.tribler_database import ResourceType +from tribler.core.components.database.db.layers.knowledge_data_access_layer import ResourceType from tribler.core.components.metadata_store.db.orm_bindings.channel_node import LEGACY_ENTRY from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT, COLLECTION_NODE, REGULAR_TORRENT, \ SNIPPET