Skip to content

Commit

Permalink
Respect Asset.name when accessing inlet and outlet events (#44639)
Browse files Browse the repository at this point in the history
* docs: removing examples that accessing inlet and outlet events

* feat(utils/context): Deprecate accessing inlet and outlet events through string

* fix(util/context): use asset unique key in outlet event accessors

* feat(asset): extend AssetAliasEvent to include asset uri and asset name

* refactor(asset): merge dest_asset_name and dest_asset_uri to dest_asset_key for AssetAliasEvent

* refactor(context): fix context typing

* feat(utils/context): add back string access to inlet and outlet events

* feat(utils/context): forbid using string to access inlet and outlet event

* feat(utils/context): rename AssetUniqueKey and AssetAliasUniqueKey to_obj methods to to_asset, to_asset_alias

* refactor(utils/context): remove redundant iter call

Co-authored-by: Tzu-ping Chung <[email protected]>

* feat(metadata): remove string access support to Metadata

* fix(serialization): fix asset unique key serialization

* feat(asset): make asset unique key an attrs

* fix(serialization): fix asset unique and asset alias unique serialization

* feat(utils/context): change outlet_event error to TypeError when non-asset or non-asset-alias is passed

---------

Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
Lee-W and uranusjr authored Dec 11, 2024
1 parent 37daea7 commit 4289051
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 228 deletions.
2 changes: 1 addition & 1 deletion airflow/example_dags/example_asset_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def produce_asset_events():
def produce_asset_events_through_asset_alias(*, outlet_events=None):
bucket_name = "bucket"
object_path = "my-task"
outlet_events["example-alias"].add(Asset(f"s3://{bucket_name}/{object_path}"))
outlet_events[AssetAlias("example-alias")].add(Asset(f"s3://{bucket_name}/{object_path}"))

produce_asset_events_through_asset_alias()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def produce_asset_events():
def produce_asset_events_through_asset_alias_with_no_taskflow(*, outlet_events=None):
bucket_name = "bucket"
object_path = "my-task"
outlet_events["example-alias-no-taskflow"].add(Asset(f"s3://{bucket_name}/{object_path}"))
outlet_events[AssetAlias("example-alias-no-taskflow")].add(Asset(f"s3://{bucket_name}/{object_path}"))

PythonOperator(
task_id="produce_asset_events_through_asset_alias_with_no_taskflow",
Expand Down
14 changes: 7 additions & 7 deletions airflow/example_dags/example_outlet_event_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset.metadata import Metadata

ds = Asset("s3://output/1.txt")
asset = Asset(uri="s3://output/1.txt", name="test-asset")

with DAG(
dag_id="asset_with_extra_by_yield",
Expand All @@ -41,9 +41,9 @@
tags=["produces"],
):

@task(outlets=[ds])
@task(outlets=[asset])
def asset_with_extra_by_yield():
yield Metadata(ds, {"hi": "bye"})
yield Metadata(asset, {"hi": "bye"})

asset_with_extra_by_yield()

Expand All @@ -55,9 +55,9 @@ def asset_with_extra_by_yield():
tags=["produces"],
):

@task(outlets=[ds])
@task(outlets=[asset])
def asset_with_extra_by_context(*, outlet_events=None):
outlet_events[ds].extra = {"hi": "bye"}
outlet_events[asset].extra = {"hi": "bye"}

asset_with_extra_by_context()

Expand All @@ -70,11 +70,11 @@ def asset_with_extra_by_context(*, outlet_events=None):
):

def _asset_with_extra_from_classic_operator_post_execute(context, result):
context["outlet_events"][ds].extra = {"hi": "bye"}
context["outlet_events"][asset].extra = {"hi": "bye"}

