Skip to content

Commit

Permalink
Add upgrade procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Jan 25, 2022
1 parent 380fc33 commit 3ca6a6d
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 39 deletions.
11 changes: 10 additions & 1 deletion src/tribler-core/run_tribler_upgrader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import signal
import sys

Expand All @@ -10,17 +11,21 @@
from tribler_core.upgrade.upgrade import TriblerUpgrader
from tribler_core.utilities.path_util import Path

logger = logging.getLogger(__name__)


def upgrade_state_dir(root_state_dir: Path,
update_status_callback=None,
interrupt_upgrade_event=None):
logger.info('Upgrade state dir')
# Before any upgrade, prepare a separate state directory for the update version so it does not
# affect the older version state directory. This allows for safe rollback.
version_history = VersionHistory(root_state_dir)
version_history.fork_state_directory_if_necessary()
version_history.save_if_necessary()
state_dir = version_history.code_version.directory
if not state_dir.exists():
logger.info('State dir does not exist. Exit upgrade procedure.')
return

config = TriblerConfig.load(file=state_dir / CONFIG_FILE_NAME, state_dir=state_dir, reset_config_on_error=True)
Expand All @@ -37,15 +42,19 @@ def upgrade_state_dir(root_state_dir: Path,


if __name__ == "__main__":

logger.info('Run')
_upgrade_interrupted_event = []


def interrupt_upgrade(*_):
logger.info('Interrupt upgrade')
_upgrade_interrupted_event.append(True)


def upgrade_interrupted():
return bool(_upgrade_interrupted_event)


signal.signal(signal.SIGINT, interrupt_upgrade)
signal.signal(signal.SIGTERM, interrupt_upgrade)
_root_state_dir = Path(sys.argv[1])
Expand Down
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions src/tribler-core/tribler_core/upgrade/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def convert_config_to_tribler76(state_dir):
"""
config = ConfigObj(infile=(str(state_dir / 'triblerd.conf')), default_encoding='utf-8')
if 'http_api' in config:
logger.info('Convert config')
config['api'] = {}
config['api']['http_enabled'] = config['http_api'].get('enabled', False)
config['api']['http_port'] = config['http_api'].get('port', -1)
Expand Down
91 changes: 56 additions & 35 deletions src/tribler-core/tribler_core/upgrade/tests/test_upgrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from tribler_core.components.bandwidth_accounting.db.database import BandwidthDatabase
from tribler_core.components.metadata_store.db.orm_bindings.channel_metadata import CHANNEL_DIR_NAME_LENGTH
from tribler_core.components.metadata_store.db.store import CURRENT_DB_VERSION, MetadataStore
from tribler_core.components.tag.db.tag_db import TagDatabase
from tribler_core.notifier import Notifier
from tribler_core.tests.tools.common import TESTS_DATA_DIR
from tribler_core.upgrade.db8_to_db10 import calc_progress
from tribler_core.upgrade.upgrade import TriblerUpgrader, cleanup_noncompliant_channel_torrents
from tribler_core.utilities.configparser import CallbackConfigParser

# pylint: disable=redefined-outer-name, protected-access

@pytest.fixture
def state_dir(tmp_path):
Expand Down Expand Up @@ -46,16 +48,28 @@ def notifier():
return Notifier()


def test_upgrade_pony_db_complete(upgrader, channels_dir, state_dir, trustchain_keypair): # pylint: disable=W0621
@pytest.fixture
def mds_path(state_dir):
return state_dir / 'sqlite/metadata.db'


def _copy(source_name, target):
source = TESTS_DATA_DIR / 'upgrade_databases' / source_name
shutil.copyfile(source, target)


def test_upgrade_pony_db_complete(upgrader, channels_dir, state_dir, trustchain_keypair,
mds_path): # pylint: disable=W0621
"""
Test complete update sequence for Pony DB (e.g. 6->7->8)
"""
old_db_sample = TESTS_DATA_DIR / 'upgrade_databases' / 'pony_v8.db'
old_database_path = state_dir / 'sqlite' / 'metadata.db'
shutil.copyfile(old_db_sample, old_database_path)
tags_path = state_dir / 'sqlite/tags.db'

_copy(source_name='pony_v8.db', target=mds_path)
_copy(source_name='tags_v13.db', target=tags_path)

upgrader.run()
mds = MetadataStore(old_database_path, channels_dir, trustchain_keypair)
mds = MetadataStore(mds_path, channels_dir, trustchain_keypair)
db = mds._db # pylint: disable=protected-access

existing_indexes = [
Expand All @@ -80,7 +94,7 @@ def test_upgrade_pony_db_complete(upgrader, channels_dir, state_dir, trustchain_
with db_session:
assert mds.TorrentMetadata.select().count() == 23
assert mds.ChannelMetadata.select().count() == 2
assert int(mds.MiscData.get(name="db_version").value) == CURRENT_DB_VERSION
assert mds.get_value("db_version") == str(CURRENT_DB_VERSION)
for index_name in existing_indexes:
assert list(db.execute(f'PRAGMA index_info("{index_name}")'))
for index_name in removed_indexes:
Expand Down Expand Up @@ -112,57 +126,65 @@ def test_delete_noncompliant_state(tmpdir):
assert CHANNEL_DIR_NAME_LENGTH == len(pstate.get('state', 'metainfo')['info']['name'])


def test_upgrade_pony_8to10(upgrader, channels_dir, state_dir, trustchain_keypair): # pylint: disable=W0621
old_db_sample = TESTS_DATA_DIR / 'upgrade_databases' / 'pony_v8.db'
database_path = state_dir / 'sqlite' / 'metadata.db'
shutil.copyfile(old_db_sample, database_path)
def test_upgrade_pony_8to10(upgrader, channels_dir, mds_path, trustchain_keypair): # pylint: disable=W0621
_copy('pony_v8.db', mds_path)

upgrader.upgrade_pony_db_8to10()
mds = MetadataStore(database_path, channels_dir, trustchain_keypair, check_tables=False, db_version=10)
mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=10)
with db_session:
assert int(mds.MiscData.get(name="db_version").value) == 10
assert mds.get_value("db_version") == '10'
assert mds.ChannelNode.select().count() == 23
mds.shutdown()


@pytest.mark.asyncio
async def test_upgrade_pony_10to11(upgrader, channels_dir, state_dir, trustchain_keypair):
old_db_sample = TESTS_DATA_DIR / 'upgrade_databases' / 'pony_v10.db'
database_path = state_dir / 'sqlite' / 'metadata.db'
shutil.copyfile(old_db_sample, database_path)
async def test_upgrade_pony_10to11(upgrader, channels_dir, mds_path, trustchain_keypair):
_copy('pony_v10.db', mds_path)

upgrader.upgrade_pony_db_10to11()
mds = MetadataStore(database_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11)
mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11)
with db_session:
# pylint: disable=protected-access
assert upgrader.column_exists_in_table(mds._db, 'TorrentState', 'self_checked')
assert int(mds.MiscData.get(name="db_version").value) == 11
assert mds.get_value("db_version") == '11'
mds.shutdown()


def test_upgrade_pony11to12(upgrader, channels_dir, state_dir, trustchain_keypair):
old_db_sample = TESTS_DATA_DIR / 'upgrade_databases' / 'pony_v11.db'
database_path = state_dir / 'sqlite' / 'metadata.db'
shutil.copyfile(old_db_sample, database_path)
def test_upgrade_pony11to12(upgrader, channels_dir, mds_path, trustchain_keypair):
_copy('pony_v11.db', mds_path)

upgrader.upgrade_pony_db_11to12()
mds = MetadataStore(database_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11)
mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11)
with db_session:
# pylint: disable=protected-access
assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'json_text')
assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'binary_data')
assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'data_type')
assert int(mds.MiscData.get(name="db_version").value) == 12
assert mds.get_value("db_version") == '12'
mds.shutdown()


def test_upgrade_pony12to13(upgrader, channels_dir, state_dir, trustchain_keypair): # pylint: disable=W0621
old_db_sample = TESTS_DATA_DIR / 'upgrade_databases' / 'pony_v12.db'
database_path = state_dir / 'sqlite' / 'metadata.db'
shutil.copyfile(old_db_sample, database_path)
def test_upgrade_pony13to14(upgrader: TriblerUpgrader, state_dir, channels_dir, trustchain_keypair, mds_path):
tags_path = state_dir / 'sqlite/tags.db'

_copy(source_name='pony_v13.db', target=mds_path)
_copy(source_name='tags_v13.db', target=tags_path)

upgrader.upgrade_pony_db_13to14()
mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False)
tags = TagDatabase(str(tags_path), check_tables=False)

with db_session:
assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'tag_version')
assert upgrader.column_exists_in_table(tags.instance, 'TorrentTagOp', 'auto_generated')
assert mds.get_value('db_version') == '14'


def test_upgrade_pony12to13(upgrader, channels_dir, mds_path, trustchain_keypair): # pylint: disable=W0621
_copy('pony_v12.db', mds_path)

upgrader.upgrade_pony_db_12to13()
mds = MetadataStore(database_path, channels_dir, trustchain_keypair, check_tables=False, db_version=12)
mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=12)
db = mds._db # pylint: disable=protected-access

existing_indexes = [
Expand All @@ -187,7 +209,7 @@ def test_upgrade_pony12to13(upgrader, channels_dir, state_dir, trustchain_keypai
with db_session:
assert mds.TorrentMetadata.select().count() == 23
assert mds.ChannelMetadata.select().count() == 2
assert int(mds.MiscData.get(name="db_version").value) == CURRENT_DB_VERSION
assert mds.get_value("db_version") == '13'
for index_name in existing_indexes:
assert list(db.execute(f'PRAGMA index_info("{index_name}")')), index_name
for index_name in removed_indexes:
Expand Down Expand Up @@ -216,13 +238,12 @@ def test_calc_progress():


@pytest.mark.asyncio
async def test_upgrade_bw_accounting_db_8to9(upgrader, channels_dir, state_dir, trustchain_keypair):
old_db_sample = TESTS_DATA_DIR / 'upgrade_databases' / 'bandwidth_v8.db'
database_path = state_dir / 'sqlite' / 'bandwidth.db'
shutil.copyfile(old_db_sample, database_path)
async def test_upgrade_bw_accounting_db_8to9(upgrader, state_dir, trustchain_keypair):
bandwidth_path = state_dir / 'sqlite/bandwidth.db'
_copy('bandwidth_v8.db', bandwidth_path)

upgrader.upgrade_bw_accounting_db_8to9()
db = BandwidthDatabase(database_path, trustchain_keypair.key.pk)
db = BandwidthDatabase(bandwidth_path, trustchain_keypair.key.pk)
with db_session:
assert not list(select(tx for tx in db.BandwidthTransaction))
assert not list(select(item for item in db.BandwidthHistory))
Expand Down
52 changes: 49 additions & 3 deletions src/tribler-core/tribler_core/upgrade/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import shutil
from configparser import MissingSectionHeaderError, ParsingError
from types import SimpleNamespace
from typing import Optional

from ipv8.keyvault.private.libnaclkey import LibNaCLSK

Expand All @@ -17,11 +19,13 @@
sql_create_partial_index_channelnode_subscribed,
sql_create_partial_index_torrentstate_last_check,
)
from tribler_core.components.tag.db.tag_db import TagDatabase
from tribler_core.upgrade.config_converter import convert_config_to_tribler76
from tribler_core.upgrade.db8_to_db10 import PonyToPonyMigration, get_db_version
from tribler_core.utilities.configparser import CallbackConfigParser
from tribler_core.utilities.path_util import Path

# pylint: disable=protected-access

def cleanup_noncompliant_channel_torrents(state_dir):
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,13 +91,30 @@ def run(self):
"""
Run the upgrader if it is enabled in the config.
"""
self._logger.info('Run')

self.upgrade_pony_db_8to10()
self.upgrade_pony_db_10to11()
convert_config_to_tribler76(self.state_dir)
self.upgrade_bw_accounting_db_8to9()
self.upgrade_pony_db_11to12()
self.upgrade_pony_db_12to13()
self.upgrade_pony_db_13to14()

def upgrade_pony_db_13to14(self):
mds_path = self.state_dir / STATEDIR_DB_DIR / 'metadata.db'
tagdb_path = self.state_dir / STATEDIR_DB_DIR / 'tags.db'

mds = MetadataStore(mds_path, self.channels_dir, self.trustchain_keypair, disable_sync=True,
check_tables=False, db_version=13) if mds_path.exists() else None
tagdb = TagDatabase(str(tagdb_path), check_tables=False) if tagdb_path.exists() else None

self.do_upgrade_pony_db_13to14(mds, tagdb)

if mds:
mds.shutdown()
if tagdb:
tagdb.shutdown()

def upgrade_pony_db_12to13(self):
"""
Expand Down Expand Up @@ -149,6 +170,7 @@ def upgrade_bw_accounting_db_8to9(self):
database_path = self.state_dir / STATEDIR_DB_DIR / 'bandwidth.db'
if not database_path.exists() or get_db_version(database_path) >= 9:
return # No need to update if the database does not exist or is already updated
self._logger.info('bw8->9')
db = BandwidthDatabase(database_path, self.trustchain_keypair.key.pk)

# Wipe all transactions and bandwidth history
Expand Down Expand Up @@ -183,7 +205,7 @@ def do_upgrade_pony_db_12to13(self, mds):
db_version = mds.MiscData.get(name="db_version")
if int(db_version.value) != from_version:
return

self._logger.info(f'{from_version}->{to_version}')
db.execute('DROP INDEX IF EXISTS idx_channelnode__public_key')
db.execute('DROP INDEX IF EXISTS idx_channelnode__status')
db.execute('DROP INDEX IF EXISTS idx_channelnode__size')
Expand All @@ -206,6 +228,29 @@ def do_upgrade_pony_db_12to13(self, mds):

db_version.value = str(to_version)

def do_upgrade_pony_db_13to14(self, mds: Optional[MetadataStore], tags: Optional[TagDatabase]):
def _alter(db, table_name, column_name, column_type):
if not self.column_exists_in_table(db, table_name, column_name):
db.execute(f'ALTER TABLE "{table_name}" ADD "{column_name}" {column_type} DEFAULT 0')

if not mds:
return

version = SimpleNamespace(current='13', next='14')
with db_session:
db_version = mds.get_value(key='db_version')
if db_version != version.current:
return

self._logger.info(f'{version.current}->{version.next}')

_alter(db=mds._db, table_name='ChannelNode', column_name='tag_version', column_type='INT')
_alter(db=tags.instance, table_name='TorrentTagOp', column_name='auto_generated', column_type='BOOLEAN')

tags.instance.commit()
mds._db.commit()
mds.set_value(key='db_version', value=version.next)

def do_upgrade_pony_db_11to12(self, mds):
from_version = 11
to_version = 12
Expand All @@ -214,7 +259,7 @@ def do_upgrade_pony_db_11to12(self, mds):
db_version = mds.MiscData.get(name="db_version")
if int(db_version.value) != from_version:
return

self._logger.info(f'{from_version}->{to_version}')
# Just in case, we skip altering table if the column is somehow already there
table_name = "ChannelNode"
new_columns = [("json_text", "TEXT1"),
Expand All @@ -238,6 +283,7 @@ def do_upgrade_pony_db_10to11(self, mds):
db_version = mds.MiscData.get(name="db_version")
if int(db_version.value) != from_version:
return
self._logger.info(f'{from_version}->{to_version}')

# Just in case, we skip altering table if the column is somehow already there
table_name = "TorrentState"
Expand All @@ -260,7 +306,7 @@ def upgrade_pony_db_8to10(self):
if not database_path.exists() or get_db_version(database_path) >= 10:
# Either no old db exists, or the old db version is up to date - nothing to do
return

self._logger.info('8->10')
# Otherwise, start upgrading
self.update_status("STARTING")
tmp_database_path = database_path.parent / 'metadata_upgraded.db'
Expand Down

0 comments on commit 3ca6a6d

Please sign in to comment.