Skip to content

Commit

Permalink
Implement XComArg.zip()
Browse files Browse the repository at this point in the history
With the foundation layed down, this is relatively straightforward. The
only thing we need to be aware is the make the resolved value
index-able, and thus cannot use the built-in zip() (which returns an
iterator).
  • Loading branch information
uranusjr committed Aug 1, 2022
1 parent aa650d3 commit 98833b9
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 4 deletions.
100 changes: 96 additions & 4 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
raise NotImplementedError()

def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
raise NotImplementedError()
return MapXComArg(self, [f])

def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
return ZipXComArg([self, *others], fillvalue=fillvalue)

def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
"""Inspect length of pushed value for task-mapping.
Expand Down Expand Up @@ -285,8 +288,13 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]:

def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
if self.key != XCOM_RETURN_KEY:
raise ValueError
return MapXComArg(self, [f])
raise ValueError("cannot map against non-return XCom")
return super().map(f)

def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
if self.key != XCOM_RETURN_KEY:
raise ValueError("cannot map against non-return XCom")
return super().zip(*others, fillvalue=fillvalue)

def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -372,6 +380,7 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
yield from self.arg.iter_references()

def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
# Flatten arg.map(f1).map(f2) into one MapXComArg.
return MapXComArg(self.arg, [*self.callables, f])

def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
Expand All @@ -380,13 +389,96 @@ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[in
@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
value = self.arg.resolve(context, session=session)
assert isinstance(value, (Sequence, dict)) # Validation was done when XCom was pushed.
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
return _MapResult(value, self.callables)


class _ZipResult(Sequence):
def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None:
self.values = values
self.fillvalue = fillvalue

@staticmethod
def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any:
try:
return container[index]
except (IndexError, KeyError):
return fillvalue

def __getitem__(self, index: Any) -> Any:
if index >= len(self):
raise IndexError(index)
return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)

def __len__(self) -> int:
lengths = (len(v) for v in self.values)
if self.fillvalue is NOTSET:
return min(lengths)
return max(lengths)


class ZipXComArg(XComArg):
"""An XCom reference with ``zip()`` applied.
This is constructed from multiple XComArg instances, and presents an
iterable that "zips" them together like the built-in ``zip()`` (and
``itertools.zip_longest()`` if ``fillvalue`` is provided).
"""

def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
if not args:
raise ValueError("At least one input is required")
self.args = args
self.fillvalue = fillvalue

def __repr__(self) -> str:
args_iter = iter(self.args)
first = repr(next(args_iter))
rest = ", ".join(repr(arg) for arg in args_iter)
if self.fillvalue is NOTSET:
return f"{first}.zip({rest})"
return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"

def _serialize(self) -> Dict[str, Any]:
args = [serialize_xcom_arg(arg) for arg in self.args]
if self.fillvalue is NOTSET:
return {"args": args}
return {"args": args, "fillvalue": self.fillvalue}

@classmethod
def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
return cls(
[deserialize_xcom_arg(arg, dag) for arg in data["args"]],
fillvalue=data.get("fillvalue", NOTSET),
)

def iter_references(self) -> Iterator[Tuple["Operator", str]]:
for arg in self.args:
yield from arg.iter_references()

def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(self.args):
return None # If any of the referenced XComs is not ready, we are not ready either.
if self.fillvalue is NOTSET:
return min(ready_lengths)
return max(ready_lengths)

@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
values = [arg.resolve(context, session=session) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
return _ZipResult(values, fillvalue=self.fillvalue)


_XCOM_ARG_TYPES: Mapping[str, Type[XComArg]] = {
"": PlainXComArg,
"map": MapXComArg,
"zip": ZipXComArg,
}


Expand Down
44 changes: 44 additions & 0 deletions tests/models/test_xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator
from airflow.utils.types import NOTSET
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs

Expand Down Expand Up @@ -177,3 +178,46 @@ def push_xcom_value(key, value, **context):
)
op1 >> op2
dag.run()


@pytest.mark.parametrize(
"fillvalue, expected_results",
[
(NOTSET, {("a", 1), ("b", 2), ("c", 3)}),
(None, {("a", 1), ("b", 2), ("c", 3), (None, 4)}),
],
)
def test_xcom_zip(dag_maker, session, fillvalue, expected_results):
results = set()
with dag_maker(session=session) as dag:

@dag.task
def push_letters():
return ["a", "b", "c"]

@dag.task
def push_numbers():
return [1, 2, 3, 4]

@dag.task
def pull(value):
results.add(value)

pull.expand(value=push_letters().zip(push_numbers(), fillvalue=fillvalue))

dr = dag_maker.create_dagrun()

# Run "push_letters" and "push_numbers".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
for ti in decision.schedulable_tis:
ti.run(session=session)
session.commit()

# Run "pull".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
for ti in decision.schedulable_tis:
ti.run(session=session)

assert results == expected_results
44 changes: 44 additions & 0 deletions tests/models/test_xcom_arg_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,47 @@ def pull(value):
for ti in decision.schedulable_tis:
ti.run()
assert results == {"aa", "bb", "cc"}


def test_xcom_map_zip_nest(dag_maker, session):
results = set()

with dag_maker(session=session) as dag:

@dag.task
def push_letters():
return ["a", "b", "c", "d"]

@dag.task
def push_numbers():
return [1, 2, 3, 4]

@dag.task
def pull(value):
results.add(value)

doubled = push_numbers().map(lambda v: v * 2)
combined = doubled.zip(push_letters())

def convert_zipped(zipped):
letter, number = zipped
return letter * number

pull.expand(value=combined.map(convert_zipped))

dr = dag_maker.create_dagrun()

# Run "push_letters" and "push_numbers".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
for ti in decision.schedulable_tis:
ti.run(session=session)
session.commit()

# Run "pull".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
for ti in decision.schedulable_tis:
ti.run(session=session)

assert results == {"aa", "bbbb", "cccccc", "dddddddd"}

0 comments on commit 98833b9

Please sign in to comment.