diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 95985712401cf..a36d7d5d43a7b 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -53,7 +53,12 @@ parse_retries, ) from airflow.models.dag import DAG, DagContext -from airflow.models.expandinput import EXPAND_INPUT_EMPTY, DictOfListsExpandInput, ExpandInput +from airflow.models.expandinput import ( + EXPAND_INPUT_EMPTY, + DictOfListsExpandInput, + ExpandInput, + ListOfDictsExpandInput, +) from airflow.models.mappedoperator import ( MappedOperator, ValidationSource, @@ -171,8 +176,17 @@ def __init__( op_args = op_args or [] op_kwargs = op_kwargs or {} - # Check that arguments can be binded - inspect.signature(python_callable).bind(*op_args, **op_kwargs) + # Check that arguments can be binded. There's a slight difference when + # we do validation for task-mapping: Since there's no guarantee we can + # receive enough arguments at parse time, we use bind_partial to simply + # check all the arguments we know are valid. Whether these are enough + # can only be known at execution time, when unmapping happens, and this + # is called without the _airflow_mapped_validation_only flag. + if kwargs.get("_airflow_mapped_validation_only"): + inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs) + else: + inspect.signature(python_callable).bind(*op_args, **op_kwargs) + self.multiple_outputs = multiple_outputs self.op_args = op_args self.op_kwargs = op_kwargs @@ -323,6 +337,13 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg: # to False to skip the checks on execution. return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) + def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg: + from airflow.models.xcom_arg import XComArg + + if not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") + return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) + def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ensure_xcomarg_return_value(expand_input.value) @@ -442,10 +463,11 @@ def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> D mapped_kwargs["op_kwargs"], fail_reason="mapping already partial", ) + + static_kwargs = {k for k, _ in self.op_kwargs_expand_input.iter_parse_time_resolved_kwargs()} self._combined_op_kwargs = {**self.partial_kwargs["op_kwargs"], **mapped_kwargs["op_kwargs"]} - self._already_resolved_op_kwargs = { - k for k, v in self.op_kwargs_expand_input.value.items() if isinstance(v, XComArg) - } + self._already_resolved_op_kwargs = {k for k in mapped_kwargs["op_kwargs"] if k not in static_kwargs} + kwargs = { "multiple_outputs": self.multiple_outputs, "python_callable": self.python_callable, diff --git a/airflow/exceptions.py b/airflow/exceptions.py index f1a8c1cb66473..7a91100f1122a 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -113,12 +113,26 @@ def __str__(self) -> str: class UnmappableXComTypePushed(AirflowException): """Raise when an unmappable type is pushed as a mapped downstream's dependency.""" - def __init__(self, value: Any) -> None: - super().__init__(value) + def __init__(self, value: Any, *values: Any) -> None: + super().__init__(value, *values) + + def __str__(self) -> str: + typename = type(self.args[0]).__qualname__ + for arg in self.args[1:]: + typename = f"{typename}[{type(arg).__qualname__}]" + return f"unmappable return type {typename!r}" + + +class UnmappableXComValuePushed(AirflowException): + """Raise when an invalid value is pushed as a mapped downstream's dependency.""" + + def __init__(self, value: Any, reason: str) -> None: + super().__init__(value, reason) self.value = value + self.reason = reason def __str__(self) -> str: - return f"unmappable return type {type(self.value).__qualname__!r}" + return f"unmappable return value {self.value!r} ({self.reason})" class UnmappableXComLengthPushed(AirflowException): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 05ef24a124b3c..2795a0f538787 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -756,6 +756,7 @@ def __init__( super().__init__() + kwargs.pop("_airflow_mapped_validation_only", None) if kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): raise AirflowException( @@ -1509,7 +1510,7 @@ def defer( def validate_mapped_arguments(cls, **kwargs: Any) -> None: """Validate arguments when this operator is being mapped.""" if cls.mapped_arguments_validated_by_init: - cls(**kwargs, _airflow_from_mapped=True) + cls(**kwargs, _airflow_from_mapped=True, _airflow_mapped_validation_only=True) def unmap(self, ctx: Union[None, Dict[str, Any], Tuple[Context, Session]]) -> "BaseOperator": """:meta private:""" diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 86623b41a1c7c..b5b922f9dfa5e 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -22,26 +22,36 @@ import collections.abc import functools import operator -from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union +from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Sequence, Sized, Union from sqlalchemy import func from sqlalchemy.orm import Session -from airflow.exceptions import UnmappableXComTypePushed +from airflow.compat.functools import cache +from airflow.exceptions import UnmappableXComTypePushed, UnmappableXComValuePushed from airflow.utils.context import Context if TYPE_CHECKING: from airflow.models.xcom_arg import XComArg +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + # BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any # mapping since we need the value to be ordered). Mappable = Union["XComArg", Sequence, dict] -MAPPABLE_LITERAL_TYPES = (dict, list) + +# For isinstance() check. +@cache +def get_mappable_types() -> tuple[type, ...]: + from airflow.models.xcom_arg import XComArg + + return (XComArg, list, tuple, dict) class NotFullyPopulated(RuntimeError): """Raise when ``get_map_lengths`` cannot populate all mapping metadata. + This is generally due to not all upstream tasks have finished when the function is called. """ @@ -67,10 +77,20 @@ def validate_xcom(value: Any) -> None: if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)): raise UnmappableXComTypePushed(value) + def get_unresolved_kwargs(self) -> dict[str, Any]: + """Get the kwargs dict that can be inferred without resolving.""" + return self.value + + def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + """Generate kwargs with values available on parse-time.""" + from airflow.models.xcom_arg import XComArg + + return ((k, v) for k, v in self.value.items() if not isinstance(v, XComArg)) + def get_parse_time_mapped_ti_count(self) -> int | None: if not self.value: return 0 - literal_values = [len(v) for v in self.value.values() if isinstance(v, MAPPABLE_LITERAL_TYPES)] + literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()] if len(literal_values) != len(self.value): return None # None-literal type encountered, so give up. return functools.reduce(operator.mul, literal_values, 1) @@ -184,12 +204,77 @@ def resolve(self, context: Context, session: Session) -> dict[str, Any]: return {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()} -ExpandInput = DictOfListsExpandInput +class ListOfDictsExpandInput(NamedTuple): + """Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand_kwargs(xcom_arg)``. + """ + + value: XComArg + + @staticmethod + def validate_xcom(value: Any) -> None: + if not isinstance(value, collections.abc.Collection): + raise UnmappableXComTypePushed(value) + if isinstance(value, (str, bytes, collections.abc.Mapping)): + raise UnmappableXComTypePushed(value) + for item in value: + if not isinstance(item, collections.abc.Mapping): + raise UnmappableXComTypePushed(value, item) + if not all(isinstance(k, str) for k in item): + raise UnmappableXComValuePushed(value, reason="dict keys must be str") + + def get_unresolved_kwargs(self) -> dict[str, Any]: + """Get the kwargs dict that can be inferred without resolving. + + Since the list-of-dicts case relies entirely on run-time XCom, there's + no kwargs structure available, so this just returns an empty dict. + """ + return {} + + def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + return () + + def get_parse_time_mapped_ti_count(self) -> int | None: + return None + + def get_total_map_length(self, run_id: str, *, session: Session) -> int: + from airflow.models.taskmap import TaskMap + from airflow.models.xcom import XCom + + task = self.value.operator + if task.is_mapped: + query = session.query(func.count(XCom.map_index)).filter( + XCom.dag_id == task.dag_id, + XCom.run_id == run_id, + XCom.task_id == task.task_id, + XCom.map_index >= 0, + ) + else: + query = session.query(TaskMap.length).filter( + TaskMap.dag_id == task.dag_id, + TaskMap.run_id == run_id, + TaskMap.task_id == task.task_id, + TaskMap.map_index < 0, + ) + value = query.scalar() + if value is None: + raise NotFullyPopulated({"expand_kwargs() argument"}) + return value + + def resolve(self, context: Context, session: Session) -> dict[str, Any]: + map_index = context["ti"].map_index + if map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + # Validation should be done when the upstream returns. + return self.value.resolve(context, session)[map_index] + EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. _EXPAND_INPUT_TYPES = { "dict-of-lists": DictOfListsExpandInput, + "list-of-dicts": ListOfDictsExpandInput, } diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 9e75e9f4aa94e..a883ff2404d5e 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -61,11 +61,12 @@ TaskStateChangeCallback, ) from airflow.models.expandinput import ( - MAPPABLE_LITERAL_TYPES, DictOfListsExpandInput, ExpandInput, + ListOfDictsExpandInput, Mappable, NotFullyPopulated, + get_mappable_types, ) from airflow.models.pool import Pool from airflow.serialization.enums import DagAttributeTypes @@ -86,19 +87,12 @@ from airflow.models.dag import DAG from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance + from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup ValidationSource = Union[Literal["expand"], Literal["partial"]] -# For isinstance() check. -@cache -def get_mappable_types() -> Tuple[type, ...]: - from airflow.models.xcom_arg import XComArg - - return (XComArg,) + MAPPABLE_LITERAL_TYPES - - def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, value: Dict[str, Any]) -> None: # use a dict so order of args is same as code order unknown_args = value.copy() @@ -198,6 +192,13 @@ def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": # to False to skip the checks on execution. return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) + def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> "MappedOperator": + from airflow.models.xcom_arg import XComArg + + if not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") + return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) + def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator": from airflow.operators.empty import EmptyOperator @@ -541,12 +542,10 @@ def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, Session]]) -> D operation on the list-of-dicts variant before execution time, an empty dict will be returned for this case. """ - kwargs = self._get_specified_expand_input() + expand_input = self._get_specified_expand_input() if resolve is not None: - return kwargs.resolve(*resolve) - if isinstance(kwargs, DictOfListsExpandInput): - return kwargs.value - return {} + return expand_input.resolve(*resolve) + return expand_input.get_unresolved_kwargs() def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]: """Get init kwargs to unmap the underlying operator class. diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 17c339620be8c..8110240d4c302 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -1422,3 +1422,24 @@ def test_schedule_tis_map_index(dag_maker, session): assert ti0.state == TaskInstanceState.SUCCESS assert ti1.state == TaskInstanceState.SCHEDULED assert ti2.state == TaskInstanceState.SUCCESS + + +def test_mapped_expand_kwargs(dag_maker): + with dag_maker() as dag: + + @task + def task_1(): + return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}] + + MockOperator.partial(task_id="task_2").expand_kwargs(task_1()) + + dr: DagRun = dag_maker.create_dagrun() + assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1 + + ti1 = dr.get_task_instance("task_1") + ti1.refresh_from_task(dag.get_task("task_1")) + ti1.run() + + dr.task_instance_scheduling_decisions() + ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"} + assert ti_states == {0: None, 1: None, 2: None} diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 92d6226097959..09ab87524b062 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -272,5 +272,110 @@ def __init__(self, value, arg1, **kwargs): assert isinstance(op, MyOperator) assert op.value == "{{ ds }}", "Should not be templated!" - assert op.arg1 == "{{ ds }}" + assert op.arg1 == "{{ ds }}", "Should not be templated!" + assert op.arg2 == "a" + + +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'), + pytest.param( + 3, + [(0, 'success'), (1, 'success'), (2, 'success')], + id='all-tis-exist', + ), + pytest.param( + 5, + [ + (0, 'success'), + (1, 'success'), + (2, 'success'), + (3, TaskInstanceState.REMOVED), + (4, TaskInstanceState.REMOVED), + ], + id="tis-to-be-removed", + ), + ), +) +def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [{"arg1": "a"}, {"arg1": "b"}, {"arg1": "c"}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1)) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + mapped.expand_mapped_task(dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected + + +@pytest.mark.parametrize( + "map_index, expected", + [ + pytest.param(0, "{{ ds }}", id="0"), + pytest.param(1, 2, id="1"), + ], +) +def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand_kwargs(XComArg(task1)) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": '{{ ds }}'}, {"arg1": 2}], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=2, + keys=None, + ) + ) + session.flush() + + ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + ti.refresh_from_task(mapped) + ti.map_index = map_index + op = mapped.render_template_fields(context=ti.get_template_context(session=session)) + assert isinstance(op, MockOperator) + assert op.arg1 == expected assert op.arg2 == "a" diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e17f34bd78d9f..05ac8daae5b53 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -43,6 +43,7 @@ AirflowSkipException, UnmappableXComLengthPushed, UnmappableXComTypePushed, + UnmappableXComValuePushed, XComForMappingNotPushed, ) from airflow.models import ( @@ -86,6 +87,7 @@ from tests.test_utils import db from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_connections, clear_db_runs +from tests.test_utils.mock_operators import MockOperator @pytest.fixture @@ -2500,6 +2502,94 @@ def pull_something(value): assert ti.state == TaskInstanceState.FAILED assert str(ctx.value) == error_message + @pytest.mark.parametrize( + "return_value, exception_type, error_message", + [ + (123, UnmappableXComTypePushed, "unmappable return type 'int'"), + ([123], UnmappableXComTypePushed, "unmappable return type 'list[int]'"), + ([{1: 3}], UnmappableXComValuePushed, "unmappable return value [{1: 3}] (dict keys must be str)"), + (None, XComForMappingNotPushed, "did not push XCom for task mapping"), + ], + ) + def test_expand_kwargs_error_if_unmappable_type( + self, + dag_maker, + return_value, + exception_type, + error_message, + ): + """If an unmappable return value is used for expand_kwargs(), fail the task that pushed the XCom.""" + with dag_maker(dag_id="test_expand_kwargs_error_if_unmappable_type") as dag: + + @dag.task() + def push(): + return return_value + + MockOperator.partial(task_id="pull").expand_kwargs(push()) + + ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push") + with pytest.raises(exception_type) as ctx: + ti.run() + + assert dag_maker.session.query(TaskMap).count() == 0 + assert ti.state == TaskInstanceState.FAILED + assert str(ctx.value) == error_message + + @pytest.mark.parametrize( + "downstream, error_message", + [ + ("taskflow", "mapping already partial argument: arg2"), + ("classic", "unmappable or already specified argument: arg2"), + ], + ids=["taskflow", "classic"], + ) + @pytest.mark.parametrize("strict", [True, False], ids=["strict", "override"]) + def test_expand_kwargs_override_partial(self, dag_maker, session, downstream, error_message, strict): + class ClassicOperator(MockOperator): + def execute(self, context): + return (self.arg1, self.arg2) + + with dag_maker(dag_id="test_expand_kwargs_override_partial", session=session) as dag: + + @dag.task() + def push(): + return [{"arg1": "a"}, {"arg1": "b", "arg2": "c"}] + + push_task = push() + + ClassicOperator.partial(task_id="classic", arg2="d").expand_kwargs(push_task, strict=strict) + + @dag.task(task_id="taskflow") + def pull(arg1, arg2): + return (arg1, arg2) + + pull.partial(arg2="d").expand_kwargs(push_task, strict=strict) + + dr = dag_maker.create_dagrun() + next(ti for ti in dr.task_instances if ti.task_id == "push").run() + + decision = dr.task_instance_scheduling_decisions(session=session) + tis = {(ti.task_id, ti.map_index, ti.state): ti for ti in decision.schedulable_tis} + assert sorted(tis) == [ + ("classic", 0, None), + ("classic", 1, None), + ("taskflow", 0, None), + ("taskflow", 1, None), + ] + + ti = tis[((downstream, 0, None))] + ti.run() + ti.xcom_pull(task_ids=downstream, map_indexes=0, session=session) == ["a", "d"] + + ti = tis[((downstream, 1, None))] + if strict: + with pytest.raises(TypeError) as ctx: + ti.run() + assert str(ctx.value) == error_message + else: + ti.run() + ti.xcom_pull(task_ids=downstream, map_indexes=1, session=session) == ["b", "c"] + def test_error_if_upstream_does_not_push(self, dag_maker): """Fail the upstream task if it fails to push the XCom used for task mapping.""" with dag_maker(dag_id="test_not_recorded_for_unused") as dag: diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 418cb4af8928f..5751ae137c5de 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1777,6 +1777,52 @@ def test_operator_expand_xcomarg_serde(): assert xcom_arg.operator is serialized_dag.task_dict['op1'] +@pytest.mark.parametrize("strict", [True, False]) +def test_operator_expand_kwargs_serde(strict): + from airflow.models.xcom_arg import XComArg + + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1), strict=strict) + + serialized = SerializedBaseOperator._serialize(mapped) + assert serialized == { + '_is_empty': False, + '_is_mapped': True, + '_task_module': 'tests.test_utils.mock_operators', + '_task_type': 'MockOperator', + 'downstream_task_ids': [], + 'expand_input': { + "type": "list-of-dicts", + "value": {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}, + }, + 'partial_kwargs': {}, + 'task_id': 'task_2', + 'template_fields': ['arg1', 'arg2'], + 'template_ext': [], + 'template_fields_renderers': {}, + 'operator_extra_links': [], + 'ui_color': '#fff', + 'ui_fgcolor': '#000', + "_disallow_kwargs_override": strict, + '_expand_input_attr': 'expand_input', + } + + op = SerializedBaseOperator.deserialize_operator(serialized) + assert op.deps is MappedOperator.deps_for(BaseOperator) + assert op._disallow_kwargs_override == strict + + xcom_ref = op.expand_input.value + assert xcom_ref.task_id == 'op1' + assert xcom_ref.key == XCOM_RETURN_KEY + + serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value + assert isinstance(xcom_arg, XComArg) + assert xcom_arg.operator is serialized_dag.task_dict['op1'] + + def test_operator_expand_deserialized_unmap(): """Unmap a deserialized mapped operator should be similar to deserializing an non-mapped operator.""" normal = BashOperator(task_id='a', bash_command=[1, 2], executor_config={"a": "b"}) @@ -1891,6 +1937,89 @@ def x(arg1, arg2, arg3): } +@pytest.mark.parametrize("strict", [True, False]) +def test_taskflow_expand_kwargs_serde(strict): + from airflow.decorators import task + from airflow.models.xcom_arg import XComArg + from airflow.serialization.serialized_objects import _ExpandInputRef, _XComRef + + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + op1 = BaseOperator(task_id="op1") + + @task(retry_delay=30) + def x(arg1, arg2, arg3): + print(arg1, arg2, arg3) + + x.partial(arg1=[1, 2, {"a": "b"}]).expand_kwargs(XComArg(op1), strict=strict) + + original = dag.get_task("x") + + serialized = SerializedBaseOperator._serialize(original) + assert serialized == { + '_is_empty': False, + '_is_mapped': True, + '_task_module': 'airflow.decorators.python', + '_task_type': '_PythonDecoratedOperator', + 'downstream_task_ids': [], + 'partial_kwargs': { + 'op_args': [], + 'op_kwargs': { + '__type': 'dict', + '__var': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]}, + }, + 'retry_delay': {'__type': 'timedelta', '__var': 30.0}, + }, + 'op_kwargs_expand_input': { + "type": "list-of-dicts", + "value": { + "__type": "xcomref", + "__var": {'task_id': 'op1', 'key': 'return_value'}, + }, + }, + 'operator_extra_links': [], + 'ui_color': '#ffefeb', + 'ui_fgcolor': '#000', + 'task_id': 'x', + 'template_ext': [], + 'template_fields': ['op_args', 'op_kwargs'], + 'template_fields_renderers': {"op_args": "py", "op_kwargs": "py"}, + "_disallow_kwargs_override": strict, + '_expand_input_attr': 'op_kwargs_expand_input', + } + + deserialized = SerializedBaseOperator.deserialize_operator(serialized) + assert isinstance(deserialized, MappedOperator) + assert deserialized.deps is MappedOperator.deps_for(BaseOperator) + assert deserialized._disallow_kwargs_override == strict + assert deserialized.upstream_task_ids == set() + assert deserialized.downstream_task_ids == set() + + assert deserialized.op_kwargs_expand_input == _ExpandInputRef( + key="list-of-dicts", + value=_XComRef("op1", XCOM_RETURN_KEY), + ) + assert deserialized.partial_kwargs == { + "op_args": [], + "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, + "retry_delay": timedelta(seconds=30), + } + + # Ensure the serialized operator can also be correctly pickled, to ensure + # correct interaction between DAG pickling and serialization. This is done + # here so we don't need to duplicate tests between pickled and non-pickled + # DAGs everywhere else. + pickled = pickle.loads(pickle.dumps(deserialized)) + assert pickled.op_kwargs_expand_input == _ExpandInputRef( + "list-of-dicts", + _XComRef("op1", XCOM_RETURN_KEY), + ) + assert pickled.partial_kwargs == { + "op_args": [], + "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, + "retry_delay": timedelta(seconds=30), + } + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize( "is_inherit",