Skip to content

Commit

Permalink
Ensure jobs are scheduled for retry regardless of exception type raised
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashley Heath committed Dec 14, 2023
1 parent f91743c commit 6b9c374
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
6 changes: 3 additions & 3 deletions procrastinate/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BaseRetryStrategy:
"""

def get_retry_exception(
self, exception: Exception, attempts: int
self, exception: BaseException, attempts: int
) -> Optional[exceptions.JobRetry]:
schedule_in = self.get_schedule_in(exception=exception, attempts=attempts)
if schedule_in is None:
Expand All @@ -26,7 +26,7 @@ def get_retry_exception(
schedule_at = utils.utcnow() + datetime.timedelta(seconds=schedule_in)
return exceptions.JobRetry(schedule_at.replace(microsecond=0))

def get_schedule_in(self, *, exception: Exception, attempts: int) -> Optional[int]:
def get_schedule_in(self, *, exception: BaseException, attempts: int) -> Optional[int]:
"""
Parameters
----------
Expand Down Expand Up @@ -81,7 +81,7 @@ class RetryStrategy(BaseRetryStrategy):
exponential_wait: int = 0
retry_exceptions: Optional[Iterable[Type[Exception]]] = None

def get_schedule_in(self, *, exception: Exception, attempts: int) -> Optional[int]:
def get_schedule_in(self, *, exception: BaseException, attempts: int) -> Optional[int]:
if self.max_attempts and attempts >= self.max_attempts:
return None
# isinstance's 2nd param must be a tuple, not an arbitrary iterable
Expand Down
2 changes: 1 addition & 1 deletion procrastinate/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def configure(
)

def get_retry_exception(
self, exception: Exception, job: jobs.Job
self, exception: BaseException, job: jobs.Job
) -> Optional[exceptions.JobRetry]:
if not self.retry_strategy:
return None
Expand Down
26 changes: 23 additions & 3 deletions procrastinate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, Iterable, Optional, Union

from procrastinate import app, exceptions, job_context, jobs, signals, tasks, utils
from procrastinate.exceptions import ProcrastinateException

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -164,14 +165,24 @@ async def process_job(self, job: jobs.Job, worker_id: int = 0) -> None:
extra=context.log_extra(action="loaded_job_info"),
)

def find_exception_to_re_raise(ex: ProcrastinateException) -> Optional[BaseException]:
# If the job raises a BaseException that is _not_ an Exception
# (e.g. a CancelledError, SystemExit, etc.) then we want to persist the
# outcome of the job before propagating the exception further up the
# call stack.
return ex.__cause__ if not isinstance(e.__cause__, Exception) else None

status, retry_at = None, None
exception_to_re_raise = None
try:
await self.run_job(job=job, worker_id=worker_id)
status = jobs.Status.SUCCEEDED
except exceptions.JobRetry as e:
retry_at = e.scheduled_at
except exceptions.JobError:
exception_to_re_raise = find_exception_to_re_raise(e)
except exceptions.JobError as e:
status = jobs.Status.FAILED
exception_to_re_raise = find_exception_to_re_raise(e)
except exceptions.TaskNotFound as exc:
status = jobs.Status.FAILED
self.logger.exception(
Expand Down Expand Up @@ -201,6 +212,9 @@ async def process_job(self, job: jobs.Job, worker_id: int = 0) -> None:
# Remove job information from the current context
self.context_for_worker(worker_id=worker_id, reset=True)

if exception_to_re_raise is not None:
raise exception_to_re_raise

def find_task(self, task_name: str) -> tasks.Task:
try:
return self.app.tasks[task_name]
Expand All @@ -221,10 +235,16 @@ async def run_job(self, job: jobs.Job, worker_id: int) -> None:
f"Starting job {job.call_string}",
extra=context.log_extra(action="start_job"),
)
exc_info: Union[bool, Exception]
job_args = []
if task.pass_context:
job_args.append(context)

# Initialise logging variables
task_result = None
log_title = "Error"
log_action = "job_error"
log_level = logging.ERROR
exc_info: Union[bool, BaseException] = False
try:
task_result = task(*job_args, **job.task_kwargs)
if asyncio.iscoroutine(task_result):
Expand All @@ -236,7 +256,7 @@ async def run_job(self, job: jobs.Job, worker_id: int) -> None:
extra=context.log_extra(action="concurrent_sync_task"),
)

except Exception as e:
except BaseException as e:
task_result = None
log_title = "Error"
log_action = "job_error"
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,32 @@ async def coro(*args, **kwargs):
test_worker.run_job.assert_called_with(job=job, worker_id=0)
assert connector.jobs[1]["status"] == "todo"
assert connector.jobs[1]["scheduled_at"] == scheduled_at
assert connector.jobs[1]["attempts"] == 1


async def test_process_job_retry_failed_job_re_raise_base_exception(
mocker, test_worker, job_factory, connector
):
class TestException(BaseException):
pass

scheduled_at = conftest.aware_datetime(2000, 1, 1)
job_exception = exceptions.JobRetry(scheduled_at=scheduled_at)
job_exception.__cause__ = TestException()

test_worker.run_job = mocker.Mock(side_effect=job_exception)
job = job_factory(id=1)
await test_worker.job_manager.defer_job_async(job)

# Exceptions that extend BaseException should be re-raised after the failed job
# is scheduled for retry (if retry is applicable).
with pytest.raises(TestException):
await test_worker.process_job(job=job, worker_id=0)

test_worker.run_job.assert_called_with(job=job, worker_id=0)
assert connector.jobs[1]["status"] == "todo"
assert connector.jobs[1]["scheduled_at"] == scheduled_at
assert connector.jobs[1]["attempts"] == 1


async def test_run_job(app):
Expand Down

0 comments on commit 6b9c374

Please sign in to comment.