Skip to content

Commit

Permalink
fix(serialization): fix asset unique and asset alias unique serializa…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
Lee-W committed Dec 10, 2024
1 parent 9fa4c72 commit 58c7370
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 26 deletions.
2 changes: 2 additions & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DagAttributeTypes(str, Enum):
ASSET_ANY = "asset_any"
ASSET_ALL = "asset_all"
ASSET_REF = "asset_ref"
ASSET_UNIQUE_KEY = "asset_unique_key"
ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
Expand Down
38 changes: 35 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasUniqueKey,
AssetAll,
AssetAny,
AssetRef,
AssetUniqueKey,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
Expand Down Expand Up @@ -316,19 +318,35 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
outlet_event_accessor = OutletEventAccessor(
key=BaseSerialization.deserialize(var["key"]),
extra=var["extra"],
asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events],
asset_alias_events=[
AssetAliasEvent(
source_alias_name=e["source_alias_name"],
dest_asset_key=AssetUniqueKey(
name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"]
),
extra=e["extra"],
)
for e in asset_alias_events
],
)
return outlet_event_accessor


def encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str, Any]:
return {k: encode_outlet_event_accessor(v) for k, v in var._dict.items()} # type: ignore[attr-defined]
return {
"__type": DAT.ASSET_EVENT_ACCESSORS,
"_dict": [
{"key": BaseSerialization.serialize(k), "value": encode_outlet_event_accessor(v)}
for k, v in var._dict.items() # type: ignore[attr-defined]
],
}


def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors:
d = OutletEventAccessors() # type: ignore[assignment]
d._dict = { # type: ignore[attr-defined]
k: decode_outlet_event_accessor(v) for k, v in var.items()
BaseSerialization.deserialize(row["key"]): decode_outlet_event_accessor(row["value"])
for row in var["_dict"]
}
return d

Expand Down Expand Up @@ -694,6 +712,16 @@ def serialize(
encode_outlet_event_accessors(var),
type_=DAT.ASSET_EVENT_ACCESSORS,
)
elif isinstance(var, AssetUniqueKey):
return cls._encode(
attrs.asdict(var),
type_=DAT.ASSET_UNIQUE_KEY,
)
elif isinstance(var, AssetAliasUniqueKey):
return cls._encode(
attrs.asdict(var),
type_=DAT.ASSET_ALIAS_UNIQUE_KEY,
)
elif isinstance(var, DAG):
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
elif isinstance(var, Resources):
Expand Down Expand Up @@ -858,6 +886,10 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.ASSET_EVENT_ACCESSORS:
return decode_outlet_event_accessors(var)
elif type_ == DAT.ASSET_UNIQUE_KEY:
return AssetUniqueKey(name=var["name"], uri=var["uri"])
elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY:
return AssetAliasUniqueKey(name=var["name"])
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
Expand Down
16 changes: 8 additions & 8 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,7 +2525,7 @@ def test_outlet_asset_alias(self, dag_maker, session):

@task(outlets=AssetAlias(alias_name_1))
def producer(*, outlet_events):
outlet_events[alias_name_1].add(Asset(asset_uri))
outlet_events[AssetAlias(alias_name_1)].add(Asset(asset_uri))

producer()

Expand Down Expand Up @@ -2581,9 +2581,9 @@ def test_outlet_multiple_asset_alias(self, dag_maker, session):
]
)
def producer(*, outlet_events):
outlet_events[asset_alias_name_1].add(Asset(asset_uri))
outlet_events[asset_alias_name_2].add(Asset(asset_uri))
outlet_events[asset_alias_name_3].add(Asset(asset_uri), extra={"k": "v"})
outlet_events[AssetAlias(asset_alias_name_1)].add(Asset(asset_uri))
outlet_events[AssetAlias(asset_alias_name_2)].add(Asset(asset_uri))
outlet_events[AssetAlias(asset_alias_name_3)].add(Asset(asset_uri), extra={"k": "v"})

producer()

Expand Down Expand Up @@ -2686,7 +2686,7 @@ def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session):

@task(outlets=AssetAlias(asset_alias_name))
def producer(*, outlet_events):
outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"key": "value"})
outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"key": "value"})

producer()

Expand Down Expand Up @@ -2786,7 +2786,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session):

@task(outlets=AssetAlias(asset_alias_name))
def write(*, ti, outlet_events):
outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"from": ti.task_id})
outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"from": ti.task_id})

@task(inlets=AssetAlias(asset_alias_name))
def read(*, inlet_events):
Expand Down Expand Up @@ -2875,7 +2875,7 @@ def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected):

@task(outlets=Asset(asset_uri))
def write(*, params, outlet_events):
outlet_events[asset_uri].extra = {"from": params["i"]}
outlet_events[Asset(asset_uri)].extra = {"from": params["i"]}

write()

Expand Down Expand Up @@ -2935,7 +2935,7 @@ def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expecte

@task(outlets=AssetAlias(asset_alias_name))
def write(*, params, outlet_events):
outlet_events[asset_alias_name].add(Asset(asset_uri), {"from": params["i"]})
outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), {"from": params["i"]})

write()

Expand Down
43 changes: 28 additions & 15 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic
Expand Down Expand Up @@ -150,6 +150,16 @@ class Test:
DAG_RUN.id = 1


def create_outlet_event_accessors(
key: Asset | AssetAlias, extra: dict, asset_alias_events: list[AssetAliasEvent]
) -> OutletEventAccessors:
o = OutletEventAccessors()
o[key].extra = extra
o[key].asset_alias_events = asset_alias_events

return o


def equals(a, b) -> bool:
return a == b

Expand All @@ -162,6 +172,13 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool:
return a.__class__ == b.__class__ and str(a) == str(b)


def equal_outlet_event_accessors(a: OutletEventAccessors, b: OutletEventAccessors) -> bool:
return a._dict.keys() == b._dict.keys() and all( # type: ignore[attr-defined]
equal_outlet_event_accessor(a._dict[key], b._dict[key]) # type: ignore[attr-defined]
for key in a._dict # type: ignore[attr-defined]
)


def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool:
return a.key == b.key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events

Expand Down Expand Up @@ -257,30 +274,26 @@ def __len__(self) -> int:
lambda a, b: a.get_uri() == b.get_uri(),
),
(
OutletEventAccessor(
key=AssetUniqueKey.from_asset(Asset(uri="test", name="test", group="test-group")),
extra={"key": "value"},
asset_alias_events=[],
create_outlet_event_accessors(
Asset(uri="test", name="test", group="test-group"), {"key": "value"}, []
),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
DAT.ASSET_EVENT_ACCESSORS,
equal_outlet_event_accessors,
),
(
OutletEventAccessor(
key=AssetAliasUniqueKey.from_asset_alias(
AssetAlias(name="test_alias", group="test-alias-group")
),
extra={"key": "value"},
asset_alias_events=[
create_outlet_event_accessors(
AssetAlias(name="test_alias", group="test-alias-group"),
{"key": "value"},
[
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="test_name", uri="test://asset-uri"),
extra={},
)
],
),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
DAT.ASSET_EVENT_ACCESSORS,
equal_outlet_event_accessors,
),
(
AirflowException("test123 wohoo!"),
Expand Down

0 comments on commit 58c7370

Please sign in to comment.