Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect Asset.name when accessing inlet and outlet events #44639

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/example_dags/example_asset_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def produce_asset_events():
def produce_asset_events_through_asset_alias(*, outlet_events=None):
bucket_name = "bucket"
object_path = "my-task"
outlet_events["example-alias"].add(Asset(f"s3://{bucket_name}/{object_path}"))
outlet_events[AssetAlias("example-alias")].add(Asset(f"s3://{bucket_name}/{object_path}"))

produce_asset_events_through_asset_alias()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def produce_asset_events():
def produce_asset_events_through_asset_alias_with_no_taskflow(*, outlet_events=None):
bucket_name = "bucket"
object_path = "my-task"
outlet_events["example-alias-no-taskflow"].add(Asset(f"s3://{bucket_name}/{object_path}"))
outlet_events[AssetAlias("example-alias-no-taskflow")].add(Asset(f"s3://{bucket_name}/{object_path}"))

PythonOperator(
task_id="produce_asset_events_through_asset_alias_with_no_taskflow",
Expand Down
14 changes: 7 additions & 7 deletions airflow/example_dags/example_outlet_event_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset.metadata import Metadata

ds = Asset("s3://output/1.txt")
asset = Asset(uri="s3://output/1.txt", name="test-asset")

with DAG(
dag_id="asset_with_extra_by_yield",
Expand All @@ -41,9 +41,9 @@
tags=["produces"],
):

@task(outlets=[ds])
@task(outlets=[asset])
def asset_with_extra_by_yield():
yield Metadata(ds, {"hi": "bye"})
yield Metadata(asset, {"hi": "bye"})

asset_with_extra_by_yield()

Expand All @@ -55,9 +55,9 @@ def asset_with_extra_by_yield():
tags=["produces"],
):

@task(outlets=[ds])
@task(outlets=[asset])
def asset_with_extra_by_context(*, outlet_events=None):
outlet_events[ds].extra = {"hi": "bye"}
outlet_events[asset].extra = {"hi": "bye"}

asset_with_extra_by_context()

Expand All @@ -70,11 +70,11 @@ def asset_with_extra_by_context(*, outlet_events=None):
):

def _asset_with_extra_from_classic_operator_post_execute(context, result):
context["outlet_events"][ds].extra = {"hi": "bye"}
context["outlet_events"][asset].extra = {"hi": "bye"}

