From 64d3f988711ba7204845634434c246de4524bffa Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 4 Nov 2021 17:02:37 +0000 Subject: [PATCH] Define datetime and StringID column types centrally in migrations We have various flavours of the code all over the place in many migration files -- which leads to duplication and things not being in sync. This pulls them once in to a central location. --- airflow/migrations/db_types.py | 86 +++++++++++++++++++ airflow/migrations/db_types.pyi | 28 ++++++ ...02_increase_length_of_fab_ab_view_menu_.py | 4 +- .../0a2a5b66e19d_add_task_reschedule_table.py | 41 +++------ ...terval_start_end_to_dagmodel_and_dagrun.py | 35 ++------ .../versions/1b38cef5b76e_add_dagrun.py | 6 +- .../3c20cacc0044_add_dagrun_run_type.py | 10 ++- ...de9cddf6c9_add_task_fails_journal_table.py | 6 +- ...4a4_make_taskinstance_pool_not_nullable.py | 6 +- .../7939bcff74ba_add_dagtags_table.py | 4 +- ...2661a43ba3_taskinstance_keyed_to_dagrun.py | 23 ++--- ...3f031fd9f1c_improve_mssql_compatibility.py | 64 ++++---------- ...add_rendered_task_instance_fields_table.py | 6 +- ...922c8a04_change_default_pool_slots_to_1.py | 9 +- ...b8_add_queued_at_column_to_dagrun_table.py | 9 +- ..._add_scheduling_decision_to_dagrun_and_.py | 30 ++----- ...7_add_max_tries_column_to_task_instance.py | 9 +- .../d38e04c12aa2_add_serialized_dag_table.py | 4 +- ...e357a868_update_schema_for_smart_sensor.py | 34 ++------ .../versions/e3a246e0dc1_current_schema.py | 30 +++---- ...1f0_make_xcom_pkey_columns_non_nullable.py | 32 ++----- .../f2ca10b85618_add_dag_stats_table.py | 4 +- airflow/models/base.py | 7 +- tests/utils/test_db.py | 15 +++- 24 files changed, 238 insertions(+), 264 deletions(-) create mode 100644 airflow/migrations/db_types.py create mode 100644 airflow/migrations/db_types.pyi diff --git a/airflow/migrations/db_types.py b/airflow/migrations/db_types.py new file mode 100644 index 0000000000000..80394d02697b8 --- /dev/null +++ b/airflow/migrations/db_types.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import sys + +import sqlalchemy as sa +from alembic import context +from lazy_object_proxy import Proxy + +###################################### +# Note about this module: +# +# It loads the specific type dynamically at runtime. For IDE/typing support +# there is an associated db_types.pyi. If you add a new type in here, add a +# simple version in there too. +###################################### + + +def _mssql_use_date_time2(): + conn = context.get_bind() + result = conn.execute( + """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) + like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) + like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""" + ).fetchone() + mssql_version = result[0] + return mssql_version not in ("2000", "2005") + + +MSSQL_USE_DATE_TIME2 = Proxy(_mssql_use_date_time2) + + +def _mssql_TIMESTAMP(): + from sqlalchemy.dialects import mssql + + return mssql.DATETIME2(precision=6) if MSSQL_USE_DATE_TIME2 else mssql.DATETIME + + +def _mysql_TIMESTAMP(): + from sqlalchemy.dialects import mysql + + return mysql.TIMESTAMP(fsp=6, timezone=True) + + +def _sa_TIMESTAMP(): + return sa.TIMESTAMP(timezone=True) + + +def _sa_StringID(): + from airflow.models.base import StringID + + return StringID + + +def __getattr__(name): + if name in ["TIMESTAMP", "StringID"]: + dialect = context.get_bind().dialect.name + module = globals() + + # Lookup the type based on the dialect specific type, or fallback to the generic type + type_ = module.get(f'_{dialect}_{name}', None) or module.get(f'_sa_{name}') + val = module[name] = type_() + return val + + raise AttributeError(f"module {__name__} has no attribute {name}") + + +if sys.version_info < (3, 7): + from pep562 import Pep562 + + Pep562(__name__) diff --git a/airflow/migrations/db_types.pyi b/airflow/migrations/db_types.pyi new file mode 100644 index 0000000000000..0a56b5f408174 --- /dev/null +++ b/airflow/migrations/db_types.pyi @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sqlalchemy as sa + +TIMESTAMP = sa.TIMESTAMP +"""Database specific timestamp with timezone""" + +StringID = sa.String +"""String column type with correct DB collation applied""" + +MSSQL_USE_DATE_TIME2: bool diff --git a/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py b/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py index f7484bf979c83..3b78f02c7c9f7 100644 --- a/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py +++ b/airflow/migrations/versions/03afc6b6f902_increase_length_of_fab_ab_view_menu_.py @@ -28,7 +28,7 @@ from alembic import op from sqlalchemy.engine.reflection import Inspector -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = '03afc6b6f902' @@ -63,7 +63,7 @@ def upgrade(): op.alter_column( table_name='ab_view_menu', column_name='name', - type_=sa.String(length=250, **COLLATION_ARGS), + type_=StringID(length=250), nullable=False, ) diff --git a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py index 2133eb148c30b..b7e62413ec1a3 100644 --- a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py +++ b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py @@ -24,9 +24,8 @@ """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import mysql -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. revision = '0a2a5b66e19d' @@ -38,43 +37,25 @@ INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date' -# For Microsoft SQL Server, TIMESTAMP is a row-id type, -# having nothing to do with date-time. DateTime() will -# be sufficient. -def mssql_timestamp(): - return sa.DateTime() - - -def mysql_timestamp(): - return mysql.TIMESTAMP(fsp=6) - - -def sa_timestamp(): - return sa.TIMESTAMP(timezone=True) - - def upgrade(): # See 0e2a74e0fc9f_add_time_zone_awareness - conn = op.get_bind() - if conn.dialect.name == 'mysql': - timestamp = mysql_timestamp - elif conn.dialect.name == 'mssql': - timestamp = mssql_timestamp - else: - timestamp = sa_timestamp + timestamp = TIMESTAMP + if op.get_bind().dialect.name == 'mssql': + # We need to keep this as it was for this old migration on mssql + timestamp = sa.DateTime() op.create_table( TABLE_NAME, sa.Column('id', sa.Integer(), nullable=False), - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('task_id', StringID(), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), # use explicit server_default=None otherwise mysql implies defaults for first timestamp column - sa.Column('execution_date', timestamp(), nullable=False, server_default=None), + sa.Column('execution_date', timestamp, nullable=False, server_default=None), sa.Column('try_number', sa.Integer(), nullable=False), - sa.Column('start_date', timestamp(), nullable=False), - sa.Column('end_date', timestamp(), nullable=False), + sa.Column('start_date', timestamp, nullable=False), + sa.Column('end_date', timestamp, nullable=False), sa.Column('duration', sa.Integer(), nullable=False), - sa.Column('reschedule_date', timestamp(), nullable=False), + sa.Column('reschedule_date', timestamp, nullable=False), sa.PrimaryKeyConstraint('id'), sa.ForeignKeyConstraint( ['task_id', 'dag_id', 'execution_date'], diff --git a/airflow/migrations/versions/142555e44c17_add_data_interval_start_end_to_dagmodel_and_dagrun.py b/airflow/migrations/versions/142555e44c17_add_data_interval_start_end_to_dagmodel_and_dagrun.py index 2eedcb81c6444..6be37326cd29c 100644 --- a/airflow/migrations/versions/142555e44c17_add_data_interval_start_end_to_dagmodel_and_dagrun.py +++ b/airflow/migrations/versions/142555e44c17_add_data_interval_start_end_to_dagmodel_and_dagrun.py @@ -25,8 +25,9 @@ """ from alembic import op -from sqlalchemy import TIMESTAMP, Column -from sqlalchemy.dialects import mssql, mysql +from sqlalchemy import Column + +from airflow.migrations.db_types import TIMESTAMP # Revision identifiers, used by Alembic. revision = "142555e44c17" @@ -35,36 +36,14 @@ depends_on = None -def _use_date_time2(conn): - result = conn.execute( - """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""" - ).fetchone() - mssql_version = result[0] - return mssql_version not in ("2000", "2005") - - -def _get_timestamp(conn): - dialect_name = conn.dialect.name - if dialect_name == "mysql": - return mysql.TIMESTAMP(fsp=6, timezone=True) - if dialect_name != "mssql": - return TIMESTAMP(timezone=True) - if _use_date_time2(conn): - return mssql.DATETIME2(precision=6) - return mssql.DATETIME - - def upgrade(): """Apply data_interval fields to DagModel and DagRun.""" - column_type = _get_timestamp(op.get_bind()) with op.batch_alter_table("dag_run") as batch_op: - batch_op.add_column(Column("data_interval_start", column_type)) - batch_op.add_column(Column("data_interval_end", column_type)) + batch_op.add_column(Column("data_interval_start", TIMESTAMP)) + batch_op.add_column(Column("data_interval_end", TIMESTAMP)) with op.batch_alter_table("dag") as batch_op: - batch_op.add_column(Column("next_dagrun_data_interval_start", column_type)) - batch_op.add_column(Column("next_dagrun_data_interval_end", column_type)) + batch_op.add_column(Column("next_dagrun_data_interval_start", TIMESTAMP)) + batch_op.add_column(Column("next_dagrun_data_interval_end", TIMESTAMP)) def downgrade(): diff --git a/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py b/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py index 095b496bddcab..441a4d8e7159b 100644 --- a/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py +++ b/airflow/migrations/versions/1b38cef5b76e_add_dagrun.py @@ -27,7 +27,7 @@ import sqlalchemy as sa from alembic import op -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = '1b38cef5b76e' @@ -40,10 +40,10 @@ def upgrade(): op.create_table( 'dag_run', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=True), + sa.Column('dag_id', StringID(), nullable=True), sa.Column('execution_date', sa.DateTime(), nullable=True), sa.Column('state', sa.String(length=50), nullable=True), - sa.Column('run_id', sa.String(length=250, **COLLATION_ARGS), nullable=True), + sa.Column('run_id', StringID(), nullable=True), sa.Column('external_trigger', sa.Boolean(), nullable=True), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('dag_id', 'execution_date'), diff --git a/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py b/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py index b7f6650825a51..3a9f0a2097abe 100644 --- a/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py +++ b/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py @@ -31,7 +31,7 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.ext.declarative import declarative_base -from airflow.models.base import ID_LEN +from airflow.migrations.db_types import StringID from airflow.utils import timezone from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import State @@ -55,12 +55,12 @@ class DagRun(Base): # type: ignore __tablename__ = "dag_run" id = Column(Integer, primary_key=True) - dag_id = Column(String(ID_LEN)) + dag_id = Column(StringID()) execution_date = Column(UtcDateTime, default=timezone.utcnow) start_date = Column(UtcDateTime, default=timezone.utcnow) end_date = Column(UtcDateTime) _state = Column('state', String(50), default=State.RUNNING) - run_id = Column(String(ID_LEN)) + run_id = Column(StringID()) external_trigger = Column(Boolean, default=True) run_type = Column(String(50), nullable=False) conf = Column(PickleType) @@ -96,7 +96,9 @@ def upgrade(): # Make run_type not nullable with op.batch_alter_table("dag_run") as batch_op: - batch_op.alter_column("run_type", type_=run_type_col_type, nullable=False) + batch_op.alter_column( + "run_type", existing_type=run_type_col_type, type_=run_type_col_type, nullable=False + ) def downgrade(): diff --git a/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py b/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py index 4243e3a3b40fd..c2dbb647089a2 100644 --- a/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py +++ b/airflow/migrations/versions/64de9cddf6c9_add_task_fails_journal_table.py @@ -26,7 +26,7 @@ import sqlalchemy as sa from alembic import op -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = '64de9cddf6c9' @@ -39,8 +39,8 @@ def upgrade(): op.create_table( 'task_fail', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('task_id', StringID(), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.Column('execution_date', sa.DateTime(), nullable=False), sa.Column('start_date', sa.DateTime(), nullable=True), sa.Column('end_date', sa.DateTime(), nullable=True), diff --git a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py index c59c52225757f..dd42e25c949be 100644 --- a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py +++ b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py @@ -30,7 +30,7 @@ from sqlalchemy import Column, Float, Integer, PickleType, String from sqlalchemy.ext.declarative import declarative_base -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID from airflow.utils.session import create_session from airflow.utils.sqlalchemy import UtcDateTime @@ -60,8 +60,8 @@ class TaskInstance(Base): # type: ignore __tablename__ = "task_instance" - task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + task_id = Column(StringID(), primary_key=True) + dag_id = Column(StringID(), primary_key=True) execution_date = Column(UtcDateTime, primary_key=True) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) diff --git a/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py b/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py index b207b717f03ab..a19fee24f3772 100644 --- a/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py +++ b/airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py @@ -27,7 +27,7 @@ import sqlalchemy as sa from alembic import op -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = '7939bcff74ba' @@ -41,7 +41,7 @@ def upgrade(): op.create_table( 'dag_tag', sa.Column('name', sa.String(length=100), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.ForeignKeyConstraint( ['dag_id'], ['dag.dag_id'], diff --git a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py index 775d239471bfc..3e5d917c5cc5f 100644 --- a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py +++ b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py @@ -30,7 +30,7 @@ from alembic import op from sqlalchemy.sql import and_, column, select, table -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import TIMESTAMP, StringID ID_LEN = 250 @@ -41,19 +41,6 @@ depends_on = None -def _datetime_type(dialect_name): - if dialect_name == "mssql": - from sqlalchemy.dialects import mssql - - return mssql.DATETIME2(precision=6) - elif dialect_name == "mysql": - from sqlalchemy.dialects import mysql - - return mysql.DATETIME(fsp=6) - - return sa.TIMESTAMP(timezone=True) - - # Just Enough Table to run the conditions for update. task_instance = table( 'task_instance', @@ -106,9 +93,9 @@ def upgrade(): """Apply TaskInstance keyed to DagRun""" conn = op.get_bind() dialect_name = conn.dialect.name - dt_type = _datetime_type(dialect_name) - string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS) + dt_type = TIMESTAMP + string_id_col_type = StringID() if dialect_name == 'sqlite': naming_convention = { @@ -326,8 +313,8 @@ def upgrade(): def downgrade(): """Unapply TaskInstance keyed to DagRun""" dialect_name = op.get_bind().dialect.name - dt_type = _datetime_type(dialect_name) - string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS) + dt_type = TIMESTAMP + string_id_col_type = StringID() op.add_column('task_instance', sa.Column('execution_date', dt_type, nullable=True)) op.add_column('task_reschedule', sa.Column('execution_date', dt_type, nullable=True)) diff --git a/airflow/migrations/versions/83f031fd9f1c_improve_mssql_compatibility.py b/airflow/migrations/versions/83f031fd9f1c_improve_mssql_compatibility.py index daaf711b4c79f..97d0019ac44cc 100644 --- a/airflow/migrations/versions/83f031fd9f1c_improve_mssql_compatibility.py +++ b/airflow/migrations/versions/83f031fd9f1c_improve_mssql_compatibility.py @@ -30,6 +30,8 @@ from alembic import op from sqlalchemy.dialects import mssql +from airflow.migrations.db_types import MSSQL_USE_DATE_TIME2, TIMESTAMP + # revision identifiers, used by Alembic. revision = '83f031fd9f1c' down_revision = 'ccde3e26fe78' @@ -104,16 +106,6 @@ def create_constraints(operator, column_name, constraint_dict): operator.create_unique_constraint(constraint_name=constraint[0], columns=columns) -def _use_date_time2(conn): - result = conn.execute( - """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""" - ).fetchone() - mssql_version = result[0] - return mssql_version not in ("2000", "2005") - - def _is_timestamp(conn, table_name, column_name): query = f"""SELECT TYPE_NAME(C.USER_TYPE_ID) AS DATA_TYPE @@ -136,16 +128,13 @@ def recreate_mssql_ts_column(conn, op, table_name, column_name): constraint_dict = get_table_constraints(conn, table_name) drop_column_constraints(batch_op, column_name, constraint_dict) batch_op.drop_column(column_name=column_name) - if _use_date_time2(conn): - batch_op.add_column(sa.Column(column_name, mssql.DATETIME2(precision=6), nullable=False)) - else: - batch_op.add_column(sa.Column(column_name, mssql.DATETIME, nullable=False)) + batch_op.add_column(sa.Column(column_name, TIMESTAMP, nullable=False)) create_constraints(batch_op, column_name, constraint_dict) def alter_mssql_datetime_column(conn, op, table_name, column_name, nullable): """Update the datetime column to datetime2(6)""" - if _use_date_time2(conn): + if MSSQL_USE_DATE_TIME2: op.alter_column( table_name=table_name, column_name=column_name, @@ -156,19 +145,12 @@ def alter_mssql_datetime_column(conn, op, table_name, column_name, nullable): def alter_mssql_datetime2_column(conn, op, table_name, column_name, nullable): """Update the datetime2(6) column to datetime""" - if _use_date_time2(conn): + if MSSQL_USE_DATE_TIME2: op.alter_column( table_name=table_name, column_name=column_name, type_=mssql.DATETIME, nullable=nullable ) -def _get_timestamp(conn): - if _use_date_time2(conn): - return mssql.DATETIME2(precision=6) - else: - return mssql.DATETIME - - def upgrade(): """Improve compatibility with MSSQL backend""" conn = op.get_bind() @@ -177,22 +159,14 @@ def upgrade(): recreate_mssql_ts_column(conn, op, 'dag_code', 'last_updated') recreate_mssql_ts_column(conn, op, 'rendered_task_instance_fields', 'execution_date') alter_mssql_datetime_column(conn, op, 'serialized_dag', 'last_updated', False) - op.alter_column(table_name="xcom", column_name="timestamp", type_=_get_timestamp(conn), nullable=False) + op.alter_column(table_name="xcom", column_name="timestamp", type_=TIMESTAMP, nullable=False) with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.alter_column( - column_name='end_date', type_=_get_timestamp(conn), nullable=False - ) - task_reschedule_batch_op.alter_column( - column_name='reschedule_date', type_=_get_timestamp(conn), nullable=False - ) - task_reschedule_batch_op.alter_column( - column_name='start_date', type_=_get_timestamp(conn), nullable=False - ) + task_reschedule_batch_op.alter_column(column_name='end_date', type_=TIMESTAMP, nullable=False) + task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=TIMESTAMP, nullable=False) + task_reschedule_batch_op.alter_column(column_name='start_date', type_=TIMESTAMP, nullable=False) with op.batch_alter_table('task_fail') as task_fail_batch_op: task_fail_batch_op.drop_index('idx_task_fail_dag_task_date') - task_fail_batch_op.alter_column( - column_name="execution_date", type_=_get_timestamp(conn), nullable=False - ) + task_fail_batch_op.alter_column(column_name="execution_date", type_=TIMESTAMP, nullable=False) task_fail_batch_op.create_index( 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False ) @@ -225,22 +199,14 @@ def downgrade(): if conn.dialect.name != 'mssql': return alter_mssql_datetime2_column(conn, op, 'serialized_dag', 'last_updated', False) - op.alter_column(table_name="xcom", column_name="timestamp", type_=_get_timestamp(conn), nullable=True) + op.alter_column(table_name="xcom", column_name="timestamp", type_=TIMESTAMP, nullable=True) with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op: - task_reschedule_batch_op.alter_column( - column_name='end_date', type_=_get_timestamp(conn), nullable=True - ) - task_reschedule_batch_op.alter_column( - column_name='reschedule_date', type_=_get_timestamp(conn), nullable=True - ) - task_reschedule_batch_op.alter_column( - column_name='start_date', type_=_get_timestamp(conn), nullable=True - ) + task_reschedule_batch_op.alter_column(column_name='end_date', type_=TIMESTAMP, nullable=True) + task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=TIMESTAMP, nullable=True) + task_reschedule_batch_op.alter_column(column_name='start_date', type_=TIMESTAMP, nullable=True) with op.batch_alter_table('task_fail') as task_fail_batch_op: task_fail_batch_op.drop_index('idx_task_fail_dag_task_date') - task_fail_batch_op.alter_column( - column_name="execution_date", type_=_get_timestamp(conn), nullable=False - ) + task_fail_batch_op.alter_column(column_name="execution_date", type_=TIMESTAMP, nullable=False) task_fail_batch_op.create_index( 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False ) diff --git a/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py b/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py index 83a0635dca06b..3a32841dba0f1 100644 --- a/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py +++ b/airflow/migrations/versions/852ae6c715af_add_rendered_task_instance_fields_table.py @@ -27,7 +27,7 @@ import sqlalchemy as sa from alembic import op -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = '852ae6c715af' @@ -53,8 +53,8 @@ def upgrade(): op.create_table( TABLE_NAME, - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), + sa.Column('task_id', StringID(), nullable=False), sa.Column('execution_date', sa.TIMESTAMP(timezone=True), nullable=False), sa.Column('rendered_fields', json_type(), nullable=False), sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'), diff --git a/airflow/migrations/versions/8646922c8a04_change_default_pool_slots_to_1.py b/airflow/migrations/versions/8646922c8a04_change_default_pool_slots_to_1.py index 678df91ec7636..0b44c110223bd 100644 --- a/airflow/migrations/versions/8646922c8a04_change_default_pool_slots_to_1.py +++ b/airflow/migrations/versions/8646922c8a04_change_default_pool_slots_to_1.py @@ -30,7 +30,7 @@ from sqlalchemy import Column, Float, Integer, PickleType, String from sqlalchemy.ext.declarative import declarative_base -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID from airflow.utils.sqlalchemy import UtcDateTime # revision identifiers, used by Alembic. @@ -41,7 +41,6 @@ Base = declarative_base() BATCH_SIZE = 5000 -ID_LEN = 250 class TaskInstance(Base): # type: ignore @@ -49,8 +48,8 @@ class TaskInstance(Base): # type: ignore __tablename__ = "task_instance" - task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + task_id = Column(StringID(), primary_key=True) + dag_id = Column(StringID(), primary_key=True) execution_date = Column(UtcDateTime, primary_key=True) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) @@ -70,7 +69,7 @@ class TaskInstance(Base): # type: ignore queued_by_job_id = Column(Integer) pid = Column(Integer) executor_config = Column(PickleType(pickler=dill)) - external_executor_id = Column(String(ID_LEN, **COLLATION_ARGS)) + external_executor_id = Column(StringID()) def upgrade(): diff --git a/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py index 03caebc1471a9..db44b0fde326b 100644 --- a/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py +++ b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py @@ -26,7 +26,8 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import mssql + +from airflow.migrations.db_types import TIMESTAMP # revision identifiers, used by Alembic. revision = '97cdd93827b8' @@ -37,11 +38,7 @@ def upgrade(): """Apply Add queued_at column to dagrun table""" - conn = op.get_bind() - if conn.dialect.name == "mssql": - op.add_column('dag_run', sa.Column('queued_at', mssql.DATETIME2(precision=6), nullable=True)) - else: - op.add_column('dag_run', sa.Column('queued_at', sa.DateTime(), nullable=True)) + op.add_column('dag_run', sa.Column('queued_at', TIMESTAMP, nullable=True)) def downgrade(): diff --git a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py index 9d6ca5736b328..c27f30fc3156c 100644 --- a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py +++ b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py @@ -26,7 +26,8 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import mssql, mysql + +from airflow.migrations.db_types import TIMESTAMP # revision identifiers, used by Alembic. revision = '98271e7606e2' @@ -35,44 +36,23 @@ depends_on = None -def _use_date_time2(conn): - result = conn.execute( - """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""" - ).fetchone() - mssql_version = result[0] - return mssql_version not in ("2000", "2005") - - -def _get_timestamp(conn): - dialect_name = conn.dialect.name - if dialect_name == "mssql": - return mssql.DATETIME2(precision=6) if _use_date_time2(conn) else mssql.DATETIME - elif dialect_name != "mysql": - return sa.TIMESTAMP(timezone=True) - else: - return mysql.TIMESTAMP(fsp=6, timezone=True) - - def upgrade(): """Apply Add scheduling_decision to DagRun and DAG""" conn = op.get_bind() is_sqlite = bool(conn.dialect.name == "sqlite") is_mssql = bool(conn.dialect.name == "mssql") - timestamp = _get_timestamp(conn) if is_sqlite: op.execute("PRAGMA foreign_keys=off") with op.batch_alter_table('dag_run', schema=None) as batch_op: - batch_op.add_column(sa.Column('last_scheduling_decision', timestamp, nullable=True)) + batch_op.add_column(sa.Column('last_scheduling_decision', TIMESTAMP, nullable=True)) batch_op.create_index('idx_last_scheduling_decision', ['last_scheduling_decision'], unique=False) batch_op.add_column(sa.Column('dag_hash', sa.String(32), nullable=True)) with op.batch_alter_table('dag', schema=None) as batch_op: - batch_op.add_column(sa.Column('next_dagrun', timestamp, nullable=True)) - batch_op.add_column(sa.Column('next_dagrun_create_after', timestamp, nullable=True)) + batch_op.add_column(sa.Column('next_dagrun', TIMESTAMP, nullable=True)) + batch_op.add_column(sa.Column('next_dagrun_create_after', TIMESTAMP, nullable=True)) # Create with nullable and no default, then ALTER to set values, to avoid table level lock batch_op.add_column(sa.Column('concurrency', sa.Integer(), nullable=True)) batch_op.add_column(sa.Column('has_task_concurrency_limits', sa.Boolean(), nullable=True)) diff --git a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py index 95a32c0b44332..6b01a8c545944 100644 --- a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py +++ b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py @@ -25,13 +25,13 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy import Column, Integer, String +from sqlalchemy import Column, Integer from sqlalchemy.engine.reflection import Inspector from sqlalchemy.ext.declarative import declarative_base from airflow import settings +from airflow.migrations.db_types import StringID from airflow.models import DagBag -from airflow.models.base import COLLATION_ARGS # revision identifiers, used by Alembic. revision = 'cc1e65623dc7' @@ -41,7 +41,6 @@ Base = declarative_base() BATCH_SIZE = 5000 -ID_LEN = 250 class TaskInstance(Base): # type: ignore @@ -49,8 +48,8 @@ class TaskInstance(Base): # type: ignore __tablename__ = "task_instance" - task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + task_id = Column(StringID(), primary_key=True) + dag_id = Column(StringID(), primary_key=True) execution_date = Column(sa.DateTime, primary_key=True) max_tries = Column(Integer) try_number = Column(Integer, default=0) diff --git a/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py b/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py index 4b8b058dc05fd..9b96a56db3ecc 100644 --- a/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py +++ b/airflow/migrations/versions/d38e04c12aa2_add_serialized_dag_table.py @@ -27,7 +27,7 @@ from alembic import op from sqlalchemy.dialects import mysql -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = 'd38e04c12aa2' @@ -51,7 +51,7 @@ def upgrade(): op.create_table( 'serialized_dag', - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.Column('fileloc', sa.String(length=2000), nullable=False), sa.Column('fileloc_hash', sa.Integer(), nullable=False), sa.Column('data', json_type(), nullable=False), diff --git a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py index def5fdf0ba26f..213299a2d0c97 100644 --- a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py +++ b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py @@ -26,10 +26,9 @@ import sqlalchemy as sa from alembic import op from sqlalchemy import func -from sqlalchemy.dialects import mysql from sqlalchemy.engine.reflection import Inspector -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. revision = 'e38be357a868' @@ -38,18 +37,6 @@ depends_on = None -def mssql_timestamp(): - return sa.DateTime() - - -def mysql_timestamp(): - return mysql.TIMESTAMP(fsp=6) - - -def sa_timestamp(): - return sa.TIMESTAMP(timezone=True) - - def upgrade(): conn = op.get_bind() @@ -58,30 +45,23 @@ def upgrade(): if 'sensor_instance' in tables: return - if conn.dialect.name == 'mysql': - timestamp = mysql_timestamp - elif conn.dialect.name == 'mssql': - timestamp = mssql_timestamp - else: - timestamp = sa_timestamp - op.create_table( 'sensor_instance', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('execution_date', timestamp(), nullable=False), + sa.Column('task_id', StringID(), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), + sa.Column('execution_date', TIMESTAMP, nullable=False), sa.Column('state', sa.String(length=20), nullable=True), sa.Column('try_number', sa.Integer(), nullable=True), - sa.Column('start_date', timestamp(), nullable=True), + sa.Column('start_date', TIMESTAMP, nullable=True), sa.Column('operator', sa.String(length=1000), nullable=False), sa.Column('op_classpath', sa.String(length=1000), nullable=False), sa.Column('hashcode', sa.BigInteger(), nullable=False), sa.Column('shardcode', sa.Integer(), nullable=False), sa.Column('poke_context', sa.Text(), nullable=False), sa.Column('execution_context', sa.Text(), nullable=True), - sa.Column('created_at', timestamp(), default=func.now(), nullable=False), - sa.Column('updated_at', timestamp(), default=func.now(), nullable=False), + sa.Column('created_at', TIMESTAMP, default=func.now(), nullable=False), + sa.Column('updated_at', TIMESTAMP, default=func.now(), nullable=False), sa.PrimaryKeyConstraint('id'), ) op.create_index('ti_primary_key', 'sensor_instance', ['dag_id', 'task_id', 'execution_date'], unique=True) diff --git a/airflow/migrations/versions/e3a246e0dc1_current_schema.py b/airflow/migrations/versions/e3a246e0dc1_current_schema.py index 9760232910739..49e73591887bf 100644 --- a/airflow/migrations/versions/e3a246e0dc1_current_schema.py +++ b/airflow/migrations/versions/e3a246e0dc1_current_schema.py @@ -29,7 +29,7 @@ from sqlalchemy import func from sqlalchemy.engine.reflection import Inspector -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = 'e3a246e0dc1' @@ -47,7 +47,7 @@ def upgrade(): op.create_table( 'connection', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('conn_id', sa.String(length=250, **COLLATION_ARGS), nullable=True), + sa.Column('conn_id', StringID(), nullable=True), sa.Column('conn_type', sa.String(length=500), nullable=True), sa.Column('host', sa.String(length=500), nullable=True), sa.Column('schema', sa.String(length=500), nullable=True), @@ -60,7 +60,7 @@ def upgrade(): if 'dag' not in tables: op.create_table( 'dag', - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.Column('is_paused', sa.Boolean(), nullable=True), sa.Column('is_subdag', sa.Boolean(), nullable=True), sa.Column('is_active', sa.Boolean(), nullable=True), @@ -112,8 +112,8 @@ def upgrade(): 'log', sa.Column('id', sa.Integer(), nullable=False), sa.Column('dttm', sa.DateTime(), nullable=True), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=True), - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=True), + sa.Column('dag_id', StringID(), nullable=True), + sa.Column('task_id', StringID(), nullable=True), sa.Column('event', sa.String(length=30), nullable=True), sa.Column('execution_date', sa.DateTime(), nullable=True), sa.Column('owner', sa.String(length=500), nullable=True), @@ -122,8 +122,8 @@ def upgrade(): if 'sla_miss' not in tables: op.create_table( 'sla_miss', - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('task_id', StringID(), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.Column('execution_date', sa.DateTime(), nullable=False), sa.Column('email_sent', sa.Boolean(), nullable=True), sa.Column('timestamp', sa.DateTime(), nullable=True), @@ -134,7 +134,7 @@ def upgrade(): op.create_table( 'slot_pool', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('pool', sa.String(length=50, **COLLATION_ARGS), nullable=True), + sa.Column('pool', StringID(length=50), nullable=True), sa.Column('slots', sa.Integer(), nullable=True), sa.Column('description', sa.Text(), nullable=True), sa.PrimaryKeyConstraint('id'), @@ -143,8 +143,8 @@ def upgrade(): if 'task_instance' not in tables: op.create_table( 'task_instance', - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('task_id', StringID(), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.Column('execution_date', sa.DateTime(), nullable=False), sa.Column('start_date', sa.DateTime(), nullable=True), sa.Column('end_date', sa.DateTime(), nullable=True), @@ -169,7 +169,7 @@ def upgrade(): op.create_table( 'user', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('username', sa.String(length=250, **COLLATION_ARGS), nullable=True), + sa.Column('username', StringID(), nullable=True), sa.Column('email', sa.String(length=500), nullable=True), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('username'), @@ -178,7 +178,7 @@ def upgrade(): op.create_table( 'variable', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('key', sa.String(length=250, **COLLATION_ARGS), nullable=True), + sa.Column('key', StringID(), nullable=True), sa.Column('val', sa.Text(), nullable=True), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('key'), @@ -211,12 +211,12 @@ def upgrade(): op.create_table( 'xcom', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('key', sa.String(length=512, **COLLATION_ARGS), nullable=True), + sa.Column('key', StringID(length=512), nullable=True), sa.Column('value', sa.PickleType(), nullable=True), sa.Column('timestamp', sa.DateTime(), default=func.now(), nullable=False), sa.Column('execution_date', sa.DateTime(), nullable=False), - sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('task_id', StringID(), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.PrimaryKeyConstraint('id'), ) diff --git a/airflow/migrations/versions/e9304a3141f0_make_xcom_pkey_columns_non_nullable.py b/airflow/migrations/versions/e9304a3141f0_make_xcom_pkey_columns_non_nullable.py index bde065b3e1ef4..8b331b6c72fe9 100644 --- a/airflow/migrations/versions/e9304a3141f0_make_xcom_pkey_columns_non_nullable.py +++ b/airflow/migrations/versions/e9304a3141f0_make_xcom_pkey_columns_non_nullable.py @@ -23,11 +23,9 @@ Create Date: 2021-04-06 13:22:02.197726 """ -import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import mssql, mysql -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import TIMESTAMP, StringID # revision identifiers, used by Alembic. revision = 'e9304a3141f0' @@ -36,32 +34,12 @@ depends_on = None -def _use_date_time2(conn): - result = conn.execute( - """SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion')) - like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion""" - ).fetchone() - mssql_version = result[0] - return mssql_version not in ("2000", "2005") - - -def _get_timestamp(conn): - dialect_name = conn.dialect.name - if dialect_name == "mssql": - return mssql.DATETIME2(precision=6) if _use_date_time2(conn) else mssql.DATETIME - elif dialect_name == "mysql": - return mysql.TIMESTAMP(fsp=6, timezone=True) - else: - return sa.TIMESTAMP(timezone=True) - - def upgrade(): """Apply make xcom pkey columns non-nullable""" conn = op.get_bind() with op.batch_alter_table('xcom') as bop: - bop.alter_column("key", type_=sa.String(length=512, **COLLATION_ARGS), nullable=False) - bop.alter_column("execution_date", type_=_get_timestamp(conn), nullable=False) + bop.alter_column("key", type_=StringID(length=512), nullable=False) + bop.alter_column("execution_date", type_=TIMESTAMP, nullable=False) if conn.dialect.name == 'mssql': bop.create_primary_key('pk_xcom', ['dag_id', 'task_id', 'key', 'execution_date']) @@ -72,5 +50,5 @@ def downgrade(): with op.batch_alter_table('xcom') as bop: if conn.dialect.name == 'mssql': bop.drop_constraint('pk_xcom', 'primary') - bop.alter_column("key", type_=sa.String(length=512, **COLLATION_ARGS), nullable=True) - bop.alter_column("execution_date", type_=_get_timestamp(conn), nullable=True) + bop.alter_column("key", type_=StringID(length=512), nullable=True) + bop.alter_column("execution_date", type_=TIMESTAMP, nullable=True) diff --git a/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py b/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py index ce6f5010dd498..644c138eba9ab 100644 --- a/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py +++ b/airflow/migrations/versions/f2ca10b85618_add_dag_stats_table.py @@ -26,7 +26,7 @@ import sqlalchemy as sa from alembic import op -from airflow.models.base import COLLATION_ARGS +from airflow.migrations.db_types import StringID # revision identifiers, used by Alembic. revision = 'f2ca10b85618' @@ -38,7 +38,7 @@ def upgrade(): op.create_table( 'dag_stats', - sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False), + sa.Column('dag_id', StringID(), nullable=False), sa.Column('state', sa.String(length=50), nullable=False), sa.Column('count', sa.Integer(), nullable=False, default=0), sa.Column('dirty', sa.Boolean(), nullable=False, default=False), diff --git a/airflow/models/base.py b/airflow/models/base.py index 02d230df811fb..29a5320e879d8 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. -from typing import Any +import functools +from typing import Any, Type -from sqlalchemy import MetaData +from sqlalchemy import MetaData, String from sqlalchemy.ext.declarative import declarative_base from airflow.configuration import conf @@ -61,3 +62,5 @@ def get_id_collation_args(): COLLATION_ARGS = get_id_collation_args() + +StringID: Type[String] = functools.partial(String, length=ID_LEN, **COLLATION_ARGS) diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py index 601dc6f9fe9da..d0c82a10a9011 100644 --- a/tests/utils/test_db.py +++ b/tests/utils/test_db.py @@ -23,6 +23,7 @@ from alembic.autogenerate import compare_metadata from alembic.config import Config from alembic.migration import MigrationContext +from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory from sqlalchemy import MetaData @@ -85,9 +86,17 @@ def test_only_single_head_revision_in_migrations(self): config.set_main_option("script_location", "airflow:migrations") script = ScriptDirectory.from_config(config) - # This will raise if there are multiple heads - # To resolve, use the command `alembic merge` - script.get_current_head() + from airflow.settings import engine + + with EnvironmentContext( + config, + script, + as_sql=True, + ) as env: + env.configure(dialect_name=engine.dialect.name) + # This will raise if there are multiple heads + # To resolve, use the command `alembic merge` + script.get_current_head() def test_default_connections_sort(self): pattern = re.compile('conn_id=[\"|\'](.*?)[\"|\']', re.DOTALL)