Skip to content

Commit

Permalink
Prevent DagRun's start_date from reset (#30124) (#30125)
Browse files Browse the repository at this point in the history
(cherry picked from commit 070ecbd)
  • Loading branch information
Dmytro Suvorov authored and ephraimbuddy committed May 8, 2023
1 parent ecc4e15 commit 94582bd
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 15 deletions.
17 changes: 11 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,11 @@ def clear_task_instances(
) -> None:
"""
Clears a set of task instances, but makes sure the running ones
get killed.
get killed. Also sets Dagrun's `state` to QUEUED and `start_date`
to the time of execution. But only for finished DRs (SUCCESS and FAILED).
Doesn't clear DR's `state` and `start_date`for running
DRs (QUEUED and RUNNING) because clearing the state for already
running DR is redundant and clearing `start_date` affects DR's duration.
:param tis: a list of task instances
:param session: current session
Expand Down Expand Up @@ -302,11 +306,12 @@ def clear_task_instances(
)
dag_run_state = DagRunState(dag_run_state) # Validate the state value.
for dr in drs:
dr.state = dag_run_state
dr.start_date = timezone.utcnow()
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
if dr.state in State.finished_dr_states:
dr.state = dag_run_state
dr.start_date = timezone.utcnow()
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
session.flush()


Expand Down
3 changes: 3 additions & 0 deletions airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class State:
SKIPPED = TaskInstanceState.SKIPPED
DEFERRED = TaskInstanceState.DEFERRED

finished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.SUCCESS, DagRunState.FAILED])
unfinished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.QUEUED, DagRunState.RUNNING])

task_states: tuple[TaskInstanceState | None, ...] = (None,) + tuple(TaskInstanceState)

dag_states: tuple[DagRunState, ...] = (
Expand Down
4 changes: 2 additions & 2 deletions tests/api_connexion/endpoints/test_dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunType
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user
from tests.test_utils.config import conf_vars
Expand Down Expand Up @@ -1440,7 +1440,7 @@ def test_should_respond_200(self, dag_maker, session):
with dag_maker(dag_id) as dag:
task = EmptyOperator(task_id="task_id", dag=dag)
self.app.dag_bag.bag_dag(dag, root_dag=dag)
dr = dag_maker.create_dagrun(run_id=dag_run_id)
dr = dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED)
ti = dr.get_task_instance(task_id="task_id")
ti.task = task
ti.state = State.SUCCESS
Expand Down
90 changes: 86 additions & 4 deletions tests/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from airflow.operators.empty import EmptyOperator
from airflow.sensors.python import PythonSensor
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
from tests.test_utils import db
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_clear_task_instances_next_method(self, dag_maker, session):
assert ti0.next_kwargs is None

@pytest.mark.parametrize(
["state", "last_scheduling"], [(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)]
["state", "last_scheduling"], [(DagRunState.QUEUED, None), (DagRunState.RUNNING, DEFAULT_DATE)]
)
def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker):
"""Test that DR state is set to None after clear.
Expand All @@ -147,7 +147,7 @@ def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker):
EmptyOperator(task_id="0")
EmptyOperator(task_id="1", retries=2)
dr = dag_maker.create_dagrun(
state=State.RUNNING,
state=DagRunState.SUCCESS,
run_type=DagRunType.SCHEDULED,
)
ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
Expand All @@ -168,9 +168,91 @@ def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker):
session.refresh(dr)

assert dr.state == state
assert dr.start_date is None if state == State.QUEUED else dr.start_date
assert dr.start_date is None if state == DagRunState.QUEUED else dr.start_date
assert dr.last_scheduling_decision == last_scheduling

@pytest.mark.parametrize("state", [DagRunState.QUEUED, DagRunState.RUNNING])
def test_clear_task_instances_on_running_dr(self, state, dag_maker):
"""Test that DagRun state, start_date and last_scheduling_decision
are not changed after clearing TI in an unfinished DagRun.
"""
with dag_maker(
"test_clear_task_instances",
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
) as dag:
EmptyOperator(task_id="0")
EmptyOperator(task_id="1", retries=2)
dr = dag_maker.create_dagrun(
state=state,
run_type=DagRunType.SCHEDULED,
)
ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
dr.last_scheduling_decision = DEFAULT_DATE
ti0.state = TaskInstanceState.SUCCESS
ti1.state = TaskInstanceState.SUCCESS
session = dag_maker.session
session.flush()

