Skip to content

Commit

Permalink
Add TriblerDatabaseMigrationChain
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Oct 16, 2023
1 parent 20fb224 commit ef744cb
Show file tree
Hide file tree
Showing 12 changed files with 352 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/tribler/core/components/database/database_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def run(self):
if self.session.config.gui_test_mode:
db_path = ":memory:"

self.db = TriblerDatabase(str(db_path), create_tables=True)
self.db = TriblerDatabase(str(db_path))

async def shutdown(self):
await super().shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# pylint: disable=protected-access
class TestKnowledgeAccessLayer(TestKnowledgeAccessLayerBase):
@patch.object(TrackedDatabase, 'generate_mapping')
@patch.object(TriblerDatabase, 'fill_default_data', Mock())
def test_constructor_create_tables_true(self, mocked_generate_mapping: Mock):
TriblerDatabase(':memory:')
""" Test that constructor of TriblerDatabase calls TrackedDatabase.generate_mapping with create_tables=True"""
TriblerDatabase()

mocked_generate_mapping.assert_called_with(create_tables=True)

@patch.object(TrackedDatabase, 'generate_mapping')
@patch.object(TriblerDatabase, 'fill_default_data', Mock())
def test_constructor_create_tables_false(self, mocked_generate_mapping: Mock):
TriblerDatabase(':memory:', create_tables=False)
""" Test that constructor of TriblerDatabase calls TrackedDatabase.generate_mapping with create_tables=False"""
TriblerDatabase(create_tables=False)

mocked_generate_mapping.assert_called_with(create_tables=False)

@db_session
Expand Down Expand Up @@ -245,7 +251,7 @@ def test_get_objects_removed(self):
)

self.add_operation(self.db, subject='infohash1', predicate=ResourceType.TAG, obj='tag2', peer=b'4',
operation=Operation.REMOVE)
operation=Operation.REMOVE)

assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1']

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from ipv8.test.base import TestBase
from pony.orm import db_session

Expand Down Expand Up @@ -37,8 +38,10 @@ def dump_db(self):
@db_session
def test_set_misc(self):
"""Test that set_misc works as expected"""
self.db.set_misc(key='key', value='value')
assert self.db.get_misc(key='key') == 'value'
self.db.set_misc(key='string', value='value')
self.db.set_misc(key='integer', value=1)
assert self.db.get_misc(key='string') == 'value'
assert self.db.get_misc(key='integer') == '1'

@db_session
def test_non_existent_misc(self):
Expand All @@ -48,3 +51,20 @@ def test_non_existent_misc(self):

# A value if the key does exist
assert self.db.get_misc(key='non existent', default=42) == 42

@db_session
def test_default_version(self):
""" Test that the default version is equal to `CURRENT_VERSION`"""
assert self.db.version == TriblerDatabase.CURRENT_VERSION

@db_session
def test_version_getter_and_setter(self):
""" Test that the version getter and setter work as expected"""
self.db.version = 42
assert self.db.version == 42

@db_session
def test_version_getter_unsupported_type(self):
""" Test that the version getter raises a TypeError if the type is not supported"""
with pytest.raises(TypeError):
self.db.version = 'string'
37 changes: 33 additions & 4 deletions src/tribler/core/components/database/db/tribler_database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import logging
import os
from typing import Any, Optional

from pony import orm

from tribler.core.components.database.db.layers.health_data_access_layer import HealthDataAccessLayer
from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer
from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create
from tribler.core.utilities.pony_utils import TrackedDatabase, db_session, get_or_create

MEMORY = ':memory:'


class TriblerDatabase:
CURRENT_VERSION = 1
_SCHEME_VERSION = 'scheme_version'

def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs):
self.instance = TrackedDatabase()

Expand All @@ -25,21 +31,31 @@ def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True
self.TorrentHealth = self.health.TorrentHealth
self.Tracker = self.health.Tracker

self.instance.bind('sqlite', filename or ':memory:', create_db=True)
filename = filename or MEMORY
db_does_not_exist = filename == MEMORY or not os.path.isfile(filename)

