Skip to content

Commit

Permalink
bugfix: deferred tasks does not cancel when DAG is marked fail (#20649)
Browse files Browse the repository at this point in the history
  • Loading branch information
dungdm93 authored Jan 5, 2022
1 parent b83084b commit 64c0bd5
Showing 1 changed file with 84 additions and 35 deletions.
119 changes: 84 additions & 35 deletions airflow/api/common/experimental/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@
# under the License.
"""Marks tasks APIs."""

import datetime
from typing import Iterable
from datetime import datetime
from typing import Generator, Iterable, List, Optional

from sqlalchemy import or_
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm.session import Session as SASession
from sqlalchemy.sql.expression import or_

from airflow import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.subdag import SubDagOperator
from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.types import DagRunType


def _create_dagruns(dag, execution_dates, state, run_type):
def _create_dagruns(
dag: DAG, execution_dates: List[datetime], state: TaskInstanceState, run_type: DagRunType
) -> List[DagRun]:
"""
Infers from the dates which dag runs need to be created and does so.
Expand Down Expand Up @@ -63,15 +67,15 @@ def _create_dagruns(dag, execution_dates, state, run_type):
@provide_session
def set_state(
tasks: Iterable[BaseOperator],
execution_date: datetime.datetime,
execution_date: datetime,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
past: bool = False,
state: TaskInstanceState = TaskInstanceState.SUCCESS,
commit: bool = False,
session=None,
):
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the state of a task instance and if needed its relatives. Can set state
for future tasks (calculated from execution_date) and retroactively
Expand Down Expand Up @@ -134,7 +138,9 @@ def set_state(
return tis_altered


def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates):
def all_subdag_tasks_query(
sub_dag_run_ids: List[str], session: SASession, state: TaskInstanceState, confirmed_dates: List[datetime]
):
"""Get *all* tasks of the sub dags"""
qry_sub_dag = (
session.query(TaskInstance)
Expand All @@ -144,7 +150,13 @@ def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates):
return qry_sub_dag


def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates):
def get_all_dag_task_query(
dag: DAG,
session: SASession,
state: TaskInstanceState,
task_ids: List[str],
confirmed_dates: List[datetime],
):
"""Get all tasks of the main dag that will be affected by a state change"""
qry_dag = (
session.query(TaskInstance)
Expand All @@ -160,7 +172,14 @@ def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates):
return qry_dag


def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates):
def get_subdag_runs(
dag: DAG,
session: SASession,
state: TaskInstanceState,
task_ids: List[str],
commit: bool,
confirmed_dates: List[datetime],
) -> List[str]:
"""Go through subdag operators and create dag runs. We will only work
within the scope of the subdag. We won't propagate to the parent dag,
but we will propagate from parent to subdag.
Expand All @@ -181,7 +200,7 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates):
dag_runs = _create_dagruns(
current_task.subdag,
execution_dates=confirmed_dates,
state=State.RUNNING,
state=TaskInstanceState.RUNNING,
run_type=DagRunType.BACKFILL_JOB,
)

Expand All @@ -192,7 +211,13 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates):
return sub_dag_ids


def verify_dagruns(dag_runs, commit, state, session, current_task):
def verify_dagruns(
dag_runs: List[DagRun],
commit: bool,
state: TaskInstanceState,
session: SASession,
current_task: BaseOperator,
):
"""Verifies integrity of dag_runs.
:param dag_runs: dag runs to verify
Expand All @@ -210,7 +235,7 @@ def verify_dagruns(dag_runs, commit, state, session, current_task):
session.merge(dag_run)


def verify_dag_run_integrity(dag, dates):
def verify_dag_run_integrity(dag: DAG, dates: List[datetime]) -> List[datetime]:
"""
Verify the integrity of the dag runs in case a task was added or removed
set the confirmed execution dates as they might be different
Expand All @@ -225,7 +250,9 @@ def verify_dag_run_integrity(dag, dates):
return confirmed_dates


