Skip to content

Commit

Permalink
Handle corrupted databases
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Oct 10, 2023
1 parent 166d3e8 commit 4ab79e3
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, db_path: Union[Path, type(MEMORY_DB)], my_pub_key: bytes,
# with the static analysis.
# pylint: disable=unused-variable

@self.database.on_connect(provider='sqlite')
@self.database.on_connect
def sqlite_sync_pragmas(_, connection):
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = WAL")
Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/components/knowledge/db/knowledge_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class KnowledgeDatabase:
def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs):
self.instance = TriblerDatabase()
self.define_binding(self.instance)
self.instance.bind('sqlite', filename or ':memory:', create_db=True)
self.instance.bind(provider='sqlite', filename=filename or ':memory:', create_db=True)
generate_mapping_kwargs['create_tables'] = create_tables
self.instance.generate_mapping(**generate_mapping_kwargs)
self.logger = logging.getLogger(self.__class__.__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/components/metadata_store/db/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
# This attribute is internally called by Pony on startup, though pylint cannot detect it
# with the static analysis.
# pylint: disable=unused-variable
@self.db.on_connect(provider='sqlite')
@self.db.on_connect
def on_connect(_, connection):
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = WAL")
Expand Down
60 changes: 31 additions & 29 deletions src/tribler/core/upgrade/db8_to_db10.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pony.orm import db_session

from tribler.core.components.metadata_store.db.store import MetadataStore
from tribler.core.utilities.pony_utils import marking_corrupted_db

TABLE_NAMES = (
"ChannelNode", "TorrentState", "TorrentState_TrackerState", "ChannelPeer", "ChannelVote", "TrackerState", "Vsids")
Expand Down Expand Up @@ -126,31 +127,31 @@ def convert_command(offset, batch_size):
def do_migration(self):
result = None # estimated duration in seconds of ChannelNode table copying time
try:

old_table_columns = {}
for table_name in TABLE_NAMES:
old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name)

with contextlib.closing(sqlite3.connect(self.new_db_path)) as connection, connection:
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = OFF;")
cursor.execute("PRAGMA synchronous = OFF;")
cursor.execute("PRAGMA foreign_keys = OFF;")
cursor.execute("PRAGMA temp_store = MEMORY;")
cursor.execute("PRAGMA cache_size = -204800;")
cursor.execute(f'ATTACH DATABASE "{self.old_db_path}" as old_db;')

with marking_corrupted_db(self.old_db_path):
old_table_columns = {}
for table_name in TABLE_NAMES:
t1 = now()
cursor.execute("BEGIN TRANSACTION;")
if not self.must_shutdown():
self.convert_table(cursor, table_name, old_table_columns[table_name])
cursor.execute("COMMIT;")
duration = now() - t1
self._logger.info(f"Upgrade: copied table {table_name} in {duration:.2f} seconds")

if table_name == 'ChannelNode':
result = duration
old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name)

with contextlib.closing(sqlite3.connect(self.new_db_path)) as connection, connection:
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = OFF;")
cursor.execute("PRAGMA synchronous = OFF;")
cursor.execute("PRAGMA foreign_keys = OFF;")
cursor.execute("PRAGMA temp_store = MEMORY;")
cursor.execute("PRAGMA cache_size = -204800;")
cursor.execute(f'ATTACH DATABASE "{self.old_db_path}" as old_db;')

for table_name in TABLE_NAMES:
t1 = now()
cursor.execute("BEGIN TRANSACTION;")
if not self.must_shutdown():
self.convert_table(cursor, table_name, old_table_columns[table_name])
cursor.execute("COMMIT;")
duration = now() - t1
self._logger.info(f"Upgrade: copied table {table_name} in {duration:.2f} seconds")

if table_name == 'ChannelNode':
result = duration

self.update_status("Synchronizing the upgraded DB to disk, please wait.")
except Exception as e:
Expand Down Expand Up @@ -242,8 +243,9 @@ def get_table_columns(db_path, table_name):


