Skip to content

Commit

Permalink
feat(utils/context): forbid using string to access inlet and outlet e…
Browse files Browse the repository at this point in the history
…vent
  • Loading branch information
Lee-W committed Dec 9, 2024
1 parent c9d8b6f commit 0a81fa4
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 87 deletions.
80 changes: 19 additions & 61 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@
AssetAliasUniqueKey,
AssetRef,
AssetUniqueKey,
BaseAsset,
BaseAssetUniqueKey,
)
from airflow.sdk.definitions.asset.metadata import extract_event_key
from airflow.utils.db import LazySelectSequence
from airflow.utils.types import NOTSET

Expand Down Expand Up @@ -178,25 +176,16 @@ class OutletEventAccessor:
:meta private:
"""

key: str | BaseAssetUniqueKey
key: BaseAssetUniqueKey
extra: dict[str, Any] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)

def add(self, asset: str | Asset, extra: dict[str, Any] | None = None) -> None:
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
"""Add an AssetEvent to an existing Asset."""
if isinstance(asset, str):
asset = Asset(asset)
elif not isinstance(asset, Asset):
return

if isinstance(self.key, AssetAliasUniqueKey):
asset_alias_name = self.key.name
elif isinstance(self.key, str):
# TODO: deprecate string access
asset_alias_name = self.key
else:
if not isinstance(self.key, AssetAliasUniqueKey):
return

asset_alias_name = self.key.name
event = AssetAliasEvent(
source_alias_name=asset_alias_name,
dest_asset_key=AssetUniqueKey.from_asset(asset),
Expand All @@ -205,61 +194,36 @@ def add(self, asset: str | Asset, extra: dict[str, Any] | None = None) -> None:
self.asset_alias_events.append(event)


class OutletEventAccessors(Mapping[Union[str, BaseAsset], OutletEventAccessor]):
class OutletEventAccessors(Mapping[Union[Asset | AssetAlias], OutletEventAccessor]):
"""
Lazy mapping of outlet asset event accessors.
:meta private:
"""

def __init__(self) -> None:
self._dict: dict[str | BaseAssetUniqueKey, OutletEventAccessor] = {}
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}

def __str__(self) -> str:
return f"OutletEventAccessors(_dict={self._dict})"

def __iter__(self) -> Iterator[str | BaseAsset]:
def __iter__(self) -> Iterator[Asset | AssetAlias]:
return iter(key.to_obj() if isinstance(key, BaseAssetUniqueKey) else key for key in self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor:
hashable_key: str | BaseAssetUniqueKey
# TODO: Remove it once string accessing is deprecated.
# We currently still support accessing through string.
# Asset("abc") and "abc" returns the same thing.
# Thus, if an user pass Asset("abc"), we need to also check whether "abc" is in dict.
# Same for alias.
potential_equivalent_key = None
def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor:
hashable_key: BaseAssetUniqueKey
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
# TODO: remove after deprecating string accessing
if key.name == key.uri and key.name in self._dict:
potential_equivalent_key = key.name
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
# TODO: remove after deprecating string accessing
if key.name in self._dict:
potential_equivalent_key = key.name
elif isinstance(key, str):
# TODO: remove after deprecating string accessing
hashable_key = key
else:
raise KeyError("Key should be either an asset or an asset alias")
# TODO: remove after deprecating string accessing
if key.name in self._dict:
potential_equivalent_key = key.name

if (
hashable_key not in self._dict
and not potential_equivalent_key
and potential_equivalent_key not in self._dict
):

if hashable_key not in self._dict:
self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
elif potential_equivalent_key:
# TODO: remove after deprecating string accessing
hashable_key = potential_equivalent_key
return self._dict[hashable_key]


Expand All @@ -280,7 +244,7 @@ def _process_row(row: Row) -> AssetEvent:


@attrs.define(init=False)
class InletEventsAccessors(Mapping[Union[str, int, BaseAsset], LazyAssetEventSelectSequence]):
class InletEventsAccessors(Mapping[Union[int, Asset, AssetAlias, AssetRef], LazyAssetEventSelectSequence]):
"""
Lazy mapping for inlet asset events accessors.
Expand Down Expand Up @@ -311,39 +275,33 @@ def __init__(self, inlets: list, *, session: Session) -> None:
for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items():
self._assets[AssetUniqueKey.from_asset(asset)] = asset

