Skip to content

Commit

Permalink
Add map_index to XCom interface
Browse files Browse the repository at this point in the history
This adds an additional (optional) map_index argument to XCom's
get/set/clear interface so mapped task instances can push to the
correct entries, and have them pulled correctly by a downstream.

To make the XCom interface easier to use for common scenarios, a
convenience method get_value is added to take a TaskInstanceKey that
automatically performs argument unpacking and call get_one underneath.
This is not done as a get_one overload to simplify the implementation
and typing.
  • Loading branch information
uranusjr committed Mar 11, 2022
1 parent 765e20a commit cb60690
Show file tree
Hide file tree
Showing 14 changed files with 153 additions and 70 deletions.
142 changes: 99 additions & 43 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def set(
dag_id: str,
task_id: str,
run_id: str,
map_index: int = -1,
session: Session = NEW_SESSION,
) -> None:
"""Store an XCom value.
Expand All @@ -124,6 +125,8 @@ def set(
:param dag_id: DAG ID.
:param task_id: Task ID.
:param run_id: DAG run ID for the task.
:param map_index: Optional map index to assign XCom for a mapped task.
The default is ``-1`` (set for a non-mapped task).
:param session: Database session. If not given, a new session will be
created for this function.
"""
Expand Down Expand Up @@ -153,6 +156,7 @@ def set(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
map_index: int = -1,
) -> None:
""":sphinx-autoapi-skip:"""
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -184,6 +188,7 @@ def set(
task_id=task_id,
dag_id=dag_id,
run_id=run_id,
map_index=map_index,
)

