Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Purge duplicates from TaskFail prior to 2.3 upgrade #22769

Merged
merged 2 commits into from
Apr 14, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 92 additions & 25 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import warnings
from dataclasses import dataclass
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union

from sqlalchemy import Table, and_, column, exc, func, inspect, literal, or_, table, text
from sqlalchemy import Table, and_, column, exc, func, inspect, or_, table, text
from sqlalchemy.orm.session import Session

import airflow
Expand Down Expand Up @@ -92,8 +92,8 @@
}


def _format_airflow_moved_table_name(source_table, version):
return "__".join([settings.AIRFLOW_MOVED_TABLE_PREFIX, version.replace(".", "_"), source_table])
def _format_airflow_moved_table_name(source_table, version, category):
return "__".join([settings.AIRFLOW_MOVED_TABLE_PREFIX, version.replace(".", "_"), category, source_table])


@provide_session
Expand Down Expand Up @@ -849,7 +849,7 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:
)


def reflect_tables(models, session):
def reflect_tables(tables: List[Union[Base, str]], session):
"""
When running checks prior to upgrades, we use reflection to determine current state of the
database.
Expand All @@ -861,9 +861,10 @@ def reflect_tables(models, session):

metadata = sqlalchemy.schema.MetaData(session.bind)

for model in models:
for tbl in tables:
try:
metadata.reflect(only=[model.__tablename__], extend_existing=True, resolve_fks=False)
table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
metadata.reflect(only=[table_name], extend_existing=True, resolve_fks=False)
except exc.InvalidRequestError:
continue
return metadata
Expand All @@ -881,10 +882,13 @@ def check_task_fail_for_duplicates(session):
table_name=task_fail.name,
uniqueness=['dag_id', 'task_id', 'execution_date'],
session=session,
version='2.3',
)


def check_table_for_duplicates(table_name: str, uniqueness: List[str], session: Session) -> Iterable[str]:
def check_table_for_duplicates(
*, session: Session, table_name: str, uniqueness: List[str], version: str
) -> Iterable[str]:
"""
Check table for duplicates, given a list of columns which define the uniqueness of the table.

Expand All @@ -895,24 +899,39 @@ def check_table_for_duplicates(table_name: str, uniqueness: List[str], session:
:param session: session of the sqlalchemy
:rtype: str
"""
table_obj = table(table_name, *[column(x) for x in uniqueness])
dupe_count = 0
minimal_table_obj = table(table_name, *[column(x) for x in uniqueness])
try:
subquery = (
session.query(table_obj, func.count().label('dupe_count'))
session.query(minimal_table_obj, func.count().label('dupe_count'))
.group_by(*[text(x) for x in uniqueness])
.having(func.count() > literal(1))
.having(func.count() > text('1'))
.subquery()
)
dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar()
if not dupe_count:
# there are no duplicates; nothing to do.
return

log.warning("Found %s duplicates in table %s. Will attempt to move them.", dupe_count, table_name)

metadata = reflect_tables(tables=[table_name], session=session)
if table_name not in metadata.tables:
yield f"Table {table_name} does not exist in the database."

# We can't use the model here since it may differ from the db state due to
# this function is run prior to migration. Use the reflected table instead.
table_obj = metadata.tables[table_name]

_move_duplicate_data_to_new_table(
session=session,
source_table=table_obj,
subquery=subquery,
uniqueness=uniqueness,
target_table_name=_format_airflow_moved_table_name(table_name, version, 'duplicates'),
)
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
# fallback if `table_name` hasn't been created yet
session.rollback()
if dupe_count:
yield (
f"Found {dupe_count} duplicate records in table {table_name}. You must de-dupe these "
f"records before upgrading. The uniqueness constraint for this table is {uniqueness!r}"
)


def check_conn_type_null(session: Session) -> Iterable[str]:
Expand Down Expand Up @@ -966,9 +985,7 @@ def check_run_id_null(session: Session) -> Iterable[str]:
)
invalid_dagrun_count = session.query(dagrun_table.c.id).filter(invalid_dagrun_filter).count()
if invalid_dagrun_count > 0:
dagrun_dangling_table_name = _format_airflow_moved_table_name(
source_table=dagrun_table.name, version="2.2"
)
dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, '2.2', 'dangling')
if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
yield _format_dangling_error(
source_table=dagrun_table.name,
Expand Down Expand Up @@ -1102,6 +1119,58 @@ def _task_instance_exists(session, source_table, dag_run, task_instance):
return exists_subquery


def _move_duplicate_data_to_new_table(
session, source_table: "Table", subquery: "Query", uniqueness: List[str], target_table_name: str
):
"""
When adding a uniqueness constraint we first should ensure that there are no duplicate rows.

This function accepts a subquery that should return one record for each row with duplicates (e.g.
a group by with having count(*) > 1). We select from ``source_table`` getting all rows matching the
subquery result and store in ``target_table_name``. Then to purge the duplicates from the source table,
we do a DELETE FROM with a join to the target table (which now contains the dupes).

:param session: sqlalchemy session for metadata db
:param source_table: table to purge dupes from
:param subquery: the subquery that returns the duplicate rows
:param uniqueness: the string list of columns used to define the uniqueness for the table. used in
building the DELETE FROM join condition.
:param target_table_name: name of the table in which to park the duplicate rows
"""

bind = session.get_bind()
dialect_name = bind.dialect.name
query = (
session.query(source_table)
.with_entities(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns])
.select_from(source_table)
.join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness]))
)

_create_table_as(
session=session,
dialect_name=dialect_name,
source_query=query,
target_table_name=target_table_name,
source_table_name=source_table.name,
)

# we must ensure that the CTAS table is created prior to the DELETE step since we have to join to it
session.commit()

metadata = reflect_tables([target_table_name], session)
target_table = metadata.tables[target_table_name]
where_clause = and_(*[getattr(source_table.c, x) == getattr(target_table.c, x) for x in uniqueness])

if dialect_name == "sqlite":
subq = query.selectable.with_only_columns([text(f'{source_table}.ROWID')])
delete = source_table.delete().where(column('ROWID').in_(subq))
else:
delete = source_table.delete(where_clause)

session.execute(delete)


def check_bad_references(session: Session) -> Iterable[str]:
"""
Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id`
Expand Down Expand Up @@ -1174,9 +1243,7 @@ class BadReferenceConfig:
if invalid_row_count <= 0:
continue

dangling_table_name = _format_airflow_moved_table_name(
source_table=source_table.name, version=change_version
)
dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling')
if dangling_table_name in existing_table_names:
yield _format_dangling_error(
source_table=source_table.name,
Expand Down Expand Up @@ -1207,7 +1274,7 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
:rtype: list[str]
"""
check_functions: Tuple[Callable[..., Iterable[str]], ...] = (
# todo: check task fail for duplicates
check_task_fail_for_duplicates,
check_conn_id_duplicates,
check_conn_type_null,
check_run_id_null,
Expand Down