self.instance.bind('sqlite', filename, create_db=db_does_not_exist)
generate_mapping_kwargs['create_tables'] = create_tables
self.instance.generate_mapping(**generate_mapping_kwargs)
self.logger = logging.getLogger(self.__class__.__name__)

if db_does_not_exist:
self.fill_default_data()

@staticmethod
def define_binding(db):
""" Define common bindings"""

class Misc(db.Entity): # pylint: disable=unused-variable
class Misc(db.Entity):
name = orm.PrimaryKey(str)
value = orm.Optional(str)

return Misc

@db_session
def fill_default_data(self):
self.logger.info('Filling the DB with the default data')
self.set_misc(self._SCHEME_VERSION, self.CURRENT_VERSION)

def get_misc(self, key: str, default: Optional[str] = None) -> Optional[str]:
data = self.Misc.get(name=key)
return data.value if data else default
Expand All @@ -48,5 +64,18 @@ def set_misc(self, key: str, value: Any):
key_value = get_or_create(self.Misc, name=key)
key_value.value = str(value)

@property
def version(self) -> int:
""" Get the database version"""
return int(self.get_misc(key=self._SCHEME_VERSION, default=0))

@version.setter
def version(self, value: int):
""" Set the database version"""
if not isinstance(value, int):
raise TypeError('DB version should be integer')

self.set_misc(key=self._SCHEME_VERSION, value=value)

def shutdown(self) -> None:
self.instance.disconnect()
18 changes: 18 additions & 0 deletions src/tribler/core/upgrade/tribler_db/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR


# pylint: disable=redefined-outer-name


@pytest.fixture
def migration_chain(tmpdir):
""" Create an empty migration chain with an empty database."""
db_file_name = Path(tmpdir) / STATEDIR_DB_DIR / 'tribler.db'
db_file_name.parent.mkdir()
TriblerDatabase(filename=str(db_file_name))
return TriblerDatabaseMigrationChain(state_dir=Path(tmpdir), chain=[])
51 changes: 51 additions & 0 deletions src/tribler/core/upgrade/tribler_db/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import functools
import logging
from typing import Callable

from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.utilities.pony_utils import db_session

MIGRATION_METADATA = "_tribler_db_migration"

logger = logging.getLogger('Migration (TriblerDB)')


def migration(execute_only_if_version: int):
""" Decorator for migration functions.
The migration executes in the single transaction. If the migration fails, the transaction is rolled back.
The decorator also sets the metadata attribute to the decorated function. It could be checked by
calling the `has_migration_metadata` function.
After the successful migration, db version is set to the value equal to `execute_only_if_version + 1`.
Args:
execute_only_if_version: Execute the migration only if the current db version is equal to this value.
"""

def decorator(func):
@functools.wraps(func)
@db_session
def wrapper(db: TriblerDatabase, **kwargs):
target_version = execute_only_if_version
if target_version != db.version:
logger.info(
f"Function {func.__name__} is not executed because DB version is not equal to {target_version}. "
f"The current db version is {db.version}"
)
return None

result = func(db, **kwargs)
db.version = target_version + 1

return result

setattr(wrapper, MIGRATION_METADATA, {})
return wrapper

return decorator


def has_migration_metadata(f: Callable):
""" Check if the function has migration metadata."""
return hasattr(f, MIGRATION_METADATA)
53 changes: 53 additions & 0 deletions src/tribler/core/upgrade/tribler_db/migration_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
from typing import Callable, Iterator, List, Optional

from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.upgrade.tribler_db.decorator import has_migration_metadata
from tribler.core.upgrade.tribler_db.scheme_migrations.scheme_migration_0 import scheme_migration_0
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR


class TriblerDatabaseMigrationChain:
""" A chain of migrations that can be executed on a TriblerDatabase.
To create a new migration, create a new function and decorate it with the `migration` decorator. Then add it to
the `DEFAULT_CHAIN` list.
"""

DEFAULT_CHAIN = [
scheme_migration_0,
# add your migration here
]

def __init__(self, state_dir: Path, chain: Optional[List[Callable]] = None):
self.logger = logging.getLogger(self.__class__.__name__)
self.state_dir = state_dir

