Skip to content

Commit

Permalink
Replace sqlite_utils with db_corruption_handling.sqlite_replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Nov 23, 2023
1 parent 3cc66a8 commit 2a9fa6a
Show file tree
Hide file tree
Showing 18 changed files with 232 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from tribler.core.components.bandwidth_accounting.db import history, misc, transaction as db_transaction
from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData
from tribler.core.utilities.db_corruption_handling.base import handle_db_if_corrupted
from tribler.core.utilities.pony_utils import TriblerDatabase
from tribler.core.utilities.sqlite_utils import handle_db_if_corrupted
from tribler.core.utilities.utilities import MEMORY_DB


Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from tribler.core.components.exceptions import ComponentStartupException, MissedDependency, NoneComponent
from tribler.core.components.reporter.exception_handler import default_core_exception_handler
from tribler.core.sentry_reporter.sentry_reporter import SentryReporter
from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted
from tribler.core.utilities.exit_codes import EXITCODE_DATABASE_IS_CORRUPTED
from tribler.core.utilities.sqlite_utils import DatabaseIsCorrupted
from tribler.core.utilities.process_manager import get_global_process_manager

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from tribler.core.components.metadata_store.db.orm_bindings.channel_node import COMMITTED
from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT
from tribler.core.components.metadata_store.db.store import MetadataStore
from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.sqlite_utils import DatabaseIsCorrupted
from tribler.core.utilities.simpledefs import DownloadStatus
from tribler.core.utilities.unicode import hexlify

Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/components/metadata_store/db/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@
from tribler.core.components.metadata_store.remote_query_community.payload_checker import process_payload
from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo
from tribler.core.exceptions import InvalidSignatureException
from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted, handle_db_if_corrupted
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import TriblerDatabase, get_max, get_or_create, run_threaded
from tribler.core.utilities.search_utils import torrent_rank
from tribler.core.utilities.sqlite_utils import DatabaseIsCorrupted, handle_db_if_corrupted
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import MEMORY_DB

Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/components/reporter/exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from tribler.core.components.exceptions import ComponentStartupException
from tribler.core.components.reporter.reported_error import ReportedError
from tribler.core.sentry_reporter.sentry_reporter import SentryReporter
from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted
from tribler.core.utilities.exit_codes import EXITCODE_DATABASE_IS_CORRUPTED
from tribler.core.utilities.sqlite_utils import DatabaseIsCorrupted
from tribler.core.utilities.process_manager import get_global_process_manager

# There are some errors that we are ignoring.
Expand Down
6 changes: 3 additions & 3 deletions src/tribler/core/upgrade/db8_to_db10.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pony.orm import db_session

from tribler.core.components.metadata_store.db.store import MetadataStore
from tribler.core.utilities import sqlite_utils
from tribler.core.utilities.db_corruption_handling import sqlite_replacement

TABLE_NAMES = (
"ChannelNode", "TorrentState", "TorrentState_TrackerState", "ChannelPeer", "ChannelVote", "TrackerState", "Vsids")
Expand Down Expand Up @@ -130,7 +130,7 @@ def do_migration(self):
for table_name in TABLE_NAMES:
old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name)

with contextlib.closing(sqlite_utils.connect(self.new_db_path)) as connection:
with contextlib.closing(sqlite_replacement.connect(self.new_db_path)) as connection:
with connection:
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = OFF;")
Expand Down Expand Up @@ -234,7 +234,7 @@ def calc_progress(duration_now, duration_half=60.0):


