diff --git a/airflow/compat/sqlalchemy.py b/airflow/compat/sqlalchemy.py new file mode 100644 index 0000000000000..427db90a73d67 --- /dev/null +++ b/airflow/compat/sqlalchemy.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from sqlalchemy import Table +from sqlalchemy.engine import Connection + +try: + from sqlalchemy import inspect +except AttributeError: + from sqlalchemy.engine.reflection import Inspector + + inspect = Inspector.from_engine + +__all__ = ["has_table", "inspect"] + + +def has_table(conn: Connection, table: Table): + try: + return inspect(conn).has_table(table) + except AttributeError: + return table.exists(conn) diff --git a/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py b/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py index b03d4019879cf..4378c8bd0c084 100644 --- a/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py +++ b/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py @@ -26,8 +26,8 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy.engine.reflection import Inspector +from airflow.compat.sqlalchemy import inspect from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. @@ -41,7 +41,7 @@ def upgrade(): """Apply Increase length of ``Flask-AppBuilder`` ``ab_view_menu.name`` column""" conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if "ab_view_menu" in tables: @@ -72,7 +72,7 @@ def upgrade(): def downgrade(): """Unapply Increase length of ``Flask-AppBuilder`` ``ab_view_menu.name`` column""" conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if "ab_view_menu" in tables: if conn.dialect.name == "sqlite": diff --git a/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py b/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py index ffade38b51e76..2d13be3d2e234 100644 --- a/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py +++ b/airflow/migrations/versions/1507a7289a2f_create_is_encrypted.py @@ -25,7 +25,8 @@ """ import sqlalchemy as sa from alembic import op -from sqlalchemy.engine.reflection import Inspector + +from airflow.compat.sqlalchemy import inspect # revision identifiers, used by Alembic. revision = '1507a7289a2f' @@ -44,7 +45,7 @@ def upgrade(): # true for users who are upgrading from a previous version of Airflow # that predates Alembic integration conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) # this will only be true if 'connection' already exists in the db, # but not if alembic created it in a previous migration diff --git a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py index 6273c01f2633c..c37831ac05c0d 100644 --- a/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py +++ b/airflow/migrations/versions/33ae817a1ff4_add_kubernetes_resource_checkpointing.py @@ -25,7 +25,8 @@ """ import sqlalchemy as sa from alembic import op -from sqlalchemy.engine.reflection import Inspector + +from airflow.compat.sqlalchemy import inspect # revision identifiers, used by Alembic. revision = '33ae817a1ff4' @@ -39,7 +40,7 @@ def upgrade(): conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) if RESOURCE_TABLE not in inspector.get_table_names(): columns_and_constraints = [ @@ -63,7 +64,7 @@ def upgrade(): def downgrade(): conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) if RESOURCE_TABLE in inspector.get_table_names(): op.drop_table(RESOURCE_TABLE) diff --git a/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py b/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py index 19c310b62da1b..ce3d090742dd9 100644 --- a/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py +++ b/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py @@ -28,9 +28,9 @@ import sqlalchemy as sa from alembic import op from sqlalchemy import Column, Integer, String -from sqlalchemy.engine.reflection import Inspector from sqlalchemy.ext.declarative import declarative_base +from airflow.compat.sqlalchemy import inspect from airflow.utils.types import DagRunType # revision identifiers, used by Alembic. @@ -58,7 +58,7 @@ def upgrade(): run_type_col_type = sa.String(length=50) conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) dag_run_columns = [col.get('name') for col in inspector.get_columns("dag_run")] if "run_type" not in dag_run_columns: diff --git a/airflow/migrations/versions/92c57b58940d_add_fab_tables.py b/airflow/migrations/versions/92c57b58940d_add_fab_tables.py index 775f7140fe94f..bd3fe44e9ee87 100644 --- a/airflow/migrations/versions/92c57b58940d_add_fab_tables.py +++ b/airflow/migrations/versions/92c57b58940d_add_fab_tables.py @@ -26,7 +26,8 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy.engine.reflection import Inspector + +from airflow.compat.sqlalchemy import inspect # revision identifiers, used by Alembic. revision = '92c57b58940d' @@ -39,7 +40,7 @@ def upgrade(): """Create FAB Tables""" conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if "ab_permission" not in tables: op.create_table( @@ -153,7 +154,7 @@ def upgrade(): def downgrade(): """Drop FAB Tables""" conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() fab_tables = [ "ab_permission", diff --git a/airflow/migrations/versions/bbf4a7ad0465_remove_id_column_from_xcom.py b/airflow/migrations/versions/bbf4a7ad0465_remove_id_column_from_xcom.py index 0c191641b7263..b1abe27472346 100644 --- a/airflow/migrations/versions/bbf4a7ad0465_remove_id_column_from_xcom.py +++ b/airflow/migrations/versions/bbf4a7ad0465_remove_id_column_from_xcom.py @@ -28,7 +28,8 @@ from alembic import op from sqlalchemy import Column, Integer -from sqlalchemy.engine.reflection import Inspector + +from airflow.compat.sqlalchemy import inspect # revision identifiers, used by Alembic. revision = 'bbf4a7ad0465' @@ -97,7 +98,7 @@ def create_constraints(operator, column_name, constraint_dict): def upgrade(): """Apply Remove id column from xcom""" conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) with op.batch_alter_table('xcom') as bop: xcom_columns = [col.get('name') for col in inspector.get_columns("xcom")] diff --git a/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py b/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py index b11b67b436e76..d2449aa48dc1f 100644 --- a/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py +++ b/airflow/migrations/versions/bef4f3d11e8b_drop_kuberesourceversion_and_.py @@ -26,7 +26,8 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy.engine.reflection import Inspector + +from airflow.compat.sqlalchemy import inspect # revision identifiers, used by Alembic. revision = 'bef4f3d11e8b' @@ -43,7 +44,7 @@ def upgrade(): """Apply Drop ``KubeResourceVersion`` and ``KubeWorkerId``entifier tables""" conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if WORKER_UUID_TABLE in tables: @@ -55,7 +56,7 @@ def upgrade(): def downgrade(): """Unapply Drop ``KubeResourceVersion`` and ``KubeWorkerId``entifier tables""" conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if WORKER_UUID_TABLE not in tables: diff --git a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py index a2a89fe690019..f3d1bd8891e67 100644 --- a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py +++ b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py @@ -26,10 +26,10 @@ import sqlalchemy as sa from alembic import op from sqlalchemy import Column, Integer, String -from sqlalchemy.engine.reflection import Inspector from sqlalchemy.ext.declarative import declarative_base from airflow import settings +from airflow.compat.sqlalchemy import inspect from airflow.models import DagBag # revision identifiers, used by Alembic. @@ -62,7 +62,7 @@ def upgrade(): # Checking task_instance table exists prevent the error of querying # non-existing task_instance table. connection = op.get_bind() - inspector = Inspector.from_engine(connection) + inspector = inspect(connection) tables = inspector.get_table_names() if 'task_instance' in tables: diff --git a/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py b/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py index db51a32be210d..deb2778661f68 100644 --- a/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py +++ b/airflow/migrations/versions/cf5dc11e79ad_drop_user_and_chart.py @@ -25,7 +25,8 @@ import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import mysql -from sqlalchemy.engine.reflection import Inspector + +from airflow.compat.sqlalchemy import inspect # revision identifiers, used by Alembic. revision = 'cf5dc11e79ad' @@ -43,7 +44,7 @@ def upgrade(): # But before we can delete the users table we need to drop the FK conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if 'known_event' in tables: diff --git a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py index aae2113a52919..d98be012b7b52 100644 --- a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py +++ b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py @@ -26,8 +26,8 @@ import sqlalchemy as sa from alembic import op from sqlalchemy import func -from sqlalchemy.engine.reflection import Inspector +from airflow.compat.sqlalchemy import inspect from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. @@ -41,7 +41,7 @@ def upgrade(): conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if 'sensor_instance' in tables: return @@ -74,7 +74,7 @@ def upgrade(): def downgrade(): conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if 'sensor_instance' in tables: op.drop_table('sensor_instance') diff --git a/airflow/migrations/versions/e3a246e0dc1_current_schema.py b/airflow/migrations/versions/e3a246e0dc1_current_schema.py index b4edbd627405d..9824db7dad36f 100644 --- a/airflow/migrations/versions/e3a246e0dc1_current_schema.py +++ b/airflow/migrations/versions/e3a246e0dc1_current_schema.py @@ -27,8 +27,8 @@ import sqlalchemy as sa from alembic import op from sqlalchemy import func -from sqlalchemy.engine.reflection import Inspector +from airflow.compat.sqlalchemy import inspect from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. @@ -41,7 +41,7 @@ def upgrade(): conn = op.get_bind() - inspector = Inspector.from_engine(conn) + inspector = inspect(conn) tables = inspector.get_table_names() if 'connection' not in tables: diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 2b2b2dd49f234..3ae84f0faec2f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -29,6 +29,7 @@ from sqlalchemy.orm.session import Session from airflow import settings +from airflow.compat.sqlalchemy import has_table from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.jobs.base_job import BaseJob # noqa: F401 @@ -1265,7 +1266,7 @@ def drop_airflow_models(connection): migration_ctx = MigrationContext.configure(connection) version = migration_ctx._version - if version.exists(connection): + if has_table(connection, version): version.drop(connection) diff --git a/airflow/www/fab_security/sqla/manager.py b/airflow/www/fab_security/sqla/manager.py index 3e1fdf89d0474..9042a22f6a0e4 100644 --- a/airflow/www/fab_security/sqla/manager.py +++ b/airflow/www/fab_security/sqla/manager.py @@ -23,10 +23,10 @@ from flask_appbuilder.models.sqla import Base from flask_appbuilder.models.sqla.interface import SQLAInterface from sqlalchemy import and_, func, literal -from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm.exc import MultipleResultsFound from werkzeug.security import generate_password_hash +from airflow.compat import sqlalchemy as sqla_compat from airflow.www.fab_security.manager import BaseSecurityManager from airflow.www.fab_security.sqla.models import ( Action, @@ -99,7 +99,7 @@ def register_views(self): def create_db(self): try: engine = self.get_session.get_bind(mapper=None, clause=None) - inspector = Inspector.from_engine(engine) + inspector = sqla_compat.inspect(engine) if "ab_user" not in inspector.get_table_names(): log.info(c.LOGMSG_INF_SEC_NO_DB) Base.metadata.create_all(engine) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 0a96f887fe7ee..6aa83d6de4fcb 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -888,7 +888,7 @@ def test_extra_serialized_field_and_operator_links( assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == serialized_links # Test all the extra_links are set - assert set(simple_task.extra_links) == set(links.keys()) | {'airflow', 'github', 'google'} + assert set(simple_task.extra_links) == {*links, 'airflow', 'github', 'google'} dr = dag_maker.create_dagrun(execution_date=test_date) (ti,) = dr.task_instances @@ -897,7 +897,7 @@ def test_extra_serialized_field_and_operator_links( value=bash_command, task_id=simple_task.task_id, dag_id=simple_task.dag_id, - execution_date=test_date, + run_id=dr.run_id, ) # Test Deserialized inbuilt link