Skip to content

Commit

Permalink
Implement expand_kwargs() (#24989)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Jul 14, 2022
1 parent cebff31 commit 7d95bd9
Show file tree
Hide file tree
Showing 9 changed files with 496 additions and 30 deletions.
34 changes: 28 additions & 6 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
parse_retries,
)
from airflow.models.dag import DAG, DagContext
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, DictOfListsExpandInput, ExpandInput
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
)
from airflow.models.mappedoperator import (
MappedOperator,
ValidationSource,
Expand Down Expand Up @@ -171,8 +176,17 @@ def __init__(
op_args = op_args or []
op_kwargs = op_kwargs or {}

# Check that arguments can be binded
inspect.signature(python_callable).bind(*op_args, **op_kwargs)
# Check that arguments can be binded. There's a slight difference when
# we do validation for task-mapping: Since there's no guarantee we can
# receive enough arguments at parse time, we use bind_partial to simply
# check all the arguments we know are valid. Whether these are enough
# can only be known at execution time, when unmapping happens, and this
# is called without the _airflow_mapped_validation_only flag.
if kwargs.get("_airflow_mapped_validation_only"):
inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs)
else:
inspect.signature(python_callable).bind(*op_args, **op_kwargs)

self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
Expand Down Expand Up @@ -323,6 +337,13 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)

def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg:
from airflow.models.xcom_arg import XComArg

if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
ensure_xcomarg_return_value(expand_input.value)

Expand Down Expand Up @@ -442,10 +463,11 @@ def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> D
mapped_kwargs["op_kwargs"],
fail_reason="mapping already partial",
)

static_kwargs = {k for k, _ in self.op_kwargs_expand_input.iter_parse_time_resolved_kwargs()}
self._combined_op_kwargs = {**self.partial_kwargs["op_kwargs"], **mapped_kwargs["op_kwargs"]}
self._already_resolved_op_kwargs = {
k for k, v in self.op_kwargs_expand_input.value.items() if isinstance(v, XComArg)
}
self._already_resolved_op_kwargs = {k for k in mapped_kwargs["op_kwargs"] if k not in static_kwargs}

