Skip to content

Commit

Permalink
Fix SkipMixin with Database Isolation for AIP-44 (#40781)
Browse files Browse the repository at this point in the history
* Fix SkipMixin with Database Isolation for AIP-44

* Fix pytest of _log instance

* Fix pytests
  • Loading branch information
jscheffl authored Jul 17, 2024
1 parent 6fdc398 commit fec2b10
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 25 deletions.
3 changes: 3 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import (
TaskInstance,
_add_log,
Expand Down Expand Up @@ -110,6 +111,8 @@ def _initialize_map() -> dict[str, Callable]:
DagRun.fetch_task_instance,
DagRun._get_log_template,
SerializedDagModel.get_serialized_dag,
SkipMixin._skip,
SkipMixin._skip_all_except,
TaskInstance._check_and_change_state_before_execution,
TaskInstance.get_task_instance,
TaskInstance._get_dagrun,
Expand Down
1 change: 1 addition & 0 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def get_task_instance(
)

@staticmethod
@internal_api_call
@provide_session
def fetch_task_instance(
dag_id: str,
Expand Down
73 changes: 51 additions & 22 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from __future__ import annotations

import warnings
from types import GeneratorType
from typing import TYPE_CHECKING, Iterable, Sequence

from sqlalchemy import select, update

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -60,8 +62,8 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
class SkipMixin(LoggingMixin):
"""A Mixin to skip Tasks Instances."""

@staticmethod
def _set_state_to_skipped(
self,
dag_run: DagRun | DagRunPydantic,
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
Expand Down Expand Up @@ -93,12 +95,28 @@ def _set_state_to_skipped(
.execution_options(synchronize_session=False)
)

@provide_session
def skip(
self,
dag_run: DagRun | DagRunPydantic,
execution_date: DateTime,
tasks: Iterable[DAGNode],
map_index: int = -1,
):
"""Facade for compatibility for call to internal API."""
# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
task_id: str | None = getattr(self, "task_id", None)
SkipMixin._skip(
dag_run=dag_run, task_id=task_id, execution_date=execution_date, tasks=tasks, map_index=map_index
)

@staticmethod
@internal_api_call
@provide_session
def _skip(
dag_run: DagRun | DagRunPydantic,
task_id: str | None,
execution_date: DateTime,
tasks: Iterable[DAGNode],
session: Session = NEW_SESSION,
map_index: int = -1,
):
Expand Down Expand Up @@ -143,11 +161,9 @@ def skip(
raise ValueError("dag_run is required")

task_ids_list = [d.task_id for d in task_list]
self._set_state_to_skipped(dag_run, task_ids_list, session)
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
task_id: str | None = getattr(self, "task_id", None)
if task_id is not None:
from airflow.models.xcom import XCom

Expand All @@ -165,6 +181,21 @@ def skip_all_except(
self,
ti: TaskInstance | TaskInstancePydantic,
branch_task_ids: None | str | Iterable[str],
):
"""Facade for compatibility for call to internal API."""
# Ensure we don't serialize a generator object
if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
branch_task_ids = list(branch_task_ids)
SkipMixin._skip_all_except(ti=ti, branch_task_ids=branch_task_ids)

@classmethod
@internal_api_call
@provide_session
def _skip_all_except(
cls,
ti: TaskInstance | TaskInstancePydantic,
branch_task_ids: None | str | Iterable[str],
session: Session = NEW_SESSION,
):
"""
Implement the logic for a branching operator.
Expand All @@ -175,6 +206,7 @@ def skip_all_except(
branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
newly added tasks should be skipped when they are cleared.
"""
log = cls().log # Note: need to catch logger form instance, static logger breaks pytest
if isinstance(branch_task_ids, str):
branch_task_id_set = {branch_task_ids}
elif isinstance(branch_task_ids, Iterable):
Expand All @@ -195,20 +227,15 @@ def skip_all_except(
f"but got {type(branch_task_ids).__name__!r}."
)

self.log.info("Following branch %s", branch_task_id_set)
log.info("Following branch %s", branch_task_id_set)

dag_run = ti.get_dagrun()
dag_run = ti.get_dagrun(session=session)
if TYPE_CHECKING:
assert isinstance(dag_run, DagRun)
assert ti.task

# TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to
# pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition
# does not attempt to serialize the field from/to ORM
task = ti.task
dag = task.dag
if TYPE_CHECKING:
assert dag
dag = TaskInstance.ensure_dag(ti, session=session)

valid_task_ids = set(dag.task_ids)
invalid_task_ids = branch_task_id_set - valid_task_ids
Expand Down Expand Up @@ -239,15 +266,17 @@ def skip_all_except(
skip_tasks = [
(t.task_id, downstream_ti.map_index)
for t in downstream_tasks
if (downstream_ti := dag_run.get_task_instance(t.task_id, map_index=ti.map_index))
if (
downstream_ti := dag_run.get_task_instance(
t.task_id, map_index=ti.map_index, session=session
)
)
and t.task_id not in branch_task_id_set
]

follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
self.log.info("Skipping tasks %s", skip_tasks)
with create_session() as session:
self._set_state_to_skipped(dag_run, skip_tasks, session=session)
# For some reason, session.commit() needs to happen before xcom_push.
# Otherwise the session is not committed.
session.commit()
ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
log.info("Skipping tasks %s", skip_tasks)
SkipMixin._set_state_to_skipped(dag_run, skip_tasks, session=session)
ti.xcom_push(
key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}, session=session
)
16 changes: 16 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,22 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:

return dr

@classmethod
@provide_session
def ensure_dag(
cls, task_instance: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION
) -> DAG:
"""Ensure that task has a dag object associated, might have been removed by serialization."""
if TYPE_CHECKING:
assert task_instance.task
if task_instance.task.dag is None or task_instance.task.dag is ATTRIBUTE_REMOVED:
task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag(
dag_id=task_instance.dag_id, session=session
)
if TYPE_CHECKING:
assert task_instance.task.dag
return task_instance.task.dag

@classmethod
@internal_api_call
@provide_session
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_skip(self, mock_now, dag_maker):
execution_date=now,
state=State.FAILED,
)
SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks, session=session)
SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks)

session.query(TI).filter(
TI.dag_id == "dag",
Expand All @@ -91,7 +91,7 @@ def test_skip_none_dagrun(self, mock_now, dag_maker):
RemovedInAirflow3Warning,
match=r"Passing an execution_date to `skip\(\)` is deprecated in favour of passing a dag_run",
):
SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session)
SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks)

session.query(TI).filter(
TI.dag_id == "dag",
Expand All @@ -103,7 +103,7 @@ def test_skip_none_dagrun(self, mock_now, dag_maker):

def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session)
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[])
assert not session.query.called
assert not session.commit.called

Expand Down

0 comments on commit fec2b10

Please sign in to comment.