diff --git a/.env.example b/.env.example index 892f7cf6..86ec042e 100644 --- a/.env.example +++ b/.env.example @@ -14,7 +14,7 @@ API_HEALTH_PATH=/health # Log LOG_EXCLUDE_PATHS="\A(?!x)x" LOG_HTTP_EVENT="HTTP" -c$LOG_INCLUDE_COMPRESSED_BODY=false +LOG_INCLUDE_COMPRESSED_BODY=false LOG_LEVEL=20 LOG_OBFUSCATE_COOKIES='["session"]' LOG_OBFUSCATE_HEADERS='["Authorization","X-API-KEY"]' @@ -56,3 +56,11 @@ SERVER_KEEPALIVE=65 SERVER_PORT=8000 SERVER_RELOAD=false SERVER_RELOAD_DIRS='[]' + +# Worker +WORKER_JOB_TIMEOUT=10 +WORKER_JOB_HEARTBEAT=0 +WORKER_JOB_RETRIES=10 +WORKER_JOB_TTL=600 +WORKER_JOB_RETRY_DELAY=1.0 +WORKER_JOB_RETRY_BACKOFF=60 diff --git a/src/starlite_saqlalchemy/log/worker.py b/src/starlite_saqlalchemy/log/worker.py index ea2f081c..406dd5e0 100644 --- a/src/starlite_saqlalchemy/log/worker.py +++ b/src/starlite_saqlalchemy/log/worker.py @@ -28,6 +28,10 @@ async def after_process(ctx: Context) -> None: # parse log context from `ctx` job: Job = ctx["job"] log_ctx = {k: getattr(job, k) for k in settings.log.JOB_FIELDS} + # add duration measures + log_ctx["pickup_time_ms"] = job.started - job.queued + log_ctx["completed_time_ms"] = job.completed - job.started + log_ctx["total_time_ms"] = job.completed - job.queued # emit the log if job.error: level = logging.ERROR diff --git a/src/starlite_saqlalchemy/service.py b/src/starlite_saqlalchemy/service.py index 4e9e3296..8660d177 100644 --- a/src/starlite_saqlalchemy/service.py +++ b/src/starlite_saqlalchemy/service.py @@ -11,9 +11,13 @@ import logging from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar +from saq.job import Job + +from starlite_saqlalchemy import utils from starlite_saqlalchemy.db import async_session_factory +from starlite_saqlalchemy.exceptions import NotFoundError from starlite_saqlalchemy.repository.sqlalchemy import ModelT -from starlite_saqlalchemy.worker import queue +from starlite_saqlalchemy.worker import default_job_config_dict, queue if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -22,6 +26,7 @@ from starlite_saqlalchemy.repository.abc import AbstractRepository from starlite_saqlalchemy.repository.types import FilterTypes + from starlite_saqlalchemy.worker import JobConfig logger = logging.getLogger(__name__) @@ -50,6 +55,8 @@ def __init_subclass__(cls, *_: Any, **__: Any) -> None: cls.__id__ = f"{cls.__module__}.{cls.__name__}" service_object_identity_map[cls.__id__] = cls + # pylint:disable=unused-argument + async def create(self, data: T) -> T: """Create an instance of `T`. @@ -59,7 +66,7 @@ async def create(self, data: T) -> T: Returns: Representation of created instance. """ - raise NotImplementedError + return data async def list(self, **kwargs: Any) -> list[T]: """Return view of the collection of `T`. @@ -70,7 +77,7 @@ async def list(self, **kwargs: Any) -> list[T]: Returns: The list of instances retrieved from the repository. """ - raise NotImplementedError + return [] async def update(self, id_: Any, data: T) -> T: """Update existing instance of `T` with `data`. @@ -82,7 +89,7 @@ async def update(self, id_: Any, data: T) -> T: Returns: Updated representation. """ - raise NotImplementedError + return data async def upsert(self, id_: Any, data: T) -> T: """Create or update an instance of `T` with `data`. @@ -94,7 +101,7 @@ async def upsert(self, id_: Any, data: T) -> T: Returns: Updated or created representation. """ - raise NotImplementedError + return data async def get(self, id_: Any) -> T: """Retrieve a representation of `T` with that is identified by `id_` @@ -105,7 +112,7 @@ async def get(self, id_: Any) -> T: Returns: Representation of instance with identifier `id_`. """ - raise NotImplementedError + raise NotFoundError async def delete(self, id_: Any) -> T: """Delete `T` that is identified by `id_`. @@ -116,25 +123,36 @@ async def delete(self, id_: Any) -> T: Returns: Representation of the deleted instance. """ - raise NotImplementedError + raise NotFoundError - async def enqueue_background_task(self, method_name: str, **kwargs: Any) -> None: + async def enqueue_background_task( + self, method_name: str, job_config: JobConfig | None = None, **kwargs: Any + ) -> None: """Enqueue an async callback for the operation and data. Args: method_name: Method on the service object that should be called by the async worker. + job_config: Configuration object to control the job that is enqueued. **kwargs: Arguments to be passed to the method when called. Must be JSON serializable. """ module = inspect.getmodule(self) if module is None: # pragma: no cover logger.warning("Callback not enqueued, no module resolved for %s", self) return - await queue.enqueue( - make_service_callback.__qualname__, - service_type_id=self.__id__, - service_method_name=method_name, - **kwargs, + job_config_dict: dict[str, Any] + if job_config is None: + job_config_dict = default_job_config_dict + else: + job_config_dict = utils.dataclass_as_dict_shallow(job_config, exclude_none=True) + + kwargs["service_type_id"] = self.__id__ + kwargs["service_method_name"] = method_name + job = Job( + function=make_service_callback.__qualname__, + kwargs=kwargs, + **job_config_dict, ) + await queue.enqueue(job) @classmethod @contextlib.asynccontextmanager @@ -246,11 +264,7 @@ async def new(cls: type[RepoServiceT]) -> AsyncIterator[RepoServiceT]: async def make_service_callback( - _ctx: Context, - *, - service_type_id: str, - service_method_name: str, - **kwargs: Any, + _ctx: Context, *, service_type_id: str, service_method_name: str, **kwargs: Any ) -> None: """Make an async service callback. diff --git a/src/starlite_saqlalchemy/settings.py b/src/starlite_saqlalchemy/settings.py index c4c7a089..136630e0 100644 --- a/src/starlite_saqlalchemy/settings.py +++ b/src/starlite_saqlalchemy/settings.py @@ -255,6 +255,44 @@ class Config: EXPONENTIAL_BACKOFF_MULTIPLIER: float = 1 +class WorkerSettings(BaseSettings): + """Global SAQ Job configuration.""" + + class Config: + case_sensitive = True + env_file = ".env" + env_prefix = "WORKER_" + + JOB_TIMEOUT: int = 10 + """Max time a job can run for, in seconds. + + Set to `0` for no timeout. + """ + JOB_HEARTBEAT: int = 0 + """Max time a job can survive without emitting a heartbeat. `0` to disable. + + `job.update()` will trigger a heartbeat. + """ + JOB_RETRIES: int = 10 + """Max attempts for any job.""" + JOB_TTL: int = 600 + """Lifetime of available job information, in seconds. + + 0: indefinite + -1: disabled (no info retained) + """ + JOB_RETRY_DELAY: float = 1.0 + """Seconds to delay before retrying a job.""" + JOB_RETRY_BACKOFF: bool | float = 60 + """If true, use exponential backoff for retry delays. + + - The first retry will have whatever retry_delay is. + - The second retry will have retry_delay*2. The third retry will have retry_delay*4. And so on. + - This always includes jitter, where the final retry delay is a random number between 0 and the calculated retry delay. + - If retry_backoff is set to a number, that number is the maximum retry delay, in seconds." + """ + + # `.parse_obj()` thing is a workaround for pyright and pydantic interplay, see: # https://github.com/pydantic/pydantic/issues/3753#issuecomment-1087417884 api = APISettings.parse_obj({}) @@ -275,3 +313,5 @@ class Config: """Sentry settings.""" server = ServerSettings.parse_obj({}) """Server settings.""" +worker = WorkerSettings.parse_obj({}) +"""Worker settings.""" diff --git a/src/starlite_saqlalchemy/utils.py b/src/starlite_saqlalchemy/utils.py new file mode 100644 index 00000000..c140ac36 --- /dev/null +++ b/src/starlite_saqlalchemy/utils.py @@ -0,0 +1,14 @@ +"""General utility functions.""" +import dataclasses +from typing import Any + + +def dataclass_as_dict_shallow(dataclass: Any, *, exclude_none: bool = False) -> dict[str, Any]: + """Convert a dataclass to dict, without deepcopy.""" + ret: dict[str, Any] = {} + for field in dataclasses.fields(dataclass): + value = getattr(dataclass, field.name) + if exclude_none and value is None: + continue + ret[field.name] = value + return ret diff --git a/src/starlite_saqlalchemy/worker.py b/src/starlite_saqlalchemy/worker.py index 4ca68089..73608d66 100644 --- a/src/starlite_saqlalchemy/worker.py +++ b/src/starlite_saqlalchemy/worker.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import dataclasses from functools import partial from typing import TYPE_CHECKING, Any @@ -9,16 +10,18 @@ import saq from starlite.utils.serialization import default_serializer -from starlite_saqlalchemy import redis, settings, type_encoders +from starlite_saqlalchemy import redis, settings, type_encoders, utils if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Collection from signal import Signals __all__ = [ + "JobConfig", "Queue", "Worker", "create_worker_instance", + "default_job_config_dict", "queue", ] @@ -31,10 +34,13 @@ class Queue(saq.Queue): """Async task queue.""" def __init__(self, *args: Any, **kwargs: Any) -> None: - """[SAQ - Queue](https://github.com/tobymao/saq/blob/master/saq/queue.py). + """Create an SAQ Queue. - Configures `orjson` for JSON serialization/deserialization if not + See: https://github.com/tobymao/saq/blob/master/saq/queue.py + + Names the queue per the application slug - namespaces SAQ's redis keys to the app. + + Configures `msgspec` for JSON serialization/deserialization if not otherwise configured. Args: @@ -67,6 +73,53 @@ async def on_app_startup(self) -> None: # pragma: no cover """ +@dataclasses.dataclass() +class JobConfig: + """Configure a Job. + + Used to configure jobs enqueued via + `Service.enqueue_background_task()` + """ + + # pylint:disable=too-many-instance-attributes + + queue: Queue = queue + """Queue associated with the job.""" + key: str | None = None + """Pass in to control duplicate jobs.""" + timeout: int = settings.worker.JOB_TIMEOUT + """Max time a job can run for, in seconds. + + Set to `0` for no timeout. + """ + heartbeat: int = settings.worker.JOB_HEARTBEAT + """Max time a job can survive without emitting a heartbeat. `0` to disable. + + `job.update()` will trigger a heartbeat. + """ + retries: int = settings.worker.JOB_RETRIES + """Max attempts for any job.""" + ttl: int = settings.worker.JOB_TTL + """Lifetime of available job information, in seconds. + + 0: indefinite + -1: disabled (no info retained) + """ + retry_delay: float = settings.worker.JOB_TTL + """Seconds to delay before retrying a job.""" + retry_backoff: bool | float = settings.worker.JOB_RETRY_BACKOFF + """If true, use exponential backoff for retry delays. + + - The first retry will have whatever retry_delay is. + - The second retry will have retry_delay*2. The third retry will have retry_delay*4. And so on. + - This always includes jitter, where the final retry delay is a random number between 0 and the calculated retry delay. + - If retry_backoff is set to a number, that number is the maximum retry delay, in seconds." + """ + + +default_job_config_dict = utils.dataclass_as_dict_shallow(JobConfig(), exclude_none=True) + + def create_worker_instance( functions: Collection[Callable[..., Any] | tuple[str, Callable]], before_process: Callable[[dict[str, Any]], Awaitable[Any]] | None = None, diff --git a/tests/unit/test_log.py b/tests/unit/test_log.py index 552af1e8..d09fbdc9 100644 --- a/tests/unit/test_log.py +++ b/tests/unit/test_log.py @@ -346,6 +346,9 @@ async def test_after_process(job: Job, cap_logger: CapturingLogger) -> None: "event": "Worker", "level": "info", "timestamp": ANY, + "pickup_time_ms": 0, + "completed_time_ms": 0, + "total_time_ms": 0, }, ) ] == cap_logger.calls @@ -373,6 +376,9 @@ async def test_after_process_logs_at_error(job: Job, cap_logger: CapturingLogger "event": "Worker", "level": "error", "timestamp": ANY, + "pickup_time_ms": 0, + "completed_time_ms": 0, + "total_time_ms": 0, }, ) ] == cap_logger.calls diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 10719e5d..e9f4b57d 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -7,8 +7,10 @@ from uuid import uuid4 import pytest +from saq import Job from starlite_saqlalchemy import db, service, worker +from starlite_saqlalchemy.exceptions import NotFoundError from tests.utils import domain if TYPE_CHECKING: @@ -118,15 +120,52 @@ async def test_enqueue_service_callback(monkeypatch: "MonkeyPatch") -> None: monkeypatch.setattr(worker.queue, "enqueue", enqueue_mock) service_instance = domain.authors.Service(session=db.async_session_factory()) await service_instance.enqueue_background_task("receive_callback", raw_obj={"a": "b"}) - enqueue_mock.assert_called_once_with( - "make_service_callback", - service_type_id="tests.utils.domain.authors.Service", - service_method_name="receive_callback", - raw_obj={"a": "b"}, + enqueue_mock.assert_called_once() + assert isinstance(enqueue_mock.mock_calls[0].args[0], Job) + job = enqueue_mock.mock_calls[0].args[0] + assert job.function == service.make_service_callback.__qualname__ + assert job.kwargs == { + "service_type_id": "tests.utils.domain.authors.Service", + "service_method_name": "receive_callback", + "raw_obj": {"a": "b"}, + } + + +async def test_enqueue_service_callback_with_custom_job_config(monkeypatch: "MonkeyPatch") -> None: + """Tests that job enqueued with desired arguments.""" + enqueue_mock = AsyncMock() + monkeypatch.setattr(worker.queue, "enqueue", enqueue_mock) + service_instance = domain.authors.Service(session=db.async_session_factory()) + await service_instance.enqueue_background_task( + "receive_callback", job_config=worker.JobConfig(timeout=999), raw_obj={"a": "b"} ) + enqueue_mock.assert_called_once() + assert isinstance(enqueue_mock.mock_calls[0].args[0], Job) + job = enqueue_mock.mock_calls[0].args[0] + assert job.function == service.make_service_callback.__qualname__ + assert job.timeout == 999 + assert job.kwargs == { + "service_type_id": "tests.utils.domain.authors.Service", + "service_method_name": "receive_callback", + "raw_obj": {"a": "b"}, + } async def test_service_new_context_manager() -> None: """Simple test of `Service.new()` context manager behavior.""" async with service.Service[domain.authors.Author].new() as service_obj: assert isinstance(service_obj, service.Service) + + +async def test_service_method_default_behavior() -> None: + """Test default behavior of base service methods.""" + service_obj = service.Service[object]() + data = object() + assert await service_obj.create(data) is data + assert await service_obj.list() == [] + assert await service_obj.update("abc", data) is data + assert await service_obj.upsert("abc", data) is data + with pytest.raises(NotFoundError): + await service_obj.get("abc") + with pytest.raises(NotFoundError): + await service_obj.delete("abc")