diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d6bcb4b641d..b55698eb274 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1206,6 +1206,7 @@ def clear( tis = tis.filter(TI.task_id.in_(self.task_ids)) if include_parentdag and self.is_subdag and self.parent_dag is not None: + dag_ids.append(self.parent_dag.dag_id) p_dag = self.parent_dag.sub_dag( task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]), include_upstream=False, diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 334573d2698..bd9fdb7155d 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1363,6 +1363,63 @@ def test_clear_set_dagrun_state_for_subdag(self, dag_run_state): ) assert dagrun.state == dag_run_state + @parameterized.expand( + [ + (State.NONE,), + (State.RUNNING,), + ] + ) + def test_clear_set_dagrun_state_for_parent_dag(self, dag_run_state): + dag_id = 'test_clear_set_dagrun_state_parent_dag' + self._clean_up(dag_id) + task_id = 't1' + dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) + t_1 = DummyOperator(task_id=task_id, dag=dag) + subdag = DAG(dag_id + '.test', start_date=DEFAULT_DATE, max_active_runs=1) + SubDagOperator(task_id='test', subdag=subdag, dag=dag) + t_2 = DummyOperator(task_id='task', dag=subdag) + subdag.parent_dag = dag + subdag.is_subdag = True + + session = settings.Session() + dagrun_1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + state=State.FAILED, + start_date=DEFAULT_DATE, + execution_date=DEFAULT_DATE, + ) + dagrun_2 = subdag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + state=State.FAILED, + start_date=DEFAULT_DATE, + execution_date=DEFAULT_DATE, + ) + session.merge(dagrun_1) + session.merge(dagrun_2) + task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=State.RUNNING) + task_instance_2 = TI(t_2, execution_date=DEFAULT_DATE, state=State.RUNNING) + session.merge(task_instance_1) + session.merge(task_instance_2) + session.commit() + + subdag.clear( + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=1), + dag_run_state=dag_run_state, + include_subdags=True, + include_parentdag=True, + session=session, + ) + + dagrun = ( + session.query( + DagRun, + ) + .filter(DagRun.dag_id == dag_id) + .one() + ) + assert dagrun.state == dag_run_state + @parameterized.expand( [(state, State.NONE) for state in State.task_states if state != State.RUNNING] + [(State.RUNNING, State.SHUTDOWN)]