From 384fa4a87dfaa79a89ad8e18ac1980e07badec4b Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Sat, 8 Jan 2022 12:09:03 +0800 Subject: [PATCH] Allow depending to a @task_group as a whole (#20671) --- airflow/decorators/task_group.py | 21 +++++++++++++++--- tests/utils/test_task_group.py | 38 +++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index e93002384cca9..02a980cc2f9b6 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -42,7 +42,7 @@ class TaskGroupDecorator(Generic[R]): """:meta private:""" - function: Callable[..., R] = attr.ib(validator=attr.validators.is_callable()) + function: Callable[..., Optional[R]] = attr.ib(validator=attr.validators.is_callable()) kwargs: Dict[str, Any] = attr.ib(factory=dict) """kwargs for the TaskGroup""" @@ -62,9 +62,24 @@ def _make_task_group(self, **kwargs) -> TaskGroup: return TaskGroup(**kwargs) def __call__(self, *args, **kwargs) -> R: - with self._make_task_group(add_suffix_on_collision=True, **self.kwargs): + with self._make_task_group(add_suffix_on_collision=True, **self.kwargs) as task_group: # Invoke function to run Tasks inside the TaskGroup - return self.function(*args, **kwargs) + retval = self.function(*args, **kwargs) + + # If the task-creating function returns a task, forward the return value + # so dependencies bind to it. This is equivalent to + # with TaskGroup(...) as tg: + # t2 = task_2(task_1()) + # start >> t2 >> end + if retval is not None: + return retval + + # Otherwise return the task group as a whole, equivalent to + # with TaskGroup(...) as tg: + # task_1() + # task_2() + # start >> tg >> end + return task_group def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs) diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 51e0319e49963..2c784e01704ab 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -689,6 +689,42 @@ def section_2(value2): assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids +def test_build_task_group_depended_by_task(): + """A decorator-based task group should be able to be used as a relative to operators.""" + + from airflow.decorators import dag as dag_decorator, task + + @dag_decorator(start_date=pendulum.now()) + def build_task_group_depended_by_task(): + @task + def task_start(): + return "[Task_start]" + + @task + def task_end(): + return "[Task_end]" + + @task + def task_thing(value): + return f"[Task_thing {value}]" + + @task_group_decorator + def section_1(): + task_thing(1) + task_thing(2) + + task_start() >> section_1() >> task_end() + + dag = build_task_group_depended_by_task() + task_thing_1 = dag.task_dict["section_1.task_thing"] + task_thing_2 = dag.task_dict["section_1.task_thing__1"] + + # Tasks in the task group don't depend on each other; they both become + # downstreams to task_start, and upstreams to task_end. + assert task_thing_1.upstream_task_ids == task_thing_2.upstream_task_ids == {"task_start"} + assert task_thing_1.downstream_task_ids == task_thing_2.downstream_task_ids == {"task_end"} + + def test_build_task_group_with_operators(): """Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator""" @@ -731,7 +767,7 @@ def section_a(value): t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag) sec_1.set_downstream(t_end) - # Testing Tasks ing DAG + # Testing Tasks in DAG assert set(dag.task_group.children.keys()) == {'section_1', 'task_start', 'task_end'} assert set(dag.task_group.children['section_1'].children.keys()) == { 'section_1.task_2',