Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix operator_extra_links property serialization in mapped tasks #31904

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class BaseSerialization:
_datetime_types = (datetime.datetime,)

# Object types that are always excluded in serialization.
_excluded_types = (logging.Logger, Connection, type)
_excluded_types = (logging.Logger, Connection, type, property)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this one affect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we don't exclude it, the serialized object will have both operator_extra_links (added by the method serialize_to_json with the method as value) and _operator_extra_links, also the property type should be excluded anyway because it is not deserialized.


_json_schema: Validator | None = None

Expand Down Expand Up @@ -822,7 +822,9 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool)

if op.operator_extra_links:
serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
op.operator_extra_links
op.operator_extra_links.__get__(op)
if isinstance(op.operator_extra_links, property)
else op.operator_extra_links
)

if include_deps:
Expand Down
40 changes: 39 additions & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from airflow.utils.task_group import TaskGroup
from airflow.utils.xcom import XCOM_RETURN_KEY
from tests.test_utils.config import conf_vars
from tests.test_utils.mock_operators import CustomOperator, GoogleLink, MockOperator
from tests.test_utils.mock_operators import AirflowLink2, CustomOperator, GoogleLink, MockOperator
from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable

repo_root = Path(airflow.__file__).parent.parent
Expand Down Expand Up @@ -2479,3 +2479,41 @@ def tg(a: str) -> None:
serde_tg = serde_dag.task_group.children["tg"]
assert isinstance(serde_tg, MappedTaskGroup)
assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]})


def test_mapped_task_with_operator_extra_links_property():
class _DummyOperator(BaseOperator):
def __init__(self, inputs, **kwargs):
super().__init__(**kwargs)
self.inputs = inputs

@property
def operator_extra_links(self):
return (AirflowLink2(),)

with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
serialized_dag = SerializedBaseOperator.serialize(dag)
assert serialized_dag["tasks"][0] == {
"task_id": "task",
"expand_input": {
"type": "dict-of-lists",
"value": {"__type": "dict", "__var": {"inputs": [1, 2, 3]}},
},
"partial_kwargs": {},
"_disallow_kwargs_override": False,
"_expand_input_attr": "expand_input",
"downstream_task_ids": [],
"_operator_extra_links": [{"tests.test_utils.mock_operators.AirflowLink2": {}}],
"ui_color": "#fff",
"ui_fgcolor": "#000",
"template_ext": [],
"template_fields": [],
"template_fields_renderers": {},
"_task_type": "_DummyOperator",
"_task_module": "tests.serialization.test_dag_serialization",
"_is_empty": False,
"_is_mapped": True,
}
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag)
assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()]