diff --git a/procrastinate/blueprints.py b/procrastinate/blueprints.py index 24a0546bd..c15df3222 100644 --- a/procrastinate/blueprints.py +++ b/procrastinate/blueprints.py @@ -3,15 +3,21 @@ import functools import logging import sys -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload + +from typing_extensions import Concatenate, ParamSpec, Unpack from procrastinate import exceptions, jobs, periodic, retry, utils +from procrastinate.job_context import JobContext if TYPE_CHECKING: - from procrastinate import tasks + from procrastinate.tasks import ConfigureTaskOptions, Task + logger = logging.getLogger(__name__) +P = ParamSpec("P") + class Blueprint: """ @@ -63,7 +69,7 @@ def my_task(): """ def __init__(self) -> None: - self.tasks: dict[str, tasks.Task] = {} + self.tasks: dict[str, Task] = {} self.periodic_registry = periodic.PeriodicRegistry() self._check_stack() @@ -88,7 +94,7 @@ def _check_stack(self): extra={"action": "app_defined_in___main__"}, ) - def _register_task(self, task: tasks.Task) -> None: + def _register_task(self, task: Task) -> None: """ Register the task into the blueprint task registry. Raises exceptions.TaskAlreadyRegistered if the task name @@ -98,7 +104,7 @@ def _register_task(self, task: tasks.Task) -> None: # Each call to _add_task may raise TaskAlreadyRegistered. # We're using an intermediary dict to make sure that if the registration # is interrupted midway though, self.tasks is left unmodified. - to_add: dict[str, tasks.Task] = {} + to_add: dict[str, Task] = {} self._add_task(task=task, name=task.name, to=to_add) for alias in task.aliases: @@ -106,7 +112,7 @@ def _register_task(self, task: tasks.Task) -> None: self.tasks.update(to_add) - def _add_task(self, task: tasks.Task, name: str, to: dict | None = None) -> None: + def _add_task(self, task: Task, name: str, to: dict | None = None) -> None: # Add a task to a dict of task while making # sure a task of the same name was not already in self.tasks. # This lets us prepare a dict of tasks we might add while not adding @@ -120,7 +126,7 @@ def _add_task(self, task: tasks.Task, name: str, to: dict | None = None) -> None result_dict = self.tasks if to is None else to result_dict[name] = task - def add_task_alias(self, task: tasks.Task, alias: str) -> None: + def add_task_alias(self, task: Task, alias: str) -> None: """ Add an alias to a task. This can be useful if a task was in a given Blueprint and moves to a different blueprint. @@ -189,37 +195,22 @@ def add_tasks_from(self, blueprint: Blueprint, *, namespace: str) -> None: configure_kwargs=periodic_task.configure_kwargs, ) + @overload def task( self, - _func: Callable[..., Any] | None = None, *, + _func: None = None, name: str | None = None, aliases: list[str] | None = None, retry: retry.RetryValue = False, - pass_context: bool = False, + pass_context: Literal[False] = False, queue: str = jobs.DEFAULT_QUEUE, lock: str | None = None, queueing_lock: str | None = None, - ) -> Any: - """ - Declare a function as a task. This method is meant to be used as a decorator:: - - @app.task(...) - def my_task(args): - ... - - or:: - - @app.task - def my_task(args): - ... - - The second form will use the default value for all parameters. - + ) -> Callable[[Callable[P]], Task[P, P]]: + """Declare a function as a task. This method is meant to be used as a decorator Parameters ---------- - _func : - The decorated function queue : The name of the queue in which jobs from this task will be launched, if the queue is not overridden at launch. @@ -251,11 +242,73 @@ def my_task(args): pass_context : Passes the task execution context in the task as first """ + ... + + @overload + def task( + self, + *, + _func: None = None, + name: str | None = None, + aliases: list[str] | None = None, + retry: retry.RetryValue = False, + pass_context: Literal[True], + queue: str = jobs.DEFAULT_QUEUE, + lock: str | None = None, + queueing_lock: str | None = None, + ) -> Callable[ + [Callable[Concatenate[JobContext, P]]], + Task[Concatenate[JobContext, P], P], + ]: + """Declare a function as a task. This method is meant to be used as a decorator + Parameters + ---------- + _func : + The decorated function + """ + ... - def _wrap(func: Callable[..., tasks.Task]): - from procrastinate import tasks + @overload + def task(self, _func: Callable[P]) -> Task[P, P]: + """Declare a function as a task. This method is meant to be used as a decorator + Parameters + ---------- + _func : + The decorated function + """ + ... + + def task( + self, + _func: Callable[P] | None = None, + *, + name: str | None = None, + aliases: list[str] | None = None, + retry: retry.RetryValue = False, + pass_context: bool = False, + queue: str = jobs.DEFAULT_QUEUE, + lock: str | None = None, + queueing_lock: str | None = None, + ): + from procrastinate.tasks import Task - task = tasks.Task( + """ + Declare a function as a task. This method is meant to be used as a decorator:: + + @app.task(...) + def my_task(args): + ... + + or:: + + @app.task + def my_task(args): + ... + The second form will use the default value for all parameters. + """ + + def _wrap(func: Callable[P]): + task = Task( func, blueprint=self, queue=queue, @@ -271,11 +324,26 @@ def _wrap(func: Callable[..., tasks.Task]): return functools.update_wrapper(task, func, updated=()) if _func is None: # Called as @app.task(...) - return _wrap + return cast( + Union[ + Callable[[Callable[P, Any]], Task[P, P]], + Callable[ + [Callable[Concatenate[JobContext, P], Any]], + Task[Concatenate[JobContext, P], P], + ], + ], + _wrap, + ) return _wrap(_func) # Called as @app.task - def periodic(self, *, cron: str, periodic_id: str = "", **kwargs: dict[str, Any]): + def periodic( + self, + *, + cron: str, + periodic_id: str = "", + **configure_kwargs: Unpack[ConfigureTaskOptions], + ): """ Task decorator, marks task as being scheduled for periodic deferring (see `howto/advanced/cron`). @@ -290,7 +358,7 @@ def periodic(self, *, cron: str, periodic_id: str = "", **kwargs: dict[str, Any] Additional parameters are passed to `Task.configure`. """ return self.periodic_registry.periodic_decorator( - cron=cron, periodic_id=periodic_id, **kwargs + cron=cron, periodic_id=periodic_id, **configure_kwargs ) def will_configure_task(self) -> None: diff --git a/procrastinate/periodic.py b/procrastinate/periodic.py index a2fe1ce2c..938fa980d 100644 --- a/procrastinate/periodic.py +++ b/procrastinate/periodic.py @@ -4,13 +4,17 @@ import functools import logging import time -from typing import Any, Iterable, Tuple +from typing import Callable, Generic, Iterable, Tuple, cast import attr import croniter +from typing_extensions import Concatenate, ParamSpec, Unpack from procrastinate import exceptions, tasks +P = ParamSpec("P") +Args = ParamSpec("Args") + # The maximum delay after which tasks will be considered as # outdated, and ignored. MAX_DELAY = 60 * 10 # 10 minutes @@ -24,11 +28,11 @@ @attr.dataclass(frozen=True) -class PeriodicTask: - task: tasks.Task +class PeriodicTask(Generic[P, Args]): + task: tasks.Task[P, Args] cron: str periodic_id: str - configure_kwargs: dict[str, Any] + configure_kwargs: tasks.ConfigureTaskOptions @cached_property def croniter(self) -> croniter.croniter: @@ -42,28 +46,36 @@ class PeriodicRegistry: def __init__(self): self.periodic_tasks: dict[tuple[str, str], PeriodicTask] = {} - def periodic_decorator(self, cron: str, periodic_id: str, **kwargs): + def periodic_decorator( + self, + cron: str, + periodic_id: str, + **configure_kwargs: Unpack[tasks.ConfigureTaskOptions], + ) -> Callable[[tasks.Task[P, Concatenate[int, Args]]], tasks.Task[P, Args]]: """ Decorator over a task definition that registers that task for periodic launch. This decorator should not be used directly, ``@app.periodic()`` is meant to be used instead. """ - def wrapper(task: tasks.Task): + def wrapper(task: tasks.Task[P, Concatenate[int, Args]]) -> tasks.Task[P, Args]: self.register_task( - task=task, cron=cron, periodic_id=periodic_id, configure_kwargs=kwargs + task=task, + cron=cron, + periodic_id=periodic_id, + configure_kwargs=configure_kwargs, ) - return task + return cast(tasks.Task[P, Args], task) return wrapper def register_task( self, - task: tasks.Task, + task: tasks.Task[P, Concatenate[int, Args]], cron: str, periodic_id: str, - configure_kwargs: dict[str, Any], - ) -> PeriodicTask: + configure_kwargs: tasks.ConfigureTaskOptions, + ) -> PeriodicTask[P, Concatenate[int, Args]]: key = (task.name, periodic_id) if key in self.periodic_tasks: raise exceptions.TaskAlreadyRegistered( @@ -190,7 +202,12 @@ async def defer_jobs(self, jobs_to_defer: Iterable[TaskAtTime]) -> None: task = periodic_task.task periodic_id = periodic_task.periodic_id configure_kwargs = periodic_task.configure_kwargs - configure_kwargs.setdefault("task_kwargs", {})["timestamp"] = timestamp + task_kwargs = configure_kwargs.get("task_kwargs") + if task_kwargs is None: + task_kwargs = {} + configure_kwargs["task_kwargs"] = task_kwargs + task_kwargs["timestamp"] = timestamp + description = { "task_name": task.name, "periodic_id": periodic_id, diff --git a/procrastinate/tasks.py b/procrastinate/tasks.py index de48b094c..4c52467a9 100644 --- a/procrastinate/tasks.py +++ b/procrastinate/tasks.py @@ -2,7 +2,9 @@ import datetime import logging -from typing import Any, Callable, cast +from typing import Any, Callable, Generic, TypedDict, cast + +from typing_extensions import NotRequired, ParamSpec, Unpack from procrastinate import app as app_module from procrastinate import blueprints, exceptions, jobs, manager, types, utils @@ -11,31 +13,52 @@ logger = logging.getLogger(__name__) +Args = ParamSpec("Args") +P = ParamSpec("P") + + +class TimeDeltaParams(TypedDict): + weeks: NotRequired[int] + days: NotRequired[int] + hours: NotRequired[int] + minutes: NotRequired[int] + seconds: NotRequired[int] + milliseconds: NotRequired[int] + microseconds: NotRequired[int] + + +class ConfigureTaskOptions(TypedDict): + lock: NotRequired[str | None] + queueing_lock: NotRequired[str | None] + task_kwargs: NotRequired[types.JSONDict | None] + schedule_at: NotRequired[datetime.datetime | None] + schedule_in: NotRequired[TimeDeltaParams | None] + queue: NotRequired[str | None] + + def configure_task( *, name: str, job_manager: manager.JobManager, - lock: str | None = None, - queueing_lock: str | None = None, - task_kwargs: types.JSONDict | None = None, - schedule_at: datetime.datetime | None = None, - schedule_in: dict[str, int] | None = None, - queue: str = jobs.DEFAULT_QUEUE, + **options: Unpack[ConfigureTaskOptions], ) -> jobs.JobDeferrer: + schedule_at = options.get("schedule_at") + schedule_in = options.get("schedule_in") + if schedule_at and schedule_in is not None: raise ValueError("Cannot set both schedule_at and schedule_in") if schedule_in is not None: schedule_at = utils.utcnow() + datetime.timedelta(**schedule_in) - task_kwargs = task_kwargs or {} + task_kwargs = options.get("task_kwargs") or {} return jobs.JobDeferrer( job=jobs.Job( id=None, - lock=lock, - queueing_lock=queueing_lock, + lock=options.get("lock"), + queueing_lock=options.get("queueing_lock"), task_name=name, - queue=queue, + queue=options.get("queue") or jobs.DEFAULT_QUEUE, task_kwargs=task_kwargs, scheduled_at=schedule_at, ), @@ -43,7 +66,7 @@ def configure_task( ) -class Task: +class Task(Generic[P, Args]): """ A task is a function that should be executed later. It is linked to a default queue, and expects keyword arguments. @@ -72,7 +95,7 @@ class Task: def __init__( self, - func: Callable, + func: Callable[P], *, blueprint: blueprints.Blueprint, # task naming @@ -88,7 +111,7 @@ def __init__( ): self.queue = queue self.blueprint = blueprint - self.func: Callable = func + self.func: Callable[P] = func self.aliases = aliases if aliases else [] self.retry_strategy = retry_module.get_retry_strategy(retry) self.name: str = name if name else self.full_path @@ -106,14 +129,14 @@ def add_namespace(self, namespace: str) -> None: for alias in self.aliases ] - def __call__(self, *args, **kwargs: types.JSONValue) -> Any: + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any: return self.func(*args, **kwargs) @property def full_path(self) -> str: return utils.get_full_path(self.func) - async def defer_async(self, **task_kwargs: types.JSONValue) -> int: + async def defer_async(self, *_: Args.args, **task_kwargs: Args.kwargs) -> int: """ Create a job from this task and the given arguments. The job will be created with default parameters, if you want to better @@ -121,7 +144,7 @@ async def defer_async(self, **task_kwargs: types.JSONValue) -> int: """ return await self.configure().defer_async(**task_kwargs) - def defer(self, **task_kwargs: types.JSONValue) -> int: + def defer(self, *_: Args.args, **task_kwargs: Args.kwargs) -> int: """ Create a job from this task and the given arguments. The job will be created with default parameters, if you want to better @@ -129,16 +152,7 @@ def defer(self, **task_kwargs: types.JSONValue) -> int: """ return self.configure().defer(**task_kwargs) - def configure( - self, - *, - lock: str | None = None, - queueing_lock: str | None = None, - task_kwargs: types.JSONDict | None = None, - schedule_at: datetime.datetime | None = None, - schedule_in: dict[str, int] | None = None, - queue: str | None = None, - ) -> jobs.JobDeferrer: + def configure(self, **options: Unpack[ConfigureTaskOptions]) -> jobs.JobDeferrer: """ Configure the job with all the specific settings, defining how the job should be launched. @@ -181,8 +195,14 @@ def configure( """ self.blueprint.will_configure_task() - app = cast(app_module.App, self.blueprint) + lock = options.get("lock") + queueing_lock = options.get("queueing_lock") + task_kwargs = options.get("task_kwargs") + schedule_at = options.get("schedule_at") + schedule_in = options.get("schedule_in") + queue = options.get("queue") + app = cast(app_module.App, self.blueprint) return configure_task( name=self.name, job_manager=app.job_manager, diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 6031218ec..129c5c184 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -17,10 +17,10 @@ def task_func(): def test_app_no_connector(): with pytest.raises(TypeError): - app_module.App() + app_module.App() # type: ignore -def test_app_task_dont_read_function_attributes(app): +def test_app_task_dont_read_function_attributes(app: app_module.App): # This is a weird one. It's a regression test. At some point, we noted that, # due to the slightly wrong usage of update_wrapper, the attributes on the # decorated function were copied on the task. This led to surprising @@ -33,12 +33,12 @@ def wrapped(): assert task.pass_context is False -def test_app_register_builtins(app): +def test_app_register_builtins(app: app_module.App): assert "procrastinate.builtin_tasks.remove_old_jobs" in app.tasks assert "builtin:procrastinate.builtin_tasks.remove_old_jobs" in app.tasks -def test_app_register(app): +def test_app_register(app: app_module.App): task = tasks.Task(task_func, blueprint=app, queue="queue", name="bla") app._register_task(task) @@ -58,7 +58,7 @@ def test_app_worker(app, mocker): ) -def test_app_run_worker(app): +def test_app_run_worker(app: app_module.App): result = [] @app.task @@ -72,7 +72,7 @@ def my_task(a): assert result == [1] -async def test_app_run_worker_async(app): +async def test_app_run_worker_async(app: app_module.App): result = [] @app.task @@ -86,7 +86,7 @@ async def my_task(a): assert result == [1] -async def test_app_run_worker_async_cancel(app): +async def test_app_run_worker_async_cancel(app: app_module.App): result = [] @app.task @@ -110,7 +110,7 @@ def test_from_path(mocker): load.assert_called_once_with("dotted.path", app_module.App) -def test_app_configure_task(app): +def test_app_configure_task(app: app_module.App): scheduled_at = conftest.aware_datetime(2000, 1, 1) job = app.configure_task( name="my_name", @@ -127,7 +127,7 @@ def test_app_configure_task(app): assert job.task_kwargs == {"a": 1} -def test_app_configure_task_unknown_allowed(app): +def test_app_configure_task_unknown_allowed(app: app_module.App): @app.task(name="my_name", queue="bla") def my_name(a): pass @@ -144,15 +144,15 @@ def my_name(a): assert job.task_kwargs == {"a": 1} -def test_app_configure_task_unkown_not_allowed(app): +def test_app_configure_task_unkown_not_allowed(app: app_module.App): with pytest.raises(exceptions.TaskNotFound): app.configure_task(name="my_name", allow_unknown=False) -def test_app_periodic(app): +def test_app_periodic(app: app_module.App): @app.periodic(cron="0 * * * 1", periodic_id="foo") @app.task - def yay(timestamp): + def yay(timestamp: int): pass assert len(app.periodic_registry.periodic_tasks) == 1 diff --git a/tests/unit/test_blueprints.py b/tests/unit/test_blueprints.py index 15d17b761..188cb6e4f 100644 --- a/tests/unit/test_blueprints.py +++ b/tests/unit/test_blueprints.py @@ -3,6 +3,7 @@ import pytest from procrastinate import blueprints, exceptions, periodic, retry +from procrastinate.job_context import JobContext def test_blueprint_task_aliases(blueprint, mocker): @@ -79,7 +80,7 @@ def my_other_task(): assert blueprint.tasks == {"foo": my_task} -def test_register_task_clash_alias(blueprint): +def test_register_task_clash_alias(blueprint: blueprints.Blueprint): @blueprint.task(name="foo") def my_task(): return "foo" @@ -93,7 +94,7 @@ def my_other_task(): assert blueprint.tasks == {"foo": my_task} -def test_add_task_alias(blueprint): +def test_add_task_alias(blueprint: blueprints.Blueprint): @blueprint.task(name="foo") def my_task(): return "foo" @@ -103,7 +104,7 @@ def my_task(): assert blueprint.tasks == {"foo": my_task, "bar": my_task} -def test_add_tasks_from(blueprint): +def test_add_tasks_from(blueprint: blueprints.Blueprint): other = blueprints.Blueprint() @blueprint.task(name="foo") @@ -120,17 +121,17 @@ def my_other_task(): assert my_other_task.name == "ns:bar" -def test_add_tasks_from__periodic(blueprint): +def test_add_tasks_from__periodic(blueprint: blueprints.Blueprint): other = blueprints.Blueprint() @blueprint.periodic(cron="0 * * * 1", periodic_id="foo") @blueprint.task(name="foo") - def my_task(): + def my_task(timestamp: int): return "foo" @other.periodic(cron="0 * * * 1", periodic_id="foo") - @other.task(name="bar") - def my_other_task(): + @other.task(name="bar", pass_context=True) + def my_other_task(context: JobContext, timestamp: int): return "bar" blueprint.add_tasks_from(other, namespace="ns") @@ -151,7 +152,7 @@ def my_other_task(): } -def test_add_tasks_from_clash(blueprint): +def test_add_tasks_from_clash(blueprint: blueprints.Blueprint): other = blueprints.Blueprint() @blueprint.task(name="ns:foo") @@ -169,7 +170,7 @@ def my_other_task(): assert my_other_task.name == "foo" -def test_add_tasks_from_clash_alias(blueprint): +def test_add_tasks_from_clash_alias(blueprint: blueprints.Blueprint): other = blueprints.Blueprint() @blueprint.task(name="foo", aliases=["ns:foo"]) @@ -187,7 +188,7 @@ def my_other_task(): assert my_other_task.name == "foo" -def test_add_tasks_from_clash_other_alias(blueprint): +def test_add_tasks_from_clash_other_alias(blueprint: blueprints.Blueprint): other = blueprints.Blueprint() @blueprint.task(name="ns:foo") @@ -205,7 +206,7 @@ def my_other_task(): assert my_other_task.name == "bar" -def test_blueprint_task_explicit(blueprint, mocker): +def test_blueprint_task_explicit(blueprint: blueprints.Blueprint, mocker): @blueprint.task( name="foobar", queue="bar", @@ -215,10 +216,10 @@ def test_blueprint_task_explicit(blueprint, mocker): pass_context=True, aliases=["a"], ) - def my_task(): + def my_task(context: JobContext): return "foo" - assert my_task() == "foo" + assert my_task(JobContext()) == "foo" assert blueprint.tasks["foobar"].name == "foobar" assert blueprint.tasks["foobar"].queue == "bar" assert blueprint.tasks["foobar"].lock == "sher" diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py index 4b7927863..c78a3cb23 100644 --- a/tests/unit/test_tasks.py +++ b/tests/unit/test_tasks.py @@ -3,6 +3,7 @@ import pytest from procrastinate import tasks, utils +from procrastinate.app import App from .. import conftest @@ -11,14 +12,14 @@ def task_func(): pass -def test_task_init_with_no_name(app): +def test_task_init_with_no_name(app: App): task = tasks.Task(task_func, blueprint=app, queue="queue") assert task.func is task_func assert task.name == "tests.unit.test_tasks.task_func" -async def test_task_defer_async(app, connector): +async def test_task_defer_async(app: App, connector): task = tasks.Task(task_func, blueprint=app, queue="queue") await task.defer_async(c=3) @@ -114,3 +115,4 @@ def test_task_get_retry_exception(app, mocker): exception = ValueError() assert task.get_retry_exception(exception=exception, job=job) is mock.return_value mock.assert_called_with(exception=exception, attempts=0) + mock.assert_called_with(exception=exception, attempts=0)