def get_table_columns(db_path, table_name):
with contextlib.closing(sqlite_utils.connect(db_path)) as connection, connection:
with contextlib.closing(sqlite_replacement.connect(db_path)) as connection, connection:
cursor = connection.cursor()
cursor.execute(f'SELECT * FROM {table_name} LIMIT 1')
names = [description[0] for description in cursor.description]
Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/upgrade/tests/test_upgrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tribler.core.upgrade.upgrade import TriblerUpgrader, catch_db_is_corrupted_exception, \
cleanup_noncompliant_channel_torrents
from tribler.core.utilities.configparser import CallbackConfigParser
from tribler.core.utilities.sqlite_utils import DatabaseIsCorrupted
from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted
from tribler.core.utilities.utilities import random_infohash


Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/upgrade/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge
from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase
from tribler.core.utilities.configparser import CallbackConfigParser
from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import get_db_version
from tribler.core.utilities.sqlite_utils import DatabaseIsCorrupted
from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR


Expand Down
Empty file.
59 changes: 59 additions & 0 deletions src/tribler/core/utilities/db_corruption_handling/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import logging
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Union

logger = logging.getLogger('db_corruption_handling')


class DatabaseIsCorrupted(Exception):
pass


@contextmanager
def handling_malformed_db_error(db_filepath: Path):
# Used in all methods of Connection and Cursor classes where the database corruption error can occur
try:
yield
except Exception as e:
if _is_malformed_db_exception(e):
_mark_db_as_corrupted(db_filepath)
raise DatabaseIsCorrupted(str(db_filepath)) from e
raise


def handle_db_if_corrupted(db_filename: Union[str, Path]):
# Checks if the database is marked as corrupted and handles it by removing the database file and the marker file
db_path = Path(db_filename)
marker_path = get_corrupted_db_marker_path(db_path)
if marker_path.exists():
_handle_corrupted_db(db_path)


def get_corrupted_db_marker_path(db_filepath: Path) -> Path:
return Path(str(db_filepath) + '.is_corrupted')


def _is_malformed_db_exception(exception):
return isinstance(exception, sqlite3.DatabaseError) and 'malformed' in str(exception)


def _mark_db_as_corrupted(db_filepath: Path):
# Creates a new `*.is_corrupted` marker file alongside the database file
marker_path = get_corrupted_db_marker_path(db_filepath)
marker_path.touch()


def _handle_corrupted_db(db_path: Path):
# Removes the database file and the marker file
if db_path.exists():
logger.warning(f'Database file was marked as corrupted, removing it: {db_path}')
db_path.unlink()

marker_path = get_corrupted_db_marker_path(db_path)
if marker_path.exists():
logger.warning(f'Removing the corrupted database marker: {marker_path}')
marker_path.unlink()
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import logging
import sqlite3
import sys
from contextlib import contextmanager
from pathlib import Path
from sqlite3 import DataError, DatabaseError, Error, IntegrityError, InterfaceError, InternalError, NotSupportedError, \
OperationalError, ProgrammingError, Warning, sqlite_version_info # pylint: disable=unused-import, redefined-builtin
from typing import Any, Generator, List, Literal, TypeVar, Union

from tribler.core.utilities.db_corruption_handling.base import handling_malformed_db_error


# This module serves as a replacement to the sqlite3 module and handles the case when the database is corrupted.
Expand All @@ -20,71 +19,18 @@
# After that, the database is recreated upon the next attempt to connect to it.


logger = logging.getLogger(__name__)


class DatabaseIsCorrupted(Exception):
pass


@contextmanager
def _handling_malformed_db_error(db_filepath: Path):
# Used in all methods of Connection and Cursor classes where the database corruption error can occur
try:
yield
except Exception as e:
if _is_malformed_db_exception(e):
_mark_db_as_corrupted(db_filepath)
raise DatabaseIsCorrupted(str(db_filepath)) from e
raise


def _is_malformed_db_exception(exception):
return isinstance(exception, sqlite3.DatabaseError) and 'malformed' in str(exception)


def _mark_db_as_corrupted(db_filepath: Path):
# Creates a new `*.is_corrupted` marker file alongside the database file
marker_path = get_corrupted_db_marker_path(db_filepath)
marker_path.touch()


def get_corrupted_db_marker_path(db_filepath: Path) -> Path:
return Path(str(db_filepath) + '.is_corrupted')


