From 85c9f67690619fcaad543182728d78b3e91699eb Mon Sep 17 00:00:00 2001 From: Qian Yu Date: Thu, 17 Sep 2020 11:45:26 +0800 Subject: [PATCH] Fix dag.sub_dag not copying task_group tasks bug --- airflow/models/dag.py | 4 +++ tests/utils/test_task_group.py | 50 ++++++++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4e528e8edc0cc..886837c429a76 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1279,6 +1279,10 @@ def remove_excluded(group): if isinstance(child, BaseOperator): if child.task_id not in dag.task_dict: group.children.pop(child.task_id) + else: + # The tasks in the subdag are a copy of tasks in the original dag + # so update the reference in the TaskGroups too. + group.children[child.task_id] = dag.task_dict[child.task_id] else: remove_excluded(child) diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 1aa51b613a0e7..c4f7a125dec4c 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -354,34 +354,66 @@ def test_sub_dag_task_group(): _ = DummyOperator(task_id="task3") _ = DummyOperator(task_id="task4") - with TaskGroup("group67") as group6: + with TaskGroup("group6") as group6: _ = DummyOperator(task_id="task6") task7 = DummyOperator(task_id="task7") - task5 = DummyOperator(task_id="task5") + task1 >> group234 group34 >> task5 group234 >> group6 group234 >> task7 - subdag = dag.sub_dag(task_regex="task2", include_upstream=True, include_downstream=False) + subdag = dag.sub_dag(task_regex="task5", include_upstream=True, include_downstream=False) assert extract_node_id(task_group_to_dict(subdag.task_group)) == { 'id': None, 'children': [ - {'id': 'group234', 'children': [{'id': 'group234.task2'}, {'id': 'group234.upstream_join_id'}]}, + { + 'id': 'group234', + 'children': [ + { + 'id': 'group234.group34', + 'children': [ + {'id': 'group234.group34.task3'}, + {'id': 'group234.group34.task4'}, + {'id': 'group234.group34.downstream_join_id'}, + ], + }, + {'id': 'group234.upstream_join_id'}, + ], + }, {'id': 'task1'}, + {'id': 'task5'}, ], } + edges = dag_edges(subdag) + assert sorted((e["source_id"], e["target_id"]) for e in edges) == [ + ('group234.group34.downstream_join_id', 'task5'), + ('group234.group34.task3', 'group234.group34.downstream_join_id'), + ('group234.group34.task4', 'group234.group34.downstream_join_id'), + ('group234.upstream_join_id', 'group234.group34.task3'), + ('group234.upstream_join_id', 'group234.group34.task4'), + ('task1', 'group234.upstream_join_id'), + ] + subdag_task_groups = subdag.task_group.get_task_group_dict() - assert subdag_task_groups.keys() == {None, "group234"} + assert subdag_task_groups.keys() == {None, "group234", "group234.group34"} + + included_group_ids = {"group234", "group234.group34"} + included_task_ids = {'group234.group34.task3', 'group234.group34.task4', 'task1', 'task5'} + for task_group in subdag_task_groups.values(): - assert task_group.upstream_group_ids.issubset({"group234"}) - assert task_group.downstream_group_ids.issubset({"group234"}) - assert task_group.upstream_task_ids.issubset({"task1", "task2"}) - assert task_group.downstream_task_ids.issubset({"task1", "task2"}) + assert task_group.upstream_group_ids.issubset(included_group_ids) + assert task_group.downstream_group_ids.issubset(included_group_ids) + assert task_group.upstream_task_ids.issubset(included_task_ids) + assert task_group.downstream_task_ids.issubset(included_task_ids) + + for task in subdag.task_group: + assert task.upstream_task_ids.issubset(included_task_ids) + assert task.downstream_task_ids.issubset(included_task_ids) def test_dag_edges():