diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index a7e557ff96272..a28a6ea85815d 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -509,6 +509,7 @@ def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]: def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]: # TODO: Find a way to cache this. from airflow.models.taskmap import TaskMap + from airflow.models.xcom import XCom from airflow.models.xcom_arg import XComArg expansion_kwargs = self._get_expansion_kwargs() @@ -518,19 +519,45 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]: map_lengths.update((k, len(v)) for k, v in expansion_kwargs.items() if not isinstance(v, XComArg)) # Build a reverse mapping of what arguments each task contributes to. - dep_keys: Dict[str, Set[str]] = collections.defaultdict(set) + mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set) + non_mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set) for k, v in expansion_kwargs.items(): if not isinstance(v, XComArg): continue - dep_keys[v.operator.task_id].add(k) - + if v.operator.is_mapped: + mapped_dep_keys[v.operator.task_id].add(k) + else: + non_mapped_dep_keys[v.operator.task_id].add(k) + # TODO: It's not possible now, but in the future (AIP-42 Phase 2) + # we will add support for depending on one single mapped task + # instance. When that happens, we need to further analyze the mapped + # case to contain only tasks we depend on "as a whole", and put + # those we only depend on individually to the non-mapped lookup. + + # Collect lengths from unmapped upstreams. taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter( TaskMap.dag_id == self.dag_id, TaskMap.run_id == run_id, - TaskMap.task_id.in_(list(dep_keys)), + TaskMap.task_id.in_(non_mapped_dep_keys), + TaskMap.map_index < 0, ) for task_id, length in taskmap_query: - for mapped_arg_name in dep_keys[task_id]: + for mapped_arg_name in non_mapped_dep_keys[task_id]: + map_lengths[mapped_arg_name] += length + + # Collect lengths from mapped upstreams. + xcom_query = ( + session.query(XCom.task_id, func.count(XCom.map_index)) + .group_by(XCom.task_id) + .filter( + XCom.dag_id == self.dag_id, + XCom.run_id == run_id, + XCom.task_id.in_(mapped_dep_keys), + XCom.map_index >= 0, + ) + ) + for task_id, length in xcom_query: + for mapped_arg_name in mapped_dep_keys[task_id]: map_lengths[mapped_arg_name] += length if len(map_lengths) < len(expansion_kwargs): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 00ea4e2c05feb..2f3d75436de55 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -35,7 +35,9 @@ IO, TYPE_CHECKING, Any, + ContextManager, Dict, + Generator, Iterable, Iterator, List, @@ -69,6 +71,8 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value +from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.sqltypes import BigInteger @@ -295,6 +299,71 @@ def clear_task_instances( dr.start_date = None +class _LazyXComAccessIterator(collections.abc.Iterator): + __slots__ = ['_cm', '_it'] + + def __init__(self, cm: ContextManager[Query]): + self._cm = cm + self._it = None + + def __del__(self): + if self._it: + self._cm.__exit__(None, None, None) + + def __iter__(self): + return self + + def __next__(self): + if not self._it: + self._it = iter(self._cm.__enter__()) + return XCom.deserialize_value(next(self._it)) + + +class _LazyXComAccess(collections.abc.Sequence): + """Wrapper to lazily pull XCom with a sequence-like interface. + + Note that since the session bound to the parent query may have died when we + actually access the sequence's content, we must create a new session + for every function call with ``with_session()``. + """ + + def __init__(self, query: Query): + self._q = query + self._len = None + + def __len__(self): + if self._len is None: + with self._get_bound_query() as query: + self._len = query.count() + return self._len + + def __iter__(self): + return _LazyXComAccessIterator(self._get_bound_query()) + + def __getitem__(self, key): + if not isinstance(key, int): + raise ValueError("only support index access for now") + try: + with self._get_bound_query() as query: + r = query.offset(key).limit(1).one() + except NoResultFound: + raise IndexError(key) from None + return XCom.deserialize_value(r) + + @contextlib.contextmanager + def _get_bound_query(self) -> Generator[Query, None, None]: + # Do we have a valid session already? + if self._q.session and self._q.session.is_active: + yield self._q + return + + session = settings.Session() + try: + yield self._q.with_session(session) + finally: + session.close() + + class TaskInstanceKey(NamedTuple): """Key used to identify task instance.""" @@ -2233,14 +2302,19 @@ def set_duration(self) -> None: self.log.debug("Task Duration set to %s", self.duration) def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None: - if not task.has_mapped_dependants(): + # TODO: We don't push TaskMap for mapped task instances because it's not + # currently possible for a downstream to depend on one individual mapped + # task instance, only a task as a whole. This will change in AIP-42 + # Phase 2, and we'll need to further analyze the mapped task case. + if task.is_mapped or not task.has_mapped_dependants(): return if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)): raise UnmappableXComTypePushed(value) + task_map = TaskMap.from_task_instance_xcom(self, value) max_map_length = conf.getint("core", "max_map_length", fallback=1024) - if len(value) > max_map_length: + if task_map.length > max_map_length: raise UnmappableXComLengthPushed(value, max_map_length) - session.merge(TaskMap.from_task_instance_xcom(self, value)) + session.merge(task_map) @provide_session def xcom_push( @@ -2351,21 +2425,13 @@ def xcom_pull( # make sure all XComs come from one task and run (for task_ids=None # and include_prior_dates=True), and re-order by map index (reset # needed because XCom.get_many() orders by XCom timestamp). - query = ( + return _LazyXComAccess( query.with_entities(XCom.value) - .filter(XCom.run_id == first.run_id, XCom.task_id == first.task_id) + .filter(XCom.run_id == first.run_id, XCom.task_id == first.task_id, XCom.map_index >= 0) .order_by(None) .order_by(XCom.map_index.asc()) ) - def iter_xcom_values(query): - # The session passed to xcom_pull() may die before this is - # iterated through, so we need to bind to a new session. - for r in query.with_session(settings.Session()): - yield XCom.deserialize_value(r) - - return iter_xcom_values(query) - # At this point either task_ids or map_indexes is explicitly multi-value. results = ( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 8b4da46692f38..2a1ea0889e3d8 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1089,10 +1089,7 @@ def test_xcom_pull_mapped(self, dag_maker, session): assert ti_2.xcom_pull(["task_1"], map_indexes=[0, 1], session=session) == ["a", "b"] assert ti_2.xcom_pull("task_1", map_indexes=1, session=session) == "b" - - joined = ti_2.xcom_pull("task_1", session=session) - assert iter(joined) is joined, "should be iterator" - assert list(joined) == ["a", "b"] + assert list(ti_2.xcom_pull("task_1", session=session)) == ["a", "b"] def test_xcom_pull_after_success(self, create_task_instance): """ @@ -2635,7 +2632,7 @@ def cmds(): @mock.patch("airflow.models.taskinstance.XCom.deserialize_value", side_effect=XCom.deserialize_value) -def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterator(mock_deserialize_value, dag_maker, session): +def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session): """Ensure we access XCom lazily when pulling from a mapped operator.""" with dag_maker(dag_id="test_xcom", session=session): task_1 = DummyOperator.partial(task_id="task_1").expand() @@ -2657,10 +2654,36 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterator(mock_deserialize_v joined = ti_2.xcom_pull("task_1", session=session) assert mock_deserialize_value.call_count == 0 - # Only when we go through the iterator does deserialization happen. - assert next(joined) == "a" + # Only when we go through the iterable does deserialization happen. + it = iter(joined) + assert next(it) == "a" assert mock_deserialize_value.call_count == 1 - assert next(joined) == "b" + assert next(it) == "b" assert mock_deserialize_value.call_count == 2 with pytest.raises(StopIteration): - next(joined) + next(it) + + +def test_ti_mapped_depends_on_mapped_xcom_arg(dag_maker, session): + with dag_maker(session=session) as dag: + + @dag.task + def add_one(x): + return x + 1 + + two_three_four = add_one.expand(x=[1, 2, 3]) + add_one.expand(x=two_three_four) + + dagrun = dag_maker.create_dagrun() + for map_index in range(3): + ti = dagrun.get_task_instance("add_one", map_index=map_index) + ti.refresh_from_task(dag.get_task("add_one")) + ti.run() + + task_345 = dag.get_task("add_one__1") + for ti in task_345.expand_mapped_task(dagrun.run_id, session=session): + ti.refresh_from_task(task_345) + ti.run() + + query = XCom.get_many(run_id=dagrun.run_id, task_ids=["add_one__1"], session=session) + assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5]