kwargs = {
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
Expand Down
20 changes: 17 additions & 3 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,26 @@ def __str__(self) -> str:
class UnmappableXComTypePushed(AirflowException):
"""Raise when an unmappable type is pushed as a mapped downstream's dependency."""

def __init__(self, value: Any) -> None:
super().__init__(value)
def __init__(self, value: Any, *values: Any) -> None:
super().__init__(value, *values)

def __str__(self) -> str:
typename = type(self.args[0]).__qualname__
for arg in self.args[1:]:
typename = f"{typename}[{type(arg).__qualname__}]"
return f"unmappable return type {typename!r}"


class UnmappableXComValuePushed(AirflowException):
"""Raise when an invalid value is pushed as a mapped downstream's dependency."""

def __init__(self, value: Any, reason: str) -> None:
super().__init__(value, reason)
self.value = value
self.reason = reason

def __str__(self) -> str:
return f"unmappable return type {type(self.value).__qualname__!r}"
return f"unmappable return value {self.value!r} ({self.reason})"


class UnmappableXComLengthPushed(AirflowException):
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def __init__(

super().__init__()

kwargs.pop("_airflow_mapped_validation_only", None)
if kwargs:
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
raise AirflowException(
Expand Down Expand Up @@ -1509,7 +1510,7 @@ def defer(
def validate_mapped_arguments(cls, **kwargs: Any) -> None:
"""Validate arguments when this operator is being mapped."""
if cls.mapped_arguments_validated_by_init:
cls(**kwargs, _airflow_from_mapped=True)
cls(**kwargs, _airflow_from_mapped=True, _airflow_mapped_validation_only=True)

def unmap(self, ctx: Union[None, Dict[str, Any], Tuple[Context, Session]]) -> "BaseOperator":
""":meta private:"""
Expand Down
95 changes: 90 additions & 5 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,36 @@
import collections.abc
import functools
import operator
from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union
from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Sequence, Sized, Union

from sqlalchemy import func
from sqlalchemy.orm import Session

from airflow.exceptions import UnmappableXComTypePushed
from airflow.compat.functools import cache
from airflow.exceptions import UnmappableXComTypePushed, UnmappableXComValuePushed
from airflow.utils.context import Context

if TYPE_CHECKING:
from airflow.models.xcom_arg import XComArg

ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]

# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any
# mapping since we need the value to be ordered).
Mappable = Union["XComArg", Sequence, dict]

MAPPABLE_LITERAL_TYPES = (dict, list)

# For isinstance() check.
@cache
def get_mappable_types() -> tuple[type, ...]:
from airflow.models.xcom_arg import XComArg

return (XComArg, list, tuple, dict)


class NotFullyPopulated(RuntimeError):
"""Raise when ``get_map_lengths`` cannot populate all mapping metadata.
This is generally due to not all upstream tasks have finished when the
function is called.
"""
Expand All @@ -67,10 +77,20 @@ def validate_xcom(value: Any) -> None:
if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
raise UnmappableXComTypePushed(value)

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving."""
return self.value

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
from airflow.models.xcom_arg import XComArg

return ((k, v) for k, v in self.value.items() if not isinstance(v, XComArg))

def get_parse_time_mapped_ti_count(self) -> int | None:
if not self.value:
return 0
literal_values = [len(v) for v in self.value.values() if isinstance(v, MAPPABLE_LITERAL_TYPES)]
literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()]
if len(literal_values) != len(self.value):
return None # None-literal type encountered, so give up.
return functools.reduce(operator.mul, literal_values, 1)
Expand Down Expand Up @@ -184,12 +204,77 @@ def resolve(self, context: Context, session: Session) -> dict[str, Any]:
return {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}


ExpandInput = DictOfListsExpandInput
class ListOfDictsExpandInput(NamedTuple):
"""Storage type of a mapped operator's mapped kwargs.
This is created from ``expand_kwargs(xcom_arg)``.
"""

value: XComArg

@staticmethod
def validate_xcom(value: Any) -> None:
if not isinstance(value, collections.abc.Collection):
raise UnmappableXComTypePushed(value)
if isinstance(value, (str, bytes, collections.abc.Mapping)):
raise UnmappableXComTypePushed(value)
for item in value:
if not isinstance(item, collections.abc.Mapping):
raise UnmappableXComTypePushed(value, item)
if not all(isinstance(k, str) for k in item):
raise UnmappableXComValuePushed(value, reason="dict keys must be str")

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving.
Since the list-of-dicts case relies entirely on run-time XCom, there's
no kwargs structure available, so this just returns an empty dict.
"""
return {}

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
return ()

def get_parse_time_mapped_ti_count(self) -> int | None:
return None

def get_total_map_length(self, run_id: str, *, session: Session) -> int:
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

task = self.value.operator
if task.is_mapped:
query = session.query(func.count(XCom.map_index)).filter(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
XCom.task_id == task.task_id,
XCom.map_index >= 0,
)
else:
query = session.query(TaskMap.length).filter(
TaskMap.dag_id == task.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task.task_id,
TaskMap.map_index < 0,
)
value = query.scalar()
if value is None:
raise NotFullyPopulated({"expand_kwargs() argument"})
return value

def resolve(self, context: Context, session: Session) -> dict[str, Any]:
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
# Validation should be done when the upstream returns.
return self.value.resolve(context, session)[map_index]


EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.

_EXPAND_INPUT_TYPES = {
"dict-of-lists": DictOfListsExpandInput,
"list-of-dicts": ListOfDictsExpandInput,
}


Expand Down
27 changes: 13 additions & 14 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@
TaskStateChangeCallback,
)
from airflow.models.expandinput import (
MAPPABLE_LITERAL_TYPES,
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
Mappable,
NotFullyPopulated,
get_mappable_types,
)
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
Expand All @@ -86,19 +87,12 @@
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup

ValidationSource = Union[Literal["expand"], Literal["partial"]]


# For isinstance() check.
@cache
def get_mappable_types() -> Tuple[type, ...]:
from airflow.models.xcom_arg import XComArg

return (XComArg,) + MAPPABLE_LITERAL_TYPES


def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, value: Dict[str, Any]) -> None:
# use a dict so order of args is same as code order
unknown_args = value.copy()
Expand Down Expand Up @@ -198,6 +192,13 @@ def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)

def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> "MappedOperator":
from airflow.models.xcom_arg import XComArg

if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator":
from airflow.operators.empty import EmptyOperator

Expand Down Expand Up @@ -541,12 +542,10 @@ def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, Session]]) -> D
operation on the list-of-dicts variant before execution time, an empty
dict will be returned for this case.
"""
kwargs = self._get_specified_expand_input()
expand_input = self._get_specified_expand_input()
if resolve is not None:
return kwargs.resolve(*resolve)
if isinstance(kwargs, DictOfListsExpandInput):
return kwargs.value
return {}
return expand_input.resolve(*resolve)
return expand_input.get_unresolved_kwargs()

def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
"""Get init kwargs to unmap the underlying operator class.
Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,3 +1422,24 @@ def test_schedule_tis_map_index(dag_maker, session):
assert ti0.state == TaskInstanceState.SUCCESS
assert ti1.state == TaskInstanceState.SCHEDULED
assert ti2.state == TaskInstanceState.SUCCESS


def test_mapped_expand_kwargs(dag_maker):
with dag_maker() as dag:

@task
def task_1():
return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}]

MockOperator.partial(task_id="task_2").expand_kwargs(task_1())

dr: DagRun = dag_maker.create_dagrun()
assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1

ti1 = dr.get_task_instance("task_1")
ti1.refresh_from_task(dag.get_task("task_1"))
ti1.run()

dr.task_instance_scheduling_decisions()
ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"}
assert ti_states == {0: None, 1: None, 2: None}
Loading

0 comments on commit 7d95bd9

Please sign in to comment.