Skip to content

Commit

Permalink
Implement XComArg.zip(*xcom_args) (#25176)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Aug 1, 2022
1 parent 3cf7aec commit b90fc14
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 135 deletions.
88 changes: 11 additions & 77 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
import operator
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union

from sqlalchemy import func
from sqlalchemy.orm import Session

from airflow.compat.functools import cache
from airflow.utils.context import Context

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.xcom_arg import XComArg

ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
Expand Down Expand Up @@ -95,63 +94,16 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]:
If any arguments are not known right now (upstream task not finished),
they will not be present in the dict.
"""
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY, XCom
from airflow.models.xcom_arg import XComArg

# Populate literal mapped arguments first.
map_lengths: dict[str, int] = collections.defaultdict(int)
map_lengths.update((k, len(v)) for k, v in self.value.items() if not isinstance(v, XComArg))

try:
dag_id = next(v.operator.dag_id for v in self.value.values() if isinstance(v, XComArg))
except StopIteration: # All mapped arguments are literal. We're done.
return map_lengths

# Build a reverse mapping of what arguments each task contributes to.
mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
non_mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
for k, v in self.value.items():
if not isinstance(v, XComArg):
continue
assert v.operator.dag_id == dag_id
if v.operator.is_mapped:
mapped_dep_keys[v.operator.task_id].add(k)
else:
non_mapped_dep_keys[v.operator.task_id].add(k)
# TODO: It's not possible now, but in the future we may support
# depending on one single mapped task instance. When that happens,
# we need to further analyze the mapped case to contain only tasks
# we depend on "as a whole", and put those we only depend on
# individually to the non-mapped lookup.

# Collect lengths from unmapped upstreams.
taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
TaskMap.dag_id == dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id.in_(non_mapped_dep_keys),
TaskMap.map_index < 0,
)
for task_id, length in taskmap_query:
for mapped_arg_name in non_mapped_dep_keys[task_id]:
map_lengths[mapped_arg_name] += length

# Collect lengths from mapped upstreams.
xcom_query = (
session.query(XCom.task_id, func.count(XCom.map_index))
.group_by(XCom.task_id)
.filter(
XCom.dag_id == dag_id,
XCom.run_id == run_id,
XCom.key == XCOM_RETURN_KEY,
XCom.task_id.in_(mapped_dep_keys),
XCom.map_index >= 0,
)
# TODO: This initiates one database call for each XComArg. Would it be
# more efficient to do one single db call and unpack the value here?
map_lengths_iterator = (
(k, (v.get_task_map_length(run_id, session=session) if isinstance(v, XComArg) else len(v)))
for k, v in self.value.items()
)
for task_id, length in xcom_query:
for mapped_arg_name in mapped_dep_keys[task_id]:
map_lengths[mapped_arg_name] += length

map_lengths = {k: v for k, v in map_lengths_iterator if v is not None}
if len(map_lengths) < len(self.value):
raise NotFullyPopulated(set(self.value).difference(map_lengths))
return map_lengths
Expand Down Expand Up @@ -228,28 +180,10 @@ def get_parse_time_mapped_ti_count(self) -> int | None:
return None

def get_total_map_length(self, run_id: str, *, session: Session) -> int:
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

task = self.value.operator
if task.is_mapped:
query = session.query(func.count(XCom.map_index)).filter(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
XCom.task_id == task.task_id,
XCom.map_index >= 0,
)
else:
query = session.query(TaskMap.length).filter(
TaskMap.dag_id == task.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task.task_id,
TaskMap.map_index < 0,
)
value = query.scalar()
if value is None:
length = self.value.get_task_map_length(run_id, session=session)
if length is None:
raise NotFullyPopulated({"expand_kwargs() argument"})
return value
return length

def resolve(self, context: Context, session: Session) -> Mapping[str, Any]:
map_index = context["ti"].map_index
Expand Down
8 changes: 5 additions & 3 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ def ensure_xcomarg_return_value(arg: Any) -> None:
from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg

if isinstance(arg, XComArg):
if arg.key != XCOM_RETURN_KEY:
raise ValueError(f"cannot map over XCom with custom key {arg.key!r} from {arg.operator}")
for operator, key in arg.iter_references():
if key != XCOM_RETURN_KEY:
raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
elif not is_container(arg):
return
elif isinstance(arg, collections.abc.Mapping):
Expand Down Expand Up @@ -704,7 +705,8 @@ def iter_mapped_dependencies(self) -> Iterator["Operator"]:
from airflow.models.xcom_arg import XComArg

for ref in XComArg.iter_xcom_args(self._get_specified_expand_input()):
yield ref.operator
for operator, _ in ref.iter_references():
yield operator

@cached_property
def parse_time_mapped_ti_count(self) -> Optional[int]:
Expand Down
Loading

0 comments on commit b90fc14

Please sign in to comment.