From 98833b9759695d03670e9b7e0e2242b2d46e9636 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 26 Jul 2022 14:57:05 +0800 Subject: [PATCH] Implement XComArg.zip() With the foundation layed down, this is relatively straightforward. The only thing we need to be aware is the make the resolved value index-able, and thus cannot use the built-in zip() (which returns an iterator). --- airflow/models/xcom_arg.py | 100 ++++++++++++++++++++++++++++-- tests/models/test_xcom_arg.py | 44 +++++++++++++ tests/models/test_xcom_arg_map.py | 44 +++++++++++++ 3 files changed, 184 insertions(+), 4 deletions(-) diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 7a6d395e5f5ad..a4c2b4d46d6e3 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -183,7 +183,10 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]: raise NotImplementedError() def map(self, f: Callable[[Any], Any]) -> "MapXComArg": - raise NotImplementedError() + return MapXComArg(self, [f]) + + def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg": + return ZipXComArg([self, *others], fillvalue=fillvalue) def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: """Inspect length of pushed value for task-mapping. @@ -285,8 +288,13 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]: def map(self, f: Callable[[Any], Any]) -> "MapXComArg": if self.key != XCOM_RETURN_KEY: - raise ValueError - return MapXComArg(self, [f]) + raise ValueError("cannot map against non-return XCom") + return super().map(f) + + def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg": + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().zip(*others, fillvalue=fillvalue) def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: from airflow.models.taskmap import TaskMap @@ -372,6 +380,7 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]: yield from self.arg.iter_references() def map(self, f: Callable[[Any], Any]) -> "MapXComArg": + # Flatten arg.map(f1).map(f2) into one MapXComArg. return MapXComArg(self.arg, [*self.callables, f]) def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: @@ -380,13 +389,96 @@ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[in @provide_session def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: value = self.arg.resolve(context, session=session) - assert isinstance(value, (Sequence, dict)) # Validation was done when XCom was pushed. + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") return _MapResult(value, self.callables) +class _ZipResult(Sequence): + def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None: + self.values = values + self.fillvalue = fillvalue + + @staticmethod + def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any: + try: + return container[index] + except (IndexError, KeyError): + return fillvalue + + def __getitem__(self, index: Any) -> Any: + if index >= len(self): + raise IndexError(index) + return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) + + def __len__(self) -> int: + lengths = (len(v) for v in self.values) + if self.fillvalue is NOTSET: + return min(lengths) + return max(lengths) + + +class ZipXComArg(XComArg): + """An XCom reference with ``zip()`` applied. + + This is constructed from multiple XComArg instances, and presents an + iterable that "zips" them together like the built-in ``zip()`` (and + ``itertools.zip_longest()`` if ``fillvalue`` is provided). + """ + + def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: + if not args: + raise ValueError("At least one input is required") + self.args = args + self.fillvalue = fillvalue + + def __repr__(self) -> str: + args_iter = iter(self.args) + first = repr(next(args_iter)) + rest = ", ".join(repr(arg) for arg in args_iter) + if self.fillvalue is NOTSET: + return f"{first}.zip({rest})" + return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" + + def _serialize(self) -> Dict[str, Any]: + args = [serialize_xcom_arg(arg) for arg in self.args] + if self.fillvalue is NOTSET: + return {"args": args} + return {"args": args, "fillvalue": self.fillvalue} + + @classmethod + def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) + + def iter_references(self) -> Iterator[Tuple["Operator", str]]: + for arg in self.args: + yield from arg.iter_references() + + def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: + all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(self.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if self.fillvalue is NOTSET: + return min(ready_lengths) + return max(ready_lengths) + + @provide_session + def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: + values = [arg.resolve(context, session=session) for arg in self.args] + for value in values: + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") + return _ZipResult(values, fillvalue=self.fillvalue) + + _XCOM_ARG_TYPES: Mapping[str, Type[XComArg]] = { "": PlainXComArg, "map": MapXComArg, + "zip": ZipXComArg, } diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py index cd3b548285933..047412a248df8 100644 --- a/tests/models/test_xcom_arg.py +++ b/tests/models/test_xcom_arg.py @@ -19,6 +19,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator +from airflow.utils.types import NOTSET from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs @@ -177,3 +178,46 @@ def push_xcom_value(key, value, **context): ) op1 >> op2 dag.run() + + +@pytest.mark.parametrize( + "fillvalue, expected_results", + [ + (NOTSET, {("a", 1), ("b", 2), ("c", 3)}), + (None, {("a", 1), ("b", 2), ("c", 3), (None, 4)}), + ], +) +def test_xcom_zip(dag_maker, session, fillvalue, expected_results): + results = set() + with dag_maker(session=session) as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c"] + + @dag.task + def push_numbers(): + return [1, 2, 3, 4] + + @dag.task + def pull(value): + results.add(value) + + pull.expand(value=push_letters().zip(push_numbers(), fillvalue=fillvalue)) + + dr = dag_maker.create_dagrun() + + # Run "push_letters" and "push_numbers". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + session.commit() + + # Run "pull". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + + assert results == expected_results diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index 144eb20327596..9da732b93d360 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -257,3 +257,47 @@ def pull(value): for ti in decision.schedulable_tis: ti.run() assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_zip_nest(dag_maker, session): + results = set() + + with dag_maker(session=session) as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c", "d"] + + @dag.task + def push_numbers(): + return [1, 2, 3, 4] + + @dag.task + def pull(value): + results.add(value) + + doubled = push_numbers().map(lambda v: v * 2) + combined = doubled.zip(push_letters()) + + def convert_zipped(zipped): + letter, number = zipped + return letter * number + + pull.expand(value=combined.map(convert_zipped)) + + dr = dag_maker.create_dagrun() + + # Run "push_letters" and "push_numbers". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + session.commit() + + # Run "pull". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + + assert results == {"aa", "bbbb", "cccccc", "dddddddd"}