db_path = self.state_dir / STATEDIR_DB_DIR / 'tribler.db'
self.logger.info(f'Tribler DB path: {db_path}')
self.db = TriblerDatabase(str(db_path), check_tables=False) if db_path.is_file() else None

self.migrations = chain or self.DEFAULT_CHAIN

def execute(self) -> bool:
""" Execute all migrations in the chain.
Returns: True if all migrations were executed successfully, False otherwise.
An exception in any of the migrations will halt the execution chain and be re-raised.
"""

if not self.db:
return False

for _ in self.steps():
...

return True

def steps(self) -> Iterator:
""" Execute migrations step by step."""
for m in self.migrations:
if not has_migration_metadata(m):
raise NotImplementedError(f'The migration {m} should have `migration` decorator')
yield m(self.db, state_dir=self.state_dir)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.upgrade.tribler_db.decorator import migration


@migration(execute_only_if_version=0)
def scheme_migration_0(db: TriblerDatabase, **kwargs): # pylint: disable=unused-argument
""" "This is initial migration, placed here primarily for demonstration purposes.
It doesn't do anything except set the database version to `1`.
For upcoming migrations, there are some guidelines:
1. functions should contain a single parameter, `db: TriblerDatabase`,
2. they should apply the `@migration` decorator.
Utilizing plain SQL (as seen in the example below) is considered good practice since it helps prevent potential
inconsistencies in DB schemes in the future (model versions preceding the current one may differ from it).
For more information see: https://github.com/Tribler/tribler/issues/7382
The example of a migration:
db.execute('ALTER TABLE "TorrentState" ADD "has_data" BOOLEAN DEFAULT 0')
db.execute('UPDATE "TorrentState" SET "has_data" = 1 WHERE last_check > 0')
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain
from tribler.core.upgrade.tribler_db.scheme_migrations.scheme_migration_0 import scheme_migration_0
from tribler.core.utilities.pony_utils import db_session


@db_session
def test_scheme_migration_0(migration_chain: TriblerDatabaseMigrationChain):
""" Test that the scheme_migration_0 changes the database version to 1. """
migration_chain.db.version = 0
migration_chain.migrations = [scheme_migration_0]

assert migration_chain.execute()
assert migration_chain.db.version == 1
56 changes: 56 additions & 0 deletions src/tribler/core/upgrade/tribler_db/tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from unittest.mock import Mock

import pytest

from tribler.core.upgrade.tribler_db.decorator import has_migration_metadata, migration


def test_migration_execute_only_if_version():
""" Test that migration is executed only if the version of the database is equal to the specified one."""

@migration(execute_only_if_version=1)
def test(_: Mock):
return True

assert test(Mock(version=1))
assert not test(Mock(version=2))


def test_migration_set_next_version():
""" Test that the version of the database is set to the next one after the successful migration."""

@migration(execute_only_if_version=1)
def test(_: Mock):
return True

db = Mock(version=1)
assert test(db)
assert db.version == 2


def test_migration_set_next_version_raise_an_exception():
""" Test that if an exception is raised during the migration, the version of the database is not changed."""

@migration(execute_only_if_version=1)
def test(_: Mock):
raise TypeError

db = Mock(version=1)
with pytest.raises(TypeError):
test(db)

assert db.version == 1


def test_set_metadata():
""" Test that the metadata flag is set."""

@migration(execute_only_if_version=1)
def simple_migration(_: Mock):
...

Check warning on line 50 in src/tribler/core/upgrade/tribler_db/tests/test_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/tribler/core/upgrade/tribler_db/tests/test_decorator.py#L50

Added line #L50 was not covered by tests

def no_migration(_: Mock):
...

Check warning on line 53 in src/tribler/core/upgrade/tribler_db/tests/test_decorator.py

View check run for this annotation

Codecov / codecov/patch

src/tribler/core/upgrade/tribler_db/tests/test_decorator.py#L53

Added line #L53 was not covered by tests

assert has_migration_metadata(simple_migration)
assert not has_migration_metadata(no_migration)
Loading

0 comments on commit ef744cb

Please sign in to comment.