Skip to content

Commit

Permalink
Only count bad refs when moved table exists (#23491)
Browse files Browse the repository at this point in the history
This keeps the logic to fail without upgrading when (A) there are bad rows and
(B) the "moved" table already exists. But we optimize so that we don't count
the bad rows unless the "moved" table is there. Previously we counted always,
but the first time a user attempts upgrade, the tables won't be there so
there's no point in counting.

Instead what we do is skip right to the CTAS, creating the _airflow_moved
tables. If there aren't any rows in the "moved" table, then we delete the table
immediately.

Also included here is a delete optimization, where we join to the moved table
instead of running the not exists query again.

Co-authored-by: Jed Cunningham <[email protected]>
Co-authored-by: Ash Berlin-Taylor <[email protected]>
(cherry picked from commit 6cc41ab)
  • Loading branch information
dstandish authored and ephraimbuddy committed May 8, 2022
1 parent e4521ef commit d9075f8
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union

from sqlalchemy import Table, and_, column, exc, func, inspect, or_, table, text
from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text
from sqlalchemy.orm.session import Session

import airflow
Expand Down Expand Up @@ -1054,16 +1054,38 @@ def _move_dangling_data_to_new_table(
dialect_name = bind.dialect.name

# First: Create moved rows from new table
log.debug("running CTAS for table %s", target_table_name)
_create_table_as(
dialect_name=dialect_name,
source_query=source_query,
target_table_name=target_table_name,
source_table_name=source_table.name,
session=session,
)
session.commit()

delete = source_table.delete().where(~exists_subquery.exists())
session.execute(delete)
target_table = source_table.to_metadata(source_table.metadata, name=target_table_name)
log.debug("checking whether rows were moved for table %s", target_table_name)
moved_rows_exist_query = select([1]).select_from(target_table).limit(1)
first_moved_row = session.execute(moved_rows_exist_query).all()
session.commit()

if not first_moved_row:
log.debug("no rows moved; dropping %s", target_table_name)
target_table.drop(bind=session.get_bind(), checkfirst=True)
else:
log.debug("rows moved; purging from %s", source_table.name)
if dialect_name == 'sqlite':
delete = source_table.delete().where(~exists_subquery.exists())
else:
delete = source_table.delete().where(
and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
)
log.debug(delete.compile())
session.execute(delete)
session.commit()

log.debug("exiting move function")


def _dag_run_exists(session, source_table, dag_run):
Expand Down Expand Up @@ -1226,6 +1248,7 @@ class BadReferenceConfig:
errored = False

for model, change_version, bad_ref_cfg in models_list:
log.debug("checking model %s", model.__tablename__)
# We can't use the model here since it may differ from the db state due to
# this function is run prior to migration. Use the reflected table instead.
exists_func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
Expand All @@ -1240,20 +1263,23 @@ class BadReferenceConfig:
bad_rows_subquery = bad_ref_cfg.exists_func(session, source_table, **exists_func_kwargs)
select_list = [x.label(x.name) for x in source_table.c]
invalid_rows_query = session.query(*select_list).filter(~bad_rows_subquery.exists())
invalid_row_count = invalid_rows_query.count()
if invalid_row_count <= 0:
continue

dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, 'dangling')
if dangling_table_name in existing_table_names:
yield _format_dangling_error(
source_table=source_table.name,
target_table=dangling_table_name,
invalid_count=invalid_row_count,
reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
)
errored = True
invalid_row_count = invalid_rows_query.count()
if invalid_row_count <= 0:
continue
else:
yield _format_dangling_error(
source_table=source_table.name,
target_table=dangling_table_name,
invalid_count=invalid_row_count,
reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
)
errored = True
continue

log.debug("moving data for table %s", source_table.name)
_move_dangling_data_to_new_table(
session,
source_table,
Expand Down Expand Up @@ -1282,6 +1308,7 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
check_bad_references,
)
for check_fn in check_functions:
log.debug("running check function %s", check_fn.__name__)
yield from check_fn(session=session)
# Ensure there is no "active" transaction. Seems odd, but without this MSSQL can hang
session.commit()
Expand Down

0 comments on commit d9075f8

Please sign in to comment.