diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index d1b946b38ef1e..ab7d59cb21682 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -60,6 +60,8 @@ class DagAttributeTypes(str, Enum): ASSET_ANY = "asset_any" ASSET_ALL = "asset_all" ASSET_REF = "asset_ref" + ASSET_UNIQUE_KEY = "asset_unique_key" + ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key" SIMPLE_TASK_INSTANCE = "simple_task_instance" BASE_JOB = "Job" TASK_INSTANCE = "task_instance" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 3d8cb2ad13629..95d2b82f55164 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -59,9 +59,11 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasUniqueKey, AssetAll, AssetAny, AssetRef, + AssetUniqueKey, BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator @@ -316,19 +318,35 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor: outlet_event_accessor = OutletEventAccessor( key=BaseSerialization.deserialize(var["key"]), extra=var["extra"], - asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events], + asset_alias_events=[ + AssetAliasEvent( + source_alias_name=e["source_alias_name"], + dest_asset_key=AssetUniqueKey( + name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"] + ), + extra=e["extra"], + ) + for e in asset_alias_events + ], ) return outlet_event_accessor def encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str, Any]: - return {k: encode_outlet_event_accessor(v) for k, v in var._dict.items()} # type: ignore[attr-defined] + return { + "__type": DAT.ASSET_EVENT_ACCESSORS, + "_dict": [ + {"key": BaseSerialization.serialize(k), "value": encode_outlet_event_accessor(v)} + for k, v in var._dict.items() # type: ignore[attr-defined] + ], + } def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors: d = OutletEventAccessors() # type: ignore[assignment] d._dict = { # type: ignore[attr-defined] - k: decode_outlet_event_accessor(v) for k, v in var.items() + BaseSerialization.deserialize(row["key"]): decode_outlet_event_accessor(row["value"]) + for row in var["_dict"] } return d @@ -694,6 +712,16 @@ def serialize( encode_outlet_event_accessors(var), type_=DAT.ASSET_EVENT_ACCESSORS, ) + elif isinstance(var, AssetUniqueKey): + return cls._encode( + attrs.asdict(var), + type_=DAT.ASSET_UNIQUE_KEY, + ) + elif isinstance(var, AssetAliasUniqueKey): + return cls._encode( + attrs.asdict(var), + type_=DAT.ASSET_ALIAS_UNIQUE_KEY, + ) elif isinstance(var, DAG): return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG) elif isinstance(var, Resources): @@ -858,6 +886,10 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()} elif type_ == DAT.ASSET_EVENT_ACCESSORS: return decode_outlet_event_accessors(var) + elif type_ == DAT.ASSET_UNIQUE_KEY: + return AssetUniqueKey(name=var["name"], uri=var["uri"]) + elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY: + return AssetAliasUniqueKey(name=var["name"]) elif type_ == DAT.DAG: return SerializedDAG.deserialize_dag(var) elif type_ == DAT.OP: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 377477a21da07..136d8d68a5035 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2525,7 +2525,7 @@ def test_outlet_asset_alias(self, dag_maker, session): @task(outlets=AssetAlias(alias_name_1)) def producer(*, outlet_events): - outlet_events[alias_name_1].add(Asset(asset_uri)) + outlet_events[AssetAlias(alias_name_1)].add(Asset(asset_uri)) producer() @@ -2581,9 +2581,9 @@ def test_outlet_multiple_asset_alias(self, dag_maker, session): ] ) def producer(*, outlet_events): - outlet_events[asset_alias_name_1].add(Asset(asset_uri)) - outlet_events[asset_alias_name_2].add(Asset(asset_uri)) - outlet_events[asset_alias_name_3].add(Asset(asset_uri), extra={"k": "v"}) + outlet_events[AssetAlias(asset_alias_name_1)].add(Asset(asset_uri)) + outlet_events[AssetAlias(asset_alias_name_2)].add(Asset(asset_uri)) + outlet_events[AssetAlias(asset_alias_name_3)].add(Asset(asset_uri), extra={"k": "v"}) producer() @@ -2686,7 +2686,7 @@ def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): @task(outlets=AssetAlias(asset_alias_name)) def producer(*, outlet_events): - outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"key": "value"}) + outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"key": "value"}) producer() @@ -2786,7 +2786,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session): @task(outlets=AssetAlias(asset_alias_name)) def write(*, ti, outlet_events): - outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"from": ti.task_id}) + outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"from": ti.task_id}) @task(inlets=AssetAlias(asset_alias_name)) def read(*, inlet_events): @@ -2875,7 +2875,7 @@ def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected): @task(outlets=Asset(asset_uri)) def write(*, params, outlet_events): - outlet_events[asset_uri].extra = {"from": params["i"]} + outlet_events[Asset(asset_uri)].extra = {"from": params["i"]} write() @@ -2935,7 +2935,7 @@ def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expecte @task(outlets=AssetAlias(asset_alias_name)) def write(*, params, outlet_events): - outlet_events[asset_alias_name].add(Asset(asset_uri), {"from": params["i"]}) + outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), {"from": params["i"]}) write() diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 3dda7623ff64e..8100c2a84bc74 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -49,7 +49,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic @@ -150,6 +150,16 @@ class Test: DAG_RUN.id = 1 +def create_outlet_event_accessors( + key: Asset | AssetAlias, extra: dict, asset_alias_events: list[AssetAliasEvent] +) -> OutletEventAccessors: + o = OutletEventAccessors() + o[key].extra = extra + o[key].asset_alias_events = asset_alias_events + + return o + + def equals(a, b) -> bool: return a == b @@ -162,6 +172,13 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool: return a.__class__ == b.__class__ and str(a) == str(b) +def equal_outlet_event_accessors(a: OutletEventAccessors, b: OutletEventAccessors) -> bool: + return a._dict.keys() == b._dict.keys() and all( # type: ignore[attr-defined] + equal_outlet_event_accessor(a._dict[key], b._dict[key]) # type: ignore[attr-defined] + for key in a._dict # type: ignore[attr-defined] + ) + + def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool: return a.key == b.key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events @@ -257,21 +274,17 @@ def __len__(self) -> int: lambda a, b: a.get_uri() == b.get_uri(), ), ( - OutletEventAccessor( - key=AssetUniqueKey.from_asset(Asset(uri="test", name="test", group="test-group")), - extra={"key": "value"}, - asset_alias_events=[], + create_outlet_event_accessors( + Asset(uri="test", name="test", group="test-group"), {"key": "value"}, [] ), - DAT.ASSET_EVENT_ACCESSOR, - equal_outlet_event_accessor, + DAT.ASSET_EVENT_ACCESSORS, + equal_outlet_event_accessors, ), ( - OutletEventAccessor( - key=AssetAliasUniqueKey.from_asset_alias( - AssetAlias(name="test_alias", group="test-alias-group") - ), - extra={"key": "value"}, - asset_alias_events=[ + create_outlet_event_accessors( + AssetAlias(name="test_alias", group="test-alias-group"), + {"key": "value"}, + [ AssetAliasEvent( source_alias_name="test_alias", dest_asset_key=AssetUniqueKey(name="test_name", uri="test://asset-uri"), @@ -279,8 +292,8 @@ def __len__(self) -> int: ) ], ), - DAT.ASSET_EVENT_ACCESSOR, - equal_outlet_event_accessor, + DAT.ASSET_EVENT_ACCESSORS, + equal_outlet_event_accessors, ), ( AirflowException("test123 wohoo!"),