diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index ba357c0bd1bb..675550c82ca6 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -293,6 +293,17 @@ def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: if t.is_teardown and not t == self: yield t + def get_upstreams_only_setups(self) -> Iterable[Operator]: + """ + Return relevant upstream setups. + + This method is meant to be used when we are checking task dependencies where we need + to wait for all the upstream setups to complete before we can run the task. + """ + for task in self.get_upstreams_only_setups_and_teardowns(): + if task.is_setup: + yield task + def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """Return mapped nodes that are direct dependencies of the current task. diff --git a/airflow/models/dag.py b/airflow/models/dag.py index cd3f47da8843..8a396948ca6d 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -123,6 +123,7 @@ with_row_locks, ) from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType if TYPE_CHECKING: @@ -718,6 +719,13 @@ def validate_setup_teardown(self): :meta private: """ for task in self.tasks: + if task.is_setup: + for down_task in task.downstream_list: + if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS: + # todo: we can relax this to allow out-of-scope tasks to have other trigger rules + # this is required to ensure consistent behavior of dag + # when clearing an indirect setup + raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.") FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) def __repr__(self): diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 9731d2a7c235..c3dfb71e6c5d 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -20,7 +20,7 @@ import collections import collections.abc import functools -from typing import TYPE_CHECKING, Iterator, NamedTuple +from typing import TYPE_CHECKING, Iterator, KeysView, NamedTuple from sqlalchemy import and_, func, or_, select @@ -33,6 +33,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnOperators + from airflow import DAG from airflow.models.taskinstance import TaskInstance from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus @@ -122,10 +123,6 @@ def _evaluate_trigger_rule( from airflow.models.operator import needs_expansion from airflow.models.taskinstance import TaskInstance - task = ti.task - upstream_tasks = {t.task_id: t for t in task.upstream_list} - trigger_rule = task.trigger_rule - @functools.lru_cache def _get_expanded_ti_count() -> int: """Get how many tis the current task is supposed to be expanded into. @@ -133,7 +130,7 @@ def _get_expanded_ti_count() -> int: This extra closure allows us to query the database only when needed, and at most once. """ - return task.get_mapped_ti_count(ti.run_id, session=session) + return ti.task.get_mapped_ti_count(ti.run_id, session=session) @functools.lru_cache def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: @@ -143,24 +140,34 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: and at most once for each task (instead of once for each expanded task instance of the same task). """ + if TYPE_CHECKING: + assert isinstance(ti.task.dag, DAG) try: expanded_ti_count = _get_expanded_ti_count() except (NotFullyPopulated, NotMapped): return None return ti.get_relevant_upstream_map_indexes( - upstream_tasks[upstream_id], - expanded_ti_count, + upstream=ti.task.dag.task_dict[upstream_id], + ti_count=expanded_ti_count, session=session, ) - def _is_relevant_upstream(upstream: TaskInstance) -> bool: - """Whether a task instance is a "relevant upstream" of the current task.""" + def _is_relevant_upstream(upstream: TaskInstance, relevant_ids: set[str] | KeysView[str]) -> bool: + """ + Whether a task instance is a "relevant upstream" of the current task. + + This will return false if upstream.task_id is not in relevant_ids, + or if both of the following are true: + 1. upstream.task_id in relevant_ids is True + 2. ti is in a mapped task group and upstream has a map index + that ti does not depend on. + """ # Not actually an upstream task. - if upstream.task_id not in task.upstream_task_ids: + if upstream.task_id not in relevant_ids: return False # The current task is not in a mapped task group. All tis from an # upstream task are relevant. - if task.get_closest_mapped_task_group() is None: + if ti.task.get_closest_mapped_task_group() is None: return True # The upstream ti is not expanded. The upstream may be mapped or # not, but the ti is relevant either way. @@ -168,7 +175,7 @@ def _is_relevant_upstream(upstream: TaskInstance) -> bool: return True # Now we need to perform fine-grained check on whether this specific # upstream ti's map index is relevant. - relevant = _get_relevant_upstream_map_indexes(upstream.task_id) + relevant = _get_relevant_upstream_map_indexes(upstream_id=upstream.task_id) if relevant is None: return True if relevant == upstream.map_index: @@ -177,31 +184,17 @@ def _is_relevant_upstream(upstream: TaskInstance) -> bool: return True return False - finished_upstream_tis = ( - finished_ti - for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) - if _is_relevant_upstream(finished_ti) - ) - upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) - - success = upstream_states.success - skipped = upstream_states.skipped - failed = upstream_states.failed - upstream_failed = upstream_states.upstream_failed - removed = upstream_states.removed - done = upstream_states.done - success_setup = upstream_states.success_setup - skipped_setup = upstream_states.skipped_setup - - def _iter_upstream_conditions() -> Iterator[ColumnOperators]: + def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators]: # Optimization: If the current task is not in a mapped task group, # it depends on all upstream task instances. - if task.get_closest_mapped_task_group() is None: - yield TaskInstance.task_id.in_(upstream_tasks) + from airflow.models.taskinstance import TaskInstance + + if ti.task.get_closest_mapped_task_group() is None: + yield TaskInstance.task_id.in_(relevant_tasks.keys()) return # Otherwise we need to figure out which map indexes are depended on # for each upstream by the current task instance. - for upstream_id in upstream_tasks: + for upstream_id in relevant_tasks.keys(): map_indexes = _get_relevant_upstream_map_indexes(upstream_id) if map_indexes is None: # All tis of this upstream are dependencies. yield (TaskInstance.task_id == upstream_id) @@ -222,27 +215,49 @@ def _iter_upstream_conditions() -> Iterator[ColumnOperators]: else: yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) - # Optimization: Don't need to hit the database if all upstreams are - # "simple" tasks (no task or task group mapping involved). - if not any(needs_expansion(t) for t in upstream_tasks.values()): - upstream = len(upstream_tasks) - upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) - else: - task_id_counts = session.execute( - select(TaskInstance.task_id, func.count(TaskInstance.task_id)) - .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) - .where(or_(*_iter_upstream_conditions())) - .group_by(TaskInstance.task_id) - ).all() - upstream = sum(count for _, count in task_id_counts) - upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup) - - upstream_done = done >= upstream - - changed = False - new_state = None - if dep_context.flag_upstream_failed: - if trigger_rule == TR.ALL_SUCCESS: + def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]: + """Evaluate whether ``ti``'s trigger rule was met. + + :param ti: Task instance to evaluate the trigger rule of. + :param dep_context: The current dependency context. + :param session: Database session. + """ + task = ti.task + + indirect_setups = {k: v for k, v in relevant_setups.items() if k not in task.upstream_task_ids} + finished_upstream_tis = ( + x + for x in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) + if _is_relevant_upstream(upstream=x, relevant_ids=indirect_setups.keys()) + ) + upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) + + # all of these counts reflect indirect setups which are relevant for this ti + success = upstream_states.success + skipped = upstream_states.skipped + failed = upstream_states.failed + upstream_failed = upstream_states.upstream_failed + removed = upstream_states.removed + + # Optimization: Don't need to hit the database if all upstreams are + # "simple" tasks (no task or task group mapping involved). + if not any(needs_expansion(t) for t in indirect_setups.values()): + upstream = len(indirect_setups) + else: + task_id_counts = session.execute( + select(TaskInstance.task_id, func.count(TaskInstance.task_id)) + .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) + .where(or_(*_iter_upstream_conditions(relevant_tasks=indirect_setups))) + .group_by(TaskInstance.task_id) + ).all() + upstream = sum(count for _, count in task_id_counts) + + new_state = None + changed = False + + # if there's a failure, we mark upstream_failed; if there's a skip, we mark skipped + # in either case, we don't wait for all relevant setups to complete + if dep_context.flag_upstream_failed: if upstream_failed or failed: new_state = TaskInstanceState.UPSTREAM_FAILED elif skipped: @@ -250,196 +265,297 @@ def _iter_upstream_conditions() -> Iterator[ColumnOperators]: elif removed and success and ti.map_index > -1: if ti.map_index >= success: new_state = TaskInstanceState.REMOVED - elif trigger_rule == TR.ALL_FAILED: - if success or skipped: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ONE_SUCCESS: - if upstream_done and done == skipped: - # if upstream is done and all are skipped mark as skipped - new_state = TaskInstanceState.SKIPPED - elif upstream_done and success <= 0: - # if upstream is done and there are no success mark as upstream failed - new_state = TaskInstanceState.UPSTREAM_FAILED - elif trigger_rule == TR.ONE_FAILED: - if upstream_done and not (failed or upstream_failed): - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ONE_DONE: - if upstream_done and not (failed or success): - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.NONE_FAILED: - if upstream_failed or failed: - new_state = TaskInstanceState.UPSTREAM_FAILED - elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: - if upstream_failed or failed: - new_state = TaskInstanceState.UPSTREAM_FAILED - elif skipped == upstream: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.NONE_SKIPPED: - if skipped: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ALL_SKIPPED: - if success or failed: - new_state = TaskInstanceState.SKIPPED - elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: - if upstream_done and upstream_setup and skipped_setup >= upstream_setup: - # when there is an upstream setup and they have all skipped, then skip - new_state = TaskInstanceState.SKIPPED - elif upstream_done and upstream_setup and success_setup == 0: - # when there is an upstream setup, if none succeeded, mark upstream failed - # if at least one setup ran, we'll let it run - new_state = TaskInstanceState.UPSTREAM_FAILED - if new_state is not None: - if new_state == TaskInstanceState.SKIPPED and dep_context.wait_for_past_depends_before_skipping: - past_depends_met = ti.xcom_pull( - task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False - ) - if not past_depends_met: - yield self._failing_status( - reason=("Task should be skipped but the past depends are not met") + + if new_state is not None: + if ( + new_state == TaskInstanceState.SKIPPED + and dep_context.wait_for_past_depends_before_skipping + ): + past_depends_met = ti.xcom_pull( + task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False ) - return - changed = ti.set_state(new_state, session) + if not past_depends_met: + yield self._failing_status( + reason="Task should be skipped but the past depends are not met" + ), changed + return + changed = ti.set_state(new_state, session) - if changed: - dep_context.have_changed_ti_states = True + if changed: + dep_context.have_changed_ti_states = True - if trigger_rule == TR.ONE_SUCCESS: - if success <= 0: + non_successes = upstream - success + if ti.map_index > -1: + non_successes -= removed + if non_successes > 0: yield self._failing_status( reason=( - f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " - f"but none were found. upstream_states={upstream_states}, " + f"All setup tasks must complete successfully. Relevant setups: {relevant_setups}: " + f"upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" + ), + ), changed + + def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: + """Evaluate whether ``ti``'s trigger rule was met. + + :param ti: Task instance to evaluate the trigger rule of. + :param dep_context: The current dependency context. + :param session: Database session. + """ + task = ti.task + upstream_tasks = {t.task_id: t for t in task.upstream_list} + trigger_rule = task.trigger_rule + + finished_upstream_tis = ( + finished_ti + for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) + if _is_relevant_upstream(upstream=finished_ti, relevant_ids=ti.task.upstream_task_ids) + ) + upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) + + success = upstream_states.success + skipped = upstream_states.skipped + failed = upstream_states.failed + upstream_failed = upstream_states.upstream_failed + removed = upstream_states.removed + done = upstream_states.done + success_setup = upstream_states.success_setup + skipped_setup = upstream_states.skipped_setup + + # Optimization: Don't need to hit the database if all upstreams are + # "simple" tasks (no task or task group mapping involved). + if not any(needs_expansion(t) for t in upstream_tasks.values()): + upstream = len(upstream_tasks) + upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) + else: + task_id_counts = session.execute( + select(TaskInstance.task_id, func.count(TaskInstance.task_id)) + .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) + .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) + .group_by(TaskInstance.task_id) + ).all() + upstream = sum(count for _, count in task_id_counts) + upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup) + + upstream_done = done >= upstream + + changed = False + new_state = None + if dep_context.flag_upstream_failed: + if trigger_rule == TR.ALL_SUCCESS: + if upstream_failed or failed: + new_state = TaskInstanceState.UPSTREAM_FAILED + elif skipped: + new_state = TaskInstanceState.SKIPPED + elif removed and success and ti.map_index > -1: + if ti.map_index >= success: + new_state = TaskInstanceState.REMOVED + elif trigger_rule == TR.ALL_FAILED: + if success or skipped: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ONE_SUCCESS: + if upstream_done and done == skipped: + # if upstream is done and all are skipped mark as skipped + new_state = TaskInstanceState.SKIPPED + elif upstream_done and success <= 0: + # if upstream is done and there are no success mark as upstream failed + new_state = TaskInstanceState.UPSTREAM_FAILED + elif trigger_rule == TR.ONE_FAILED: + if upstream_done and not (failed or upstream_failed): + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ONE_DONE: + if upstream_done and not (failed or success): + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.NONE_FAILED: + if upstream_failed or failed: + new_state = TaskInstanceState.UPSTREAM_FAILED + elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: + if upstream_failed or failed: + new_state = TaskInstanceState.UPSTREAM_FAILED + elif skipped == upstream: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.NONE_SKIPPED: + if skipped: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ALL_SKIPPED: + if success or failed: + new_state = TaskInstanceState.SKIPPED + elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: + if upstream_done and upstream_setup and skipped_setup >= upstream_setup: + # when there is an upstream setup and they have all skipped, then skip + new_state = TaskInstanceState.SKIPPED + elif upstream_done and upstream_setup and success_setup == 0: + # when there is an upstream setup, if none succeeded, mark upstream failed + # if at least one setup ran, we'll let it run + new_state = TaskInstanceState.UPSTREAM_FAILED + if new_state is not None: + if ( + new_state == TaskInstanceState.SKIPPED + and dep_context.wait_for_past_depends_before_skipping + ): + past_depends_met = ti.xcom_pull( + task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False ) - ) - elif trigger_rule == TR.ONE_FAILED: - if not failed and not upstream_failed: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " - f"but none were found. upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + if not past_depends_met: + yield self._failing_status( + reason=("Task should be skipped but the past depends are not met") + ) + return + changed = ti.set_state(new_state, session) + + if changed: + dep_context.have_changed_ti_states = True + + if trigger_rule == TR.ONE_SUCCESS: + if success <= 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " + f"but none were found. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ONE_DONE: - if success + failed <= 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}'" - "requires at least one upstream task failure or success" - f"but none were failed or success. upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ONE_FAILED: + if not failed and not upstream_failed: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " + f"but none were found. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_SUCCESS: - num_failures = upstream - success - if ti.map_index > -1: - num_failures -= removed - if num_failures > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"succeeded, but found {num_failures} non-success(es). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ONE_DONE: + if success + failed <= 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}'" + "requires at least one upstream task failure or success" + f"but none were failed or success. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_FAILED: - num_success = upstream - failed - upstream_failed - if ti.map_index > -1: - num_success -= removed - if num_success > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have failed, " - f"but found {num_success} non-failure(s). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_SUCCESS: + num_failures = upstream - success + if ti.map_index > -1: + num_failures -= removed + if num_failures > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"succeeded, but found {num_failures} non-success(es). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_DONE: - if not upstream_done: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_FAILED: + num_success = upstream - failed - upstream_failed + if ti.map_index > -1: + num_success -= removed + if num_success > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks " + f"to have failed, but found {num_success} non-failure(s). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.NONE_FAILED: - num_failures = upstream - success - skipped - if ti.map_index > -1: - num_failures -= removed - if num_failures > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"succeeded or been skipped, but found {num_failures} non-success(es). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_DONE: + if not upstream_done: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"completed, but found {len(upstream_tasks) - done} task(s) that were " + f"not done. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: - num_failures = upstream - success - skipped - if ti.map_index > -1: - num_failures -= removed - if num_failures > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"succeeded or been skipped, but found {num_failures} non-success(es). " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.NONE_FAILED: + num_failures = upstream - success - skipped + if ti.map_index > -1: + num_failures -= removed + if num_failures > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"succeeded or been skipped, but found {num_failures} non-success(es). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.NONE_SKIPPED: - if not upstream_done or (skipped > 0): - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not have been " - f"skipped, but found {skipped} task(s) skipped. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: + num_failures = upstream - success - skipped + if ti.map_index > -1: + num_failures -= removed + if num_failures > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"succeeded or been skipped, but found {num_failures} non-success(es). " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_SKIPPED: - num_non_skipped = upstream - skipped - if num_non_skipped > 0: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " - f"skipped, but found {num_non_skipped} task(s) in non skipped state. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.NONE_SKIPPED: + if not upstream_done or (skipped > 0): + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not " + f"have been skipped, but found {skipped} task(s) skipped. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: - if not upstream_done: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " - f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_SKIPPED: + num_non_skipped = upstream - skipped + if num_non_skipped > 0: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " + f"skipped, but found {num_non_skipped} task(s) in non skipped state. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif upstream_setup is None: # for now, None only happens in mapped case - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' cannot have mapped tasks as upstream. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: + if not upstream_done: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " + f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - elif upstream_setup and not success_setup >= 1: - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task be " - f"successful, but found {upstream_setup - success_setup} task(s) that were not. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" + elif upstream_setup is None: # for now, None only happens in mapped case + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' cannot have mapped tasks as upstream. " + f"upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) + ) + elif upstream_setup and not success_setup >= 1: + yield self._failing_status( + reason=( + f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task " + f"be successful, but found {upstream_setup - success_setup} task(s) that were " + f"not. upstream_states={upstream_states}, " + f"upstream_task_ids={task.upstream_task_ids}" + ) ) - ) - else: - yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.") + else: + yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.") + + if not ti.task.is_teardown: + # a teardown cannot have any indirect setups + relevant_setups = {t.task_id: t for t in ti.task.get_upstreams_only_setups()} + if relevant_setups: + for status, changed in _evaluate_setup_constraint(relevant_setups=relevant_setups): + yield status + if not status.passed and changed: + # no need to evaluate trigger rule; we've already marked as skipped or failed + return + + yield from _evaluate_direct_relatives() diff --git a/docs/apache-airflow/howto/setup-and-teardown.rst b/docs/apache-airflow/howto/setup-and-teardown.rst index 355442a299fe..7afb3c4a350b 100644 --- a/docs/apache-airflow/howto/setup-and-teardown.rst +++ b/docs/apache-airflow/howto/setup-and-teardown.rst @@ -125,6 +125,14 @@ In that example, we (in our pretend docs land) actually wanted to delete the clu create_cluster >> run_query >> other_task run_query >> EmptyOperator(task_id="cluster_teardown").as_teardown(setups=create_cluster) +Implicit ALL_SUCCESS constraint +""""""""""""""""""""""""""""""" + +Any task in the scope of a setup has an implicit "all_success" constraint on its setups. +This is necessary to ensure that if a task with indirect setups is cleared, it will +wait for them to complete. If a setup fails or is skipped, the work tasks which depend +them will be marked ask failures or skips. We also require that any non-teardown directly +downstream of a setup must have trigger rule ALL_SUCCESS. Controlling dag run state """"""""""""""""""""""""" diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 3b5f45a28e4c..cf84577b0b23 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -74,6 +74,7 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup, TaskGroupContext from airflow.utils.timezone import datetime as datetime_tz +from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from airflow.utils.weight_rule import WeightRule from tests.models import DEFAULT_DATE @@ -4022,3 +4023,16 @@ def test_clearing_behavior_just_teardown(self): assert self.cleared_upstream(s1) == {s1, t1} assert self.cleared_downstream(s1) == {s1, t1} assert self.cleared_neither(s1) == {s1, t1} + + def test_validate_setup_teardown_trigger_rule(self): + with DAG( + dag_id="direct_setup_trigger_rule", start_date=pendulum.now(), schedule=None, catchup=False + ) as dag: + s1, w1 = self.make_tasks(dag, "s1, w1") + s1 >> w1 + dag.validate_setup_teardown() + w1.trigger_rule = TriggerRule.ONE_FAILED + with pytest.raises( + Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS." + ): + dag.validate_setup_teardown() diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e50917e10107..a2e72d614e5f 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1236,18 +1236,36 @@ def test_depends_on_past(self, dag_maker): 2, _UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0), True, - None, + (True, None), # is_teardown=True, expect_state=None True, - id="one setup failed one setup success --> should run", + id="is teardown one setup failed one setup success", + ), + param( + "all_done_setup_success", + 2, + _UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0), + True, + (False, "upstream_failed"), # is_teardown=False, expect_state="upstream_failed" + True, + id="not teardown one setup failed one setup success", ), param( "all_done_setup_success", 2, _UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1), True, - None, + (True, None), # is_teardown=True, expect_state=None True, - id="one setup success one setup skipped --> should run", + id="is teardown one setup success one setup skipped", + ), + param( + "all_done_setup_success", + 2, + _UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1), + True, + (False, "skipped"), # is_teardown=False, expect_state="skipped" + True, + id="not teardown one setup success one setup skipped", ), param( "all_done_setup_success", @@ -1263,18 +1281,36 @@ def test_depends_on_past(self, dag_maker): 1, _UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0), True, - None, + (True, None), # is_teardown=True, expect_state=None False, - id="not all done, one failed", + id="is teardown not all done one failed", + ), + param( + "all_done_setup_success", + 1, + _UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0), + True, + (False, "upstream_failed"), # is_teardown=False, expect_state="upstream_failed" + False, + id="not teardown not all done one failed", ), param( "all_done_setup_success", 1, _UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0), True, - None, + (True, None), # is_teardown=True, expect_state=None False, - id="not all done, one skipped", + id="not all done one skipped", + ), + param( + "all_done_setup_success", + 1, + _UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0), + True, + (False, "skipped"), # is_teardown=False, expect_state="skipped' + False, + id="not all done one skipped", ), ], ) @@ -1289,6 +1325,13 @@ def test_check_task_dependencies( expect_state: State, expect_passed: bool, ): + # this allows us to change the expected state depending on whether the + # task is a teardown + set_teardown = False + if isinstance(expect_state, tuple): + set_teardown, expect_state = expect_state + assert isinstance(set_teardown, bool) + monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) # sanity checks @@ -1299,6 +1342,8 @@ def test_check_task_dependencies( with dag_maker() as dag: downstream = EmptyOperator(task_id="downstream", trigger_rule=trigger_rule) + if set_teardown: + downstream.as_teardown() for i in range(5): task = EmptyOperator(task_id=f"work_{i}", dag=dag) task.set_downstream(downstream) diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 4293303bd0ad..b1f0cf64fb67 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -1217,3 +1217,167 @@ def tg(a): results = list(result_iterator) assert len(results) == 1 assert results[0].passed is False + + +class TestTriggerRuleDepSetupConstraint: + @staticmethod + def get_ti(dr, task_id): + return next(ti for ti in dr.task_instances if ti.task_id == task_id) + + def get_dep_statuses(self, dr, task_id, flag_upstream_failed=False, session=None): + return list( + TriggerRuleDep()._get_dep_statuses( + ti=self.get_ti(dr, task_id), + dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), + session=session, + ) + ) + + def test_setup_constraint_blocks_execution(self, dag_maker, session): + with dag_maker(session=session): + + @task + def t1(): + return 1 + + @task + def t2(): + return 2 + + @task + def t3(): + return 3 + + t1_task = t1() + t2_task = t2() + t3_task = t3() + t1_task >> t2_task >> t3_task + t1_task.as_setup() + dr = dag_maker.create_dagrun() + + # setup constraint is not applied to t2 because it has a direct setup + # so even though the setup is not done, the check passes + # but trigger rule fails because the normal trigger rule dep behavior + statuses = self.get_dep_statuses(dr, "t2", session=session) + assert len(statuses) == 1 + assert statuses[0].passed is False + assert statuses[0].reason.startswith("Task's trigger rule 'all_success' requires all upstream tasks") + + # t3 has an indirect setup so the setup check fails + # trigger rule also fails + statuses = self.get_dep_statuses(dr, "t3", session=session) + assert len(statuses) == 2 + assert statuses[0].passed is False + assert statuses[0].reason.startswith("All setup tasks must complete successfully") + assert statuses[1].passed is False + assert statuses[1].reason.startswith("Task's trigger rule 'all_success' requires all upstream tasks") + + @pytest.mark.parametrize( + "setup_state, expected", [(None, None), ("failed", "upstream_failed"), ("skipped", "skipped")] + ) + def test_setup_constraint_changes_state_appropriately(self, dag_maker, session, setup_state, expected): + with dag_maker(session=session): + + @task + def t1(): + return 1 + + @task + def t2(): + return 2 + + @task + def t3(): + return 3 + + t1_task = t1() + t2_task = t2() + t3_task = t3() + t1_task >> t2_task >> t3_task + t1_task.as_setup() + dr = dag_maker.create_dagrun() + + # if the setup fails then now, in processing the trigger rule dep, the ti states + # will be updated + if setup_state: + self.get_ti(dr, "t1").state = setup_state + session.commit() + (status,) = self.get_dep_statuses(dr, "t2", flag_upstream_failed=True, session=session) + assert status.passed is False + # t2 fails on the non-setup-related trigger rule constraint since it has + # a direct setup + assert status.reason.startswith("Task's trigger rule 'all_success' requires") + assert self.get_ti(dr, "t2").state == expected + assert self.get_ti(dr, "t3").state is None # hasn't been evaluated yet + + # unlike t2, t3 fails on the setup constraint, and the normal trigger rule + # constraint is not actually evaluated, since it ain't gonna run anyway + if setup_state is None: + # when state is None, setup constraint doesn't mutate ti state, so we get + # two failure reasons -- setup constraint and trigger rule + (status, _) = self.get_dep_statuses(dr, "t3", flag_upstream_failed=True, session=session) + else: + (status,) = self.get_dep_statuses(dr, "t3", flag_upstream_failed=True, session=session) + assert status.reason.startswith("All setup tasks must complete successfully") + assert self.get_ti(dr, "t3").state == expected + + @pytest.mark.parametrize( + "setup_state, expected", [(None, None), ("failed", "upstream_failed"), ("skipped", "skipped")] + ) + def test_setup_constraint_will_fail_or_skip_fast(self, dag_maker, session, setup_state, expected): + """ + When a setup fails or skips, the tasks that depend on it will immediately fail or skip + and not, for example, wait for all setups to complete before determining what is + the appropriate state. This is a bit of a race condition, but it's consistent + with the behavior for many-to-one direct upstream task relationships, and it's + required if you want to fail fast. + + So in this test we verify that if even one setup is failed or skipped, the + state will propagate to the in-scope work tasks. + """ + with dag_maker(session=session): + + @task + def s1(): + return 1 + + @task + def s2(): + return 1 + + @task + def w1(): + return 2 + + @task + def w2(): + return 3 + + s1 = s1().as_setup() + s2 = s2().as_setup() + [s1, s2] >> w1() >> w2() + dr = dag_maker.create_dagrun() + + # if the setup fails then now, in processing the trigger rule dep, the ti states + # will be updated + if setup_state: + self.get_ti(dr, "s2").state = setup_state + session.commit() + (status,) = self.get_dep_statuses(dr, "w1", flag_upstream_failed=True, session=session) + assert status.passed is False + # t2 fails on the non-setup-related trigger rule constraint since it has + # a direct setup + assert status.reason.startswith("Task's trigger rule 'all_success' requires") + assert self.get_ti(dr, "w1").state == expected + assert self.get_ti(dr, "w2").state is None # hasn't been evaluated yet + + # unlike t2, t3 fails on the setup constraint, and the normal trigger rule + # constraint is not actually evaluated, since it ain't gonna run anyway + if setup_state is None: + # when state is None, setup constraint doesn't mutate ti state, so we get + # two failure reasons -- setup constraint and trigger rule + (status, _) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session) + else: + (status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session) + assert status.reason.startswith("All setup tasks must complete successfully") + assert self.get_ti(dr, "w2").state == expected