From 43aa6f6473c2c3e9cbf048eaf962bcd73e278823 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 2 Dec 2024 21:42:12 +0800 Subject: [PATCH] refactor(context): fix context typing --- airflow/models/taskinstance.py | 2 +- airflow/utils/context.py | 12 ++++++------ airflow/utils/context.pyi | 12 ++++++------ .../src/airflow/sdk/definitions/asset/__init__.py | 10 ++++++++-- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 75d8a2bf2cb94..e53c963b92466 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2761,7 +2761,7 @@ def _register_asset_changes_int( ) } if missing_assets := [ - unique_key.to_asset() for unique_key, _ in asset_alias_names if unique_key not in asset_models + unique_key.to_obj() for unique_key, _ in asset_alias_names if unique_key not in asset_models ]: asset_models.update( (AssetUniqueKey.from_asset(asset_obj), asset_obj) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 31c5cd89a8278..600393a78f40c 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -51,6 +51,7 @@ AssetRef, AssetUniqueKey, BaseAsset, + BaseAssetUniqueKey, ) from airflow.sdk.definitions.asset.metadata import extract_event_key from airflow.utils.db import LazySelectSequence @@ -176,7 +177,7 @@ class OutletEventAccessor: :meta private: """ - key: AssetUniqueKey | AssetAliasUniqueKey + key: BaseAssetUniqueKey extra: dict[str, Any] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) @@ -206,26 +207,25 @@ class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]): """ def __init__(self) -> None: - self._dict: dict[BaseAsset, OutletEventAccessor] = {} + self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} def __str__(self) -> str: return f"OutletEventAccessors(_dict={self._dict})" def __iter__(self) -> Iterator[BaseAsset]: - return iter(self._dict) + return iter(key.to_obj() for key in self._dict) def __len__(self) -> int: return len(self._dict) def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: - hashable_key: AssetUniqueKey | AssetAliasUniqueKey + hashable_key: BaseAssetUniqueKey if isinstance(key, Asset): hashable_key = AssetUniqueKey.from_asset(key) elif isinstance(key, AssetAlias): hashable_key = AssetAliasUniqueKey.from_asset_alias(key) else: - # TODO - raise SystemExit() + raise KeyError("Key should be either an asset or an asset alias") if hashable_key not in self._dict: self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 4b573920219ba..bc0964f713aaa 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -39,7 +39,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset, BaseAssetUniqueKey from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.typing_compat import TypedDict @@ -70,18 +70,18 @@ class OutletEventAccessor: self, *, extra: dict[str, Any], - key: Asset | AssetAlias, + key: BaseAssetUniqueKey, asset_alias_events: list[AssetAliasEvent], ) -> None: ... def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... extra: dict[str, Any] - key: Asset | AssetAlias + key: BaseAssetUniqueKey asset_alias_events: list[AssetAliasEvent] -class OutletEventAccessors(Mapping[Asset | AssetAlias, OutletEventAccessor]): - def __iter__(self) -> Iterator[Asset | AssetAlias]: ... +class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]): + def __iter__(self) -> Iterator[BaseAsset]: ... def __len__(self) -> int: ... - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: ... + def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: ... class InletEventsAccessor(Sequence[AssetEvent]): @overload diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 71d6ea4b2bc45..7abc4e53d7472 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -22,7 +22,7 @@ import os import urllib.parse import warnings -from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, Union, overload import attrs @@ -58,7 +58,7 @@ class AssetUniqueKey(NamedTuple): def from_asset(asset: Asset | AssetModel) -> AssetUniqueKey: return AssetUniqueKey(name=asset.name, uri=asset.uri) - def to_asset(self) -> Asset: + def to_obj(self) -> Asset: return Asset(name=self.name, uri=self.uri) @@ -69,6 +69,12 @@ class AssetAliasUniqueKey(NamedTuple): def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasUniqueKey: return AssetAliasUniqueKey(name=asset_alias.name) + def to_obj(self) -> AssetAlias: + return AssetAlias(name=self.name) + + +BaseAssetUniqueKey = Union[AssetUniqueKey, AssetAliasUniqueKey] + def normalize_noop(parts: SplitResult) -> SplitResult: """