Skip to content

Commit

Permalink
Change approach to finding bad rows to LEFT OUTER JOIN. (#23528)
Browse files Browse the repository at this point in the history
Rather than sub-selects (two for count, or one for the CREATE TABLE).

For a _large_ database (27m TaskInstances, 2m DagRuns) this takes the
time from 10minutes to around 3 minutes per table (we have 3) down to 3
minutes per table. (All times on Postgres.)

Before:

```sql
CREATE TABLE _airflow_moved__2_3__dangling__rendered_task_instance_fields AS
SELECT
  rendered_task_instance_fields.dag_id AS dag_id,
  rendered_task_instance_fields.task_id AS task_id,
  rendered_task_instance_fields.execution_date AS execution_date,
  rendered_task_instance_fields.rendered_fields AS rendered_fields,
  rendered_task_instance_fields.k8s_pod_yaml AS k8s_pod_yaml +
FROM
  rendered_task_instance_fields
WHERE
  NOT (
    EXISTS (
      SELECT
        1
      FROM
        task_instance
        JOIN dag_run ON dag_run.dag_id = task_instance.dag_id
        AND dag_run.run_id = task_instance.run_id
      WHERE
        rendered_task_instance_fields.dag_id = task_instance.dag_id
        AND rendered_task_instance_fields.task_id = task_instance.task_id
        AND rendered_task_instance_fields.execution_date = dag_run.execution_date
    )
  )
```

After:

```sql
CREATE TABLE _airflow_moved__2_3__dangling__rendered_task_instance_fields AS
SELECT
  rendered_task_instance_fields.dag_id AS dag_id,
  rendered_task_instance_fields.task_id AS task_id,
  rendered_task_instance_fields.execution_date AS execution_date,
  rendered_task_instance_fields.rendered_fields AS rendered_fields,
  rendered_task_instance_fields.k8s_pod_yaml AS k8s_pod_yaml +
FROM
  rendered_task_instance_fields
  LEFT OUTER JOIN dag_run ON rendered_task_instance_fields.dag_id = dag_run.dag_id
  AND rendered_task_instance_fields.execution_date = dag_run.execution_date
  LEFT OUTER JOIN task_instance ON dag_run.dag_id = task_instance.dag_id
  AND dag_run.run_id = task_instance.run_id
  AND rendered_task_instance_fields.task_id = task_instance.task_id
WHERE
  task_instance.dag_id IS NULL
  OR dag_run.dag_id IS NULL
;
```
  • Loading branch information
ashb authored May 6, 2022
1 parent 6cc41ab commit 22a9293
Showing 1 changed file with 40 additions and 33 deletions.
73 changes: 40 additions & 33 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union

from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text
from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_
from sqlalchemy.orm.session import Session

import airflow
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def _create_table_as(


def _move_dangling_data_to_new_table(
session, source_table: "Table", source_query: "Query", exists_subquery, target_table_name: str
session, source_table: "Table", source_query: "Query", target_table_name: str
):

bind = session.get_bind()
Expand All @@ -1072,11 +1072,16 @@ def _move_dangling_data_to_new_table(

if not first_moved_row:
log.debug("no rows moved; dropping %s", target_table_name)
# no bad rows were found; drop moved rows table.
target_table.drop(bind=session.get_bind(), checkfirst=True)
else:
log.debug("rows moved; purging from %s", source_table.name)
if dialect_name == 'sqlite':
delete = source_table.delete().where(~exists_subquery.exists())
pk_cols = source_table.primary_key.columns

delete = source_table.delete().where(
tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery())
)
else:
delete = source_table.delete().where(
and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
Expand All @@ -1088,7 +1093,7 @@ def _move_dangling_data_to_new_table(
log.debug("exiting move function")


def _dag_run_exists(session, source_table, dag_run):
def _dangling_against_dag_run(session, source_table, dag_run):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a dagrun.
Expand All @@ -1097,11 +1102,14 @@ def _dag_run_exists(session, source_table, dag_run):
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
exists_subquery = session.query(text('1')).select_from(dag_run).filter(source_to_dag_run_join_cond)
return exists_subquery
return (
session.query(*[c.label(c.name) for c in source_table.c])
.join(dag_run, source_to_dag_run_join_cond, isouter=True)
.filter(dag_run.c.dag_id.is_(None))
)


def _task_instance_exists(session, source_table, dag_run, task_instance):
def _dangling_against_task_instance(session, source_table, dag_run, task_instance):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a valid task instance (and associated dagrun).
Expand All @@ -1114,32 +1122,33 @@ def _task_instance_exists(session, source_table, dag_run, task_instance):
"""
if 'run_id' not in task_instance.c:
# db is < 2.2.0
where_clause = and_(
source_table.c.dag_id == task_instance.c.dag_id,
source_table.c.task_id == task_instance.c.task_id,
source_table.c.execution_date == task_instance.c.execution_date,
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_to_dr_join_cond = and_(
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.execution_date == task_instance.c.execution_date,
source_table.c.task_id == task_instance.c.task_id,
)
else:
# db is 2.2.0 <= version < 2.3.0
where_clause = and_(
source_table.c.dag_id == task_instance.c.dag_id,
source_table.c.task_id == task_instance.c.task_id,
dr_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
ti_to_dr_join_cond = and_(
ti_join_cond = and_(
dag_run.c.dag_id == task_instance.c.dag_id,
dag_run.c.run_id == task_instance.c.run_id,
source_table.c.task_id == task_instance.c.task_id,
)
exists_subquery = (
session.query(text('1'))
.select_from(task_instance.join(dag_run, onclause=ti_to_dr_join_cond))
.filter(where_clause)

return (
session.query(*[c.label(c.name) for c in source_table.c])
.join(dag_run, dr_join_cond, isouter=True)
.join(task_instance, ti_join_cond, isouter=True)
.filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
)
return exists_subquery


def _move_duplicate_data_to_new_table(
Expand Down Expand Up @@ -1207,23 +1216,23 @@ def check_bad_references(session: Session) -> Iterable[str]:
@dataclass
class BadReferenceConfig:
"""
:param exists_func: function that returns subquery which determines whether bad rows exist
:param bad_rows_func: function that returns subquery which determines whether bad rows exist
:param join_tables: table objects referenced in subquery
:param ref_table: information-only identifier for categorizing the missing ref
"""

exists_func: Callable
bad_rows_func: Callable
join_tables: List[str]
ref_table: str

missing_dag_run_config = BadReferenceConfig(
exists_func=_dag_run_exists,
bad_rows_func=_dangling_against_dag_run,
join_tables=['dag_run'],
ref_table='dag_run',
)

missing_ti_config = BadReferenceConfig(
exists_func=_task_instance_exists,
bad_rows_func=_dangling_against_task_instance,
join_tables=['dag_run', 'task_instance'],
ref_table='task_instance',
)
Expand All @@ -1238,7 +1247,8 @@ class BadReferenceConfig:
metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session)

if (
metadata.tables.get(DagRun.__tablename__) is None
not metadata.tables
or metadata.tables.get(DagRun.__tablename__) is None
or metadata.tables.get(TaskInstance.__tablename__) is None
):
# Key table doesn't exist -- likely empty DB.
Expand All @@ -1251,7 +1261,6 @@ class BadReferenceConfig:
log.debug("checking model %s", model.__tablename__)
# We can't use the model here since it may differ from the db state due to
# this function is run prior to migration. Use the reflected table instead.
exists_func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
source_table = metadata.tables.get(model.__tablename__) # type: ignore
if source_table is None:
continue
Expand All @@ -1260,13 +1269,12 @@ class BadReferenceConfig:
if "run_id" in source_table.columns:
continue

bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, **exists_func_kwargs)
select_list = [x.label(x.name) for x in source_table.c]
invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists())
func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs)

dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling')
if dangling_table_name in existing_table_names:
invalid_row_count = invalid_rows_query.count()
invalid_row_count = bad_rows_query.count()
if invalid_row_count <= 0:
continue
else:
Expand All @@ -1283,8 +1291,7 @@ class BadReferenceConfig:
_move_dangling_data_to_new_table(
session,
source_table,
invalid_rows_query,
bad_rows_subquery,
bad_rows_query,
dangling_table_name,
)

Expand Down

0 comments on commit 22a9293

Please sign in to comment.