Skip to content

Commit

Permalink
Fix dag.sub_dag not copying task_group tasks bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yuqian90 committed Sep 18, 2020
1 parent 4c9464a commit 85c9f67
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
4 changes: 4 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,10 @@ def remove_excluded(group):
if isinstance(child, BaseOperator):
if child.task_id not in dag.task_dict:
group.children.pop(child.task_id)
else:
# The tasks in the subdag are a copy of tasks in the original dag
# so update the reference in the TaskGroups too.
group.children[child.task_id] = dag.task_dict[child.task_id]
else:
remove_excluded(child)

Expand Down
50 changes: 41 additions & 9 deletions tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,34 +354,66 @@ def test_sub_dag_task_group():
_ = DummyOperator(task_id="task3")
_ = DummyOperator(task_id="task4")

with TaskGroup("group67") as group6:
with TaskGroup("group6") as group6:
_ = DummyOperator(task_id="task6")

task7 = DummyOperator(task_id="task7")

task5 = DummyOperator(task_id="task5")

task1 >> group234
group34 >> task5
group234 >> group6
group234 >> task7

subdag = dag.sub_dag(task_regex="task2", include_upstream=True, include_downstream=False)
subdag = dag.sub_dag(task_regex="task5", include_upstream=True, include_downstream=False)

assert extract_node_id(task_group_to_dict(subdag.task_group)) == {
'id': None,
'children': [
{'id': 'group234', 'children': [{'id': 'group234.task2'}, {'id': 'group234.upstream_join_id'}]},
{
'id': 'group234',
'children': [
{
'id': 'group234.group34',
'children': [
{'id': 'group234.group34.task3'},
{'id': 'group234.group34.task4'},
{'id': 'group234.group34.downstream_join_id'},
],
},
{'id': 'group234.upstream_join_id'},
],
},
{'id': 'task1'},
{'id': 'task5'},
],
}

edges = dag_edges(subdag)
assert sorted((e["source_id"], e["target_id"]) for e in edges) == [
('group234.group34.downstream_join_id', 'task5'),
('group234.group34.task3', 'group234.group34.downstream_join_id'),
('group234.group34.task4', 'group234.group34.downstream_join_id'),
('group234.upstream_join_id', 'group234.group34.task3'),
('group234.upstream_join_id', 'group234.group34.task4'),
('task1', 'group234.upstream_join_id'),
]

subdag_task_groups = subdag.task_group.get_task_group_dict()
assert subdag_task_groups.keys() == {None, "group234"}
assert subdag_task_groups.keys() == {None, "group234", "group234.group34"}

included_group_ids = {"group234", "group234.group34"}
included_task_ids = {'group234.group34.task3', 'group234.group34.task4', 'task1', 'task5'}

for task_group in subdag_task_groups.values():
assert task_group.upstream_group_ids.issubset({"group234"})
assert task_group.downstream_group_ids.issubset({"group234"})
assert task_group.upstream_task_ids.issubset({"task1", "task2"})
assert task_group.downstream_task_ids.issubset({"task1", "task2"})
assert task_group.upstream_group_ids.issubset(included_group_ids)
assert task_group.downstream_group_ids.issubset(included_group_ids)
assert task_group.upstream_task_ids.issubset(included_task_ids)
assert task_group.downstream_task_ids.issubset(included_task_ids)

for task in subdag.task_group:
assert task.upstream_task_ids.issubset(included_task_ids)
assert task.downstream_task_ids.issubset(included_task_ids)


def test_dag_edges():
Expand Down

0 comments on commit 85c9f67

Please sign in to comment.