Skip to content

Commit

Permalink
Update base sensor operator to support XCOM return value (apache#20656)
Browse files Browse the repository at this point in the history
Co-authored-by: mingshi <[email protected]>
  • Loading branch information
mingshi-wang and mingshi authored Mar 21, 2022
1 parent cef004d commit cd35972
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 5 deletions.
31 changes: 28 additions & 3 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
import warnings
from datetime import timedelta
from typing import Any, Callable, Iterable, Union
from typing import Any, Callable, Iterable, Optional, Union

from airflow import settings
from airflow.configuration import conf
Expand Down Expand Up @@ -56,6 +56,23 @@ def _is_metadatabase_mysql() -> bool:
return settings.engine.url.get_backend_name() == "mysql"


class PokeReturnValue:
"""
Sensors can optionally return an instance of the PokeReturnValue class in the poke method.
If an XCom value is supplied when the sensor is done, then the XCom value will be
pushed through the operator return value.
:param is_done: Set to true to indicate the sensor can stop poking.
:param xcom_value: An optional XCOM value to be returned by the operator.
"""

def __init__(self, is_done: bool, xcom_value: Optional[Any] = None) -> None:
self.xcom_value = xcom_value
self.is_done = is_done

def __bool__(self) -> bool:
return self.is_done


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class and inherit these attributes.
Expand Down Expand Up @@ -150,7 +167,7 @@ def _validate_input_values(self) -> None:
f"mode since it will take reschedule time over MySQL's TIMESTAMP limit."
)

def poke(self, context: Context) -> bool:
def poke(self, context: Context) -> Union[bool, PokeReturnValue]:
"""
Function that the sensors defined while deriving this class should
override.
Expand Down Expand Up @@ -255,7 +272,14 @@ def run_duration() -> float:
try_number = 1
log_dag_id = self.dag.dag_id if self.has_dag() else ""

while not self.poke(context):
xcom_value = None
while True:
poke_return = self.poke(context)
if poke_return:
if isinstance(poke_return, PokeReturnValue):
xcom_value = poke_return.xcom_value
break

if run_duration() > self.timeout:
# If sensor is in soft fail mode but times out raise AirflowSkipException.
if self.soft_fail:
Expand All @@ -275,6 +299,7 @@ def run_duration() -> float:
time.sleep(self._get_next_poke_interval(started_at, run_duration, try_number))
try_number += 1
self.log.info("Success criteria met. Exiting.")
return xcom_value

def _get_next_poke_interval(
self,
Expand Down
40 changes: 40 additions & 0 deletions docs/apache-airflow-providers/howto/create-update-providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,46 @@ this (note the ``if ti_key is not None:`` condition).
return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id)
Having sensors return XOM values
--------------------------------
In Airflow 2.3, sensor operators will be able to return XCOM values. This is achieved by returning an instance of the ``PokeReturnValue`` object at the end of the ``poke()`` method:

.. code-block:: python
from airflow.sensors.base import PokeReturnValue
class SensorWithXcomValue(BaseSensorOperator):
def poke(self, context: Context) -> Union[bool, PokeReturnValue]:
# ...
is_done = ... # set to true if the sensor should stop poking.
xcom_value = ... # return value of the sensor operator to be pushed to XCOM.
return PokeReturnValue(is_done, xcom_value)
To implement a sensor operator that pushes a XCOM value and supports both version 2.3 and pre-2.3, you need to explicitly push the XCOM value if the version is pre-2.3.

.. code-block:: python
try:
from airflow.sensors.base import PokeReturnValue
except ImportError:
PokeReturnValue = None
class SensorWithXcomValue(BaseSensorOperator):
def poke(self, context: Context) -> bool:
# ...
is_done = ... # set to true if the sensor should stop poking.
xcom_value = ... # return value of the sensor operator to be pushed to XCOM.
if PokeReturnValue is not None:
return PokeReturnValue(is_done, xcom_value)
else:
if is_done:
context["ti"].xcom_push(key="xcom_key", value=xcom_value)
return is_done
How-to Update a community provider
----------------------------------

Expand Down
53 changes: 51 additions & 2 deletions tests/sensors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@

from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout
from airflow.models import TaskReschedule
from airflow.models.xcom import XCom
from airflow.operators.dummy import DummyOperator
from airflow.sensors.base import BaseSensorOperator, poke_mode_only
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
from airflow.utils.context import Context
Expand All @@ -48,6 +49,16 @@ def poke(self, context: Context):
return self.return_value


class DummySensorWithXcomValue(BaseSensorOperator):
def __init__(self, return_value=False, xcom_value=None, **kwargs):
super().__init__(**kwargs)
self.xcom_value = xcom_value
self.return_value = return_value

def poke(self, context: Context):
return PokeReturnValue(self.return_value, self.xcom_value)


class TestBaseSensor:
@staticmethod
def clean_db():
Expand Down Expand Up @@ -78,7 +89,10 @@ def _make_sensor(return_value, task_id=SENSOR_OP, **kwargs):
kwargs[timeout] = 0

with dag_maker(TEST_DAG_ID):
sensor = DummySensor(task_id=task_id, return_value=return_value, **kwargs)
if "xcom_value" in kwargs:
sensor = DummySensorWithXcomValue(task_id=task_id, return_value=return_value, **kwargs)
else:
sensor = DummySensor(task_id=task_id, return_value=return_value, **kwargs)

dummy_op = DummyOperator(task_id=DUMMY_OP)
sensor >> dummy_op
Expand Down Expand Up @@ -607,6 +621,41 @@ def assert_ti_state(try_number, max_tries, state):
self._run(sensor)
assert_ti_state(4, 4, State.FAILED)

def test_sensor_with_xcom(self, make_sensor):
xcom_value = "TestValue"
sensor, dr = make_sensor(True, xcom_value=xcom_value)

self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SUCCESS
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
actual_xcom_value = XCom.get_one(
key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id
)
assert actual_xcom_value == xcom_value

def test_sensor_with_xcom_fails(self, make_sensor):
xcom_value = "TestValue"
sensor, dr = make_sensor(False, xcom_value=xcom_value)

with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.FAILED
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE
actual_xcom_value = XCom.get_one(
key="return_value", task_id=SENSOR_OP, dag_id=dr.dag_id, run_id=dr.run_id
)
assert actual_xcom_value is None


@poke_mode_only
class DummyPokeOnlySensor(BaseSensorOperator):
Expand Down

0 comments on commit cd35972

Please sign in to comment.