From 3bda1f71fe22b2fa389b525f111d466bd5a76b01 Mon Sep 17 00:00:00 2001 From: drew2a Date: Mon, 16 Oct 2023 17:36:44 +0300 Subject: [PATCH] Add `TriblerDatabaseMigrationChain` --- .../components/database/database_component.py | 2 +- .../tests/test_knowledge_data_access_layer.py | 12 ++- .../db/tests/test_tribler_database.py | 24 +++++- .../database/db/tribler_database.py | 37 +++++++++- .../core/upgrade/tribler_db/conftest.py | 18 +++++ .../core/upgrade/tribler_db/decorator.py | 53 +++++++++++++ .../upgrade/tribler_db/migration_chain.py | 48 ++++++++++++ .../scheme_migrations/scheme_migration_0.py | 23 ++++++ .../tests/test_scheme_migration_0.py | 13 ++++ .../tribler_db/tests/test_decorator.py | 74 +++++++++++++++++++ .../tribler_db/tests/test_migration_chain.py | 55 ++++++++++++++ src/tribler/core/upgrade/upgrade.py | 4 + 12 files changed, 353 insertions(+), 10 deletions(-) create mode 100644 src/tribler/core/upgrade/tribler_db/conftest.py create mode 100644 src/tribler/core/upgrade/tribler_db/decorator.py create mode 100644 src/tribler/core/upgrade/tribler_db/migration_chain.py create mode 100644 src/tribler/core/upgrade/tribler_db/scheme_migrations/scheme_migration_0.py create mode 100644 src/tribler/core/upgrade/tribler_db/scheme_migrations/tests/test_scheme_migration_0.py create mode 100644 src/tribler/core/upgrade/tribler_db/tests/test_decorator.py create mode 100644 src/tribler/core/upgrade/tribler_db/tests/test_migration_chain.py diff --git a/src/tribler/core/components/database/database_component.py b/src/tribler/core/components/database/database_component.py index 124d592a09f..92c300adcb8 100644 --- a/src/tribler/core/components/database/database_component.py +++ b/src/tribler/core/components/database/database_component.py @@ -15,7 +15,7 @@ async def run(self): if self.session.config.gui_test_mode: db_path = ":memory:" - self.db = TriblerDatabase(str(db_path), create_tables=True) + self.db = TriblerDatabase(str(db_path)) async def shutdown(self): await super().shutdown() diff --git a/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer.py b/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer.py index 1c3e073e620..9ca7aa42ca1 100644 --- a/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer.py +++ b/src/tribler/core/components/database/db/layers/tests/test_knowledge_data_access_layer.py @@ -14,13 +14,19 @@ # pylint: disable=protected-access class TestKnowledgeAccessLayer(TestKnowledgeAccessLayerBase): @patch.object(TrackedDatabase, 'generate_mapping') + @patch.object(TriblerDatabase, 'fill_default_data', Mock()) def test_constructor_create_tables_true(self, mocked_generate_mapping: Mock): - TriblerDatabase(':memory:') + """ Test that constructor of TriblerDatabase calls TrackedDatabase.generate_mapping with create_tables=True""" + TriblerDatabase() + mocked_generate_mapping.assert_called_with(create_tables=True) @patch.object(TrackedDatabase, 'generate_mapping') + @patch.object(TriblerDatabase, 'fill_default_data', Mock()) def test_constructor_create_tables_false(self, mocked_generate_mapping: Mock): - TriblerDatabase(':memory:', create_tables=False) + """ Test that constructor of TriblerDatabase calls TrackedDatabase.generate_mapping with create_tables=False""" + TriblerDatabase(create_tables=False) + mocked_generate_mapping.assert_called_with(create_tables=False) @db_session @@ -245,7 +251,7 @@ def test_get_objects_removed(self): ) self.add_operation(self.db, subject='infohash1', predicate=ResourceType.TAG, obj='tag2', peer=b'4', - operation=Operation.REMOVE) + operation=Operation.REMOVE) assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1'] diff --git a/src/tribler/core/components/database/db/tests/test_tribler_database.py b/src/tribler/core/components/database/db/tests/test_tribler_database.py index 2c2027b101e..1038069d584 100644 --- a/src/tribler/core/components/database/db/tests/test_tribler_database.py +++ b/src/tribler/core/components/database/db/tests/test_tribler_database.py @@ -1,3 +1,4 @@ +import pytest from ipv8.test.base import TestBase from pony.orm import db_session @@ -37,8 +38,10 @@ def dump_db(self): @db_session def test_set_misc(self): """Test that set_misc works as expected""" - self.db.set_misc(key='key', value='value') - assert self.db.get_misc(key='key') == 'value' + self.db.set_misc(key='string', value='value') + self.db.set_misc(key='integer', value=1) + assert self.db.get_misc(key='string') == 'value' + assert self.db.get_misc(key='integer') == '1' @db_session def test_non_existent_misc(self): @@ -48,3 +51,20 @@ def test_non_existent_misc(self): # A value if the key does exist assert self.db.get_misc(key='non existent', default=42) == 42 + + @db_session + def test_default_version(self): + """ Test that the default version is equal to `CURRENT_VERSION`""" + assert self.db.version == TriblerDatabase.CURRENT_VERSION + + @db_session + def test_version_getter_and_setter(self): + """ Test that the version getter and setter work as expected""" + self.db.version = 42 + assert self.db.version == 42 + + @db_session + def test_version_getter_unsupported_type(self): + """ Test that the version getter raises a TypeError if the type is not supported""" + with pytest.raises(TypeError): + self.db.version = 'string' diff --git a/src/tribler/core/components/database/db/tribler_database.py b/src/tribler/core/components/database/db/tribler_database.py index 847a6b65a65..7a544bafd5b 100644 --- a/src/tribler/core/components/database/db/tribler_database.py +++ b/src/tribler/core/components/database/db/tribler_database.py @@ -1,14 +1,20 @@ import logging +import os from typing import Any, Optional from pony import orm from tribler.core.components.database.db.layers.health_data_access_layer import HealthDataAccessLayer from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer -from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create +from tribler.core.utilities.pony_utils import TrackedDatabase, db_session, get_or_create + +MEMORY = ':memory:' class TriblerDatabase: + CURRENT_VERSION = 1 + _SCHEME_VERSION_KEY = 'scheme_version' + def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): self.instance = TrackedDatabase() @@ -25,21 +31,31 @@ def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True self.TorrentHealth = self.health.TorrentHealth self.Tracker = self.health.Tracker - self.instance.bind('sqlite', filename or ':memory:', create_db=True) + filename = filename or MEMORY + db_does_not_exist = filename == MEMORY or not os.path.isfile(filename) + + self.instance.bind('sqlite', filename, create_db=db_does_not_exist) generate_mapping_kwargs['create_tables'] = create_tables self.instance.generate_mapping(**generate_mapping_kwargs) self.logger = logging.getLogger(self.__class__.__name__) + if db_does_not_exist: + self.fill_default_data() + @staticmethod def define_binding(db): """ Define common bindings""" - - class Misc(db.Entity): # pylint: disable=unused-variable + class Misc(db.Entity): name = orm.PrimaryKey(str) value = orm.Optional(str) return Misc + @db_session + def fill_default_data(self): + self.logger.info('Filling the DB with the default data') + self.set_misc(self._SCHEME_VERSION_KEY, self.CURRENT_VERSION) + def get_misc(self, key: str, default: Optional[str] = None) -> Optional[str]: data = self.Misc.get(name=key) return data.value if data else default @@ -48,5 +64,18 @@ def set_misc(self, key: str, value: Any): key_value = get_or_create(self.Misc, name=key) key_value.value = str(value) + @property + def version(self) -> int: + """ Get the database version""" + return int(self.get_misc(key=self._SCHEME_VERSION_KEY, default=0)) + + @version.setter + def version(self, value: int): + """ Set the database version""" + if not isinstance(value, int): + raise TypeError('DB version should be integer') + + self.set_misc(key=self._SCHEME_VERSION_KEY, value=value) + def shutdown(self) -> None: self.instance.disconnect() diff --git a/src/tribler/core/upgrade/tribler_db/conftest.py b/src/tribler/core/upgrade/tribler_db/conftest.py new file mode 100644 index 00000000000..f18f3d13ca0 --- /dev/null +++ b/src/tribler/core/upgrade/tribler_db/conftest.py @@ -0,0 +1,18 @@ +import pytest + +from tribler.core.components.database.db.tribler_database import TriblerDatabase +from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain +from tribler.core.utilities.path_util import Path +from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR + + +# pylint: disable=redefined-outer-name + + +@pytest.fixture +def migration_chain(tmpdir): + """ Create an empty migration chain with an empty database.""" + db_file_name = Path(tmpdir) / STATEDIR_DB_DIR / 'tribler.db' + db_file_name.parent.mkdir() + TriblerDatabase(filename=str(db_file_name)) + return TriblerDatabaseMigrationChain(state_dir=Path(tmpdir), chain=[]) diff --git a/src/tribler/core/upgrade/tribler_db/decorator.py b/src/tribler/core/upgrade/tribler_db/decorator.py new file mode 100644 index 00000000000..3864c15859b --- /dev/null +++ b/src/tribler/core/upgrade/tribler_db/decorator.py @@ -0,0 +1,53 @@ +import functools +import logging +from typing import Callable, Optional + +from tribler.core.components.database.db.tribler_database import TriblerDatabase +from tribler.core.utilities.pony_utils import db_session + +MIGRATION_METADATA = "_tribler_db_migration" + +logger = logging.getLogger('Migration (TriblerDB)') + + +def migration(execute_only_if_version: int, set_after_successful_execution_version: Optional[int] = None): + """ Decorator for migration functions. + The migration executes in the single transaction. If the migration fails, the transaction is rolled back. + The decorator also sets the metadata attribute to the decorated function. It could be checked by + calling the `has_migration_metadata` function. + Args: + execute_only_if_version: Execute the migration only if the current db version is equal to this value. + set_after_successful_execution_version: Set the db version to this value after the migration is executed. + If it is not specified, then `set_after_successful_execution_version = execute_only_if_version + 1` + """ + + def decorator(func): + @functools.wraps(func) + @db_session + def wrapper(db: TriblerDatabase, **kwargs): + target_version = execute_only_if_version + if target_version != db.version: + logger.info( + f"Function {func.__name__} is not executed because DB version is not equal to {target_version}. " + f"The current db version is {db.version}" + ) + return None + + result = func(db, **kwargs) + + next_version = set_after_successful_execution_version + if next_version is None: + next_version = target_version + 1 + db.version = next_version + + return result + + setattr(wrapper, MIGRATION_METADATA, {}) + return wrapper + + return decorator + + +def has_migration_metadata(f: Callable): + """ Check if the function has migration metadata.""" + return hasattr(f, MIGRATION_METADATA) diff --git a/src/tribler/core/upgrade/tribler_db/migration_chain.py b/src/tribler/core/upgrade/tribler_db/migration_chain.py new file mode 100644 index 00000000000..3ad141001a9 --- /dev/null +++ b/src/tribler/core/upgrade/tribler_db/migration_chain.py @@ -0,0 +1,48 @@ +import logging +from typing import Callable, List, Optional + +from tribler.core.components.database.db.tribler_database import TriblerDatabase +from tribler.core.upgrade.tribler_db.decorator import has_migration_metadata +from tribler.core.upgrade.tribler_db.scheme_migrations.scheme_migration_0 import scheme_migration_0 +from tribler.core.utilities.path_util import Path +from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR + + +class TriblerDatabaseMigrationChain: + """ A chain of migrations that can be executed on a TriblerDatabase. + + To create a new migration, create a new function and decorate it with the `migration` decorator. Then add it to + the `DEFAULT_CHAIN` list. + """ + + DEFAULT_CHAIN = [ + scheme_migration_0, + # add your migration here + ] + + def __init__(self, state_dir: Path, chain: Optional[List[Callable]] = None): + self.logger = logging.getLogger(self.__class__.__name__) + self.state_dir = state_dir + + db_path = self.state_dir / STATEDIR_DB_DIR / 'tribler.db' + self.logger.info(f'Tribler DB path: {db_path}') + self.db = TriblerDatabase(str(db_path), check_tables=False) if db_path.is_file() else None + + self.migrations = chain or self.DEFAULT_CHAIN + + def execute(self) -> bool: + """ Execute all migrations in the chain. + + Returns: True if all migrations were executed successfully, False otherwise. + An exception in any of the migrations will halt the execution chain and be re-raised. + """ + + if not self.db: + return False + + for m in self.migrations: + if not has_migration_metadata(m): + raise NotImplementedError(f'The migration {m} should have `migration` decorator') + m(self.db, state_dir=self.state_dir) + + return True diff --git a/src/tribler/core/upgrade/tribler_db/scheme_migrations/scheme_migration_0.py b/src/tribler/core/upgrade/tribler_db/scheme_migrations/scheme_migration_0.py new file mode 100644 index 00000000000..3b813a9e6fb --- /dev/null +++ b/src/tribler/core/upgrade/tribler_db/scheme_migrations/scheme_migration_0.py @@ -0,0 +1,23 @@ +from tribler.core.components.database.db.tribler_database import TriblerDatabase +from tribler.core.upgrade.tribler_db.decorator import migration + + +@migration(execute_only_if_version=0) +def scheme_migration_0(db: TriblerDatabase, **kwargs): # pylint: disable=unused-argument + """ "This is initial migration, placed here primarily for demonstration purposes. + It doesn't do anything except set the database version to `1`. + + For upcoming migrations, there are some guidelines: + 1. functions should contain a single parameter, `db: TriblerDatabase`, + 2. they should apply the `@migration` decorator. + + + Utilizing plain SQL (as seen in the example below) is considered good practice since it helps prevent potential + inconsistencies in DB schemes in the future (model versions preceding the current one may differ from it). + For more information see: https://github.com/Tribler/tribler/issues/7382 + + The example of a migration: + + db.execute('ALTER TABLE "TorrentState" ADD "has_data" BOOLEAN DEFAULT 0') + db.execute('UPDATE "TorrentState" SET "has_data" = 1 WHERE last_check > 0') + """ diff --git a/src/tribler/core/upgrade/tribler_db/scheme_migrations/tests/test_scheme_migration_0.py b/src/tribler/core/upgrade/tribler_db/scheme_migrations/tests/test_scheme_migration_0.py new file mode 100644 index 00000000000..70d03ec1437 --- /dev/null +++ b/src/tribler/core/upgrade/tribler_db/scheme_migrations/tests/test_scheme_migration_0.py @@ -0,0 +1,13 @@ +from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain +from tribler.core.upgrade.tribler_db.scheme_migrations.scheme_migration_0 import scheme_migration_0 +from tribler.core.utilities.pony_utils import db_session + + +@db_session +def test_scheme_migration_0(migration_chain: TriblerDatabaseMigrationChain): + """ Test that the scheme_migration_0 changes the database version to 1. """ + migration_chain.db.version = 0 + migration_chain.migrations = [scheme_migration_0] + + assert migration_chain.execute() + assert migration_chain.db.version == 1 diff --git a/src/tribler/core/upgrade/tribler_db/tests/test_decorator.py b/src/tribler/core/upgrade/tribler_db/tests/test_decorator.py new file mode 100644 index 00000000000..2be2283d94d --- /dev/null +++ b/src/tribler/core/upgrade/tribler_db/tests/test_decorator.py @@ -0,0 +1,74 @@ +from unittest.mock import Mock + +import pytest + +from tribler.core.upgrade.tribler_db.decorator import has_migration_metadata, migration + + +def test_migration_execute_only_if_version(): + """ Test that migration is executed only if the version of the database is equal to the specified one.""" + + @migration(execute_only_if_version=1) + def test(_: Mock): + return True + + assert test(Mock(version=1)) + assert not test(Mock(version=2)) + + +def test_set_after_successful_execution_version(): + """ Test that the version of the database is set to the specified one after the migration is successfully + executed. + """ + + @migration(execute_only_if_version=1, set_after_successful_execution_version=33) + def test(_: Mock): + ... + + db = Mock(version=1) + test(db) + + assert db.version == 33 + + +def test_set_after_successful_execution_version_not_specified(): + """ Test that if the version is not specified, the version of the database will be set to + execute_only_if_version + 1 + """ + + @migration(execute_only_if_version=1) + def test(_: Mock): + ... + + db = Mock(version=1) + test(db) + + assert db.version == 2 + + +def test_set_after_successful_execution_raise_an_exception(): + """ Test that if an exception is raised during the migration, the version of the database is not changed.""" + + @migration(execute_only_if_version=1, set_after_successful_execution_version=33) + def test(_: Mock): + raise TypeError + + db = Mock(version=1) + with pytest.raises(TypeError): + test(db) + + assert db.version == 1 + + +def test_set_metadata(): + """ Test that the metadata flag is set.""" + + @migration(execute_only_if_version=1) + def simple_migration(_: Mock): + ... + + def no_migration(_: Mock): + ... + + assert has_migration_metadata(simple_migration) + assert not has_migration_metadata(no_migration) diff --git a/src/tribler/core/upgrade/tribler_db/tests/test_migration_chain.py b/src/tribler/core/upgrade/tribler_db/tests/test_migration_chain.py new file mode 100644 index 00000000000..7feb6653deb --- /dev/null +++ b/src/tribler/core/upgrade/tribler_db/tests/test_migration_chain.py @@ -0,0 +1,55 @@ +import pytest + +from tribler.core.upgrade.tribler_db.decorator import migration +from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain +from tribler.core.utilities.path_util import Path +from tribler.core.utilities.pony_utils import db_session + + +def test_db_does_not_exist(tmpdir): + """ Test that the migration chain does not execute if the database does not exist.""" + tribler_db_migration = TriblerDatabaseMigrationChain(state_dir=Path(tmpdir)) + assert not tribler_db_migration.execute() + + +@db_session +def test_db_execute(migration_chain: TriblerDatabaseMigrationChain): + """ Test that the migration chain executes all the migrations step by step.""" + migration_chain.db.version = 0 + + @migration(execute_only_if_version=0) + def migration1(*_, **__): + ... + + @migration(execute_only_if_version=1) + def migration2(*_, **__): + ... + + @migration(execute_only_if_version=99) + def migration99(*_, **__): # this migration should be skipped + ... + + migration_chain.migrations = [ + migration1, + migration2, + migration99, + ] + + # test execution of all the migration + assert migration_chain.execute() + assert migration_chain.db.version == 2 + + +@db_session +def test_db_execute_no_annotation(migration_chain: TriblerDatabaseMigrationChain): + """ Test that the migration chain raises the NotImplementedError if the migration does not have the annotation.""" + + def migration_without_annotation(*_, **__): + ... + + migration_chain.migrations = [ + migration_without_annotation + ] + + with pytest.raises(NotImplementedError): + migration_chain.execute() diff --git a/src/tribler/core/upgrade/upgrade.py b/src/tribler/core/upgrade/upgrade.py index 682f9157836..dbb5abc224e 100644 --- a/src/tribler/core/upgrade/upgrade.py +++ b/src/tribler/core/upgrade/upgrade.py @@ -22,6 +22,7 @@ from tribler.core.upgrade.knowledge_to_triblerdb.migration import MigrationKnowledgeToTriblerDB from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge from tribler.core.upgrade.tags_to_knowledge.previous_dbs.tags_db import TagDatabase +from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain from tribler.core.utilities.configparser import CallbackConfigParser from tribler.core.utilities.path_util import Path from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR @@ -107,6 +108,9 @@ def run(self): self.upgrade_pony_db_14to15() self.upgrade_knowledge_to_tribler_db() + migration_chain = TriblerDatabaseMigrationChain(self.state_dir) + migration_chain.execute() + def remove_old_logs(self) -> Tuple[List[Path], List[Path]]: self._logger.info(f'Remove old logs')