def get_db_version(db_path):
with contextlib.closing(sqlite3.connect(db_path)) as connection, connection:
cursor = connection.cursor()
cursor.execute('SELECT value FROM MiscData WHERE name == "db_version"')
version = int(cursor.fetchone()[0])
return version
with marking_corrupted_db(db_path):
with contextlib.closing(sqlite3.connect(db_path)) as connection, connection:
cursor = connection.cursor()
cursor.execute('SELECT value FROM MiscData WHERE name == "db_version"')
version = int(cursor.fetchone()[0])
return version
10 changes: 10 additions & 0 deletions src/tribler/core/upgrade/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase
from tribler.core.utilities.configparser import CallbackConfigParser
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import handle_db_if_corrupted
from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR


Expand Down Expand Up @@ -134,6 +135,7 @@ def upgrade_tags_to_knowledge(self):
def upgrade_pony_db_14to15(self):
self._logger.info('Upgrade Pony DB from version 14 to version 15')
mds_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db'
handle_db_if_corrupted(mds_path)

mds = MetadataStore(mds_path, self.channels_dir, self.primary_key, disable_sync=True,
check_tables=False, db_version=14) if mds_path.exists() else None
Expand All @@ -147,6 +149,9 @@ def upgrade_pony_db_13to14(self):
mds_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db'
tagdb_path = self.state_dir / STATEDIR_DB_DIR / 'tags.db'

handle_db_if_corrupted(mds_path)
handle_db_if_corrupted(tagdb_path)

