diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index ffb06127..598eeb63 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -27,13 +27,6 @@ def __init__(self, **kwargs: Any) -> None: """Repository constructors accept arbitrary kwargs.""" super().__init__(**kwargs) - @classmethod - def __class_getitem__(cls: type[RepoT], item: type[T]) -> type[RepoT]: - """Set `model_type` attribute, using generic parameter.""" - if not isinstance(item, TypeVar) and not getattr(cls, "model_type", None): - cls.model_type = item - return cls - @abstractmethod async def add(self, data: T) -> T: """Add `data` to the collection. diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index 8e338513..58f611a0 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar from sqlalchemy import select, text from sqlalchemy.exc import IntegrityError, SQLAlchemyError @@ -36,6 +36,7 @@ T = TypeVar("T") ModelT = TypeVar("ModelT", bound="orm.Base") +SQLARepoT = TypeVar("SQLARepoT", bound="SQLAlchemyRepository") @contextmanager @@ -59,11 +60,9 @@ def wrap_sqlalchemy_exception() -> Any: raise RepositoryException(f"An exception occurred: {exc}") from exc -class SQLAlchemyRepository(AbstractRepository[ModelT]): +class SQLAlchemyRepository(AbstractRepository[ModelT], Generic[ModelT]): """SQLAlchemy based implementation of the repository interface.""" - model_type: type[ModelT] - def __init__( self, *, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None, **kwargs: Any ) -> None: diff --git a/src/starlite_saqlalchemy/service.py b/src/starlite_saqlalchemy/service.py index ff6a0f82..d07a75ca 100644 --- a/src/starlite_saqlalchemy/service.py +++ b/src/starlite_saqlalchemy/service.py @@ -5,17 +5,17 @@ """ from __future__ import annotations -import importlib import inspect import logging -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar -from starlite_saqlalchemy.db import async_session_factory, orm -from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository +from starlite_saqlalchemy.db import async_session_factory from starlite_saqlalchemy.repository.types import ModelT from starlite_saqlalchemy.worker import queue if TYPE_CHECKING: + from saq.types import Context + from starlite_saqlalchemy.repository.abc import AbstractRepository from starlite_saqlalchemy.repository.types import FilterTypes @@ -23,7 +23,8 @@ logger = logging.getLogger(__name__) ServiceT = TypeVar("ServiceT", bound="Service") -Context = dict[str, Any] + +service_object_identity_map: dict[str, type[Service]] = {} class ServiceException(Exception): @@ -37,6 +38,7 @@ class UnauthorizedException(ServiceException): class Service(Generic[ModelT]): """Generic Service object.""" + __id__: ClassVar[str] repository_type: type[AbstractRepository[ModelT]] def __init__(self, **repo_kwargs: Any) -> None: @@ -47,12 +49,17 @@ def __init__(self, **repo_kwargs: Any) -> None: """ self.repository = self.repository_type(**repo_kwargs) - @classmethod - def __class_getitem__(cls: type[ServiceT], item: type[ModelT]) -> type[ServiceT]: - """Set `repository_type` from generic parameter.""" - if not getattr(cls, "repository_type", None) and issubclass(item, orm.Base): - cls.repository_type = SQLAlchemyRepository[item] # type:ignore[valid-type] - return cls + def __init_subclass__(cls, *_: Any, **__: Any) -> None: + """Map the service object to a unique identifier. + + Important that the id is deterministic across running + application instances, e.g., using something like `hash()` or + `id()` won't work as those would be different on different + instances of the running application. So we use the full import + path to the object. + """ + cls.__id__ = f"{cls.__module__}.{cls.__name__}" + service_object_identity_map[cls.__id__] = cls async def create(self, data: ModelT) -> ModelT: """Wrap repository instance creation. @@ -138,8 +145,7 @@ async def enqueue_background_task(self, method_name: str, **kwargs: Any) -> None return await queue.enqueue( make_service_callback.__qualname__, - service_module_name=module.__name__, - service_type_fqdn=type(self).__qualname__, + service_type_id=self.__id__, service_method_name=method_name, **kwargs, ) @@ -148,8 +154,7 @@ async def enqueue_background_task(self, method_name: str, **kwargs: Any) -> None async def make_service_callback( _ctx: Context, *, - service_module_name: str, - service_type_fqdn: str, + service_type_id: str, service_method_name: str, **kwargs: Any, ) -> None: @@ -157,19 +162,11 @@ async def make_service_callback( Args: _ctx: the SAQ context - service_module_name: Module of service type to instantiate. - service_type_fqdn: Reference to service type in module. + service_type_id: Value of `__id__` class var on service type. service_method_name: Method to be called on the service object. **kwargs: Unpacked into the service method call as keyword arguments. """ - obj_: Any = importlib.import_module(service_module_name) - for name in service_type_fqdn.split("."): - obj_ = getattr(obj_, name, None) - if inspect.isclass(obj_) and issubclass(obj_, Service): - service_type = obj_ - break - else: - raise RuntimeError("Couldn't find a service type with given module and fqdn") + service_type = service_object_identity_map[service_type_id] async with async_session_factory() as session: service_object: Service = service_type(session=session) method = getattr(service_object, service_method_name) diff --git a/src/starlite_saqlalchemy/testing.py b/src/starlite_saqlalchemy/testing.py index 80181783..c672a0cd 100644 --- a/src/starlite_saqlalchemy/testing.py +++ b/src/starlite_saqlalchemy/testing.py @@ -2,6 +2,8 @@ Uses a `dict` for storage. """ +from __future__ import annotations + from datetime import datetime from typing import TYPE_CHECKING, Any, Generic, TypeVar from uuid import uuid4 @@ -24,9 +26,9 @@ class GenericMockRepository(AbstractRepository[BaseT], Generic[BaseT]): Uses a `dict` for storage. """ - collection: "abc.MutableMapping[abc.Hashable, BaseT]" = {} + collection: abc.MutableMapping[abc.Hashable, BaseT] = {} - def __init__(self, id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any) -> None: + def __init__(self, id_factory: abc.Callable[[], Any] = uuid4, **_: Any) -> None: super().__init__() self._id_factory = id_factory @@ -34,6 +36,14 @@ def _find_or_raise_not_found(self, id_: Any) -> BaseT: return self.check_not_found(self.collection.get(id_)) async def add(self, data: BaseT, _allow_id: bool = False) -> BaseT: + """Add `data` to the collection. + + Args: + data: Instance to be added to the collection. + + Returns: + The added instance. + """ if _allow_id is False and self.get_id_attribute_value(data) is not None: raise RepositoryConflictException("`add()` received identified item.") now = datetime.now() @@ -45,18 +55,61 @@ async def add(self, data: BaseT, _allow_id: bool = False) -> BaseT: return data async def delete(self, id_: Any) -> BaseT: + """Delete instance identified by `id_`. + + Args: + id_: Identifier of instance to be deleted. + + Returns: + The deleted instance. + + Raises: + RepositoryNotFoundException: If no instance found identified by `id_`. + """ try: return self._find_or_raise_not_found(id_) finally: del self.collection[id_] async def get(self, id_: Any) -> BaseT: + """Get instance identified by `id_`. + + Args: + id_: Identifier of the instance to be retrieved. + + Returns: + The retrieved instance. + + Raises: + RepositoryNotFoundException: If no instance found identified by `id_`. + """ return self._find_or_raise_not_found(id_) async def list(self, *filters: "FilterTypes", **kwargs: Any) -> list[BaseT]: + """Get a list of instances, optionally filtered. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + The list of instances, after filtering applied. + """ return list(self.collection.values()) async def update(self, data: BaseT) -> BaseT: + """Update instance with the attribute values present on `data`. + + Args: + data: An instance that should have a value for `self.id_attribute` that exists in the + collection. + + Returns: + The updated instance. + + Raises: + RepositoryNotFoundException: If no instance found with same identifier as `data`. + """ item = self._find_or_raise_not_found(self.get_id_attribute_value(data)) # should never be modifiable data.updated = datetime.now() @@ -67,6 +120,22 @@ async def update(self, data: BaseT) -> BaseT: return item async def upsert(self, data: BaseT) -> BaseT: + """Update or create instance. + + Updates instance with the attribute values present on `data`, or creates a new instance if + one doesn't exist. + + Args: + data: Instance to update existing, or be created. Identifier used to determine if an + existing instance exists is the value of an attribute on `data` named as value of + `self.id_attribute`. + + Returns: + The updated or created instance. + + Raises: + RepositoryNotFoundException: If no instance found with same identifier as `data`. + """ id_ = self.get_id_attribute_value(data) if id_ in self.collection: return await self.update(data) diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 559fd7cf..8a966673 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -85,8 +85,7 @@ async def test_make_service_callback( monkeypatch.setattr(service.Service, "receive_callback", recv_cb_mock, raising=False) await service.make_service_callback( {}, - service_module_name="starlite_saqlalchemy.service", - service_type_fqdn="Service", + service_type_id="tests.utils.domain.Service", service_method_name="receive_callback", raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)), ) @@ -105,11 +104,10 @@ async def test_make_service_callback_raises_runtime_error( raw_authors: list[dict[str, Any]] ) -> None: """Tests loading and retrieval of service object types.""" - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): await service.make_service_callback( {}, - service_module_name="starlite_saqlalchemy.service", - service_type_fqdn="TheService", + service_type_id="tests.utils.domain.LSKDFJ", service_method_name="receive_callback", raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)), ) @@ -123,8 +121,7 @@ async def test_enqueue_service_callback(monkeypatch: "MonkeyPatch") -> None: await service_instance.enqueue_background_task("receive_callback", raw_obj={"a": "b"}) enqueue_mock.assert_called_once_with( "make_service_callback", - service_module_name="starlite_saqlalchemy.service", - service_type_fqdn="Service", + service_type_id="tests.utils.domain.Service", service_method_name="receive_callback", raw_obj={"a": "b"}, ) diff --git a/tests/utils/domain.py b/tests/utils/domain.py index 15aecb07..2679c370 100644 --- a/tests/utils/domain.py +++ b/tests/utils/domain.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Mapped -from starlite_saqlalchemy import db, dto, service +from starlite_saqlalchemy import db, dto, repository, service class Author(db.orm.Base): # pylint: disable=too-few-public-methods @@ -15,8 +15,17 @@ class Author(db.orm.Base): # pylint: disable=too-few-public-methods dob: Mapped[date] -Service = service.Service[Author] -"""Author service object.""" +class Repository(repository.sqlalchemy.SQLAlchemyRepository[Author]): + """Author repository.""" + + model_type = Author + + +class Service(service.Service[Author]): + """Author service object.""" + + repository_type = Repository + CreateDTO = dto.factory("AuthorCreateDTO", Author, purpose=dto.Purpose.WRITE, exclude={"id"}) """