Skip to content

Commit

Permalink
Fix MappedTaskGroup tasks not respecting upstream dependency (#33732)
Browse files Browse the repository at this point in the history
* Fix MappedTaskGroup tasks not respecting upstream dependency

When a MappedTaskGroup has upstream dependencies, the tasks in the group don't wait for the upstream tasks
before they start running, this causes the tasks to fail.
From my investigation, the tasks inside the MappedTaskGroup don't have upstream tasks while the
MappedTaskGroup has the upstream tasks properly set. Due to this, the task's dependencies are met even though the Group has
upstreams that haven't finished.
The Fix was to set upstreams after creating the task group with the factory
Closes: apache/airflow#33446

* set the relationship in __exit__

(cherry picked from commit fe27031382e2034b59a23db1c6b9bdbfef259137)

GitOrigin-RevId: 22df7b111261c78fbeeb38191226f9694986bd05
  • Loading branch information
ephraimbuddy authored and Cloud Composer Team committed May 15, 2024
1 parent ac5ad7a commit 16c1f7e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
7 changes: 5 additions & 2 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,6 @@ class MappedTaskGroup(TaskGroup):
def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input
for op, _ in expand_input.iter_references():
self.set_upstream(op)

def iter_mapped_dependencies(self) -> Iterator[Operator]:
"""Upstream dependencies that provide XComs used by this mapped task group."""
Expand Down Expand Up @@ -619,6 +617,11 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)

def __exit__(self, exc_type, exc_val, exc_tb):
for op, _ in self._expand_input.iter_references():
self.set_upstream(op)
super().__exit__(exc_type, exc_val, exc_tb)


class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
Expand Down
46 changes: 46 additions & 0 deletions tests/decorators/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,52 @@ def tg(a, b):
assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")}


def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog):
with dag_maker() as dag:

@dag.task
def t1():
return [{"a": 1}, {"a": 2}]

@task_group("tg1")
def tg1(a, b):
@dag.task()
def t2():
return [a, b]

t2()

tg1.expand_kwargs(t1())

dr = dag_maker.create_dagrun()
dr.task_instance_scheduling_decisions()
assert "Cannot expand" not in caplog.text
assert "missing upstream values: ['expand_kwargs() argument']" not in caplog.text


def test_task_group_expand_with_upstream(dag_maker, session, caplog):
with dag_maker() as dag:

@dag.task
def t1():
return [1, 2, 3]

@task_group("tg1")
def tg1(a, b):
@dag.task()
def t2():
return [a, b]

t2()

tg1.partial(a=1).expand(b=t1())

dr = dag_maker.create_dagrun()
dr.task_instance_scheduling_decisions()
assert "Cannot expand" not in caplog.text
assert "missing upstream values: ['b']" not in caplog.text


def test_override_dag_default_args():
@dag(
dag_id="test_dag",
Expand Down

0 comments on commit 16c1f7e

Please sign in to comment.