BashOperator(
task_id="asset_with_extra_from_classic_operator",
outlets=[ds],
outlets=[asset],
bash_command=":",
post_execute=_asset_with_extra_from_classic_operator_post_execute,
)
31 changes: 20 additions & 11 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
inspect,
or_,
text,
tuple_,
update,
)
from sqlalchemy.dialects import postgresql
Expand Down Expand Up @@ -100,7 +101,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
Expand Down Expand Up @@ -2735,7 +2736,7 @@ def _register_asset_changes_int(
# One task only triggers one asset event for each asset with the same extra.
# This tuple[asset uri, extra] to sets alias names mapping is used to find whether
# there're assets with same uri but different extra that we need to emit more than one asset events.
asset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = defaultdict(set)
for obj in ti.task.outlets or []:
ti.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides assets
Expand All @@ -2749,26 +2750,34 @@ def _register_asset_changes_int(
elif isinstance(obj, AssetAlias):
for asset_alias_event in events[obj].asset_alias_events:
asset_alias_name = asset_alias_event.source_alias_name
asset_uri = asset_alias_event.dest_asset_uri
asset_unique_key = asset_alias_event.dest_asset_key
frozen_extra = frozenset(asset_alias_event.extra.items())
asset_alias_names[(asset_uri, frozen_extra)].add(asset_alias_name)
asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name)

asset_models: dict[str, AssetModel] = {
asset_obj.uri: asset_obj
asset_models: dict[AssetUniqueKey, AssetModel] = {
AssetUniqueKey.from_asset(asset_obj): asset_obj
for asset_obj in session.scalars(
select(AssetModel).where(AssetModel.uri.in_(uri for uri, _ in asset_alias_names))
select(AssetModel).where(
tuple_(AssetModel.name, AssetModel.uri).in_(
(key.name, key.uri) for key, _ in asset_alias_names
)
)
)
}
if missing_assets := [Asset(uri=u) for u, _ in asset_alias_names if u not in asset_models]:
if missing_assets := [
asset_unique_key.to_asset()
for asset_unique_key, _ in asset_alias_names
if asset_unique_key not in asset_models
]:
asset_models.update(
(asset_obj.uri, asset_obj)
(AssetUniqueKey.from_asset(asset_obj), asset_obj)
for asset_obj in asset_manager.create_assets(missing_assets, session=session)
)
ti.log.warning("Created new assets for alias reference: %s", missing_assets)
session.flush() # Needed because we need the id for fk.

for (uri, extra_items), alias_names in asset_alias_names.items():
asset_obj = asset_models[uri]
for (unique_key, extra_items), alias_names in asset_alias_names.items():
asset_obj = asset_models[unique_key]
ti.log.info(
'Creating event for %r through aliases "%s"',
asset_obj,
Expand Down
2 changes: 2 additions & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DagAttributeTypes(str, Enum):
ASSET_ANY = "asset_any"
ASSET_ALL = "asset_all"
ASSET_REF = "asset_ref"
ASSET_UNIQUE_KEY = "asset_unique_key"
ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
Expand Down
66 changes: 48 additions & 18 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasUniqueKey,
AssetAll,
AssetAny,
AssetRef,
AssetUniqueKey,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
Expand Down Expand Up @@ -303,25 +305,52 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:


def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
raw_key = var.raw_key
key = var.key
return {
"key": BaseSerialization.serialize(key),
"extra": var.extra,
"asset_alias_events": [attrs.asdict(cast(attrs.AttrsInstance, e)) for e in var.asset_alias_events],
"raw_key": BaseSerialization.serialize(raw_key),
}


def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
asset_alias_events = var.get("asset_alias_events", [])

outlet_event_accessor = OutletEventAccessor(
key=BaseSerialization.deserialize(var["key"]),
extra=var["extra"],
raw_key=BaseSerialization.deserialize(var["raw_key"]),
asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events],
asset_alias_events=[
AssetAliasEvent(
source_alias_name=e["source_alias_name"],
dest_asset_key=AssetUniqueKey(
name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"]
),
extra=e["extra"],
)
for e in asset_alias_events
],
)
return outlet_event_accessor


def encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str, Any]:
return {
"__type": DAT.ASSET_EVENT_ACCESSORS,
"_dict": [
{"key": BaseSerialization.serialize(k), "value": encode_outlet_event_accessor(v)}
for k, v in var._dict.items() # type: ignore[attr-defined]
],
}


def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors:
d = OutletEventAccessors() # type: ignore[assignment]
d._dict = { # type: ignore[attr-defined]
BaseSerialization.deserialize(row["key"]): decode_outlet_event_accessor(row["value"])
for row in var["_dict"]
}
return d


def encode_timetable(var: Timetable) -> dict[str, Any]:
"""
Encode a timetable instance.
Expand Down Expand Up @@ -680,17 +709,18 @@ def serialize(
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, OutletEventAccessors):
return cls._encode(
cls.serialize(
var._dict, # type: ignore[attr-defined]
strict=strict,
use_pydantic_models=use_pydantic_models,
),
encode_outlet_event_accessors(var),
type_=DAT.ASSET_EVENT_ACCESSORS,
)
elif isinstance(var, OutletEventAccessor):
elif isinstance(var, AssetUniqueKey):
return cls._encode(
attrs.asdict(var),
type_=DAT.ASSET_UNIQUE_KEY,
)
elif isinstance(var, AssetAliasUniqueKey):
return cls._encode(
encode_outlet_event_accessor(var),
type_=DAT.ASSET_EVENT_ACCESSOR,
attrs.asdict(var),
type_=DAT.ASSET_ALIAS_UNIQUE_KEY,
)
elif isinstance(var, DAG):
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
Expand Down Expand Up @@ -855,11 +885,11 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
elif type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.ASSET_EVENT_ACCESSORS:
d = OutletEventAccessors() # type: ignore[assignment]
d._dict = cls.deserialize(var) # type: ignore[attr-defined]
return d
elif type_ == DAT.ASSET_EVENT_ACCESSOR:
return decode_outlet_event_accessor(var)
return decode_outlet_event_accessors(var)
elif type_ == DAT.ASSET_UNIQUE_KEY:
return AssetUniqueKey(name=var["name"], uri=var["uri"])
elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY:
return AssetAliasUniqueKey(name=var["name"])
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
Expand Down
Loading