def __iter__(self) -> Iterator[BaseAsset]:
def __iter__(self) -> Iterator[Asset | AssetAlias]:
return iter(self._inlets)

def __len__(self) -> int:
return len(self._inlets)

def __getitem__(self, key: int | str | BaseAsset) -> LazyAssetEventSelectSequence:
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> LazyAssetEventSelectSequence:
if isinstance(key, int): # Support index access; it's easier for trivial cases.
obj = self._inlets[key]
if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
raise IndexError(key)
else:
obj = key

if isinstance(obj, AssetAlias):
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
join_clause = AssetEvent.source_aliases
where_clause = AssetAliasModel.name == asset_alias.name
elif isinstance(obj, Asset):
if isinstance(obj, Asset):
asset = self._assets[AssetUniqueKey.from_asset(obj)]
join_clause = AssetEvent.asset
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
elif isinstance(obj, AssetAlias):
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
join_clause = AssetEvent.source_aliases
where_clause = AssetAliasModel.name == asset_alias.name
elif isinstance(obj, AssetRef):
# TODO: handle the case that Asset uri is different from name
asset = self._assets[AssetUniqueKey.from_asset(Asset(name=obj.name))]
join_clause = AssetEvent.asset
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
elif isinstance(obj, str):
# TODO: deprecate string access
asset_name = extract_event_key(obj)
asset = self._assets[AssetUniqueKey.from_asset(Asset(name=asset_name))]
join_clause = AssetEvent.asset
where_clause = AssetModel.name == asset.name
else:
raise ValueError(key)

Expand Down
18 changes: 9 additions & 9 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey, BaseAsset, BaseAssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef, AssetUniqueKey, BaseAssetUniqueKey
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import TypedDict
Expand Down Expand Up @@ -69,19 +69,19 @@ class OutletEventAccessor:
def __init__(
self,
*,
key: BaseAssetUniqueKey,
extra: dict[str, Any],
key: str | BaseAssetUniqueKey,
asset_alias_events: list[AssetAliasEvent],
) -> None: ...
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ...
key: BaseAssetUniqueKey
extra: dict[str, Any]
key: str | BaseAssetUniqueKey
asset_alias_events: list[AssetAliasEvent]

class OutletEventAccessors(Mapping[str | BaseAsset, OutletEventAccessor]):
def __iter__(self) -> Iterator[BaseAsset]: ...
class OutletEventAccessors(Mapping[Asset | AssetAlias, OutletEventAccessor]):
def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor: ...
def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: ...

class InletEventsAccessor(Sequence[AssetEvent]):
@overload
Expand All @@ -90,11 +90,11 @@ class InletEventsAccessor(Sequence[AssetEvent]):
def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ...
def __len__(self) -> int: ...

class InletEventsAccessors(Mapping[str | BaseAsset, InletEventsAccessor]):
class InletEventsAccessors(Mapping[Asset | AssetAlias, InletEventsAccessor]):
def __init__(self, inlets: list, *, session: Session) -> None: ...
def __iter__(self) -> Iterator[BaseAsset]: ...
def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: int | str | BaseAsset) -> InletEventsAccessor: ...
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> InletEventsAccessor: ...

# NOTE: Please keep this in sync with the following:
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
Expand Down
6 changes: 0 additions & 6 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,6 @@ def __len__(self) -> int:
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
# TODO: deprecate string access
(
OutletEventAccessor(key="test", extra={"key": "value"}, asset_alias_events=[]),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
(
AirflowException("test123 wohoo!"),
DAT.AIRFLOW_EXC_SER,
Expand Down
11 changes: 0 additions & 11 deletions tests/utils/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,6 @@ def test_add(self, key, asset_alias_events):
@pytest.mark.parametrize(
"key, asset_alias_events",
(
(
"test_alias",
[
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"),
extra={},
)
],
),
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
Expand Down Expand Up @@ -90,7 +80,6 @@ class TestOutletEventAccessors:
@pytest.mark.parametrize(
"access_key, internal_key",
(
("test", "test"),
(Asset("test"), AssetUniqueKey.from_asset(Asset("test"))),
(
Asset(name="test", uri="test://asset"),
Expand Down

0 comments on commit 0a81fa4

Please sign in to comment.