Skip to content

Commit

Permalink
Make AssetAliasEvent a class context.py (apache#44709)
Browse files Browse the repository at this point in the history
* Make AssetAliasEvent a class context.py

We don't use this class anywhere else, so it's better for it to live in
this module instead.

I also changed it into a simple class instead since it does not really
make sense for it to be a dict. We need to do some deprecation in the
2.x branch later.

* Make sure AssetAliasEvent is private

* Fix deserialization

* Fix type stub
  • Loading branch information
uranusjr authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent 63d6b2b commit b8268fb
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 46 deletions.
6 changes: 3 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,9 +2748,9 @@ def _register_asset_changes_int(
)
elif isinstance(obj, AssetAlias):
for asset_alias_event in events[obj].asset_alias_events:
asset_alias_name = asset_alias_event["source_alias_name"]
asset_uri = asset_alias_event["dest_asset_uri"]
frozen_extra = frozenset(asset_alias_event["extra"].items())
asset_alias_name = asset_alias_event.source_alias_name
asset_uri = asset_alias_event.dest_asset_uri
frozen_extra = frozenset(asset_alias_event.extra.items())
asset_alias_names[(asset_uri, frozen_extra)].add(asset_alias_name)

asset_models: dict[str, AssetModel] = {
Expand Down
5 changes: 3 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import (
AssetAliasEvent,
ConnectionAccessor,
Context,
OutletEventAccessor,
Expand Down Expand Up @@ -305,7 +306,7 @@ def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
raw_key = var.raw_key
return {
"extra": var.extra,
"asset_alias_events": var.asset_alias_events,
"asset_alias_events": [attrs.asdict(cast(attrs.AttrsInstance, e)) for e in var.asset_alias_events],
"raw_key": BaseSerialization.serialize(raw_key),
}

Expand All @@ -316,7 +317,7 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
outlet_event_accessor = OutletEventAccessor(
extra=var["extra"],
raw_key=BaseSerialization.deserialize(var["raw_key"]),
asset_alias_events=asset_alias_events,
asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events],
)
return outlet_event_accessor

Expand Down
24 changes: 15 additions & 9 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, fetch_active_assets_by_name
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetRef,
)
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef
from airflow.sdk.definitions.asset.metadata import extract_event_key
from airflow.utils.db import LazySelectSequence
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -145,6 +140,19 @@ def get(self, key: str, default_conn: Any = None) -> Any:
return default_conn


@attrs.define()
class AssetAliasEvent:
"""
Represeation of asset event to be triggered by an asset alias.
:meta private:
"""

source_alias_name: str
dest_asset_uri: str
extra: dict[str, Any]


@attrs.define()
class OutletEventAccessor:
"""
Expand Down Expand Up @@ -173,9 +181,7 @@ def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None:
else:
return

event = AssetAliasEvent(
source_alias_name=asset_alias_name, dest_asset_uri=asset_uri, extra=extra or {}
)
event = AssetAliasEvent(asset_alias_name, asset_uri, extra=extra or {})
self.asset_alias_events.append(event)


Expand Down
8 changes: 7 additions & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import TypedDict
Expand All @@ -57,6 +57,12 @@ class VariableAccessor:
class ConnectionAccessor:
def get(self, key: str, default_conn: Any = None) -> Any: ...

class AssetAliasEvent:
source_alias_name: str
dest_asset_uri: str
extra: dict[str, Any]
def __init__(self, source_alias_name: str, dest_asset_uri: str, extra: dict[str, Any]) -> None: ...

class OutletEventAccessor:
def __init__(
self,
Expand Down
18 changes: 2 additions & 16 deletions providers/src/airflow/providers/common/compat/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,12 @@
if TYPE_CHECKING:
from airflow.auth.managers.models.resource_details import AssetDetails
from airflow.models.asset import expand_alias_to_assets
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAll,
AssetAny,
)
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
else:
if AIRFLOW_V_3_0_PLUS:
from airflow.auth.managers.models.resource_details import AssetDetails
from airflow.models.asset import expand_alias_to_assets
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAll,
AssetAny,
)
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
else:
# dataset is renamed to asset since Airflow 3.0
from airflow.datasets import Dataset as Asset
Expand All @@ -63,15 +51,13 @@
if AIRFLOW_V_2_10_PLUS:
from airflow.datasets import (
DatasetAlias as AssetAlias,
DatasetAliasEvent as AssetAliasEvent,
expand_alias_to_datasets as expand_alias_to_assets,
)


__all__ = [
"Asset",
"AssetAlias",
"AssetAliasEvent",
"AssetAll",
"AssetAny",
"AssetDetails",
Expand Down
18 changes: 7 additions & 11 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import attrs

from airflow.serialization.dag_dependency import DagDependency
from airflow.typing_compat import TypedDict

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -454,14 +453,6 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat
)


class AssetAliasEvent(TypedDict):
"""A represeation of asset event to be triggered by an asset alias."""

source_alias_name: str
dest_asset_uri: str
extra: dict[str, Any]


class _AssetBooleanCondition(BaseAsset):
"""Base class for asset boolean logic."""

Expand All @@ -476,7 +467,7 @@ def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
seen = set() # We want to keep the first instance.
seen: set[AssetUniqueKey] = set() # We want to keep the first instance.
for o in self.objects:
for k, v in o.iter_assets():
if k in seen:
Expand All @@ -486,8 +477,13 @@ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:

def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
"""Filter asset aliases in the condition."""
seen: set[str] = set() # We want to keep the first instance.
for o in self.objects:
yield from o.iter_asset_aliases()
for k, v in o.iter_asset_aliases():
if k in seen:
continue
yield k, v
seen.add(k)

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic
Expand All @@ -60,7 +60,7 @@
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.context import OutletEventAccessor, OutletEventAccessors
from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors
from airflow.utils.db import LazySelectSequence
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import pytest

from airflow.models.asset import AssetAliasModel, AssetModel
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent
from airflow.utils.context import OutletEventAccessor, OutletEventAccessors
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors


class TestOutletEventAccessor:
Expand Down

0 comments on commit b8268fb

Please sign in to comment.