From 411362d10ffdf5b4014c08d2770aeff87891bd16 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 4 Dec 2024 17:09:19 +0800 Subject: [PATCH] feat(utils/context): add back string access to inlet and outlet events --- airflow/models/taskinstance.py | 6 +- airflow/utils/context.py | 87 ++++++++++++++----- airflow/utils/context.pyi | 16 ++-- tests/models/test_taskinstance.py | 2 +- .../serialization/test_serialized_objects.py | 14 ++- tests/utils/test_context.py | 21 +++-- 6 files changed, 103 insertions(+), 43 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index e53c963b92466..79912a5f71a0e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2757,7 +2757,11 @@ def _register_asset_changes_int( asset_models: dict[AssetUniqueKey, AssetModel] = { AssetUniqueKey.from_asset(asset_obj): asset_obj for asset_obj in session.scalars( - select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(asset_alias_names)) + select(AssetModel).where( + tuple_(AssetModel.name, AssetModel.uri).in_( + (key.name, key.uri) for key, _ in asset_alias_names + ) + ) ) } if missing_assets := [ diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 600393a78f40c..f811c875ca63b 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -36,11 +36,12 @@ TYPE_CHECKING, Any, SupportsIndex, + Union, ) import attrs import lazy_object_proxy -from sqlalchemy import select +from sqlalchemy import and_, select from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, fetch_active_assets_by_name @@ -177,17 +178,22 @@ class OutletEventAccessor: :meta private: """ - key: BaseAssetUniqueKey + key: str | BaseAssetUniqueKey extra: dict[str, Any] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) - def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: + def add(self, asset: str | Asset, extra: dict[str, Any] | None = None) -> None: """Add an AssetEvent to an existing Asset.""" - if not isinstance(asset, Asset): + if isinstance(asset, str): + asset = Asset(asset) + elif not isinstance(asset, Asset): return if isinstance(self.key, AssetAliasUniqueKey): asset_alias_name = self.key.name + elif isinstance(self.key, str): + # TODO: deprecate string access + asset_alias_name = self.key else: return @@ -199,7 +205,7 @@ def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: self.asset_alias_events.append(event) -class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]): +class OutletEventAccessors(Mapping[Union[str, BaseAsset], OutletEventAccessor]): """ Lazy mapping of outlet asset event accessors. @@ -207,28 +213,53 @@ class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]): """ def __init__(self) -> None: - self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} + self._dict: dict[str | BaseAssetUniqueKey, OutletEventAccessor] = {} def __str__(self) -> str: return f"OutletEventAccessors(_dict={self._dict})" - def __iter__(self) -> Iterator[BaseAsset]: - return iter(key.to_obj() for key in self._dict) + def __iter__(self) -> Iterator[str | BaseAsset]: + return iter(key.to_obj() if isinstance(key, BaseAssetUniqueKey) else key for key in self._dict) def __len__(self) -> int: return len(self._dict) - def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: - hashable_key: BaseAssetUniqueKey + def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor: + hashable_key: str | BaseAssetUniqueKey + # TODO: Remove it once string accessing is deprecated. + # We currently still support accessing through string. + # Asset("abc") and "abc" returns the same thing. + # Thus, if an user pass Asset("abc"), we need to also check whether "abc" is in dict. + # Same for alias. + potential_equivalent_key = None if isinstance(key, Asset): hashable_key = AssetUniqueKey.from_asset(key) + # TODO: remove after deprecating string accessing + if key.name == key.uri and key.name in self._dict: + potential_equivalent_key = key.name elif isinstance(key, AssetAlias): hashable_key = AssetAliasUniqueKey.from_asset_alias(key) + # TODO: remove after deprecating string accessing + if key.name in self._dict: + potential_equivalent_key = key.name + elif isinstance(key, str): + # TODO: remove after deprecating string accessing + hashable_key = key else: raise KeyError("Key should be either an asset or an asset alias") - - if hashable_key not in self._dict: + # TODO: remove after deprecating string accessing + if key.name in self._dict: + potential_equivalent_key = key.name + + if ( + hashable_key not in self._dict + and not potential_equivalent_key + and potential_equivalent_key not in self._dict + ): self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) + elif potential_equivalent_key: + # TODO: remove after deprecating string accessing + hashable_key = potential_equivalent_key return self._dict[hashable_key] @@ -249,7 +280,7 @@ def _process_row(row: Row) -> AssetEvent: @attrs.define(init=False) -class InletEventsAccessors(Mapping[BaseAsset, LazyAssetEventSelectSequence]): +class InletEventsAccessors(Mapping[Union[str, int, BaseAsset], LazyAssetEventSelectSequence]): """ Lazy mapping for inlet asset events accessors. @@ -257,8 +288,8 @@ class InletEventsAccessors(Mapping[BaseAsset, LazyAssetEventSelectSequence]): """ _inlets: list[Any] - _assets: dict[str, Asset] - _asset_aliases: dict[str, AssetAlias] + _assets: dict[AssetUniqueKey, Asset] + _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias] _session: Session def __init__(self, inlets: list, *, session: Session) -> None: @@ -270,15 +301,15 @@ def __init__(self, inlets: list, *, session: Session) -> None: _asset_ref_names: list[str] = [] for inlet in inlets: if isinstance(inlet, Asset): - self._assets[inlet.name] = inlet + self._assets[AssetUniqueKey.from_asset(inlet)] = inlet elif isinstance(inlet, AssetAlias): - self._asset_aliases[inlet.name] = inlet + self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet elif isinstance(inlet, AssetRef): _asset_ref_names.append(inlet.name) if _asset_ref_names: - for asset_name, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items(): - self._assets[asset_name] = asset + for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items(): + self._assets[AssetUniqueKey.from_asset(asset)] = asset def __iter__(self) -> Iterator[BaseAsset]: return iter(self._inlets) @@ -286,7 +317,7 @@ def __iter__(self) -> Iterator[BaseAsset]: def __len__(self) -> int: return len(self._inlets) - def __getitem__(self, key: int | BaseAsset) -> LazyAssetEventSelectSequence: + def __getitem__(self, key: int | str | BaseAsset) -> LazyAssetEventSelectSequence: if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] if not isinstance(obj, (Asset, AssetAlias, AssetRef)): @@ -295,14 +326,22 @@ def __getitem__(self, key: int | BaseAsset) -> LazyAssetEventSelectSequence: obj = key if isinstance(obj, AssetAlias): - asset_alias = self._asset_aliases[obj.name] + asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] join_clause = AssetEvent.source_aliases where_clause = AssetAliasModel.name == asset_alias.name - elif isinstance(obj, (Asset, AssetRef)): + elif isinstance(obj, Asset): + asset = self._assets[AssetUniqueKey.from_asset(obj)] + join_clause = AssetEvent.asset + where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri) + elif isinstance(obj, AssetRef): + # TODO: handle the case that Asset uri is different from name + asset = self._assets[AssetUniqueKey.from_asset(Asset(name=obj.name))] join_clause = AssetEvent.asset - where_clause = AssetModel.name == self._assets[obj.name].name + where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri) elif isinstance(obj, str): - asset = self._assets[extract_event_key(obj)] + # TODO: deprecate string access + asset_name = extract_event_key(obj) + asset = self._assets[AssetUniqueKey.from_asset(Asset(name=asset_name))] join_clause = AssetEvent.asset where_clause = AssetModel.name == asset.name else: diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index bc0964f713aaa..e95e1bac94805 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -39,7 +39,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset, BaseAssetUniqueKey +from airflow.sdk.definitions.asset import Asset, AssetUniqueKey, BaseAsset, BaseAssetUniqueKey from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.typing_compat import TypedDict @@ -70,18 +70,18 @@ class OutletEventAccessor: self, *, extra: dict[str, Any], - key: BaseAssetUniqueKey, + key: str | BaseAssetUniqueKey, asset_alias_events: list[AssetAliasEvent], ) -> None: ... def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... extra: dict[str, Any] - key: BaseAssetUniqueKey + key: str | BaseAssetUniqueKey asset_alias_events: list[AssetAliasEvent] -class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]): +class OutletEventAccessors(Mapping[str | BaseAsset, OutletEventAccessor]): def __iter__(self) -> Iterator[BaseAsset]: ... def __len__(self) -> int: ... - def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: ... + def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor: ... class InletEventsAccessor(Sequence[AssetEvent]): @overload @@ -90,11 +90,11 @@ class InletEventsAccessor(Sequence[AssetEvent]): def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ... def __len__(self) -> int: ... -class InletEventsAccessors(Mapping[Asset | AssetAlias, InletEventsAccessor]): +class InletEventsAccessors(Mapping[str | BaseAsset, InletEventsAccessor]): def __init__(self, inlets: list, *, session: Session) -> None: ... - def __iter__(self) -> Iterator[Asset | AssetAlias]: ... + def __iter__(self) -> Iterator[BaseAsset]: ... def __len__(self) -> int: ... - def __getitem__(self, key: int | Asset | AssetAlias) -> InletEventsAccessor: ... + def __getitem__(self, key: int | str | BaseAsset) -> InletEventsAccessor: ... # NOTE: Please keep this in sync with the following: # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e6563c4417494..071ad6e0d08f6 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2770,7 +2770,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session): asset_uri = "test_inlet_asset_extra_ds" asset_alias_name = "test_inlet_asset_extra_asset_alias" - asset_model = AssetModel(id=1, uri=asset_uri) + asset_model = AssetModel(id=1, uri=asset_uri, group="asset") asset_alias_model = AssetAliasModel(name=asset_alias_name) asset_alias_model.assets.append(asset_model) session.add_all([asset_model, asset_alias_model]) diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index b96a1fd230c86..34ff88c844812 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -49,7 +49,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic @@ -258,7 +258,7 @@ def __len__(self) -> int: ), ( OutletEventAccessor( - key=Asset(uri="test", name="test", group="test-group"), + key=AssetUniqueKey.from_asset(Asset(uri="test", name="test", group="test-group")), extra={"key": "value"}, asset_alias_events=[], ), @@ -267,7 +267,9 @@ def __len__(self) -> int: ), ( OutletEventAccessor( - key=AssetAlias(name="test_alias", group="test-alias-group"), + key=AssetAliasUniqueKey.from_asset_alias( + AssetAlias(name="test_alias", group="test-alias-group") + ), extra={"key": "value"}, asset_alias_events=[ AssetAliasEvent( @@ -280,6 +282,12 @@ def __len__(self) -> int: DAT.ASSET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), + # TODO: deprecate string access + ( + OutletEventAccessor(key="test", extra={"key": "value"}, asset_alias_events=[]), + DAT.ASSET_EVENT_ACCESSOR, + equal_outlet_event_accessor, + ), ( AirflowException("test123 wohoo!"), DAT.AIRFLOW_EXC_SER, diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 1f6ff0c8d77eb..7b3711eb52124 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -29,6 +29,7 @@ class TestOutletEventAccessor: @pytest.mark.parametrize( "key, asset_alias_events", ( + (AssetUniqueKey.from_asset(Asset("test_uri")), []), ( AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), [ @@ -39,7 +40,6 @@ class TestOutletEventAccessor: ) ], ), - (AssetUniqueKey.from_asset(Asset("test_uri")), []), ), ) def test_add(self, key, asset_alias_events): @@ -52,7 +52,7 @@ def test_add(self, key, asset_alias_events): "key, asset_alias_events", ( ( - AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + "test_alias", [ AssetAliasEvent( source_alias_name="test_alias", @@ -62,6 +62,16 @@ def test_add(self, key, asset_alias_events): ], ), (AssetUniqueKey.from_asset(Asset("test_uri")), []), + ( + AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")), + [ + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"), + extra={}, + ) + ], + ), ), ) def test_add_with_db(self, key, asset_alias_events, session): @@ -69,19 +79,18 @@ def test_add_with_db(self, key, asset_alias_events, session): aam = AssetAliasModel(name="test_alias") session.add_all([asm, aam]) session.flush() + asset = Asset(uri="test://asset-uri", name="test-asset") outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) - outlet_event_accessor.add(Asset(uri="test://asset-uri", name="test-asset"), extra={}) + outlet_event_accessor.add(asset, extra={}) assert outlet_event_accessor.asset_alias_events == asset_alias_events -# TODO: add test case to verify string does not work - - class TestOutletEventAccessors: @pytest.mark.parametrize( "access_key, internal_key", ( + ("test", "test"), (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))), ( Asset(name="test", uri="test://asset"),