diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 730757694d921..54aacc4d2276a 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -26,7 +26,22 @@ from tempfile import gettempdir from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union -from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_ +import sqlalchemy.exc +from sqlalchemy import ( + Column, + String, + Table, + and_, + column, + exc, + func, + inspect, + or_, + select, + table, + text, + tuple_, +) from sqlalchemy.orm.session import Session import airflow @@ -37,6 +52,7 @@ from airflow.jobs.base_job import BaseJob # noqa: F401 from airflow.models import ( # noqa: F401 DAG, + ID_LEN, XCOM_RETURN_KEY, Base, BaseOperator, @@ -59,12 +75,14 @@ ) # We need to add this model manually to get reset working well +from airflow.models.base import COLLATION_ARGS from airflow.models.serialized_dag import SerializedDagModel # noqa: F401 from airflow.models.tasklog import LogTemplate from airflow.utils import helpers # TODO: remove create_session once we decide to break backward compatibility from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401 +from airflow.utils.sqlalchemy import UtcDateTime from airflow.version import version if TYPE_CHECKING: @@ -1063,7 +1081,6 @@ def _move_dangling_data_to_new_table( session=session, ) session.commit() - target_table = source_table.to_metadata(source_table.metadata, name=target_table_name) log.debug("checking whether rows were moved for table %s", target_table_name) moved_rows_exist_query = select([1]).select_from(target_table).limit(1) @@ -1075,6 +1092,7 @@ def _move_dangling_data_to_new_table( # no bad rows were found; drop moved rows table. target_table.drop(bind=session.get_bind(), checkfirst=True) else: + # purge the bad rows log.debug("rows moved; purging from %s", source_table.name) if dialect_name == 'sqlite': pk_cols = source_table.primary_key.columns @@ -1088,6 +1106,7 @@ def _move_dangling_data_to_new_table( ) log.debug(delete.compile()) session.execute(delete) + session.commit() log.debug("exiting move function") @@ -1109,7 +1128,7 @@ def _dangling_against_dag_run(session, source_table, dag_run): ) -def _dangling_against_task_instance(session, source_table, dag_run, task_instance): +def _dangling_against_task_instance(session, source_table, ti_lkp_table): """ Given a source table, we generate a subquery that will return 1 for every row that has a valid task instance (and associated dagrun). @@ -1120,35 +1139,13 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc query logic depending on which revision the database is at. """ - if 'run_id' not in task_instance.c: - # db is < 2.2.0 - dr_join_cond = and_( - source_table.c.dag_id == dag_run.c.dag_id, - source_table.c.execution_date == dag_run.c.execution_date, - ) - ti_join_cond = and_( - dag_run.c.dag_id == task_instance.c.dag_id, - dag_run.c.execution_date == task_instance.c.execution_date, - source_table.c.task_id == task_instance.c.task_id, - ) - else: - # db is 2.2.0 <= version < 2.3.0 - dr_join_cond = and_( - source_table.c.dag_id == dag_run.c.dag_id, - source_table.c.execution_date == dag_run.c.execution_date, - ) - ti_join_cond = and_( - dag_run.c.dag_id == task_instance.c.dag_id, - dag_run.c.run_id == task_instance.c.run_id, - source_table.c.task_id == task_instance.c.task_id, - ) - - return ( - session.query(*[c.label(c.name) for c in source_table.c]) - .join(dag_run, dr_join_cond, isouter=True) - .join(task_instance, ti_join_cond, isouter=True) - .filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) + where_clause = and_( + source_table.c.dag_id == ti_lkp_table.c.dag_id, + source_table.c.task_id == ti_lkp_table.c.task_id, + source_table.c.execution_date == ti_lkp_table.c.execution_date, ) + exists_subquery = session.query(text('1')).select_from(ti_lkp_table).filter(where_clause) + return exists_subquery def _move_duplicate_data_to_new_table( @@ -1203,6 +1200,43 @@ def _move_duplicate_data_to_new_table( session.execute(delete) +def _create_ti_key_lkp_table(session, table_name) -> Table: + """Creates lkp table for all valid TI keys""" + tmp_table = Table( + table_name, + Base.metadata, + Column('task_id', String(ID_LEN, **COLLATION_ARGS), primary_key=True), + Column('dag_id', String(ID_LEN, **COLLATION_ARGS), primary_key=True), + Column('execution_date', UtcDateTime, primary_key=True), + ) + tmp_table.drop(bind=settings.engine, checkfirst=True) + + log.debug("creating TI key lkp table") + Base.metadata.create_all(settings.engine, tables=[tmp_table]) + log.debug("inserting TI key lkp table") + session.commit() + try: + # post 2.2 + session.execute( + f"insert into {table_name} " + "select ti.task_id, ti.dag_id, dr.execution_date " + "from task_instance ti " + "join dag_run dr on dr.dag_id = ti.dag_id " + " and dr.run_id = ti.run_id " + ) + except sqlalchemy.exc.OperationalError: + # pre-2.2 + session.execute( + f"insert into {table_name} " + "select ti.task_id, ti.dag_id, dr.execution_date " + "from task_instance ti " + "join dag_run dr on dr.dag_id = ti.dag_id " + " and dr.execution_date = ti.execution_date " + " and dr.run_id is not null" + ) + return tmp_table + + def check_bad_references(session: Session) -> Iterable[str]: """ Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id` @@ -1255,8 +1289,10 @@ class BadReferenceConfig: return existing_table_names = set(inspect(session.get_bind()).get_table_names()) - errored = False + ti_lkp_table = _create_ti_key_lkp_table(session=session, table_name='_airflow_tmp_ti_key_lkp') + + session.commit() for model, change_version, bad_ref_cfg in models_list: log.debug("checking model %s", model.__tablename__) # We can't use the model here since it may differ from the db state due to @@ -1269,12 +1305,15 @@ class BadReferenceConfig: if "run_id" in source_table.columns: continue - func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables} - bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs) - + session.commit() + bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, ti_lkp_table=ti_lkp_table) dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling') + select_list = [x.label(x.name) for x in source_table.c] + log.debug(bad_rows_subquery.selectable.compile()) + invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists()) + session.commit() if dangling_table_name in existing_table_names: - invalid_row_count = bad_rows_query.count() + invalid_row_count = invalid_rows_query.count() if invalid_row_count <= 0: continue else: @@ -1284,21 +1323,18 @@ class BadReferenceConfig: invalid_count=invalid_row_count, reason=f"without a corresponding {bad_ref_cfg.ref_table} row", ) - errored = True - continue + continue log.debug("moving data for table %s", source_table.name) _move_dangling_data_to_new_table( session, source_table, - bad_rows_query, + invalid_rows_query, dangling_table_name, ) - - if errored: - session.rollback() - else: session.commit() + session.commit() + ti_lkp_table.drop(bind=settings.engine, checkfirst=True) @provide_session