Skip to content

Commit

Permalink
refactor(context): fix context typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Dec 4, 2024
1 parent 8710512 commit 451e1ea
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,7 +2761,7 @@ def _register_asset_changes_int(
)
}
if missing_assets := [
unique_key.to_asset() for unique_key, _ in asset_alias_names if unique_key not in asset_models
unique_key.to_obj() for unique_key, _ in asset_alias_names if unique_key not in asset_models
]:
asset_models.update(
(AssetUniqueKey.from_asset(asset_obj), asset_obj)
Expand Down
12 changes: 6 additions & 6 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
AssetRef,
AssetUniqueKey,
BaseAsset,
BaseAssetUniqueKey,
)
from airflow.sdk.definitions.asset.metadata import extract_event_key
from airflow.utils.db import LazySelectSequence
Expand Down Expand Up @@ -169,7 +170,7 @@ class OutletEventAccessor:
:meta private:
"""

key: AssetUniqueKey | AssetAliasUniqueKey
key: BaseAssetUniqueKey
extra: dict[str, Any] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)

Expand Down Expand Up @@ -199,26 +200,25 @@ class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
"""

def __init__(self) -> None:
self._dict: dict[BaseAsset, OutletEventAccessor] = {}
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}

def __str__(self) -> str:
return f"OutletEventAccessors(_dict={self._dict})"

def __iter__(self) -> Iterator[BaseAsset]:
return iter(self._dict)
return iter(key.to_obj() for key in self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: BaseAsset) -> OutletEventAccessor:
hashable_key: AssetUniqueKey | AssetAliasUniqueKey
hashable_key: BaseAssetUniqueKey
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
else:
# TODO
raise SystemExit()
raise KeyError("Key should be either an asset or an asset alias")

if hashable_key not in self._dict:
self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
Expand Down
12 changes: 6 additions & 6 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAsset, BaseAssetUniqueKey
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import TypedDict
Expand All @@ -62,18 +62,18 @@ class OutletEventAccessor:
self,
*,
extra: dict[str, Any],
key: Asset | AssetAlias,
key: BaseAssetUniqueKey,
asset_alias_events: list[AssetAliasEvent],
) -> None: ...
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ...
extra: dict[str, Any]
key: Asset | AssetAlias
key: BaseAssetUniqueKey
asset_alias_events: list[AssetAliasEvent]

class OutletEventAccessors(Mapping[Asset | AssetAlias, OutletEventAccessor]):
def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
def __iter__(self) -> Iterator[BaseAsset]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: ...
def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: ...

class InletEventsAccessor(Sequence[AssetEvent]):
@overload
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Callable,
ClassVar,
NamedTuple,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -72,7 +73,7 @@ class AssetUniqueKey(NamedTuple):
def from_asset(asset: Asset | AssetModel) -> AssetUniqueKey:
return AssetUniqueKey(name=asset.name, uri=asset.uri)

def to_asset(self) -> Asset:
def to_obj(self) -> Asset:
return Asset(name=self.name, uri=self.uri)


Expand All @@ -83,6 +84,12 @@ class AssetAliasUniqueKey(NamedTuple):
def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasUniqueKey:
return AssetAliasUniqueKey(name=asset_alias.name)

def to_obj(self) -> AssetAlias:
return AssetAlias(name=self.name)


BaseAssetUniqueKey = Union[AssetUniqueKey, AssetAliasUniqueKey]


def normalize_noop(parts: SplitResult) -> SplitResult:
"""
Expand Down

0 comments on commit 451e1ea

Please sign in to comment.