Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add map_index to XCom model and interface #22112

Merged
merged 2 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.

"""Switch XCom table to use ``run_id``.
"""Switch XCom table to use ``run_id`` and add ``map_index``.

Revision ID: c306b5b5ae4a
Revises: a3bcd0914482
Expand All @@ -25,7 +25,7 @@
from typing import Sequence

from alembic import op
from sqlalchemy import Column, Integer, LargeBinary, MetaData, Table, and_, select
from sqlalchemy import Column, Integer, LargeBinary, MetaData, Table, and_, literal_column, select

from airflow.migrations.db_types import TIMESTAMP, StringID
from airflow.migrations.utils import get_mssql_table_constraints
Expand All @@ -50,6 +50,7 @@ def _get_new_xcom_columns() -> Sequence[Column]:
Column("timestamp", TIMESTAMP, nullable=False),
Column("dag_id", StringID(), nullable=False),
Column("run_id", StringID(), nullable=False),
Column("map_index", Integer, nullable=False, server_default="-1"),
]


Expand Down Expand Up @@ -98,6 +99,7 @@ def upgrade():
xcom.c.timestamp,
xcom.c.dag_id,
dagrun.c.run_id,
literal_column("-1"),
],
).select_from(
xcom.join(
Expand All @@ -118,9 +120,9 @@ def upgrade():
op.rename_table("__airflow_tmp_xcom", "xcom")

with op.batch_alter_table("xcom") as batch_op:
batch_op.create_primary_key("xcom_pkey", ["dag_run_id", "task_id", "key"])
batch_op.create_primary_key("xcom_pkey", ["dag_run_id", "task_id", "map_index", "key"])
batch_op.create_index("idx_xcom_key", ["key"])
batch_op.create_index("idx_xcom_ti_id", ["dag_id", "task_id", "run_id"])
batch_op.create_index("idx_xcom_ti_id", ["dag_id", "run_id", "task_id", "map_index"])


def downgrade():
Expand All @@ -132,6 +134,10 @@ def downgrade():
op.create_table("__airflow_tmp_xcom", *_get_old_xcom_columns())

xcom = Table("xcom", metadata, *_get_new_xcom_columns())

# Remoe XCom entries from mapped tis.
op.execute(xcom.delete().where(xcom.c.map_index != -1))

dagrun = _get_dagrun_table()
query = select(
[
Expand Down
145 changes: 101 additions & 44 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class BaseXCom(Base, LoggingMixin):

dag_run_id = Column(Integer(), nullable=False, primary_key=True)
task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
map_index = Column(Integer, primary_key=True, nullable=False, server_default="-1")
key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)

# Denormalized for easier lookup.
Expand All @@ -87,7 +88,7 @@ class BaseXCom(Base, LoggingMixin):
# but it goes over MySQL's index length limit. So we instead create indexes
# separately, and enforce uniqueness with DagRun.id instead.
Index("idx_xcom_key", key),
Index("idx_xcom_ti_id", dag_id, task_id, run_id),
Index("idx_xcom_ti_id", dag_id, task_id, run_id, map_index),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column order here doesn't match the migrations -- this matters from one of our DBs (sqlite? mssql?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we have an FK to TaskInstance?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There wasn’t a foreign key before. I think previously (when XCom was based on execution date) it was possible to push an XCom to a future date, so a foreign key did not make sense back then; now it’s based on run ID perhaps it makes sense to have a ti-xcom relation, but that should be a separate discussion regardless.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was possible to push an XCom to a future date

Already done https://lists.apache.org/thread/gofj3g6m6vvksy6n0cmgq1qxd309bbbl (I don't think it ever actually worked.)

)

@reconstructor
Expand All @@ -111,6 +112,7 @@ def set(
dag_id: str,
task_id: str,
run_id: str,
map_index: int = -1,
session: Session = NEW_SESSION,
) -> None:
"""Store an XCom value.
Expand All @@ -123,6 +125,8 @@ def set(
:param dag_id: DAG ID.
:param task_id: Task ID.
:param run_id: DAG run ID for the task.
:param map_index: Optional map index to assign XCom for a mapped task.
The default is ``-1`` (set for a non-mapped task).
:param session: Database session. If not given, a new session will be
created for this function.
"""
Expand Down Expand Up @@ -152,6 +156,7 @@ def set(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
map_index: int = -1,
) -> None:
""":sphinx-autoapi-skip:"""
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -183,6 +188,7 @@ def set(
task_id=task_id,
dag_id=dag_id,
run_id=run_id,
map_index=map_index,
)

# Remove duplicate XComs and insert a new one.
Expand All @@ -203,13 +209,49 @@ def set(
session.add(new)
session.flush()

@classmethod
@provide_session
def get_value(
cls,
*,
ti_key: "TaskInstanceKey",
key: Optional[str] = None,
session: Session = NEW_SESSION,
) -> Any:
"""Retrieve an XCom value for a task instance.

This method returns "full" XCom values (i.e. uses ``deserialize_value``
from the XCom backend). Use :meth:`get_many` if you want the "shortened"
value via ``orm_deserialize_value``.

If there are no results, *None* is returned. If multiple XCom entries
match the criteria, an arbitrary one is returned.

:param ti_key: The TaskInstanceKey to look up the XCom for.
:param key: A key for the XCom. If provided, only XCom with matching
keys will be returned. Pass *None* (default) to remove the filter.
:param session: Database session. If not given, a new session will be
created for this function.
"""
return cls.get_one(
key=key,
task_id=ti_key.task_id,
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
map_index=ti_key.map_index,
session=session,
)

@overload
@classmethod
def get_one(
cls,
*,
key: Optional[str] = None,
ti_key: "TaskInstanceKey",
dag_id: Optional[str] = None,
task_id: Optional[str] = None,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
session: Session = NEW_SESSION,
) -> Optional[Any]:
"""Retrieve an XCom value, optionally meeting certain criteria.
Expand All @@ -218,12 +260,22 @@ def get_one(
from the XCom backend). Use :meth:`get_many` if you want the "shortened"
value via ``orm_deserialize_value``.

If there are no results, *None* is returned.
If there are no results, *None* is returned. If multiple XCom entries
match the criteria, an arbitrary one is returned.

A deprecated form of this function accepts ``execution_date`` instead of
``run_id``. The two arguments are mutually exclusive.

:param ti_key: The TaskInstanceKey to look up the XCom for
.. seealso:: ``get_value()`` is a convenience function if you already
have a structured TaskInstance or TaskInstanceKey object available.

:param run_id: DAG run ID for the task.
:param dag_id: Only pull XCom from this DAG. Pass *None* (default) to
remove the filter.
:param task_id: Only XCom from task with matching ID will be pulled.
Pass *None* (default) to remove the filter.
:param map_index: Only XCom from task with matching ID will be pulled.
Pass *None* (default) to remove the filter.
:param key: A key for the XCom. If provided, only XCom with matching
keys will be returned. Pass *None* (default) to remove the filter.
:param include_prior_dates: If *False* (default), only XCom from the
Expand All @@ -233,19 +285,6 @@ def get_one(
created for this function.
"""

@overload
@classmethod
def get_one(
cls,
*,
key: Optional[str] = None,
task_id: str,
dag_id: str,
run_id: str,
session: Session = NEW_SESSION,
) -> Optional[Any]:
...

@overload
@classmethod
def get_one(
Expand All @@ -271,27 +310,19 @@ def get_one(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
ti_key: Optional["TaskInstanceKey"] = None,
map_index: Optional[int] = None,
) -> Optional[Any]:
""":sphinx-autoapi-skip:"""
if not exactly_one(execution_date is not None, ti_key is not None, run_id is not None):
if not exactly_one(execution_date is not None, run_id is not None):
raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed")

if ti_key is not None:
query = session.query(cls).filter_by(
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
task_id=ti_key.task_id,
)
if key:
query = query.filter_by(key=key)
query = query.limit(1)
elif run_id:
if run_id:
query = cls.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
Expand All @@ -307,6 +338,7 @@ def get_one(
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
Expand All @@ -328,6 +360,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Union[str, Iterable[str], None] = None,
dag_ids: Union[str, Iterable[str], None] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand All @@ -347,6 +380,8 @@ def get_many(
Pass *None* (default) to remove the filter.
:param dag_id: Only pulls XComs from this DAG. If *None* (default), the
DAG of the calling task is used.
:param map_index: Only XComs from matching map indexes will be pulled.
Pass *None* (default) to remove the filter.
:param include_prior_dates: If *False* (default), only XComs from the
specified DAG run are returned. If *True*, all matching XComs are
returned regardless of the run it belongs to.
Expand All @@ -362,6 +397,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Union[str, Iterable[str], None] = None,
dag_ids: Union[str, Iterable[str], None] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand All @@ -376,6 +412,7 @@ def get_many(
key: Optional[str] = None,
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_ids: Optional[Union[str, Iterable[str]]] = None,
map_indexes: Union[int, Iterable[int], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = NEW_SESSION,
Expand Down Expand Up @@ -406,6 +443,11 @@ def get_many(
elif dag_ids is not None:
query = query.filter(cls.dag_id == dag_ids)

if is_container(map_indexes):
query = query.filter(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
query = query.filter(cls.map_index == map_indexes)

if include_prior_dates:
if execution_date is not None:
query = query.filter(DagRun.execution_date <= execution_date)
Expand Down Expand Up @@ -438,7 +480,15 @@ def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> Non

@overload
@classmethod
def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None:
def clear(
cls,
*,
dag_id: str,
task_id: str,
run_id: str,
map_index: Optional[int] = None,
session: Session = NEW_SESSION,
) -> None:
"""Clear all XCom data from the database for the given task instance.

A deprecated form of this function accepts ``execution_date`` instead of
Expand All @@ -447,6 +497,8 @@ def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Sess
:param dag_id: ID of DAG to clear the XCom for.
:param task_id: ID of task to clear the XCom for.
:param run_id: ID of DAG run to clear the XCom for.
:param map_index: If given, only clear XCom from this particular mapped
task. The default ``None`` clears *all* XComs from the task.
:param session: Database session. If not given, a new session will be
created for this function.
"""
Expand All @@ -472,6 +524,7 @@ def clear(
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
) -> None:
""":sphinx-autoapi-skip:"""
from airflow.models import DagRun
Expand All @@ -495,17 +548,20 @@ def clear(
.scalar()
)

return session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id).delete()
query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
if map_index is not None:
query = query.filter_by(map_index=map_index)
query.delete()

@staticmethod
def serialize_value(
value: Any,
*,
key=None,
task_id=None,
dag_id=None,
run_id=None,
mapping_index: int = -1,
key: Optional[str] = None,
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
run_id: Optional[str] = None,
map_index: Optional[int] = None,
):
"""Serialize XCom value to str or pickled object"""
if conf.getboolean('core', 'enable_xcom_pickling'):
Expand Down Expand Up @@ -549,13 +605,14 @@ def orm_deserialize_value(self) -> Any:
return BaseXCom.deserialize_value(self)


def _patch_outdated_serializer(clazz, params):
"""
Previously XCom.serialize_value only accepted one argument ``value``. In order to give
custom XCom backends more flexibility with how they store values we now forward to
``XCom.serialize_value`` all params passed to ``XCom.set``. In order to maintain
compatibility with XCom backends written with the old signature we check the signature
and if necessary we patch with a method that ignores kwargs the backend does not accept.
def _patch_outdated_serializer(clazz: Type[BaseXCom], params: Iterable[str]) -> None:
"""Patch a custom ``serialize_value`` to accept the modern signature.

To give custom XCom backends more flexibility with how they store values, we
now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``.
In order to maintain compatibility with custom XCom backends written with
the old signature, we check the signature and, if necessary, patch with a
method that ignores kwargs the backend does not accept.
"""
old_serializer = clazz.serialize_value

Expand All @@ -570,7 +627,7 @@ def _shim(**kwargs):
)
return old_serializer(**kwargs)

clazz.serialize_value = _shim
clazz.serialize_value = _shim # type: ignore[assignment]


def _get_function_params(function) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_link(
) -> str:
# Fetch the correct execution date for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
when = XCom.get_one(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when}
return build_airflow_url_with_query(query)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_link(
:return: url link
"""
if ti_key:
flow_id = XCom.get_one(key="return_value", ti_key=ti_key)
flow_id = XCom.get_value(key="return_value", ti_key=ti_key)
uranusjr marked this conversation as resolved.
Show resolved Hide resolved
else:
assert dttm
flow_id = XCom.get_one(
Expand Down
Loading