From 6b9c374626133a71e5ceca99f094f9146449c713 Mon Sep 17 00:00:00 2001 From: Ashley Heath Date: Thu, 14 Dec 2023 16:34:24 +0000 Subject: [PATCH] Ensure jobs are scheduled for retry regardless of exception type raised --- procrastinate/retry.py | 6 +++--- procrastinate/tasks.py | 2 +- procrastinate/worker.py | 26 +++++++++++++++++++++++--- tests/unit/test_worker.py | 26 ++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/procrastinate/retry.py b/procrastinate/retry.py index 55e58cbb7..10c050d5f 100644 --- a/procrastinate/retry.py +++ b/procrastinate/retry.py @@ -17,7 +17,7 @@ class BaseRetryStrategy: """ def get_retry_exception( - self, exception: Exception, attempts: int + self, exception: BaseException, attempts: int ) -> Optional[exceptions.JobRetry]: schedule_in = self.get_schedule_in(exception=exception, attempts=attempts) if schedule_in is None: @@ -26,7 +26,7 @@ def get_retry_exception( schedule_at = utils.utcnow() + datetime.timedelta(seconds=schedule_in) return exceptions.JobRetry(schedule_at.replace(microsecond=0)) - def get_schedule_in(self, *, exception: Exception, attempts: int) -> Optional[int]: + def get_schedule_in(self, *, exception: BaseException, attempts: int) -> Optional[int]: """ Parameters ---------- @@ -81,7 +81,7 @@ class RetryStrategy(BaseRetryStrategy): exponential_wait: int = 0 retry_exceptions: Optional[Iterable[Type[Exception]]] = None - def get_schedule_in(self, *, exception: Exception, attempts: int) -> Optional[int]: + def get_schedule_in(self, *, exception: BaseException, attempts: int) -> Optional[int]: if self.max_attempts and attempts >= self.max_attempts: return None # isinstance's 2nd param must be a tuple, not an arbitrary iterable diff --git a/procrastinate/tasks.py b/procrastinate/tasks.py index 26431d7fb..fe1dda68e 100644 --- a/procrastinate/tasks.py +++ b/procrastinate/tasks.py @@ -197,7 +197,7 @@ def configure( ) def get_retry_exception( - self, exception: Exception, job: jobs.Job + self, exception: BaseException, job: jobs.Job ) -> Optional[exceptions.JobRetry]: if not self.retry_strategy: return None diff --git a/procrastinate/worker.py b/procrastinate/worker.py index 75c7f3c04..1205f9c67 100644 --- a/procrastinate/worker.py +++ b/procrastinate/worker.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Optional, Union from procrastinate import app, exceptions, job_context, jobs, signals, tasks, utils +from procrastinate.exceptions import ProcrastinateException logger = logging.getLogger(__name__) @@ -164,14 +165,24 @@ async def process_job(self, job: jobs.Job, worker_id: int = 0) -> None: extra=context.log_extra(action="loaded_job_info"), ) + def find_exception_to_re_raise(ex: ProcrastinateException) -> Optional[BaseException]: + # If the job raises a BaseException that is _not_ an Exception + # (e.g. a CancelledError, SystemExit, etc.) then we want to persist the + # outcome of the job before propagating the exception further up the + # call stack. + return ex.__cause__ if not isinstance(e.__cause__, Exception) else None + status, retry_at = None, None + exception_to_re_raise = None try: await self.run_job(job=job, worker_id=worker_id) status = jobs.Status.SUCCEEDED except exceptions.JobRetry as e: retry_at = e.scheduled_at - except exceptions.JobError: + exception_to_re_raise = find_exception_to_re_raise(e) + except exceptions.JobError as e: status = jobs.Status.FAILED + exception_to_re_raise = find_exception_to_re_raise(e) except exceptions.TaskNotFound as exc: status = jobs.Status.FAILED self.logger.exception( @@ -201,6 +212,9 @@ async def process_job(self, job: jobs.Job, worker_id: int = 0) -> None: # Remove job information from the current context self.context_for_worker(worker_id=worker_id, reset=True) + if exception_to_re_raise is not None: + raise exception_to_re_raise + def find_task(self, task_name: str) -> tasks.Task: try: return self.app.tasks[task_name] @@ -221,10 +235,16 @@ async def run_job(self, job: jobs.Job, worker_id: int) -> None: f"Starting job {job.call_string}", extra=context.log_extra(action="start_job"), ) - exc_info: Union[bool, Exception] job_args = [] if task.pass_context: job_args.append(context) + + # Initialise logging variables + task_result = None + log_title = "Error" + log_action = "job_error" + log_level = logging.ERROR + exc_info: Union[bool, BaseException] = False try: task_result = task(*job_args, **job.task_kwargs) if asyncio.iscoroutine(task_result): @@ -236,7 +256,7 @@ async def run_job(self, job: jobs.Job, worker_id: int) -> None: extra=context.log_extra(action="concurrent_sync_task"), ) - except Exception as e: + except BaseException as e: task_result = None log_title = "Error" log_action = "job_error" diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py index cdb6cf847..c572ea507 100644 --- a/tests/unit/test_worker.py +++ b/tests/unit/test_worker.py @@ -137,6 +137,32 @@ async def coro(*args, **kwargs): test_worker.run_job.assert_called_with(job=job, worker_id=0) assert connector.jobs[1]["status"] == "todo" assert connector.jobs[1]["scheduled_at"] == scheduled_at + assert connector.jobs[1]["attempts"] == 1 + + +async def test_process_job_retry_failed_job_re_raise_base_exception( + mocker, test_worker, job_factory, connector +): + class TestException(BaseException): + pass + + scheduled_at = conftest.aware_datetime(2000, 1, 1) + job_exception = exceptions.JobRetry(scheduled_at=scheduled_at) + job_exception.__cause__ = TestException() + + test_worker.run_job = mocker.Mock(side_effect=job_exception) + job = job_factory(id=1) + await test_worker.job_manager.defer_job_async(job) + + # Exceptions that extend BaseException should be re-raised after the failed job + # is scheduled for retry (if retry is applicable). + with pytest.raises(TestException): + await test_worker.process_job(job=job, worker_id=0) + + test_worker.run_job.assert_called_with(job=job, worker_id=0) + assert connector.jobs[1]["status"] == "todo" + assert connector.jobs[1]["scheduled_at"] == scheduled_at + assert connector.jobs[1]["attempts"] == 1 async def test_run_job(app):