BashOperator(
task_id="asset_with_extra_from_classic_operator",
outlets=[ds],
outlets=[asset],
bash_command=":",
post_execute=_asset_with_extra_from_classic_operator_post_execute,
)
31 changes: 20 additions & 11 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
inspect,
or_,
text,
tuple_,
update,
)
from sqlalchemy.dialects import postgresql
Expand Down Expand Up @@ -100,7 +101,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
Expand Down Expand Up @@ -2735,7 +2736,7 @@ def _register_asset_changes_int(
# One task only triggers one asset event for each asset with the same extra.
# This tuple[asset uri, extra] to sets alias names mapping is used to find whether
# there're assets with same uri but different extra that we need to emit more than one asset events.
asset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = defaultdict(set)
for obj in ti.task.outlets or []:
ti.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides assets
Expand All @@ -2749,26 +2750,34 @@ def _register_asset_changes_int(
elif isinstance(obj, AssetAlias):
for asset_alias_event in events[obj].asset_alias_events:
asset_alias_name = asset_alias_event.source_alias_name
asset_uri = asset_alias_event.dest_asset_uri
asset_unique_key = asset_alias_event.dest_asset_key
frozen_extra = frozenset(asset_alias_event.extra.items())
asset_alias_names[(asset_uri, frozen_extra)].add(asset_alias_name)
asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name)

asset_models: dict[str, AssetModel] = {
asset_obj.uri: asset_obj
asset_models: dict[AssetUniqueKey, AssetModel] = {
AssetUniqueKey.from_asset(asset_obj): asset_obj
for asset_obj in session.scalars(
select(AssetModel).where(AssetModel.uri.in_(uri for uri, _ in asset_alias_names))
select(AssetModel).where(
tuple_(AssetModel.name, AssetModel.uri).in_(
(key.name, key.uri) for key, _ in asset_alias_names
)
)
)
}
if missing_assets := [Asset(uri=u) for u, _ in asset_alias_names if u not in asset_models]:
if missing_assets := [
asset_unique_key.to_asset()
for asset_unique_key, _ in asset_alias_names
if asset_unique_key not in asset_models
]:
asset_models.update(
(asset_obj.uri, asset_obj)
(AssetUniqueKey.from_asset(asset_obj), asset_obj)
for asset_obj in asset_manager.create_assets(missing_assets, session=session)
)
ti.log.warning("Created new assets for alias reference: %s", missing_assets)
session.flush() # Needed because we need the id for fk.

for (uri, extra_items), alias_names in asset_alias_names.items():
asset_obj = asset_models[uri]
for (unique_key, extra_items), alias_names in asset_alias_names.items():
asset_obj = asset_models[unique_key]
ti.log.info(
'Creating event for %r through aliases "%s"',
asset_obj,
Expand Down
2 changes: 2 additions & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DagAttributeTypes(str, Enum):
ASSET_ANY = "asset_any"
ASSET_ALL = "asset_all"
ASSET_REF = "asset_ref"
ASSET_UNIQUE_KEY = "asset_unique_key"
ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
Expand Down
66 changes: 48 additions & 18 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasUniqueKey,
AssetAll,
AssetAny,
AssetRef,
AssetUniqueKey,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
Expand Down Expand Up @@ -303,25 +305,52 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:


def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
raw_key = var.raw_key
key = var.key
return {
"key": BaseSerialization.serialize(key),
"extra": var.extra,
"asset_alias_events": [attrs.asdict(cast(attrs.AttrsInstance, e)) for e in var.asset_alias_events],
"raw_key": BaseSerialization.serialize(raw_key),
}


def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
asset_alias_events = var.get("asset_alias_events", [])

outlet_event_accessor = OutletEventAccessor(
key=BaseSerialization.deserialize(var["key"]),
extra=var["extra"],
raw_key=BaseSerialization.deserialize(var["raw_key"]),
asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events],
asset_alias_events=[
AssetAliasEvent(
source_alias_name=e["source_alias_name"],
dest_asset_key=AssetUniqueKey(
name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"]
),
extra=e["extra"],
)
for e in asset_alias_events
],
)
return outlet_event_accessor


def encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str, Any]:
return {
"__type": DAT.ASSET_EVENT_ACCESSORS,
"_dict": [
{"key": BaseSerialization.serialize(k), "value": encode_outlet_event_accessor(v)}
for k, v in var._dict.items() # type: ignore[attr-defined]
],
}


def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors:
d = OutletEventAccessors() # type: ignore[assignment]
d._dict = { # type: ignore[attr-defined]
BaseSerialization.deserialize(row["key"]): decode_outlet_event_accessor(row["value"])
for row in var["_dict"]
}
return d


def encode_timetable(var: Timetable) -> dict[str, Any]:
"""
Encode a timetable instance.
Expand Down Expand Up @@ -680,17 +709,18 @@ def serialize(
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, OutletEventAccessors):
return cls._encode(
cls.serialize(
var._dict, # type: ignore[attr-defined]
strict=strict,
use_pydantic_models=use_pydantic_models,
),
encode_outlet_event_accessors(var),
type_=DAT.ASSET_EVENT_ACCESSORS,
)
elif isinstance(var, OutletEventAccessor):
elif isinstance(var, AssetUniqueKey):
return cls._encode(
attrs.asdict(var),
type_=DAT.ASSET_UNIQUE_KEY,
)
elif isinstance(var, AssetAliasUniqueKey):
return cls._encode(
encode_outlet_event_accessor(var),
type_=DAT.ASSET_EVENT_ACCESSOR,
attrs.asdict(var),
type_=DAT.ASSET_ALIAS_UNIQUE_KEY,
)
elif isinstance(var, DAG):
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
Expand Down Expand Up @@ -855,11 +885,11 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
elif type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.ASSET_EVENT_ACCESSORS:
d = OutletEventAccessors() # type: ignore[assignment]
d._dict = cls.deserialize(var) # type: ignore[attr-defined]
return d
elif type_ == DAT.ASSET_EVENT_ACCESSOR:
return decode_outlet_event_accessor(var)
return decode_outlet_event_accessors(var)
elif type_ == DAT.ASSET_UNIQUE_KEY:
return AssetUniqueKey(name=var["name"], uri=var["uri"])
elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY:
return AssetAliasUniqueKey(name=var["name"])
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
Expand Down
Loading

0 comments on commit 4289051

Please sign in to comment.