diff --git a/src/tribler/core/components/bandwidth_accounting/db/database.py b/src/tribler/core/components/bandwidth_accounting/db/database.py index fcc6435eded..6594b95c9b7 100644 --- a/src/tribler/core/components/bandwidth_accounting/db/database.py +++ b/src/tribler/core/components/bandwidth_accounting/db/database.py @@ -5,6 +5,7 @@ from tribler.core.components.bandwidth_accounting.db import history, misc, transaction as db_transaction from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData +from tribler.core.utilities.db_corruption_handling.base import handle_db_if_corrupted from tribler.core.utilities.pony_utils import TriblerDatabase from tribler.core.utilities.utilities import MEMORY_DB @@ -34,7 +35,7 @@ def __init__(self, db_path: Union[Path, type(MEMORY_DB)], my_pub_key: bytes, # with the static analysis. # pylint: disable=unused-variable - @self.database.on_connect(provider='sqlite') + @self.database.on_connect def sqlite_sync_pragmas(_, connection): cursor = connection.cursor() cursor.execute("PRAGMA journal_mode = WAL") @@ -50,6 +51,8 @@ def sqlite_sync_pragmas(_, connection): create_db = True db_path_string = ":memory:" else: + # We need to handle the database corruption case before determining the state of the create_db flag. + handle_db_if_corrupted(db_path) create_db = not db_path.is_file() db_path_string = str(db_path) diff --git a/src/tribler/core/components/component.py b/src/tribler/core/components/component.py index 7a10dcee4f3..513fa737526 100644 --- a/src/tribler/core/components/component.py +++ b/src/tribler/core/components/component.py @@ -9,6 +9,9 @@ from tribler.core.components.exceptions import ComponentStartupException, MissedDependency, NoneComponent from tribler.core.components.reporter.exception_handler import default_core_exception_handler from tribler.core.sentry_reporter.sentry_reporter import SentryReporter +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted +from tribler.core.utilities.exit_codes import EXITCODE_DATABASE_IS_CORRUPTED +from tribler.core.utilities.process_manager import get_global_process_manager if TYPE_CHECKING: from tribler.core.components.session import Session, T @@ -47,8 +50,26 @@ async def start(self): self._set_component_status(msg, logging.ERROR, exc_info=exc_info) self.failed = True self.started_event.set() + + if isinstance(e, DatabaseIsCorrupted): + # When the database corruption is detected, we should stop the process immediately. + # Tribler GUI will restart the process and the database will be recreated. + + # Usually we wrap an exception into ComponentStartupException, and allow + # CoreExceptionHandler.unhandled_error_observer to handle it after all components are started, + # but in this case we don't do it. The reason is that handling ComponentStartupException + # starts the shutting down of Tribler, and due to some obscure reasons it is not possible to + # raise any exception, even SystemExit, from CoreExceptionHandler.unhandled_error_observer when + # Tribler is shutting down. It looks like in this case unhandled_error_observer is called from + # Task.__del__ method and all exceptions that are raised from __del__ are ignored. + # See https://bugs.python.org/issue25489 for similar case. + process_manager = get_global_process_manager() + process_manager.sys_exit(EXITCODE_DATABASE_IS_CORRUPTED, e) + return # Added for clarity; actually, the code raised SystemExit on the previous line + if self.session.failfast: raise e + self.session.set_startup_exception(ComponentStartupException(self, e)) self.started_event.set() diff --git a/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py b/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py index 0db52720f71..529b6048e62 100644 --- a/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py +++ b/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py @@ -12,6 +12,7 @@ from tribler.core.components.metadata_store.db.orm_bindings.channel_node import COMMITTED from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.notifier import Notifier from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.simpledefs import DownloadStatus @@ -80,6 +81,8 @@ async def check_and_regen_personal_channels(self): channel.id_, channel_download, ) + except DatabaseIsCorrupted: + raise # re-raise this exception and terminate the Core process except Exception: self._logger.exception("Error when tried to resume personal channel seeding on GigaChannel Manager startup") diff --git a/src/tribler/core/components/knowledge/db/knowledge_db.py b/src/tribler/core/components/knowledge/db/knowledge_db.py index 4950cd1a1ee..4feab9994c5 100644 --- a/src/tribler/core/components/knowledge/db/knowledge_db.py +++ b/src/tribler/core/components/knowledge/db/knowledge_db.py @@ -66,7 +66,7 @@ class KnowledgeDatabase: def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): self.instance = TriblerDatabase() self.define_binding(self.instance) - self.instance.bind('sqlite', filename or ':memory:', create_db=True) + self.instance.bind(provider='sqlite', filename=filename or ':memory:', create_db=True) generate_mapping_kwargs['create_tables'] = create_tables self.instance.generate_mapping(**generate_mapping_kwargs) self.logger = logging.getLogger(self.__class__.__name__) diff --git a/src/tribler/core/components/metadata_store/db/store.py b/src/tribler/core/components/metadata_store/db/store.py index 7f8e79e696b..b8addf25816 100644 --- a/src/tribler/core/components/metadata_store/db/store.py +++ b/src/tribler/core/components/metadata_store/db/store.py @@ -46,6 +46,7 @@ from tribler.core.components.metadata_store.remote_query_community.payload_checker import process_payload from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo from tribler.core.exceptions import InvalidSignatureException +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted, handle_db_if_corrupted from tribler.core.utilities.notifier import Notifier from tribler.core.utilities.path_util import Path from tribler.core.utilities.pony_utils import TriblerDatabase, get_max, get_or_create, run_threaded @@ -166,7 +167,7 @@ def __init__( # This attribute is internally called by Pony on startup, though pylint cannot detect it # with the static analysis. # pylint: disable=unused-variable - @self.db.on_connect(provider='sqlite') + @self.db.on_connect def on_connect(_, connection): cursor = connection.cursor() cursor.execute("PRAGMA journal_mode = WAL") @@ -218,6 +219,8 @@ def on_connect(_, connection): create_db = True db_path_string = ":memory:" else: + # We need to handle the database corruption case before determining the state of the create_db flag. + handle_db_if_corrupted(db_filename) create_db = not db_filename.is_file() db_path_string = str(db_filename) @@ -450,9 +453,12 @@ def process_mdblob_file(self, filepath, **kwargs): async def process_compressed_mdblob_threaded(self, compressed_data, **kwargs): try: return await run_threaded(self.db, self.process_compressed_mdblob, compressed_data, **kwargs) + except DatabaseIsCorrupted: + raise # re-raise this exception and terminate the Core process except Exception as e: # pylint: disable=broad-except # pragma: no cover - self._logger.warning("DB transaction error when tried to process compressed mdblob: %s", str(e)) - return None + self._logger.exception("DB transaction error when tried to process compressed mdblob: " + f"{e.__class__.__name__}: {e}", exc_info=e) + return [] def process_compressed_mdblob(self, compressed_data, **kwargs): try: diff --git a/src/tribler/core/components/reporter/exception_handler.py b/src/tribler/core/components/reporter/exception_handler.py index 1b4d4a0be44..018dcc67f3d 100644 --- a/src/tribler/core/components/reporter/exception_handler.py +++ b/src/tribler/core/components/reporter/exception_handler.py @@ -10,6 +10,8 @@ from tribler.core.components.exceptions import ComponentStartupException from tribler.core.components.reporter.reported_error import ReportedError from tribler.core.sentry_reporter.sentry_reporter import SentryReporter +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted +from tribler.core.utilities.exit_codes import EXITCODE_DATABASE_IS_CORRUPTED from tribler.core.utilities.process_manager import get_global_process_manager # There are some errors that we are ignoring. @@ -93,12 +95,18 @@ def unhandled_error_observer(self, _, context): should_stop = context.pop('should_stop', True) message = context.pop('message', 'no message') exception = context.pop('exception', None) or self._create_exception_from(message) - # Exception - text = str(exception) + + self.logger.exception(f'{exception.__class__.__name__}: {exception}', exc_info=exception) + + if isinstance(exception, DatabaseIsCorrupted): + process_manager.sys_exit(EXITCODE_DATABASE_IS_CORRUPTED, exception) + return # Added for clarity; actually, the code raised SystemExit on the previous line + if isinstance(exception, ComponentStartupException): self.logger.info('The exception is ComponentStartupException') should_stop = exception.component.tribler_should_stop_on_component_error exception = exception.__cause__ + if isinstance(exception, NoCrashException): self.logger.info('The exception is NoCrashException') should_stop = False @@ -113,7 +121,7 @@ def unhandled_error_observer(self, _, context): reported_error = ReportedError( type=exception.__class__.__name__, - text=text, + text=str(exception), long_text=long_text, context=str(context), event=self.sentry_reporter.event_from_exception(exception) or {}, diff --git a/src/tribler/core/components/reporter/tests/test_exception_handler.py b/src/tribler/core/components/reporter/tests/test_exception_handler.py index 41bf79779b7..3b2ecb6cf38 100644 --- a/src/tribler/core/components/reporter/tests/test_exception_handler.py +++ b/src/tribler/core/components/reporter/tests/test_exception_handler.py @@ -7,6 +7,7 @@ from tribler.core.components.reporter.exception_handler import CoreExceptionHandler from tribler.core.sentry_reporter import sentry_reporter from tribler.core.sentry_reporter.sentry_reporter import SentryReporter +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted # pylint: disable=protected-access, redefined-outer-name @@ -85,6 +86,17 @@ def test_unhandled_error_observer_exception(exception_handler): assert reported_error.should_stop +@patch('tribler.core.components.reporter.exception_handler.get_global_process_manager') +def test_unhandled_error_observer_database_corrupted(get_global_process_manager, exception_handler): + # test that database corruption exception reported to the GUI + exception = DatabaseIsCorrupted('db_path_string') + exception_handler.report_callback = MagicMock() + exception_handler.unhandled_error_observer(None, {'exception': exception}) + + get_global_process_manager().sys_exit.assert_called_once_with(99, exception) + exception_handler.report_callback.assert_not_called() + + def test_unhandled_error_observer_only_message(exception_handler): # test that unhandled exception, represented by message, reported to the GUI context = {'message': 'Any'} diff --git a/src/tribler/core/components/restapi/rest/rest_manager.py b/src/tribler/core/components/restapi/rest/rest_manager.py index 413e41ac756..6f76657d2b0 100644 --- a/src/tribler/core/components/restapi/rest/rest_manager.py +++ b/src/tribler/core/components/restapi/rest/rest_manager.py @@ -67,9 +67,7 @@ async def error_middleware(request, handler): 'message': f'Request size is larger than {MAX_REQUEST_SIZE} bytes' }}, status=HTTP_REQUEST_ENTITY_TOO_LARGE) except Exception as e: - logger.exception(e) full_exception = traceback.format_exc() - default_core_exception_handler.unhandled_error_observer(None, {'exception': e, 'should_stop': False}) return RESTResponse({"error": { diff --git a/src/tribler/core/components/tests/test_base_component.py b/src/tribler/core/components/tests/test_base_component.py index 49cfd7a0740..4f3bf0683b2 100644 --- a/src/tribler/core/components/tests/test_base_component.py +++ b/src/tribler/core/components/tests/test_base_component.py @@ -1,9 +1,12 @@ +from unittest.mock import patch + import pytest from tribler.core.components.component import Component from tribler.core.components.exceptions import MissedDependency, MultipleComponentsFound, NoneComponent from tribler.core.components.session import Session from tribler.core.config.tribler_config import TriblerConfig +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted class ComponentTestException(Exception): @@ -46,6 +49,20 @@ class TestComponentB(TestComponent): assert component.stopped +@patch('tribler.core.components.component.get_global_process_manager') +async def test_session_start_database_corruption_detected(get_global_process_manager): + exception = DatabaseIsCorrupted('db_path_string') + + class TestComponent(Component): + async def run(self): + raise exception + + component = TestComponent() + + await component.start() + get_global_process_manager().sys_exit.assert_called_once_with(99, exception) + + class ComponentA(Component): pass diff --git a/src/tribler/core/upgrade/db8_to_db10.py b/src/tribler/core/upgrade/db8_to_db10.py index a6ee17ae250..68534b56456 100644 --- a/src/tribler/core/upgrade/db8_to_db10.py +++ b/src/tribler/core/upgrade/db8_to_db10.py @@ -1,13 +1,13 @@ import contextlib import datetime import logging -import sqlite3 from collections import deque from time import time as now from pony.orm import db_session from tribler.core.components.metadata_store.db.store import MetadataStore +from tribler.core.utilities.db_corruption_handling import sqlite_replacement TABLE_NAMES = ( "ChannelNode", "TorrentState", "TorrentState_TrackerState", "ChannelPeer", "ChannelVote", "TrackerState", "Vsids") @@ -126,31 +126,31 @@ def convert_command(offset, batch_size): def do_migration(self): result = None # estimated duration in seconds of ChannelNode table copying time try: - old_table_columns = {} for table_name in TABLE_NAMES: old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name) - with contextlib.closing(sqlite3.connect(self.new_db_path)) as connection, connection: - cursor = connection.cursor() - cursor.execute("PRAGMA journal_mode = OFF;") - cursor.execute("PRAGMA synchronous = OFF;") - cursor.execute("PRAGMA foreign_keys = OFF;") - cursor.execute("PRAGMA temp_store = MEMORY;") - cursor.execute("PRAGMA cache_size = -204800;") - cursor.execute(f'ATTACH DATABASE "{self.old_db_path}" as old_db;') - - for table_name in TABLE_NAMES: - t1 = now() - cursor.execute("BEGIN TRANSACTION;") - if not self.must_shutdown(): - self.convert_table(cursor, table_name, old_table_columns[table_name]) - cursor.execute("COMMIT;") - duration = now() - t1 - self._logger.info(f"Upgrade: copied table {table_name} in {duration:.2f} seconds") - - if table_name == 'ChannelNode': - result = duration + with contextlib.closing(sqlite_replacement.connect(self.new_db_path)) as connection: + with connection: + cursor = connection.cursor() + cursor.execute("PRAGMA journal_mode = OFF;") + cursor.execute("PRAGMA synchronous = OFF;") + cursor.execute("PRAGMA foreign_keys = OFF;") + cursor.execute("PRAGMA temp_store = MEMORY;") + cursor.execute("PRAGMA cache_size = -204800;") + cursor.execute(f'ATTACH DATABASE "{self.old_db_path}" as old_db;') + + for table_name in TABLE_NAMES: + t1 = now() + cursor.execute("BEGIN TRANSACTION;") + if not self.must_shutdown(): + self.convert_table(cursor, table_name, old_table_columns[table_name]) + cursor.execute("COMMIT;") + duration = now() - t1 + self._logger.info(f"Upgrade: copied table {table_name} in {duration:.2f} seconds") + + if table_name == 'ChannelNode': + result = duration self.update_status("Synchronizing the upgraded DB to disk, please wait.") except Exception as e: @@ -234,16 +234,8 @@ def calc_progress(duration_now, duration_half=60.0): def get_table_columns(db_path, table_name): - with contextlib.closing(sqlite3.connect(db_path)) as connection, connection: + with contextlib.closing(sqlite_replacement.connect(db_path)) as connection, connection: cursor = connection.cursor() cursor.execute(f'SELECT * FROM {table_name} LIMIT 1') names = [description[0] for description in cursor.description] return names - - -def get_db_version(db_path): - with contextlib.closing(sqlite3.connect(db_path)) as connection, connection: - cursor = connection.cursor() - cursor.execute('SELECT value FROM MiscData WHERE name == "db_version"') - version = int(cursor.fetchone()[0]) - return version diff --git a/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py b/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py index f9890118e60..7f00ab93742 100644 --- a/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py +++ b/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py @@ -10,7 +10,7 @@ class TagDatabase: def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): self.instance = TriblerDatabase() self.define_binding(self.instance) - self.instance.bind('sqlite', filename or ':memory:', create_db=True) + self.instance.bind(provider='sqlite', filename=filename or ':memory:', create_db=True) generate_mapping_kwargs['create_tables'] = create_tables self.instance.generate_mapping(**generate_mapping_kwargs) diff --git a/src/tribler/core/upgrade/tests/test_upgrader.py b/src/tribler/core/upgrade/tests/test_upgrader.py index e6b1a369c92..91d7bb6ab66 100644 --- a/src/tribler/core/upgrade/tests/test_upgrader.py +++ b/src/tribler/core/upgrade/tests/test_upgrader.py @@ -3,7 +3,7 @@ import time from pathlib import Path from typing import Set -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from ipv8.keyvault.private.libnaclkey import LibNaCLSK @@ -15,8 +15,10 @@ from tribler.core.tests.tools.common import TESTS_DATA_DIR from tribler.core.upgrade.db8_to_db10 import calc_progress from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase -from tribler.core.upgrade.upgrade import TriblerUpgrader, cleanup_noncompliant_channel_torrents +from tribler.core.upgrade.upgrade import TriblerUpgrader, catch_db_is_corrupted_exception, \ + cleanup_noncompliant_channel_torrents from tribler.core.utilities.configparser import CallbackConfigParser +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.utilities import random_infohash @@ -55,6 +57,35 @@ def _copy(source_name, target): shutil.copyfile(source, target) +def test_catch_db_is_corrupted_exception_with_exception(): + upgrader = Mock(_db_is_corrupted_exception=None) + upgrader_method = Mock(side_effect=DatabaseIsCorrupted()) + decorated_method = catch_db_is_corrupted_exception(upgrader_method) + + # Call the decorated method and expect it to catch the exception + decorated_method(upgrader) + upgrader_method.assert_called_once() + + # Check if the exception was caught and stored + upgrader_method.assert_called_once() + assert isinstance(upgrader._db_is_corrupted_exception, DatabaseIsCorrupted) + upgrader._logger.exception.assert_called_once() + + +def test_catch_db_is_corrupted_exception_without_exception(): + upgrader = Mock(_db_is_corrupted_exception=None) + upgrader_method = Mock() + decorated_method = catch_db_is_corrupted_exception(upgrader_method) + + # Call the decorated method and expect it to run without exceptions + decorated_method(upgrader) + + # Check if the method was called and no exception was stored + upgrader_method.assert_called_once() + assert upgrader._db_is_corrupted_exception is None + upgrader._logger.exception.assert_not_called() + + def test_upgrade_pony_db_complete(upgrader, channels_dir, state_dir, trustchain_keypair, mds_path): # pylint: disable=W0621 """ diff --git a/src/tribler/core/upgrade/upgrade.py b/src/tribler/core/upgrade/upgrade.py index 3d490e1d537..0a8d4e9f478 100644 --- a/src/tribler/core/upgrade/upgrade.py +++ b/src/tribler/core/upgrade/upgrade.py @@ -3,6 +3,7 @@ import shutil import time from configparser import MissingSectionHeaderError, ParsingError +from functools import wraps from types import SimpleNamespace from typing import List, Optional, Tuple @@ -12,17 +13,19 @@ from tribler.core.components.bandwidth_accounting.db.database import BandwidthDatabase from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import CHANNEL_DIR_NAME_LENGTH from tribler.core.components.metadata_store.db.store import ( - MetadataStore, + CURRENT_DB_VERSION, MetadataStore, sql_create_partial_index_channelnode_metadata_type, sql_create_partial_index_channelnode_subscribed, sql_create_partial_index_torrentstate_last_check, ) from tribler.core.upgrade.config_converter import convert_config_to_tribler76 -from tribler.core.upgrade.db8_to_db10 import PonyToPonyMigration, get_db_version +from tribler.core.upgrade.db8_to_db10 import PonyToPonyMigration from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase from tribler.core.utilities.configparser import CallbackConfigParser +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.path_util import Path +from tribler.core.utilities.pony_utils import get_db_version from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR @@ -68,6 +71,35 @@ def cleanup_noncompliant_channel_torrents(state_dir): file_path) +def catch_db_is_corrupted_exception(upgrader_method): + # This decorator applied for TriblerUpgrader methods. It suppresses and remembers the DatabaseIsCorrupted exception. + # As a result, if one upgrade method raises an exception, the following upgrade methods are still executed. + # + # The reason for this is the following: it is possible that one upgrade method upgrades database A + # while the following upgrade method upgrades database B. If a corruption is detected in the database A, + # the database B still needs to be upgraded. So, we want to temporarily suppress the DatabaseIsCorrupted exception + # until all upgrades are executed. + # + # If an upgrade finds the database to be corrupted, the database is marked as corrupted. Then, the next upgrade + # will rename the corrupted database file (the get_db_version call handles this) and immediately return because + # there is no database to upgrade. So, if one upgrade function detects database corruption, all the following + # upgrade functions for this specific database will skip the actual upgrade. As a result, a new database with + # the current DB version will be created on the Tribler Core start. + + @wraps(upgrader_method) + def new_method(*args, **kwargs): + try: + upgrader_method(*args, **kwargs) + except DatabaseIsCorrupted as exc: + self: TriblerUpgrader = args[0] + self._logger.exception(exc) + + if not self._db_is_corrupted_exception: + self._db_is_corrupted_exception = exc # Suppress and remember the exception to re-raise it later + + return new_method + + class TriblerUpgrader: def __init__(self, state_dir: Path, channels_dir: Path, primary_key: LibNaCLSK, secondary_key: Optional[LibNaCLSK], @@ -83,6 +115,7 @@ def __init__(self, state_dir: Path, channels_dir: Path, primary_key: LibNaCLSK, self.failed = True self._pony2pony = None + self._db_is_corrupted_exception: Optional[DatabaseIsCorrupted] = None @property def shutting_down(self): @@ -105,6 +138,12 @@ def run(self): self.remove_old_logs() self.upgrade_pony_db_14to15() + if self._db_is_corrupted_exception: + # The current code is executed in the worker's thread. After all upgrade methods are executed, + # we re-raise the delayed exception, and then it is received and handled in the main thread + # by the UpgradeManager.on_worker_finished signal handler. + raise self._db_is_corrupted_exception # pylint: disable=raising-bad-type + def remove_old_logs(self) -> Tuple[List[Path], List[Path]]: self._logger.info(f'Remove old logs') @@ -126,14 +165,19 @@ def remove_old_logs(self) -> Tuple[List[Path], List[Path]]: return removed_files, left_files + @catch_db_is_corrupted_exception def upgrade_tags_to_knowledge(self): self._logger.info('Upgrade tags to knowledge') migration = MigrationTagsToKnowledge(self.state_dir, self.secondary_key) migration.run() + @catch_db_is_corrupted_exception def upgrade_pony_db_14to15(self): self._logger.info('Upgrade Pony DB from version 14 to version 15') mds_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' + if not mds_path.exists() or get_db_version(mds_path, CURRENT_DB_VERSION) > 14: + # No need to update if the database does not exist or is already updated + return # pragma: no cover mds = MetadataStore(mds_path, self.channels_dir, self.primary_key, disable_sync=True, check_tables=False, db_version=14) if mds_path.exists() else None @@ -142,11 +186,16 @@ def upgrade_pony_db_14to15(self): if mds: mds.shutdown() + @catch_db_is_corrupted_exception def upgrade_pony_db_13to14(self): self._logger.info('Upgrade Pony DB from version 13 to version 14') mds_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' tagdb_path = self.state_dir / STATEDIR_DB_DIR / 'tags.db' + if not mds_path.exists() or get_db_version(mds_path, CURRENT_DB_VERSION) > 13: + # No need to update if the database does not exist or is already updated + return # pragma: no cover + mds = MetadataStore(mds_path, self.channels_dir, self.primary_key, disable_sync=True, check_tables=False, db_version=13) if mds_path.exists() else None tag_db = TagDatabase(str(tagdb_path), create_tables=False, @@ -158,6 +207,7 @@ def upgrade_pony_db_13to14(self): if tag_db: tag_db.shutdown() + @catch_db_is_corrupted_exception def upgrade_pony_db_12to13(self): """ Upgrade GigaChannel DB from version 12 (7.9.x) to version 13 (7.11.x). @@ -166,12 +216,16 @@ def upgrade_pony_db_12to13(self): self._logger.info('Upgrade Pony DB 12 to 13') # We have to create the Metadata Store object because Session-managed Store has not been started yet database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' - if database_path.exists(): - mds = MetadataStore(database_path, self.channels_dir, self.primary_key, - disable_sync=True, check_tables=False, db_version=12) - self.do_upgrade_pony_db_12to13(mds) - mds.shutdown() + if not database_path.exists() or get_db_version(database_path, CURRENT_DB_VERSION) > 12: + # No need to update if the database does not exist or is already updated + return # pragma: no cover + + mds = MetadataStore(database_path, self.channels_dir, self.primary_key, + disable_sync=True, check_tables=False, db_version=12) + self.do_upgrade_pony_db_12to13(mds) + mds.shutdown() + @catch_db_is_corrupted_exception def upgrade_pony_db_11to12(self): """ Upgrade GigaChannel DB from version 11 (7.8.x) to version 12 (7.9.x). @@ -181,13 +235,16 @@ def upgrade_pony_db_11to12(self): self._logger.info('Upgrade Pony DB 11 to 12') # We have to create the Metadata Store object because Session-managed Store has not been started yet database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' - if not database_path.exists(): - return + if not database_path.exists() or get_db_version(database_path, CURRENT_DB_VERSION) > 11: + # No need to update if the database does not exist or is already updated + return # pragma: no cover + mds = MetadataStore(database_path, self.channels_dir, self.primary_key, disable_sync=True, check_tables=False, db_version=11) self.do_upgrade_pony_db_11to12(mds) mds.shutdown() + @catch_db_is_corrupted_exception def upgrade_pony_db_10to11(self): """ Upgrade GigaChannel DB from version 10 (7.6.x) to version 11 (7.7.x). @@ -197,14 +254,17 @@ def upgrade_pony_db_10to11(self): self._logger.info('Upgrade Pony DB 10 to 11') # We have to create the Metadata Store object because Session-managed Store has not been started yet database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' - if not database_path.exists(): - return + if not database_path.exists() or get_db_version(database_path, CURRENT_DB_VERSION) > 10: + # No need to update if the database does not exist or is already updated + return # pragma: no cover + # code of the migration mds = MetadataStore(database_path, self.channels_dir, self.primary_key, disable_sync=True, check_tables=False, db_version=10) self.do_upgrade_pony_db_10to11(mds) mds.shutdown() + @catch_db_is_corrupted_exception def upgrade_bw_accounting_db_8to9(self): """ Upgrade the database with bandwidth accounting information from 8 to 9. @@ -215,8 +275,10 @@ def upgrade_bw_accounting_db_8to9(self): to_version = 9 database_path = self.state_dir / STATEDIR_DB_DIR / 'bandwidth.db' - if not database_path.exists() or get_db_version(database_path) >= 9: - return # No need to update if the database does not exist or is already updated + if not database_path.exists() or get_db_version(database_path, BandwidthDatabase.CURRENT_DB_VERSION) > 8: + # No need to update if the database does not exist or is already updated + return # pragma: no cover + self._logger.info('bw8->9') db = BandwidthDatabase(database_path, self.primary_key.key.pk) @@ -370,6 +432,7 @@ def do_upgrade_pony_db_10to11(self, mds): db_version = mds.MiscData.get(name="db_version") db_version.value = str(to_version) + @catch_db_is_corrupted_exception def upgrade_pony_db_8to10(self): """ Upgrade GigaChannel DB from version 8 (7.5.x) to version 10 (7.6.x). @@ -377,9 +440,11 @@ def upgrade_pony_db_8to10(self): """ self._logger.info('Upgrading GigaChannel DB from version 8 to 10') database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' - if not database_path.exists() or get_db_version(database_path) >= 10: + + if not database_path.exists() or get_db_version(database_path, CURRENT_DB_VERSION) >= 10: # Either no old db exists, or the old db version is up to date - nothing to do return + self._logger.info('8->10') # Otherwise, start upgrading self.update_status("STARTING") diff --git a/src/tribler/core/upgrade/version_manager.py b/src/tribler/core/upgrade/version_manager.py index ddf40686a79..431d4f4d503 100644 --- a/src/tribler/core/upgrade/version_manager.py +++ b/src/tribler/core/upgrade/version_manager.py @@ -280,9 +280,10 @@ def __init__(self, root_state_dir: Path, code_version_id: Optional[str] = None): code_version.should_be_copied = True code_version.should_recreate_directory = True else: + prev_version_str = code_version.can_be_copied_from.version_str self.logger.info(f"The state directory for the current version {code_version.version_str} " - f"is present, but the version does not listed in version history. Will not copy state " - f"from a previous version {code_version.can_be_copied_from.version_str}") + "is present, but the version does not listed in version history. " + f"Will not copy state from a previous version {prev_version_str}") else: self.logger.info("Cannot find the previous suitable version to copy state directory") diff --git a/src/tribler/core/utilities/db_corruption_handling/__init__.py b/src/tribler/core/utilities/db_corruption_handling/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/utilities/db_corruption_handling/base.py b/src/tribler/core/utilities/db_corruption_handling/base.py new file mode 100644 index 00000000000..d4016bba6ea --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/base.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Union + +logger = logging.getLogger('db_corruption_handling') + + +class DatabaseIsCorrupted(Exception): + pass + + +@contextmanager +def handling_malformed_db_error(db_filepath: Path): + # Used in all methods of Connection and Cursor classes where the database corruption error can occur + try: + yield + except Exception as e: + if _is_malformed_db_exception(e): + _mark_db_as_corrupted(db_filepath) + raise DatabaseIsCorrupted(str(db_filepath)) from e + raise + + +def handle_db_if_corrupted(db_filename: Union[str, Path]): + # Checks if the database is marked as corrupted and handles it by removing the database file and the marker file + db_path = Path(db_filename) + marker_path = get_corrupted_db_marker_path(db_path) + if marker_path.exists(): + _handle_corrupted_db(db_path) + + +def get_corrupted_db_marker_path(db_filepath: Path) -> Path: + return Path(str(db_filepath) + '.is_corrupted') + + +def _is_malformed_db_exception(exception): + return isinstance(exception, sqlite3.DatabaseError) and 'malformed' in str(exception) + + +def _mark_db_as_corrupted(db_filepath: Path): + # Creates a new `*.is_corrupted` marker file alongside the database file + marker_path = get_corrupted_db_marker_path(db_filepath) + marker_path.touch() + + +def _handle_corrupted_db(db_path: Path): + # Removes the database file and the marker file + if db_path.exists(): + logger.warning(f'Database file was marked as corrupted, removing it: {db_path}') + db_path.unlink() + + marker_path = get_corrupted_db_marker_path(db_path) + if marker_path.exists(): + logger.warning(f'Removing the corrupted database marker: {marker_path}') + marker_path.unlink() diff --git a/src/tribler/core/utilities/db_corruption_handling/sqlite_replacement.py b/src/tribler/core/utilities/db_corruption_handling/sqlite_replacement.py new file mode 100644 index 00000000000..12d530f32da --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/sqlite_replacement.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import sqlite3 +import sys +from pathlib import Path +from sqlite3 import DataError, DatabaseError, Error, IntegrityError, InterfaceError, InternalError, NotSupportedError, \ + OperationalError, ProgrammingError, Warning, sqlite_version_info # pylint: disable=unused-import, redefined-builtin + +from tribler.core.utilities.db_corruption_handling.base import handling_malformed_db_error + + +# This module serves as a replacement to the sqlite3 module and handles the case when the database is corrupted. +# It provides the `connect` function that should be used instead of `sqlite3.connect` and the `Cursor` and `Connection` +# classes that replaces `sqlite3.Cursor` and `sqlite3.Connection` classes respectively. If the `connect` function or +# any Connectoin or Cursor method is called and the database is corrupted, the database is marked as corrupted and +# the DatabaseIsCorrupted exception is raised. It should be handled by terminating the Tribler Core with the exit code +# EXITCODE_DATABASE_IS_CORRUPTED (99). After the Core restarts, the `handle_db_if_corrupted` function checks the +# presense of the database corruption marker and handles it by removing the database file and the corruption marker. +# After that, the database is recreated upon the next attempt to connect to it. + + +def connect(db_filename: str, **kwargs) -> sqlite3.Connection: + # Replaces the sqlite3.connect function + kwargs['factory'] = Connection + with handling_malformed_db_error(Path(db_filename)): + return sqlite3.connect(db_filename, **kwargs) + + +def _add_method_wrapper_that_handles_malformed_db_exception(cls, method_name: str): + # Creates a wrapper for the given method that handles the case when the database is corrupted + + def wrapper(self, *args, **kwargs): + with handling_malformed_db_error(self._db_filepath): # pylint: disable=protected-access + return getattr(super(cls, self), method_name)(*args, **kwargs) + + wrapper.__name__ = method_name + wrapper.is_wrapped = True # for testing purposes + setattr(cls, method_name, wrapper) + + +class Cursor(sqlite3.Cursor): + # Handles the case when the database is corrupted in all relevant methods. + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._db_filepath = self.connection._db_filepath + + +for method_name_ in ['execute', 'executemany', 'executescript', 'fetchall', 'fetchmany', 'fetchone', '__next__']: + _add_method_wrapper_that_handles_malformed_db_exception(Cursor, method_name_) + + + +class ConnectionBase(sqlite3.Connection): + # This class simplifies testing of the Connection class by allowing mocking of base class methods. + # Direct mocking of sqlite3.Connection methods is not possible because they are C functions. + + if sys.version_info < (3, 11): + def blobopen(self, *args, **kwargs) -> Blob: + raise NotImplementedError + + +class Connection(ConnectionBase): + # Handles the case when the database is corrupted in all relevant methods. + def __init__(self, db_filepath: str, *args, **kwargs): + super().__init__(db_filepath, *args, **kwargs) + self._db_filepath = Path(db_filepath) + + def cursor(self, factory=None) -> Cursor: + return super().cursor(factory or Cursor) + + def iterdump(self): + # Not implemented because it is not used in Tribler. + # Can be added later with an iterator class that handles the malformed db error during the iteration + raise NotImplementedError + + def blobopen(self, *args, **kwargs) -> Blob: # Works for Python >= 3.11 + with handling_malformed_db_error(self._db_filepath): + blob = super().blobopen(*args, **kwargs) + return Blob(blob, self._db_filepath) + + +for method_name_ in ['commit', 'execute', 'executemany', 'executescript', 'backup', '__enter__', '__exit__', + 'serialize', 'deserialize']: + _add_method_wrapper_that_handles_malformed_db_exception(Connection, method_name_) + + +class Blob: # For Python >= 3.11. Added now, so we do not forgot to add it later when upgrading to 3.11. + def __init__(self, blob, db_filepath: Path): + self._blob = blob + self._db_filepath = db_filepath + + +for method_name_ in ['close', 'read', 'write', 'seek', '__len__', '__enter__', '__exit__', '__getitem__', + '__setitem__']: + _add_method_wrapper_that_handles_malformed_db_exception(Blob, method_name_) diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/__init__.py b/src/tribler/core/utilities/db_corruption_handling/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/conftest.py b/src/tribler/core/utilities/db_corruption_handling/tests/conftest.py new file mode 100644 index 00000000000..ea2f0bfa8b1 --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest + +from tribler.core.utilities.db_corruption_handling.sqlite_replacement import connect + + +@pytest.fixture(name='db_filepath') +def db_filepath_fixture(tmp_path): + return tmp_path / 'test.db' + + +@pytest.fixture(name='connection') +def connection_fixture(db_filepath): + connection = connect(str(db_filepath)) + yield connection + connection.close() diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/test_base.py b/src/tribler/core/utilities/db_corruption_handling/tests/test_base.py new file mode 100644 index 00000000000..79448f2da0d --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/tests/test_base.py @@ -0,0 +1,57 @@ +import sqlite3 +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + + +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted, handle_db_if_corrupted, \ + handling_malformed_db_error + +malformed_error = sqlite3.DatabaseError('database disk image is malformed') + + +def test_handling_malformed_db_error__no_error(db_filepath): + # If no error is raised, the database should not be marked as corrupted + with handling_malformed_db_error(db_filepath): + pass + + assert not Path(str(db_filepath) + '.is_corrupted').exists() + + +def test_handling_malformed_db_error__malformed_error(db_filepath): + # Malformed database errors should be handled by marking the database as corrupted + with pytest.raises(DatabaseIsCorrupted): + with handling_malformed_db_error(db_filepath): + raise malformed_error + + assert Path(str(db_filepath) + '.is_corrupted').exists() + + +def test_handling_malformed_db_error__other_error(db_filepath): + # Other errors should not be handled like malformed database errors + class TestError(Exception): + pass + + with pytest.raises(TestError): + with handling_malformed_db_error(db_filepath): + raise TestError() + + assert not Path(str(db_filepath) + '.is_corrupted').exists() + + +def test_handle_db_if_corrupted__corrupted(db_filepath: Path): + # If the corruption marker is found, the corrupted database file is removed + marker_path = Path(str(db_filepath) + '.is_corrupted') + marker_path.touch() + + handle_db_if_corrupted(db_filepath) + assert not db_filepath.exists() + assert not marker_path.exists() + + +@patch('tribler.core.utilities.db_corruption_handling.base._handle_corrupted_db') +def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_filepath: Path): + # If the corruption marker is not found, the handling of the database is not performed + handle_db_if_corrupted(db_filepath) + handle_corrupted_db.assert_not_called() diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/test_sqlite_replacement.py b/src/tribler/core/utilities/db_corruption_handling/tests/test_sqlite_replacement.py new file mode 100644 index 00000000000..3b36b067a09 --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/tests/test_sqlite_replacement.py @@ -0,0 +1,77 @@ +import sqlite3 +from unittest.mock import Mock, patch + +import pytest + +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted +from tribler.core.utilities.db_corruption_handling.sqlite_replacement import Blob, Connection, \ + Cursor, _add_method_wrapper_that_handles_malformed_db_exception, connect + + +# pylint: disable=protected-access + + +malformed_error = sqlite3.DatabaseError('database disk image is malformed') + + +def test_connect(db_filepath): + connection = connect(str(db_filepath)) + assert isinstance(connection, Connection) + connection.close() + + +def test_make_method_that_handles_malformed_db_exception(db_filepath): + # Tests that the _make_method_that_handles_malformed_db_exception function creates a method that handles + # the malformed database exception + + class BaseClass: + method1 = Mock(return_value=Mock()) + + class TestClass(BaseClass): + _db_filepath = db_filepath + + _add_method_wrapper_that_handles_malformed_db_exception(TestClass, 'method1') + + # The method should be successfully wrapped + assert TestClass.method1.is_wrapped + assert TestClass.method1.__name__ == 'method1' + + test_instance = TestClass() + result = test_instance.method1(1, 2, x=3, y=4) + + # *args and **kwargs should be passed to the original method, and the result should be returned + BaseClass.method1.assert_called_once_with(1, 2, x=3, y=4) + assert result is BaseClass.method1.return_value + + # When the base method raises a malformed database exception, the DatabaseIsCorrupted exception should be raised + BaseClass.method1.side_effect = malformed_error + with pytest.raises(DatabaseIsCorrupted): + test_instance.method1(1, 2, x=3, y=4) + + +def test_connection_cursor(connection): + cursor = connection.cursor() + assert isinstance(cursor, Cursor) + + +def test_connection_iterdump(connection): + with pytest.raises(NotImplementedError): + connection.iterdump() + + +@patch('tribler.core.utilities.db_corruption_handling.sqlite_replacement.ConnectionBase.blobopen', + Mock(side_effect=malformed_error)) +def test_connection_blobopen__exception(connection): + with pytest.raises(DatabaseIsCorrupted): + connection.blobopen() + + +@patch('tribler.core.utilities.db_corruption_handling.sqlite_replacement.ConnectionBase.blobopen') +def test_connection_blobopen__no_exception(blobopen, connection): + blobopen.return_value = Mock() + result = connection.blobopen() + + blobopen.assert_called_once() + assert isinstance(result, Blob) + assert result._blob is blobopen.return_value + assert result._db_filepath == connection._db_filepath diff --git a/src/tribler/core/utilities/exit_codes.py b/src/tribler/core/utilities/exit_codes.py new file mode 100644 index 00000000000..4d034554fd6 --- /dev/null +++ b/src/tribler/core/utilities/exit_codes.py @@ -0,0 +1,3 @@ + +# Valid range for custom errors is 1..127 +EXITCODE_DATABASE_IS_CORRUPTED = 99 # If the Core process finishes with this error, the GUI process restarts it. diff --git a/src/tribler/core/utilities/pony_utils.py b/src/tribler/core/utilities/pony_utils.py index 1b1cd1ddcf9..3d4c7d30fd8 100644 --- a/src/tribler/core/utilities/pony_utils.py +++ b/src/tribler/core/utilities/pony_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import logging import sys import threading @@ -9,6 +10,7 @@ from dataclasses import dataclass from io import StringIO from operator import attrgetter +from pathlib import Path from types import FrameType from typing import Callable, Dict, Iterable, Optional, Type from weakref import WeakSet @@ -17,7 +19,18 @@ from pony.orm import core from pony.orm.core import Database, select from pony.orm.dbproviders import sqlite -from pony.utils import localbase +from pony.utils import cut_traceback, localbase +from tribler.core.utilities.db_corruption_handling import sqlite_replacement +from tribler.core.utilities.db_corruption_handling.base import handle_db_if_corrupted + +# Inject sqlite replacement to PonyORM sqlite database provider to use augmented version of Connection and Cursor +# classes that handle database corruption errors. All connection and cursor methods, such as execute and fetchone, +# raise DatabaseIsCorrupted exception if the database is corrupted. Also, the marker file with ".is_corrupted" +# extension is created alongside the corrupted database file. As a result of exception, the Tribler Core immediately +# stops with the error code 99. Tribler GUI handles this error code by showing the message to the user and automatically +# restarting the Core. After the Core is restarted, the database is re-created from scratch. +sqlite.sqlite = sqlite_replacement + SLOW_DB_SESSION_DURATION_THRESHOLD = 1.0 @@ -28,6 +41,33 @@ StatDict = Dict[Optional[str], core.QueryStat] +def table_exists(cursor: sqlite_replacement.Cursor, table_name: str) -> bool: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + return cursor.fetchone() is not None + + +def get_db_version(db_path, default: int = None) -> int: + handle_db_if_corrupted(db_path) + version = None + + if db_path.exists(): + with contextlib.closing(sqlite_replacement.connect(db_path)) as connection: + with connection: + cursor = connection.cursor() + if table_exists(cursor, 'MiscData'): + cursor.execute("SELECT value FROM MiscData WHERE name == 'db_version'") + row = cursor.fetchone() + version = int(row[0]) if row else None + + if version is not None: + return version + + if default is not None: + return default + + raise RuntimeError(f'The version value is not found in database {db_path}') + + # pylint: disable=bad-staticmethod-argument def get_or_create(cls: Type[core.Entity], create_kwargs=None, **kwargs) -> core.Entity: """Get or create db entity. @@ -132,6 +172,8 @@ class DbSessionInfo: class TriblerDbSession(core.DBSessionContextManager): + track_slow_db_sessions = False + def __init__(self, *args, duration_threshold: Optional[float] = None, **kwargs): super().__init__(*args, **kwargs) # `duration_threshold` specifies how long db_session should be to trigger the long db_session warning. @@ -144,16 +186,42 @@ def _enter(self): if is_top_level_db_session: self._start_tracking() + def __exit__(self, exc_type=None, exc=None, tb=None): + try: + super().__exit__(exc_type, exc, tb) + finally: + was_top_level_db_session = core.local.db_session is None + if was_top_level_db_session: + self._stop_tracking() + def _start_tracking(self): for db in databases_to_track: - # Clear the local statistics for all databases, so we can accumulate new local statistics in db_session + # Clear the local statistics for all databases, so we can accumulate new local statistics in db session db.merge_local_stats() + # If the tracking of slow db_sessions is not enabled, we still create the DbSessionInfo instance, but without + # the current db session stack. It is fast, and this way it is easier to avoid race conditions for the case + # when the track_slow_db_sessions value is changed on the fly during the db session execution. local.db_session_info = DbSessionInfo( - current_db_session_stack=self._extract_stack(), + current_db_session_stack=self._extract_stack() if self.track_slow_db_sessions else None, start_time=time.time() ) + def _stop_tracking(self): + info: DbSessionInfo = local.db_session_info + local.db_session_info = None + + if info.current_db_session_stack is None: + # The tracking of slow db sessions was not enabled when the db session was started, so we skip analyzing it + return + + start_time = info.start_time + db_session_duration = time.time() - start_time + + threshold = SLOW_DB_SESSION_DURATION_THRESHOLD if self.duration_threshold is None else self.duration_threshold + if db_session_duration > threshold: + self._log_warning(db_session_duration, info) + @staticmethod def _extract_stack() -> traceback.StackSummary: current_frame: FrameType = sys._getframe() # pylint: disable=protected-access @@ -178,25 +246,6 @@ def _extract_stack() -> traceback.StackSummary: stack.reverse() return stack - def __exit__(self, exc_type=None, exc=None, tb=None): - try: - super().__exit__(exc_type, exc, tb) - finally: - was_top_level_db_session = core.local.db_session is None - if was_top_level_db_session: - self._stop_tracking() - - def _stop_tracking(self): - info: DbSessionInfo = local.db_session_info - local.db_session_info = None - - start_time = info.start_time - db_session_duration = time.time() - start_time - - threshold = SLOW_DB_SESSION_DURATION_THRESHOLD if self.duration_threshold is None else self.duration_threshold - if db_session_duration > threshold: - self._log_warning(db_session_duration, info) - def _log_warning(self, db_session_duration: float, info: DbSessionInfo): db_session_query_statistics = self._summarize_stat(db.local_stats for db in databases_to_track) @@ -261,15 +310,15 @@ def _merge_stats(stats_iter: Iterable[StatDict]) -> StatDict: return result -class PatchedSQLiteProvider(sqlite.SQLiteProvider): - _acquire_time: float = 0 # A time when the current provider were able to acquire the database lock +class TriblerSQLiteProvider(sqlite.SQLiteProvider): - # It is impossible to override the __init__ method of `SQLiteProvider` without breaking - # the `SQLiteProvider.get_pool` method's logic. Therefore, we don't initialize - # a new attribute `_acquire_time` inside a class constructor method; - # instead, we set its initial value at a class level. + # It is impossible to override the __init__ method without breaking the `SQLiteProvider.get_pool` method's logic. + # Therefore, we don't initialize a new attribute `_acquire_time` inside a class constructor method. + # Instead, we set its initial value at a class level. + _acquire_time: float = 0 # A time when the current provider were able to acquire the database lock def acquire_lock(self): + # Adds tracking of a db_session's lock wait duration and lock acquire count t1 = time.time() super().acquire_lock() info = local.db_session_info @@ -281,6 +330,7 @@ def acquire_lock(self): info.lock_wait_total_duration += lock_wait_duration def release_lock(self): + # Adds tracking of a db_session's total lock hold duration super().release_lock() info = local.db_session_info if info is not None: @@ -290,18 +340,34 @@ def release_lock(self): db_session = TriblerDbSession() +orm.db_session = orm.core.db_session = db_session class TriblerDatabase(Database): - # If a developer what to track the slow execution of the database, he should create an instance of TriblerDatabase - # instead of the usual pony.orm.Database. + # TriblerDatabase extends the functionality of the Database class in the following ways: + # * It adds handling the case when the database file is corrupted + # * It accumulates and shows statistics on slow database queries def __init__(self): databases_to_track.add(self) super().__init__() + @cut_traceback + def bind(self, **kwargs): + provider = kwargs.pop('provider', None) + if provider and provider != 'sqlite': + raise TypeError(f"Invalid 'provider' argument for TriblerDatabase: {provider!r}") + + filename = kwargs.get('filename', None) + if filename and filename not in {':memory:', ':sharedmemory:'}: + db_path = Path(filename) + if not db_path.absolute(): + raise ValueError(f"The 'filename' attribute is expected to be an absolute path. Got: {filename}") + + handle_db_if_corrupted(db_path) + + self._bind(TriblerSQLiteProvider, **kwargs) + def track_slow_db_sessions(): - # The method enables tracking of slow db_sessions - orm.db_session = orm.core.db_session = db_session - sqlite.provider_cls = PatchedSQLiteProvider + TriblerDbSession.track_slow_db_sessions = True diff --git a/src/tribler/core/utilities/tests/test_pony_utils.py b/src/tribler/core/utilities/tests/test_pony_utils.py index 4d93e8c2dfa..99b3e30fdcb 100644 --- a/src/tribler/core/utilities/tests/test_pony_utils.py +++ b/src/tribler/core/utilities/tests/test_pony_utils.py @@ -1,10 +1,12 @@ +import sqlite3 +from pathlib import Path from unittest.mock import patch import pytest from pony.orm.core import QueryStat, Required from tribler.core.utilities import pony_utils - +from tribler.core.utilities.pony_utils import get_db_version, table_exists EMPTY_DICT = {} @@ -43,9 +45,9 @@ def test_merge_stats(): def test_patched_db_session(tmp_path): # The test is added for better coverage of TriblerDbSession methods - with patch('pony.orm.dbproviders.sqlite.provider_cls', pony_utils.PatchedSQLiteProvider): + with patch('tribler.core.utilities.pony_utils.TriblerDbSession.track_slow_db_sessions', True): db = pony_utils.TriblerDatabase() - db.bind('sqlite', str(tmp_path / 'db.sqlite'), create_db=True) + db.bind(provider='sqlite', filename=str(tmp_path / 'db.sqlite'), create_db=True) class Entity1(db.Entity): a = Required(int) @@ -81,9 +83,9 @@ def test_patched_db_session_default_duration_threshold(tmp_path): # The test checks that db_session uses the current dynamic value of SLOW_DB_SESSION_DURATION_THRESHOLD # if no duration_threshold was explicitly specified for db_session - with patch('pony.orm.dbproviders.sqlite.provider_cls', pony_utils.PatchedSQLiteProvider): + with patch('tribler.core.utilities.pony_utils.TriblerDbSession.track_slow_db_sessions', True): db = pony_utils.TriblerDatabase() - db.bind('sqlite', str(tmp_path / 'db.sqlite'), create_db=True) + db.bind(provider='sqlite', filename=str(tmp_path / 'db.sqlite'), create_db=True) class Entity1(db.Entity): a = Required(int) @@ -120,3 +122,81 @@ def test_format_warning(): Queries statistics for the entire application: """ + + +@pytest.fixture(name='db_path') +def db_path_fixture(tmp_path: Path): + db_path = tmp_path / 'test.db' + db_path.touch() + return db_path + + +def test_get_db_version__db_does_not_exist(tmp_path: Path): + # When the database does not exist, the call to get_db_version generates RuntimeError + db_path = tmp_path / 'doesnotexist.db' + + with pytest.raises(RuntimeError, match='^The version value is not found in database .*doesnotexist.db$'): + get_db_version(db_path) + + +def test_get_db_version__db_does_not_exist_default_version(tmp_path: Path): + # When the database does not exist, and the default is specified, get_db_version returns the default value + db_path = tmp_path / 'doesnotexist.db' + + default_version = 123 + version = get_db_version(db_path, default=default_version) + assert version == default_version + + +def test_get_db_version__version_table_does_not_exist(db_path: Path): + # When the version table does not exist, the call to get_db_version generates RuntimeError + with pytest.raises(RuntimeError, match='^The version value is not found in database .*test.db$'): + get_db_version(db_path) + + +def test_get_db_version__version_table_does_not_exist_default_version(db_path: Path): + # When the version table does not exist, and the default is specified, get_db_version returns the default value + default_version = 123 + version = get_db_version(db_path, default=default_version) + assert version == default_version + + +def test_get_db_version(db_path: Path): + # Tests that if the database schema version is specified in the database, the default value is completely ignored. + + with sqlite3.connect(db_path) as connection: + connection.execute('create table MiscData(name text primary key, value text)') + connection.execute("insert into MiscData(name, value) values('db_version', '100')") + + version = get_db_version(db_path) + assert version == 100 + + version = get_db_version(db_path, default=99) + assert version == 100 + + version = get_db_version(db_path, default=101) + assert version == 100 + + +def test_get_db_version__corrupted_db(tmp_path: Path): + # Tests that if the database file is marked as corrupted, then the version from the db is ignored, + # and default version is used after the database file is re-created. + + db_path = tmp_path / 'test.db' + + connection = sqlite3.connect(db_path) + connection.execute('create table MiscData(name text primary key, value text)') + connection.execute("insert into MiscData(name, value) values('db_version', '100')") + connection.commit() + assert table_exists(connection.cursor(), 'MiscData') + connection.close() + + marker_path = Path(str(db_path) + '.is_corrupted') + marker_path.touch() + + default_version = 10 + version = get_db_version(db_path, default=default_version) + assert version == default_version + + with sqlite3.connect(db_path) as connection: + assert not table_exists(connection.cursor(), 'MiscData') diff --git a/src/tribler/gui/core_manager.py b/src/tribler/gui/core_manager.py index 80baf5e5ec8..a93de797f17 100644 --- a/src/tribler/gui/core_manager.py +++ b/src/tribler/gui/core_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import re @@ -10,13 +12,14 @@ from PyQt5.QtCore import QObject, QProcess, QProcessEnvironment, QTimer from PyQt5.QtNetwork import QNetworkRequest +from tribler.core.utilities.exit_codes import EXITCODE_DATABASE_IS_CORRUPTED from tribler.core.utilities.process_manager import ProcessManager from tribler.gui import gui_sentry_reporter from tribler.gui.app_manager import AppManager from tribler.gui.event_request_manager import EventRequestManager from tribler.gui.exceptions import CoreConnectTimeoutError, CoreCrashedError from tribler.gui.network.request_manager import SHUTDOWN_ENDPOINT, request_manager -from tribler.gui.utilities import connect +from tribler.gui.utilities import connect, show_message_corrupted_database_was_fixed API_PORT_CHECK_INTERVAL = 100 # 0.1 seconds between attempts to retrieve Core API port API_PORT_CHECK_TIMEOUT = 120 # Stop trying to determine API port after this number of seconds @@ -288,6 +291,14 @@ def format_error_message(exit_code: int, exit_status: int) -> str: def on_core_finished(self, exit_code, exit_status): self._logger.info("Core process finished") self.core_running = False + self.core_connected = False + if exit_code == EXITCODE_DATABASE_IS_CORRUPTED: + self._logger.error(f"Core process crashed with code {exit_code}: the database is corrupted. " + "Restarting Core...") + self.start_tribler_core() + show_message_corrupted_database_was_fixed() + return + self.core_finished = True if self.shutting_down: if self.should_quit_app_on_core_finished: diff --git a/src/tribler/gui/defs.py b/src/tribler/gui/defs.py index 40cc4a8bd63..c1fca85cd0c 100644 --- a/src/tribler/gui/defs.py +++ b/src/tribler/gui/defs.py @@ -210,6 +210,9 @@ def get(cls, item, default=None): UPGRADE_CANCELLED_ERROR_TITLE = "Tribler Upgrade cancelled" + NO_DISK_SPACE_ERROR_MESSAGE = "Not enough storage space available. \n" \ "Tribler requires at least %s space to continue. \n\n" \ "Please free up the required space and re-run Tribler. " + +CORRUPTED_DB_WAS_FIXED_MESSAGE = "The corrupted database file was fixed" diff --git a/src/tribler/gui/event_request_manager.py b/src/tribler/gui/event_request_manager.py index 8b26321de6a..88f68108c57 100644 --- a/src/tribler/gui/event_request_manager.py +++ b/src/tribler/gui/event_request_manager.py @@ -130,11 +130,10 @@ def on_error(self, error: int, reschedule_on_err: bool): return if self.receiving_data: - # We are performing reconnect on the initial connection error only. - # In the future, if we consider it useful, we can immediately call here - # `self.reconnect(reschedule_on_err=False)` - # and raise an exception if it fails to reconnect - raise CoreConnectionError('The connection to the Tribler Core was lost') + # Most probably Core is crashed. If CoreManager decides to restart the core, + # it will also call event_manager.connect_to_core() + self._logger.error('The connection to the Tribler Core was lost') + return should_retry = reschedule_on_err and time.time() < self.start_time + CORE_CONNECTION_TIMEOUT error_name = self.network_errors.get(error, error) diff --git a/src/tribler/gui/tribler_window.py b/src/tribler/gui/tribler_window.py index f7b0d51a765..0a914e53cb6 100644 --- a/src/tribler/gui/tribler_window.py +++ b/src/tribler/gui/tribler_window.py @@ -545,7 +545,6 @@ def tray_show_message(self, title, message): def on_core_connected(self, version): if self.core_connected: self._logger.warning("Received duplicate Tribler Core connected event") - return self._logger.info("Core connected") self.core_connected = True @@ -559,6 +558,10 @@ def on_receive_settings(self, settings): self.start_ui() def start_ui(self): + if self.ui_started: + self._logger.info("UI already started") + return + self.top_menu_button.setHidden(False) self.left_menu.setHidden(False) # self.token_balance_widget.setHidden(False) # restore it after the token balance calculation is fixed diff --git a/src/tribler/gui/upgrade_manager.py b/src/tribler/gui/upgrade_manager.py index 9c7e7597c11..fea476870ac 100644 --- a/src/tribler/gui/upgrade_manager.py +++ b/src/tribler/gui/upgrade_manager.py @@ -11,10 +11,12 @@ from tribler.core.config.tribler_config import TriblerConfig from tribler.core.upgrade.upgrade import TriblerUpgrader from tribler.core.upgrade.version_manager import TriblerVersion, VersionHistory, NoDiskSpaceAvailableError -from tribler.gui.defs import BUTTON_TYPE_NORMAL, NO_DISK_SPACE_ERROR_MESSAGE, UPGRADE_CANCELLED_ERROR_TITLE +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted +from tribler.gui.defs import BUTTON_TYPE_NORMAL, CORRUPTED_DB_WAS_FIXED_MESSAGE, NO_DISK_SPACE_ERROR_MESSAGE, \ + UPGRADE_CANCELLED_ERROR_TITLE from tribler.gui.dialogs.confirmationdialog import ConfirmationDialog from tribler.gui.exceptions import UpgradeError -from tribler.gui.utilities import connect, format_size, tr +from tribler.gui.utilities import connect, format_size, show_message_corrupted_database_was_fixed, tr if TYPE_CHECKING: from tribler.gui.tribler_window import TriblerWindow @@ -60,7 +62,8 @@ def run(self): self.logger.info('Finished') self.finished.emit(None) - def format_no_disk_space_available_error(self, disk_error: NoDiskSpaceAvailableError) -> str: + @staticmethod + def format_no_disk_space_available_error(disk_error: NoDiskSpaceAvailableError) -> str: diff_space = format_size(disk_error.space_required - disk_error.space_available) formatted_error = tr(NO_DISK_SPACE_ERROR_MESSAGE) % diff_space return formatted_error @@ -247,9 +250,18 @@ def on_worker_finished(self, exc): self.stop_worker() if exc is None: self.upgrader_finished.emit() + elif isinstance(exc, DatabaseIsCorrupted): + self.upgrader_finished.emit() + show_message_corrupted_database_was_fixed(db_path=str(exc)) else: raise UpgradeError(f'{exc.__class__.__name__}: {exc}') from exc + @staticmethod + def _format_database_corruption_fixed_message(exc: DatabaseIsCorrupted) -> str: + message = tr(CORRUPTED_DB_WAS_FIXED_MESSAGE) + formatted_error = f'{message}:\n\n{exc}' + return formatted_error + def on_worker_cancelled(self, reason: str): self.stop_worker() self.upgrader_cancelled.emit(reason) diff --git a/src/tribler/gui/utilities.py b/src/tribler/gui/utilities.py index 200809e0630..6eb34575f78 100644 --- a/src/tribler/gui/utilities.py +++ b/src/tribler/gui/utilities.py @@ -9,7 +9,7 @@ import types from datetime import datetime, timedelta from pathlib import Path -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional from urllib.parse import quote_plus from uuid import uuid4 @@ -27,7 +27,7 @@ import tribler.gui from tribler.core.components.knowledge.db.knowledge_db import ResourceType -from tribler.gui.defs import HEALTH_DEAD, HEALTH_GOOD, HEALTH_MOOT, HEALTH_UNCHECKED +from tribler.gui.defs import CORRUPTED_DB_WAS_FIXED_MESSAGE, HEALTH_DEAD, HEALTH_GOOD, HEALTH_MOOT, HEALTH_UNCHECKED # fmt: off @@ -521,6 +521,16 @@ def show_message_box(text: str = '', title: str = 'Error', icon: QMessageBox.Ico message_box.exec_() +def show_message_corrupted_database_was_fixed(db_path: Optional[str] = None): + text = tr(CORRUPTED_DB_WAS_FIXED_MESSAGE) + if db_path: + text = f'{text}:\n\n{db_path}' + + message_box = QMessageBox(icon=QMessageBox.Critical, text=text) + message_box.setWindowTitle(tr("Database corruption detected")) + message_box.exec() + + def make_network_errors_dict() -> Dict[int, str]: network_errors = {} for name in dir(QNetworkReply):