From 5e23de5a379258da7e0f57db1b808d69195fce7a Mon Sep 17 00:00:00 2001 From: Dylan Storey Date: Thu, 13 Apr 2023 10:17:57 -0400 Subject: [PATCH] Fix `TriggerDagRunOperator` with deferrable parameter (#30406) * readding after borked it * pre-commit * finally fixing after the github issue last week * push fix * feedback from hussein --- airflow/operators/trigger_dagrun.py | 11 ++++--- airflow/triggers/external_task.py | 2 +- tests/operators/test_trigger_dagrun.py | 40 +++++++++++++++++++++----- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 9a84bfac97dd1..b0115636c3d2c 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -20,7 +20,7 @@ import datetime import json import time -from typing import TYPE_CHECKING, Sequence, cast +from typing import TYPE_CHECKING, Any, Sequence, cast from sqlalchemy.orm.exc import NoResultFound @@ -211,13 +211,16 @@ def execute(self, context: Context): return @provide_session - def execute_complete(self, context: Context, session: Session, **kwargs): - parsed_execution_date = context["execution_date"] + def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]): + # This execution date is parsed from the return trigger event + provided_execution_date = event[1]["execution_dates"][0] try: dag_run = ( session.query(DagRun) - .filter(DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == parsed_execution_date) + .filter( + DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_execution_date + ) .one() ) diff --git a/airflow/triggers/external_task.py b/airflow/triggers/external_task.py index 5ed0d3b5e3f92..883753401c58f 100644 --- a/airflow/triggers/external_task.py +++ b/airflow/triggers/external_task.py @@ -144,7 +144,7 @@ async def run(self) -> typing.AsyncIterator["TriggerEvent"]: while True: num_dags = await self.count_dags() if num_dags == len(self.execution_dates): - yield TriggerEvent(True) + yield TriggerEvent(self.serialize()) await asyncio.sleep(self.poll_interval) @sync_to_async diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index cb2d75e84c8b7..3d24315dbc6ec 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -28,6 +28,7 @@ from airflow.models import DAG, DagBag, DagModel, DagRun, Log, TaskInstance from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.triggers.external_task import DagStateTrigger from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State @@ -410,11 +411,17 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 + trigger = DagStateTrigger( + dag_id="down_stream", + execution_dates=[DEFAULT_DATE], + poll_interval=20, + states=["success", "failed"], + ) - task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date}) + task.execute_complete(context={}, event=trigger.serialize()) def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self): - """Test TriggerDagRunOperator with wait_for_completion.""" + """Test TriggerDagRunOperator wait_for_completion dag run in non defined state.""" execution_date = DEFAULT_DATE task = TriggerDagRunOperator( task_id="test_task", @@ -433,11 +440,21 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self): dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - with pytest.raises(AirflowException): - task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date}) + trigger = DagStateTrigger( + dag_id="down_stream", + execution_dates=[DEFAULT_DATE], + poll_interval=20, + states=["success", "failed"], + ) + with pytest.raises(AirflowException) as exception: + task.execute_complete( + context={}, + event=trigger.serialize(), + ) + assert "which is not in" in str(exception) def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self): - """Test TriggerDagRunOperator with wait_for_completion.""" + """Test TriggerDagRunOperator wait_for_completion dag run in failed state.""" execution_date = DEFAULT_DATE task = TriggerDagRunOperator( task_id="test_task", @@ -457,5 +474,14 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self) dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - with pytest.raises(AirflowException): - task.execute_complete(context={"execution_date": execution_date, "logical_date": execution_date}) + trigger = DagStateTrigger( + dag_id="down_stream", + execution_dates=[DEFAULT_DATE], + poll_interval=20, + states=["success", "failed"], + ) + + with pytest.raises(AirflowException) as exception: + task.execute_complete(context={}, event=trigger.serialize()) + + assert "failed with failed state" in str(exception)