Skip to content

Commit

Permalink
fix(serialization): fix asset unique and asset alias unique serializa…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
Lee-W committed Dec 10, 2024
1 parent 5b27b1b commit f047a64
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
2 changes: 2 additions & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DagAttributeTypes(str, Enum):
ASSET_ANY = "asset_any"
ASSET_ALL = "asset_all"
ASSET_REF = "asset_ref"
ASSET_UNIQUE_KEY = "asset_unique_key"
ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
Expand Down
38 changes: 35 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasUniqueKey,
AssetAll,
AssetAny,
AssetRef,
AssetUniqueKey,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
Expand Down Expand Up @@ -316,19 +318,35 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
outlet_event_accessor = OutletEventAccessor(
key=BaseSerialization.deserialize(var["key"]),
extra=var["extra"],
asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events],
asset_alias_events=[
AssetAliasEvent(
source_alias_name=e["source_alias_name"],
dest_asset_key=AssetUniqueKey(
name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"]
),
extra=e["extra"],
)
for e in asset_alias_events
],
)
return outlet_event_accessor


def encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str, Any]:
return {k: encode_outlet_event_accessor(v) for k, v in var._dict.items()} # type: ignore[attr-defined]
return {
"__type": DAT.ASSET_EVENT_ACCESSORS,
"_dict": [
{"key": BaseSerialization.serialize(k), "value": encode_outlet_event_accessor(v)}
for k, v in var._dict.items()
],
}


def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors:
d = OutletEventAccessors() # type: ignore[assignment]
d._dict = { # type: ignore[attr-defined]
k: decode_outlet_event_accessor(v) for k, v in var.items()
BaseSerialization.deserialize(row["key"]): decode_outlet_event_accessor(row["value"])
for row in var["_dict"]
}
return d

Expand Down Expand Up @@ -694,6 +712,16 @@ def serialize(
encode_outlet_event_accessors(var),
type_=DAT.ASSET_EVENT_ACCESSORS,
)
elif isinstance(var, AssetUniqueKey):
return cls._encode(
attrs.asdict(var),
type_=DAT.ASSET_UNIQUE_KEY,
)
elif isinstance(var, AssetAliasUniqueKey):
return cls._encode(
attrs.asdict(var),
type_=DAT.ASSET_ALIAS_UNIQUE_KEY,
)
elif isinstance(var, DAG):
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
elif isinstance(var, Resources):
Expand Down Expand Up @@ -858,6 +886,10 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.ASSET_EVENT_ACCESSORS:
return decode_outlet_event_accessors(var)
elif type_ == DAT.ASSET_UNIQUE_KEY:
return AssetUniqueKey(name=var["name"], uri=var["uri"])
elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY:
return AssetAliasUniqueKey(name=var["name"])
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
Expand Down
42 changes: 27 additions & 15 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic
Expand Down Expand Up @@ -150,6 +150,16 @@ class Test:
DAG_RUN.id = 1


def create_outlet_event_accessors(
key: Asset | AssetAlias, extra: dict, asset_alias_events: list[AssetAliasEvent]
) -> OutletEventAccessors:
o = OutletEventAccessors()
o[key].extra = extra
o[key].asset_alias_events = asset_alias_events

return o


def equals(a, b) -> bool:
return a == b

Expand All @@ -162,6 +172,12 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool:
return a.__class__ == b.__class__ and str(a) == str(b)


def equal_outlet_event_accessors(a: OutletEventAccessors, b: OutletEventAccessors) -> bool:
return a._dict.keys() == b._dict.keys() and all(
equal_outlet_event_accessor(a._dict[key], b._dict[key]) for key in a._dict
)


def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool:
return a.key == b.key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events

Expand Down Expand Up @@ -257,30 +273,26 @@ def __len__(self) -> int:
lambda a, b: a.get_uri() == b.get_uri(),
),
(
OutletEventAccessor(
key=AssetUniqueKey.from_asset(Asset(uri="test", name="test", group="test-group")),
extra={"key": "value"},
asset_alias_events=[],
create_outlet_event_accessors(
Asset(uri="test", name="test", group="test-group"), {"key": "value"}, []
),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
DAT.ASSET_EVENT_ACCESSORS,
equal_outlet_event_accessors,
),
(
OutletEventAccessor(
key=AssetAliasUniqueKey.from_asset_alias(
AssetAlias(name="test_alias", group="test-alias-group")
),
extra={"key": "value"},
asset_alias_events=[
create_outlet_event_accessors(
AssetAlias(name="test_alias", group="test-alias-group"),
{"key": "value"},
[
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="test_name", uri="test://asset-uri"),
extra={},
)
],
),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
DAT.ASSET_EVENT_ACCESSORS,
equal_outlet_event_accessors,
),
(
AirflowException("test123 wohoo!"),
Expand Down

0 comments on commit f047a64

Please sign in to comment.