Skip to content

Commit

Permalink
Add map_index to XCom model and interface (#22112)
Browse files Browse the repository at this point in the history
* Add map_index to XCom primary key

This is not actually stored correctly yet. We still need to fix the XCom
interface.

* Add map_index to XCom interface

This adds an additional (optional) map_index argument to XCom's
get/set/clear interface so mapped task instances can push to the
correct entries, and have them pulled correctly by a downstream.

To make the XCom interface easier to use for common scenarios, a
convenience method get_value is added to take a TaskInstanceKey that
automatically performs argument unpacking and call get_one underneath.
This is not done as a get_one overload to simplify the implementation
and typing.
  • Loading branch information
uranusjr authored Mar 11, 2022
1 parent 9eb1a1c commit d08284e
Show file tree
Hide file tree
Showing 16 changed files with 167 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.

"""Switch XCom table to use ``run_id``.
"""Switch XCom table to use ``run_id`` and add ``map_index``.
Revision ID: c306b5b5ae4a
Revises: a3bcd0914482
Expand All @@ -25,7 +25,7 @@
from typing import Sequence

from alembic import op
from sqlalchemy import Column, Integer, LargeBinary, MetaData, Table, and_, select
from sqlalchemy import Column, Integer, LargeBinary, MetaData, Table, and_, literal_column, select

from airflow.migrations.db_types import TIMESTAMP, StringID
from airflow.migrations.utils import get_mssql_table_constraints
Expand All @@ -50,6 +50,7 @@ def _get_new_xcom_columns() -> Sequence[Column]:
Column("timestamp", TIMESTAMP, nullable=False),
Column("dag_id", StringID(), nullable=False),
Column("run_id", StringID(), nullable=False),
Column("map_index", Integer, nullable=False, server_default="-1"),
]


