Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-upgrade check for dangling TI references #22924

Merged
merged 1 commit into from
Apr 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 104 additions & 37 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import sys
import time
import warnings
from dataclasses import dataclass
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple

from sqlalchemy import Table, column, exc, func, inspect, literal, or_, table, text
from sqlalchemy import Table, and_, column, exc, func, inspect, literal, or_, table, text
from sqlalchemy.orm.session import Session

import airflow
from airflow import settings
from airflow.compat.sqlalchemy import has_table
from airflow.configuration import conf
Expand Down Expand Up @@ -654,8 +656,7 @@ def initdb(session: Session = NEW_SESSION):
def _get_alembic_config():
from alembic.config import Config

current_dir = os.path.dirname(os.path.abspath(__file__))
package_dir = os.path.normpath(os.path.join(current_dir, '..'))
package_dir = os.path.dirname(airflow.__file__)
directory = os.path.join(package_dir, 'migrations')
config = Config(os.path.join(package_dir, 'alembic.ini'))
config.set_main_option('script_location', directory.replace('%', '%%'))
Expand Down Expand Up @@ -955,7 +956,9 @@ def check_run_id_null(session: Session) -> Iterable[str]:
)
invalid_dagrun_count = session.query(dagrun_table.c.id).filter(invalid_dagrun_filter).count()
if invalid_dagrun_count > 0:
dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2")
dagrun_dangling_table_name = _format_airflow_moved_table_name(
source_table=dagrun_table.name, version="2.2"
)
if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
yield _format_dangling_error(
source_table=dagrun_table.name,
Expand Down Expand Up @@ -1035,28 +1038,102 @@ def _move_dangling_data_to_new_table(
session.execute(delete)


def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]:
def _dag_run_exists(session, source_table, dag_run):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a dagrun.
"""
source_to_dag_run_join_cond = and_(
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
exists_subquery = session.query(text('1')).select_from(dag_run).filter(source_to_dag_run_join_cond)
return exists_subquery


def _task_instance_exists(session, source_table, dag_run, task_instance):
"""
Given a source table, we generate a subquery that will return 1 for every row that
has a valid task instance (and associated dagrun).

This is used to identify rows that need to be removed from tables prior to adding a TI fk.

Since this check is applied prior to running the migrations, we have to use different
query logic depending on which revision the database is at.

"""
if 'run_id' not in task_instance.c:
# db is < 2.2.0
source_to_ti_join_cond = and_(
source_table.c.dag_id == task_instance.c.dag_id,
source_table.c.task_id == task_instance.c.task_id,
source_table.c.execution_date == task_instance.c.execution_date,
)
ti_to_dr_join_cond = and_(
source_table.c.dag_id == task_instance.c.dag_id,
source_table.c.execution_date == task_instance.c.execution_date,
)
else:
# db is 2.2.0 <= version < 2.3.0
source_to_ti_join_cond = and_(
source_table.c.dag_id == task_instance.c.dag_id,
source_table.c.task_id == task_instance.c.task_id,
)
ti_to_dr_join_cond = and_(
source_table.c.dag_id == task_instance.c.dag_id,
dag_run.c.run_id == task_instance.c.run_id,
source_table.c.execution_date == dag_run.c.execution_date,
)
exists_subquery = (
session.query(text('1'))
.select_from(task_instance.join(dag_run, onclause=ti_to_dr_join_cond))
.filter(source_to_ti_join_cond)
)
return exists_subquery


def check_bad_references(session: Session) -> Iterable[str]:
"""
Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id`
in many tables.
Here we go through each table and look for records that can't be mapped to a dag run.
When we find such "dangling" rows we back them up in a special table and delete them
from the main table.
"""
from sqlalchemy import and_

from airflow.models.renderedtifields import RenderedTaskInstanceFields

models_to_dagrun: List[Tuple[Base, str]] = [
(mod, ver)
for ver, models in {
'2.2': [TaskInstance, TaskReschedule],
'2.3': [RenderedTaskInstanceFields, TaskFail, XCom],
}.items()
for mod in models
]
@dataclass
class BadReferenceConfig:
"""
:param exists_func: function that returns subquery which determines whether bad rows exist
:param join_tables: table objects referenced in subquery
:param ref_table: information-only identifier for categorizing the missing ref
"""

exists_func: Callable
join_tables: List[str]
ref_table: str

missing_dag_run_config = BadReferenceConfig(
exists_func=_dag_run_exists,
join_tables=['dag_run'],
ref_table='dag_run',
)

missing_ti_config = BadReferenceConfig(
exists_func=_task_instance_exists,
join_tables=['dag_run', 'task_instance'],
ref_table='task_instance',
)

metadata = reflect_tables([*[x[0] for x in models_to_dagrun], DagRun], session)
models_list: List[Tuple[Base, str, BadReferenceConfig]] = [
(TaskInstance, '2.2', missing_dag_run_config),
(TaskReschedule, '2.2', missing_ti_config),
(RenderedTaskInstanceFields, '2.3', missing_ti_config),
(TaskFail, '2.3', missing_ti_config),
(XCom, '2.3', missing_ti_config),
]
metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session)

if (
metadata.tables.get(DagRun.__tablename__) is None
Expand All @@ -1065,16 +1142,13 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
# Key table doesn't exist -- likely empty DB.
return

# We can't use the model here since it may differ from the db state due to
# this function is run prior to migration. Use the reflected table instead.
dagrun_table = metadata.tables[DagRun.__tablename__]

existing_table_names = set(inspect(session.get_bind()).get_table_names())
errored = False

for model, change_version in models_to_dagrun:
for model, change_version, bad_ref_cfg in models_list:
# We can't use the model here since it may differ from the db state due to
# this function is run prior to migration. Use the reflected table instead.
exists_func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
source_table = metadata.tables.get(model.__tablename__) # type: ignore
if source_table is None:
continue
Expand All @@ -1083,37 +1157,30 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
if "run_id" in source_table.columns:
continue

# find rows in source table which don't have a matching dag run
source_to_dag_run_join_cond = and_(
source_table.c.dag_id == dagrun_table.c.dag_id,
source_table.c.execution_date == dagrun_table.c.execution_date,
)
exists_subquery = (
session.query(text('1')).select_from(dagrun_table).filter(source_to_dag_run_join_cond)
)
invalid_rows_query = session.query(*[x.label(x.name) for x in source_table.c]).filter(
~exists_subquery.exists()
)

bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, **exists_func_kwargs)
select_list = [x.label(x.name) for x in source_table.c]
invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists())
invalid_row_count = invalid_rows_query.count()
if invalid_row_count <= 0:
continue

dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version)
dangling_table_name = _format_airflow_moved_table_name(
source_table=source_table.name, version=change_version
)
if dangling_table_name in existing_table_names:
yield _format_dangling_error(
source_table=source_table.name,
target_table=dangling_table_name,
invalid_count=invalid_row_count,
reason=f"without a corresponding {dagrun_table.name} row",
reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
)
errored = True
continue
_move_dangling_data_to_new_table(
session,
source_table,
invalid_rows_query,
exists_subquery,
bad_rows_subquery,
dangling_table_name,
)

Expand All @@ -1134,7 +1201,7 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
check_conn_id_duplicates,
check_conn_type_null,
check_run_id_null,
check_task_tables_without_matching_dagruns,
check_bad_references,
)
for check_fn in check_functions:
yield from check_fn(session=session)
Expand Down