Skip to content

Commit

Permalink
Fix TIPydantic serialization of MappedOperator (#39288)
Browse files Browse the repository at this point in the history
Previously we relied on SerializedBaseOperator.serialize_operator for serialization of all Operator objects but this is no bueno because it does not work for mapped operator.  Instead we use BaseSerialization.serialize, which calls the right method for each obj type.  To make this actually work for db isolation though, we had to do a few more things.  1. operator_class obj was not deserialized properly so fixed that. It's a weird case cus it roundtrips to a dict -- not a class.  Also expand_input did not work quite right because it roundtrips to _ExpandInputRef.  Then we had to make a small change to get_map_type_key since we might have _ExpandInputRef here.
  • Loading branch information
dstandish authored May 24, 2024
1 parent bca2930 commit c2f1739
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 7 deletions.
7 changes: 6 additions & 1 deletion airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from airflow.models.operator import Operator
from airflow.models.xcom_arg import XComArg
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.typing_compat import TypeGuard
from airflow.utils.context import Context

Expand Down Expand Up @@ -281,7 +282,11 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
}


def get_map_type_key(expand_input: ExpandInput) -> str:
def get_map_type_key(expand_input: ExpandInput | _ExpandInputRef) -> str:
from airflow.serialization.serialized_objects import _ExpandInputRef

if isinstance(expand_input, _ExpandInputRef):
return expand_input.key
return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input))


Expand Down
7 changes: 6 additions & 1 deletion airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,12 @@ def get_parse_time_mapped_ti_count(self) -> int:
return parent_count * current_count

def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
from airflow.serialization.serialized_objects import _ExpandInputRef

exp_input = self._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(self.dag)
current_count = exp_input.get_total_map_length(run_id, session=session)
try:
parent_count = super().get_mapped_ti_count(run_id, session=session)
except NotMapped:
Expand Down
9 changes: 5 additions & 4 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,21 @@

def serialize_operator(x: Operator | None) -> dict | None:
if x:
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.serialization.serialized_objects import BaseSerialization

return SerializedBaseOperator.serialize_operator(x)
return BaseSerialization.serialize(x, use_pydantic_models=True)
return None


def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any:
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator

if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None:
return x
return SerializedBaseOperator.deserialize_operator(x)
from airflow.serialization.serialized_objects import BaseSerialization

return BaseSerialization.deserialize(x, use_pydantic_models=True)


PydanticOperator = Annotated[
Expand Down
4 changes: 3 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,8 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
v = {arg: cls.deserialize(value) for arg, value in v.items()}
elif k in {"expand_input", "op_kwargs_expand_input"}:
v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
elif k == "operator_class":
v = {k_: cls.deserialize(v_, use_pydantic_models=True) for k_, v_ in v.items()}
elif (
k in cls._decorated_fields
or k not in op.get_serialized_fields()
Expand All @@ -1200,7 +1202,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
setattr(op, k, v)

for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
# TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check
# TODO: refactor deserialization of BaseOperator and MappedOperator (split it out), then check
# could go away.
if not hasattr(op, k):
setattr(op, k, None)
Expand Down
64 changes: 64 additions & 0 deletions tests/serialization/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import pytest
from dateutil import relativedelta

from airflow.decorators import task
from airflow.decorators.python import _PythonDecoratedOperator
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models import MappedOperator
from airflow.models.dag import DagModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
Expand All @@ -36,6 +39,7 @@
from airflow.serialization.pydantic.dataset import DatasetEventPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.state import State
Expand Down Expand Up @@ -68,6 +72,66 @@ def test_serializing_pydantic_task_instance(session, create_task_instance):
assert deserialized_model.next_kwargs == {"foo": "bar"}


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, dag_maker):
op_class_dict_expected = {
"_task_type": "_PythonDecoratedOperator",
"downstream_task_ids": [],
"_operator_name": "@task",
"ui_fgcolor": "#000",
"ui_color": "#ffefeb",
"template_fields": ["templates_dict", "op_args", "op_kwargs"],
"template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"},
"template_ext": [],
"task_id": "target",
}

with dag_maker():

@task
def source():
return [1, 2, 3]

@task
def target(val=None):
print(val)

# source() >> target()
target.expand(val=source())
dr = dag_maker.create_dagrun()
ti = dr.task_instances[1]

# roundtrip task
ser_task = BaseSerialization.serialize(ti.task, use_pydantic_models=True)
deser_task = BaseSerialization.deserialize(ser_task, use_pydantic_models=True)
ti.task.operator_class
# this is part of the problem!
assert isinstance(ti.task.operator_class, type)
assert isinstance(deser_task.operator_class, dict)

assert ti.task.operator_class == _PythonDecoratedOperator
ti.refresh_from_task(deser_task)
# roundtrip ti
sered = BaseSerialization.serialize(ti, use_pydantic_models=True)
desered = BaseSerialization.deserialize(sered, use_pydantic_models=True)

assert "operator_class" not in sered["__var"]["task"]

assert desered.task.__class__ == MappedOperator

assert desered.task.operator_class == op_class_dict_expected

desered.refresh_from_task(deser_task)

assert desered.task.__class__ == MappedOperator

assert isinstance(desered.task.operator_class, dict)

resered = BaseSerialization.serialize(desered, use_pydantic_models=True)
deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True)
assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_serializing_pydantic_dagrun(session, create_task_instance):
dag_id = "test-dag"
Expand Down

0 comments on commit c2f1739

Please sign in to comment.