Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change approach to finding bad rows to LEFT OUTER JOIN. #23528

Merged
merged 1 commit into from
May 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 40 additions & 33 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union

from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text
from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_
from sqlalchemy.orm.session import Session

import airflow
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def _create_table_as(


def _move_dangling_data_to_new_table(
session, source_table: "Table", source_query: "Query", exists_subquery, target_table_name: str
session, source_table: "Table", source_query: "Query", target_table_name: str
):

bind = session.get_bind()
Expand All @@ -1072,11 +1072,16 @@ def _move_dangling_data_to_new_table(

if not first_moved_row:
log.debug("no rows moved; dropping %s", target_table_name)
# no bad rows were found; drop moved rows table.
target_table.drop(bind=session.get_bind(), checkfirst=True)
else:
log.debug("rows moved; purging from %s", source_table.name)
if dialect_name == 'sqlite':
delete = source_table.delete().where(~exists_subquery.exists())
pk_cols = source_table.primary_key.columns

delete = source_table.delete().where(
tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery())
)
else:
delete = source_table.delete().where(
and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
Expand All @@ -1088,7 +1093,7 @@ def _move_dangling_data_to_new_table(
log.debug("exiting move function")


def _dag_run_exists(session, source_table, dag_run):
def _dangling_against_dag_run(session, source_table, dag_run):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a dagrun.
Expand All @@ -1097,11 +1102,14 @@ def _dag_run_exists(session, source_table, dag_run):
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
exists_subquery = session.query(text('1')).select_from(dag_run).filter(source_to_dag_run_join_cond)
return exists_subquery
return (
session.query(*[c.label(c.name) for c in source_table.c])
.join(dag_run, source_to_dag_run_join_cond, isouter=True)
.filter(dag_run.c.dag_id.is_(None))
)


def _task_instance_exists(session, source_table, dag_run, task_instance):
def _dangling_against_task_instance(session, source_table, dag_run, task_instance):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a valid task instance (and associated dagrun).
Expand All @@ -1114,32 +1122,33 @@ def _task_instance_exists(session, source_table, dag_run, task_instance):
"""
if 'run_id' not in task_instance.c:
# db is < 2.2.0
where_clause = and_(
source_table.c.dag_id == task_instance.c.dag_id,
source_table.c.task_id == task_instance.c.task_id,
source_table.c.execution_date == task_instance.c.execution_date,
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_to_dr_join_cond = and_(
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.execution_date == task_instance.c.execution_date,
source_table.c.task_id == task_instance.c.task_id,
)
else:
# db is 2.2.0 <= version < 2.3.0
where_clause = and_(
source_table.c.dag_id == task_instance.c.dag_id,
source_table.c.task_id == task_instance.c.task_id,
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_to_dr_join_cond = and_(
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.run_id == task_instance.c.run_id,
source_table.c.task_id == task_instance.c.task_id,
)
exists_subquery = (
session.query(text('1'))
.select_from(task_instance.join(dag_run, onclause=ti_to_dr_join_cond))
.filter(where_clause)

return (
session.query(*[c.label(c.name) for c in source_table.c])
.join(dag_run, dr_join_cond, isouter=True)
.join(task_instance, ti_join_cond, isouter=True)
.filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
)
return exists_subquery


def _move_duplicate_data_to_new_table(
Expand Down Expand Up @@ -1207,23 +1216,23 @@ def check_bad_references(session: Session) -> Iterable[str]:
@dataclass
class BadReferenceConfig:
"""
:param exists_func: function that returns subquery which determines whether bad rows exist
:param bad_rows_func: function that returns subquery which determines whether bad rows exist
:param join_tables: table objects referenced in subquery
:param ref_table: information-only identifier for categorizing the missing ref
"""

exists_func: Callable
bad_rows_func: Callable
join_tables: List[str]
ref_table: str

missing_dag_run_config = BadReferenceConfig(
exists_func=_dag_run_exists,
bad_rows_func=_dangling_against_dag_run,
join_tables=['dag_run'],
ref_table='dag_run',
)

missing_ti_config = BadReferenceConfig(
exists_func=_task_instance_exists,
bad_rows_func=_dangling_against_task_instance,
join_tables=['dag_run', 'task_instance'],
ref_table='task_instance',
)
Expand All @@ -1238,7 +1247,8 @@ class BadReferenceConfig:
metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session)

if (
metadata.tables.get(DagRun.__tablename__) is None
not metadata.tables
or metadata.tables.get(DagRun.__tablename__) is None
or metadata.tables.get(TaskInstance.__tablename__) is None
):
# Key table doesn't exist -- likely empty DB.
Expand All @@ -1251,7 +1261,6 @@ class BadReferenceConfig:
log.debug("checking model %s", model.__tablename__)
# We can't use the model here since it may differ from the db state due to
# this function is run prior to migration. Use the reflected table instead.
exists_func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
source_table = metadata.tables.get(model.__tablename__) # type: ignore
if source_table is None:
continue
Expand All @@ -1260,13 +1269,12 @@ class BadReferenceConfig:
if "run_id" in source_table.columns:
continue

bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, **exists_func_kwargs)
select_list = [x.label(x.name) for x in source_table.c]
invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists())
func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs)

dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling')
if dangling_table_name in existing_table_names:
invalid_row_count = invalid_rows_query.count()
invalid_row_count = bad_rows_query.count()
if invalid_row_count <= 0:
continue
else:
Expand All @@ -1283,8 +1291,7 @@ class BadReferenceConfig:
_move_dangling_data_to_new_table(
session,
source_table,
invalid_rows_query,
bad_rows_subquery,
bad_rows_query,
dangling_table_name,
)

Expand Down