From 03d6aaa92640cd0512ef8be0583e4ce1b1bbd9eb Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Sun, 27 Feb 2022 11:10:28 -0800 Subject: [PATCH] Airflow `db downgrade` cli command (#21596) --- airflow/cli/cli_parser.py | 48 +++++++++++++++++- airflow/cli/commands/db_command.py | 46 +++++++++++++++++ airflow/utils/db.py | 67 +++++++++++++++++++++++-- tests/cli/commands/test_db_command.py | 72 +++++++++++++++++++++++++-- tests/utils/test_db.py | 26 +++++++++- 5 files changed, 250 insertions(+), 9 deletions(-) diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index cb0bcd29cc697..e74edec9a5b53 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -212,7 +212,10 @@ def string_list_type(val): ARG_STDOUT = Arg(("--stdout",), help="Redirect stdout to this file") ARG_LOG_FILE = Arg(("-l", "--log-file"), help="Location of the log file") ARG_YES = Arg( - ("-y", "--yes"), help="Do not prompt to confirm. Use with care!", action="store_true", default=False + ("-y", "--yes"), + help="Do not prompt to confirm. Use with care!", + action="store_true", + default=False, ) ARG_OUTPUT = Arg( ( @@ -510,12 +513,42 @@ def string_list_type(val): ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg") ARG_MAP_INDEX = Arg(('--map-index',), type=int, default=-1, help="Mapped task index") + +# database ARG_MIGRATION_TIMEOUT = Arg( ("-t", "--migration-wait-timeout"), help="timeout to wait for db to migrate ", type=int, default=60, ) +ARG_DB_VERSION = Arg( + ( + "-n", + "--version", + ), + help="The airflow version to downgrade to", +) +ARG_DB_FROM_VERSION = Arg( + ("--from-version",), + help="(Optional) if generating sql, may supply a _from_ version", +) +ARG_DB_REVISION = Arg( + ( + "-r", + "--revision", + ), + help="The airflow revision to downgrade to", +) +ARG_DB_FROM_REVISION = Arg( + ("--from-revision",), + help="(Optional) if generating sql, may supply a _from_ revision", +) +ARG_DB_SQL = Arg( + ("-s", "--sql-only"), + help="Don't actually run migrations; just print out sql scripts for offline migration.", + action="store_true", + default=False, +) # webserver ARG_PORT = Arg( @@ -1327,6 +1360,19 @@ class GroupCommand(NamedTuple): func=lazy_load_command('airflow.cli.commands.db_command.upgradedb'), args=(ARG_VERSION_RANGE, ARG_REVISION_RANGE), ), + ActionCommand( + name='downgrade', + help="Downgrade the schema of the metadata database", + func=lazy_load_command('airflow.cli.commands.db_command.downgrade'), + args=( + ARG_DB_REVISION, + ARG_DB_VERSION, + ARG_DB_SQL, + ARG_YES, + ARG_DB_FROM_REVISION, + ARG_DB_FROM_VERSION, + ), + ), ActionCommand( name='shell', help="Runs a shell to access the database", diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py index 02811d0246387..828f352b1aa96 100644 --- a/airflow/cli/commands/db_command.py +++ b/airflow/cli/commands/db_command.py @@ -22,6 +22,7 @@ from airflow import settings from airflow.exceptions import AirflowException from airflow.utils import cli as cli_utils, db +from airflow.utils.db import REVISION_HEADS_MAP from airflow.utils.db_cleanup import config_dict, run_cleanup from airflow.utils.process_utils import execute_interactive @@ -50,6 +51,51 @@ def upgradedb(args): print("Upgrades done") +@cli_utils.action_cli(check_db=False) +def downgrade(args): + """Downgrades the metadata database""" + if args.revision and args.version: + raise SystemExit("Cannot supply both `revision` and `version`.") + if args.from_version and args.from_revision: + raise SystemExit("`--from-revision` may not be combined with `--from-version`") + if (args.from_revision or args.from_version) and not args.sql_only: + raise SystemExit("Args `--from-revision` and `--from-version` may only be used with `--sql-only`") + if not (args.version or args.revision): + raise SystemExit("Must provide either revision or version.") + from_revision = None + if args.from_revision: + from_revision = args.from_revision + elif args.from_version: + from_revision = REVISION_HEADS_MAP.get(args.from_version) + if not from_revision: + raise SystemExit(f"Unknown version {args.version!r} supplied as `--from-version`.") + if args.version: + revision = REVISION_HEADS_MAP.get(args.version) + if not revision: + raise SystemExit(f"Downgrading to version {args.version} is not supported.") + elif args.revision: + revision = args.revision + if not args.sql_only: + print("Performing downgrade with database " + repr(settings.engine.url)) + else: + print("Generating sql for downgrade -- downgrade commands will *not* be submitted.") + + if args.sql_only or ( + args.yes + or input( + "\nWarning: About to reverse schema migrations for the airflow metastore. " + "Please ensure you have backed up your database before any upgrade or " + "downgrade operation. Proceed? (y/n)\n" + ).upper() + == "Y" + ): + db.downgrade(to_revision=revision, from_revision=from_revision, sql=args.sql_only) + if not args.sql_only: + print("Downgrade complete") + else: + raise SystemExit("Cancelled") + + def check_migrations(args): """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s""" db.check_migrations(timeout=args.migration_wait_timeout) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 2038d92cbc5b8..2b2b2dd49f234 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -1034,12 +1034,12 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]: session.commit() -def _offline_migration(command, config, revision): +def _offline_migration(migration_func: Callable, config, revision): log.info("Running offline migrations for revision range %s", revision) with warnings.catch_warnings(): warnings.simplefilter("ignore") logging.disable(logging.CRITICAL) - command.upgrade(config, revision, sql=True) + migration_func(config, revision, sql=True) logging.disable(logging.NOTSET) @@ -1134,10 +1134,10 @@ def upgradedb( revision = _validate_version_range(command, config, version_range) if not revision: return - return _offline_migration(command, config, revision) + return _offline_migration(command.upgrade, config, revision) elif revision_range: _validate_revision(command, config, revision_range) - return _offline_migration(command, config, revision_range) + return _offline_migration(command.upgrade, config, revision_range) errors_seen = False for err in _check_migration_errors(session=session): @@ -1172,6 +1172,65 @@ def resetdb(session: Session = NEW_SESSION): initdb(session=session) +@provide_session +def downgrade(to_revision, sql=False, from_revision=None, session: Session = NEW_SESSION): + """ + Downgrade the airflow metastore schema to a prior version. + + :param to_revision: The alembic revision to downgrade *to*. + :param sql: if True, print sql statements but do not run them + :param from_revision: if supplied, alembic revision to dawngrade *from*. This may only + be used in conjunction with ``sql=True`` because if we actually run the commands, + we should only downgrade from the *current* revision. + :param session: sqlalchemy session for connection to airflow metadata database + """ + if from_revision and not sql: + raise ValueError( + "`from_revision` can't be combined with `sql=False`. When actually " + "applying a downgrade (instead of just generating sql), we always " + "downgrade from current revision." + ) + + if not settings.SQL_ALCHEMY_CONN: + raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.") + + # alembic adds significant import time, so we import it lazily + from alembic import command + + log.info("Attempting downgrade to revision %s", to_revision) + + config = _get_alembic_config() + + config.set_main_option('sqlalchemy.url', settings.SQL_ALCHEMY_CONN.replace('%', '%%')) + + errors_seen = False + for err in _check_migration_errors(session=session): + if not errors_seen: + log.error("Automatic migration failed. You may need to apply downgrades manually. ") + errors_seen = True + log.error("%s", err) + + if errors_seen: + exit(1) + + with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): + if sql: + log.warning("Generating sql scripts for manual migration.") + + conn = session.connection() + + from alembic.migration import MigrationContext + + migration_ctx = MigrationContext.configure(conn) + if not from_revision: + from_revision = migration_ctx.get_current_revision() + revision_range = f"{from_revision}:{to_revision}" + _offline_migration(command.downgrade, config=config, revision=revision_range) + else: + log.info("Applying downgrade migrations.") + command.downgrade(config, revision=to_revision, sql=sql) + + def drop_airflow_models(connection): """ Drops all airflow models. diff --git a/tests/cli/commands/test_db_command.py b/tests/cli/commands/test_db_command.py index 09cb14ae59af0..c8f00593b5e9a 100644 --- a/tests/cli/commands/test_db_command.py +++ b/tests/cli/commands/test_db_command.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import unittest from unittest import mock from unittest.mock import patch @@ -28,9 +27,9 @@ from airflow.exceptions import AirflowException -class TestCliDb(unittest.TestCase): +class TestCliDb: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.parser = cli_parser.get_parser() @mock.patch("airflow.cli.commands.db_command.db.initdb") @@ -137,6 +136,73 @@ def test_cli_shell_invalid(self): with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg2"): db_command.shell(self.parser.parse_args(['db', 'shell'])) + @pytest.mark.parametrize( + 'args, match', + [ + (['-y', '--revision', 'abc', '--version', '2.2.0'], 'Cannot supply both'), + (['-y', '--revision', 'abc1', '--from-revision', 'abc2'], 'only .* with `--sql-only`'), + (['-y', '--revision', 'abc1', '--from-version', '2.2.2'], 'only .* with `--sql-only`'), + (['-y', '--version', '2.2.2', '--from-version', '2.2.2'], 'only .* with `--sql-only`'), + ( + ['-y', '--revision', 'abc', '--from-version', '2.2.0', '--from-revision', 'abc'], + 'may not be combined', + ), + (['-y', '--version', 'abc'], r'Downgrading to .* not supported\.'), + (['-y'], 'Must provide either'), + ], + ) + @mock.patch("airflow.utils.db.downgrade") + def test_cli_downgrade_invalid(self, mock_dg, args, match): + """We test some options that should produce an error""" + + with pytest.raises(SystemExit, match=match): + db_command.downgrade(self.parser.parse_args(['db', 'downgrade', *args])) + + @pytest.mark.parametrize( + 'args, expected', + [ + (['-y', '--revision', 'abc1'], dict(to_revision='abc1')), + ( + ['-y', '--revision', 'abc1', '--from-revision', 'abc2', '-s'], + dict(to_revision='abc1', from_revision='abc2', sql=True), + ), + ( + ['-y', '--revision', 'abc1', '--from-version', '2.2.2', '-s'], + dict(to_revision='abc1', from_revision='7b2661a43ba3', sql=True), + ), + ( + ['-y', '--version', '2.2.2', '--from-version', '2.2.2', '-s'], + dict(to_revision='7b2661a43ba3', from_revision='7b2661a43ba3', sql=True), + ), + (['-y', '--version', '2.2.2'], dict(to_revision='7b2661a43ba3')), + ], + ) + @mock.patch("airflow.utils.db.downgrade") + def test_cli_downgrade_good(self, mock_dg, args, expected): + defaults = dict(from_revision=None, sql=False) + db_command.downgrade(self.parser.parse_args(['db', 'downgrade', *args])) + mock_dg.assert_called_with(**{**defaults, **expected}) + + @pytest.mark.parametrize( + 'resp, raise_', + [ + ('y', False), + ('Y', False), + ('n', True), + ('a', True), # any other value + ], + ) + @mock.patch("airflow.utils.db.downgrade") + @mock.patch("airflow.cli.commands.db_command.input") + def test_cli_downgrade_confirm(self, mock_input, mock_dg, resp, raise_): + mock_input.return_value = resp + if raise_: + with pytest.raises(SystemExit): + db_command.downgrade(self.parser.parse_args(['db', 'downgrade', '--revision', 'abc'])) + else: + db_command.downgrade(self.parser.parse_args(['db', 'downgrade', '--revision', 'abc'])) + mock_dg.assert_called_with(to_revision='abc', from_revision=None, sql=False) + class TestCLIDBClean: @classmethod diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py index 5d828264f2e41..3d12127007207 100644 --- a/tests/utils/test_db.py +++ b/tests/utils/test_db.py @@ -31,7 +31,7 @@ from airflow.exceptions import AirflowException from airflow.models import Base as airflow_base from airflow.settings import engine -from airflow.utils.db import check_migrations, create_default_connections, upgradedb +from airflow.utils.db import check_migrations, create_default_connections, downgrade, upgradedb class TestDb: @@ -213,3 +213,27 @@ def test_versions_without_migration_donot_raise(self): with mock.patch('alembic.command.upgrade') as mock_alembic_upgrade: upgradedb("2.1.1:2.1.2") mock_alembic_upgrade.assert_not_called() + + @mock.patch('airflow.utils.db._offline_migration') + def test_downgrade_sql_no_from(self, mock_om): + downgrade(to_revision='abc', sql=True, from_revision=None) + actual = mock_om.call_args[1]['revision'] + assert re.match(r'[a-z0-9]+:abc', actual) is not None + + @mock.patch('airflow.utils.db._offline_migration') + def test_downgrade_sql_with_from(self, mock_om): + downgrade(to_revision='abc', sql=True, from_revision='123') + actual = mock_om.call_args[1]['revision'] + assert actual == '123:abc' + + @mock.patch('alembic.command.downgrade') + def test_downgrade_invalid_combo(self, mock_om): + """can't combine `sql=False` and `from_revision`""" + with pytest.raises(ValueError, match="can't be combined"): + downgrade(to_revision='abc', from_revision='123') + + @mock.patch('alembic.command.downgrade') + def test_downgrade_with_from(self, mock_om): + downgrade(to_revision='abc') + actual = mock_om.call_args[1]['revision'] + assert actual == 'abc'