diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 235d6eb6d555a..72992b2170e97 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -22,7 +22,7 @@ import time import warnings from datetime import timedelta -from typing import Any, Callable, Iterable, Union +from typing import Any, Callable, Iterable, Optional, Union from airflow import settings from airflow.configuration import conf @@ -56,6 +56,23 @@ def _is_metadatabase_mysql() -> bool: return settings.engine.url.get_backend_name() == "mysql" +class PokeReturnValue: + """ + Sensors can optionally return an instance of the PokeReturnValue class in the poke method. + If an XCom value is supplied when the sensor is done, then the XCom value will be + pushed through the operator return value. + :param is_done: Set to true to indicate the sensor can stop poking. + :param xcom_value: An optional XCOM value to be returned by the operator. + """ + + def __init__(self, is_done: bool, xcom_value: Optional[Any] = None) -> None: + self.xcom_value = xcom_value + self.is_done = is_done + + def __bool__(self) -> bool: + return self.is_done + + class BaseSensorOperator(BaseOperator, SkipMixin): """ Sensor operators are derived from this class and inherit these attributes. @@ -150,7 +167,7 @@ def _validate_input_values(self) -> None: f"mode since it will take reschedule time over MySQL's TIMESTAMP limit." ) - def poke(self, context: Context) -> bool: + def poke(self, context: Context) -> Union[bool, PokeReturnValue]: """ Function that the sensors defined while deriving this class should override. @@ -255,7 +272,14 @@ def run_duration() -> float: try_number = 1 log_dag_id = self.dag.dag_id if self.has_dag() else "" - while not self.poke(context): + xcom_value = None + while True: + poke_return = self.poke(context) + if poke_return: + if isinstance(poke_return, PokeReturnValue): + xcom_value = poke_return.xcom_value + break + if run_duration() > self.timeout: # If sensor is in soft fail mode but times out raise AirflowSkipException. if self.soft_fail: @@ -275,6 +299,7 @@ def run_duration() -> float: time.sleep(self._get_next_poke_interval(started_at, run_duration, try_number)) try_number += 1 self.log.info("Success criteria met. Exiting.") + return xcom_value def _get_next_poke_interval( self, diff --git a/docs/apache-airflow-providers/howto/create-update-providers.rst b/docs/apache-airflow-providers/howto/create-update-providers.rst index 5fdb717ab6d21..6cb5886e84e3d 100644 --- a/docs/apache-airflow-providers/howto/create-update-providers.rst +++ b/docs/apache-airflow-providers/howto/create-update-providers.rst @@ -389,6 +389,46 @@ this (note the ``if ti_key is not None:`` condition). return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) +Having sensors return XOM values +-------------------------------- +In Airflow 2.3, sensor operators will be able to return XCOM values. This is achieved by returning an instance of the ``PokeReturnValue`` object at the end of the ``poke()`` method: + + .. code-block:: python + + from airflow.sensors.base import PokeReturnValue + + + class SensorWithXcomValue(BaseSensorOperator): + def poke(self, context: Context) -> Union[bool, PokeReturnValue]: + # ... + is_done = ... # set to true if the sensor should stop poking. + xcom_value = ... # return value of the sensor operator to be pushed to XCOM. + return PokeReturnValue(is_done, xcom_value) + + +To implement a sensor operator that pushes a XCOM value and supports both version 2.3 and pre-2.3, you need to explicitly push the XCOM value if the version is pre-2.3. + + .. code-block:: python + + try: + from airflow.sensors.base import PokeReturnValue + except ImportError: + PokeReturnValue = None + + + class SensorWithXcomValue(BaseSensorOperator): + def poke(self, context: Context) -> bool: + # ... + is_done = ... # set to true if the sensor should stop poking. + xcom_value = ... # return value of the sensor operator to be pushed to XCOM. + if PokeReturnValue is not None: + return PokeReturnValue(is_done, xcom_value) + else: + if is_done: + context["ti"].xcom_push(key="xcom_key", value=xcom_value) + return is_done + + How-to Update a community provider ---------------------------------- diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index 6e580b6deea8c..e77b61c6ccaa6 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -23,8 +23,9 @@ from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout from airflow.models import TaskReschedule +from airflow.models.xcom import XCom from airflow.operators.dummy import DummyOperator -from airflow.sensors.base import BaseSensorOperator, poke_mode_only +from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils import timezone from airflow.utils.context import Context @@ -48,6 +49,16 @@ def poke(self, context: Context): return self.return_value +class DummySensorWithXcomValue(BaseSensorOperator): + def __init__(self, return_value=False, xcom_value=None, **kwargs): + super().__init__(**kwargs) + self.xcom_value = xcom_value + self.return_value = return_value + + def poke(self, context: Context): + return PokeReturnValue(self.return_value, self.xcom_value) + + class TestBaseSensor: @staticmethod def clean_db(): @@ -78,7 +89,10 @@ def _make_sensor(return_value, task_id=SENSOR_OP, **kwargs): kwargs[timeout] = 0 with dag_maker(TEST_DAG_ID): - sensor = DummySensor(task_id=task_id, return_value=return_value, **kwargs) + if "xcom_value" in kwargs: + sensor = DummySensorWithXcomValue(task_id=task_id, return_value=return_value, **kwargs) + else: + sensor = DummySensor(task_id=task_id, return_value=return_value, **kwargs) dummy_op = DummyOperator(task_id=DUMMY_OP) sensor >> dummy_op @@ -607,6 +621,41 @@ def assert_ti_state(try_number, max_tries, state): self._run(sensor) assert_ti_state(4, 4, State.FAILED) + def test_sensor_with_xcom(self, make_sensor): + xcom_value = "TestValue" + sensor, dr = make_sensor(True, xcom_value=xcom_value) + + self._run(sensor) + tis = dr.get_task_instances() + assert len(tis) == 2 + for ti in tis: + if ti.task_id == SENSOR_OP: + assert ti.state == State.SUCCESS + if ti.task_id == DUMMY_OP: + assert ti.state == State.NONE + actual_xcom_value = XCom.get_one( + key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id + ) + assert actual_xcom_value == xcom_value + + def test_sensor_with_xcom_fails(self, make_sensor): + xcom_value = "TestValue" + sensor, dr = make_sensor(False, xcom_value=xcom_value) + + with pytest.raises(AirflowSensorTimeout): + self._run(sensor) + tis = dr.get_task_instances() + assert len(tis) == 2 + for ti in tis: + if ti.task_id == SENSOR_OP: + assert ti.state == State.FAILED + if ti.task_id == DUMMY_OP: + assert ti.state == State.NONE + actual_xcom_value = XCom.get_one( + key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id + ) + assert actual_xcom_value is None + @poke_mode_only class DummyPokeOnlySensor(BaseSensorOperator):