Skip to content

Commit

Permalink
Move get_db_version to pony_utils; add tests for get_db_version
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Oct 16, 2023
1 parent 171620f commit 394d7f6
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 38 deletions.
31 changes: 1 addition & 30 deletions src/tribler/core/upgrade/db8_to_db10.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pony.orm import db_session

from tribler.core.components.metadata_store.db.store import MetadataStore
from tribler.core.utilities.pony_utils import handle_db_if_corrupted, marking_corrupted_db
from tribler.core.utilities.pony_utils import marking_corrupted_db

TABLE_NAMES = (
"ChannelNode", "TorrentState", "TorrentState_TrackerState", "ChannelPeer", "ChannelVote", "TrackerState", "Vsids")
Expand Down Expand Up @@ -240,32 +240,3 @@ def get_table_columns(db_path, table_name):
cursor.execute(f'SELECT * FROM {table_name} LIMIT 1')
names = [description[0] for description in cursor.description]
return names


def table_exists(cursor: sqlite3.Cursor, table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
return cursor.fetchone() is not None


def get_db_version(db_path, default: int = None) -> int:
handle_db_if_corrupted(db_path)
if not db_path.exists():
version = None
else:
with marking_corrupted_db(db_path):
with contextlib.closing(sqlite3.connect(db_path)) as connection, connection:
cursor = connection.cursor()
if not table_exists(cursor, 'MiscData'):
version = None
else:
cursor.execute("SELECT value FROM MiscData WHERE name == 'db_version'")
row = cursor.fetchone()
version = int(row[0]) if row else None

if version is not None:
return version

if default is not None:
return default

raise RuntimeError(f'The version value is not found in database {db_path}')
3 changes: 2 additions & 1 deletion src/tribler/core/upgrade/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
sql_create_partial_index_torrentstate_last_check,
)
from tribler.core.upgrade.config_converter import convert_config_to_tribler76
from tribler.core.upgrade.db8_to_db10 import PonyToPonyMigration, get_db_version
from tribler.core.upgrade.db8_to_db10 import PonyToPonyMigration
from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge
from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase
from tribler.core.utilities.configparser import CallbackConfigParser
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import get_db_version
from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR


Expand Down
31 changes: 30 additions & 1 deletion src/tribler/core/utilities/pony_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import logging
import sqlite3
import sys
Expand Down Expand Up @@ -37,6 +38,34 @@ class DatabaseIsCorrupted(Exception):
pass


def table_exists(cursor: sqlite3.Cursor, table_name: str) -> bool:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
return cursor.fetchone() is not None


def get_db_version(db_path, default: int = None) -> int:
handle_db_if_corrupted(db_path)
if not db_path.exists():
version = None
else:
with marking_corrupted_db(db_path), contextlib.closing(sqlite3.connect(db_path)) as connection, connection:
cursor = connection.cursor()
if not table_exists(cursor, 'MiscData'):
version = None
else:
cursor.execute("SELECT value FROM MiscData WHERE name == 'db_version'")
row = cursor.fetchone()
version = int(row[0]) if row else None

if version is not None:
return version

if default is not None:
return default

raise RuntimeError(f'The version value is not found in database {db_path}')


def handle_db_if_corrupted(db_filename: Union[str, Path]):
db_path = Path(db_filename)
marker_path = _get_corrupted_db_marker_path(db_path)
Expand Down Expand Up @@ -77,7 +106,7 @@ def _is_malformed_db_exception(exception):

def _mark_db_as_corrupted(db_filename: Path):
if not db_filename.exists():
raise RuntimeError(f'Corrupted database file not found: {db_filename!r}')
raise RuntimeError(f'Corrupted database file not found: {db_filename}')

marker_path = _get_corrupted_db_marker_path(db_filename)
marker_path.touch()
Expand Down
82 changes: 76 additions & 6 deletions src/tribler/core/utilities/tests/test_pony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from pony.orm.core import QueryStat, Required
from tribler.core.utilities import pony_utils
from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, handle_db_if_corrupted, marking_corrupted_db
from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, _mark_db_as_corrupted, get_db_version, \
handle_db_if_corrupted, \
marking_corrupted_db, table_exists

