Skip to content

Commit

Permalink
Use TI key lookup table for dangling row checks
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish committed May 6, 2022
1 parent 22a9293 commit b74a2e3
Showing 1 changed file with 75 additions and 43 deletions.
118 changes: 75 additions & 43 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,22 @@
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union

from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_
import sqlalchemy.exc
from sqlalchemy import (
Column,
String,
Table,
and_,
column,
exc,
func,
inspect,
or_,
select,
table,
text,
tuple_,
)
from sqlalchemy.orm.session import Session

import airflow
Expand All @@ -37,6 +52,7 @@
from airflow.jobs.base_job import BaseJob # noqa: F401
from airflow.models import ( # noqa: F401
DAG,
ID_LEN,
XCOM_RETURN_KEY,
Base,
BaseOperator,
Expand All @@ -59,12 +75,14 @@
)

# We need to add this model manually to get reset working well
from airflow.models.base import COLLATION_ARGS
from airflow.models.serialized_dag import SerializedDagModel # noqa: F401
from airflow.models.tasklog import LogTemplate
from airflow.utils import helpers

# TODO: remove create_session once we decide to break backward compatibility
from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.version import version

if TYPE_CHECKING:
Expand Down Expand Up @@ -1063,7 +1081,6 @@ def _move_dangling_data_to_new_table(
session=session,
)
session.commit()

target_table = source_table.to_metadata(source_table.metadata, name=target_table_name)
log.debug("checking whether rows were moved for table %s", target_table_name)
moved_rows_exist_query = select([1]).select_from(target_table).limit(1)
Expand All @@ -1075,6 +1092,7 @@ def _move_dangling_data_to_new_table(
# no bad rows were found; drop moved rows table.
target_table.drop(bind=session.get_bind(), checkfirst=True)
else:
# purge the bad rows
log.debug("rows moved; purging from %s", source_table.name)
if dialect_name == 'sqlite':
pk_cols = source_table.primary_key.columns
Expand All @@ -1088,6 +1106,7 @@ def _move_dangling_data_to_new_table(
)
log.debug(delete.compile())
session.execute(delete)

session.commit()

log.debug("exiting move function")
Expand All @@ -1109,7 +1128,7 @@ def _dangling_against_dag_run(session, source_table, dag_run):
)


def _dangling_against_task_instance(session, source_table, dag_run, task_instance):
def _dangling_against_task_instance(session, source_table, ti_lkp_table):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a valid task instance (and associated dagrun).
Expand All @@ -1120,35 +1139,13 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc
query logic depending on which revision the database is at.
"""
if 'run_id' not in task_instance.c:
# db is < 2.2.0
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.execution_date == task_instance.c.execution_date,
source_table.c.task_id == task_instance.c.task_id,
)
else:
# db is 2.2.0 <= version < 2.3.0
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.run_id == task_instance.c.run_id,
source_table.c.task_id == task_instance.c.task_id,
)

return (
session.query(*[c.label(c.name) for c in source_table.c])
.join(dag_run, dr_join_cond, isouter=True)
.join(task_instance, ti_join_cond, isouter=True)
.filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
where_clause = and_(
source_table.c.dag_id == ti_lkp_table.c.dag_id,
source_table.c.task_id == ti_lkp_table.c.task_id,
source_table.c.execution_date == ti_lkp_table.c.execution_date,
)
exists_subquery = session.query(text('1')).select_from(ti_lkp_table).filter(where_clause)
return exists_subquery


def _move_duplicate_data_to_new_table(
Expand Down Expand Up @@ -1203,6 +1200,43 @@ def _move_duplicate_data_to_new_table(
session.execute(delete)


def _create_ti_key_lkp_table(session, table_name) -> Table:
"""Creates lkp table for all valid TI keys"""
tmp_table = Table(
table_name,
Base.metadata,
Column('task_id', String(ID_LEN, **COLLATION_ARGS), primary_key=True),
Column('dag_id', String(ID_LEN, **COLLATION_ARGS), primary_key=True),
Column('execution_date', UtcDateTime, primary_key=True),
)
tmp_table.drop(bind=settings.engine, checkfirst=True)

log.debug("creating TI key lkp table")
Base.metadata.create_all(settings.engine, tables=[tmp_table])
log.debug("inserting TI key lkp table")
session.commit()
try:
# post 2.2
session.execute(
f"insert into {table_name} "
"select ti.task_id, ti.dag_id, dr.execution_date "
"from task_instance ti "
"join dag_run dr on dr.dag_id = ti.dag_id "
" and dr.run_id = ti.run_id "
)
except sqlalchemy.exc.OperationalError:
# pre-2.2
session.execute(
f"insert into {table_name} "
"select ti.task_id, ti.dag_id, dr.execution_date "
"from task_instance ti "
"join dag_run dr on dr.dag_id = ti.dag_id "
" and dr.execution_date = ti.execution_date "
" and dr.run_id is not null"
)
return tmp_table


def check_bad_references(session: Session) -> Iterable[str]:
"""
Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id`
Expand Down Expand Up @@ -1255,7 +1289,8 @@ class BadReferenceConfig:
return

existing_table_names = set(inspect(session.get_bind()).get_table_names())
errored = False

ti_lkp_table = _create_ti_key_lkp_table(session=session, table_name='_airflow_tmp_ti_key_lkp')

for model, change_version, bad_ref_cfg in models_list:
log.debug("checking model %s", model.__tablename__)
Expand All @@ -1269,12 +1304,13 @@ class BadReferenceConfig:
if "run_id" in source_table.columns:
continue

func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs)

bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, ti_lkp_table=ti_lkp_table)
dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling')
select_list = [x.label(x.name) for x in source_table.c]
log.debug(bad_rows_subquery.selectable.compile())
invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists())
if dangling_table_name in existing_table_names:
invalid_row_count = bad_rows_query.count()
invalid_row_count = invalid_rows_query.count()
if invalid_row_count <= 0:
continue
else:
Expand All @@ -1284,21 +1320,17 @@ class BadReferenceConfig:
invalid_count=invalid_row_count,
reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
)
errored = True
continue
continue

log.debug("moving data for table %s", source_table.name)
_move_dangling_data_to_new_table(
session,
source_table,
bad_rows_query,
invalid_rows_query,
dangling_table_name,
)

if errored:
session.rollback()
else:
session.commit()
ti_lkp_table.drop(bind=settings.engine, checkfirst=True)


@provide_session
Expand Down

0 comments on commit b74a2e3

Please sign in to comment.