def find_task_relatives(tasks, downstream, upstream):
def find_task_relatives(
tasks: Iterable[BaseOperator], downstream: bool, upstream: bool
) -> Generator[str, None, None]:
"""Yield task ids and optionally ancestor and descendant ids."""
for task in tasks:
yield task.task_id
Expand All @@ -237,7 +264,7 @@ def find_task_relatives(tasks, downstream, upstream):
yield relative.task_id


def get_execution_dates(dag, execution_date, future, past):
def get_execution_dates(dag: DAG, execution_date: datetime, future: bool, past: bool) -> List[datetime]:
"""Returns dates of DAG execution"""
latest_execution_date = dag.get_latest_execution_date()
if latest_execution_date is None:
Expand Down Expand Up @@ -266,7 +293,9 @@ def get_execution_dates(dag, execution_date, future, past):


@provide_session
def _set_dag_run_state(dag_id, execution_date, state, session=None):
def _set_dag_run_state(
dag_id: str, execution_date: datetime, state: TaskInstanceState, session: SASession = NEW_SESSION
):
"""
Helper method that set dag run state in the DB.
Expand All @@ -279,7 +308,7 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None):
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).one()
)
dag_run.state = state
if state == State.RUNNING:
if state == TaskInstanceState.RUNNING:
dag_run.start_date = timezone.utcnow()
dag_run.end_date = None
else:
Expand All @@ -288,7 +317,12 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None):


@provide_session
def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None):
def set_dag_run_state_to_success(
dag: Optional[DAG],
execution_date: Optional[datetime],
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the dag run for a specific execution date and its task instances
to success.
Expand All @@ -306,18 +340,27 @@ def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None

# Mark the dag run to success.
if commit:
_set_dag_run_state(dag.dag_id, execution_date, State.SUCCESS, session)
_set_dag_run_state(dag.dag_id, execution_date, TaskInstanceState.SUCCESS, session)

# Mark all task instances of the dag run to success.
for task in dag.tasks:
task.dag = dag
return set_state(
tasks=dag.tasks, execution_date=execution_date, state=State.SUCCESS, commit=commit, session=session
tasks=dag.tasks,
execution_date=execution_date,
state=TaskInstanceState.SUCCESS,
commit=commit,
session=session,
)


@provide_session
def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None):
def set_dag_run_state_to_failed(
dag: Optional[DAG],
execution_date: Optional[datetime],
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the dag run for a specific execution date and its running task instances
to failed.
Expand All @@ -335,18 +378,15 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)

# Mark the dag run to failed.
if commit:
_set_dag_run_state(dag.dag_id, execution_date, State.FAILED, session)
_set_dag_run_state(dag.dag_id, execution_date, TaskInstanceState.FAILED, session)

# Mark only RUNNING task instances.
# Mark only running task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task_ids),
)
.filter(TaskInstance.state == State.RUNNING)
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task_ids),
TaskInstance.state.in_(State.running),
)
task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]

Expand All @@ -358,12 +398,21 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)
tasks.append(task)

return set_state(
tasks=tasks, execution_date=execution_date, state=State.FAILED, commit=commit, session=session
tasks=tasks,
execution_date=execution_date,
state=TaskInstanceState.FAILED,
commit=commit,
session=session,
)


@provide_session
def set_dag_run_state_to_running(dag, execution_date, commit=False, session=None):
def set_dag_run_state_to_running(
dag: Optional[DAG],
execution_date: Optional[datetime],
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the dag run for a specific execution date to running.
Expand All @@ -380,7 +429,7 @@ def set_dag_run_state_to_running(dag, execution_date, commit=False, session=None

# Mark the dag run to running.
if commit:
_set_dag_run_state(dag.dag_id, execution_date, State.RUNNING, session)
_set_dag_run_state(dag.dag_id, execution_date, TaskInstanceState.RUNNING, session)

# To keep the return type consistent with the other similar functions.
return res

0 comments on commit 64c0bd5

Please sign in to comment.