Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that tasks wait for running indirect setup #33903

Merged
merged 15 commits into from
Sep 3, 2023
11 changes: 11 additions & 0 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,17 @@ def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
if t.is_teardown and not t == self:
yield t

def get_upstreams_only_setups(self) -> Iterable[Operator]:
"""
Return relevant upstream setups.

This method is meant to be used when we are checking task dependencies where we need
to wait for all the upstream setups to complete before we can run the task.
"""
for task in self.get_upstreams_only_setups_and_teardowns():
if task.is_setup:
yield task

def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""Return mapped nodes that are direct dependencies of the current task.

Expand Down
8 changes: 8 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
with_row_locks,
)
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType

if TYPE_CHECKING:
Expand Down Expand Up @@ -718,6 +719,13 @@ def validate_setup_teardown(self):
:meta private:
"""
for task in self.tasks:
if task.is_setup:
for down_task in task.downstream_list:
if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS:
# todo: we can relax this to allow out-of-scope tasks to have other trigger rules
# this is required to ensure consistent behavior of dag
# when clearing an indirect setup
raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.")
FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)

def __repr__(self):
Expand Down
572 changes: 344 additions & 228 deletions airflow/ti_deps/deps/trigger_rule_dep.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions docs/apache-airflow/howto/setup-and-teardown.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ In that example, we (in our pretend docs land) actually wanted to delete the clu
create_cluster >> run_query >> other_task
run_query >> EmptyOperator(task_id="cluster_teardown").as_teardown(setups=create_cluster)

Implicit ALL_SUCCESS constraint
"""""""""""""""""""""""""""""""

Any task in the scope of a setup has an implicit "all_success" constraint on its setups.
This is necessary to ensure that if a task with indirect setups is cleared, it will
wait for them to complete. If a setup fails or is skipped, the work tasks which depend
them will be marked ask failures or skips. We also require that any non-teardown directly
downstream of a setup must have trigger rule ALL_SUCCESS.

Controlling dag run state
"""""""""""""""""""""""""
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import TaskGroup, TaskGroupContext
from airflow.utils.timezone import datetime as datetime_tz
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from airflow.utils.weight_rule import WeightRule
from tests.models import DEFAULT_DATE
Expand Down Expand Up @@ -4022,3 +4023,16 @@ def test_clearing_behavior_just_teardown(self):
assert self.cleared_upstream(s1) == {s1, t1}
assert self.cleared_downstream(s1) == {s1, t1}
assert self.cleared_neither(s1) == {s1, t1}

def test_validate_setup_teardown_trigger_rule(self):
with DAG(
dag_id="direct_setup_trigger_rule", start_date=pendulum.now(), schedule=None, catchup=False
) as dag:
s1, w1 = self.make_tasks(dag, "s1, w1")
s1 >> w1
dag.validate_setup_teardown()
w1.trigger_rule = TriggerRule.ONE_FAILED
with pytest.raises(
Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS."
):
dag.validate_setup_teardown()
61 changes: 53 additions & 8 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,18 +1236,36 @@ def test_depends_on_past(self, dag_maker):
2,
_UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0),
True,
None,
(True, None), # is_teardown=True, expect_state=None
True,
id="one setup failed one setup success --> should run",
id="is teardown one setup failed one setup success",
),
param(
"all_done_setup_success",
2,
_UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0),
True,
(False, "upstream_failed"), # is_teardown=False, expect_state="upstream_failed"
True,
id="not teardown one setup failed one setup success",
),
param(
"all_done_setup_success",
2,
_UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1),
True,
None,
(True, None), # is_teardown=True, expect_state=None
True,
id="one setup success one setup skipped --> should run",
id="is teardown one setup success one setup skipped",
),
param(
"all_done_setup_success",
2,
_UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1),
True,
(False, "skipped"), # is_teardown=False, expect_state="skipped"
True,
id="not teardown one setup success one setup skipped",
),
param(
"all_done_setup_success",
Expand All @@ -1263,18 +1281,36 @@ def test_depends_on_past(self, dag_maker):
1,
_UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0),
True,
None,
(True, None), # is_teardown=True, expect_state=None
False,
id="not all done, one failed",
id="is teardown not all done one failed",
),
param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0),
True,
(False, "upstream_failed"), # is_teardown=False, expect_state="upstream_failed"
False,
id="not teardown not all done one failed",
),
param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0),
True,
None,
(True, None), # is_teardown=True, expect_state=None
False,
id="not all done, one skipped",
id="not all done one skipped",
),
param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0),
True,
(False, "skipped"), # is_teardown=False, expect_state="skipped'
False,
id="not all done one skipped",
),
],
)
Expand All @@ -1289,6 +1325,13 @@ def test_check_task_dependencies(
expect_state: State,
expect_passed: bool,
):
# this allows us to change the expected state depending on whether the
# task is a teardown
set_teardown = False
if isinstance(expect_state, tuple):
set_teardown, expect_state = expect_state
assert isinstance(set_teardown, bool)

monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states)

# sanity checks
Expand All @@ -1299,6 +1342,8 @@ def test_check_task_dependencies(

with dag_maker() as dag:
downstream = EmptyOperator(task_id="downstream", trigger_rule=trigger_rule)
if set_teardown:
downstream.as_teardown()
for i in range(5):
task = EmptyOperator(task_id=f"work_{i}", dag=dag)
task.set_downstream(downstream)
Expand Down
164 changes: 164 additions & 0 deletions tests/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,3 +1217,167 @@ def tg(a):
results = list(result_iterator)
assert len(results) == 1
assert results[0].passed is False


class TestTriggerRuleDepSetupConstraint:
@staticmethod
def get_ti(dr, task_id):
return next(ti for ti in dr.task_instances if ti.task_id == task_id)

def get_dep_statuses(self, dr, task_id, flag_upstream_failed=False, session=None):
return list(
TriggerRuleDep()._get_dep_statuses(
ti=self.get_ti(dr, task_id),
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
session=session,
)
)

def test_setup_constraint_blocks_execution(self, dag_maker, session):
with dag_maker(session=session):

@task
def t1():
return 1

@task
def t2():
return 2

@task
def t3():
return 3

t1_task = t1()
t2_task = t2()
t3_task = t3()
t1_task >> t2_task >> t3_task
t1_task.as_setup()
dr = dag_maker.create_dagrun()

# setup constraint is not applied to t2 because it has a direct setup
# so even though the setup is not done, the check passes
# but trigger rule fails because the normal trigger rule dep behavior
statuses = self.get_dep_statuses(dr, "t2", session=session)
assert len(statuses) == 1
assert statuses[0].passed is False
assert statuses[0].reason.startswith("Task's trigger rule 'all_success' requires all upstream tasks")

# t3 has an indirect setup so the setup check fails
# trigger rule also fails
statuses = self.get_dep_statuses(dr, "t3", session=session)
assert len(statuses) == 2
assert statuses[0].passed is False
assert statuses[0].reason.startswith("All setup tasks must complete successfully")
assert statuses[1].passed is False
assert statuses[1].reason.startswith("Task's trigger rule 'all_success' requires all upstream tasks")

@pytest.mark.parametrize(
"setup_state, expected", [(None, None), ("failed", "upstream_failed"), ("skipped", "skipped")]
)
def test_setup_constraint_changes_state_appropriately(self, dag_maker, session, setup_state, expected):
with dag_maker(session=session):

@task
def t1():
return 1

@task
def t2():
return 2

@task
def t3():
return 3

t1_task = t1()
t2_task = t2()
t3_task = t3()
t1_task >> t2_task >> t3_task
t1_task.as_setup()
dr = dag_maker.create_dagrun()

# if the setup fails then now, in processing the trigger rule dep, the ti states
# will be updated
if setup_state:
self.get_ti(dr, "t1").state = setup_state
session.commit()
(status,) = self.get_dep_statuses(dr, "t2", flag_upstream_failed=True, session=session)
assert status.passed is False
# t2 fails on the non-setup-related trigger rule constraint since it has
# a direct setup
assert status.reason.startswith("Task's trigger rule 'all_success' requires")
assert self.get_ti(dr, "t2").state == expected
assert self.get_ti(dr, "t3").state is None # hasn't been evaluated yet

# unlike t2, t3 fails on the setup constraint, and the normal trigger rule
# constraint is not actually evaluated, since it ain't gonna run anyway
if setup_state is None:
# when state is None, setup constraint doesn't mutate ti state, so we get
# two failure reasons -- setup constraint and trigger rule
(status, _) = self.get_dep_statuses(dr, "t3", flag_upstream_failed=True, session=session)
else:
(status,) = self.get_dep_statuses(dr, "t3", flag_upstream_failed=True, session=session)
assert status.reason.startswith("All setup tasks must complete successfully")
assert self.get_ti(dr, "t3").state == expected

@pytest.mark.parametrize(
"setup_state, expected", [(None, None), ("failed", "upstream_failed"), ("skipped", "skipped")]
)
def test_setup_constraint_will_fail_or_skip_fast(self, dag_maker, session, setup_state, expected):
"""
When a setup fails or skips, the tasks that depend on it will immediately fail or skip
and not, for example, wait for all setups to complete before determining what is
the appropriate state. This is a bit of a race condition, but it's consistent
with the behavior for many-to-one direct upstream task relationships, and it's
required if you want to fail fast.

