From cb60690c269b9d6520b76bc2a0e8fef0a7fb7ed9 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 9 Mar 2022 17:48:04 +0800 Subject: [PATCH] Add map_index to XCom interface This adds an additional (optional) map_index argument to XCom's get/set/clear interface so mapped task instances can push to the correct entries, and have them pulled correctly by a downstream. To make the XCom interface easier to use for common scenarios, a convenience method get_value is added to take a TaskInstanceKey that automatically performs argument unpacking and call get_one underneath. This is not done as a get_one overload to simplify the implementation and typing. --- airflow/models/xcom.py | 142 ++++++++++++------ airflow/operators/trigger_dagrun.py | 2 +- airflow/providers/amazon/aws/operators/emr.py | 2 +- airflow/providers/dbt/cloud/operators/dbt.py | 2 +- airflow/providers/google/cloud/links/base.py | 2 +- .../providers/google/cloud/links/dataproc.py | 4 +- .../google/cloud/operators/bigquery.py | 27 ++-- .../cloud/operators/dataproc_metastore.py | 4 +- .../google/cloud/operators/mlengine.py | 22 ++- .../microsoft/azure/operators/data_factory.py | 2 +- airflow/providers/qubole/operators/qubole.py | 2 +- airflow/www/views.py | 2 + tests/models/test_xcom.py | 8 +- tests/serialization/test_dag_serialization.py | 2 +- 14 files changed, 153 insertions(+), 70 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 5b7bef0eb84c1..7456fd293d778 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -112,6 +112,7 @@ def set( dag_id: str, task_id: str, run_id: str, + map_index: int = -1, session: Session = NEW_SESSION, ) -> None: """Store an XCom value. @@ -124,6 +125,8 @@ def set( :param dag_id: DAG ID. :param task_id: Task ID. :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). :param session: Database session. If not given, a new session will be created for this function. """ @@ -153,6 +156,7 @@ def set( session: Session = NEW_SESSION, *, run_id: Optional[str] = None, + map_index: int = -1, ) -> None: """:sphinx-autoapi-skip:""" from airflow.models.dagrun import DagRun @@ -184,6 +188,7 @@ def set( task_id=task_id, dag_id=dag_id, run_id=run_id, + map_index=map_index, ) # Remove duplicate XComs and insert a new one. @@ -204,13 +209,49 @@ def set( session.add(new) session.flush() + @classmethod + @provide_session + def get_value( + cls, + *, + ti_key: "TaskInstanceKey", + key: Optional[str] = None, + session: Session = NEW_SESSION, + ) -> Any: + """Retrieve an XCom value for a task instance. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + :param ti_key: The TaskInstanceKey to look up the XCom for. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + :param session: Database session. If not given, a new session will be + created for this function. + """ + return cls.get_one( + key=key, + task_id=ti_key.task_id, + dag_id=ti_key.dag_id, + run_id=ti_key.run_id, + map_index=ti_key.map_index, + session=session, + ) + @overload @classmethod def get_one( cls, *, key: Optional[str] = None, - ti_key: "TaskInstanceKey", + dag_id: Optional[str] = None, + task_id: Optional[str] = None, + run_id: Optional[str] = None, + map_index: Optional[int] = None, session: Session = NEW_SESSION, ) -> Optional[Any]: """Retrieve an XCom value, optionally meeting certain criteria. @@ -219,12 +260,22 @@ def get_one( from the XCom backend). Use :meth:`get_many` if you want the "shortened" value via ``orm_deserialize_value``. - If there are no results, *None* is returned. + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. A deprecated form of this function accepts ``execution_date`` instead of ``run_id``. The two arguments are mutually exclusive. - :param ti_key: The TaskInstanceKey to look up the XCom for + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. :param key: A key for the XCom. If provided, only XCom with matching keys will be returned. Pass *None* (default) to remove the filter. :param include_prior_dates: If *False* (default), only XCom from the @@ -234,19 +285,6 @@ def get_one( created for this function. """ - @overload - @classmethod - def get_one( - cls, - *, - key: Optional[str] = None, - task_id: str, - dag_id: str, - run_id: str, - session: Session = NEW_SESSION, - ) -> Optional[Any]: - ... - @overload @classmethod def get_one( @@ -272,27 +310,19 @@ def get_one( session: Session = NEW_SESSION, *, run_id: Optional[str] = None, - ti_key: Optional["TaskInstanceKey"] = None, + map_index: Optional[int] = None, ) -> Optional[Any]: """:sphinx-autoapi-skip:""" - if not exactly_one(execution_date is not None, ti_key is not None, run_id is not None): + if not exactly_one(execution_date is not None, run_id is not None): raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed") - if ti_key is not None: - query = session.query(cls).filter_by( - dag_id=ti_key.dag_id, - run_id=ti_key.run_id, - task_id=ti_key.task_id, - ) - if key: - query = query.filter_by(key=key) - query = query.limit(1) - elif run_id: + if run_id: query = cls.get_many( run_id=run_id, key=key, task_ids=task_id, dag_ids=dag_id, + map_indexes=map_index, include_prior_dates=include_prior_dates, limit=1, session=session, @@ -308,6 +338,7 @@ def get_one( key=key, task_ids=task_id, dag_ids=dag_id, + map_indexes=map_index, include_prior_dates=include_prior_dates, limit=1, session=session, @@ -329,6 +360,7 @@ def get_many( key: Optional[str] = None, task_ids: Union[str, Iterable[str], None] = None, dag_ids: Union[str, Iterable[str], None] = None, + map_indexes: Union[int, Iterable[int], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, session: Session = NEW_SESSION, @@ -348,6 +380,8 @@ def get_many( Pass *None* (default) to remove the filter. :param dag_id: Only pulls XComs from this DAG. If *None* (default), the DAG of the calling task is used. + :param map_index: Only XComs from matching map indexes will be pulled. + Pass *None* (default) to remove the filter. :param include_prior_dates: If *False* (default), only XComs from the specified DAG run are returned. If *True*, all matching XComs are returned regardless of the run it belongs to. @@ -363,6 +397,7 @@ def get_many( key: Optional[str] = None, task_ids: Union[str, Iterable[str], None] = None, dag_ids: Union[str, Iterable[str], None] = None, + map_indexes: Union[int, Iterable[int], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, session: Session = NEW_SESSION, @@ -377,6 +412,7 @@ def get_many( key: Optional[str] = None, task_ids: Optional[Union[str, Iterable[str]]] = None, dag_ids: Optional[Union[str, Iterable[str]]] = None, + map_indexes: Union[int, Iterable[int], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, session: Session = NEW_SESSION, @@ -407,6 +443,11 @@ def get_many( elif dag_ids is not None: query = query.filter(cls.dag_id == dag_ids) + if is_container(map_indexes): + query = query.filter(cls.map_index.in_(map_indexes)) + elif map_indexes is not None: + query = query.filter(cls.map_index == map_indexes) + if include_prior_dates: if execution_date is not None: query = query.filter(DagRun.execution_date <= execution_date) @@ -439,7 +480,15 @@ def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> Non @overload @classmethod - def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None: + def clear( + cls, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: Optional[int] = None, + session: Session = NEW_SESSION, + ) -> None: """Clear all XCom data from the database for the given task instance. A deprecated form of this function accepts ``execution_date`` instead of @@ -448,6 +497,8 @@ def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Sess :param dag_id: ID of DAG to clear the XCom for. :param task_id: ID of task to clear the XCom for. :param run_id: ID of DAG run to clear the XCom for. + :param map_index: If given, only clear XCom from this particular mapped + task. The default ``None`` clears *all* XComs from the task. :param session: Database session. If not given, a new session will be created for this function. """ @@ -473,6 +524,7 @@ def clear( session: Session = NEW_SESSION, *, run_id: Optional[str] = None, + map_index: Optional[int] = None, ) -> None: """:sphinx-autoapi-skip:""" from airflow.models import DagRun @@ -496,17 +548,20 @@ def clear( .scalar() ) - return session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id).delete() + query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) + if map_index is not None: + query = query.filter_by(map_index=map_index) + query.delete() @staticmethod def serialize_value( value: Any, *, - key=None, - task_id=None, - dag_id=None, - run_id=None, - mapping_index: int = -1, + key: Optional[str] = None, + task_id: Optional[str] = None, + dag_id: Optional[str] = None, + run_id: Optional[str] = None, + map_index: Optional[int] = None, ): """Serialize XCom value to str or pickled object""" if conf.getboolean('core', 'enable_xcom_pickling'): @@ -550,13 +605,14 @@ def orm_deserialize_value(self) -> Any: return BaseXCom.deserialize_value(self) -def _patch_outdated_serializer(clazz, params): - """ - Previously XCom.serialize_value only accepted one argument ``value``. In order to give - custom XCom backends more flexibility with how they store values we now forward to - ``XCom.serialize_value`` all params passed to ``XCom.set``. In order to maintain - compatibility with XCom backends written with the old signature we check the signature - and if necessary we patch with a method that ignores kwargs the backend does not accept. +def _patch_outdated_serializer(clazz: Type[BaseXCom], params: Iterable[str]) -> None: + """Patch a custom ``serialize_value`` to accept the modern signature. + + To give custom XCom backends more flexibility with how they store values, we + now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``. + In order to maintain compatibility with custom XCom backends written with + the old signature, we check the signature and, if necessary, patch with a + method that ignores kwargs the backend does not accept. """ old_serializer = clazz.serialize_value @@ -571,7 +627,7 @@ def _shim(**kwargs): ) return old_serializer(**kwargs) - clazz.serialize_value = _shim + clazz.serialize_value = _shim # type: ignore[assignment] def _get_function_params(function) -> List[str]: diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index b72373a43dde4..0689f14c56261 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -56,7 +56,7 @@ def get_link( ) -> str: # Fetch the correct execution date for the triggerED dag which is # stored in xcom during execution of the triggerING task. - when = XCom.get_one(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO) + when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO) query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when} return build_airflow_url_with_query(query) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 183d20a738a64..84ce7d3fe8fdf 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -245,7 +245,7 @@ def get_link( :return: url link """ if ti_key: - flow_id = XCom.get_one(key="return_value", ti_key=ti_key) + flow_id = XCom.get_value(key="return_value", ti_key=ti_key) else: assert dttm flow_id = XCom.get_one( diff --git a/airflow/providers/dbt/cloud/operators/dbt.py b/airflow/providers/dbt/cloud/operators/dbt.py index 00d8bcd8b07b0..fa93eb63f0b13 100644 --- a/airflow/providers/dbt/cloud/operators/dbt.py +++ b/airflow/providers/dbt/cloud/operators/dbt.py @@ -35,7 +35,7 @@ class DbtCloudRunJobOperatorLink(BaseOperatorLink): def get_link(self, operator, dttm=None, *, ti_key=None): if ti_key: - job_run_url = XCom.get_one(key="job_run_url", ti_key=ti_key) + job_run_url = XCom.get_value(key="job_run_url", ti_key=ti_key) else: assert dttm job_run_url = XCom.get_one( diff --git a/airflow/providers/google/cloud/links/base.py b/airflow/providers/google/cloud/links/base.py index c0f160e200a29..a47920ddd4b39 100644 --- a/airflow/providers/google/cloud/links/base.py +++ b/airflow/providers/google/cloud/links/base.py @@ -38,7 +38,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( diff --git a/airflow/providers/google/cloud/links/dataproc.py b/airflow/providers/google/cloud/links/dataproc.py index 2cfa6e3e86ce6..c797c7fd6b5a1 100644 --- a/airflow/providers/google/cloud/links/dataproc.py +++ b/airflow/providers/google/cloud/links/dataproc.py @@ -70,7 +70,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( @@ -113,7 +113,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - list_conf = XCom.get_one(key=self.key, ti_key=ti_key) + list_conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm list_conf = XCom.get_one( diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 04b7cc3282e53..f7d6e7d1206f8 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -63,13 +63,22 @@ class BigQueryConsoleLink(BaseOperatorLink): name = 'BigQuery Console' - def get_link(self, operator, dttm): - job_id = XCom.get_one( - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - key='job_id', - ) + def get_link( + self, + operator, + dttm: Optional[datetime] = None, + ti_key: Optional["TaskInstanceKey"] = None, + ): + if ti_key is not None: + job_id = XCom.get_value(key='job_id', ti_key=ti_key) + else: + assert dttm is not None + job_id = XCom.get_one( + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + key='job_id', + ) return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else '' @@ -90,9 +99,9 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ): if ti_key: - job_ids = XCom.get_one(key='job_id', ti_key=ti_key) + job_ids = XCom.get_value(key='job_id', ti_key=ti_key) else: - assert dttm + assert dttm is not None job_ids = XCom.get_one( key='job_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm ) diff --git a/airflow/providers/google/cloud/operators/dataproc_metastore.py b/airflow/providers/google/cloud/operators/dataproc_metastore.py index 89dae26e530a5..5e70ed95185e5 100644 --- a/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -86,7 +86,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( @@ -141,7 +141,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - conf = XCom.get_one(key=self.key, ti_key=ti_key) + conf = XCom.get_value(key=self.key, ti_key=ti_key) else: assert dttm conf = XCom.get_one( diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index f2784a08b7fb3..ce7d6ca9d51fd 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud MLEngine operators.""" +import datetime import logging import re import warnings @@ -26,6 +27,7 @@ from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context @@ -978,10 +980,22 @@ class AIPlatformConsoleLink(BaseOperatorLink): name = "AI Platform Console" - def get_link(self, operator, dttm): - gcp_metadata_dict = XCom.get_one( - key="gcp_metadata", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm - ) + def get_link( + self, + operator, + dttm: Optional[datetime.datetime] = None, + ti_key: Optional["TaskInstanceKey"] = None, + ) -> str: + if ti_key is not None: + gcp_metadata_dict = XCom.get_value(key="gcp_metadata", ti_key=ti_key) + else: + assert dttm is not None + gcp_metadata_dict = XCom.get_one( + key="gcp_metadata", + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + ) if not gcp_metadata_dict: return '' job_id = gcp_metadata_dict['job_id'] diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index 2800fe900bdcb..88adc217d4fa1 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -43,7 +43,7 @@ def get_link( ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: - run_id = XCom.get_one(key="run_id", ti_key=ti_key) + run_id = XCom.get_value(key="run_id", ti_key=ti_key) else: assert dttm run_id = XCom.get_one( diff --git a/airflow/providers/qubole/operators/qubole.py b/airflow/providers/qubole/operators/qubole.py index c867536bf0153..ec4ce329a947b 100644 --- a/airflow/providers/qubole/operators/qubole.py +++ b/airflow/providers/qubole/operators/qubole.py @@ -64,7 +64,7 @@ def get_link( else: host = 'https://api.qubole.com/v2/analyze?command_id=' if ti_key: - qds_command_id = XCom.get_one(key='qbol_cmd_id', ti_key=ti_key) + qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key) else: assert dttm qds_command_id = XCom.get_one( diff --git a/airflow/www/views.py b/airflow/www/views.py index ae70427fb3b66..4a624dd20d18a 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3670,6 +3670,7 @@ def pre_add(self, item): task_id=item.task_id, dag_id=item.dag_id, run_id=item.run_id, + map_index=item.map_index, ) def pre_update(self, item): @@ -3681,6 +3682,7 @@ def pre_update(self, item): task_id=item.task_id, dag_id=item.dag_id, run_id=item.run_id, + map_index=item.map_index, ) diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 49663b28b2696..0d8181abf61c8 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -81,7 +81,7 @@ def test_resolve_xcom_class(self): def test_resolve_xcom_class_fallback_to_basexcom(self): cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls().serialize_value([1]) == b"[1]" + assert cls.serialize_value([1]) == b"[1]" @conf_vars({("core", "enable_xcom_pickling"): "False"}) @conf_vars({("core", "xcom_backend"): "to be removed"}) @@ -89,7 +89,7 @@ def test_resolve_xcom_class_fallback_to_basexcom_no_config(self): conf.remove_option("core", "xcom_backend") cls = resolve_xcom_backend() assert issubclass(cls, BaseXCom) - assert cls().serialize_value([1]) == b"[1]" + assert cls.serialize_value([1]) == b"[1]" def test_xcom_deserialize_with_json_to_pickle_switch(self, dag_run, session): ti_key = TaskInstanceKey( @@ -226,7 +226,7 @@ def serialize_value( dag_id=None, task_id=None, run_id=None, - mapping_index: int = -1, + map_index=None, ): serialize_watcher( value=value, @@ -234,6 +234,7 @@ def serialize_value( dag_id=dag_id, task_id=task_id, run_id=run_id, + map_index=map_index, ) return json.dumps(value).encode('utf-8') @@ -245,6 +246,7 @@ def serialize_value( dag_id="test_dag", task_id="test_task", run_id=IN_MEMORY_RUN_ID, + map_index=-1, ) expected = {**kwargs, 'run_id': '__airflow_in_memory_dagrun__'} XCom = resolve_xcom_backend() diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 7144feddbbd39..bca754d3321df 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -920,7 +920,7 @@ class TaskStateLink(BaseOperatorLink): name = 'My Link' - def get_link(self, operator, dttm): + def get_link(self, operator, *, ti_key): return 'https://www.google.com' class MyOperator(BaseOperator):