Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow depending to a @task_group as a whole #20671

Merged
merged 1 commit into from
Jan 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
class TaskGroupDecorator(Generic[R]):
""":meta private:"""

function: Callable[..., R] = attr.ib(validator=attr.validators.is_callable())
function: Callable[..., Optional[R]] = attr.ib(validator=attr.validators.is_callable())
kwargs: Dict[str, Any] = attr.ib(factory=dict)
"""kwargs for the TaskGroup"""

Expand All @@ -62,9 +62,24 @@ def _make_task_group(self, **kwargs) -> TaskGroup:
return TaskGroup(**kwargs)

def __call__(self, *args, **kwargs) -> R:
with self._make_task_group(add_suffix_on_collision=True, **self.kwargs):
with self._make_task_group(add_suffix_on_collision=True, **self.kwargs) as task_group:
# Invoke function to run Tasks inside the TaskGroup
return self.function(*args, **kwargs)
retval = self.function(*args, **kwargs)

# If the task-creating function returns a task, forward the return value
# so dependencies bind to it. This is equivalent to
# with TaskGroup(...) as tg:
# t2 = task_2(task_1())
# start >> t2 >> end
if retval is not None:
return retval

# Otherwise return the task group as a whole, equivalent to
# with TaskGroup(...) as tg:
# task_1()
# task_2()
# start >> tg >> end
return task_group

def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs)
Expand Down
38 changes: 37 additions & 1 deletion tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,42 @@ def section_2(value2):
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids


def test_build_task_group_depended_by_task():
"""A decorator-based task group should be able to be used as a relative to operators."""

from airflow.decorators import dag as dag_decorator, task

@dag_decorator(start_date=pendulum.now())
def build_task_group_depended_by_task():
@task
def task_start():
return "[Task_start]"

@task
def task_end():
return "[Task_end]"

@task
def task_thing(value):
return f"[Task_thing {value}]"

@task_group_decorator
def section_1():
task_thing(1)
task_thing(2)

task_start() >> section_1() >> task_end()

dag = build_task_group_depended_by_task()
task_thing_1 = dag.task_dict["section_1.task_thing"]
task_thing_2 = dag.task_dict["section_1.task_thing__1"]

# Tasks in the task group don't depend on each other; they both become
# downstreams to task_start, and upstreams to task_end.
assert task_thing_1.upstream_task_ids == task_thing_2.upstream_task_ids == {"task_start"}
assert task_thing_1.downstream_task_ids == task_thing_2.downstream_task_ids == {"task_end"}


def test_build_task_group_with_operators():
"""Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator"""

Expand Down Expand Up @@ -731,7 +767,7 @@ def section_a(value):
t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag)
sec_1.set_downstream(t_end)

# Testing Tasks ing DAG
# Testing Tasks in DAG
assert set(dag.task_group.children.keys()) == {'section_1', 'task_start', 'task_end'}
assert set(dag.task_group.children['section_1'].children.keys()) == {
'section_1.task_2',
Expand Down