From b44753b4f5dedd6e1acd8eb55af7462921a00f34 Mon Sep 17 00:00:00 2001 From: tatian Date: Tue, 13 Aug 2024 19:02:10 +0800 Subject: [PATCH 1/5] Add retry in update trigger timeout --- airflow/jobs/scheduler_job_runner.py | 32 +++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 163bf5b71449..8ed9cfdefdb7 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -1921,21 +1921,23 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: @provide_session def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None: """Mark any "deferred" task as failed if the trigger or execution timeout has passed.""" - num_timed_out_tasks = session.execute( - update(TI) - .where( - TI.state == TaskInstanceState.DEFERRED, - TI.trigger_timeout < timezone.utcnow(), - ) - .values( - state=TaskInstanceState.SCHEDULED, - next_method="__fail__", - next_kwargs={"error": "Trigger/execution timeout"}, - trigger_id=None, - ) - ).rowcount - if num_timed_out_tasks: - self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) + for attempt in run_with_db_retries(logger=self.log): + with attempt: + num_timed_out_tasks = session.execute( + update(TI) + .where( + TI.state == TaskInstanceState.DEFERRED, + TI.trigger_timeout < timezone.utcnow(), + ) + .values( + state=TaskInstanceState.SCHEDULED, + next_method="__fail__", + next_kwargs={"error": "Trigger/execution timeout"}, + trigger_id=None, + ) + ).rowcount + if num_timed_out_tasks: + self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) # [START find_zombies] def _find_zombies(self) -> None: From ed141e8ecaee183108e3f7c896bb9372ea1f2f71 Mon Sep 17 00:00:00 2001 From: tatian Date: Sat, 17 Aug 2024 17:23:48 +0800 Subject: [PATCH 2/5] add ut for these cases --- airflow/jobs/scheduler_job_runner.py | 6 ++- tests/jobs/test_scheduler_job.py | 72 ++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 8ed9cfdefdb7..2b0cc63e90af 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -1919,9 +1919,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: return len(to_reset) @provide_session - def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None: + def check_trigger_timeouts( + self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION + ) -> None: """Mark any "deferred" task as failed if the trigger or execution timeout has passed.""" - for attempt in run_with_db_retries(logger=self.log): + for attempt in run_with_db_retries(max_retries, logger=self.log): with attempt: num_timed_out_tasks = session.execute( update(TI) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 2e96728d5eca..58a27550a9d3 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -5227,6 +5227,78 @@ def test_timeout_triggers(self, dag_maker): assert ti1.next_method == "__fail__" assert ti2.state == State.DEFERRED + def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker): + """ + Tests that it will retry on DB error like deadlock when updating timeout triggers. + """ + from sqlalchemy.exc import InternalError + + retry_times = 3 + + session = settings.Session() + # Create the test DAG and task + with dag_maker( + dag_id="test_retry_on_db_error_when_update_timeout_triggers", + start_date=DEFAULT_DATE, + schedule="@once", + max_active_runs=1, + session=session, + ): + EmptyOperator(task_id="dummy1") + + # Mock the db failure within retry times + might_fail_session = MagicMock(wraps=session) + + def check_if_trigger_timeout(max_retries: int): + def side_effect(*args, **kwargs): + calls = side_effect.call_count + side_effect.call_count += 1 + if calls < retry_times - 1: + raise InternalError("any_statement", "any_params", "any_orig") + else: + return session.execute(*args, **kwargs) + + side_effect.call_count = 0 + might_fail_session.execute.side_effect = side_effect + + try: + # Create a Task Instance for the task that is allegedly deferred + # but past its timeout, and one that is still good. + # We don't actually need a linked trigger here; the code doesn't check. + dr1 = dag_maker.create_dagrun() + dr2 = dag_maker.create_dagrun( + run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(seconds=1) + ) + ti1 = dr1.get_task_instance("dummy1", session) + ti2 = dr2.get_task_instance("dummy1", session) + ti1.state = State.DEFERRED + ti1.trigger_timeout = timezone.utcnow() - datetime.timedelta(seconds=60) + ti2.state = State.DEFERRED + ti2.trigger_timeout = timezone.utcnow() + datetime.timedelta(seconds=60) + session.flush() + + # Boot up the scheduler and make it check timeouts + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.check_trigger_timeouts(max_retries=max_retries, session=might_fail_session) + + # Make sure that TI1 is now scheduled to fail, and 2 wasn't touched + session.refresh(ti1) + session.refresh(ti2) + assert ti1.state == State.SCHEDULED + assert ti1.next_method == "__fail__" + assert ti2.state == State.DEFERRED + finally: + self.clean_db() + + # Positive case, will retry until success before reach max retry times + check_if_trigger_timeout(retry_times) + + # Negative case: no retries, execute only once. + with pytest.raises(InternalError): + check_if_trigger_timeout(1) + def test_find_zombies_nothing(self): executor = MockExecutor(do_update=False) scheduler_job = Job(executor=executor) From 9fe33ee1a74b62e17a0d3bc28dbc940be6e7f662 Mon Sep 17 00:00:00 2001 From: tatian Date: Sat, 17 Aug 2024 17:28:42 +0800 Subject: [PATCH 3/5] use OperationalError in ut to describe deadlock scenarios --- tests/jobs/test_scheduler_job.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 58a27550a9d3..2d040169b972 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -5231,7 +5231,7 @@ def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker): """ Tests that it will retry on DB error like deadlock when updating timeout triggers. """ - from sqlalchemy.exc import InternalError + from sqlalchemy.exc import OperationalError retry_times = 3 @@ -5254,7 +5254,7 @@ def side_effect(*args, **kwargs): calls = side_effect.call_count side_effect.call_count += 1 if calls < retry_times - 1: - raise InternalError("any_statement", "any_params", "any_orig") + raise OperationalError("any_statement", "any_params", "any_orig") else: return session.execute(*args, **kwargs) @@ -5296,7 +5296,7 @@ def side_effect(*args, **kwargs): check_if_trigger_timeout(retry_times) # Negative case: no retries, execute only once. - with pytest.raises(InternalError): + with pytest.raises(OperationalError): check_if_trigger_timeout(1) def test_find_zombies_nothing(self): From fd7044b092ec07547b9d42f19a542412c94ed5e3 Mon Sep 17 00:00:00 2001 From: tatian Date: Sun, 29 Sep 2024 22:14:57 +0800 Subject: [PATCH 4/5] [MINOR] add newsfragment for this PR --- newsfragments/41429.improvement.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/41429.improvement.rst diff --git a/newsfragments/41429.improvement.rst b/newsfragments/41429.improvement.rst new file mode 100644 index 000000000000..6d04d5dfe61a --- /dev/null +++ b/newsfragments/41429.improvement.rst @@ -0,0 +1 @@ +Add ``run_with_db_retries`` when the scheduler updates the deferred Task as failed to tolerate database deadlock issues. From 2c6be65740a49e24e61edff92624247723eebb12 Mon Sep 17 00:00:00 2001 From: tatian Date: Wed, 2 Oct 2024 09:25:20 +0800 Subject: [PATCH 5/5] [MINOR] refactor UT for mypy check --- tests/jobs/test_scheduler_job.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 2d040169b972..8ba1ad501d82 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -144,7 +144,7 @@ def clean_db(): @pytest.fixture(autouse=True) def per_test(self) -> Generator: self.clean_db() - self.job_runner = None + self.job_runner: SchedulerJobRunner | None = None yield @@ -5250,16 +5250,20 @@ def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker): might_fail_session = MagicMock(wraps=session) def check_if_trigger_timeout(max_retries: int): - def side_effect(*args, **kwargs): - calls = side_effect.call_count - side_effect.call_count += 1 - if calls < retry_times - 1: - raise OperationalError("any_statement", "any_params", "any_orig") - else: - return session.execute(*args, **kwargs) - - side_effect.call_count = 0 - might_fail_session.execute.side_effect = side_effect + def make_side_effect(): + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + if call_count < retry_times - 1: + call_count += 1 + raise OperationalError("any_statement", "any_params", "any_orig") + else: + return session.execute(*args, **kwargs) + + return side_effect + + might_fail_session.execute.side_effect = make_side_effect() try: # Create a Task Instance for the task that is allegedly deferred