diff --git a/airflow/migrations/versions/0102_c306b5b5ae4a_switch_xcom_table_to_use_run_id.py b/airflow/migrations/versions/0102_c306b5b5ae4a_switch_xcom_table_to_use_run_id.py index 951c0a6c89bf3..633f545c6e51c 100644 --- a/airflow/migrations/versions/0102_c306b5b5ae4a_switch_xcom_table_to_use_run_id.py +++ b/airflow/migrations/versions/0102_c306b5b5ae4a_switch_xcom_table_to_use_run_id.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -"""Switch XCom table to use ``run_id``. +"""Switch XCom table to use ``run_id`` and add ``map_index``. Revision ID: c306b5b5ae4a Revises: a3bcd0914482 @@ -25,7 +25,7 @@ from typing import Sequence from alembic import op -from sqlalchemy import Column, Integer, LargeBinary, MetaData, Table, and_, select +from sqlalchemy import Column, Integer, LargeBinary, MetaData, Table, and_, literal_column, select from airflow.migrations.db_types import TIMESTAMP, StringID from airflow.migrations.utils import get_mssql_table_constraints @@ -50,6 +50,7 @@ def _get_new_xcom_columns() -> Sequence[Column]: Column("timestamp", TIMESTAMP, nullable=False), Column("dag_id", StringID(), nullable=False), Column("run_id", StringID(), nullable=False), + Column("map_index", Integer, nullable=False, server_default="-1"), ] @@ -98,6 +99,7 @@ def upgrade(): xcom.c.timestamp, xcom.c.dag_id, dagrun.c.run_id, + literal_column("-1"), ], ).select_from( xcom.join( @@ -118,9 +120,9 @@ def upgrade(): op.rename_table("__airflow_tmp_xcom", "xcom") with op.batch_alter_table("xcom") as batch_op: - batch_op.create_primary_key("xcom_pkey", ["dag_run_id", "task_id", "key"]) + batch_op.create_primary_key("xcom_pkey", ["dag_run_id", "task_id", "map_index", "key"]) batch_op.create_index("idx_xcom_key", ["key"]) - batch_op.create_index("idx_xcom_ti_id", ["dag_id", "task_id", "run_id"]) + batch_op.create_index("idx_xcom_ti_id", ["dag_id", "run_id", "task_id", "map_index"]) def downgrade(): @@ -132,6 +134,10 @@ def downgrade(): op.create_table("__airflow_tmp_xcom", *_get_old_xcom_columns()) xcom = Table("xcom", metadata, *_get_new_xcom_columns()) + + # Remoe XCom entries from mapped tis. + op.execute(xcom.delete().where(xcom.c.map_index != -1)) + dagrun = _get_dagrun_table() query = select( [ diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 32190d23f081c..7456fd293d778 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -61,6 +61,7 @@ class BaseXCom(Base, LoggingMixin): dag_run_id = Column(Integer(), nullable=False, primary_key=True) task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) + map_index = Column(Integer, primary_key=True, nullable=False, server_default="-1") key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) # Denormalized for easier lookup. @@ -87,7 +88,7 @@ class BaseXCom(Base, LoggingMixin): # but it goes over MySQL's index length limit. So we instead create indexes # separately, and enforce uniqueness with DagRun.id instead. Index("idx_xcom_key", key), - Index("idx_xcom_ti_id", dag_id, task_id, run_id), + Index("idx_xcom_ti_id", dag_id, task_id, run_id, map_index), ) @reconstructor @@ -111,6 +112,7 @@ def set( dag_id: str, task_id: str, run_id: str, + map_index: int = -1, session: Session = NEW_SESSION, ) -> None: """Store an XCom value. @@ -123,6 +125,8 @@ def set( :param dag_id: DAG ID. :param task_id: Task ID. :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). :param session: Database session. If not given, a new session will be created for this function. """ @@ -152,6 +156,7 @@ def set( session: Session = NEW_SESSION, *, run_id: Optional[str] = None, + map_index: int = -1, ) -> None: """:sphinx-autoapi-skip:""" from airflow.models.dagrun import DagRun @@ -183,6 +188,7 @@ def set( task_id=task_id, dag_id=dag_id, run_id=run_id, + map_index=map_index, ) # Remove duplicate XComs and insert a new one. @@ -203,13 +209,49 @@ def set( session.add(new) session.flush() + @classmethod + @provide_session + def get_value( + cls, + *, + ti_key: "TaskInstanceKey", + key: Optional[str] = None, + session: Session = NEW_SESSION, + ) -> Any: + """Retrieve an XCom value for a task instance. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + :param ti_key: The TaskInstanceKey to look up the XCom for. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + :param session: Database session. If not given, a new session will be + created for this function. + """ + return cls.get_one( + key=key, + task_id=ti_key.task_id, + dag_id=ti_key.dag_id, + run_id=ti_key.run_id, + map_index=ti_key.map_index, + session=session, + ) + @overload @classmethod def get_one( cls, *, key: Optional[str] = None, - ti_key: "TaskInstanceKey", + dag_id: Optional[str] = None, + task_id: Optional[str] = None, + run_id: Optional[str] = None, + map_index: Optional[int] = None, session: Session = NEW_SESSION, ) -> Optional[Any]: """Retrieve an XCom value, optionally meeting certain criteria. @@ -218,12 +260,22 @@ def get_one( from the XCom backend). Use :meth:`get_many` if you want the "shortened" value via ``orm_deserialize_value``. - If there are no results, *None* is returned. + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. A deprecated form of this function accepts ``execution_date`` instead of ``run_id``. The two arguments are mutually exclusive. - :param ti_key: The TaskInstanceKey to look up the XCom for + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. :param key: A key for the XCom. If provided, only XCom with matching keys will be returned. Pass *None* (default) to remove the filter. :param include_prior_dates: If *False* (default), only XCom from the @@ -233,19 +285,6 @@ def get_one( created for this function. """ - @overload - @classmethod - def get_one( - cls, - *, - key: Optional[str] = None, - task_id: str, - dag_id: str, - run_id: str, - session: Session = NEW_SESSION, - ) -> Optional[Any]: - ... - @overload @classmethod def get_one( @@ -271,27 +310,19 @@ def get_one( session: Session = NEW_SESSION, *, run_id: Optional[str] = None, - ti_key: Optional["TaskInstanceKey"] = None, + map_index: Optional[int] = None, ) -> Optional[Any]: """:sphinx-autoapi-skip:""" - if not exactly_one(execution_date is not None, ti_key is not None, run_id is not None): + if not exactly_one(execution_date is not None, run_id is not None): raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed") - if ti_key is not None: - query = session.query(cls).filter_by( - dag_id=ti_key.dag_id, - run_id=ti_key.run_id, - task_id=ti_key.task_id, - ) - if key: - query = query.filter_by(key=key) - query = query.limit(1) - elif run_id: + if run_id: query = cls.get_many( run_id=run_id, key=key, task_ids=task_id, dag_ids=dag_id, + map_indexes=map_index, include_prior_dates=include_prior_dates, limit=1, session=session, @@ -307,6 +338,7 @@ def get_one( key=key, task_ids=task_id, dag_ids=dag_id, + map_indexes=map_index, include_prior_dates=include_prior_dates, limit=1, session=session, @@ -328,6 +360,7 @@ def get_many( key: Optional[str] = None, task_ids: Union[str, Iterable[str], None] = None, dag_ids: Union[str, Iterable[str], None] = None, + map_indexes: Union[int, Iterable[int], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, session: Session = NEW_SESSION, @@ -347,6 +380,8 @@ def get_many( Pass *None* (default) to remove the filter. :param dag_id: Only pulls XComs from this DAG. If *None* (default), the DAG of the calling task is used. + :param map_index: Only XComs from matching map indexes will be pulled. + Pass *None* (default) to remove the filter. :param include_prior_dates: If *False* (default), only XComs from the specified DAG run are returned. If *True*, all matching XComs are returned regardless of the run it belongs to. @@ -362,6 +397,7 @@ def get_many( key: Optional[str] = None, task_ids: Union[str, Iterable[str], None] = None, dag_ids: Union[str, Iterable[str], None] = None, + map_indexes: Union[int, Iterable[int], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, session: Session = NEW_SESSION, @@ -376,6 +412,7 @@ def get_many( key: Optional[str] = None, task_ids: Optional[Union[str, Iterable[str]]] = None, dag_ids: Optional[Union[str, Iterable[str]]] = None, + map_indexes: Union[int, Iterable[int], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, session: Session = NEW_SESSION, @@ -406,6 +443,11 @@ def get_many( elif dag_ids is not None: query = query.filter(cls.dag_id == dag_ids) + if is_container(map_indexes): + query = query.filter(cls.map_index.in_(map_indexes)) + elif map_indexes is not None: + query = query.filter(cls.map_index == map_indexes) + if include_prior_dates: if execution_date is not None: query = query.filter(DagRun.execution_date <= execution_date) @@ -438,7 +480,15 @@ def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> Non @overload @classmethod - def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None: + def clear( + cls, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: Optional[int] = None, + session: Session = NEW_SESSION, + ) -> None: """Clear all XCom data from the database for the given task instance. A deprecated form of this function accepts ``execution_date`` instead of @@ -447,6 +497,8 @@ def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Sess :param dag_id: ID of DAG to clear the XCom for. :param task_id: ID of task to clear the XCom for. :param run_id: ID of DAG run to clear the XCom for. + :param map_index: If given, only clear XCom from this particular mapped + task. The default ``None`` clears *all* XComs from the task. :param session: Database session. If not given, a new session will be created for this function. """ @@ -472,6 +524,7 @@ def clear( session: Session = NEW_SESSION, *, run_id: Optional[str] = None, + map_index: Optional[int] = None, ) -> None: """:sphinx-autoapi-skip:""" from airflow.models import DagRun @@ -495,17 +548,20 @@ def clear( .scalar() ) - return session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id).delete() + query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) + if map_index is not None: + query = query.filter_by(map_index=map_index) + query.delete() @staticmethod def serialize_value( value: Any, *, - key=None, - task_id=None, - dag_id=None, - run_id=None, - mapping_index: int = -1, + key: Optional[str] = None, + task_id: Optional[str] = None, + dag_id: Optional[str] = None, + run_id: Optional[str] = None, + map_index: Optional[int] = None, ): """Serialize XCom value to str or pickled object""" if conf.getboolean('core', 'enable_xcom_pickling'): @@ -549,13 +605,14 @@ def orm_deserialize_value(self) -> Any: return BaseXCom.deserialize_value(self) -def _patch_outdated_serializer(clazz, params): - """ - Previously XCom.serialize_value only accepted one argument ``value``. In order to give - custom XCom backends more flexibility with how they store values we now forward to - ``XCom.serialize_value`` all params passed to ``XCom.set``. In order to maintain - compatibility with XCom backends written with the old signature we check the signature - and if necessary we patch with a method that ignores kwargs the backend does not accept. +def _patch_outdated_serializer(clazz: Type[BaseXCom], params: Iterable[str]) -> None: + """Patch a custom ``serialize_value`` to accept the modern signature. + + To give custom XCom backends more flexibility with how they store values, we + now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``. + In order to maintain compatibility with custom XCom backends written with + the old signature, we check the signature and, if necessary, patch with a + method that ignores kwargs the backend does not accept. """ old_serializer = clazz.serialize_value @@ -570,7 +627,7 @@ def _shim(**kwargs): ) return old_serializer(**kwargs) - clazz.serialize_value = _shim + clazz.serialize_value = _shim # type: ignore[assignment] def _get_function_params(function) -> List[str]: diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index b72373a43dde4..0689f14c56261 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -56,7 +56,7 @@ def get_link( ) -> str: # Fetch the correct execution date for the triggerED dag which is # stored in xcom during execution of the triggerING task. - when = XCom.get_one(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO) + when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO) query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when} return build_airflow_url_with_query(query) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 183d20a738a64..84ce7d3fe8fdf 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -245,7 +245,7 @@ def get_link( :return: url link """ if ti_key: - flow_id = XCom.get_one(key="return_value", ti_key=ti_key) + flow_id = XCom.get_value(key="return_value", ti_key=ti_key) else: assert dttm flow_id = XCom.get_one( diff --git a/airflow/providers/dbt/cloud/operators/dbt.py b/airflow/providers/dbt/cloud/operators/dbt.py index 00d8bcd8b07b0..fa93eb63f0b13 100644 --- a/airflow/providers/dbt/cloud/operators/dbt.py +++ b/airflow/providers/dbt/cloud/operators/dbt.py @@ -35,7 +35,7 @@ class DbtCloudRunJobOperatorLink(BaseOperatorLink): def get_link(self, operator, dttm=None, *, ti_key=None): if ti_key: - job_run_url = XCom.get_one(key="job_run_url", ti_key=ti_key) + job_run_url = XCom.get_value(key="job_run_url", ti_key=ti_key) else: assert dttm job_run_url = XCom.get_one( diff --git a/airflow/providers/google/cloud/links/base.py b/airflow/providers/google/cloud/links/base.py index c0f160e200a29..a47920ddd4b39 100644 --- a/airflow/providers/google/cloud/links/base.py +++ b/airflow/providers/google/cloud/links/base.py @@ -38,7 +38,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( diff --git a/airflow/providers/google/cloud/links/dataproc.py b/airflow/providers/google/cloud/links/dataproc.py index 2cfa6e3e86ce6..c797c7fd6b5a1 100644 --- a/airflow/providers/google/cloud/links/dataproc.py +++ b/airflow/providers/google/cloud/links/dataproc.py @@ -70,7 +70,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( @@ -113,7 +113,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - list_conf = XCom.get_one(key=self.key, ti_key=ti_key) + list_conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm list_conf = XCom.get_one( diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 04b7cc3282e53..f7d6e7d1206f8 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -63,13 +63,22 @@ class BigQueryConsoleLink(BaseOperatorLink): name = 'BigQuery Console' - def get_link(self, operator, dttm): - job_id = XCom.get_one( - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - key='job_id', - ) + def get_link( + self, + operator, + dttm: Optional[datetime] = None, + ti_key: Optional["TaskInstanceKey"] = None, + ): + if ti_key is not None: + job_id = XCom.get_value(key='job_id', ti_key=ti_key) + else: + assert dttm is not None + job_id = XCom.get_one( + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + key='job_id', + ) return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else '' @@ -90,9 +99,9 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ): if ti_key: - job_ids = XCom.get_one(key='job_id', ti_key=ti_key) + job_ids = XCom.get_value(key='job_id', ti_key=ti_key) else: - assert dttm + assert dttm is not None job_ids = XCom.get_one( key='job_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm ) diff --git a/airflow/providers/google/cloud/operators/dataproc_metastore.py b/airflow/providers/google/cloud/operators/dataproc_metastore.py index 89dae26e530a5..5e70ed95185e5 100644 --- a/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -86,7 +86,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( @@ -141,7 +141,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index f2784a08b7fb3..ce7d6ca9d51fd 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud MLEngine operators.""" +import datetime import logging import re import warnings @@ -26,6 +27,7 @@ from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context @@ -978,10 +980,22 @@ class AIPlatformConsoleLink(BaseOperatorLink): name = "AI Platform Console" - def get_link(self, operator, dttm): - gcp_metadata_dict = XCom.get_one( - key="gcp_metadata", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm - ) + def get_link( + self, + operator, + dttm: Optional[datetime.datetime] = None, + ti_key: Optional["TaskInstanceKey"] = None, + ) -> str: + if ti_key is not None: + gcp_metadata_dict = XCom.get_value(key="gcp_metadata", ti_key=ti_key) + else: + assert dttm is not None + gcp_metadata_dict = XCom.get_one( + key="gcp_metadata", + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + ) if not gcp_metadata_dict: return '' job_id = gcp_metadata_dict['job_id'] diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index 2800fe900bdcb..88adc217d4fa1 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -43,7 +43,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - run_id = XCom.get_one(key="run_id", ti_key=ti_key) + run_id = XCom.get_value(key="run_id", ti_key=ti_key) else: assert dttm run_id = XCom.get_one( diff --git a/airflow/providers/qubole/operators/qubole.py b/airflow/providers/qubole/operators/qubole.py index c867536bf0153..ec4ce329a947b 100644 --- a/airflow/providers/qubole/operators/qubole.py +++ b/airflow/providers/qubole/operators/qubole.py @@ -64,7 +64,7 @@ def get_link( else: host = 'https://api.qubole.com/v2/analyze?command_id=' if ti_key: - qds_command_id = XCom.get_one(key='qbol_cmd_id', ti_key=ti_key) + qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key) else: assert dttm qds_command_id = XCom.get_one( diff --git a/airflow/www/views.py b/airflow/www/views.py index ae70427fb3b66..4a624dd20d18a 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3670,6 +3670,7 @@ def pre_add(self, item): task_id=item.task_id, dag_id=item.dag_id, run_id=item.run_id, + map_index=item.map_index, ) def pre_update(self, item): @@ -3681,6 +3682,7 @@ def pre_update(self, item): task_id=item.task_id, dag_id=item.dag_id, run_id=item.run_id, + map_index=item.map_index, ) diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 9663a4c91340e..9dfd7dd838b34 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -29,7 +29,7 @@ Here's the list of all the Database Migrations that are executed via when you ru +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ | ``c97c2ab6aa23`` | ``c306b5b5ae4a`` | ``2.3.0`` | add callback request table | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ -| ``c306b5b5ae4a`` | ``a3bcd0914482`` | ``2.3.0`` | Switch XCom table to use ``run_id``. | +| ``c306b5b5ae4a`` | ``a3bcd0914482`` | ``2.3.0`` | Switch XCom table to use ``run_id`` and add ``map_index``. | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ | ``a3bcd0914482`` | ``e655c0453f75`` | ``2.3.0`` | add data_compressed to serialized_dag | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 5dd979f74caa1..0d8181abf61c8 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -81,7 +81,7 @@ def test_resolve_xcom_class(self): def test_resolve_xcom_class_fallback_to_basexcom(self): cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls().serialize_value([1]) == b"[1]" + assert cls.serialize_value([1]) == b"[1]" @conf_vars({("core", "enable_xcom_pickling"): "False"}) @conf_vars({("core", "xcom_backend"): "to be removed"}) @@ -89,7 +89,7 @@ def test_resolve_xcom_class_fallback_to_basexcom_no_config(self): conf.remove_option("core", "xcom_backend") cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls().serialize_value([1]) == b"[1]" + assert cls.serialize_value([1]) == b"[1]" def test_xcom_deserialize_with_json_to_pickle_switch(self, dag_run, session): ti_key = TaskInstanceKey( @@ -107,7 +107,7 @@ def test_xcom_deserialize_with_json_to_pickle_switch(self, dag_run, session): session=session, ) with conf_vars({("core", "enable_xcom_pickling"): "True"}): - ret_value = XCom.get_one(key="xcom_test3", ti_key=ti_key, session=session) + ret_value = XCom.get_value(key="xcom_test3", ti_key=ti_key, session=session) assert ret_value == {"key": "value"} def test_xcom_deserialize_with_pickle_to_json_switch(self, dag_run, session): @@ -226,7 +226,7 @@ def serialize_value( dag_id=None, task_id=None, run_id=None, - mapping_index: int = -1, + map_index=None, ): serialize_watcher( value=value, @@ -234,6 +234,7 @@ def serialize_value( dag_id=dag_id, task_id=task_id, run_id=run_id, + map_index=map_index, ) return json.dumps(value).encode('utf-8') @@ -245,6 +246,7 @@ def serialize_value( dag_id="test_dag", task_id="test_task", run_id=IN_MEMORY_RUN_ID, + map_index=-1, ) expected = {**kwargs, 'run_id': '__airflow_in_memory_dagrun__'} XCom = resolve_xcom_backend() diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 7144feddbbd39..bca754d3321df 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -920,7 +920,7 @@ class TaskStateLink(BaseOperatorLink): name = 'My Link' - def get_link(self, operator, dttm): + def get_link(self, operator, *, ti_key): return 'https://www.google.com' class MyOperator(BaseOperator):