Skip to content

Commit

Permalink
Airflow db downgrade cli command (#21596)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish authored Feb 27, 2022
1 parent e93820b commit 03d6aaa
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 9 deletions.
48 changes: 47 additions & 1 deletion airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ def string_list_type(val):
ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file")
ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file")
ARG_YES = Arg(
("-y", "--yes"), help="Do not prompt to confirm. Use with care!", action="store_true", default=False
("-y", "--yes"),
help="Do not prompt to confirm. Use with care!",
action="store_true",
default=False,
)
ARG_OUTPUT = Arg(
(
Expand Down Expand Up @@ -510,12 +513,42 @@ def string_list_type(val):
ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
ARG_MAP_INDEX = Arg(('--map-index',), type=int, default=-1, help="Mapped task index")


# database
ARG_MIGRATION_TIMEOUT = Arg(
("-t", "--migration-wait-timeout"),
help="timeout to wait for db to migrate ",
type=int,
default=60,
)
ARG_DB_VERSION = Arg(
(
"-n",
"--version",
),
help="The airflow version to downgrade to",
)
ARG_DB_FROM_VERSION = Arg(
("--from-version",),
help="(Optional) if generating sql, may supply a _from_ version",
)
ARG_DB_REVISION = Arg(
(
"-r",
"--revision",
),
help="The airflow revision to downgrade to",
)
ARG_DB_FROM_REVISION = Arg(
("--from-revision",),
help="(Optional) if generating sql, may supply a _from_ revision",
)
ARG_DB_SQL = Arg(
("-s", "--sql-only"),
help="Don't actually run migrations; just print out sql scripts for offline migration.",
action="store_true",
default=False,
)

# webserver
ARG_PORT = Arg(
Expand Down Expand Up @@ -1327,6 +1360,19 @@ class GroupCommand(NamedTuple):
func=lazy_load_command('airflow.cli.commands.db_command.upgradedb'),
args=(ARG_VERSION_RANGE, ARG_REVISION_RANGE),
),
ActionCommand(
name='downgrade',
help="Downgrade the schema of the metadata database",
func=lazy_load_command('airflow.cli.commands.db_command.downgrade'),
args=(
ARG_DB_REVISION,
ARG_DB_VERSION,
ARG_DB_SQL,
ARG_YES,
ARG_DB_FROM_REVISION,
ARG_DB_FROM_VERSION,
),
),
ActionCommand(
name='shell',
help="Runs a shell to access the database",
Expand Down
46 changes: 46 additions & 0 deletions airflow/cli/commands/db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from airflow import settings
from airflow.exceptions import AirflowException
from airflow.utils import cli as cli_utils, db
from airflow.utils.db import REVISION_HEADS_MAP
from airflow.utils.db_cleanup import config_dict, run_cleanup
from airflow.utils.process_utils import execute_interactive

Expand Down Expand Up @@ -50,6 +51,51 @@ def upgradedb(args):
print("Upgrades done")


@cli_utils.action_cli(check_db=False)
def downgrade(args):
"""Downgrades the metadata database"""
if args.revision and args.version:
raise SystemExit("Cannot supply both `revision` and `version`.")
if args.from_version and args.from_revision:
raise SystemExit("`--from-revision` may not be combined with `--from-version`")
if (args.from_revision or args.from_version) and not args.sql_only:
raise SystemExit("Args `--from-revision` and `--from-version` may only be used with `--sql-only`")
if not (args.version or args.revision):
raise SystemExit("Must provide either revision or version.")
from_revision = None
if args.from_revision:
from_revision = args.from_revision
elif args.from_version:
from_revision = REVISION_HEADS_MAP.get(args.from_version)
if not from_revision:
raise SystemExit(f"Unknown version {args.version!r} supplied as `--from-version`.")
if args.version:
revision = REVISION_HEADS_MAP.get(args.version)
if not revision:
raise SystemExit(f"Downgrading to version {args.version} is not supported.")
elif args.revision:
revision = args.revision
if not args.sql_only:
print("Performing downgrade with database " + repr(settings.engine.url))
else:
print("Generating sql for downgrade -- downgrade commands will *not* be submitted.")

if args.sql_only or (
args.yes
or input(
"\nWarning: About to reverse schema migrations for the airflow metastore. "
"Please ensure you have backed up your database before any upgrade or "
"downgrade operation. Proceed? (y/n)\n"
).upper()
== "Y"
):
db.downgrade(to_revision=revision, from_revision=from_revision, sql=args.sql_only)
if not args.sql_only:
print("Downgrade complete")
else:
raise SystemExit("Cancelled")


def check_migrations(args):
"""Function to wait for all airflow migrations to complete. Used for launching airflow in k8s"""
db.check_migrations(timeout=args.migration_wait_timeout)
Expand Down
67 changes: 63 additions & 4 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,12 +1034,12 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
session.commit()


def _offline_migration(command, config, revision):
def _offline_migration(migration_func: Callable, config, revision):
log.info("Running offline migrations for revision range %s", revision)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
logging.disable(logging.CRITICAL)
command.upgrade(config, revision, sql=True)
migration_func(config, revision, sql=True)
logging.disable(logging.NOTSET)


Expand Down Expand Up @@ -1134,10 +1134,10 @@ def upgradedb(
revision = _validate_version_range(command, config, version_range)
if not revision:
return
return _offline_migration(command, config, revision)
return _offline_migration(command.upgrade, config, revision)
elif revision_range:
_validate_revision(command, config, revision_range)
return _offline_migration(command, config, revision_range)
return _offline_migration(command.upgrade, config, revision_range)

errors_seen = False
for err in _check_migration_errors(session=session):
Expand Down Expand Up @@ -1172,6 +1172,65 @@ def resetdb(session: Session = NEW_SESSION):
initdb(session=session)


@provide_session
def downgrade(to_revision, sql=False, from_revision=None, session: Session = NEW_SESSION):
"""
Downgrade the airflow metastore schema to a prior version.
:param to_revision: The alembic revision to downgrade *to*.
:param sql: if True, print sql statements but do not run them
:param from_revision: if supplied, alembic revision to dawngrade *from*. This may only
be used in conjunction with ``sql=True`` because if we actually run the commands,
we should only downgrade from the *current* revision.
:param session: sqlalchemy session for connection to airflow metadata database
"""
if from_revision and not sql:
raise ValueError(
"`from_revision` can't be combined with `sql=False`. When actually "
"applying a downgrade (instead of just generating sql), we always "
"downgrade from current revision."
)

if not settings.SQL_ALCHEMY_CONN:
raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.")

# alembic adds significant import time, so we import it lazily
from alembic import command

log.info("Attempting downgrade to revision %s", to_revision)

config = _get_alembic_config()

config.set_main_option('sqlalchemy.url', settings.SQL_ALCHEMY_CONN.replace('%', '%%'))

errors_seen = False
for err in _check_migration_errors(session=session):
if not errors_seen:
log.error("Automatic migration failed. You may need to apply downgrades manually. ")
errors_seen = True
log.error("%s", err)

if errors_seen:
exit(1)

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
if sql:
log.warning("Generating sql scripts for manual migration.")

conn = session.connection()

from alembic.migration import MigrationContext

migration_ctx = MigrationContext.configure(conn)
if not from_revision:
from_revision = migration_ctx.get_current_revision()
revision_range = f"{from_revision}:{to_revision}"
_offline_migration(command.downgrade, config=config, revision=revision_range)
else:
log.info("Applying downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=sql)


def drop_airflow_models(connection):
"""
Drops all airflow models.
Expand Down
72 changes: 69 additions & 3 deletions tests/cli/commands/test_db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

import unittest
from unittest import mock
from unittest.mock import patch

Expand All @@ -28,9 +27,9 @@
from airflow.exceptions import AirflowException


class TestCliDb(unittest.TestCase):
class TestCliDb:
@classmethod
def setUpClass(cls):
def setup_class(cls):
cls.parser = cli_parser.get_parser()

@mock.patch("airflow.cli.commands.db_command.db.initdb")
Expand Down Expand Up @@ -137,6 +136,73 @@ def test_cli_shell_invalid(self):
with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg2"):
db_command.shell(self.parser.parse_args(['db', 'shell']))

@pytest.mark.parametrize(
'args, match',
[
(['-y', '--revision', 'abc', '--version', '2.2.0'], 'Cannot supply both'),
(['-y', '--revision', 'abc1', '--from-revision', 'abc2'], 'only .* with `--sql-only`'),
(['-y', '--revision', 'abc1', '--from-version', '2.2.2'], 'only .* with `--sql-only`'),
(['-y', '--version', '2.2.2', '--from-version', '2.2.2'], 'only .* with `--sql-only`'),
(
['-y', '--revision', 'abc', '--from-version', '2.2.0', '--from-revision', 'abc'],
'may not be combined',
),
(['-y', '--version', 'abc'], r'Downgrading to .* not supported\.'),
(['-y'], 'Must provide either'),
],
)
@mock.patch("airflow.utils.db.downgrade")
def test_cli_downgrade_invalid(self, mock_dg, args, match):
"""We test some options that should produce an error"""

with pytest.raises(SystemExit, match=match):
db_command.downgrade(self.parser.parse_args(['db', 'downgrade', *args]))

@pytest.mark.parametrize(
'args, expected',
[
(['-y', '--revision', 'abc1'], dict(to_revision='abc1')),
(
['-y', '--revision', 'abc1', '--from-revision', 'abc2', '-s'],
dict(to_revision='abc1', from_revision='abc2', sql=True),
),
(
['-y', '--revision', 'abc1', '--from-version', '2.2.2', '-s'],
dict(to_revision='abc1', from_revision='7b2661a43ba3', sql=True),
),
(
['-y', '--version', '2.2.2', '--from-version', '2.2.2', '-s'],
dict(to_revision='7b2661a43ba3', from_revision='7b2661a43ba3', sql=True),
),
(['-y', '--version', '2.2.2'], dict(to_revision='7b2661a43ba3')),
],
)
@mock.patch("airflow.utils.db.downgrade")
def test_cli_downgrade_good(self, mock_dg, args, expected):
defaults = dict(from_revision=None, sql=False)
db_command.downgrade(self.parser.parse_args(['db', 'downgrade', *args]))
mock_dg.assert_called_with(**{**defaults, **expected})

@pytest.mark.parametrize(
'resp, raise_',
[
('y', False),
('Y', False),
('n', True),
('a', True), # any other value
],
)
@mock.patch("airflow.utils.db.downgrade")
@mock.patch("airflow.cli.commands.db_command.input")
def test_cli_downgrade_confirm(self, mock_input, mock_dg, resp, raise_):
mock_input.return_value = resp
if raise_:
with pytest.raises(SystemExit):
db_command.downgrade(self.parser.parse_args(['db', 'downgrade', '--revision', 'abc']))
else:
db_command.downgrade(self.parser.parse_args(['db', 'downgrade', '--revision', 'abc']))
mock_dg.assert_called_with(to_revision='abc', from_revision=None, sql=False)


class TestCLIDBClean:
@classmethod
Expand Down
26 changes: 25 additions & 1 deletion tests/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.exceptions import AirflowException
from airflow.models import Base as airflow_base
from airflow.settings import engine
from airflow.utils.db import check_migrations, create_default_connections, upgradedb
from airflow.utils.db import check_migrations, create_default_connections, downgrade, upgradedb


class TestDb:
Expand Down Expand Up @@ -213,3 +213,27 @@ def test_versions_without_migration_donot_raise(self):
with mock.patch('alembic.command.upgrade') as mock_alembic_upgrade:
upgradedb("2.1.1:2.1.2")
mock_alembic_upgrade.assert_not_called()

@mock.patch('airflow.utils.db._offline_migration')
def test_downgrade_sql_no_from(self, mock_om):
downgrade(to_revision='abc', sql=True, from_revision=None)
actual = mock_om.call_args[1]['revision']
assert re.match(r'[a-z0-9]+:abc', actual) is not None

@mock.patch('airflow.utils.db._offline_migration')
def test_downgrade_sql_with_from(self, mock_om):
downgrade(to_revision='abc', sql=True, from_revision='123')
actual = mock_om.call_args[1]['revision']
assert actual == '123:abc'

@mock.patch('alembic.command.downgrade')
def test_downgrade_invalid_combo(self, mock_om):
"""can't combine `sql=False` and `from_revision`"""
with pytest.raises(ValueError, match="can't be combined"):
downgrade(to_revision='abc', from_revision='123')

@mock.patch('alembic.command.downgrade')
def test_downgrade_with_from(self, mock_om):
downgrade(to_revision='abc')
actual = mock_om.call_args[1]['revision']
assert actual == 'abc'

0 comments on commit 03d6aaa

Please sign in to comment.