So in this test we verify that if even one setup is failed or skipped, the
state will propagate to the in-scope work tasks.
"""
with dag_maker(session=session):

@task
def s1():
return 1

@task
def s2():
return 1

@task
def w1():
return 2

@task
def w2():
return 3

s1 = s1().as_setup()
s2 = s2().as_setup()
[s1, s2] >> w1() >> w2()
dr = dag_maker.create_dagrun()

# if the setup fails then now, in processing the trigger rule dep, the ti states
# will be updated
if setup_state:
self.get_ti(dr, "s2").state = setup_state
session.commit()
(status,) = self.get_dep_statuses(dr, "w1", flag_upstream_failed=True, session=session)
assert status.passed is False
# t2 fails on the non-setup-related trigger rule constraint since it has
# a direct setup
assert status.reason.startswith("Task's trigger rule 'all_success' requires")
assert self.get_ti(dr, "w1").state == expected
assert self.get_ti(dr, "w2").state is None # hasn't been evaluated yet

# unlike t2, t3 fails on the setup constraint, and the normal trigger rule
# constraint is not actually evaluated, since it ain't gonna run anyway
if setup_state is None:
# when state is None, setup constraint doesn't mutate ti state, so we get
# two failure reasons -- setup constraint and trigger rule
(status, _) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session)
else:
(status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session)
assert status.reason.startswith("All setup tasks must complete successfully")
assert self.get_ti(dr, "w2").state == expected