diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 167eb53b71e..1c0d1370d7e 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -565,8 +565,6 @@ class MappedTaskGroup(TaskGroup): def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: super().__init__(**kwargs) self._expand_input = expand_input - for op, _ in expand_input.iter_references(): - self.set_upstream(op) def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this mapped task group.""" @@ -619,6 +617,11 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), ) + def __exit__(self, exc_type, exc_val, exc_tb): + for op, _ in self._expand_input.iter_references(): + self.set_upstream(op) + super().__exit__(exc_type, exc_val, exc_tb) + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 3462c3a1d83..4c741ef1c15 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -191,6 +191,52 @@ def tg(a, b): assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")} +def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog): + with dag_maker() as dag: + + @dag.task + def t1(): + return [{"a": 1}, {"a": 2}] + + @task_group("tg1") + def tg1(a, b): + @dag.task() + def t2(): + return [a, b] + + t2() + + tg1.expand_kwargs(t1()) + + dr = dag_maker.create_dagrun() + dr.task_instance_scheduling_decisions() + assert "Cannot expand" not in caplog.text + assert "missing upstream values: ['expand_kwargs() argument']" not in caplog.text + + +def test_task_group_expand_with_upstream(dag_maker, session, caplog): + with dag_maker() as dag: + + @dag.task + def t1(): + return [1, 2, 3] + + @task_group("tg1") + def tg1(a, b): + @dag.task() + def t2(): + return [a, b] + + t2() + + tg1.partial(a=1).expand(b=t1()) + + dr = dag_maker.create_dagrun() + dr.task_instance_scheduling_decisions() + assert "Cannot expand" not in caplog.text + assert "missing upstream values: ['b']" not in caplog.text + + def test_override_dag_default_args(): @dag( dag_id="test_dag",