# Remove duplicate XComs and insert a new one.
Expand All @@ -204,13 +209,49 @@ def set(
session.add(new)
session.flush()

@classmethod
@provide_session
def get_value(
cls,
*,
ti_key: "TaskInstanceKey",
key: Optional[str] = None,
session: Session = NEW_SESSION,
) -> Any:
"""Retrieve an XCom value for a task instance.
This method returns "full" XCom values (i.e. uses ``deserialize_value``
from the XCom backend). Use :meth:`get_many` if you want the "shortened"
value via ``orm_deserialize_value``.
If there are no results, *None* is returned. If multiple XCom entries
match the criteria, an arbitrary one is returned.
:param ti_key: The TaskInstanceKey to look up the XCom for.
:param key: A key for the XCom. If provided, only XCom with matching
keys will be returned. Pass *None* (default) to remove the filter.
:param session: Database session. If not given, a new session will be
created for this function.
"""
return cls.get_one(
key=key,
task_id=ti_key.task_id,
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
map_index=ti_key.map_index,
session=session,
)

@overload
@classmethod
def get_one(
cls,
*,
key: Optional[str] = None,
ti_key: "TaskInstanceKey",
dag_id: Optional[str] = None,
task_id: Optional[str] = None,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
session: Session = NEW_SESSION,
) -> Optional[Any]:
"""Retrieve an XCom value, optionally meeting certain criteria.
Expand All @@ -219,12 +260,22 @@ def get_one(
from the XCom backend). Use :meth:`get_many` if you want the "shortened"
value via ``orm_deserialize_value``.
If there are no results, *None* is returned.
If there are no results, *None* is returned. If multiple XCom entries
match the criteria, an arbitrary one is returned.
A deprecated form of this function accepts ``execution_date`` instead of
``run_id``. The two arguments are mutually exclusive.
:param ti_key: The TaskInstanceKey to look up the XCom for
.. seealso:: ``get_value()`` is a convenience function if you already
have a structured TaskInstance or TaskInstanceKey object available.
:param run_id: DAG run ID for the task.
:param dag_id: Only pull XCom from this DAG. Pass *None* (default) to
remove the filter.
:param task_id: Only XCom from task with matching ID will be pulled.
Pass *None* (default) to remove the filter.
:param map_index: Only XCom from task with matching ID will be pulled.
Pass *None* (default) to remove the filter.
:param key: A key for the XCom. If provided, only XCom with matching
keys will be returned. Pass *None* (default) to remove the filter.
:param include_prior_dates: If *False* (default), only XCom from the
Expand All @@ -234,19 +285,6 @@ def get_one(
created for this function.
"""

@overload
@classmethod
def get_one(
cls,
*,
key: Optional[str] = None,
task_id: str,
dag_id: str,
run_id: str,
session: Session = NEW_SESSION,
) -> Optional[Any]:
...

@overload
@classmethod
def get_one(
Expand All @@ -272,27 +310,19 @@ def get_one(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
ti_key: Optional["TaskInstanceKey"] = None,
map_index: Optional[int] = None,
) -> Optional[Any]:
""":sphinx-autoapi-skip:"""
if not exactly_one(execution_date is not None, ti_key is not None, run_id is not None):
if not exactly_one(execution_date is not None, run_id is not None):
raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed")

if ti_key is not None:
query = session.query(cls).filter_by(
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
task_id=ti_key.task_id,
)
if key:
query = query.filter_by(key=key)
query = query.limit(1)
elif run_id:
if run_id:
query = cls.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
Expand All @@ -308,6 +338,7 @@ def get_one(
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
Expand All @@ -329,6 +360,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Union[str, Iterable[str], None] = None,
dag_ids: Union[str, Iterable[str], None] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand All @@ -348,6 +380,8 @@ def get_many(
Pass *None* (default) to remove the filter.
:param dag_id: Only pulls XComs from this DAG. If *None* (default), the
DAG of the calling task is used.
:param map_index: Only XComs from matching map indexes will be pulled.
Pass *None* (default) to remove the filter.
:param include_prior_dates: If *False* (default), only XComs from the
specified DAG run are returned. If *True*, all matching XComs are
returned regardless of the run it belongs to.
Expand All @@ -363,6 +397,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Union[str, Iterable[str], None] = None,
dag_ids: Union[str, Iterable[str], None] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand All @@ -377,6 +412,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_ids: Optional[Union[str, Iterable[str]]] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand Down Expand Up @@ -407,6 +443,11 @@ def get_many(
elif dag_ids is not None:
query = query.filter(cls.dag_id == dag_ids)

if is_container(map_indexes):
query = query.filter(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
query = query.filter(cls.map_index == map_indexes)

if include_prior_dates:
if execution_date is not None:
query = query.filter(DagRun.execution_date <= execution_date)
Expand Down Expand Up @@ -439,7 +480,15 @@ def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> Non

@overload
@classmethod
def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None:
def clear(
cls,
*,
dag_id: str,
task_id: str,
run_id: str,
map_index: Optional[int] = None,
session: Session = NEW_SESSION,
) -> None:
"""Clear all XCom data from the database for the given task instance.
A deprecated form of this function accepts ``execution_date`` instead of
Expand All @@ -448,6 +497,8 @@ def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Sess
:param dag_id: ID of DAG to clear the XCom for.
:param task_id: ID of task to clear the XCom for.
:param run_id: ID of DAG run to clear the XCom for.
:param map_index: If given, only clear XCom from this particular mapped
task. The default ``None`` clears *all* XComs from the task.
:param session: Database session. If not given, a new session will be
created for this function.
"""
Expand All @@ -473,6 +524,7 @@ def clear(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
) -> None:
""":sphinx-autoapi-skip:"""
from airflow.models import DagRun
Expand All @@ -496,17 +548,20 @@ def clear(
.scalar()
)

return session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id).delete()
query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
if map_index is not None:
query = query.filter_by(map_index=map_index)
query.delete()

@staticmethod
def serialize_value(
value: Any,
*,
key=None,
task_id=None,
dag_id=None,
run_id=None,
mapping_index: int = -1,
key: Optional[str] = None,
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
):
"""Serialize XCom value to str or pickled object"""
if conf.getboolean('core', 'enable_xcom_pickling'):
Expand Down Expand Up @@ -550,13 +605,14 @@ def orm_deserialize_value(self) -> Any:
return BaseXCom.deserialize_value(self)


def _patch_outdated_serializer(clazz, params):
"""
Previously XCom.serialize_value only accepted one argument ``value``. In order to give
custom XCom backends more flexibility with how they store values we now forward to
``XCom.serialize_value`` all params passed to ``XCom.set``. In order to maintain
compatibility with XCom backends written with the old signature we check the signature
and if necessary we patch with a method that ignores kwargs the backend does not accept.
def _patch_outdated_serializer(clazz: Type[BaseXCom], params: Iterable[str]) -> None:
"""Patch a custom ``serialize_value`` to accept the modern signature.
To give custom XCom backends more flexibility with how they store values, we
now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``.
In order to maintain compatibility with custom XCom backends written with
the old signature, we check the signature and, if necessary, patch with a
method that ignores kwargs the backend does not accept.
"""
old_serializer = clazz.serialize_value

Expand All @@ -571,7 +627,7 @@ def _shim(**kwargs):
)
return old_serializer(**kwargs)

clazz.serialize_value = _shim
clazz.serialize_value = _shim # type: ignore[assignment]


def _get_function_params(function) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_link(
) -> str:
# Fetch the correct execution date for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
when = XCom.get_one(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when}
return build_airflow_url_with_query(query)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_link(
:return: url link
"""
if ti_key:
flow_id = XCom.get_one(key="return_value", ti_key=ti_key)
flow_id = XCom.get_value(key="return_value", ti_key=ti_key)
else:
assert dttm
flow_id = XCom.get_one(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DbtCloudRunJobOperatorLink(BaseOperatorLink):

def get_link(self, operator, dttm=None, *, ti_key=None):
if ti_key:
job_run_url = XCom.get_one(key="job_run_url", ti_key=ti_key)
job_run_url = XCom.get_value(key="job_run_url", ti_key=ti_key)
else:
assert dttm
job_run_url = XCom.get_one(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/links/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_link(
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
conf = XCom.get_one(key=self.key, ti_key=ti_key)
conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
conf = XCom.get_one(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/links/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_link(
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
conf = XCom.get_one(key=self.key, ti_key=ti_key)
conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
conf = XCom.get_one(
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_link(
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
list_conf = XCom.get_one(key=self.key, ti_key=ti_key)
list_conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
list_conf = XCom.get_one(
Expand Down
Loading

0 comments on commit cb60690

Please sign in to comment.