Skip to content

Commit

Permalink
Allow using mapped upstream's aggregated XCom (#22849)
Browse files Browse the repository at this point in the history
This needs two changes. First, when the upstream pushes the return value
to XCom, we need to identify that the pushed value is not used on its
own, but only aggregated with other return values from other mapped task
instances. Fortunately, this is actually the only possible case right
now, since we have not implemented support for depending on individual
return values from a mapped task (aka nested mapping). So we instead
skip recording any TaskMap metadata from a mapped task to avoid the
problem altogether.

The second change is for when the downstream task is expanded. Since the
task depends on the mapped upstream as a whole, we should not use
TaskMap from the upstream (which corresponds to individual task
instances, as mentioned above), but the XComs pushed by every instance
of the mapped task. Again, since we don't nested mapping now, we can cut
corners and simply check whether the upstream is mapped or not to decide
what to do, and leave further logic to the future.

Co-authored-by: Ash Berlin-Taylor <[email protected]>
  • Loading branch information
uranusjr and ashb authored Apr 11, 2022
1 parent 1a8b8f5 commit 8af7712
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 27 deletions.
37 changes: 32 additions & 5 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]:
def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]:
# TODO: Find a way to cache this.
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom
from airflow.models.xcom_arg import XComArg

expansion_kwargs = self._get_expansion_kwargs()
Expand All @@ -518,19 +519,45 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]:
map_lengths.update((k, len(v)) for k, v in expansion_kwargs.items() if not isinstance(v, XComArg))

# Build a reverse mapping of what arguments each task contributes to.
dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
non_mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
for k, v in expansion_kwargs.items():
if not isinstance(v, XComArg):
continue
dep_keys[v.operator.task_id].add(k)

if v.operator.is_mapped:
mapped_dep_keys[v.operator.task_id].add(k)
else:
non_mapped_dep_keys[v.operator.task_id].add(k)
# TODO: It's not possible now, but in the future (AIP-42 Phase 2)
# we will add support for depending on one single mapped task
# instance. When that happens, we need to further analyze the mapped
# case to contain only tasks we depend on "as a whole", and put
# those we only depend on individually to the non-mapped lookup.

# Collect lengths from unmapped upstreams.
taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
TaskMap.dag_id == self.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id.in_(list(dep_keys)),
TaskMap.task_id.in_(non_mapped_dep_keys),
TaskMap.map_index < 0,
)
for task_id, length in taskmap_query:
for mapped_arg_name in dep_keys[task_id]:
for mapped_arg_name in non_mapped_dep_keys[task_id]:
map_lengths[mapped_arg_name] += length

# Collect lengths from mapped upstreams.
xcom_query = (
session.query(XCom.task_id, func.count(XCom.map_index))
.group_by(XCom.task_id)
.filter(
XCom.dag_id == self.dag_id,
XCom.run_id == run_id,
XCom.task_id.in_(mapped_dep_keys),
XCom.map_index >= 0,
)
)
for task_id, length in xcom_query:
for mapped_arg_name in mapped_dep_keys[task_id]:
map_lengths[mapped_arg_name] += length

