From 6da11f476d5b062b5a7a66297f70707f34f85005 Mon Sep 17 00:00:00 2001 From: Niko Oliveira Date: Wed, 3 Nov 2021 12:10:11 -0700 Subject: [PATCH] Fix TriggerDagRunOperator extra link The extra link provided by the operator was previously using the execution date of the triggering dag, not the triggered dag. Store the execution date of the triggered dag in xcom so that it can be read back later within the webserver when the link is being created. --- airflow/operators/trigger_dagrun.py | 19 +++++++++- tests/operators/test_trigger_dagrun.py | 49 ++++++++++++++++++++------ 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 2a32a3366538c..7193945e03874 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -24,11 +24,15 @@ from airflow.api.common.experimental.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun +from airflow.models.xcom import XCom from airflow.utils import timezone from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.state import State from airflow.utils.types import DagRunType +XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso" +XCOM_RUN_ID = "trigger_run_id" + class TriggerDagRunLink(BaseOperatorLink): """ @@ -39,7 +43,13 @@ class TriggerDagRunLink(BaseOperatorLink): name = 'Triggered DAG' def get_link(self, operator, dttm): - query = {"dag_id": operator.trigger_dag_id, "execution_date": dttm.isoformat()} + # Fetch the correct execution date for the triggerED dag which is + # stored in xcom during execution of the triggerING task. + trigger_execution_date_iso = XCom.get_one( + execution_date=dttm, key=XCOM_EXECUTION_DATE_ISO, task_id=operator.task_id, dag_id=operator.dag_id + ) + + query = {"dag_id": operator.trigger_dag_id, "base_date": trigger_execution_date_iso} return build_airflow_url_with_query(query) @@ -139,6 +149,7 @@ def execute(self, context: Dict): execution_date=self.execution_date, replace_microseconds=False, ) + except DagRunAlreadyExists as e: if self.reset_dag_run: self.log.info("Clearing %s on %s", self.trigger_dag_id, self.execution_date) @@ -156,6 +167,12 @@ def execute(self, context: Dict): else: raise e + # Store the execution date from the dag run (either created or found above) to + # be used when creating the extra link on the webserver. + ti = context['task_instance'] + ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat()) + ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) + if self.wait_for_completion: # wait for dag to complete while True: diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index ea61687db1656..fc5cb5c6ae0f4 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -19,7 +19,7 @@ import pathlib import tempfile from datetime import datetime -from unittest import TestCase +from unittest import TestCase, mock import pytest @@ -76,6 +76,25 @@ def tearDown(self): pathlib.Path(self._tmpfile).unlink() + @mock.patch('airflow.operators.trigger_dagrun.build_airflow_url_with_query') + def assert_extra_link(self, triggering_exec_date, triggered_dag_run, triggering_task, mock_build_url): + """ + Asserts whether the correct extra links url will be created. + + Specifically it tests whether the correct dag id and date are passed to + the method which constructs the final url. + Note: We can't run that method to generate the url itself because the Flask app context + isn't available within the test logic, so it is mocked here. + """ + triggering_task.get_extra_links(triggering_exec_date, 'Triggered DAG') + assert mock_build_url.called + args, _ = mock_build_url.call_args + expected_args = { + 'dag_id': triggered_dag_run.dag_id, + 'base_date': triggered_dag_run.execution_date.isoformat(), + } + assert expected_args in args + def test_trigger_dagrun(self): """Test TriggerDagRunOperator.""" task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, dag=self.dag) @@ -84,7 +103,9 @@ def test_trigger_dagrun(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( @@ -114,8 +135,10 @@ def test_trigger_dagrun_with_execution_date(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == utc_now + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == utc_now + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_twice(self): """Test TriggerDagRunOperator with custom execution_date.""" @@ -140,12 +163,14 @@ def test_trigger_dagrun_twice(self): ) session.add(dag_run) session.commit() - task.execute(None) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == utc_now + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == utc_now + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_with_templated_execution_date(self): """Test TriggerDagRunOperator with templated execution_date.""" @@ -160,8 +185,10 @@ def test_trigger_dagrun_with_templated_execution_date(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == DEFAULT_DATE + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == DEFAULT_DATE + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_operator_conf(self): """Test passing conf to the triggered DagRun.""" @@ -288,7 +315,9 @@ def test_trigger_dagrun_triggering_itself(self): .all() ) assert len(dagruns) == 2 - assert dagruns[1].state == State.QUEUED + triggered_dag_run = dagruns[1] + assert triggered_dag_run.state == State.QUEUED + self.assert_extra_link(execution_date, triggered_dag_run, task) def test_trigger_dagrun_triggering_itself_with_execution_date(self): """Test TriggerDagRunOperator that triggers itself with execution date,