Skip to content

Commit

Permalink
feat(utils/context): add back string access to inlet and outlet events
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Dec 6, 2024
1 parent 43aa6f6 commit 411362d
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 43 deletions.
6 changes: 5 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2757,7 +2757,11 @@ def _register_asset_changes_int(
asset_models: dict[AssetUniqueKey, AssetModel] = {
AssetUniqueKey.from_asset(asset_obj): asset_obj
for asset_obj in session.scalars(
select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(asset_alias_names))
select(AssetModel).where(
tuple_(AssetModel.name, AssetModel.uri).in_(
(key.name, key.uri) for key, _ in asset_alias_names
)
)
)
}
if missing_assets := [
Expand Down
87 changes: 63 additions & 24 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@
TYPE_CHECKING,
Any,
SupportsIndex,
Union,
)

import attrs
import lazy_object_proxy
from sqlalchemy import select
from sqlalchemy import and_, select

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, fetch_active_assets_by_name
Expand Down Expand Up @@ -177,17 +178,22 @@ class OutletEventAccessor:
:meta private:
"""

key: BaseAssetUniqueKey
key: str | BaseAssetUniqueKey
extra: dict[str, Any] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)

def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
def add(self, asset: str | Asset, extra: dict[str, Any] | None = None) -> None:
"""Add an AssetEvent to an existing Asset."""
if not isinstance(asset, Asset):
if isinstance(asset, str):
asset = Asset(asset)
elif not isinstance(asset, Asset):
return

if isinstance(self.key, AssetAliasUniqueKey):
asset_alias_name = self.key.name
elif isinstance(self.key, str):
# TODO: deprecate string access
asset_alias_name = self.key
else:
return

Expand All @@ -199,36 +205,61 @@ def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
self.asset_alias_events.append(event)


class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
class OutletEventAccessors(Mapping[Union[str, BaseAsset], OutletEventAccessor]):
"""
Lazy mapping of outlet asset event accessors.
:meta private:
"""

def __init__(self) -> None:
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}
self._dict: dict[str | BaseAssetUniqueKey, OutletEventAccessor] = {}

def __str__(self) -> str:
return f"OutletEventAccessors(_dict={self._dict})"

def __iter__(self) -> Iterator[BaseAsset]:
return iter(key.to_obj() for key in self._dict)
def __iter__(self) -> Iterator[str | BaseAsset]:
return iter(key.to_obj() if isinstance(key, BaseAssetUniqueKey) else key for key in self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: BaseAsset) -> OutletEventAccessor:
hashable_key: BaseAssetUniqueKey
def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor:
hashable_key: str | BaseAssetUniqueKey
# TODO: Remove it once string accessing is deprecated.
# We currently still support accessing through string.
# Asset("abc") and "abc" returns the same thing.
# Thus, if an user pass Asset("abc"), we need to also check whether "abc" is in dict.
# Same for alias.
potential_equivalent_key = None
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
# TODO: remove after deprecating string accessing
if key.name == key.uri and key.name in self._dict:
potential_equivalent_key = key.name
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
# TODO: remove after deprecating string accessing
if key.name in self._dict:
potential_equivalent_key = key.name
elif isinstance(key, str):
# TODO: remove after deprecating string accessing
hashable_key = key
else:
raise KeyError("Key should be either an asset or an asset alias")

if hashable_key not in self._dict:
# TODO: remove after deprecating string accessing
if key.name in self._dict:
potential_equivalent_key = key.name

if (
hashable_key not in self._dict
and not potential_equivalent_key
and potential_equivalent_key not in self._dict
):
self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
elif potential_equivalent_key:
# TODO: remove after deprecating string accessing
hashable_key = potential_equivalent_key
return self._dict[hashable_key]


Expand All @@ -249,16 +280,16 @@ def _process_row(row: Row) -> AssetEvent:


@attrs.define(init=False)
class InletEventsAccessors(Mapping[BaseAsset, LazyAssetEventSelectSequence]):
class InletEventsAccessors(Mapping[Union[str, int, BaseAsset], LazyAssetEventSelectSequence]):
"""
Lazy mapping for inlet asset events accessors.
:meta private:
"""

_inlets: list[Any]
_assets: dict[str, Asset]
_asset_aliases: dict[str, AssetAlias]
_assets: dict[AssetUniqueKey, Asset]
_asset_aliases: dict[AssetAliasUniqueKey, AssetAlias]
_session: Session

def __init__(self, inlets: list, *, session: Session) -> None:
Expand All @@ -270,23 +301,23 @@ def __init__(self, inlets: list, *, session: Session) -> None:
_asset_ref_names: list[str] = []
for inlet in inlets:
if isinstance(inlet, Asset):
self._assets[inlet.name] = inlet
self._assets[AssetUniqueKey.from_asset(inlet)] = inlet
elif isinstance(inlet, AssetAlias):
self._asset_aliases[inlet.name] = inlet
self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet
elif isinstance(inlet, AssetRef):
_asset_ref_names.append(inlet.name)

if _asset_ref_names:
for asset_name, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items():
self._assets[asset_name] = asset
for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items():
self._assets[AssetUniqueKey.from_asset(asset)] = asset

def __iter__(self) -> Iterator[BaseAsset]:
return iter(self._inlets)

def __len__(self) -> int:
return len(self._inlets)

def __getitem__(self, key: int | BaseAsset) -> LazyAssetEventSelectSequence:
def __getitem__(self, key: int | str | BaseAsset) -> LazyAssetEventSelectSequence:
if isinstance(key, int): # Support index access; it's easier for trivial cases.
obj = self._inlets[key]
if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
Expand All @@ -295,14 +326,22 @@ def __getitem__(self, key: int | BaseAsset) -> LazyAssetEventSelectSequence:
obj = key

if isinstance(obj, AssetAlias):
asset_alias = self._asset_aliases[obj.name]
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
join_clause = AssetEvent.source_aliases
where_clause = AssetAliasModel.name == asset_alias.name
elif isinstance(obj, (Asset, AssetRef)):
elif isinstance(obj, Asset):
asset = self._assets[AssetUniqueKey.from_asset(obj)]
join_clause = AssetEvent.asset
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
elif isinstance(obj, AssetRef):
# TODO: handle the case that Asset uri is different from name
asset = self._assets[AssetUniqueKey.from_asset(Asset(name=obj.name))]
join_clause = AssetEvent.asset
where_clause = AssetModel.name == self._assets[obj.name].name
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
elif isinstance(obj, str):
asset = self._assets[extract_event_key(obj)]
# TODO: deprecate string access
asset_name = extract_event_key(obj)
asset = self._assets[AssetUniqueKey.from_asset(Asset(name=asset_name))]
join_clause = AssetEvent.asset
where_clause = AssetModel.name == asset.name
else:
Expand Down
16 changes: 8 additions & 8 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset, BaseAssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey, BaseAsset, BaseAssetUniqueKey
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import TypedDict
Expand Down Expand Up @@ -70,18 +70,18 @@ class OutletEventAccessor:
self,
*,
extra: dict[str, Any],
key: BaseAssetUniqueKey,
key: str | BaseAssetUniqueKey,
asset_alias_events: list[AssetAliasEvent],
) -> None: ...
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ...
extra: dict[str, Any]
key: BaseAssetUniqueKey
key: str | BaseAssetUniqueKey
asset_alias_events: list[AssetAliasEvent]

class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
class OutletEventAccessors(Mapping[str | BaseAsset, OutletEventAccessor]):
def __iter__(self) -> Iterator[BaseAsset]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: ...
def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor: ...

class InletEventsAccessor(Sequence[AssetEvent]):
@overload
Expand All @@ -90,11 +90,11 @@ class InletEventsAccessor(Sequence[AssetEvent]):
def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ...
def __len__(self) -> int: ...

class InletEventsAccessors(Mapping[Asset | AssetAlias, InletEventsAccessor]):
class InletEventsAccessors(Mapping[str | BaseAsset, InletEventsAccessor]):
def __init__(self, inlets: list, *, session: Session) -> None: ...
def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
def __iter__(self) -> Iterator[BaseAsset]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: int | Asset | AssetAlias) -> InletEventsAccessor: ...
def __getitem__(self, key: int | str | BaseAsset) -> InletEventsAccessor: ...

# NOTE: Please keep this in sync with the following:
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2770,7 +2770,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session):
asset_uri = "test_inlet_asset_extra_ds"
asset_alias_name = "test_inlet_asset_extra_asset_alias"

asset_model = AssetModel(id=1, uri=asset_uri)
asset_model = AssetModel(id=1, uri=asset_uri, group="asset")
asset_alias_model = AssetAliasModel(name=asset_alias_name)
asset_alias_model.assets.append(asset_model)
session.add_all([asset_model, asset_alias_model])
Expand Down
14 changes: 11 additions & 3 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic
Expand Down Expand Up @@ -258,7 +258,7 @@ def __len__(self) -> int:
),
(
OutletEventAccessor(
key=Asset(uri="test", name="test", group="test-group"),
key=AssetUniqueKey.from_asset(Asset(uri="test", name="test", group="test-group")),
extra={"key": "value"},
asset_alias_events=[],
),
Expand All @@ -267,7 +267,9 @@ def __len__(self) -> int:
),
(
OutletEventAccessor(
key=AssetAlias(name="test_alias", group="test-alias-group"),
key=AssetAliasUniqueKey.from_asset_alias(
AssetAlias(name="test_alias", group="test-alias-group")
),
extra={"key": "value"},
asset_alias_events=[
AssetAliasEvent(
Expand All @@ -280,6 +282,12 @@ def __len__(self) -> int:
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
# TODO: deprecate string access
(
OutletEventAccessor(key="test", extra={"key": "value"}, asset_alias_events=[]),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
(
AirflowException("test123 wohoo!"),
DAT.AIRFLOW_EXC_SER,
Expand Down
21 changes: 15 additions & 6 deletions tests/utils/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestOutletEventAccessor:
@pytest.mark.parametrize(
"key, asset_alias_events",
(
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
[
Expand All @@ -39,7 +40,6 @@ class TestOutletEventAccessor:
)
],
),
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
),
)
def test_add(self, key, asset_alias_events):
Expand All @@ -52,7 +52,7 @@ def test_add(self, key, asset_alias_events):
"key, asset_alias_events",
(
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
"test_alias",
[
AssetAliasEvent(
source_alias_name="test_alias",
Expand All @@ -62,26 +62,35 @@ def test_add(self, key, asset_alias_events):
],
),
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
[
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"),
extra={},
)
],
),
),
)
def test_add_with_db(self, key, asset_alias_events, session):
asm = AssetModel(uri="test://asset-uri", name="test-asset", group="asset")
aam = AssetAliasModel(name="test_alias")
session.add_all([asm, aam])
session.flush()
asset = Asset(uri="test://asset-uri", name="test-asset")

outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
outlet_event_accessor.add(Asset(uri="test://asset-uri", name="test-asset"), extra={})
outlet_event_accessor.add(asset, extra={})
assert outlet_event_accessor.asset_alias_events == asset_alias_events


# TODO: add test case to verify string does not work


class TestOutletEventAccessors:
@pytest.mark.parametrize(
"access_key, internal_key",
(
("test", "test"),
(Asset("test"), AssetUniqueKey.from_asset(Asset("test"))),
(
Asset(name="test", uri="test://asset"),
Expand Down

0 comments on commit 411362d

Please sign in to comment.