# we use order_by(task_id) here because for the test DAG structure of ours
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
session.flush()

session.refresh(dr)

assert dr.state == state
assert dr.start_date
assert dr.last_scheduling_decision == DEFAULT_DATE

@pytest.mark.parametrize(
["state", "last_scheduling"],
[
(DagRunState.SUCCESS, None),
(DagRunState.SUCCESS, DEFAULT_DATE),
(DagRunState.FAILED, None),
(DagRunState.FAILED, DEFAULT_DATE),
],
)
def test_clear_task_instances_on_finished_dr(self, state, last_scheduling, dag_maker):
"""Test that DagRun state, start_date and last_scheduling_decision
are changed after clearing TI in a finished DagRun.
"""
with dag_maker(
"test_clear_task_instances",
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
) as dag:
EmptyOperator(task_id="0")
EmptyOperator(task_id="1", retries=2)
dr = dag_maker.create_dagrun(
state=state,
run_type=DagRunType.SCHEDULED,
)
ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
dr.last_scheduling_decision = DEFAULT_DATE
ti0.state = TaskInstanceState.SUCCESS
ti1.state = TaskInstanceState.SUCCESS
session = dag_maker.session
session.flush()

# we use order_by(task_id) here because for the test DAG structure of ours
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
session.flush()

session.refresh(dr)

assert dr.state == DagRunState.QUEUED
assert dr.start_date is None
assert dr.last_scheduling_decision is None

def test_clear_task_instances_without_task(self, dag_maker):
with dag_maker(
"test_clear_task_instances_without_task",
Expand Down
26 changes: 23 additions & 3 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create_dag_run(
task_states: Mapping[str, TaskInstanceState] | None = None,
execution_date: datetime.datetime | None = None,
is_backfill: bool = False,
state: DagRunState = DagRunState.RUNNING,
session: Session,
):
now = timezone.utcnow()
Expand All @@ -98,7 +99,7 @@ def create_dag_run(
execution_date=execution_date,
data_interval=data_interval,
start_date=now,
state=DagRunState.RUNNING,
state=state,
external_trigger=False,
)

Expand All @@ -110,11 +111,30 @@ def create_dag_run(

return dag_run

def test_clear_task_instances_for_backfill_dagrun(self, session):
@pytest.mark.parametrize("state", [DagRunState.QUEUED, DagRunState.RUNNING])
def test_clear_task_instances_for_backfill_unfinished_dagrun(self, state, session):
now = timezone.utcnow()
dag_id = "test_clear_task_instances_for_backfill_dagrun"
dag = DAG(dag_id=dag_id, start_date=now)
dag_run = self.create_dag_run(dag, execution_date=now, is_backfill=True, state=state, session=session)

task0 = EmptyOperator(task_id="backfill_task_0", owner="test", dag=dag)
ti0 = TI(task=task0, run_id=dag_run.run_id)
ti0.run()

qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
session.commit()
ti0.refresh_from_db()
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == now).first()
assert dr0.state == state

@pytest.mark.parametrize("state", [DagRunState.SUCCESS, DagRunState.FAILED])
def test_clear_task_instances_for_backfill_finished_dagrun(self, state, session):
now = timezone.utcnow()
dag_id = "test_clear_task_instances_for_backfill_dagrun"
dag = DAG(dag_id=dag_id, start_date=now)
dag_run = self.create_dag_run(dag, execution_date=now, is_backfill=True, session=session)
dag_run = self.create_dag_run(dag, execution_date=now, is_backfill=True, state=state, session=session)

task0 = EmptyOperator(task_id="backfill_task_0", owner="test", dag=dag)
ti0 = TI(task=task0, run_id=dag_run.run_id)
Expand Down

0 comments on commit 94582bd

Please sign in to comment.