Skip to content

Commit

Permalink
refactor(context): fix context typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Dec 6, 2024
1 parent 53db7e4 commit 43aa6f6
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,7 +2761,7 @@ def _register_asset_changes_int(
)
}
if missing_assets := [
unique_key.to_asset() for unique_key, _ in asset_alias_names if unique_key not in asset_models
unique_key.to_obj() for unique_key, _ in asset_alias_names if unique_key not in asset_models
]:
asset_models.update(
(AssetUniqueKey.from_asset(asset_obj), asset_obj)
Expand Down
12 changes: 6 additions & 6 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
AssetRef,
AssetUniqueKey,
BaseAsset,
BaseAssetUniqueKey,
)
from airflow.sdk.definitions.asset.metadata import extract_event_key
from airflow.utils.db import LazySelectSequence
Expand Down Expand Up @@ -176,7 +177,7 @@ class OutletEventAccessor:
:meta private:
"""

key: AssetUniqueKey | AssetAliasUniqueKey
key: BaseAssetUniqueKey
extra: dict[str, Any] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)

Expand Down Expand Up @@ -206,26 +207,25 @@ class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
"""

def __init__(self) -> None:
self._dict: dict[BaseAsset, OutletEventAccessor] = {}
self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}

def __str__(self) -> str:
return f"OutletEventAccessors(_dict={self._dict})"

def __iter__(self) -> Iterator[BaseAsset]:
return iter(self._dict)
return iter(key.to_obj() for key in self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: BaseAsset) -> OutletEventAccessor:
hashable_key: AssetUniqueKey | AssetAliasUniqueKey
hashable_key: BaseAssetUniqueKey
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
else:
# TODO
raise SystemExit()
raise KeyError("Key should be either an asset or an asset alias")

if hashable_key not in self._dict:
self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
Expand Down
12 changes: 6 additions & 6 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset, BaseAssetUniqueKey
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import TypedDict
Expand Down Expand Up @@ -70,18 +70,18 @@ class OutletEventAccessor:
self,
*,
extra: dict[str, Any],
key: Asset | AssetAlias,
key: BaseAssetUniqueKey,
asset_alias_events: list[AssetAliasEvent],
) -> None: ...
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ...
extra: dict[str, Any]
key: Asset | AssetAlias
key: BaseAssetUniqueKey
asset_alias_events: list[AssetAliasEvent]

class OutletEventAccessors(Mapping[Asset | AssetAlias, OutletEventAccessor]):
def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]):
def __iter__(self) -> Iterator[BaseAsset]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: ...
def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: ...

class InletEventsAccessor(Sequence[AssetEvent]):
@overload
Expand Down
10 changes: 8 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import urllib.parse
import warnings
from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload
from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, Union, overload

import attrs

Expand Down Expand Up @@ -58,7 +58,7 @@ class AssetUniqueKey(NamedTuple):
def from_asset(asset: Asset | AssetModel) -> AssetUniqueKey:
return AssetUniqueKey(name=asset.name, uri=asset.uri)

def to_asset(self) -> Asset:
def to_obj(self) -> Asset:
return Asset(name=self.name, uri=self.uri)


Expand All @@ -69,6 +69,12 @@ class AssetAliasUniqueKey(NamedTuple):
def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasUniqueKey:
return AssetAliasUniqueKey(name=asset_alias.name)

def to_obj(self) -> AssetAlias:
return AssetAlias(name=self.name)


BaseAssetUniqueKey = Union[AssetUniqueKey, AssetAliasUniqueKey]


def normalize_noop(parts: SplitResult) -> SplitResult:
"""
Expand Down

0 comments on commit 43aa6f6

Please sign in to comment.