From 22a9293ff8f48411d39074d9bc88af35abe9850f Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 6 May 2022 17:02:27 +0100 Subject: [PATCH] Change approach to finding bad rows to LEFT OUTER JOIN. (#23528) Rather than sub-selects (two for count, or one for the CREATE TABLE). For a _large_ database (27m TaskInstances, 2m DagRuns) this takes the time from 10minutes to around 3 minutes per table (we have 3) down to 3 minutes per table. (All times on Postgres.) Before: ```sql CREATE TABLE _airflow_moved__2_3__dangling__rendered_task_instance_fields AS SELECT rendered_task_instance_fields.dag_id AS dag_id, rendered_task_instance_fields.task_id AS task_id, rendered_task_instance_fields.execution_date AS execution_date, rendered_task_instance_fields.rendered_fields AS rendered_fields, rendered_task_instance_fields.k8s_pod_yaml AS k8s_pod_yaml + FROM rendered_task_instance_fields WHERE NOT ( EXISTS ( SELECT 1 FROM task_instance JOIN dag_run ON dag_run.dag_id = task_instance.dag_id AND dag_run.run_id = task_instance.run_id WHERE rendered_task_instance_fields.dag_id = task_instance.dag_id AND rendered_task_instance_fields.task_id = task_instance.task_id AND rendered_task_instance_fields.execution_date = dag_run.execution_date ) ) ``` After: ```sql CREATE TABLE _airflow_moved__2_3__dangling__rendered_task_instance_fields AS SELECT rendered_task_instance_fields.dag_id AS dag_id, rendered_task_instance_fields.task_id AS task_id, rendered_task_instance_fields.execution_date AS execution_date, rendered_task_instance_fields.rendered_fields AS rendered_fields, rendered_task_instance_fields.k8s_pod_yaml AS k8s_pod_yaml + FROM rendered_task_instance_fields LEFT OUTER JOIN dag_run ON rendered_task_instance_fields.dag_id = dag_run.dag_id AND rendered_task_instance_fields.execution_date = dag_run.execution_date LEFT OUTER JOIN task_instance ON dag_run.dag_id = task_instance.dag_id AND dag_run.run_id = task_instance.run_id AND rendered_task_instance_fields.task_id = task_instance.task_id WHERE task_instance.dag_id IS NULL OR dag_run.dag_id IS NULL ; ``` --- airflow/utils/db.py | 73 +++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 7325c0e243672..730757694d921 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -26,7 +26,7 @@ from tempfile import gettempdir from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union -from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text +from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_ from sqlalchemy.orm.session import Session import airflow @@ -1047,7 +1047,7 @@ def _create_table_as( def _move_dangling_data_to_new_table( - session, source_table: "Table", source_query: "Query", exists_subquery, target_table_name: str + session, source_table: "Table", source_query: "Query", target_table_name: str ): bind = session.get_bind() @@ -1072,11 +1072,16 @@ def _move_dangling_data_to_new_table( if not first_moved_row: log.debug("no rows moved; dropping %s", target_table_name) + # no bad rows were found; drop moved rows table. target_table.drop(bind=session.get_bind(), checkfirst=True) else: log.debug("rows moved; purging from %s", source_table.name) if dialect_name == 'sqlite': - delete = source_table.delete().where(~exists_subquery.exists()) + pk_cols = source_table.primary_key.columns + + delete = source_table.delete().where( + tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery()) + ) else: delete = source_table.delete().where( and_(col == target_table.c[col.name] for col in source_table.primary_key.columns) @@ -1088,7 +1093,7 @@ def _move_dangling_data_to_new_table( log.debug("exiting move function") -def _dag_run_exists(session, source_table, dag_run): +def _dangling_against_dag_run(session, source_table, dag_run): """ Given a source table, we generate a subquery that will return 1 for every row that has a dagrun. @@ -1097,11 +1102,14 @@ def _dag_run_exists(session, source_table, dag_run): source_table.c.dag_id == dag_run.c.dag_id, source_table.c.execution_date == dag_run.c.execution_date, ) - exists_subquery = session.query(text('1')).select_from(dag_run).filter(source_to_dag_run_join_cond) - return exists_subquery + return ( + session.query(*[c.label(c.name) for c in source_table.c]) + .join(dag_run, source_to_dag_run_join_cond, isouter=True) + .filter(dag_run.c.dag_id.is_(None)) + ) -def _task_instance_exists(session, source_table, dag_run, task_instance): +def _dangling_against_task_instance(session, source_table, dag_run, task_instance): """ Given a source table, we generate a subquery that will return 1 for every row that has a valid task instance (and associated dagrun). @@ -1114,32 +1122,33 @@ def _task_instance_exists(session, source_table, dag_run, task_instance): """ if 'run_id' not in task_instance.c: # db is < 2.2.0 - where_clause = and_( - source_table.c.dag_id == task_instance.c.dag_id, - source_table.c.task_id == task_instance.c.task_id, - source_table.c.execution_date == task_instance.c.execution_date, + dr_join_cond = and_( + source_table.c.dag_id == dag_run.c.dag_id, + source_table.c.execution_date == dag_run.c.execution_date, ) - ti_to_dr_join_cond = and_( + ti_join_cond = and_( dag_run.c.dag_id == task_instance.c.dag_id, dag_run.c.execution_date == task_instance.c.execution_date, + source_table.c.task_id == task_instance.c.task_id, ) else: # db is 2.2.0 <= version < 2.3.0 - where_clause = and_( - source_table.c.dag_id == task_instance.c.dag_id, - source_table.c.task_id == task_instance.c.task_id, + dr_join_cond = and_( + source_table.c.dag_id == dag_run.c.dag_id, source_table.c.execution_date == dag_run.c.execution_date, ) - ti_to_dr_join_cond = and_( + ti_join_cond = and_( dag_run.c.dag_id == task_instance.c.dag_id, dag_run.c.run_id == task_instance.c.run_id, + source_table.c.task_id == task_instance.c.task_id, ) - exists_subquery = ( - session.query(text('1')) - .select_from(task_instance.join(dag_run, onclause=ti_to_dr_join_cond)) - .filter(where_clause) + + return ( + session.query(*[c.label(c.name) for c in source_table.c]) + .join(dag_run, dr_join_cond, isouter=True) + .join(task_instance, ti_join_cond, isouter=True) + .filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) ) - return exists_subquery def _move_duplicate_data_to_new_table( @@ -1207,23 +1216,23 @@ def check_bad_references(session: Session) -> Iterable[str]: @dataclass class BadReferenceConfig: """ - :param exists_func: function that returns subquery which determines whether bad rows exist + :param bad_rows_func: function that returns subquery which determines whether bad rows exist :param join_tables: table objects referenced in subquery :param ref_table: information-only identifier for categorizing the missing ref """ - exists_func: Callable + bad_rows_func: Callable join_tables: List[str] ref_table: str missing_dag_run_config = BadReferenceConfig( - exists_func=_dag_run_exists, + bad_rows_func=_dangling_against_dag_run, join_tables=['dag_run'], ref_table='dag_run', ) missing_ti_config = BadReferenceConfig( - exists_func=_task_instance_exists, + bad_rows_func=_dangling_against_task_instance, join_tables=['dag_run', 'task_instance'], ref_table='task_instance', ) @@ -1238,7 +1247,8 @@ class BadReferenceConfig: metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session) if ( - metadata.tables.get(DagRun.__tablename__) is None + not metadata.tables + or metadata.tables.get(DagRun.__tablename__) is None or metadata.tables.get(TaskInstance.__tablename__) is None ): # Key table doesn't exist -- likely empty DB. @@ -1251,7 +1261,6 @@ class BadReferenceConfig: log.debug("checking model %s", model.__tablename__) # We can't use the model here since it may differ from the db state due to # this function is run prior to migration. Use the reflected table instead. - exists_func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables} source_table = metadata.tables.get(model.__tablename__) # type: ignore if source_table is None: continue @@ -1260,13 +1269,12 @@ class BadReferenceConfig: if "run_id" in source_table.columns: continue - bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, **exists_func_kwargs) - select_list = [x.label(x.name) for x in source_table.c] - invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists()) + func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables} + bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs) dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling') if dangling_table_name in existing_table_names: - invalid_row_count = invalid_rows_query.count() + invalid_row_count = bad_rows_query.count() if invalid_row_count <= 0: continue else: @@ -1283,8 +1291,7 @@ class BadReferenceConfig: _move_dangling_data_to_new_table( session, source_table, - invalid_rows_query, - bad_rows_subquery, + bad_rows_query, dangling_table_name, )