Expand Down Expand Up @@ -98,6 +99,7 @@ def upgrade():
xcom.c.timestamp,
xcom.c.dag_id,
dagrun.c.run_id,
literal_column("-1"),
],
).select_from(
xcom.join(
Expand All @@ -118,9 +120,9 @@ def upgrade():
op.rename_table("__airflow_tmp_xcom", "xcom")

with op.batch_alter_table("xcom") as batch_op:
batch_op.create_primary_key("xcom_pkey", ["dag_run_id", "task_id", "key"])
batch_op.create_primary_key("xcom_pkey", ["dag_run_id", "task_id", "map_index", "key"])
batch_op.create_index("idx_xcom_key", ["key"])
batch_op.create_index("idx_xcom_ti_id", ["dag_id", "task_id", "run_id"])
batch_op.create_index("idx_xcom_ti_id", ["dag_id", "run_id", "task_id", "map_index"])


def downgrade():
Expand All @@ -132,6 +134,10 @@ def downgrade():
op.create_table("__airflow_tmp_xcom", *_get_old_xcom_columns())

xcom = Table("xcom", metadata, *_get_new_xcom_columns())

# Remoe XCom entries from mapped tis.
op.execute(xcom.delete().where(xcom.c.map_index != -1))

dagrun = _get_dagrun_table()
query = select(
[
Expand Down
145 changes: 101 additions & 44 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class BaseXCom(Base, LoggingMixin):

dag_run_id = Column(Integer(), nullable=False, primary_key=True)
task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
map_index = Column(Integer, primary_key=True, nullable=False, server_default="-1")
key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)

# Denormalized for easier lookup.
Expand All @@ -87,7 +88,7 @@ class BaseXCom(Base, LoggingMixin):
# but it goes over MySQL's index length limit. So we instead create indexes
# separately, and enforce uniqueness with DagRun.id instead.
Index("idx_xcom_key", key),
Index("idx_xcom_ti_id", dag_id, task_id, run_id),
Index("idx_xcom_ti_id", dag_id, task_id, run_id, map_index),
)

@reconstructor
Expand All @@ -111,6 +112,7 @@ def set(
dag_id: str,
task_id: str,
run_id: str,
map_index: int = -1,
session: Session = NEW_SESSION,
) -> None:
"""Store an XCom value.
Expand All @@ -123,6 +125,8 @@ def set(
:param dag_id: DAG ID.
:param task_id: Task ID.
:param run_id: DAG run ID for the task.
:param map_index: Optional map index to assign XCom for a mapped task.
The default is ``-1`` (set for a non-mapped task).
:param session: Database session. If not given, a new session will be
created for this function.
"""
Expand Down Expand Up @@ -152,6 +156,7 @@ def set(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
map_index: int = -1,
) -> None:
""":sphinx-autoapi-skip:"""
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -183,6 +188,7 @@ def set(
task_id=task_id,
dag_id=dag_id,
run_id=run_id,
map_index=map_index,
)

# Remove duplicate XComs and insert a new one.
Expand All @@ -203,13 +209,49 @@ def set(
session.add(new)
session.flush()

@classmethod
@provide_session
def get_value(
cls,
*,
ti_key: "TaskInstanceKey",
key: Optional[str] = None,
session: Session = NEW_SESSION,
) -> Any:
"""Retrieve an XCom value for a task instance.
This method returns "full" XCom values (i.e. uses ``deserialize_value``
from the XCom backend). Use :meth:`get_many` if you want the "shortened"
value via ``orm_deserialize_value``.
If there are no results, *None* is returned. If multiple XCom entries
match the criteria, an arbitrary one is returned.
:param ti_key: The TaskInstanceKey to look up the XCom for.
:param key: A key for the XCom. If provided, only XCom with matching
keys will be returned. Pass *None* (default) to remove the filter.
:param session: Database session. If not given, a new session will be
created for this function.
"""
return cls.get_one(
key=key,
task_id=ti_key.task_id,
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
map_index=ti_key.map_index,
session=session,
)

@overload
@classmethod
def get_one(
cls,
*,
key: Optional[str] = None,
ti_key: "TaskInstanceKey",
dag_id: Optional[str] = None,
task_id: Optional[str] = None,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
session: Session = NEW_SESSION,
) -> Optional[Any]:
"""Retrieve an XCom value, optionally meeting certain criteria.
Expand All @@ -218,12 +260,22 @@ def get_one(
from the XCom backend). Use :meth:`get_many` if you want the "shortened"
value via ``orm_deserialize_value``.
If there are no results, *None* is returned.
If there are no results, *None* is returned. If multiple XCom entries
match the criteria, an arbitrary one is returned.
A deprecated form of this function accepts ``execution_date`` instead of
``run_id``. The two arguments are mutually exclusive.
:param ti_key: The TaskInstanceKey to look up the XCom for
.. seealso:: ``get_value()`` is a convenience function if you already
have a structured TaskInstance or TaskInstanceKey object available.
:param run_id: DAG run ID for the task.
:param dag_id: Only pull XCom from this DAG. Pass *None* (default) to
remove the filter.
:param task_id: Only XCom from task with matching ID will be pulled.
Pass *None* (default) to remove the filter.
:param map_index: Only XCom from task with matching ID will be pulled.
Pass *None* (default) to remove the filter.
:param key: A key for the XCom. If provided, only XCom with matching
keys will be returned. Pass *None* (default) to remove the filter.
:param include_prior_dates: If *False* (default), only XCom from the
Expand All @@ -233,19 +285,6 @@ def get_one(
created for this function.
"""

@overload
@classmethod
def get_one(
cls,
*,
key: Optional[str] = None,
task_id: str,
dag_id: str,
run_id: str,
session: Session = NEW_SESSION,
) -> Optional[Any]:
...

@overload
@classmethod
def get_one(
Expand All @@ -271,27 +310,19 @@ def get_one(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
ti_key: Optional["TaskInstanceKey"] = None,
map_index: Optional[int] = None,
) -> Optional[Any]:
""":sphinx-autoapi-skip:"""
if not exactly_one(execution_date is not None, ti_key is not None, run_id is not None):
if not exactly_one(execution_date is not None, run_id is not None):
raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed")

if ti_key is not None:
query = session.query(cls).filter_by(
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
task_id=ti_key.task_id,
)
if key:
query = query.filter_by(key=key)
query = query.limit(1)
elif run_id:
if run_id:
query = cls.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
Expand All @@ -307,6 +338,7 @@ def get_one(
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
Expand All @@ -328,6 +360,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Union[str, Iterable[str], None] = None,
dag_ids: Union[str, Iterable[str], None] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand All @@ -347,6 +380,8 @@ def get_many(
Pass *None* (default) to remove the filter.
:param dag_id: Only pulls XComs from this DAG. If *None* (default), the
DAG of the calling task is used.
:param map_index: Only XComs from matching map indexes will be pulled.
Pass *None* (default) to remove the filter.
:param include_prior_dates: If *False* (default), only XComs from the
specified DAG run are returned. If *True*, all matching XComs are
returned regardless of the run it belongs to.
Expand All @@ -362,6 +397,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Union[str, Iterable[str], None] = None,
dag_ids: Union[str, Iterable[str], None] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand All @@ -376,6 +412,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_ids: Optional[Union[str, Iterable[str]]] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand Down Expand Up @@ -406,6 +443,11 @@ def get_many(
elif dag_ids is not None:
query = query.filter(cls.dag_id == dag_ids)

if is_container(map_indexes):
query = query.filter(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
query = query.filter(cls.map_index == map_indexes)

if include_prior_dates:
if execution_date is not None:
query = query.filter(DagRun.execution_date <= execution_date)
Expand Down Expand Up @@ -438,7 +480,15 @@ def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> Non

@overload
@classmethod
def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None:
def clear(
cls,
*,
dag_id: str,
task_id: str,
run_id: str,
map_index: Optional[int] = None,
session: Session = NEW_SESSION,
) -> None:
"""Clear all XCom data from the database for the given task instance.
A deprecated form of this function accepts ``execution_date`` instead of
Expand All @@ -447,6 +497,8 @@ def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Sess
:param dag_id: ID of DAG to clear the XCom for.
:param task_id: ID of task to clear the XCom for.
:param run_id: ID of DAG run to clear the XCom for.
:param map_index: If given, only clear XCom from this particular mapped
task. The default ``None`` clears *all* XComs from the task.
:param session: Database session. If not given, a new session will be
created for this function.
"""
Expand All @@ -472,6 +524,7 @@ def clear(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
) -> None:
""":sphinx-autoapi-skip:"""
from airflow.models import DagRun
Expand All @@ -495,17 +548,20 @@ def clear(
.scalar()
)

return session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id).delete()
query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
if map_index is not None:
query = query.filter_by(map_index=map_index)
query.delete()

@staticmethod
def serialize_value(
value: Any,
*,
key=None,
task_id=None,
dag_id=None,
run_id=None,
mapping_index: int = -1,
key: Optional[str] = None,
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
):
"""Serialize XCom value to str or pickled object"""
if conf.getboolean('core', 'enable_xcom_pickling'):
Expand Down Expand Up @@ -549,13 +605,14 @@ def orm_deserialize_value(self) -> Any:
return BaseXCom.deserialize_value(self)


def _patch_outdated_serializer(clazz, params):
"""
Previously XCom.serialize_value only accepted one argument ``value``. In order to give
custom XCom backends more flexibility with how they store values we now forward to
``XCom.serialize_value`` all params passed to ``XCom.set``. In order to maintain
compatibility with XCom backends written with the old signature we check the signature
and if necessary we patch with a method that ignores kwargs the backend does not accept.
def _patch_outdated_serializer(clazz: Type[BaseXCom], params: Iterable[str]) -> None:
"""Patch a custom ``serialize_value`` to accept the modern signature.
To give custom XCom backends more flexibility with how they store values, we
now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``.
In order to maintain compatibility with custom XCom backends written with
the old signature, we check the signature and, if necessary, patch with a
method that ignores kwargs the backend does not accept.
"""
old_serializer = clazz.serialize_value

Expand All @@ -570,7 +627,7 @@ def _shim(**kwargs):
)
return old_serializer(**kwargs)

clazz.serialize_value = _shim
clazz.serialize_value = _shim # type: ignore[assignment]


def _get_function_params(function) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_link(
) -> str:
# Fetch the correct execution date for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
when = XCom.get_one(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when}
return build_airflow_url_with_query(query)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_link(
:return: url link
"""
if ti_key:
flow_id = XCom.get_one(key="return_value", ti_key=ti_key)
flow_id = XCom.get_value(key="return_value", ti_key=ti_key)
else:
assert dttm
flow_id = XCom.get_one(
Expand Down
Loading

0 comments on commit d08284e

Please sign in to comment.