Skip to content

Commit

Permalink
feat(utils/context): add back string access to inlet and outlet events
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Dec 4, 2024
1 parent 451e1ea commit 18404ac
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 26 deletions.
26 changes: 17 additions & 9 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TYPE_CHECKING,
Any,
SupportsIndex,
Union,
)

import attrs
Expand Down Expand Up @@ -170,7 +171,7 @@ class OutletEventAccessor:
:meta private:
"""

key: BaseAssetUniqueKey
key: str | BaseAssetUniqueKey
extra: dict[str, Any] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)

Expand All @@ -181,6 +182,9 @@ def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:

if isinstance(self.key, AssetAliasUniqueKey):
asset_alias_name = self.key.name
elif isinstance(self.key, str):
# TODO: deprecate string access
asset_alias_name = self.key
else:
return

Expand All @@ -192,31 +196,34 @@ def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
self.asset_alias_events.append(event)


class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
class OutletEventAccessors(Mapping[Union[str, BaseAsset], OutletEventAccessor]):
"""
Lazy mapping of outlet asset event accessors.
:meta private:
"""

def __init__(self) -> None:
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}
self._dict: dict[str | BaseAssetUniqueKey, OutletEventAccessor] = {}

def __str__(self) -> str:
return f"OutletEventAccessors(_dict={self._dict})"

def __iter__(self) -> Iterator[BaseAsset]:
return iter(key.to_obj() for key in self._dict)
def __iter__(self) -> Iterator[str | BaseAsset]:
return iter(key.to_obj() if isinstance(key, BaseAssetUniqueKey) else key for key in self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: BaseAsset) -> OutletEventAccessor:
hashable_key: BaseAssetUniqueKey
def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor:
hashable_key: str | BaseAssetUniqueKey
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
elif isinstance(key, str):
# TODO: deprecate string accessing
hashable_key = key
else:
raise KeyError("Key should be either an asset or an asset alias")

Expand All @@ -242,7 +249,7 @@ def _process_row(row: Row) -> AssetEvent:


@attrs.define(init=False)
class InletEventsAccessors(Mapping[BaseAsset, LazyAssetEventSelectSequence]):
class InletEventsAccessors(Mapping[Union[str, int, BaseAsset], LazyAssetEventSelectSequence]):
"""
Lazy mapping for inlet asset events accessors.
Expand Down Expand Up @@ -279,7 +286,7 @@ def __iter__(self) -> Iterator[BaseAsset]:
def __len__(self) -> int:
return len(self._inlets)

def __getitem__(self, key: int | BaseAsset) -> LazyAssetEventSelectSequence:
def __getitem__(self, key: int | str | BaseAsset) -> LazyAssetEventSelectSequence:
if isinstance(key, int): # Support index access; it's easier for trivial cases.
obj = self._inlets[key]
if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
Expand All @@ -295,6 +302,7 @@ def __getitem__(self, key: int | BaseAsset) -> LazyAssetEventSelectSequence:
join_clause = AssetEvent.asset
where_clause = AssetModel.name == self._assets[obj.name].name
elif isinstance(obj, str):
# TODO: deprecate string access
asset = self._assets[extract_event_key(obj)]
join_clause = AssetEvent.asset
where_clause = AssetModel.name == asset.name
Expand Down
16 changes: 8 additions & 8 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAsset, BaseAssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetAliasEvent, BaseAsset, BaseAssetUniqueKey
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import TypedDict
Expand All @@ -62,18 +62,18 @@ class OutletEventAccessor:
self,
*,
extra: dict[str, Any],
key: BaseAssetUniqueKey,
key: str | BaseAssetUniqueKey,
asset_alias_events: list[AssetAliasEvent],
) -> None: ...
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ...
extra: dict[str, Any]
key: BaseAssetUniqueKey
key: str | BaseAssetUniqueKey
asset_alias_events: list[AssetAliasEvent]

class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
class OutletEventAccessors(Mapping[str | BaseAsset, OutletEventAccessor]):
def __iter__(self) -> Iterator[BaseAsset]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: ...
def __getitem__(self, key: str | BaseAsset) -> OutletEventAccessor: ...

class InletEventsAccessor(Sequence[AssetEvent]):
@overload
Expand All @@ -82,11 +82,11 @@ class InletEventsAccessor(Sequence[AssetEvent]):
def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ...
def __len__(self) -> int: ...

class InletEventsAccessors(Mapping[Asset | AssetAlias, InletEventsAccessor]):
class InletEventsAccessors(Mapping[str | BaseAsset, InletEventsAccessor]):
def __init__(self, inlets: list, *, session: Session) -> None: ...
def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
def __iter__(self) -> Iterator[BaseAsset]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: int | Asset | AssetAlias) -> InletEventsAccessor: ...
def __getitem__(self, key: int | str | BaseAsset) -> InletEventsAccessor: ...

# NOTE: Please keep this in sync with the following:
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
Expand Down
20 changes: 17 additions & 3 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAliasUniqueKey,
AssetUniqueKey,
)
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic
Expand Down Expand Up @@ -258,7 +264,7 @@ def __len__(self) -> int:
),
(
OutletEventAccessor(
key=Asset(uri="test", name="test", group="test-group"),
key=AssetUniqueKey.from_asset(Asset(uri="test", name="test", group="test-group")),
extra={"key": "value"},
asset_alias_events=[],
),
Expand All @@ -267,7 +273,9 @@ def __len__(self) -> int:
),
(
OutletEventAccessor(
key=AssetAlias(name="test_alias", group="test-alias-group"),
key=AssetAliasUniqueKey.from_asset_alias(
AssetAlias(name="test_alias", group="test-alias-group")
),
extra={"key": "value"},
asset_alias_events=[
AssetAliasEvent(
Expand All @@ -280,6 +288,12 @@ def __len__(self) -> int:
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
# TODO: deprecate string access
(
OutletEventAccessor(key="test", extra={"key": "value"}, asset_alias_events=[]),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
(
AirflowException("test123 wohoo!"),
DAT.AIRFLOW_EXC_SER,
Expand Down
21 changes: 15 additions & 6 deletions tests/utils/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TestOutletEventAccessor:
@pytest.mark.parametrize(
"key, asset_alias_events",
(
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
[
Expand All @@ -45,7 +46,6 @@ class TestOutletEventAccessor:
)
],
),
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
),
)
def test_add(self, key, asset_alias_events):
Expand All @@ -58,7 +58,7 @@ def test_add(self, key, asset_alias_events):
"key, asset_alias_events",
(
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
"test_alias",
[
AssetAliasEvent(
source_alias_name="test_alias",
Expand All @@ -68,26 +68,35 @@ def test_add(self, key, asset_alias_events):
],
),
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
[
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"),
extra={},
)
],
),
),
)
def test_add_with_db(self, key, asset_alias_events, session):
asm = AssetModel(uri="test://asset-uri", name="test-asset", group="asset")
aam = AssetAliasModel(name="test_alias")
session.add_all([asm, aam])
session.flush()
asset = Asset(uri="test://asset-uri", name="test-asset")

outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
outlet_event_accessor.add(Asset(uri="test://asset-uri", name="test-asset"), extra={})
outlet_event_accessor.add(asset, extra={})
assert outlet_event_accessor.asset_alias_events == asset_alias_events


# TODO: add test case to verify string does not work


class TestOutletEventAccessors:
@pytest.mark.parametrize(
"access_key, internal_key",
(
("test", "test"),
(Asset("test"), AssetUniqueKey.from_asset(Asset("test"))),
(
Asset(name="test", uri="test://asset"),
Expand Down

0 comments on commit 18404ac

Please sign in to comment.