Skip to content

Commit

Permalink
Run triggers inline with dag test (#34642)
Browse files Browse the repository at this point in the history
No need to have trigger running -- will just run them async.

(cherry picked from commit 7b37a78)
  • Loading branch information
dstandish authored and ephraimbuddy committed Dec 5, 2023
1 parent c28ba46 commit 28897f7
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 73 deletions.
68 changes: 30 additions & 38 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# under the License.
from __future__ import annotations

import collections.abc
import asyncio
import collections
import copy
import functools
import itertools
Expand Down Expand Up @@ -82,11 +83,11 @@
from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
AirflowSkipException,
DuplicateTaskIdFound,
FailStopDagInvalidTriggerRule,
ParamValidationError,
RemovedInAirflow3Warning,
TaskDeferred,
TaskNotFound,
)
from airflow.jobs.job import run_job
Expand All @@ -101,7 +102,6 @@
Context,
TaskInstance,
TaskInstanceKey,
TaskReturnCode,
clear_task_instances,
)
from airflow.secrets.local_filesystem import LocalFilesystemBackend
Expand Down Expand Up @@ -285,12 +285,11 @@ def get_dataset_triggered_next_run_info(
}


class _StopDagTest(Exception):
"""
Raise when DAG.test should stop immediately.
def _triggerer_is_healthy():
from airflow.jobs.triggerer_job_runner import TriggererJobRunner

:meta private:
"""
job = TriggererJobRunner.most_recent_job()
return job and job.is_alive()


@functools.total_ordering
Expand Down Expand Up @@ -2844,21 +2843,12 @@ def add_logger_if_needed(ti: TaskInstance):
if not scheduled_tis and ids_unrunnable:
self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable)
time.sleep(1)
triggerer_running = _triggerer_is_healthy()
for ti in scheduled_tis:
try:
add_logger_if_needed(ti)
ti.task = tasks[ti.task_id]
ret = _run_task(ti, session=session)
if ret is TaskReturnCode.DEFERRED:
if not _triggerer_is_healthy():
raise _StopDagTest(
"Task has deferred but triggerer component is not running. "
"You can start the triggerer by running `airflow triggerer` in a terminal."
)
except _StopDagTest:
# Let this exception bubble out and not be swallowed by the
# except block below.
raise
_run_task(ti=ti, inline_trigger=not triggerer_running, session=session)
except Exception:
self.log.exception("Task failed; ti=%s", ti)
if conn_file_path or variable_file_path:
Expand Down Expand Up @@ -3992,14 +3982,15 @@ def get_current_dag(cls) -> DAG | None:
return None


def _triggerer_is_healthy():
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
def _run_trigger(trigger):
async def _run_trigger_main():
async for event in trigger.run():
return event

job = TriggererJobRunner.most_recent_job()
return job and job.is_alive()
return asyncio.run(_run_trigger_main())


def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session):
"""
Run a single task instance, and push result to Xcom for downstream tasks.
Expand All @@ -4009,20 +4000,21 @@ def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
Args:
ti: TaskInstance to run
"""
ret = None
log.info("*****************************************************")
if ti.map_index > 0:
log.info("Running task %s index %d", ti.task_id, ti.map_index)
else:
log.info("Running task %s", ti.task_id)
try:
ret = ti._run_raw_task(session=session)
session.flush()
log.info("%s ran successfully!", ti.task_id)
except AirflowSkipException:
log.info("Task Skipped, continuing")
log.info("*****************************************************")
return ret
log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index)
while True:
try:
log.info("[DAG TEST] running task %s", ti)
ti._run_raw_task(session=session, raise_on_defer=inline_trigger)
break
except TaskDeferred as e:
log.info("[DAG TEST] running trigger in line")
event = _run_trigger(e.trigger)
ti.next_method = e.method_name
ti.next_kwargs = {"event": event.payload} if event else e.kwargs
log.info("[DAG TEST] Trigger completed")
session.merge(ti)
session.commit()
log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index)


def _get_or_create_dagrun(
Expand Down
3 changes: 3 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2284,6 +2284,7 @@ def _run_raw_task(
test_mode: bool = False,
job_id: str | None = None,
pool: str | None = None,
raise_on_defer: bool = False,
session: Session = NEW_SESSION,
) -> TaskReturnCode | None:
"""
Expand Down Expand Up @@ -2338,6 +2339,8 @@ def _run_raw_task(
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
if raise_on_defer:
raise
self._defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
Expand Down
81 changes: 47 additions & 34 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import _StopDagTest
from airflow.models.dag import _run_trigger
from airflow.models.serialized_dag import SerializedDagModel
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -824,35 +825,47 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, _):
dag_command.dag_test(cli_args)
assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs

def test_dag_test_no_triggerer(self, dag_maker):
with dag_maker() as dag:

@task
def one():
return 1

@task
def two(val):
return val + 1

class MyOp(BaseOperator):
template_fields = ("tfield",)

def __init__(self, tfield, **kwargs):
self.tfield = tfield
super().__init__(**kwargs)

def execute(self, context, event=None):
if event is None:
print("I AM DEFERRING")
self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), method_name="execute")
return
print("RESUMING")
return self.tfield + 1

task_one = one()
task_two = two(task_one)
op = MyOp(task_id="abc", tfield=str(task_two))
task_two >> op
with pytest.raises(_StopDagTest, match="Task has deferred but triggerer component is not running"):
dag.test()
def test_dag_test_run_trigger(self, dag_maker):
now = timezone.utcnow()
trigger = DateTimeTrigger(moment=now)
e = _run_trigger(trigger)
assert isinstance(e, TriggerEvent)
assert e.payload == now

def test_dag_test_no_triggerer_running(self, dag_maker):
with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run:
with dag_maker() as dag:

@task
def one():
return 1

@task
def two(val):
return val + 1

trigger = TimeDeltaTrigger(timedelta(seconds=0))

class MyOp(BaseOperator):
template_fields = ("tfield",)

def __init__(self, tfield, **kwargs):
self.tfield = tfield
super().__init__(**kwargs)

def execute(self, context, event=None):
if event is None:
print("I AM DEFERRING")
self.defer(trigger=trigger, method_name="execute")
return
print("RESUMING")
return self.tfield + 1

task_one = one()
task_two = two(task_one)
op = MyOp(task_id="abc", tfield=task_two)
task_two >> op
dr = dag.test()
assert mock_run.call_args_list[0] == ((trigger,), {})
tis = dr.get_task_instances()
assert [x for x in tis if x.task_id == "abc"][0].state == "success"
2 changes: 1 addition & 1 deletion tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def execute(self, context: Context):
mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values)
task1 >> mapped
dag.test()
assert caplog.text.count("task_2 ran successfully") == 2
assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2
assert (
"Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'"
in caplog.text
Expand Down

0 comments on commit 28897f7

Please sign in to comment.