From b90fc14e0c35e679b164ef7bcab22d7a44e0210e Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 1 Aug 2022 17:12:31 +0800 Subject: [PATCH] Implement XComArg.zip(*xcom_args) (#25176) --- airflow/models/expandinput.py | 88 +----- airflow/models/mappedoperator.py | 8 +- airflow/models/xcom_arg.py | 269 ++++++++++++++++-- airflow/serialization/serialized_objects.py | 19 +- tests/decorators/test_python.py | 7 +- tests/models/test_xcom_arg.py | 44 +++ tests/models/test_xcom_arg_map.py | 44 +++ tests/serialization/test_dag_serialization.py | 28 +- 8 files changed, 372 insertions(+), 135 deletions(-) diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 5a94698f9ba46..1b8f7fa80e25c 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -24,13 +24,12 @@ import operator from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union -from sqlalchemy import func -from sqlalchemy.orm import Session - from airflow.compat.functools import cache from airflow.utils.context import Context if TYPE_CHECKING: + from sqlalchemy.orm import Session + from airflow.models.xcom_arg import XComArg ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] @@ -95,63 +94,16 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: If any arguments are not known right now (upstream task not finished), they will not be present in the dict. """ - from airflow.models.taskmap import TaskMap - from airflow.models.xcom import XCOM_RETURN_KEY, XCom from airflow.models.xcom_arg import XComArg - # Populate literal mapped arguments first. - map_lengths: dict[str, int] = collections.defaultdict(int) - map_lengths.update((k, len(v)) for k, v in self.value.items() if not isinstance(v, XComArg)) - - try: - dag_id = next(v.operator.dag_id for v in self.value.values() if isinstance(v, XComArg)) - except StopIteration: # All mapped arguments are literal. We're done. - return map_lengths - - # Build a reverse mapping of what arguments each task contributes to. - mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set) - non_mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set) - for k, v in self.value.items(): - if not isinstance(v, XComArg): - continue - assert v.operator.dag_id == dag_id - if v.operator.is_mapped: - mapped_dep_keys[v.operator.task_id].add(k) - else: - non_mapped_dep_keys[v.operator.task_id].add(k) - # TODO: It's not possible now, but in the future we may support - # depending on one single mapped task instance. When that happens, - # we need to further analyze the mapped case to contain only tasks - # we depend on "as a whole", and put those we only depend on - # individually to the non-mapped lookup. - - # Collect lengths from unmapped upstreams. - taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter( - TaskMap.dag_id == dag_id, - TaskMap.run_id == run_id, - TaskMap.task_id.in_(non_mapped_dep_keys), - TaskMap.map_index < 0, - ) - for task_id, length in taskmap_query: - for mapped_arg_name in non_mapped_dep_keys[task_id]: - map_lengths[mapped_arg_name] += length - - # Collect lengths from mapped upstreams. - xcom_query = ( - session.query(XCom.task_id, func.count(XCom.map_index)) - .group_by(XCom.task_id) - .filter( - XCom.dag_id == dag_id, - XCom.run_id == run_id, - XCom.key == XCOM_RETURN_KEY, - XCom.task_id.in_(mapped_dep_keys), - XCom.map_index >= 0, - ) + # TODO: This initiates one database call for each XComArg. Would it be + # more efficient to do one single db call and unpack the value here? + map_lengths_iterator = ( + (k, (v.get_task_map_length(run_id, session=session) if isinstance(v, XComArg) else len(v))) + for k, v in self.value.items() ) - for task_id, length in xcom_query: - for mapped_arg_name in mapped_dep_keys[task_id]: - map_lengths[mapped_arg_name] += length + map_lengths = {k: v for k, v in map_lengths_iterator if v is not None} if len(map_lengths) < len(self.value): raise NotFullyPopulated(set(self.value).difference(map_lengths)) return map_lengths @@ -228,28 +180,10 @@ def get_parse_time_mapped_ti_count(self) -> int | None: return None def get_total_map_length(self, run_id: str, *, session: Session) -> int: - from airflow.models.taskmap import TaskMap - from airflow.models.xcom import XCom - - task = self.value.operator - if task.is_mapped: - query = session.query(func.count(XCom.map_index)).filter( - XCom.dag_id == task.dag_id, - XCom.run_id == run_id, - XCom.task_id == task.task_id, - XCom.map_index >= 0, - ) - else: - query = session.query(TaskMap.length).filter( - TaskMap.dag_id == task.dag_id, - TaskMap.run_id == run_id, - TaskMap.task_id == task.task_id, - TaskMap.map_index < 0, - ) - value = query.scalar() - if value is None: + length = self.value.get_task_map_length(run_id, session=session) + if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) - return value + return length def resolve(self, context: Context, session: Session) -> Mapping[str, Any]: map_index = context["ti"].map_index diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 53ea072a799ed..f62d1c687b58e 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -138,8 +138,9 @@ def ensure_xcomarg_return_value(arg: Any) -> None: from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg if isinstance(arg, XComArg): - if arg.key != XCOM_RETURN_KEY: - raise ValueError(f"cannot map over XCom with custom key {arg.key!r} from {arg.operator}") + for operator, key in arg.iter_references(): + if key != XCOM_RETURN_KEY: + raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") elif not is_container(arg): return elif isinstance(arg, collections.abc.Mapping): @@ -704,7 +705,8 @@ def iter_mapped_dependencies(self) -> Iterator["Operator"]: from airflow.models.xcom_arg import XComArg for ref in XComArg.iter_xcom_args(self._get_specified_expand_input()): - yield ref.operator + for operator, _ in ref.iter_references(): + yield operator @cached_property def parse_time_mapped_ti_count(self) -> Optional[int]: diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 0a602d0487ac8..a4c2b4d46d6e3 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -15,7 +15,25 @@ # specific language governing permissions and limitations # under the License. # -from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Sequence, Type, Union, overload +import inspect +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + overload, +) + +from sqlalchemy import func +from sqlalchemy.orm import Session from airflow.exceptions import AirflowException from airflow.models.abstractoperator import AbstractOperator @@ -27,8 +45,7 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from sqlalchemy.orm import Session - + from airflow.models.dag import DAG from airflow.models.operator import Operator @@ -65,9 +82,6 @@ class XComArg(DependencyMixin): i.e. the referenced operator's return value. """ - operator: "Operator" - key: str - @overload def __new__(cls: Type["XComArg"], operator: "Operator", key: str = XCOM_RETURN_KEY) -> "XComArg": """Called when the user writes ``XComArg(...)`` directly.""" @@ -109,17 +123,18 @@ def apply_upstream_relationship(op: "Operator", arg: Any): sets the relationship to ``op`` on any found. """ for ref in XComArg.iter_xcom_args(arg): - op.set_upstream(ref.operator) + for operator, _ in ref.iter_references(): + op.set_upstream(operator) @property def roots(self) -> List[DAGNode]: """Required by TaskMixin""" - return [self.operator] + return [op for op, _ in self.iter_references()] @property def leaves(self) -> List[DAGNode]: """Required by TaskMixin""" - return [self.operator] + return [op for op, _ in self.iter_references()] def set_upstream( self, @@ -127,7 +142,8 @@ def set_upstream( edge_modifier: Optional[EdgeModifier] = None, ): """Proxy to underlying operator set_upstream method. Required by TaskMixin.""" - self.operator.set_upstream(task_or_task_list, edge_modifier) + for operator, _ in self.iter_references(): + operator.set_upstream(task_or_task_list, edge_modifier) def set_downstream( self, @@ -135,9 +151,51 @@ def set_downstream( edge_modifier: Optional[EdgeModifier] = None, ): """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" - self.operator.set_downstream(task_or_task_list, edge_modifier) + for operator, _ in self.iter_references(): + operator.set_downstream(task_or_task_list, edge_modifier) + + def _serialize(self) -> Dict[str, Any]: + """Called by DAG serialization. + + The implementation should be the inverse function to ``deserialize``, + returning a data dict converted from this XComArg derivative. DAG + serialization does not call this directly, but ``serialize_xcom_arg`` + instead, which adds additional information to dispatch deserialization + to the correct class. + """ + raise NotImplementedError() + + @classmethod + def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> "XComArg": + """Called when deserializing a DAG. + + The implementation should be the inverse function to ``serialize``, + implementing given a data dict converted from this XComArg derivative, + how the original XComArg should be created. DAG serialization relies on + additional information added in ``serialize_xcom_arg`` to dispatch data + dicts to the correct ``_deserialize`` information, so this function does + not need to validate whether the incoming data contains correct keys. + """ + raise NotImplementedError() + + def iter_references(self) -> Iterator[Tuple["Operator", str]]: + """Iterate through (operator, key) references.""" + raise NotImplementedError() def map(self, f: Callable[[Any], Any]) -> "MapXComArg": + return MapXComArg(self, [f]) + + def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg": + return ZipXComArg([self, *others], fillvalue=fillvalue) + + def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: + """Inspect length of pushed value for task-mapping. + + This is used to determine how many task instances the scheduler should + create for a downstream using this XComArg for task-mapping. + + *None* may be returned if the depended XCom has not been pushed. + """ raise NotImplementedError() def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: @@ -166,7 +224,7 @@ def __init__(self, operator: "Operator", key: str = XCOM_RETURN_KEY): self.operator = operator self.key = key - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, PlainXComArg): return NotImplemented return self.operator == other.operator and self.key == other.key @@ -191,7 +249,12 @@ def __iter__(self): """ raise TypeError("'XComArg' object is not iterable") - def __str__(self): + def __repr__(self) -> str: + if self.key == XCOM_RETURN_KEY: + return f"XComArg({self.operator!r})" + return f"XComArg({self.operator!r}, {self.key!r})" + + def __str__(self) -> str: """ Backward compatibility for old-style jinja used in Airflow Operators @@ -203,20 +266,57 @@ def __str__(self): """ xcom_pull_kwargs = [ f"task_ids='{self.operator.task_id}'", - f"dag_id='{self.operator.dag.dag_id}'", + f"dag_id='{self.operator.dag_id}'", ] if self.key is not None: xcom_pull_kwargs.append(f"key='{self.key}'") - xcom_pull_kwargs = ", ".join(xcom_pull_kwargs) + xcom_pull_str = ", ".join(xcom_pull_kwargs) # {{{{ are required for escape {{ in f-string - xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}" + xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" return xcom_pull + def _serialize(self) -> Dict[str, Any]: + return {"task_id": self.operator.task_id, "key": self.key} + + @classmethod + def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg: + return cls(dag.get_task(data["task_id"]), data["key"]) + + def iter_references(self) -> Iterator[Tuple["Operator", str]]: + yield self.operator, self.key + def map(self, f: Callable[[Any], Any]) -> "MapXComArg": if self.key != XCOM_RETURN_KEY: - raise ValueError - return MapXComArg(self, [f]) + raise ValueError("cannot map against non-return XCom") + return super().map(f) + + def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg": + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot map against non-return XCom") + return super().zip(*others, fillvalue=fillvalue) + + def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: + from airflow.models.taskmap import TaskMap + from airflow.models.xcom import XCom + + task = self.operator + if task.is_mapped: + query = session.query(func.count(XCom.map_index)).filter( + XCom.dag_id == task.dag_id, + XCom.run_id == run_id, + XCom.task_id == task.task_id, + XCom.map_index >= 0, + XCom.key == XCOM_RETURN_KEY, + ) + else: + query = session.query(TaskMap.length).filter( + TaskMap.dag_id == task.dag_id, + TaskMap.run_id == run_id, + TaskMap.task_id == task.task_id, + TaskMap.map_index < 0, + ) + return query.scalar() @provide_session def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: @@ -257,23 +357,140 @@ class MapXComArg(XComArg): convert the pulled XCom value. """ - def __init__(self, arg: PlainXComArg, callables: Sequence[Callable[[Any], Any]]) -> None: + def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None: self.arg = arg self.callables = callables - @property - def operator(self) -> "Operator": # type: ignore[override] - return self.arg.operator + def __repr__(self) -> str: + return f"{self.arg!r}.map([{len(self.callables)} functions])" - @property - def key(self) -> str: # type: ignore[override] - return self.arg.key + def _serialize(self) -> Dict[str, Any]: + return { + "arg": serialize_xcom_arg(self.arg), + "callables": [inspect.getsource(c) for c in self.callables], + } + + @classmethod + def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + def iter_references(self) -> Iterator[Tuple["Operator", str]]: + yield from self.arg.iter_references() def map(self, f: Callable[[Any], Any]) -> "MapXComArg": + # Flatten arg.map(f1).map(f2) into one MapXComArg. return MapXComArg(self.arg, [*self.callables, f]) + def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: + return self.arg.get_task_map_length(run_id, session=session) + @provide_session def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: value = self.arg.resolve(context, session=session) - assert isinstance(value, (Sequence, dict)) # Validation was done when XCom was pushed. + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") return _MapResult(value, self.callables) + + +class _ZipResult(Sequence): + def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None: + self.values = values + self.fillvalue = fillvalue + + @staticmethod + def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any: + try: + return container[index] + except (IndexError, KeyError): + return fillvalue + + def __getitem__(self, index: Any) -> Any: + if index >= len(self): + raise IndexError(index) + return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) + + def __len__(self) -> int: + lengths = (len(v) for v in self.values) + if self.fillvalue is NOTSET: + return min(lengths) + return max(lengths) + + +class ZipXComArg(XComArg): + """An XCom reference with ``zip()`` applied. + + This is constructed from multiple XComArg instances, and presents an + iterable that "zips" them together like the built-in ``zip()`` (and + ``itertools.zip_longest()`` if ``fillvalue`` is provided). + """ + + def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: + if not args: + raise ValueError("At least one input is required") + self.args = args + self.fillvalue = fillvalue + + def __repr__(self) -> str: + args_iter = iter(self.args) + first = repr(next(args_iter)) + rest = ", ".join(repr(arg) for arg in args_iter) + if self.fillvalue is NOTSET: + return f"{first}.zip({rest})" + return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" + + def _serialize(self) -> Dict[str, Any]: + args = [serialize_xcom_arg(arg) for arg in self.args] + if self.fillvalue is NOTSET: + return {"args": args} + return {"args": args, "fillvalue": self.fillvalue} + + @classmethod + def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) + + def iter_references(self) -> Iterator[Tuple["Operator", str]]: + for arg in self.args: + yield from arg.iter_references() + + def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]: + all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(self.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if self.fillvalue is NOTSET: + return min(ready_lengths) + return max(ready_lengths) + + @provide_session + def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any: + values = [arg.resolve(context, session=session) for arg in self.args] + for value in values: + if not isinstance(value, (Sequence, dict)): + raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") + return _ZipResult(values, fillvalue=self.fillvalue) + + +_XCOM_ARG_TYPES: Mapping[str, Type[XComArg]] = { + "": PlainXComArg, + "map": MapXComArg, + "zip": ZipXComArg, +} + + +def serialize_xcom_arg(value: XComArg) -> Dict[str, Any]: + """DAG serialization interface.""" + key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value)) + if key: + return {"type": key, **value._serialize()} + return value._serialize() + + +def deserialize_xcom_arg(data: Dict[str, Any], dag: "DAG") -> XComArg: + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index e24aa826a4bca..3ace47e2b0ef1 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -43,7 +43,7 @@ from airflow.models.operator import Operator from airflow.models.param import Param, ParamsDict from airflow.models.taskmixin import DAGNode -from airflow.models.xcom_arg import XComArg +from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.operators.trigger_dagrun import TriggerDagRunOperator from airflow.providers_manager import ProvidersManager from airflow.sensors.external_task import ExternalTaskSensor @@ -202,11 +202,10 @@ class _XComRef(NamedTuple): post-process it in ``deserialize_dag``. """ - task_id: str - key: str + data: dict def deref(self, dag: DAG) -> XComArg: - return XComArg(operator=dag.get_task(self.task_id), key=self.key) + return deserialize_xcom_arg(self.data, dag) class _ExpandInputRef(NamedTuple): @@ -393,7 +392,7 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r elif isinstance(var, Param): return cls._encode(cls._serialize_param(var), type_=DAT.PARAM) elif isinstance(var, XComArg): - return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF) + return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF) elif isinstance(var, Dataset): return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET) else: @@ -440,7 +439,7 @@ def _deserialize(cls, encoded_var: Any) -> Any: elif type_ == DAT.PARAM: return cls._deserialize_param(var) elif type_ == DAT.XCOM_REF: - return cls._deserialize_xcomref(var) + return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG. elif type_ == DAT.DATASET: return Dataset(**var) else: @@ -545,14 +544,6 @@ def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict: return ParamsDict(op_params) - @classmethod - def _serialize_xcomarg(cls, arg: XComArg) -> dict: - return {"key": arg.key, "task_id": arg.operator.task_id} - - @classmethod - def _deserialize_xcomref(cls, encoded: dict) -> _XComRef: - return _XComRef(key=encoded['key'], task_id=encoded['task_id']) - class DependencyDetector: """ diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 58ae1c5f8792d..199366df8ed9c 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -32,7 +32,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom import XCOM_RETURN_KEY -from airflow.models.xcom_arg import XComArg +from airflow.models.xcom_arg import PlainXComArg, XComArg from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.task_group import TaskGroup @@ -649,13 +649,16 @@ def product(number: int, multiple: int): product.partial(multiple=2) # No operator is actually created. + assert isinstance(doubled, PlainXComArg) + assert isinstance(trippled, PlainXComArg) + assert isinstance(quadrupled, PlainXComArg) + assert dag.task_dict == { "product": quadrupled.operator, "product__1": doubled.operator, "product__2": trippled.operator, } - assert isinstance(doubled, XComArg) assert isinstance(doubled.operator, DecoratedMappedOperator) assert doubled.operator.op_kwargs_expand_input == DictOfListsExpandInput({"number": literal}) assert doubled.operator.partial_kwargs["op_kwargs"] == {"multiple": 2} diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py index cd3b548285933..047412a248df8 100644 --- a/tests/models/test_xcom_arg.py +++ b/tests/models/test_xcom_arg.py @@ -19,6 +19,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator +from airflow.utils.types import NOTSET from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs @@ -177,3 +178,46 @@ def push_xcom_value(key, value, **context): ) op1 >> op2 dag.run() + + +@pytest.mark.parametrize( + "fillvalue, expected_results", + [ + (NOTSET, {("a", 1), ("b", 2), ("c", 3)}), + (None, {("a", 1), ("b", 2), ("c", 3), (None, 4)}), + ], +) +def test_xcom_zip(dag_maker, session, fillvalue, expected_results): + results = set() + with dag_maker(session=session) as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c"] + + @dag.task + def push_numbers(): + return [1, 2, 3, 4] + + @dag.task + def pull(value): + results.add(value) + + pull.expand(value=push_letters().zip(push_numbers(), fillvalue=fillvalue)) + + dr = dag_maker.create_dagrun() + + # Run "push_letters" and "push_numbers". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + session.commit() + + # Run "pull". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + + assert results == expected_results diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index 144eb20327596..9da732b93d360 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -257,3 +257,47 @@ def pull(value): for ti in decision.schedulable_tis: ti.run() assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_zip_nest(dag_maker, session): + results = set() + + with dag_maker(session=session) as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c", "d"] + + @dag.task + def push_numbers(): + return [1, 2, 3, 4] + + @dag.task + def pull(value): + results.add(value) + + doubled = push_numbers().map(lambda v: v * 2) + combined = doubled.zip(push_letters()) + + def convert_zipped(zipped): + letter, number = zipped + return letter * number + + pull.expand(value=combined.map(convert_zipped)) + + dr = dag_maker.create_dagrun() + + # Run "push_letters" and "push_numbers". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + session.commit() + + # Run "pull". + decision = dr.task_instance_scheduling_decisions(session=session) + assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis) + for ti in decision.schedulable_tis: + ti.run(session=session) + + assert results == {"aa", "bbbb", "cccccc", "dddddddd"} diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 5751ae137c5de..c6444b8aaa5da 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1731,7 +1731,8 @@ def test_operator_expand_serde(): def test_operator_expand_xcomarg_serde(): - from airflow.models.xcom_arg import XComArg + from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: task1 = BaseOperator(task_id="op1") @@ -1766,20 +1767,21 @@ def test_operator_expand_xcomarg_serde(): op = SerializedBaseOperator.deserialize_operator(serialized) assert op.deps is MappedOperator.deps_for(BaseOperator) - arg = op.expand_input.value['arg2'] - assert arg.task_id == 'op1' - assert arg.key == XCOM_RETURN_KEY + # The XComArg can't be deserialized before the DAG is. + xcom_ref = op.expand_input.value['arg2'] + assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}) serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value['arg2'] - assert isinstance(xcom_arg, XComArg) + assert isinstance(xcom_arg, PlainXComArg) assert xcom_arg.operator is serialized_dag.task_dict['op1'] @pytest.mark.parametrize("strict", [True, False]) def test_operator_expand_kwargs_serde(strict): - from airflow.models.xcom_arg import XComArg + from airflow.models.xcom_arg import PlainXComArg, XComArg + from airflow.serialization.serialized_objects import _XComRef with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: task1 = BaseOperator(task_id="op1") @@ -1812,14 +1814,14 @@ def test_operator_expand_kwargs_serde(strict): assert op.deps is MappedOperator.deps_for(BaseOperator) assert op._disallow_kwargs_override == strict + # The XComArg can't be deserialized before the DAG is. xcom_ref = op.expand_input.value - assert xcom_ref.task_id == 'op1' - assert xcom_ref.key == XCOM_RETURN_KEY + assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}) serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value - assert isinstance(xcom_arg, XComArg) + assert isinstance(xcom_arg, PlainXComArg) assert xcom_arg.operator is serialized_dag.task_dict['op1'] @@ -1913,7 +1915,7 @@ def x(arg1, arg2, arg3): assert deserialized.op_kwargs_expand_input == _ExpandInputRef( key="dict-of-lists", - value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", XCOM_RETURN_KEY)}, + value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, ) assert deserialized.partial_kwargs == { "op_args": [], @@ -1928,7 +1930,7 @@ def x(arg1, arg2, arg3): pickled = pickle.loads(pickle.dumps(deserialized)) assert pickled.op_kwargs_expand_input == _ExpandInputRef( key="dict-of-lists", - value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", XCOM_RETURN_KEY)}, + value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, ) assert pickled.partial_kwargs == { "op_args": [], @@ -1996,7 +1998,7 @@ def x(arg1, arg2, arg3): assert deserialized.op_kwargs_expand_input == _ExpandInputRef( key="list-of-dicts", - value=_XComRef("op1", XCOM_RETURN_KEY), + value=_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), ) assert deserialized.partial_kwargs == { "op_args": [], @@ -2011,7 +2013,7 @@ def x(arg1, arg2, arg3): pickled = pickle.loads(pickle.dumps(deserialized)) assert pickled.op_kwargs_expand_input == _ExpandInputRef( "list-of-dicts", - _XComRef("op1", XCOM_RETURN_KEY), + _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), ) assert pickled.partial_kwargs == { "op_args": [],