From 1b50f3c3c3e2d08349814c0142bfbc00d176a62f Mon Sep 17 00:00:00 2001 From: yuqian90 Date: Thu, 25 Feb 2021 23:23:15 +0800 Subject: [PATCH] Make airflow dags show command display TaskGroup (#14269) closes: #13053 Make `airflow dags show` display TaskGroup. GitOrigin-RevId: c71f707d24a9196d33b91a7a2a9e3384698e5193 --- airflow/utils/dot_renderer.py | 120 ++++++++++++++++++++++++------- tests/utils/test_dot_renderer.py | 101 +++++++++++++++++++++++++- 2 files changed, 191 insertions(+), 30 deletions(-) diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py index 990c7a7d126..4123f99303b 100644 --- a/airflow/utils/dot_renderer.py +++ b/airflow/utils/dot_renderer.py @@ -17,13 +17,17 @@ # specific language governing permissions and limitations # under the License. """Renderer DAG (tasks and dependencies) to the graphviz object.""" -from typing import List, Optional +from typing import Dict, List, Optional import graphviz from airflow.models import TaskInstance +from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG +from airflow.models.taskmixin import TaskMixin from airflow.utils.state import State +from airflow.utils.task_group import TaskGroup +from airflow.www.views import dag_edges def _refine_color(color: str): @@ -42,6 +46,88 @@ def _refine_color(color: str): return color +def _draw_task(task: BaseOperator, parent_graph: graphviz.Digraph, states_by_task_id: Dict[str, str]) -> None: + """Draw a single task on the given parent_graph""" + if states_by_task_id: + state = states_by_task_id.get(task.task_id, State.NONE) + color = State.color_fg(state) + fill_color = State.color(state) + else: + color = task.ui_fgcolor + fill_color = task.ui_color + + parent_graph.node( + task.task_id, + _attributes={ + "label": task.label, + "shape": "rectangle", + "style": "filled,rounded", + "color": _refine_color(color), + "fillcolor": _refine_color(fill_color), + }, + ) + + +def _draw_task_group( + task_group: TaskGroup, parent_graph: graphviz.Digraph, states_by_task_id: Dict[str, str] +) -> None: + """Draw the given task_group and its children on the given parent_graph""" + # Draw joins + if task_group.upstream_group_ids or task_group.upstream_task_ids: + parent_graph.node( + task_group.upstream_join_id, + _attributes={ + "label": "", + "shape": "circle", + "style": "filled,rounded", + "color": _refine_color(task_group.ui_fgcolor), + "fillcolor": _refine_color(task_group.ui_color), + "width": "0.2", + "height": "0.2", + }, + ) + + if task_group.downstream_group_ids or task_group.downstream_task_ids: + parent_graph.node( + task_group.downstream_join_id, + _attributes={ + "label": "", + "shape": "circle", + "style": "filled,rounded", + "color": _refine_color(task_group.ui_fgcolor), + "fillcolor": _refine_color(task_group.ui_color), + "width": "0.2", + "height": "0.2", + }, + ) + + # Draw children + for child in sorted(task_group.children.values(), key=lambda t: t.label): + _draw_nodes(child, parent_graph, states_by_task_id) + + +def _draw_nodes(node: TaskMixin, parent_graph: graphviz.Digraph, states_by_task_id: Dict[str, str]) -> None: + """Draw the node and its children on the given parent_graph recursively.""" + if isinstance(node, BaseOperator): + _draw_task(node, parent_graph, states_by_task_id) + else: + # Draw TaskGroup + if node.is_root: + # No need to draw background for root TaskGroup. + _draw_task_group(node, parent_graph, states_by_task_id) + else: + with parent_graph.subgraph(name=f"cluster_{node.group_id}") as sub: + sub.attr( + shape="rectangle", + style="filled", + color=_refine_color(node.ui_fgcolor), + # Partially transparent CornflowerBlue + fillcolor="#6495ed7f", + label=node.label, + ) + _draw_task_group(node, sub, states_by_task_id) + + def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.Digraph: """ Renders the DAG object to the DOT object. @@ -66,30 +152,10 @@ def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.D states_by_task_id = None if tis is not None: states_by_task_id = {ti.task_id: ti.state for ti in tis} - for task in dag.tasks: - node_attrs = { - "shape": "rectangle", - "style": "filled,rounded", - } - if states_by_task_id is None: - node_attrs.update( - { - "color": _refine_color(task.ui_fgcolor), - "fillcolor": _refine_color(task.ui_color), - } - ) - else: - state = states_by_task_id.get(task.task_id, State.NONE) - node_attrs.update( - { - "color": State.color_fg(state), - "fillcolor": State.color(state), - } - ) - dot.node( - task.task_id, - _attributes=node_attrs, - ) - for downstream_task_id in task.downstream_task_ids: - dot.edge(task.task_id, downstream_task_id) + + _draw_nodes(dag.task_group, dot, states_by_task_id) + + for edge in dag_edges(dag): + dot.edge(edge["source_id"], edge["target_id"]) + return dot diff --git a/tests/utils/test_dot_renderer.py b/tests/utils/test_dot_renderer.py index b0306233fc9..ca3ea017941 100644 --- a/tests/utils/test_dot_renderer.py +++ b/tests/utils/test_dot_renderer.py @@ -23,9 +23,11 @@ from airflow.models import TaskInstance from airflow.models.dag import DAG from airflow.operators.bash import BashOperator +from airflow.operators.dummy import DummyOperator from airflow.operators.python import PythonOperator from airflow.utils import dot_renderer from airflow.utils.state import State +from airflow.utils.task_group import TaskGroup START_DATE = datetime.datetime.now() @@ -72,9 +74,16 @@ def test_should_render_dag_with_task_instances(self): source = dot.source # Should render DAG title assert "label=DAG_ID" in source - assert 'first [color=black fillcolor=tan shape=rectangle style="filled,rounded"]' in source - assert 'second [color=white fillcolor=green shape=rectangle style="filled,rounded"]' in source - assert 'third [color=black fillcolor=lime shape=rectangle style="filled,rounded"]' in source + assert ( + 'first [color=black fillcolor=tan label=first shape=rectangle style="filled,rounded"]' in source + ) + assert ( + 'second [color=white fillcolor=green label=second shape=rectangle style="filled,rounded"]' + in source + ) + assert ( + 'third [color=black fillcolor=lime label=third shape=rectangle style="filled,rounded"]' in source + ) def test_should_render_dag_orientation(self): orientation = "TB" @@ -105,3 +114,89 @@ def test_should_render_dag_orientation(self): # Should render DAG title with orientation assert "label=DAG_ID" in source assert f'label=DAG_ID labelloc=t rankdir={orientation}' in source + + def test_render_task_group(self): + with DAG(dag_id="example_task_group", start_date=START_DATE) as dag: + start = DummyOperator(task_id="start") + + with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1: + task_1 = DummyOperator(task_id="task_1") + task_2 = BashOperator(task_id="task_2", bash_command='echo 1') + task_3 = DummyOperator(task_id="task_3") + + task_1 >> [task_2, task_3] + + with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2: + task_1 = DummyOperator(task_id="task_1") + + with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2"): + task_2 = BashOperator(task_id="task_2", bash_command='echo 1') + task_3 = DummyOperator(task_id="task_3") + task_4 = DummyOperator(task_id="task_4") + + [task_2, task_3] >> task_4 + + end = DummyOperator(task_id='end') + + start >> section_1 >> section_2 >> end + + dot = dot_renderer.render_dag(dag) + + assert dot.source == '\n'.join( + [ + 'digraph example_task_group {', + '\tgraph [label=example_task_group labelloc=t rankdir=LR]', + '\tend [color="#000000" fillcolor="#e8f7e4" label=end shape=rectangle ' + 'style="filled,rounded"]', + '\tsubgraph cluster_section_1 {', + '\t\tcolor="#000000" fillcolor="#6495ed7f" label=section_1 shape=rectangle style=filled', + '\t\t"section_1.upstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 ' + 'label="" shape=circle style="filled,rounded" width=0.2]', + '\t\t"section_1.downstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 ' + 'label="" shape=circle style="filled,rounded" width=0.2]', + '\t\t"section_1.task_1" [color="#000000" fillcolor="#e8f7e4" label=task_1 shape=rectangle ' + 'style="filled,rounded"]', + '\t\t"section_1.task_2" [color="#000000" fillcolor="#f0ede4" label=task_2 shape=rectangle ' + 'style="filled,rounded"]', + '\t\t"section_1.task_3" [color="#000000" fillcolor="#e8f7e4" label=task_3 shape=rectangle ' + 'style="filled,rounded"]', + '\t}', + '\tsubgraph cluster_section_2 {', + '\t\tcolor="#000000" fillcolor="#6495ed7f" label=section_2 shape=rectangle style=filled', + '\t\t"section_2.upstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 ' + 'label="" shape=circle style="filled,rounded" width=0.2]', + '\t\t"section_2.downstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 ' + 'label="" shape=circle style="filled,rounded" width=0.2]', + '\t\tsubgraph "cluster_section_2.inner_section_2" {', + '\t\t\tcolor="#000000" fillcolor="#6495ed7f" label=inner_section_2 shape=rectangle ' + 'style=filled', + '\t\t\t"section_2.inner_section_2.task_2" [color="#000000" fillcolor="#f0ede4" label=task_2 ' + 'shape=rectangle style="filled,rounded"]', + '\t\t\t"section_2.inner_section_2.task_3" [color="#000000" fillcolor="#e8f7e4" label=task_3 ' + 'shape=rectangle style="filled,rounded"]', + '\t\t\t"section_2.inner_section_2.task_4" [color="#000000" fillcolor="#e8f7e4" label=task_4 ' + 'shape=rectangle style="filled,rounded"]', + '\t\t}', + '\t\t"section_2.task_1" [color="#000000" fillcolor="#e8f7e4" label=task_1 shape=rectangle ' + 'style="filled,rounded"]', + '\t}', + '\tstart [color="#000000" fillcolor="#e8f7e4" label=start shape=rectangle ' + 'style="filled,rounded"]', + '\t"section_1.downstream_join_id" -> "section_2.upstream_join_id"', + '\t"section_1.task_1" -> "section_1.task_2"', + '\t"section_1.task_1" -> "section_1.task_3"', + '\t"section_1.task_2" -> "section_1.downstream_join_id"', + '\t"section_1.task_3" -> "section_1.downstream_join_id"', + '\t"section_1.upstream_join_id" -> "section_1.task_1"', + '\t"section_2.downstream_join_id" -> end', + '\t"section_2.inner_section_2.task_2" -> "section_2.inner_section_2.task_4"', + '\t"section_2.inner_section_2.task_3" -> "section_2.inner_section_2.task_4"', + '\t"section_2.inner_section_2.task_4" -> "section_2.downstream_join_id"', + '\t"section_2.task_1" -> "section_2.downstream_join_id"', + '\t"section_2.upstream_join_id" -> "section_2.inner_section_2.task_2"', + '\t"section_2.upstream_join_id" -> "section_2.inner_section_2.task_3"', + '\t"section_2.upstream_join_id" -> "section_2.task_1"', + '\tstart -> "section_1.upstream_join_id"', + '}', + ] + )