Skip to content

Commit

Permalink
Fix operator_extra_links property serialization in mapped tasks (#3…
Browse files Browse the repository at this point in the history
…1904)

* reproduce the problem in a unit test

* Fix operator_extra_links serialization

* replace fget by __get__

(cherry picked from commit 3318212)
  • Loading branch information
hussein-awala authored and ephraimbuddy committed Jul 6, 2023
1 parent 52a2fd3 commit ad7f474
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
6 changes: 4 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class BaseSerialization:
_datetime_types = (datetime.datetime,)

# Object types that are always excluded in serialization.
_excluded_types = (logging.Logger, Connection, type)
_excluded_types = (logging.Logger, Connection, type, property)

_json_schema: Validator | None = None

Expand Down Expand Up @@ -822,7 +822,9 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool)

if op.operator_extra_links:
serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
op.operator_extra_links
op.operator_extra_links.__get__(op)
if isinstance(op.operator_extra_links, property)
else op.operator_extra_links
)

if include_deps:
Expand Down
40 changes: 39 additions & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from airflow.utils.task_group import TaskGroup
from airflow.utils.xcom import XCOM_RETURN_KEY
from tests.test_utils.config import conf_vars
from tests.test_utils.mock_operators import CustomOperator, GoogleLink, MockOperator
from tests.test_utils.mock_operators import AirflowLink2, CustomOperator, GoogleLink, MockOperator
from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable

repo_root = Path(airflow.__file__).parent.parent
Expand Down Expand Up @@ -2481,3 +2481,41 @@ def tg(a: str) -> None:
serde_tg = serde_dag.task_group.children["tg"]
assert isinstance(serde_tg, MappedTaskGroup)
assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]})


def test_mapped_task_with_operator_extra_links_property():
class _DummyOperator(BaseOperator):
def __init__(self, inputs, **kwargs):
super().__init__(**kwargs)
self.inputs = inputs

@property
def operator_extra_links(self):
return (AirflowLink2(),)

with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
serialized_dag = SerializedBaseOperator.serialize(dag)
assert serialized_dag["tasks"][0] == {
"task_id": "task",
"expand_input": {
"type": "dict-of-lists",
"value": {"__type": "dict", "__var": {"inputs": [1, 2, 3]}},
},
"partial_kwargs": {},
"_disallow_kwargs_override": False,
"_expand_input_attr": "expand_input",
"downstream_task_ids": [],
"_operator_extra_links": [{"tests.test_utils.mock_operators.AirflowLink2": {}}],
"ui_color": "#fff",
"ui_fgcolor": "#000",
"template_ext": [],
"template_fields": [],
"template_fields_renderers": {},
"_task_type": "_DummyOperator",
"_task_module": "tests.serialization.test_dag_serialization",
"_is_empty": False,
"_is_mapped": True,
}
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag)
assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()]

0 comments on commit ad7f474

Please sign in to comment.