Skip to content

Commit

Permalink
When clearing task instances try to get associated DAGs from database (
Browse files Browse the repository at this point in the history
…#29065)

* When clearing task instances try to get associated DAGs from database.

This fixes problems when recursively clearing task instances across multiple DAGs:
  * Task instances in downstream DAGs weren't having their `max_tries` property incremented, which could cause downstream external task sensors in reschedule mode to instantly time out (issue #29049).
  * Task instances in downstream DAGs could have some of their properties overridden by an unrelated task in the upstream DAG if they had the same task ID.

* Use session fixture for new `test_clear_task_instances_without_dag_param` test.

* Use session fixture for new `test_clear_task_instances_in_multiple_dags` test.

---------

Co-authored-by: eladkal <[email protected]>
(cherry picked from commit 0d2e6dc)
  • Loading branch information
sean-rose authored and ephraimbuddy committed Apr 14, 2023
1 parent eccac28 commit 403cf86
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 7 deletions.
11 changes: 7 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID
from airflow.models.dagbag import DagBag
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
Expand Down Expand Up @@ -203,6 +204,7 @@ def clear_task_instances(
task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
)
dag_bag = DagBag(read_dags_from_db=True)
for ti in tis:
if ti.state == TaskInstanceState.RUNNING:
if ti.job_id:
Expand All @@ -211,15 +213,16 @@ def clear_task_instances(
ti.state = TaskInstanceState.RESTARTING
job_ids.append(ti.job_id)
else:
ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session)
task_id = ti.task_id
if dag and dag.has_task(task_id):
task = dag.get_task(task_id)
if ti_dag and ti_dag.has_task(task_id):
task = ti_dag.get_task(task_id)
ti.refresh_from_task(task)
task_retries = task.retries
ti.max_tries = ti.try_number + task_retries - 1
else:
# Ignore errors when updating max_tries if dag is None or
# task not found in dag since database records could be
# Ignore errors when updating max_tries if the DAG or
# task are not found since database records could be
# outdated. We make max_tries the maximum value of its
# original max_tries or the last attempted try number.
ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
Expand Down
97 changes: 94 additions & 3 deletions tests/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow import settings
from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.sensors.python import PythonSensor
from airflow.utils.session import create_session
Expand Down Expand Up @@ -202,9 +203,9 @@ def test_clear_task_instances_without_task(self, dag_maker):
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session)
clear_task_instances(qry, session, dag=dag)

# When dag is None, max_tries will be maximum of original max_tries or try_number.
# When no task is found, max_tries will be maximum of original max_tries or try_number.
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
Expand All @@ -214,6 +215,7 @@ def test_clear_task_instances_without_task(self, dag_maker):
assert ti1.max_tries == 2

def test_clear_task_instances_without_dag(self, dag_maker):
# Don't write DAG to the database, so no DAG is found by clear_task_instances().
with dag_maker(
"test_clear_task_instances_without_dag",
start_date=DEFAULT_DATE,
Expand Down Expand Up @@ -242,7 +244,7 @@ def test_clear_task_instances_without_dag(self, dag_maker):
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session)

# When dag is None, max_tries will be maximum of original max_tries or try_number.
# When no DAG is found, max_tries will be maximum of original max_tries or try_number.
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
Expand All @@ -251,6 +253,95 @@ def test_clear_task_instances_without_dag(self, dag_maker):
assert ti1.try_number == 2
assert ti1.max_tries == 2

def test_clear_task_instances_without_dag_param(self, dag_maker, session):
with dag_maker(
"test_clear_task_instances_without_dag_param",
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
session=session,
) as dag:
task0 = EmptyOperator(task_id="task0")
task1 = EmptyOperator(task_id="task1", retries=2)

# Write DAG to the database so it can be found by clear_task_instances().
SerializedDagModel.write_dag(dag, session=session)

dr = dag_maker.create_dagrun(
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)

ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
ti0.refresh_from_task(task0)
ti1.refresh_from_task(task1)

ti0.run(session=session)
ti1.run(session=session)

# we use order_by(task_id) here because for the test DAG structure of ours
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session)

ti0.refresh_from_db(session=session)
ti1.refresh_from_db(session=session)
# Next try to run will be try 2
assert ti0.try_number == 2
assert ti0.max_tries == 1
assert ti1.try_number == 2
assert ti1.max_tries == 3

def test_clear_task_instances_in_multiple_dags(self, dag_maker, session):
with dag_maker(
"test_clear_task_instances_in_multiple_dags0",
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
session=session,
) as dag0:
task0 = EmptyOperator(task_id="task0")

dr0 = dag_maker.create_dagrun(
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)

with dag_maker(
"test_clear_task_instances_in_multiple_dags1",
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
session=session,
) as dag1:
task1 = EmptyOperator(task_id="task1", retries=2)

# Write secondary DAG to the database so it can be found by clear_task_instances().
SerializedDagModel.write_dag(dag1, session=session)

dr1 = dag_maker.create_dagrun(
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)

ti0 = dr0.task_instances[0]
ti1 = dr1.task_instances[0]
ti0.refresh_from_task(task0)
ti1.refresh_from_task(task1)

ti0.run(session=session)
ti1.run(session=session)

qry = session.query(TI).filter(TI.dag_id.in_((dag0.dag_id, dag1.dag_id))).all()
clear_task_instances(qry, session, dag=dag0)

ti0.refresh_from_db(session=session)
ti1.refresh_from_db(session=session)
# Next try to run will be try 2
assert ti0.try_number == 2
assert ti0.max_tries == 1
assert ti1.try_number == 2
assert ti1.max_tries == 3

def test_clear_task_instances_with_task_reschedule(self, dag_maker):
"""Test that TaskReschedules are deleted correctly when TaskInstances are cleared"""

Expand Down

0 comments on commit 403cf86

Please sign in to comment.