mds = MetadataStore(mds_path, self.channels_dir, self.primary_key, disable_sync=True,
check_tables=False, db_version=13) if mds_path.exists() else None
tag_db = TagDatabase(str(tagdb_path), create_tables=False,
Expand All @@ -166,6 +171,7 @@ def upgrade_pony_db_12to13(self):
self._logger.info('Upgrade Pony DB 12 to 13')
# We have to create the Metadata Store object because Session-managed Store has not been started yet
database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db'
handle_db_if_corrupted(database_path)
if database_path.exists():
mds = MetadataStore(database_path, self.channels_dir, self.primary_key,
disable_sync=True, check_tables=False, db_version=12)
Expand All @@ -181,6 +187,7 @@ def upgrade_pony_db_11to12(self):
self._logger.info('Upgrade Pony DB 11 to 12')
# We have to create the Metadata Store object because Session-managed Store has not been started yet
database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db'
handle_db_if_corrupted(database_path)
if not database_path.exists():
return
mds = MetadataStore(database_path, self.channels_dir, self.primary_key,
Expand All @@ -197,6 +204,7 @@ def upgrade_pony_db_10to11(self):
self._logger.info('Upgrade Pony DB 10 to 11')
# We have to create the Metadata Store object because Session-managed Store has not been started yet
database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db'
handle_db_if_corrupted(database_path)
if not database_path.exists():
return
# code of the migration
Expand All @@ -215,6 +223,7 @@ def upgrade_bw_accounting_db_8to9(self):
to_version = 9

database_path = self.state_dir / STATEDIR_DB_DIR / 'bandwidth.db'
handle_db_if_corrupted(database_path)
if not database_path.exists() or get_db_version(database_path) >= 9:
return # No need to update if the database does not exist or is already updated
self._logger.info('bw8->9')
Expand Down Expand Up @@ -377,6 +386,7 @@ def upgrade_pony_db_8to10(self):
"""
self._logger.info('Upgrading GigaChannel DB from version 8 to 10')
database_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db'
handle_db_if_corrupted(database_path)
if not database_path.exists() or get_db_version(database_path) >= 10:
# Either no old db exists, or the old db version is up to date - nothing to do
return
Expand Down
102 changes: 95 additions & 7 deletions src/tribler/core/utilities/pony_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import sqlite3
import sys
import threading
import time
Expand All @@ -9,15 +10,19 @@
from dataclasses import dataclass
from io import StringIO
from operator import attrgetter
from pathlib import Path
from types import FrameType
from typing import Callable, Dict, Iterable, Optional, Type
from typing import Callable, Dict, Iterable, Optional, Type, Union
from weakref import WeakSet

from contextlib import contextmanager

from pony import orm
from pony.orm import core
from pony.orm.core import Database, select
from pony.orm.dbproviders import sqlite
from pony.utils import cut_traceback, localbase
from pony.orm.dbproviders.sqlite import SQLitePool
from pony.utils import absolutize_path, cut_traceback, cut_traceback_depth, localbase

SLOW_DB_SESSION_DURATION_THRESHOLD = 1.0

Expand All @@ -28,6 +33,56 @@
StatDict = Dict[Optional[str], core.QueryStat]


class DatabaseIsMalformed(Exception):
pass


Filename = Union[str, Path]


def handle_db_if_corrupted(db_filename: Filename):
marker_path = _get_corrupted_db_marker_path(db_filename)
if marker_path.exists():
_handle_corrupted_db(db_filename)


def _handle_corrupted_db(db_filename: Filename):
db_path = Path(db_filename)
if db_path.exists():
db_path.unlink()
marker_path = _get_corrupted_db_marker_path(db_filename)
if marker_path.exists():
marker_path.unlink()


def _get_corrupted_db_marker_path(db_filename: Filename) -> Path:
return Path(str(db_filename) + '.is_corrupted')


@contextmanager
def marking_corrupted_db(db_filename: Filename):
try:
yield
except Exception as e:
if _is_malformed_db_exception(e):
_mark_db_as_corrupted(db_filename)
raise DatabaseIsMalformed(str(e)) from e
raise


def _is_malformed_db_exception(exception):
return isinstance(exception, (core.DatabaseError, sqlite3.DatabaseError)) and 'malformed' in str(exception)


def _mark_db_as_corrupted(db_filename: Filename):
if not Path(db_filename).exists():
raise RuntimeError(f'Corrupted database file not found: {db_filename!r}')

marker_path = _get_corrupted_db_marker_path(db_filename)
marker_path.touch()



# pylint: disable=bad-staticmethod-argument
def get_or_create(cls: Type[core.Entity], create_kwargs=None, **kwargs) -> core.Entity:
"""Get or create db entity.
Expand Down Expand Up @@ -271,6 +326,7 @@ def _merge_stats(stats_iter: Iterable[StatDict]) -> StatDict:


class TriblerSQLiteProvider(sqlite.SQLiteProvider):
pool: TriblerPool

# It is impossible to override the __init__ method without breaking the `SQLiteProvider.get_pool` method's logic.
# Therefore, we don't initialize a new attribute `_acquire_time` inside a class constructor method.
Expand Down Expand Up @@ -298,14 +354,45 @@ def release_lock(self):
lock_hold_duration = time.time() - acquire_time
info.lock_hold_total_duration += lock_hold_duration

def set_transaction_mode(self, connection, cache):
with marking_corrupted_db(self.pool.filename):
return super().set_transaction_mode(connection, cache)

def execute(self, cursor, sql, arguments=None, returning_id=False):
with marking_corrupted_db(self.pool.filename):
return super().execute(cursor, sql, arguments, returning_id)

def mark_db_as_malformed(self):
filename = self.pool.filename
if not Path(filename).exists():
raise RuntimeError(f'Corrupted database file not found: {filename!r}')

marker_filename = filename + '.is_corrupted'
Path(marker_filename).touch()

def get_pool(self, is_shared_memory_db, filename, create_db=False, **kwargs):
if is_shared_memory_db or filename == ':memory:':
pass
else:
filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5) # see the base method for details
handle_db_if_corrupted(filename)
return TriblerPool(is_shared_memory_db, filename, create_db, **kwargs)


class TriblerPool(SQLitePool):
def _connect(self):
with marking_corrupted_db(self.filename):
return super()._connect()


db_session = TriblerDbSession()
orm.db_session = orm.core.db_session = db_session


class TriblerDatabase(Database):
# If a developer what to track the slow execution of the database, he should create an instance of TriblerDatabase
# instead of the usual pony.orm.Database.
# TriblerDatabase extends the functionality of the Database class in the following ways:
# * It adds handling of DatabaseError when the database file is corrupted
# * It accumulates and shows statistics on slow database queries

def __init__(self):
databases_to_track.add(self)
Expand All @@ -314,11 +401,12 @@ def __init__(self):
@cut_traceback
def bind(self, **kwargs):
if 'provider' in kwargs:
raise TypeError('You should not explicitly specify the `provider` keyword argument for TriblerDatabase')
provider = kwargs['provider']
if provider != 'sqlite':
raise TypeError(f'Invalid `provider` argument for TriblerDatabase: {provider!r}')
kwargs.pop('provider')

self._bind(TriblerSQLiteProvider, **kwargs)


def track_slow_db_sessions():
TriblerDbSession.track_slow_db_sessions = True

Loading

0 comments on commit 4ab79e3

Please sign in to comment.