From 5bdf408e1959386df50489b867d6e2b2188a9dd2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 16 Oct 2023 14:53:01 +0200 Subject: [PATCH] Move get_db_version to pony_utils; add tests for get_db_version --- src/tribler/core/upgrade/db8_to_db10.py | 31 +------ src/tribler/core/upgrade/upgrade.py | 3 +- src/tribler/core/utilities/pony_utils.py | 33 +++++++- .../core/utilities/tests/test_pony_utils.py | 82 +++++++++++++++++-- 4 files changed, 111 insertions(+), 38 deletions(-) diff --git a/src/tribler/core/upgrade/db8_to_db10.py b/src/tribler/core/upgrade/db8_to_db10.py index c46b5fe1e16..9151a4075cd 100644 --- a/src/tribler/core/upgrade/db8_to_db10.py +++ b/src/tribler/core/upgrade/db8_to_db10.py @@ -8,7 +8,7 @@ from pony.orm import db_session from tribler.core.components.metadata_store.db.store import MetadataStore -from tribler.core.utilities.pony_utils import handle_db_if_corrupted, marking_corrupted_db +from tribler.core.utilities.pony_utils import marking_corrupted_db TABLE_NAMES = ( "ChannelNode", "TorrentState", "TorrentState_TrackerState", "ChannelPeer", "ChannelVote", "TrackerState", "Vsids") @@ -240,32 +240,3 @@ def get_table_columns(db_path, table_name): cursor.execute(f'SELECT * FROM {table_name} LIMIT 1') names = [description[0] for description in cursor.description] return names - - -def table_exists(cursor: sqlite3.Cursor, table_name: str) -> bool: - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) - return cursor.fetchone() is not None - - -def get_db_version(db_path, default: int = None) -> int: - handle_db_if_corrupted(db_path) - if not db_path.exists(): - version = None - else: - with marking_corrupted_db(db_path): - with contextlib.closing(sqlite3.connect(db_path)) as connection, connection: - cursor = connection.cursor() - if not table_exists(cursor, 'MiscData'): - version = None - else: - cursor.execute("SELECT value FROM MiscData WHERE name == 'db_version'") - row = cursor.fetchone() - version = int(row[0]) if row else None - - if version is not None: - return version - - if default is not None: - return default - - raise RuntimeError(f'The version value is not found in database {db_path}') diff --git a/src/tribler/core/upgrade/upgrade.py b/src/tribler/core/upgrade/upgrade.py index d4ecb6c19dd..e8dbabf4271 100644 --- a/src/tribler/core/upgrade/upgrade.py +++ b/src/tribler/core/upgrade/upgrade.py @@ -18,11 +18,12 @@ sql_create_partial_index_torrentstate_last_check, ) from tribler.core.upgrade.config_converter import convert_config_to_tribler76 -from tribler.core.upgrade.db8_to_db10 import PonyToPonyMigration, get_db_version +from tribler.core.upgrade.db8_to_db10 import PonyToPonyMigration from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase from tribler.core.utilities.configparser import CallbackConfigParser from tribler.core.utilities.path_util import Path +from tribler.core.utilities.pony_utils import get_db_version from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR diff --git a/src/tribler/core/utilities/pony_utils.py b/src/tribler/core/utilities/pony_utils.py index e1eace1feea..797ed8e2da5 100644 --- a/src/tribler/core/utilities/pony_utils.py +++ b/src/tribler/core/utilities/pony_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import logging import sqlite3 import sys @@ -37,6 +38,36 @@ class DatabaseIsCorrupted(Exception): pass +def table_exists(cursor: sqlite3.Cursor, table_name: str) -> bool: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + return cursor.fetchone() is not None + + +def get_db_version(db_path, default: int = None) -> int: + handle_db_if_corrupted(db_path) + if not db_path.exists(): + version = None + else: + with marking_corrupted_db(db_path): + with contextlib.closing(sqlite3.connect(db_path)) as connection: + with connection: + cursor = connection.cursor() + if not table_exists(cursor, 'MiscData'): + version = None + else: + cursor.execute("SELECT value FROM MiscData WHERE name == 'db_version'") + row = cursor.fetchone() + version = int(row[0]) if row else None + + if version is not None: + return version + + if default is not None: + return default + + raise RuntimeError(f'The version value is not found in database {db_path}') + + def handle_db_if_corrupted(db_filename: Union[str, Path]): db_path = Path(db_filename) marker_path = _get_corrupted_db_marker_path(db_path) @@ -77,7 +108,7 @@ def _is_malformed_db_exception(exception): def _mark_db_as_corrupted(db_filename: Path): if not db_filename.exists(): - raise RuntimeError(f'Corrupted database file not found: {db_filename!r}') + raise RuntimeError(f'Corrupted database file not found: {db_filename}') marker_path = _get_corrupted_db_marker_path(db_filename) marker_path.touch() diff --git a/src/tribler/core/utilities/tests/test_pony_utils.py b/src/tribler/core/utilities/tests/test_pony_utils.py index 5995dcc1500..3f79adc542f 100644 --- a/src/tribler/core/utilities/tests/test_pony_utils.py +++ b/src/tribler/core/utilities/tests/test_pony_utils.py @@ -6,7 +6,9 @@ from pony.orm.core import QueryStat, Required from tribler.core.utilities import pony_utils -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, handle_db_if_corrupted, marking_corrupted_db +from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, _mark_db_as_corrupted, get_db_version, \ + handle_db_if_corrupted, \ + marking_corrupted_db, table_exists EMPTY_DICT = {} @@ -125,19 +127,19 @@ def test_format_warning(): @pytest.fixture(name='db_path') -def db_path_fixture(tmp_path): +def db_path_fixture(tmp_path: Path): db_path = tmp_path / 'test.db' db_path.touch() return db_path @patch('tribler.core.utilities.pony_utils._handle_corrupted_db') -def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_path): +def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_path: Path): handle_db_if_corrupted(db_path) handle_corrupted_db.assert_not_called() -def test_handle_db_if_corrupted__corrupted(db_path): +def test_handle_db_if_corrupted__corrupted(db_path: Path): marker_path = Path(str(db_path) + '.is_corrupted') marker_path.touch() @@ -146,7 +148,7 @@ def test_handle_db_if_corrupted__corrupted(db_path): assert not marker_path.exists() -def test_marking_corrupted_db__not_malformed(db_path): +def test_marking_corrupted_db__not_malformed(db_path: Path): with pytest.raises(ZeroDivisionError): with marking_corrupted_db(db_path): raise ZeroDivisionError() @@ -154,9 +156,77 @@ def test_marking_corrupted_db__not_malformed(db_path): assert not Path(str(db_path) + '.is_corrupted').exists() -def test_marking_corrupted_db__malformed(db_path): +def test_marking_corrupted_db__malformed(db_path: Path): with pytest.raises(DatabaseIsCorrupted): with marking_corrupted_db(db_path): raise sqlite3.DatabaseError('database disk image is malformed') assert Path(str(db_path) + '.is_corrupted').exists() + + +def test_get_db_version__db_does_not_exist(tmp_path: Path): + db_path = tmp_path / 'doesnotexist.db' + + with pytest.raises(RuntimeError, match='^The version value is not found in database .*doesnotexist.db$'): + get_db_version(db_path) + + +def test_get_db_version__db_does_not_exist_default_version(tmp_path: Path): + db_path = tmp_path / 'doesnotexist.db' + + default_version = 123 + version = get_db_version(db_path, default=default_version) + assert version == default_version + + +def test_get_db_version__version_table_does_not_exist(db_path: Path): + with pytest.raises(RuntimeError, match='^The version value is not found in database .*test.db$'): + get_db_version(db_path) + + +def test_get_db_version__version_table_does_not_exist_default_version(db_path: Path): + default_version = 123 + version = get_db_version(db_path, default=default_version) + assert version == default_version + + +def test_get_db_version(db_path: Path): + with sqlite3.connect(db_path) as connection: + connection.execute('create table MiscData(name text primary key, value text)') + connection.execute("insert into MiscData(name, value) values('db_version', '100')") + + version = get_db_version(db_path) + assert version == 100 + + version = get_db_version(db_path, default=99) + assert version == 100 + + version = get_db_version(db_path, default=101) + assert version == 100 + + +def test_get_db_version__corrupted_db(tmp_path: Path): + db_path = tmp_path / 'test.db' + + connection = sqlite3.connect(db_path) + connection.execute('create table MiscData(name text primary key, value text)') + connection.execute("insert into MiscData(name, value) values('db_version', '100')") + connection.commit() + assert table_exists(connection.cursor(), 'MiscData') + connection.close() + + marker_path = Path(str(db_path) + '.is_corrupted') + marker_path.touch() + + default_version = 10 + version = get_db_version(db_path, default=default_version) + assert version == default_version + + with sqlite3.connect(db_path) as connection: + assert not table_exists(connection.cursor(), 'MiscData') + + +def test_mark_db_as_corrupted_file_does_not_exist(tmp_path: Path): + db_path = tmp_path / 'doesnotexist.db' + with pytest.raises(RuntimeError, match='^Corrupted database file not found: .*doesnotexist.db$'): + _mark_db_as_corrupted(db_path)