From 0526d7eeac12e0a6d89a4bb03461da758cb1f17e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 27 Sep 2023 11:09:39 +0200 Subject: [PATCH] Handle corrupted databases --- .../bandwidth_accounting/db/database.py | 2 +- .../components/knowledge/db/knowledge_db.py | 2 +- .../components/metadata_store/db/store.py | 2 +- src/tribler/core/upgrade/db8_to_db10.py | 58 +++++----- .../core/upgrade/tags_to_knowledge/tags_db.py | 2 +- src/tribler/core/upgrade/upgrade.py | 10 ++ src/tribler/core/utilities/pony_utils.py | 102 ++++++++++++++++-- .../core/utilities/tests/test_pony_utils.py | 52 +++++++-- 8 files changed, 185 insertions(+), 45 deletions(-) diff --git a/src/tribler/core/components/bandwidth_accounting/db/database.py b/src/tribler/core/components/bandwidth_accounting/db/database.py index fcc6435eded..6ad33bd200c 100644 --- a/src/tribler/core/components/bandwidth_accounting/db/database.py +++ b/src/tribler/core/components/bandwidth_accounting/db/database.py @@ -34,7 +34,7 @@ def __init__(self, db_path: Union[Path, type(MEMORY_DB)], my_pub_key: bytes, # with the static analysis. # pylint: disable=unused-variable - @self.database.on_connect(provider='sqlite') + @self.database.on_connect def sqlite_sync_pragmas(_, connection): cursor = connection.cursor() cursor.execute("PRAGMA journal_mode = WAL") diff --git a/src/tribler/core/components/knowledge/db/knowledge_db.py b/src/tribler/core/components/knowledge/db/knowledge_db.py index 4950cd1a1ee..4feab9994c5 100644 --- a/src/tribler/core/components/knowledge/db/knowledge_db.py +++ b/src/tribler/core/components/knowledge/db/knowledge_db.py @@ -66,7 +66,7 @@ class KnowledgeDatabase: def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): self.instance = TriblerDatabase() self.define_binding(self.instance) - self.instance.bind('sqlite', filename or ':memory:', create_db=True) + self.instance.bind(provider='sqlite', filename=filename or ':memory:', create_db=True) generate_mapping_kwargs['create_tables'] = create_tables self.instance.generate_mapping(**generate_mapping_kwargs) self.logger = logging.getLogger(self.__class__.__name__) diff --git a/src/tribler/core/components/metadata_store/db/store.py b/src/tribler/core/components/metadata_store/db/store.py index 7f8e79e696b..a2d6811b8b3 100644 --- a/src/tribler/core/components/metadata_store/db/store.py +++ b/src/tribler/core/components/metadata_store/db/store.py @@ -166,7 +166,7 @@ def __init__( # This attribute is internally called by Pony on startup, though pylint cannot detect it # with the static analysis. # pylint: disable=unused-variable - @self.db.on_connect(provider='sqlite') + @self.db.on_connect def on_connect(_, connection): cursor = connection.cursor() cursor.execute("PRAGMA journal_mode = WAL") diff --git a/src/tribler/core/upgrade/db8_to_db10.py b/src/tribler/core/upgrade/db8_to_db10.py index a6ee17ae250..f91284ec748 100644 --- a/src/tribler/core/upgrade/db8_to_db10.py +++ b/src/tribler/core/upgrade/db8_to_db10.py @@ -8,6 +8,7 @@ from pony.orm import db_session from tribler.core.components.metadata_store.db.store import MetadataStore +from tribler.core.utilities.pony_utils import marking_corrupted_db TABLE_NAMES = ( "ChannelNode", "TorrentState", "TorrentState_TrackerState", "ChannelPeer", "ChannelVote", "TrackerState", "Vsids") @@ -126,31 +127,31 @@ def convert_command(offset, batch_size): def do_migration(self): result = None # estimated duration in seconds of ChannelNode table copying time try: - - old_table_columns = {} - for table_name in TABLE_NAMES: - old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name) - - with contextlib.closing(sqlite3.connect(self.new_db_path)) as connection, connection: - cursor = connection.cursor() - cursor.execute("PRAGMA journal_mode = OFF;") - cursor.execute("PRAGMA synchronous = OFF;") - cursor.execute("PRAGMA foreign_keys = OFF;") - cursor.execute("PRAGMA temp_store = MEMORY;") - cursor.execute("PRAGMA cache_size = -204800;") - cursor.execute(f'ATTACH DATABASE "{self.old_db_path}" as old_db;') - + with marking_corrupted_db(self.old_db_path): + old_table_columns = {} for table_name in TABLE_NAMES: - t1 = now() - cursor.execute("BEGIN TRANSACTION;") - if not self.must_shutdown(): - self.convert_table(cursor, table_name, old_table_columns[table_name]) - cursor.execute("COMMIT;") - duration = now() - t1 - self._logger.info(f"Upgrade: copied table {table_name} in {duration:.2f} seconds") - - if table_name == 'ChannelNode': - result = duration + old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name) + + with contextlib.closing(sqlite3.connect(self.new_db_path)) as connection, connection: + cursor = connection.cursor() + cursor.execute("PRAGMA journal_mode = OFF;") + cursor.execute("PRAGMA synchronous = OFF;") + cursor.execute("PRAGMA foreign_keys = OFF;") + cursor.execute("PRAGMA temp_store = MEMORY;") + cursor.execute("PRAGMA cache_size = -204800;") + cursor.execute(f'ATTACH DATABASE "{self.old_db_path}" as old_db;') + + for table_name in TABLE_NAMES: + t1 = now() + cursor.execute("BEGIN TRANSACTION;") + if not self.must_shutdown(): + self.convert_table(cursor, table_name, old_table_columns[table_name]) + cursor.execute("COMMIT;") + duration = now() - t1 + self._logger.info(f"Upgrade: copied table {table_name} in {duration:.2f} seconds") + + if table_name == 'ChannelNode': + result = duration self.update_status("Synchronizing the upgraded DB to disk, please wait.") except Exception as e: @@ -242,8 +243,9 @@ def get_table_columns(db_path, table_name): def get_db_version(db_path): - with contextlib.closing(sqlite3.connect(db_path)) as connection, connection: - cursor = connection.cursor() - cursor.execute('SELECT value FROM MiscData WHERE name == "db_version"') - version = int(cursor.fetchone()[0]) + with marking_corrupted_db(db_path): + with contextlib.closing(sqlite3.connect(db_path)) as connection, connection: + cursor = connection.cursor() + cursor.execute('SELECT value FROM MiscData WHERE name == "db_version"') + version = int(cursor.fetchone()[0]) return version diff --git a/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py b/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py index f9890118e60..7f00ab93742 100644 --- a/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py +++ b/src/tribler/core/upgrade/tags_to_knowledge/tags_db.py @@ -10,7 +10,7 @@ class TagDatabase: def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs): self.instance = TriblerDatabase() self.define_binding(self.instance) - self.instance.bind('sqlite', filename or ':memory:', create_db=True) + self.instance.bind(provider='sqlite', filename=filename or ':memory:', create_db=True) generate_mapping_kwargs['create_tables'] = create_tables self.instance.generate_mapping(**generate_mapping_kwargs) diff --git a/src/tribler/core/upgrade/upgrade.py b/src/tribler/core/upgrade/upgrade.py index 3d490e1d537..fa756e4ae66 100644 --- a/src/tribler/core/upgrade/upgrade.py +++ b/src/tribler/core/upgrade/upgrade.py @@ -23,6 +23,7 @@ from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase from tribler.core.utilities.configparser import CallbackConfigParser from tribler.core.utilities.path_util import Path +from tribler.core.utilities.pony_utils import handle_db_if_corrupted from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR @@ -134,6 +135,7 @@ def upgrade_tags_to_knowledge(self): def upgrade_pony_db_14to15(self): self._logger.info('Upgrade Pony DB from version 14 to version 15') mds_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' + handle_db_if_corrupted(mds_path) mds = MetadataStore(mds_path, self.channels_dir, self.primary_key, disable_sync=True, check_tables=False, db_version=14) if mds_path.exists() else None @@ -147,6 +149,9 @@ def upgrade_pony_db_13to14(self): mds_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' tagdb_path = self.state_dir / STATEDIR_DB_DIR / 'tags.db' + handle_db_if_corrupted(mds_path) + handle_db_if_corrupted(tagdb_path) + mds = MetadataStore(mds_path, self.channels_dir, self.primary_key, disable_sync=True, check_tables=False, db_version=13) if mds_path.exists() else None tag_db = TagDatabase(str(tagdb_path), create_tables=False, @@ -166,6 +171,7 @@ def upgrade_pony_db_12to13(self): self._logger.info('Upgrade Pony DB 12 to 13') # We have to create the Metadata Store object because Session-managed Store has not been started yet database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' + handle_db_if_corrupted(database_path) if database_path.exists(): mds = MetadataStore(database_path, self.channels_dir, self.primary_key, disable_sync=True, check_tables=False, db_version=12) @@ -181,6 +187,7 @@ def upgrade_pony_db_11to12(self): self._logger.info('Upgrade Pony DB 11 to 12') # We have to create the Metadata Store object because Session-managed Store has not been started yet database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' + handle_db_if_corrupted(database_path) if not database_path.exists(): return mds = MetadataStore(database_path, self.channels_dir, self.primary_key, @@ -197,6 +204,7 @@ def upgrade_pony_db_10to11(self): self._logger.info('Upgrade Pony DB 10 to 11') # We have to create the Metadata Store object because Session-managed Store has not been started yet database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' + handle_db_if_corrupted(database_path) if not database_path.exists(): return # code of the migration @@ -215,6 +223,7 @@ def upgrade_bw_accounting_db_8to9(self): to_version = 9 database_path = self.state_dir / STATEDIR_DB_DIR / 'bandwidth.db' + handle_db_if_corrupted(database_path) if not database_path.exists() or get_db_version(database_path) >= 9: return # No need to update if the database does not exist or is already updated self._logger.info('bw8->9') @@ -377,6 +386,7 @@ def upgrade_pony_db_8to10(self): """ self._logger.info('Upgrading GigaChannel DB from version 8 to 10') database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db' + handle_db_if_corrupted(database_path) if not database_path.exists() or get_db_version(database_path) >= 10: # Either no old db exists, or the old db version is up to date - nothing to do return diff --git a/src/tribler/core/utilities/pony_utils.py b/src/tribler/core/utilities/pony_utils.py index 6ab3d1475f4..76a9e4e44ae 100644 --- a/src/tribler/core/utilities/pony_utils.py +++ b/src/tribler/core/utilities/pony_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import sqlite3 import sys import threading import time @@ -9,15 +10,19 @@ from dataclasses import dataclass from io import StringIO from operator import attrgetter +from pathlib import Path from types import FrameType -from typing import Callable, Dict, Iterable, Optional, Type +from typing import Callable, Dict, Iterable, Optional, Type, Union from weakref import WeakSet +from contextlib import contextmanager + from pony import orm from pony.orm import core from pony.orm.core import Database, select from pony.orm.dbproviders import sqlite -from pony.utils import cut_traceback, localbase +from pony.orm.dbproviders.sqlite import SQLitePool +from pony.utils import absolutize_path, cut_traceback, cut_traceback_depth, localbase SLOW_DB_SESSION_DURATION_THRESHOLD = 1.0 @@ -28,6 +33,56 @@ StatDict = Dict[Optional[str], core.QueryStat] +class DatabaseIsCorrupted(Exception): + pass + + +Filename = Union[str, Path] + + +def handle_db_if_corrupted(db_filename: Filename): + marker_path = _get_corrupted_db_marker_path(db_filename) + if marker_path.exists(): + _handle_corrupted_db(db_filename) + + +def _handle_corrupted_db(db_filename: Filename): + db_path = Path(db_filename) + if db_path.exists(): + db_path.unlink() + marker_path = _get_corrupted_db_marker_path(db_filename) + if marker_path.exists(): + marker_path.unlink() + + +def _get_corrupted_db_marker_path(db_filename: Filename) -> Path: + return Path(str(db_filename) + '.is_corrupted') + + +@contextmanager +def marking_corrupted_db(db_filename: Filename): + try: + yield + except Exception as e: + if _is_malformed_db_exception(e): + _mark_db_as_corrupted(db_filename) + raise DatabaseIsCorrupted(str(db_filename)) from e + raise + + +def _is_malformed_db_exception(exception): + return isinstance(exception, (core.DatabaseError, sqlite3.DatabaseError)) and 'malformed' in str(exception) + + +def _mark_db_as_corrupted(db_filename: Filename): + if not Path(db_filename).exists(): + raise RuntimeError(f'Corrupted database file not found: {db_filename!r}') + + marker_path = _get_corrupted_db_marker_path(db_filename) + marker_path.touch() + + + # pylint: disable=bad-staticmethod-argument def get_or_create(cls: Type[core.Entity], create_kwargs=None, **kwargs) -> core.Entity: """Get or create db entity. @@ -271,6 +326,7 @@ def _merge_stats(stats_iter: Iterable[StatDict]) -> StatDict: class TriblerSQLiteProvider(sqlite.SQLiteProvider): + pool: TriblerPool # It is impossible to override the __init__ method without breaking the `SQLiteProvider.get_pool` method's logic. # Therefore, we don't initialize a new attribute `_acquire_time` inside a class constructor method. @@ -298,14 +354,45 @@ def release_lock(self): lock_hold_duration = time.time() - acquire_time info.lock_hold_total_duration += lock_hold_duration + def set_transaction_mode(self, connection, cache): + with marking_corrupted_db(self.pool.filename): + return super().set_transaction_mode(connection, cache) + + def execute(self, cursor, sql, arguments=None, returning_id=False): + with marking_corrupted_db(self.pool.filename): + return super().execute(cursor, sql, arguments, returning_id) + + def mark_db_as_malformed(self): + filename = self.pool.filename + if not Path(filename).exists(): + raise RuntimeError(f'Corrupted database file not found: {filename!r}') + + marker_filename = filename + '.is_corrupted' + Path(marker_filename).touch() + + def get_pool(self, is_shared_memory_db, filename, create_db=False, **kwargs): + if is_shared_memory_db or filename == ':memory:': + pass + else: + filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5) # see the base method for details + handle_db_if_corrupted(filename) + return TriblerPool(is_shared_memory_db, filename, create_db, **kwargs) + + +class TriblerPool(SQLitePool): + def _connect(self): + with marking_corrupted_db(self.filename): + return super()._connect() + db_session = TriblerDbSession() orm.db_session = orm.core.db_session = db_session class TriblerDatabase(Database): - # If a developer what to track the slow execution of the database, he should create an instance of TriblerDatabase - # instead of the usual pony.orm.Database. + # TriblerDatabase extends the functionality of the Database class in the following ways: + # * It adds handling of DatabaseError when the database file is corrupted + # * It accumulates and shows statistics on slow database queries def __init__(self): databases_to_track.add(self) @@ -314,11 +401,12 @@ def __init__(self): @cut_traceback def bind(self, **kwargs): if 'provider' in kwargs: - raise TypeError('You should not explicitly specify the `provider` keyword argument for TriblerDatabase') + provider = kwargs['provider'] + if provider != 'sqlite': + raise TypeError(f'Invalid `provider` argument for TriblerDatabase: {provider!r}') + kwargs.pop('provider') self._bind(TriblerSQLiteProvider, **kwargs) - def track_slow_db_sessions(): TriblerDbSession.track_slow_db_sessions = True - diff --git a/src/tribler/core/utilities/tests/test_pony_utils.py b/src/tribler/core/utilities/tests/test_pony_utils.py index 129253d5f08..5995dcc1500 100644 --- a/src/tribler/core/utilities/tests/test_pony_utils.py +++ b/src/tribler/core/utilities/tests/test_pony_utils.py @@ -1,10 +1,12 @@ -from unittest.mock import patch +import sqlite3 +from pathlib import Path +from unittest.mock import patch, Mock import pytest from pony.orm.core import QueryStat, Required from tribler.core.utilities import pony_utils - +from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, handle_db_if_corrupted, marking_corrupted_db EMPTY_DICT = {} @@ -43,9 +45,9 @@ def test_merge_stats(): def test_patched_db_session(tmp_path): # The test is added for better coverage of TriblerDbSession methods - with patch('pony.orm.dbproviders.sqlite.provider_cls', pony_utils.TriblerSQLiteProvider): + with patch('tribler.core.utilities.pony_utils.TriblerDbSession.track_slow_db_sessions', True): db = pony_utils.TriblerDatabase() - db.bind('sqlite', str(tmp_path / 'db.sqlite'), create_db=True) + db.bind(provider='sqlite', filename=str(tmp_path / 'db.sqlite'), create_db=True) class Entity1(db.Entity): a = Required(int) @@ -81,9 +83,9 @@ def test_patched_db_session_default_duration_threshold(tmp_path): # The test checks that db_session uses the current dynamic value of SLOW_DB_SESSION_DURATION_THRESHOLD # if no duration_threshold was explicitly specified for db_session - with patch('pony.orm.dbproviders.sqlite.provider_cls', pony_utils.TriblerSQLiteProvider): + with patch('tribler.core.utilities.pony_utils.TriblerDbSession.track_slow_db_sessions', True): db = pony_utils.TriblerDatabase() - db.bind('sqlite', str(tmp_path / 'db.sqlite'), create_db=True) + db.bind(provider='sqlite', filename=str(tmp_path / 'db.sqlite'), create_db=True) class Entity1(db.Entity): a = Required(int) @@ -120,3 +122,41 @@ def test_format_warning(): Queries statistics for the entire application: """ + + +@pytest.fixture(name='db_path') +def db_path_fixture(tmp_path): + db_path = tmp_path / 'test.db' + db_path.touch() + return db_path + + +@patch('tribler.core.utilities.pony_utils._handle_corrupted_db') +def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_path): + handle_db_if_corrupted(db_path) + handle_corrupted_db.assert_not_called() + + +def test_handle_db_if_corrupted__corrupted(db_path): + marker_path = Path(str(db_path) + '.is_corrupted') + marker_path.touch() + + handle_db_if_corrupted(db_path) + assert not db_path.exists() + assert not marker_path.exists() + + +def test_marking_corrupted_db__not_malformed(db_path): + with pytest.raises(ZeroDivisionError): + with marking_corrupted_db(db_path): + raise ZeroDivisionError() + + assert not Path(str(db_path) + '.is_corrupted').exists() + + +def test_marking_corrupted_db__malformed(db_path): + with pytest.raises(DatabaseIsCorrupted): + with marking_corrupted_db(db_path): + raise sqlite3.DatabaseError('database disk image is malformed') + + assert Path(str(db_path) + '.is_corrupted').exists()