diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 5bd9f816a33f..33b765334246 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -1579,11 +1579,11 @@ def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None: @provide_session def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: """ - Reset any TaskInstance in QUEUED or SCHEDULED state if its SchedulerJob is no longer running. + Adopt or reset any TaskInstance in resettable state if its SchedulerJob is no longer running. :return: the number of TIs reset """ - self.log.info("Resetting orphaned tasks for active dag runs") + self.log.info("Adopting or resetting orphaned tasks for active dag runs") timeout = conf.getint("scheduler", "scheduler_health_check_threshold") for attempt in run_with_db_retries(logger=self.log): @@ -1609,10 +1609,9 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: self.log.info("Marked %d SchedulerJob instances as failed", num_failed) Stats.incr(self.__class__.__name__.lower() + "_end", num_failed) - resettable_states = [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING] query = ( select(TI) - .where(TI.state.in_(resettable_states)) + .where(TI.state.in_(State.adoptable_states)) # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is @@ -1628,11 +1627,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: ) # Lock these rows, so that another scheduler can't try and adopt these too - tis_to_reset_or_adopt = with_row_locks( + tis_to_adopt_or_reset = with_row_locks( query, of=TI, session=session, **skip_locked(session=session) ) - tis_to_reset_or_adopt = session.scalars(tis_to_reset_or_adopt).all() - to_reset = self.job.executor.try_adopt_task_instances(tis_to_reset_or_adopt) + tis_to_adopt_or_reset = session.scalars(tis_to_adopt_or_reset).all() + to_reset = self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset) reset_tis_message = [] for ti in to_reset: @@ -1640,11 +1639,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: ti.state = None ti.queued_by_job_id = None - for ti in set(tis_to_reset_or_adopt) - set(to_reset): + for ti in set(tis_to_adopt_or_reset) - set(to_reset): ti.queued_by_job_id = self.job.id Stats.incr("scheduler.orphaned_tasks.cleared", len(to_reset)) - Stats.incr("scheduler.orphaned_tasks.adopted", len(tis_to_reset_or_adopt) - len(to_reset)) + Stats.incr("scheduler.orphaned_tasks.adopted", len(tis_to_adopt_or_reset) - len(to_reset)) if to_reset: task_instance_str = "\n\t".join(reset_tis_message) diff --git a/airflow/utils/state.py b/airflow/utils/state.py index 22fb6e27c814..6da7dacc75ce 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -199,3 +199,11 @@ def color_fg(cls, state): """ A list of states indicating that a task has been terminated. """ + + adoptable_states = frozenset( + [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING, TaskInstanceState.RESTARTING] + ) + """ + A list of states indicating that a task can be adopted or reset by a scheduler job + if it was queued by another scheduler job that is not running anymore. + """ diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index cffea246e460..d1612a84f3ca 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -3151,6 +3151,30 @@ def test_adopt_or_reset_orphaned_tasks_nothing(self): session = settings.Session() assert 0 == self.job_runner.adopt_or_reset_orphaned_tasks(session=session) + @pytest.mark.parametrize( + "adoptable_state", + State.adoptable_states, + ) + def test_adopt_or_reset_resettable_tasks(self, dag_maker, adoptable_state): + dag_id = "test_adopt_or_reset_adoptable_tasks_" + adoptable_state.name + with dag_maker(dag_id=dag_id, schedule="@daily"): + task_id = dag_id + "_task" + EmptyOperator(task_id=task_id) + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + session = settings.Session() + + dr1 = dag_maker.create_dagrun(external_trigger=True) + ti = dr1.get_task_instances(session=session)[0] + ti.state = adoptable_state + session.merge(ti) + session.merge(dr1) + session.commit() + + num_reset_tis = self.job_runner.adopt_or_reset_orphaned_tasks(session=session) + assert 1 == num_reset_tis + def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self, dag_maker): dag_id = "test_reset_orphaned_tasks_external_triggered_dag" with dag_maker(dag_id=dag_id, schedule="@daily"):