diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 242154820df9e..de6ce5019b9de 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -1884,23 +1884,27 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: return len(to_reset) @provide_session - def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None: + def check_trigger_timeouts( + self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION + ) -> None: """Mark any "deferred" task as failed if the trigger or execution timeout has passed.""" - num_timed_out_tasks = session.execute( - update(TI) - .where( - TI.state == TaskInstanceState.DEFERRED, - TI.trigger_timeout < timezone.utcnow(), - ) - .values( - state=TaskInstanceState.SCHEDULED, - next_method="__fail__", - next_kwargs={"error": "Trigger/execution timeout"}, - trigger_id=None, - ) - ).rowcount - if num_timed_out_tasks: - self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) + for attempt in run_with_db_retries(max_retries, logger=self.log): + with attempt: + num_timed_out_tasks = session.execute( + update(TI) + .where( + TI.state == TaskInstanceState.DEFERRED, + TI.trigger_timeout < timezone.utcnow(), + ) + .values( + state=TaskInstanceState.SCHEDULED, + next_method="__fail__", + next_kwargs={"error": "Trigger/execution timeout"}, + trigger_id=None, + ) + ).rowcount + if num_timed_out_tasks: + self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) # [START find_zombies] def _find_zombies(self) -> None: diff --git a/newsfragments/41429.improvement.rst b/newsfragments/41429.improvement.rst new file mode 100644 index 0000000000000..6d04d5dfe61af --- /dev/null +++ b/newsfragments/41429.improvement.rst @@ -0,0 +1 @@ +Add ``run_with_db_retries`` when the scheduler updates the deferred Task as failed to tolerate database deadlock issues. diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 32662d7d873db..40a7220698407 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -148,7 +148,7 @@ def clean_db(): @pytest.fixture(autouse=True) def per_test(self) -> Generator: self.clean_db() - self.job_runner = None + self.job_runner: SchedulerJobRunner | None = None yield @@ -5192,6 +5192,82 @@ def test_timeout_triggers(self, dag_maker): assert ti1.next_method == "__fail__" assert ti2.state == State.DEFERRED + def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker): + """ + Tests that it will retry on DB error like deadlock when updating timeout triggers. + """ + from sqlalchemy.exc import OperationalError + + retry_times = 3 + + session = settings.Session() + # Create the test DAG and task + with dag_maker( + dag_id="test_retry_on_db_error_when_update_timeout_triggers", + start_date=DEFAULT_DATE, + schedule="@once", + max_active_runs=1, + session=session, + ): + EmptyOperator(task_id="dummy1") + + # Mock the db failure within retry times + might_fail_session = MagicMock(wraps=session) + + def check_if_trigger_timeout(max_retries: int): + def make_side_effect(): + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + if call_count < retry_times - 1: + call_count += 1 + raise OperationalError("any_statement", "any_params", "any_orig") + else: + return session.execute(*args, **kwargs) + + return side_effect + + might_fail_session.execute.side_effect = make_side_effect() + + try: + # Create a Task Instance for the task that is allegedly deferred + # but past its timeout, and one that is still good. + # We don't actually need a linked trigger here; the code doesn't check. + dr1 = dag_maker.create_dagrun() + dr2 = dag_maker.create_dagrun( + run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(seconds=1) + ) + ti1 = dr1.get_task_instance("dummy1", session) + ti2 = dr2.get_task_instance("dummy1", session) + ti1.state = State.DEFERRED + ti1.trigger_timeout = timezone.utcnow() - datetime.timedelta(seconds=60) + ti2.state = State.DEFERRED + ti2.trigger_timeout = timezone.utcnow() + datetime.timedelta(seconds=60) + session.flush() + + # Boot up the scheduler and make it check timeouts + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.check_trigger_timeouts(max_retries=max_retries, session=might_fail_session) + + # Make sure that TI1 is now scheduled to fail, and 2 wasn't touched + session.refresh(ti1) + session.refresh(ti2) + assert ti1.state == State.SCHEDULED + assert ti1.next_method == "__fail__" + assert ti2.state == State.DEFERRED + finally: + self.clean_db() + + # Positive case, will retry until success before reach max retry times + check_if_trigger_timeout(retry_times) + + # Negative case: no retries, execute only once. + with pytest.raises(OperationalError): + check_if_trigger_timeout(1) + def test_find_zombies_nothing(self): executor = MockExecutor(do_update=False) scheduler_job = Job(executor=executor)