Skip to content

Commit

Permalink
Fix TriggerDagRunOperator with deferrable parameter (#30406)
Browse files Browse the repository at this point in the history
* readding after borked it

* pre-commit

* finally fixing after the github issue last week

* push fix

* feedback from hussein
  • Loading branch information
dylanbstorey authored Apr 13, 2023
1 parent 2c66b24 commit 5e23de5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
11 changes: 7 additions & 4 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datetime
import json
import time
from typing import TYPE_CHECKING, Sequence, cast
from typing import TYPE_CHECKING, Any, Sequence, cast

from sqlalchemy.orm.exc import NoResultFound

Expand Down Expand Up @@ -211,13 +211,16 @@ def execute(self, context: Context):
return

@provide_session
def execute_complete(self, context: Context, session: Session, **kwargs):
parsed_execution_date = context["execution_date"]
def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]):

# This execution date is parsed from the return trigger event
provided_execution_date = event[1]["execution_dates"][0]
try:
dag_run = (
session.query(DagRun)
.filter(DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == parsed_execution_date)
.filter(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_execution_date
)
.one()
)

Expand Down
2 changes: 1 addition & 1 deletion airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
while True:
num_dags = await self.count_dags()
if num_dags == len(self.execution_dates):
yield TriggerEvent(True)
yield TriggerEvent(self.serialize())
await asyncio.sleep(self.poll_interval)

@sync_to_async
Expand Down
40 changes: 33 additions & 7 deletions tests/operators/test_trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.models import DAG, DagBag, DagModel, DagRun, Log, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
Expand Down Expand Up @@ -410,11 +411,17 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1
trigger = DagStateTrigger(
dag_id="down_stream",
execution_dates=[DEFAULT_DATE],
poll_interval=20,
states=["success", "failed"],
)

task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date})
task.execute_complete(context={}, event=trigger.serialize())

def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self):
"""Test TriggerDagRunOperator with wait_for_completion."""
"""Test TriggerDagRunOperator wait_for_completion dag run in non defined state."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
Expand All @@ -433,11 +440,21 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self):
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

with pytest.raises(AirflowException):
task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date})
trigger = DagStateTrigger(
dag_id="down_stream",
execution_dates=[DEFAULT_DATE],
poll_interval=20,
states=["success", "failed"],
)
with pytest.raises(AirflowException) as exception:
task.execute_complete(
context={},
event=trigger.serialize(),
)
assert "which is not in" in str(exception)

def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self):
"""Test TriggerDagRunOperator with wait_for_completion."""
"""Test TriggerDagRunOperator wait_for_completion dag run in failed state."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
Expand All @@ -457,5 +474,14 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self)
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

with pytest.raises(AirflowException):
task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date})
trigger = DagStateTrigger(
dag_id="down_stream",
execution_dates=[DEFAULT_DATE],
poll_interval=20,
states=["success", "failed"],
)

with pytest.raises(AirflowException) as exception:
task.execute_complete(context={}, event=trigger.serialize())

assert "failed with failed state" in str(exception)

0 comments on commit 5e23de5

Please sign in to comment.