if len(map_lengths) < len(expansion_kwargs):
Expand Down
92 changes: 79 additions & 13 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
IO,
TYPE_CHECKING,
Any,
ContextManager,
Dict,
Generator,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -69,6 +71,8 @@
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.sqltypes import BigInteger
Expand Down Expand Up @@ -295,6 +299,71 @@ def clear_task_instances(
dr.start_date = None


class _LazyXComAccessIterator(collections.abc.Iterator):
__slots__ = ['_cm', '_it']

def __init__(self, cm: ContextManager[Query]):
self._cm = cm
self._it = None

def __del__(self):
if self._it:
self._cm.__exit__(None, None, None)

def __iter__(self):
return self

def __next__(self):
if not self._it:
self._it = iter(self._cm.__enter__())
return XCom.deserialize_value(next(self._it))


class _LazyXComAccess(collections.abc.Sequence):
"""Wrapper to lazily pull XCom with a sequence-like interface.
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
"""

def __init__(self, query: Query):
self._q = query
self._len = None

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
self._len = query.count()
return self._len

def __iter__(self):
return _LazyXComAccessIterator(self._get_bound_query())

def __getitem__(self, key):
if not isinstance(key, int):
raise ValueError("only support index access for now")
try:
with self._get_bound_query() as query:
r = query.offset(key).limit(1).one()
except NoResultFound:
raise IndexError(key) from None
return XCom.deserialize_value(r)

@contextlib.contextmanager
def _get_bound_query(self) -> Generator[Query, None, None]:
# Do we have a valid session already?
if self._q.session and self._q.session.is_active:
yield self._q
return

session = settings.Session()
try:
yield self._q.with_session(session)
finally:
session.close()


class TaskInstanceKey(NamedTuple):
"""Key used to identify task instance."""

Expand Down Expand Up @@ -2233,14 +2302,19 @@ def set_duration(self) -> None:
self.log.debug("Task Duration set to %s", self.duration)

def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None:
if not task.has_mapped_dependants():
# TODO: We don't push TaskMap for mapped task instances because it's not
# currently possible for a downstream to depend on one individual mapped
# task instance, only a task as a whole. This will change in AIP-42
# Phase 2, and we'll need to further analyze the mapped task case.
if task.is_mapped or not task.has_mapped_dependants():
return
if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
raise UnmappableXComTypePushed(value)
task_map = TaskMap.from_task_instance_xcom(self, value)
max_map_length = conf.getint("core", "max_map_length", fallback=1024)
if len(value) > max_map_length:
if task_map.length > max_map_length:
raise UnmappableXComLengthPushed(value, max_map_length)
session.merge(TaskMap.from_task_instance_xcom(self, value))
session.merge(task_map)

@provide_session
def xcom_push(
Expand Down Expand Up @@ -2351,21 +2425,13 @@ def xcom_pull(
# make sure all XComs come from one task and run (for task_ids=None
# and include_prior_dates=True), and re-order by map index (reset
# needed because XCom.get_many() orders by XCom timestamp).
query = (
return _LazyXComAccess(
query.with_entities(XCom.value)
.filter(XCom.run_id == first.run_id, XCom.task_id == first.task_id)
.filter(XCom.run_id == first.run_id, XCom.task_id == first.task_id, XCom.map_index >= 0)
.order_by(None)
.order_by(XCom.map_index.asc())
)

def iter_xcom_values(query):
# The session passed to xcom_pull() may die before this is
# iterated through, so we need to bind to a new session.
for r in query.with_session(settings.Session()):
yield XCom.deserialize_value(r)

return iter_xcom_values(query)

# At this point either task_ids or map_indexes is explicitly multi-value.

results = (
Expand Down
41 changes: 32 additions & 9 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,10 +1089,7 @@ def test_xcom_pull_mapped(self, dag_maker, session):
assert ti_2.xcom_pull(["task_1"], map_indexes=[0, 1], session=session) == ["a", "b"]

assert ti_2.xcom_pull("task_1", map_indexes=1, session=session) == "b"

joined = ti_2.xcom_pull("task_1", session=session)
assert iter(joined) is joined, "should be iterator"
assert list(joined) == ["a", "b"]
assert list(ti_2.xcom_pull("task_1", session=session)) == ["a", "b"]

def test_xcom_pull_after_success(self, create_task_instance):
"""
Expand Down Expand Up @@ -2635,7 +2632,7 @@ def cmds():


@mock.patch("airflow.models.taskinstance.XCom.deserialize_value", side_effect=XCom.deserialize_value)
def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterator(mock_deserialize_value, dag_maker, session):
def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session):
"""Ensure we access XCom lazily when pulling from a mapped operator."""
with dag_maker(dag_id="test_xcom", session=session):
task_1 = DummyOperator.partial(task_id="task_1").expand()
Expand All @@ -2657,10 +2654,36 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterator(mock_deserialize_v
joined = ti_2.xcom_pull("task_1", session=session)
assert mock_deserialize_value.call_count == 0

# Only when we go through the iterator does deserialization happen.
assert next(joined) == "a"
# Only when we go through the iterable does deserialization happen.
it = iter(joined)
assert next(it) == "a"
assert mock_deserialize_value.call_count == 1
assert next(joined) == "b"
assert next(it) == "b"
assert mock_deserialize_value.call_count == 2
with pytest.raises(StopIteration):
next(joined)
next(it)


def test_ti_mapped_depends_on_mapped_xcom_arg(dag_maker, session):
with dag_maker(session=session) as dag:

@dag.task
def add_one(x):
return x + 1

two_three_four = add_one.expand(x=[1, 2, 3])
add_one.expand(x=two_three_four)

dagrun = dag_maker.create_dagrun()
for map_index in range(3):
ti = dagrun.get_task_instance("add_one", map_index=map_index)
ti.refresh_from_task(dag.get_task("add_one"))
ti.run()

task_345 = dag.get_task("add_one__1")
for ti in task_345.expand_mapped_task(dagrun.run_id, session=session):
ti.refresh_from_task(task_345)
ti.run()

query = XCom.get_many(run_id=dagrun.run_id, task_ids=["add_one__1"], session=session)
assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5]

0 comments on commit 8af7712

Please sign in to comment.