diff --git a/airflow/example_dags/example_asset_alias.py b/airflow/example_dags/example_asset_alias.py index a7f2aac5845c2..2d26fa32101c2 100644 --- a/airflow/example_dags/example_asset_alias.py +++ b/airflow/example_dags/example_asset_alias.py @@ -67,7 +67,7 @@ def produce_asset_events(): def produce_asset_events_through_asset_alias(*, outlet_events=None): bucket_name = "bucket" object_path = "my-task" - outlet_events["example-alias"].add(Asset(f"s3://{bucket_name}/{object_path}")) + outlet_events[AssetAlias("example-alias")].add(Asset(f"s3://{bucket_name}/{object_path}")) produce_asset_events_through_asset_alias() diff --git a/airflow/example_dags/example_asset_alias_with_no_taskflow.py b/airflow/example_dags/example_asset_alias_with_no_taskflow.py index 19f31465ea4f8..c3d1ac0b8d14d 100644 --- a/airflow/example_dags/example_asset_alias_with_no_taskflow.py +++ b/airflow/example_dags/example_asset_alias_with_no_taskflow.py @@ -68,7 +68,7 @@ def produce_asset_events(): def produce_asset_events_through_asset_alias_with_no_taskflow(*, outlet_events=None): bucket_name = "bucket" object_path = "my-task" - outlet_events["example-alias-no-taskflow"].add(Asset(f"s3://{bucket_name}/{object_path}")) + outlet_events[AssetAlias("example-alias-no-taskflow")].add(Asset(f"s3://{bucket_name}/{object_path}")) PythonOperator( task_id="produce_asset_events_through_asset_alias_with_no_taskflow", diff --git a/airflow/example_dags/example_outlet_event_extra.py b/airflow/example_dags/example_outlet_event_extra.py index dd3041e18fc07..8b08bb5fc94a4 100644 --- a/airflow/example_dags/example_outlet_event_extra.py +++ b/airflow/example_dags/example_outlet_event_extra.py @@ -31,7 +31,7 @@ from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.asset.metadata import Metadata -ds = Asset("s3://output/1.txt") +asset = Asset(uri="s3://output/1.txt", name="test-asset") with DAG( dag_id="asset_with_extra_by_yield", @@ -41,9 +41,9 @@ tags=["produces"], ): - @task(outlets=[ds]) + @task(outlets=[asset]) def asset_with_extra_by_yield(): - yield Metadata(ds, {"hi": "bye"}) + yield Metadata(asset, {"hi": "bye"}) asset_with_extra_by_yield() @@ -55,9 +55,9 @@ def asset_with_extra_by_yield(): tags=["produces"], ): - @task(outlets=[ds]) + @task(outlets=[asset]) def asset_with_extra_by_context(*, outlet_events=None): - outlet_events[ds].extra = {"hi": "bye"} + outlet_events[asset].extra = {"hi": "bye"} asset_with_extra_by_context() @@ -70,11 +70,11 @@ def asset_with_extra_by_context(*, outlet_events=None): ): def _asset_with_extra_from_classic_operator_post_execute(context, result): - context["outlet_events"][ds].extra = {"hi": "bye"} + context["outlet_events"][asset].extra = {"hi": "bye"} BashOperator( task_id="asset_with_extra_from_classic_operator", - outlets=[ds], + outlets=[asset], bash_command=":", post_execute=_asset_with_extra_from_classic_operator_post_execute, ) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 87227f9d8c41a..6ff78617c5f1e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -60,6 +60,7 @@ inspect, or_, text, + tuple_, update, ) from sqlalchemy.dialects import postgresql @@ -100,7 +101,7 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins -from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats @@ -2735,7 +2736,7 @@ def _register_asset_changes_int( # One task only triggers one asset event for each asset with the same extra. # This tuple[asset uri, extra] to sets alias names mapping is used to find whether # there're assets with same uri but different extra that we need to emit more than one asset events. - asset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set) + asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = defaultdict(set) for obj in ti.task.outlets or []: ti.log.debug("outlet obj %s", obj) # Lineage can have other types of objects besides assets @@ -2749,26 +2750,34 @@ def _register_asset_changes_int( elif isinstance(obj, AssetAlias): for asset_alias_event in events[obj].asset_alias_events: asset_alias_name = asset_alias_event.source_alias_name - asset_uri = asset_alias_event.dest_asset_uri + asset_unique_key = asset_alias_event.dest_asset_key frozen_extra = frozenset(asset_alias_event.extra.items()) - asset_alias_names[(asset_uri, frozen_extra)].add(asset_alias_name) + asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name) - asset_models: dict[str, AssetModel] = { - asset_obj.uri: asset_obj + asset_models: dict[AssetUniqueKey, AssetModel] = { + AssetUniqueKey.from_asset(asset_obj): asset_obj for asset_obj in session.scalars( - select(AssetModel).where(AssetModel.uri.in_(uri for uri, _ in asset_alias_names)) + select(AssetModel).where( + tuple_(AssetModel.name, AssetModel.uri).in_( + (key.name, key.uri) for key, _ in asset_alias_names + ) + ) ) } - if missing_assets := [Asset(uri=u) for u, _ in asset_alias_names if u not in asset_models]: + if missing_assets := [ + asset_unique_key.to_asset() + for asset_unique_key, _ in asset_alias_names + if asset_unique_key not in asset_models + ]: asset_models.update( - (asset_obj.uri, asset_obj) + (AssetUniqueKey.from_asset(asset_obj), asset_obj) for asset_obj in asset_manager.create_assets(missing_assets, session=session) ) ti.log.warning("Created new assets for alias reference: %s", missing_assets) session.flush() # Needed because we need the id for fk. - for (uri, extra_items), alias_names in asset_alias_names.items(): - asset_obj = asset_models[uri] + for (unique_key, extra_items), alias_names in asset_alias_names.items(): + asset_obj = asset_models[unique_key] ti.log.info( 'Creating event for %r through aliases "%s"', asset_obj, diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index d1b946b38ef1e..ab7d59cb21682 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -60,6 +60,8 @@ class DagAttributeTypes(str, Enum): ASSET_ANY = "asset_any" ASSET_ALL = "asset_all" ASSET_REF = "asset_ref" + ASSET_UNIQUE_KEY = "asset_unique_key" + ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key" SIMPLE_TASK_INSTANCE = "simple_task_instance" BASE_JOB = "Job" TASK_INSTANCE = "task_instance" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 3f07f387c4961..95d2b82f55164 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -59,9 +59,11 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasUniqueKey, AssetAll, AssetAny, AssetRef, + AssetUniqueKey, BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator @@ -303,25 +305,52 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]: - raw_key = var.raw_key + key = var.key return { + "key": BaseSerialization.serialize(key), "extra": var.extra, "asset_alias_events": [attrs.asdict(cast(attrs.AttrsInstance, e)) for e in var.asset_alias_events], - "raw_key": BaseSerialization.serialize(raw_key), } def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor: asset_alias_events = var.get("asset_alias_events", []) - outlet_event_accessor = OutletEventAccessor( + key=BaseSerialization.deserialize(var["key"]), extra=var["extra"], - raw_key=BaseSerialization.deserialize(var["raw_key"]), - asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events], + asset_alias_events=[ + AssetAliasEvent( + source_alias_name=e["source_alias_name"], + dest_asset_key=AssetUniqueKey( + name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"] + ), + extra=e["extra"], + ) + for e in asset_alias_events + ], ) return outlet_event_accessor +def encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str, Any]: + return { + "__type": DAT.ASSET_EVENT_ACCESSORS, + "_dict": [ + {"key": BaseSerialization.serialize(k), "value": encode_outlet_event_accessor(v)} + for k, v in var._dict.items() # type: ignore[attr-defined] + ], + } + + +def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors: + d = OutletEventAccessors() # type: ignore[assignment] + d._dict = { # type: ignore[attr-defined] + BaseSerialization.deserialize(row["key"]): decode_outlet_event_accessor(row["value"]) + for row in var["_dict"] + } + return d + + def encode_timetable(var: Timetable) -> dict[str, Any]: """ Encode a timetable instance. @@ -680,17 +709,18 @@ def serialize( return cls._encode(json_pod, type_=DAT.POD) elif isinstance(var, OutletEventAccessors): return cls._encode( - cls.serialize( - var._dict, # type: ignore[attr-defined] - strict=strict, - use_pydantic_models=use_pydantic_models, - ), + encode_outlet_event_accessors(var), type_=DAT.ASSET_EVENT_ACCESSORS, ) - elif isinstance(var, OutletEventAccessor): + elif isinstance(var, AssetUniqueKey): + return cls._encode( + attrs.asdict(var), + type_=DAT.ASSET_UNIQUE_KEY, + ) + elif isinstance(var, AssetAliasUniqueKey): return cls._encode( - encode_outlet_event_accessor(var), - type_=DAT.ASSET_EVENT_ACCESSOR, + attrs.asdict(var), + type_=DAT.ASSET_ALIAS_UNIQUE_KEY, ) elif isinstance(var, DAG): return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG) @@ -855,11 +885,11 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: elif type_ == DAT.DICT: return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()} elif type_ == DAT.ASSET_EVENT_ACCESSORS: - d = OutletEventAccessors() # type: ignore[assignment] - d._dict = cls.deserialize(var) # type: ignore[attr-defined] - return d - elif type_ == DAT.ASSET_EVENT_ACCESSOR: - return decode_outlet_event_accessor(var) + return decode_outlet_event_accessors(var) + elif type_ == DAT.ASSET_UNIQUE_KEY: + return AssetUniqueKey(name=var["name"], uri=var["uri"]) + elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY: + return AssetAliasUniqueKey(name=var["name"]) elif type_ == DAT.DAG: return SerializedDAG.deserialize_dag(var) elif type_ == DAT.OP: diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 8a8178eda473a..67f31a3f2b6b4 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -23,21 +23,36 @@ import copy import functools import warnings -from collections.abc import Container, ItemsView, Iterator, KeysView, Mapping, MutableMapping, ValuesView +from collections.abc import ( + Container, + ItemsView, + Iterator, + KeysView, + Mapping, + MutableMapping, + ValuesView, +) from typing import ( TYPE_CHECKING, Any, SupportsIndex, + Union, ) import attrs import lazy_object_proxy -from sqlalchemy import select +from sqlalchemy import and_, select from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, fetch_active_assets_by_name -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef -from airflow.sdk.definitions.asset.metadata import extract_event_key +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasUniqueKey, + AssetRef, + AssetUniqueKey, + BaseAssetUniqueKey, +) from airflow.utils.db import LazySelectSequence from airflow.utils.types import NOTSET @@ -149,7 +164,7 @@ class AssetAliasEvent: """ source_alias_name: str - dest_asset_uri: str + dest_asset_key: AssetUniqueKey extra: dict[str, Any] @@ -161,31 +176,25 @@ class OutletEventAccessor: :meta private: """ - raw_key: str | Asset | AssetAlias + key: BaseAssetUniqueKey extra: dict[str, Any] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) - def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None: + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: """Add an AssetEvent to an existing Asset.""" - if isinstance(asset, str): - asset_uri = asset - elif isinstance(asset, Asset): - asset_uri = asset.uri - else: + if not isinstance(self.key, AssetAliasUniqueKey): return - if isinstance(self.raw_key, str): - asset_alias_name = self.raw_key - elif isinstance(self.raw_key, AssetAlias): - asset_alias_name = self.raw_key.name - else: - return - - event = AssetAliasEvent(asset_alias_name, asset_uri, extra=extra or {}) + asset_alias_name = self.key.name + event = AssetAliasEvent( + source_alias_name=asset_alias_name, + dest_asset_key=AssetUniqueKey.from_asset(asset), + extra=extra or {}, + ) self.asset_alias_events.append(event) -class OutletEventAccessors(Mapping[str, OutletEventAccessor]): +class OutletEventAccessors(Mapping[Union[Asset, AssetAlias], OutletEventAccessor]): """ Lazy mapping of outlet asset event accessors. @@ -193,22 +202,31 @@ class OutletEventAccessors(Mapping[str, OutletEventAccessor]): """ def __init__(self) -> None: - self._dict: dict[str, OutletEventAccessor] = {} + self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} def __str__(self) -> str: return f"OutletEventAccessors(_dict={self._dict})" - def __iter__(self) -> Iterator[str]: - return iter(self._dict) + def __iter__(self) -> Iterator[Asset | AssetAlias]: + return ( + key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict + ) def __len__(self) -> int: return len(self._dict) - def __getitem__(self, key: str | Asset | AssetAlias) -> OutletEventAccessor: - event_key = extract_event_key(key) - if event_key not in self._dict: - self._dict[event_key] = OutletEventAccessor(extra={}, raw_key=key) - return self._dict[event_key] + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: + hashable_key: BaseAssetUniqueKey + if isinstance(key, Asset): + hashable_key = AssetUniqueKey.from_asset(key) + elif isinstance(key, AssetAlias): + hashable_key = AssetAliasUniqueKey.from_asset_alias(key) + else: + raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") + + if hashable_key not in self._dict: + self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) + return self._dict[hashable_key] class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): @@ -228,7 +246,7 @@ def _process_row(row: Row) -> AssetEvent: @attrs.define(init=False) -class InletEventsAccessors(Mapping[str, LazyAssetEventSelectSequence]): +class InletEventsAccessors(Mapping[Union[int, Asset, AssetAlias, AssetRef], LazyAssetEventSelectSequence]): """ Lazy mapping for inlet asset events accessors. @@ -236,8 +254,8 @@ class InletEventsAccessors(Mapping[str, LazyAssetEventSelectSequence]): """ _inlets: list[Any] - _assets: dict[str, Asset] - _asset_aliases: dict[str, AssetAlias] + _assets: dict[AssetUniqueKey, Asset] + _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias] _session: Session def __init__(self, inlets: list, *, session: Session) -> None: @@ -249,23 +267,23 @@ def __init__(self, inlets: list, *, session: Session) -> None: _asset_ref_names: list[str] = [] for inlet in inlets: if isinstance(inlet, Asset): - self._assets[inlet.name] = inlet + self._assets[AssetUniqueKey.from_asset(inlet)] = inlet elif isinstance(inlet, AssetAlias): - self._asset_aliases[inlet.name] = inlet + self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet elif isinstance(inlet, AssetRef): _asset_ref_names.append(inlet.name) if _asset_ref_names: - for asset_name, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items(): - self._assets[asset_name] = asset + for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items(): + self._assets[AssetUniqueKey.from_asset(asset)] = asset - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[Asset | AssetAlias]: return iter(self._inlets) def __len__(self) -> int: return len(self._inlets) - def __getitem__(self, key: int | str | Asset | AssetAlias) -> LazyAssetEventSelectSequence: + def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> LazyAssetEventSelectSequence: if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] if not isinstance(obj, (Asset, AssetAlias, AssetRef)): @@ -273,17 +291,19 @@ def __getitem__(self, key: int | str | Asset | AssetAlias) -> LazyAssetEventSele else: obj = key - if isinstance(obj, AssetAlias): - asset_alias = self._asset_aliases[obj.name] + if isinstance(obj, Asset): + asset = self._assets[AssetUniqueKey.from_asset(obj)] + join_clause = AssetEvent.asset + where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri) + elif isinstance(obj, AssetAlias): + asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] join_clause = AssetEvent.source_aliases where_clause = AssetAliasModel.name == asset_alias.name - elif isinstance(obj, (Asset, AssetRef)): - join_clause = AssetEvent.asset - where_clause = AssetModel.name == self._assets[obj.name].name - elif isinstance(obj, str): - asset = self._assets[extract_event_key(obj)] + elif isinstance(obj, AssetRef): + # TODO: handle the case that Asset uri is different from name + asset = self._assets[AssetUniqueKey.from_asset(Asset(name=obj.name))] join_clause = AssetEvent.asset - where_clause = AssetModel.name == asset.name + where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri) else: raise ValueError(key) diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 19b6500809931..b2fdc98a7459a 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -39,7 +39,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef, AssetUniqueKey, BaseAssetUniqueKey from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.typing_compat import TypedDict @@ -59,27 +59,29 @@ class ConnectionAccessor: class AssetAliasEvent: source_alias_name: str - dest_asset_uri: str + dest_asset_key: AssetUniqueKey extra: dict[str, Any] - def __init__(self, source_alias_name: str, dest_asset_uri: str, extra: dict[str, Any]) -> None: ... + def __init__( + self, source_alias_name: str, dest_asset_key: AssetUniqueKey, extra: dict[str, Any] + ) -> None: ... class OutletEventAccessor: def __init__( self, *, + key: BaseAssetUniqueKey, extra: dict[str, Any], - raw_key: str | Asset | AssetAlias, asset_alias_events: list[AssetAliasEvent], ) -> None: ... - def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None: ... + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... + key: BaseAssetUniqueKey extra: dict[str, Any] - raw_key: str | Asset | AssetAlias asset_alias_events: list[AssetAliasEvent] -class OutletEventAccessors(Mapping[str, OutletEventAccessor]): - def __iter__(self) -> Iterator[str]: ... +class OutletEventAccessors(Mapping[Asset | AssetAlias, OutletEventAccessor]): + def __iter__(self) -> Iterator[Asset | AssetAlias]: ... def __len__(self) -> int: ... - def __getitem__(self, key: str | Asset | AssetAlias) -> OutletEventAccessor: ... + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: ... class InletEventsAccessor(Sequence[AssetEvent]): @overload @@ -88,11 +90,11 @@ class InletEventsAccessor(Sequence[AssetEvent]): def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ... def __len__(self) -> int: ... -class InletEventsAccessors(Mapping[str, InletEventsAccessor]): +class InletEventsAccessors(Mapping[Asset | AssetAlias, InletEventsAccessor]): def __init__(self, inlets: list, *, session: Session) -> None: ... - def __iter__(self) -> Iterator[str]: ... + def __iter__(self) -> Iterator[Asset | AssetAlias]: ... def __len__(self) -> int: ... - def __getitem__(self, key: int | str | Asset | AssetAlias) -> InletEventsAccessor: ... + def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> InletEventsAccessor: ... # NOTE: Please keep this in sync with the following: # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index e5e304bb4a414..93bc9e53daf38 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -276,10 +276,10 @@ def _run(): for metadata in _run(): if isinstance(metadata, Metadata): - outlet_events[metadata.uri].extra.update(metadata.extra) + outlet_events[metadata.asset].extra.update(metadata.extra) - if metadata.alias_name: - outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra) + if metadata.alias: + outlet_events[metadata.alias].add(metadata.asset, extra=metadata.extra) continue logger.warning("Ignoring unknown data of %r received from task", type(metadata)) diff --git a/docs/apache-airflow/authoring-and-scheduling/datasets.rst b/docs/apache-airflow/authoring-and-scheduling/datasets.rst index 9e777d9299587..01b31fb89965b 100644 --- a/docs/apache-airflow/authoring-and-scheduling/datasets.rst +++ b/docs/apache-airflow/authoring-and-scheduling/datasets.rst @@ -251,7 +251,7 @@ The easiest way to attach extra information to the asset event is by ``yield``-i from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.asset.metadata import Metadata - example_s3_asset = Asset("s3://asset/example.csv") + example_s3_asset = Asset(uri="s3://asset/example.csv", name="example_s3") @task(outlets=[example_s3_asset]) @@ -445,7 +445,7 @@ The following example creates an asset event against the S3 URI ``f"s3://bucket/ @task(outlets=[AssetAlias("my-task-outputs")]) def my_task_with_outlet_events(*, outlet_events): - outlet_events["my-task-outputs"].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) + outlet_events[AssetAlias("my-task-outputs")].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) **Emit an asset event during task execution through yielding Metadata** @@ -457,7 +457,7 @@ The following example creates an asset event against the S3 URI ``f"s3://bucket/ @task(outlets=[AssetAlias("my-task-outputs")]) def my_task_with_metadata(): - s3_asset = Asset("s3://bucket/my-task") + s3_asset = Asset(uri="s3://bucket/my-task", name="example_s3") yield Metadata(s3_asset, extra={"k": "v"}, alias="my-task-outputs") Only one asset event is emitted for an added asset, even if it is added to the alias multiple times, or added to multiple aliases. However, if different ``extra`` values are passed, it can emit multiple asset events. In the following example, two asset events will be emitted. @@ -475,11 +475,11 @@ Only one asset event is emitted for an added asset, even if it is added to the a ] ) def my_task_with_outlet_events(*, outlet_events): - outlet_events["my-task-outputs-1"].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) + outlet_events[AssetAlias("my-task-outputs-1")].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) # This line won't emit an additional asset event as the asset and extra are the same as the previous line. - outlet_events["my-task-outputs-2"].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) + outlet_events[AssetAlias("my-task-outputs-2")].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) # This line will emit an additional asset event as the extra is different. - outlet_events["my-task-outputs-3"].add(Asset("s3://bucket/my-task"), extra={"k2": "v2"}) + outlet_events[AssetAlias("my-task-outputs-3")].add(Asset("s3://bucket/my-task"), extra={"k2": "v2"}) Scheduling based on asset aliases ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -500,7 +500,7 @@ The asset alias is resolved to the assets during DAG parsing. Thus, if the "min_ @task(outlets=[AssetAlias("example-alias")]) def produce_asset_events(*, outlet_events): - outlet_events["example-alias"].add(Asset("s3://bucket/my-task")) + outlet_events[AssetAlias("example-alias")].add(Asset("s3://bucket/my-task")) with DAG(dag_id="asset-consumer", schedule=Asset("s3://bucket/my-task")): @@ -524,7 +524,7 @@ As mentioned in :ref:`Fetching information from previously emitted asset events< @task(outlets=[AssetAlias("example-alias")]) def produce_asset_events(*, outlet_events): - outlet_events["example-alias"].add(Asset("s3://bucket/my-task"), extra={"row_count": 1}) + outlet_events[AssetAlias("example-alias")].add(Asset("s3://bucket/my-task"), extra={"row_count": 1}) with DAG(dag_id="asset-alias-consumer", schedule=None): diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 9f5b85ccb161b..a0beb24150f7d 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -22,7 +22,7 @@ import os import urllib.parse import warnings -from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Union, overload import attrs @@ -32,6 +32,7 @@ from collections.abc import Iterable, Iterator from urllib.parse import SplitResult + from airflow.models.asset import AssetModel from airflow.triggers.base import BaseTrigger @@ -49,14 +50,45 @@ log = logging.getLogger(__name__) -class AssetUniqueKey(NamedTuple): +@attrs.define(frozen=True) +class AssetUniqueKey: + """ + Columns to identify an unique asset. + + :meta private: + """ + name: str uri: str @staticmethod - def from_asset(asset: Asset) -> AssetUniqueKey: + def from_asset(asset: Asset | AssetModel) -> AssetUniqueKey: return AssetUniqueKey(name=asset.name, uri=asset.uri) + def to_asset(self) -> Asset: + return Asset(name=self.name, uri=self.uri) + + +@attrs.define(frozen=True) +class AssetAliasUniqueKey: + """ + Columns to identify an unique asset alias. + + :meta private: + """ + + name: str + + @staticmethod + def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasUniqueKey: + return AssetAliasUniqueKey(name=asset_alias.name) + + def to_asset_alias(self) -> AssetAlias: + return AssetAlias(name=self.name) + + +BaseAssetUniqueKey = Union[AssetUniqueKey, AssetAliasUniqueKey] + def normalize_noop(parts: SplitResult) -> SplitResult: """ diff --git a/task_sdk/src/airflow/sdk/definitions/asset/metadata.py b/task_sdk/src/airflow/sdk/definitions/asset/metadata.py index 0881919703963..6639389c7ee97 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/metadata.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/metadata.py @@ -17,53 +17,20 @@ from __future__ import annotations -from typing import ( - Any, -) +from typing import TYPE_CHECKING, Any import attrs -from airflow.sdk.definitions.asset import Asset, AssetAlias, _sanitize_uri +if TYPE_CHECKING: + from airflow.sdk.definitions.asset import Asset, AssetAlias -__all__ = ["Metadata", "extract_event_key"] +__all__ = ["Metadata"] -def extract_event_key(value: str | Asset | AssetAlias) -> str: - """ - Extract the key of an inlet or an outlet event. - - If the input value is a string, it is treated as a URI and sanitized. If the - input is a :class:`Asset`, the URI it contains is considered sanitized and - returned directly. If the input is a :class:`AssetAlias`, the name it contains - will be returned directly. - - :meta private: - """ - if isinstance(value, AssetAlias): - return value.name - - if isinstance(value, Asset): - return value.uri - return _sanitize_uri(str(value)) - - -@attrs.define(init=False) +@attrs.define(init=True) class Metadata: """Metadata to attach to an AssetEvent.""" - uri: str + asset: Asset extra: dict[str, Any] - alias_name: str | None = None - - def __init__( - self, - target: str | Asset, - extra: dict[str, Any], - alias: AssetAlias | str | None = None, - ) -> None: - self.uri = extract_event_key(target) - self.extra = extra - if isinstance(alias, AssetAlias): - self.alias_name = alias.name - else: - self.alias_name = alias + alias: AssetAlias | None = None diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index 55439fb1bcd84..afdfb37a5fd5d 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -29,6 +29,7 @@ AssetAlias, AssetAll, AssetAny, + AssetUniqueKey, BaseAsset, Dataset, Model, @@ -157,7 +158,7 @@ def test_asset_logic_operations(): def test_asset_iter_assets(): - assert list(asset1.iter_assets()) == [(("asset-1", "s3://bucket1/data1"), asset1)] + assert list(asset1.iter_assets()) == [(AssetUniqueKey("asset-1", "s3://bucket1/data1"), asset1)] def test_asset_iter_asset_aliases(): @@ -212,12 +213,12 @@ def test_assset_boolean_condition_evaluate_iter(): assets_any = dict(any_condition.iter_assets()) assets_all = dict(all_condition.iter_assets()) assert assets_any == { - ("asset-1", "s3://bucket1/data1"): asset1, - ("asset-2", "s3://bucket2/data2"): asset2, + AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1, + AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2, } assert assets_all == { - ("asset-1", "s3://bucket1/data1"): asset1, - ("asset-2", "s3://bucket2/data2"): asset2, + AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1, + AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2, } diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index cd78882ebf1a0..b1b8370f1881f 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -77,7 +77,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.python import PythonSensor -from airflow.sdk.definitions.asset import AssetAlias +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.sensors.base import BaseSensorOperator from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.stats import Stats @@ -2405,12 +2405,12 @@ def test_outlet_asset_extra(self, dag_maker, session): @task(outlets=Asset("test_outlet_asset_extra_1")) def write1(*, outlet_events): - outlet_events["test_outlet_asset_extra_1"].extra = {"foo": "bar"} + outlet_events[Asset("test_outlet_asset_extra_1")].extra = {"foo": "bar"} write1() def _write2_post_execute(context, _): - context["outlet_events"]["test_outlet_asset_extra_2"].extra = {"x": 1} + context["outlet_events"][Asset("test_outlet_asset_extra_2")].extra = {"x": 1} BashOperator( task_id="write2", @@ -2446,8 +2446,8 @@ def test_outlet_asset_extra_ignore_different(self, dag_maker, session): @task(outlets=Asset("test_outlet_asset_extra")) def write(*, outlet_events): - outlet_events["test_outlet_asset_extra"].extra = {"one": 1} - outlet_events["different_uri"].extra = {"foo": "bar"} # Will be silently dropped. + outlet_events[Asset("test_outlet_asset_extra")].extra = {"one": 1} + outlet_events[Asset("different_uri")].extra = {"foo": "bar"} # Will be silently dropped. write() @@ -2469,18 +2469,18 @@ def test_outlet_asset_extra_yield(self, dag_maker, session): @task(outlets=Asset("test_outlet_asset_extra_1")) def write1(): result = "write_1 result" - yield Metadata("test_outlet_asset_extra_1", {"foo": "bar"}) + yield Metadata(Asset(name="test_outlet_asset_extra_1"), {"foo": "bar"}) return result write1() def _write2_post_execute(context, result): - yield Metadata("test_outlet_asset_extra_2", {"x": 1}) + yield Metadata(Asset(name="test_outlet_asset_extra_2", uri="test://asset-2"), extra={"x": 1}) BashOperator( task_id="write2", bash_command=":", - outlets=Asset("test_outlet_asset_extra_2"), + outlets=Asset(name="test_outlet_asset_extra_2", uri="test://asset-2"), post_execute=_write2_post_execute, ) @@ -2501,12 +2501,14 @@ def _write2_post_execute(context, result): assert events["write1"].source_run_id == dr.run_id assert events["write1"].source_task_id == "write1" assert events["write1"].asset.uri == "test_outlet_asset_extra_1" + assert events["write1"].asset.name == "test_outlet_asset_extra_1" assert events["write1"].extra == {"foo": "bar"} assert events["write2"].source_dag_id == dr.dag_id assert events["write2"].source_run_id == dr.run_id assert events["write2"].source_task_id == "write2" - assert events["write2"].asset.uri == "test_outlet_asset_extra_2" + assert events["write2"].asset.uri == "test://asset-2/" + assert events["write2"].asset.name == "test_outlet_asset_extra_2" assert events["write2"].extra == {"x": 1} def test_outlet_asset_alias(self, dag_maker, session): @@ -2523,7 +2525,7 @@ def test_outlet_asset_alias(self, dag_maker, session): @task(outlets=AssetAlias(alias_name_1)) def producer(*, outlet_events): - outlet_events[alias_name_1].add(Asset(asset_uri)) + outlet_events[AssetAlias(alias_name_1)].add(Asset(asset_uri)) producer() @@ -2579,9 +2581,9 @@ def test_outlet_multiple_asset_alias(self, dag_maker, session): ] ) def producer(*, outlet_events): - outlet_events[asset_alias_name_1].add(Asset(asset_uri)) - outlet_events[asset_alias_name_2].add(Asset(asset_uri)) - outlet_events[asset_alias_name_3].add(Asset(asset_uri), extra={"k": "v"}) + outlet_events[AssetAlias(asset_alias_name_1)].add(Asset(asset_uri)) + outlet_events[AssetAlias(asset_alias_name_2)].add(Asset(asset_uri)) + outlet_events[AssetAlias(asset_alias_name_3)].add(Asset(asset_uri), extra={"k": "v"}) producer() @@ -2645,7 +2647,7 @@ def test_outlet_asset_alias_through_metadata(self, dag_maker, session): @task(outlets=AssetAlias(asset_alias_name)) def producer(*, outlet_events): - yield Metadata(asset_uri, extra={"key": "value"}, alias=asset_alias_name) + yield Metadata(Asset(asset_uri), extra={"key": "value"}, alias=AssetAlias(asset_alias_name)) producer() @@ -2684,7 +2686,7 @@ def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): @task(outlets=AssetAlias(asset_alias_name)) def producer(*, outlet_events): - outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"key": "value"}) + outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"key": "value"}) producer() @@ -2722,22 +2724,24 @@ def test_inlet_asset_extra(self, dag_maker, session): @task(outlets=Asset("test_inlet_asset_extra")) def write(*, ti, outlet_events): - outlet_events["test_inlet_asset_extra"].extra = {"from": ti.task_id} + outlet_events[Asset("test_inlet_asset_extra")].extra = {"from": ti.task_id} + with pytest.raises(TypeError): + outlet_events["test_inlet_asset_extra"] @task(inlets=Asset("test_inlet_asset_extra")) def read(*, inlet_events): - second_event = inlet_events["test_inlet_asset_extra"][1] + second_event = inlet_events[Asset("test_inlet_asset_extra")][1] assert second_event.uri == "test_inlet_asset_extra" assert second_event.extra == {"from": "write2"} - last_event = inlet_events["test_inlet_asset_extra"][-1] + last_event = inlet_events[Asset("test_inlet_asset_extra")][-1] assert last_event.uri == "test_inlet_asset_extra" assert last_event.extra == {"from": "write3"} with pytest.raises(KeyError): - inlet_events["does_not_exist"] + inlet_events[Asset("does_not_exist")] with pytest.raises(IndexError): - inlet_events["test_inlet_asset_extra"][5] + inlet_events[Asset("test_inlet_asset_extra")][5] # TODO: Support slices. @@ -2770,7 +2774,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session): asset_uri = "test_inlet_asset_extra_ds" asset_alias_name = "test_inlet_asset_extra_asset_alias" - asset_model = AssetModel(id=1, uri=asset_uri) + asset_model = AssetModel(id=1, uri=asset_uri, group="asset") asset_alias_model = AssetAliasModel(name=asset_alias_name) asset_alias_model.assets.append(asset_model) session.add_all([asset_model, asset_alias_model]) @@ -2784,7 +2788,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session): @task(outlets=AssetAlias(asset_alias_name)) def write(*, ti, outlet_events): - outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"from": ti.task_id}) + outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), extra={"from": ti.task_id}) @task(inlets=AssetAlias(asset_alias_name)) def read(*, inlet_events): @@ -2797,7 +2801,7 @@ def read(*, inlet_events): assert last_event.extra == {"from": "write3"} with pytest.raises(KeyError): - inlet_events["does_not_exist"] + inlet_events[Asset("does_not_exist")] with pytest.raises(KeyError): inlet_events[AssetAlias("does_not_exist")] with pytest.raises(IndexError): @@ -2873,7 +2877,7 @@ def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected): @task(outlets=Asset(asset_uri)) def write(*, params, outlet_events): - outlet_events[asset_uri].extra = {"from": params["i"]} + outlet_events[Asset(asset_uri)].extra = {"from": params["i"]} write() @@ -2893,7 +2897,7 @@ def write(*, params, outlet_events): @task(inlets=Asset(asset_uri)) def read(*, inlet_events): nonlocal result - result = [e.extra for e in slicer(inlet_events[asset_uri])] + result = [e.extra for e in slicer(inlet_events[Asset(asset_uri)])] read() @@ -2933,7 +2937,7 @@ def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expecte @task(outlets=AssetAlias(asset_alias_name)) def write(*, params, outlet_events): - outlet_events[asset_alias_name].add(Asset(asset_uri), {"from": params["i"]}) + outlet_events[AssetAlias(asset_alias_name)].add(Asset(asset_uri), {"from": params["i"]}) write() diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 5c53304e7af32..8100c2a84bc74 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -49,7 +49,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic @@ -150,6 +150,16 @@ class Test: DAG_RUN.id = 1 +def create_outlet_event_accessors( + key: Asset | AssetAlias, extra: dict, asset_alias_events: list[AssetAliasEvent] +) -> OutletEventAccessors: + o = OutletEventAccessors() + o[key].extra = extra + o[key].asset_alias_events = asset_alias_events + + return o + + def equals(a, b) -> bool: return a == b @@ -162,8 +172,15 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool: return a.__class__ == b.__class__ and str(a) == str(b) +def equal_outlet_event_accessors(a: OutletEventAccessors, b: OutletEventAccessors) -> bool: + return a._dict.keys() == b._dict.keys() and all( # type: ignore[attr-defined] + equal_outlet_event_accessor(a._dict[key], b._dict[key]) # type: ignore[attr-defined] + for key in a._dict # type: ignore[attr-defined] + ) + + def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool: - return a.raw_key == b.raw_key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events + return a.key == b.key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events class MockLazySelectSequence(LazySelectSequence): @@ -257,33 +274,26 @@ def __len__(self) -> int: lambda a, b: a.get_uri() == b.get_uri(), ), ( - OutletEventAccessor( - raw_key=Asset(uri="test://asset1", name="test", group="test-group"), - extra={"key": "value"}, - asset_alias_events=[], + create_outlet_event_accessors( + Asset(uri="test", name="test", group="test-group"), {"key": "value"}, [] ), - DAT.ASSET_EVENT_ACCESSOR, - equal_outlet_event_accessor, + DAT.ASSET_EVENT_ACCESSORS, + equal_outlet_event_accessors, ), ( - OutletEventAccessor( - raw_key=AssetAlias(name="test_alias", group="test-alias-group"), - extra={"key": "value"}, - asset_alias_events=[ + create_outlet_event_accessors( + AssetAlias(name="test_alias", group="test-alias-group"), + {"key": "value"}, + [ AssetAliasEvent( source_alias_name="test_alias", - dest_asset_uri="test_uri", + dest_asset_key=AssetUniqueKey(name="test_name", uri="test://asset-uri"), extra={}, ) ], ), - DAT.ASSET_EVENT_ACCESSOR, - equal_outlet_event_accessor, - ), - ( - OutletEventAccessor(raw_key="test", extra={"key": "value"}, asset_alias_events=[]), - DAT.ASSET_EVENT_ACCESSOR, - equal_outlet_event_accessor, + DAT.ASSET_EVENT_ACCESSORS, + equal_outlet_event_accessors, ), ( AirflowException("test123 wohoo!"), @@ -525,12 +535,14 @@ def test_serialized_mapped_operator_unmap(dag_maker): def test_ser_of_asset_event_accessor(): # todo: (Airflow 3.0) we should force reserialization on upgrade d = OutletEventAccessors() - d["hi"].extra = "blah1" # todo: this should maybe be forbidden? i.e. can extra be any json or just dict? - d["yo"].extra = {"this": "that", "the": "other"} + d[ + Asset("hi") + ].extra = "blah1" # todo: this should maybe be forbidden? i.e. can extra be any json or just dict? + d[Asset(name="yo", uri="test://yo")].extra = {"this": "that", "the": "other"} ser = BaseSerialization.serialize(var=d) deser = BaseSerialization.deserialize(ser) - assert deser["hi"].extra == "blah1" - assert d["yo"].extra == {"this": "that", "the": "other"} + assert deser[Asset(uri="hi", name="hi")].extra == "blah1" + assert d[Asset(name="yo", uri="test://yo")].extra == {"this": "that", "the": "other"} class MyTrigger(BaseTrigger): diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 538b74119204b..3388a33845a4b 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -21,58 +21,77 @@ import pytest from airflow.models.asset import AssetAliasModel, AssetModel -from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors class TestOutletEventAccessor: @pytest.mark.parametrize( - "raw_key, asset_alias_events", + "key, asset_alias_events", ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), ( - AssetAlias("test_alias"), - [AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={})], + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(uri="test_uri", name="test_uri"), + extra={}, + ) + ], ), - (Asset("test_uri"), []), ), ) - def test_add(self, raw_key, asset_alias_events): - outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={}) + def test_add(self, key, asset_alias_events): + outlet_event_accessor = OutletEventAccessor(key=key, extra={}) outlet_event_accessor.add(Asset("test_uri")) assert outlet_event_accessor.asset_alias_events == asset_alias_events @pytest.mark.db_test @pytest.mark.parametrize( - "raw_key, asset_alias_events", + "key, asset_alias_events", ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), ( - AssetAlias("test_alias"), - [AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={})], + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"), + extra={}, + ) + ], ), - ( - "test_alias", - [AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={})], - ), - (Asset("test_uri"), []), ), ) - def test_add_with_db(self, raw_key, asset_alias_events, session): - asm = AssetModel(uri="test_uri") + def test_add_with_db(self, key, asset_alias_events, session): + asm = AssetModel(uri="test://asset-uri", name="test-asset", group="asset") aam = AssetAliasModel(name="test_alias") session.add_all([asm, aam]) session.flush() + asset = Asset(uri="test://asset-uri", name="test-asset") - outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={"not": ""}) - outlet_event_accessor.add("test_uri", extra={}) + outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) + outlet_event_accessor.add(asset, extra={}) assert outlet_event_accessor.asset_alias_events == asset_alias_events class TestOutletEventAccessors: - @pytest.mark.parametrize("key", ("test", Asset("test"), AssetAlias("test_alias"))) - def test____get_item___dict_key_not_exists(self, key): + @pytest.mark.parametrize( + "access_key, internal_key", + ( + (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))), + ( + Asset(name="test", uri="test://asset"), + AssetUniqueKey.from_asset(Asset(name="test", uri="test://asset")), + ), + (AssetAlias("test_alias"), AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))), + ), + ) + def test___get_item__dict_key_not_exists(self, access_key, internal_key): outlet_event_accessors = OutletEventAccessors() assert len(outlet_event_accessors) == 0 - outlet_event_accessor = outlet_event_accessors[key] + outlet_event_accessor = outlet_event_accessors[access_key] assert len(outlet_event_accessors) == 1 - assert outlet_event_accessor.raw_key == key + assert outlet_event_accessor.key == internal_key assert outlet_event_accessor.extra == {}