EMPTY_DICT = {}

Expand Down Expand Up @@ -125,19 +127,19 @@ def test_format_warning():


@pytest.fixture(name='db_path')
def db_path_fixture(tmp_path):
def db_path_fixture(tmp_path: Path):
db_path = tmp_path / 'test.db'
db_path.touch()
return db_path


@patch('tribler.core.utilities.pony_utils._handle_corrupted_db')
def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_path):
def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_path: Path):
handle_db_if_corrupted(db_path)
handle_corrupted_db.assert_not_called()


def test_handle_db_if_corrupted__corrupted(db_path):
def test_handle_db_if_corrupted__corrupted(db_path: Path):
marker_path = Path(str(db_path) + '.is_corrupted')
marker_path.touch()

Expand All @@ -146,17 +148,85 @@ def test_handle_db_if_corrupted__corrupted(db_path):
assert not marker_path.exists()


def test_marking_corrupted_db__not_malformed(db_path):
def test_marking_corrupted_db__not_malformed(db_path: Path):
with pytest.raises(ZeroDivisionError):
with marking_corrupted_db(db_path):
raise ZeroDivisionError()

assert not Path(str(db_path) + '.is_corrupted').exists()


def test_marking_corrupted_db__malformed(db_path):
def test_marking_corrupted_db__malformed(db_path: Path):
with pytest.raises(DatabaseIsCorrupted):
with marking_corrupted_db(db_path):
raise sqlite3.DatabaseError('database disk image is malformed')

assert Path(str(db_path) + '.is_corrupted').exists()


def test_get_db_version__db_does_not_exist(tmp_path: Path):
db_path = tmp_path / 'doesnotexist.db'

with pytest.raises(RuntimeError, match='^The version value is not found in database .*doesnotexist.db$'):
get_db_version(db_path)


def test_get_db_version__db_does_not_exist_default_version(tmp_path: Path):
db_path = tmp_path / 'doesnotexist.db'

default_version = 123
version = get_db_version(db_path, default=default_version)
assert version == default_version


def test_get_db_version__version_table_does_not_exist(db_path: Path):
with pytest.raises(RuntimeError, match='^The version value is not found in database .*test.db$'):
get_db_version(db_path)


def test_get_db_version__version_table_does_not_exist_default_version(db_path: Path):
default_version = 123
version = get_db_version(db_path, default=default_version)
assert version == default_version


def test_get_db_version(db_path: Path):
with sqlite3.connect(db_path) as connection:
connection.execute('create table MiscData(name text primary key, value text)')
connection.execute("insert into MiscData(name, value) values('db_version', '100')")

version = get_db_version(db_path)
assert version == 100

version = get_db_version(db_path, default=99)
assert version == 100

version = get_db_version(db_path, default=101)
assert version == 100


def test_get_db_version__corrupted_db(tmp_path: Path):
db_path = tmp_path / 'test.db'

connection = sqlite3.connect(db_path)
connection.execute('create table MiscData(name text primary key, value text)')
connection.execute("insert into MiscData(name, value) values('db_version', '100')")
connection.commit()
assert table_exists(connection.cursor(), 'MiscData')
connection.close()

marker_path = Path(str(db_path) + '.is_corrupted')
marker_path.touch()

default_version = 10
version = get_db_version(db_path, default=default_version)
assert version == default_version

with sqlite3.connect(db_path) as connection:
assert not table_exists(connection.cursor(), 'MiscData')


def test_mark_db_as_corrupted_file_does_not_exist(tmp_path: Path):
db_path = tmp_path / 'doesnotexist.db'
with pytest.raises(RuntimeError, match='^Corrupted database file not found: .*doesnotexist.db$'):
_mark_db_as_corrupted(db_path)

0 comments on commit 394d7f6

Please sign in to comment.