diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index a20280e5a1111..f39ed2c818fd6 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -33,6 +33,7 @@ from airflow.models.operator import Operator from airflow.models.xcom_arg import XComArg + from airflow.serialization.serialized_objects import _ExpandInputRef from airflow.typing_compat import TypeGuard from airflow.utils.context import Context @@ -281,7 +282,11 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any] } -def get_map_type_key(expand_input: ExpandInput) -> str: +def get_map_type_key(expand_input: ExpandInput | _ExpandInputRef) -> str: + from airflow.serialization.serialized_objects import _ExpandInputRef + + if isinstance(expand_input, _ExpandInputRef): + return expand_input.key return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input)) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index be9efb3167f95..27d0510c307c0 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -807,7 +807,12 @@ def get_parse_time_mapped_ti_count(self) -> int: return parent_count * current_count def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: - current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session) + from airflow.serialization.serialized_objects import _ExpandInputRef + + exp_input = self._get_specified_expand_input() + if isinstance(exp_input, _ExpandInputRef): + exp_input = exp_input.deref(self.dag) + current_count = exp_input.get_total_map_length(run_id, session=session) try: parent_count = super().get_mapped_ti_count(run_id, session=session) except NotMapped: diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index afb7d4e2dd064..e499a98691940 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -57,20 +57,21 @@ def serialize_operator(x: Operator | None) -> dict | None: if x: - from airflow.serialization.serialized_objects import SerializedBaseOperator + from airflow.serialization.serialized_objects import BaseSerialization - return SerializedBaseOperator.serialize_operator(x) + return BaseSerialization.serialize(x, use_pydantic_models=True) return None def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any: from airflow.models.baseoperator import BaseOperator from airflow.models.mappedoperator import MappedOperator - from airflow.serialization.serialized_objects import SerializedBaseOperator if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None: return x - return SerializedBaseOperator.deserialize_operator(x) + from airflow.serialization.serialized_objects import BaseSerialization + + return BaseSerialization.deserialize(x, use_pydantic_models=True) PydanticOperator = Annotated[ diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d5aeb532cbd32..b9b774ab7b361 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1185,6 +1185,8 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: v = {arg: cls.deserialize(value) for arg, value in v.items()} elif k in {"expand_input", "op_kwargs_expand_input"}: v = _ExpandInputRef(v["type"], cls.deserialize(v["value"])) + elif k == "operator_class": + v = {k_: cls.deserialize(v_, use_pydantic_models=True) for k_, v_ in v.items()} elif ( k in cls._decorated_fields or k not in op.get_serialized_fields() @@ -1200,7 +1202,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: setattr(op, k, v) for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): - # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check + # TODO: refactor deserialization of BaseOperator and MappedOperator (split it out), then check # could go away. if not hasattr(op, k): setattr(op, k, None) diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index 549e703485cd7..22f09b68e48a9 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -22,8 +22,11 @@ import pytest from dateutil import relativedelta +from airflow.decorators import task +from airflow.decorators.python import _PythonDecoratedOperator from airflow.jobs.job import Job from airflow.jobs.local_task_job_runner import LocalTaskJobRunner +from airflow.models import MappedOperator from airflow.models.dag import DagModel from airflow.models.dataset import ( DagScheduleDatasetReference, @@ -36,6 +39,7 @@ from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.job import JobPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic +from airflow.serialization.serialized_objects import BaseSerialization from airflow.settings import _ENABLE_AIP_44 from airflow.utils import timezone from airflow.utils.state import State @@ -68,6 +72,66 @@ def test_serializing_pydantic_task_instance(session, create_task_instance): assert deserialized_model.next_kwargs == {"foo": "bar"} +@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") +def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, dag_maker): + op_class_dict_expected = { + "_task_type": "_PythonDecoratedOperator", + "downstream_task_ids": [], + "_operator_name": "@task", + "ui_fgcolor": "#000", + "ui_color": "#ffefeb", + "template_fields": ["templates_dict", "op_args", "op_kwargs"], + "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, + "template_ext": [], + "task_id": "target", + } + + with dag_maker(): + + @task + def source(): + return [1, 2, 3] + + @task + def target(val=None): + print(val) + + # source() >> target() + target.expand(val=source()) + dr = dag_maker.create_dagrun() + ti = dr.task_instances[1] + + # roundtrip task + ser_task = BaseSerialization.serialize(ti.task, use_pydantic_models=True) + deser_task = BaseSerialization.deserialize(ser_task, use_pydantic_models=True) + ti.task.operator_class + # this is part of the problem! + assert isinstance(ti.task.operator_class, type) + assert isinstance(deser_task.operator_class, dict) + + assert ti.task.operator_class == _PythonDecoratedOperator + ti.refresh_from_task(deser_task) + # roundtrip ti + sered = BaseSerialization.serialize(ti, use_pydantic_models=True) + desered = BaseSerialization.deserialize(sered, use_pydantic_models=True) + + assert "operator_class" not in sered["__var"]["task"] + + assert desered.task.__class__ == MappedOperator + + assert desered.task.operator_class == op_class_dict_expected + + desered.refresh_from_task(deser_task) + + assert desered.task.__class__ == MappedOperator + + assert isinstance(desered.task.operator_class, dict) + + resered = BaseSerialization.serialize(desered, use_pydantic_models=True) + deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True) + assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected + + @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") def test_serializing_pydantic_dagrun(session, create_task_instance): dag_id = "test-dag"