diff --git a/src/starlite_saqlalchemy/__init__.py b/src/starlite_saqlalchemy/__init__.py index bdaf6c2c..16c60667 100644 --- a/src/starlite_saqlalchemy/__init__.py +++ b/src/starlite_saqlalchemy/__init__.py @@ -34,7 +34,6 @@ def example_handler() -> dict: orm, redis, repository, - response, sentry, service, settings, @@ -58,7 +57,6 @@ def example_handler() -> dict: "orm", "redis", "repository", - "response", "sentry", "service", "settings", diff --git a/src/starlite_saqlalchemy/init_plugin.py b/src/starlite_saqlalchemy/init_plugin.py index eb2b5ef2..42f4d89b 100644 --- a/src/starlite_saqlalchemy/init_plugin.py +++ b/src/starlite_saqlalchemy/init_plugin.py @@ -35,6 +35,7 @@ def example_handler() -> dict: from pydantic import BaseModel from starlite.app import DEFAULT_CACHE_CONFIG, DEFAULT_OPENAPI_CONFIG from starlite.plugins.sql_alchemy import SQLAlchemyPlugin +from starlite.response import Response from starlite_saqlalchemy import ( cache, @@ -45,17 +46,19 @@ def example_handler() -> dict: logging, openapi, redis, - response, sentry, sqlalchemy_plugin, static_files, ) from starlite_saqlalchemy.health import health_check from starlite_saqlalchemy.repository.exceptions import RepositoryException +from starlite_saqlalchemy.serializer import default_serializer from starlite_saqlalchemy.service import ServiceException, make_service_callback from starlite_saqlalchemy.worker import create_worker_instance if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any from starlite.config.app import AppConfig @@ -143,6 +146,15 @@ class PluginConfig(BaseModel): [`AppConfig.on_shutdown`][starlite.config.app.AppConfig.on_shutdown] that manage the lifecycle of the `SAQ` worker. """ + serializer: Callable[[Any], Any] = default_serializer + """ + The serializer callable that is used by the custom [`Response`][starlite.response.Response] + class that is created. + If [`AppConfig.response_class`][starlite.config.app.AppConfig.response_class] is not `None`, + this is ignored. + If [`PluginConfig.do_response_class`][PluginConfig.do_response_class] is `False`, this is + ignored. + """ class ConfigureApp: @@ -290,7 +302,9 @@ def configure_response_class(self, app_config: AppConfig) -> None: app_config: The Starlite application config object. """ if self.config.do_response_class and app_config.response_class is None: - app_config.response_class = response.Response + app_config.response_class = type( + "Response", (Response,), {"serializer": staticmethod(self.config.serializer)} + ) def configure_sentry(self, app_config: AppConfig) -> None: """Add handler to configure Sentry integration. diff --git a/src/starlite_saqlalchemy/response.py b/src/starlite_saqlalchemy/response.py deleted file mode 100644 index 571ffe59..00000000 --- a/src/starlite_saqlalchemy/response.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Custom response class for the application that handles serialization of pg -UUID values.""" -from __future__ import annotations - -from typing import Any - -import starlite -from asyncpg.pgproto import pgproto -from starlite.response import Response as _Response - -__all__ = ["Response"] - - -class Response(_Response[Any]): - """Custom [`starlite.Response`][starlite.response.Response] that handles - serialization of the postgres UUID type used by SQLAlchemy.""" - - @staticmethod - def serializer(value: Any) -> Any: - """Serialize `value`. - - Args: - value: To be serialized. - - Returns: - Serialized representation of `value`. - """ - if isinstance(value, pgproto.UUID): - return str(value) - return starlite.Response[Any].serializer(value) diff --git a/src/starlite_saqlalchemy/serializer.py b/src/starlite_saqlalchemy/serializer.py new file mode 100644 index 00000000..b7339625 --- /dev/null +++ b/src/starlite_saqlalchemy/serializer.py @@ -0,0 +1,19 @@ +"""Default serializer used by plugin if one not provided.""" +from typing import Any + +from asyncpg.pgproto import pgproto +from starlite import Response + + +def default_serializer(value: Any) -> Any: + """Serialize `value`. + + Args: + value: To be serialized. + + Returns: + Serialized representation of `value`. + """ + if isinstance(value, pgproto.UUID): + return str(value) + return Response[Any].serializer(value)