From 8b276c6fc191254d96451958609faf81db994b94 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 2 Mar 2022 22:19:45 +0800 Subject: [PATCH] Ensure deps is set, convert BaseSensorOperator to classvar (#21815) --- UPDATING.md | 8 ++++++++ airflow/exceptions.py | 4 ++++ airflow/models/mappedoperator.py | 13 ++++++++++--- airflow/sensors/base.py | 14 ++++---------- airflow/ti_deps/deps/ready_to_reschedule.py | 4 ++++ tests/sensors/test_base.py | 10 +++------- tests/serialization/test_dag_serialization.py | 17 +++-------------- 7 files changed, 36 insertions(+), 34 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 492cfcdefe4a8..6ca92687cb69c 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -214,6 +214,14 @@ In order to support that cleanly we have changed the interface for BaseOperatorL The existing signature will be detected (by the absence of the `ti_key` argument) and continue to work. +### `ReadyToRescheduleDep` now only runs when `reschedule` is *True* + +When a `ReadyToRescheduleDep` is run, it now checks whether the `reschedule` attribute on the operator, and always reports itself as *passed* unless it is set to *True*. If you use this dep class on your custom operator, you will need to add this attribute to the operator class. Built-in operator classes that use this dep class (including sensors and all subclasses) already have this attribute and are not affected. + +### The `deps` attribute on an operator class should be a class level attribute + +To support operator-mapping (AIP 42), the `deps` attribute on operator class must be a set at the class level. This means that if a custom operator implements this as an instance-level variable, it will not be able to be used for operator-mapping. This does not affect existing code, but we highly recommend you to restructure the operator's dep logic in order to support the new feature. + ## Airflow 2.2.4 ### Smart sensors deprecated diff --git a/airflow/exceptions.py b/airflow/exceptions.py index d77bd61e7ebe2..630268b01f2d1 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -98,6 +98,10 @@ class AirflowOptionalProviderFeatureException(AirflowException): """Raise by providers when imports are missing for optional provider features.""" +class UnmappableOperator(AirflowException): + """Raise when an operator is not implemented to be mappable.""" + + class UnmappableXComTypePushed(AirflowException): """Raise when an unmappable type is pushed as a mapped downstream's dependency.""" diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index e9f614e4efab6..578a86026f4dd 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -46,6 +46,7 @@ from sqlalchemy.orm.session import Session from airflow.compat.functools import cache +from airflow.exceptions import UnmappableOperator from airflow.models.abstractoperator import ( DEFAULT_OWNER, DEFAULT_POOL_SLOTS, @@ -66,7 +67,6 @@ from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.state import State, TaskInstanceState -from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET @@ -77,6 +77,7 @@ from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance from airflow.models.xcom_arg import XComArg + from airflow.utils.task_group import TaskGroup # BaseOperator.apply() can be called on an XComArg, sequence, or dict (not # any mapping since we need the value to be ordered). @@ -240,7 +241,7 @@ class MappedOperator(AbstractOperator): _task_type: str dag: Optional["DAG"] - task_group: Optional[TaskGroup] + task_group: Optional["TaskGroup"] start_date: Optional[pendulum.DateTime] end_date: Optional[pendulum.DateTime] upstream_task_ids: Set[str] = attr.ib(factory=set, init=False) @@ -284,7 +285,13 @@ def get_serialized_fields(cls): @staticmethod @cache def deps_for(operator_class: Type["BaseOperator"]) -> FrozenSet[BaseTIDep]: - return operator_class.deps | {MappedTaskIsExpanded()} + operator_deps = operator_class.deps + if not isinstance(operator_deps, collections.abc.Set): + raise UnmappableOperator( + f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, " + f"not a {type(operator_deps).__name__}" + ) + return operator_deps | {MappedTaskIsExpanded()} def _validate_argument_count(self) -> None: """Validate mapping arguments by unmapping with mocked values. diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index e063a2f9ce776..235d6eb6d555a 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -102,6 +102,10 @@ class BaseSensorOperator(BaseOperator, SkipMixin): 'email_on_failure', ) + # Adds one additional dependency for all sensor operators that checks if a + # sensor task instance can be rescheduled. + deps = BaseOperator.deps | {ReadyToRescheduleDep()} + def __init__( self, *, @@ -310,16 +314,6 @@ def reschedule(self): """Define mode rescheduled sensors.""" return self.mode == 'reschedule' - @property - def deps(self): - """ - Adds one additional dependency for all sensor operators that - checks if a sensor task instance can be rescheduled. - """ - if self.reschedule: - return super().deps | {ReadyToRescheduleDep()} - return super().deps - def poke_mode_only(cls): """ diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py index e867c788859cc..9086822ceac8d 100644 --- a/airflow/ti_deps/deps/ready_to_reschedule.py +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -40,6 +40,10 @@ def _get_dep_statuses(self, ti, session, dep_context): considered as passed. This dependency fails if the latest reschedule request's reschedule date is still in future. """ + if not getattr(ti.task, "reschedule", False): + yield self._passing_status(reason="Task is not in reschedule mode.") + return + if dep_context.ignore_in_reschedule_period: yield self._passing_status( reason="The context specified that being in a reschedule period was permitted." diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index 5271c0d9eb25f..6e580b6deea8c 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -332,16 +332,12 @@ def test_ok_with_reschedule_and_retry(self, make_sensor): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_should_include_ready_to_reschedule_dep_in_reschedule_mode(self): - sensor = DummySensor(task_id='a', return_value=True, mode='reschedule') + @pytest.mark.parametrize("mode", ["poke", "reschedule"]) + def test_should_include_ready_to_reschedule_dep(self, mode): + sensor = DummySensor(task_id='a', return_value=True, mode=mode) deps = sensor.deps assert ReadyToRescheduleDep() in deps - def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self, make_sensor): - sensor = DummySensor(task_id='a', return_value=False, mode='poke') - deps = sensor.deps - assert ReadyToRescheduleDep() not in deps - def test_invalid_mode(self): with pytest.raises(AirflowException): DummySensor(task_id='a', mode='foo') diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 0a6f56c430830..aa25f56e8ada4 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1330,14 +1330,8 @@ def test_edge_info_serialization(self): assert serialized_dag.edge_info == dag.edge_info - @pytest.mark.parametrize( - "mode, expect_custom_deps", - [ - ("poke", False), - ("reschedule", True), - ], - ) - def test_serialize_sensor(self, mode, expect_custom_deps): + @pytest.mark.parametrize("mode", ["poke", "reschedule"]) + def test_serialize_sensor(self, mode): from airflow.sensors.base import BaseSensorOperator class DummySensor(BaseSensorOperator): @@ -1347,14 +1341,9 @@ def poke(self, context: Context): op = DummySensor(task_id='dummy', mode=mode, poke_interval=23) blob = SerializedBaseOperator.serialize_operator(op) - - if expect_custom_deps: - assert "deps" in blob - else: - assert "deps" not in blob + assert "deps" in blob serialized_op = SerializedBaseOperator.deserialize_operator(blob) - assert op.deps == serialized_op.deps @pytest.mark.parametrize(