Skip to content

Commit

Permalink
Ensure deps is set, convert BaseSensorOperator to classvar (#21815)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Mar 2, 2022
1 parent 2c57ad4 commit 8b276c6
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 34 deletions.
8 changes: 8 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ In order to support that cleanly we have changed the interface for BaseOperatorL

The existing signature will be detected (by the absence of the `ti_key` argument) and continue to work.

### `ReadyToRescheduleDep` now only runs when `reschedule` is *True*

When a `ReadyToRescheduleDep` is run, it now checks whether the `reschedule` attribute on the operator, and always reports itself as *passed* unless it is set to *True*. If you use this dep class on your custom operator, you will need to add this attribute to the operator class. Built-in operator classes that use this dep class (including sensors and all subclasses) already have this attribute and are not affected.

### The `deps` attribute on an operator class should be a class level attribute

To support operator-mapping (AIP 42), the `deps` attribute on operator class must be a set at the class level. This means that if a custom operator implements this as an instance-level variable, it will not be able to be used for operator-mapping. This does not affect existing code, but we highly recommend you to restructure the operator's dep logic in order to support the new feature.

## Airflow 2.2.4

### Smart sensors deprecated
Expand Down
4 changes: 4 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class AirflowOptionalProviderFeatureException(AirflowException):
"""Raise by providers when imports are missing for optional provider features."""


class UnmappableOperator(AirflowException):
"""Raise when an operator is not implemented to be mappable."""


class UnmappableXComTypePushed(AirflowException):
"""Raise when an unmappable type is pushed as a mapped downstream's dependency."""

Expand Down
13 changes: 10 additions & 3 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from sqlalchemy.orm.session import Session

from airflow.compat.functools import cache
from airflow.exceptions import UnmappableOperator
from airflow.models.abstractoperator import (
DEFAULT_OWNER,
DEFAULT_POOL_SLOTS,
Expand All @@ -66,7 +67,6 @@
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET

Expand All @@ -77,6 +77,7 @@
from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup

# BaseOperator.apply() can be called on an XComArg, sequence, or dict (not
# any mapping since we need the value to be ordered).
Expand Down Expand Up @@ -240,7 +241,7 @@ class MappedOperator(AbstractOperator):
_task_type: str

dag: Optional["DAG"]
task_group: Optional[TaskGroup]
task_group: Optional["TaskGroup"]
start_date: Optional[pendulum.DateTime]
end_date: Optional[pendulum.DateTime]
upstream_task_ids: Set[str] = attr.ib(factory=set, init=False)
Expand Down Expand Up @@ -284,7 +285,13 @@ def get_serialized_fields(cls):
@staticmethod
@cache
def deps_for(operator_class: Type["BaseOperator"]) -> FrozenSet[BaseTIDep]:
return operator_class.deps | {MappedTaskIsExpanded()}
operator_deps = operator_class.deps
if not isinstance(operator_deps, collections.abc.Set):
raise UnmappableOperator(
f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
f"not a {type(operator_deps).__name__}"
)
return operator_deps | {MappedTaskIsExpanded()}

def _validate_argument_count(self) -> None:
"""Validate mapping arguments by unmapping with mocked values.
Expand Down
14 changes: 4 additions & 10 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
'email_on_failure',
)

# Adds one additional dependency for all sensor operators that checks if a
# sensor task instance can be rescheduled.
deps = BaseOperator.deps | {ReadyToRescheduleDep()}

def __init__(
self,
*,
Expand Down Expand Up @@ -310,16 +314,6 @@ def reschedule(self):
"""Define mode rescheduled sensors."""
return self.mode == 'reschedule'

@property
def deps(self):
"""
Adds one additional dependency for all sensor operators that
checks if a sensor task instance can be rescheduled.
"""
if self.reschedule:
return super().deps | {ReadyToRescheduleDep()}
return super().deps


def poke_mode_only(cls):
"""
Expand Down
4 changes: 4 additions & 0 deletions airflow/ti_deps/deps/ready_to_reschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def _get_dep_statuses(self, ti, session, dep_context):
considered as passed. This dependency fails if the latest reschedule
request's reschedule date is still in future.
"""
if not getattr(ti.task, "reschedule", False):
yield self._passing_status(reason="Task is not in reschedule mode.")
return

if dep_context.ignore_in_reschedule_period:
yield self._passing_status(
reason="The context specified that being in a reschedule period was permitted."
Expand Down
10 changes: 3 additions & 7 deletions tests/sensors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,12 @@ def test_ok_with_reschedule_and_retry(self, make_sensor):
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

def test_should_include_ready_to_reschedule_dep_in_reschedule_mode(self):
sensor = DummySensor(task_id='a', return_value=True, mode='reschedule')
@pytest.mark.parametrize("mode", ["poke", "reschedule"])
def test_should_include_ready_to_reschedule_dep(self, mode):
sensor = DummySensor(task_id='a', return_value=True, mode=mode)
deps = sensor.deps
assert ReadyToRescheduleDep() in deps

def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self, make_sensor):
sensor = DummySensor(task_id='a', return_value=False, mode='poke')
deps = sensor.deps
assert ReadyToRescheduleDep() not in deps

def test_invalid_mode(self):
with pytest.raises(AirflowException):
DummySensor(task_id='a', mode='foo')
Expand Down
17 changes: 3 additions & 14 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,14 +1330,8 @@ def test_edge_info_serialization(self):

assert serialized_dag.edge_info == dag.edge_info

@pytest.mark.parametrize(
"mode, expect_custom_deps",
[
("poke", False),
("reschedule", True),
],
)
def test_serialize_sensor(self, mode, expect_custom_deps):
@pytest.mark.parametrize("mode", ["poke", "reschedule"])
def test_serialize_sensor(self, mode):
from airflow.sensors.base import BaseSensorOperator

class DummySensor(BaseSensorOperator):
Expand All @@ -1347,14 +1341,9 @@ def poke(self, context: Context):
op = DummySensor(task_id='dummy', mode=mode, poke_interval=23)

blob = SerializedBaseOperator.serialize_operator(op)

if expect_custom_deps:
assert "deps" in blob
else:
assert "deps" not in blob
assert "deps" in blob

serialized_op = SerializedBaseOperator.deserialize_operator(blob)

assert op.deps == serialized_op.deps

@pytest.mark.parametrize(
Expand Down

0 comments on commit 8b276c6

Please sign in to comment.