Skip to content

Commit

Permalink
Allow depending to a @task_group as a whole (#20671)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Jan 8, 2022
1 parent f2039b4 commit 384fa4a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
21 changes: 18 additions & 3 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
class TaskGroupDecorator(Generic[R]):
""":meta private:"""

function: Callable[..., R] = attr.ib(validator=attr.validators.is_callable())
function: Callable[..., Optional[R]] = attr.ib(validator=attr.validators.is_callable())
kwargs: Dict[str, Any] = attr.ib(factory=dict)
"""kwargs for the TaskGroup"""

Expand All @@ -62,9 +62,24 @@ def _make_task_group(self, **kwargs) -> TaskGroup:
return TaskGroup(**kwargs)

def __call__(self, *args, **kwargs) -> R:
with self._make_task_group(add_suffix_on_collision=True, **self.kwargs):
with self._make_task_group(add_suffix_on_collision=True, **self.kwargs) as task_group:
# Invoke function to run Tasks inside the TaskGroup
return self.function(*args, **kwargs)
retval = self.function(*args, **kwargs)

# If the task-creating function returns a task, forward the return value
# so dependencies bind to it. This is equivalent to
# with TaskGroup(...) as tg:
# t2 = task_2(task_1())
# start >> t2 >> end
if retval is not None:
return retval

# Otherwise return the task group as a whole, equivalent to
# with TaskGroup(...) as tg:
# task_1()
# task_2()
# start >> tg >> end
return task_group

def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs)
Expand Down
38 changes: 37 additions & 1 deletion tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,42 @@ def section_2(value2):
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids


def test_build_task_group_depended_by_task():
"""A decorator-based task group should be able to be used as a relative to operators."""

from airflow.decorators import dag as dag_decorator, task

@dag_decorator(start_date=pendulum.now())
def build_task_group_depended_by_task():
@task
def task_start():
return "[Task_start]"

@task
def task_end():
return "[Task_end]"

@task
def task_thing(value):
return f"[Task_thing {value}]"

@task_group_decorator
def section_1():
task_thing(1)
task_thing(2)

task_start() >> section_1() >> task_end()

dag = build_task_group_depended_by_task()
task_thing_1 = dag.task_dict["section_1.task_thing"]
task_thing_2 = dag.task_dict["section_1.task_thing__1"]

# Tasks in the task group don't depend on each other; they both become
# downstreams to task_start, and upstreams to task_end.
assert task_thing_1.upstream_task_ids == task_thing_2.upstream_task_ids == {"task_start"}
assert task_thing_1.downstream_task_ids == task_thing_2.downstream_task_ids == {"task_end"}


def test_build_task_group_with_operators():
"""Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator"""

Expand Down Expand Up @@ -731,7 +767,7 @@ def section_a(value):
t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag)
sec_1.set_downstream(t_end)

# Testing Tasks ing DAG
# Testing Tasks in DAG
assert set(dag.task_group.children.keys()) == {'section_1', 'task_start', 'task_end'}
assert set(dag.task_group.children['section_1'].children.keys()) == {
'section_1.task_2',
Expand Down

0 comments on commit 384fa4a

Please sign in to comment.