def handle_db_if_corrupted(db_filename: Union[str, Path]):
# Checks if the database is marked as corrupted and handles it by removing the database file and the marker file
db_path = Path(db_filename)
marker_path = get_corrupted_db_marker_path(db_path)
if marker_path.exists():
_handle_corrupted_db(db_path)


def _handle_corrupted_db(db_path: Path):
# Removes the database file and the marker file
if db_path.exists():
logger.warning(f'Database file was marked as corrupted, removing it: {db_path}')
db_path.unlink()

marker_path = get_corrupted_db_marker_path(db_path)
if marker_path.exists():
logger.warning(f'Removing the corrupted database marker: {marker_path}')
marker_path.unlink()


def connect(db_filename: str, **kwargs) -> sqlite3.Connection:
# Replaces the sqlite3.connect function
kwargs['factory'] = Connection
with _handling_malformed_db_error(Path(db_filename)):
with handling_malformed_db_error(Path(db_filename)):
return sqlite3.connect(db_filename, **kwargs)


def _add_method_wrapper_that_handles_malformed_db_exception(cls, method_name: str):
# Creates a wrapper for the given method that handles the case when the database is corrupted

def wrapper(self, *args, **kwargs):
with _handling_malformed_db_error(self._db_filepath):
with handling_malformed_db_error(self._db_filepath):
return getattr(super(cls, self), method_name)(*args, **kwargs)

wrapper.__name__ = method_name
Expand Down Expand Up @@ -128,7 +74,7 @@ def iterdump(self):
raise NotImplementedError

def blobopen(self, *args, **kwargs) -> Blob: # Works for Python >= 3.11
with _handling_malformed_db_error(self._db_filepath):
with handling_malformed_db_error(self._db_filepath):
blob = super().blobopen(*args, **kwargs)
return Blob(blob, self._db_filepath)

Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from tribler.core.utilities.db_corruption_handling.sqlite_replacement import connect


@pytest.fixture(name='db_filepath')
def db_filepath_fixture(tmp_path):
return tmp_path / 'test.db'


@pytest.fixture(name='connection')
def connection_fixture(db_filepath):
connection = connect(str(db_filepath))
yield connection
connection.close()
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sqlite3
from pathlib import Path
from unittest.mock import Mock, patch

import pytest


from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted, handle_db_if_corrupted, \
handling_malformed_db_error

malformed_error = sqlite3.DatabaseError('database disk image is malformed')


def test_handling_malformed_db_error__no_error(db_filepath):
# If no error is raised, the database should not be marked as corrupted
with handling_malformed_db_error(db_filepath):
pass

assert not Path(str(db_filepath) + '.is_corrupted').exists()


def test_handling_malformed_db_error__malformed_error(db_filepath):
# Malformed database errors should be handled by marking the database as corrupted
with pytest.raises(DatabaseIsCorrupted):
with handling_malformed_db_error(db_filepath):
raise malformed_error

assert Path(str(db_filepath) + '.is_corrupted').exists()


def test_handling_malformed_db_error__other_error(db_filepath):
# Other errors should not be handled like malformed database errors
class TestError(Exception):
pass

with pytest.raises(TestError):
with handling_malformed_db_error(db_filepath):
raise TestError()

assert not Path(str(db_filepath) + '.is_corrupted').exists()


def test_handle_db_if_corrupted__corrupted(db_filepath: Path):
# If the corruption marker is found, the corrupted database file is removed
marker_path = Path(str(db_filepath) + '.is_corrupted')
marker_path.touch()

handle_db_if_corrupted(db_filepath)
assert not db_filepath.exists()
assert not marker_path.exists()


@patch('tribler.core.utilities.db_corruption_handling.base._handle_corrupted_db')
def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_filepath: Path):
# If the corruption marker is not found, the handling of the database is not performed
handle_db_if_corrupted(db_filepath)
handle_corrupted_db.assert_not_called()
Loading

0 comments on commit 2a9fa6a

Please sign in to comment.