Skip to content

Commit

Permalink
Don't ignore setups when arrowing from group (#33097)
Browse files Browse the repository at this point in the history
This enables us to have a group with just setups in it.
  • Loading branch information
dstandish authored Aug 8, 2023
1 parent 569e32b commit cd7e7bc
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 6 deletions.
14 changes: 9 additions & 5 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,21 +370,25 @@ def get_leaves(self) -> Generator[BaseOperator, None, None]:
tasks = list(self)
ids = {x.task_id for x in tasks}

def recurse_for_first_non_setup_teardown(task):
def recurse_for_first_non_teardown(task):
for upstream_task in task.upstream_list:
if upstream_task.task_id not in ids:
# upstream task is not in task group
continue
elif upstream_task.is_teardown:
yield from recurse_for_first_non_teardown(upstream_task)
elif task.is_teardown and upstream_task.is_setup:
# don't go through the teardown-to-setup path
continue
if upstream_task.is_setup or upstream_task.is_teardown:
yield from recurse_for_first_non_setup_teardown(upstream_task)
else:
yield upstream_task

for task in tasks:
if task.downstream_task_ids.isdisjoint(ids):
if not (task.is_teardown or task.is_setup):
if not task.is_teardown:
yield task
else:
yield from recurse_for_first_non_setup_teardown(task)
yield from recurse_for_first_non_teardown(task)

def child_id(self, label):
"""Prefix label with group_id if prefix_group_id is True. Otherwise return the label as-is."""
Expand Down
120 changes: 119 additions & 1 deletion tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
import pendulum
import pytest

from airflow.decorators import dag, task as task_decorator, task_group as task_group_decorator
from airflow.decorators import (
dag,
setup,
task as task_decorator,
task_group as task_group_decorator,
teardown,
)
from airflow.exceptions import TaskAlreadyInTaskGroup
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
Expand Down Expand Up @@ -1479,3 +1485,115 @@ def test_task_group_arrow_with_setups_teardowns():
tg1 >> w2
assert t1.downstream_task_ids == set()
assert w1.downstream_task_ids == {"tg1.t1", "w2"}


def test_task_group_arrow_with_setup_group():
with DAG(dag_id="setup_group_teardown_group", start_date=pendulum.now()):
with TaskGroup("group_1") as g1:

@setup
def setup_1():
...

@setup
def setup_2():
...

s1 = setup_1()
s2 = setup_2()

with TaskGroup("group_2") as g2:

@teardown
def teardown_1():
...

@teardown
def teardown_2():
...

t1 = teardown_1()
t2 = teardown_2()

@task_decorator
def work():
...

w1 = work()
g1 >> w1 >> g2
t1.as_teardown(setups=s1)
t2.as_teardown(setups=s2)
assert set(s1.operator.downstream_task_ids) == {"work", "group_2.teardown_1"}
assert set(s2.operator.downstream_task_ids) == {"work", "group_2.teardown_2"}
assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"}
assert set(t1.operator.downstream_task_ids) == set()
assert set(t2.operator.downstream_task_ids) == set()

def get_nodes(group):
d = task_group_to_dict(group)
new_d = {}
new_d["id"] = d["id"]
new_d["children"] = [{"id": x["id"]} for x in d["children"]]
return new_d

assert get_nodes(g1) == {
"id": "group_1",
"children": [
{"id": "group_1.setup_1"},
{"id": "group_1.setup_2"},
{"id": "group_1.downstream_join_id"},
],
}


def test_task_group_arrow_with_setup_group_deeper_setup():
"""
When recursing upstream for a non-teardown leaf, we should ignore setups that
are direct upstream of a teardown.
"""
with DAG(dag_id="setup_group_teardown_group_2", start_date=pendulum.now()):
with TaskGroup("group_1") as g1:

@setup
def setup_1():
...

@setup
def setup_2():
...

@teardown
def teardown_0():
...

s1 = setup_1()
s2 = setup_2()
t0 = teardown_0()
s2 >> t0

with TaskGroup("group_2") as g2:

@teardown
def teardown_1():
...

@teardown
def teardown_2():
...

t1 = teardown_1()
t2 = teardown_2()

@task_decorator
def work():
...

w1 = work()
g1 >> w1 >> g2
t1.as_teardown(setups=s1)
t2.as_teardown(setups=s2)
assert set(s1.operator.downstream_task_ids) == {"work", "group_2.teardown_1"}
assert set(s2.operator.downstream_task_ids) == {"group_1.teardown_0", "group_2.teardown_2"}
assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"}
assert set(t1.operator.downstream_task_ids) == set()
assert set(t2.operator.downstream_task_ids) == set()

0 comments on commit cd7e7bc

Please sign in to comment.