Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
feat!: simplify service and repo. (#134)
Browse files Browse the repository at this point in the history
* feat!: simplify service and repo.

Was all just a bit too magic. Would rather keep it simple for now and
see how it evolves.

* Update tests/unit/repository/test_sqlalchemy.py
  • Loading branch information
peterschutt authored Nov 21, 2022
1 parent 52f7ccb commit ed3b59a
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 48 deletions.
7 changes: 0 additions & 7 deletions src/starlite_saqlalchemy/repository/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,6 @@ def __init__(self, **kwargs: Any) -> None:
"""Repository constructors accept arbitrary kwargs."""
super().__init__(**kwargs)

@classmethod
def __class_getitem__(cls: type[RepoT], item: type[T]) -> type[RepoT]:
"""Set `model_type` attribute, using generic parameter."""
if not isinstance(item, TypeVar) and not getattr(cls, "model_type", None):
cls.model_type = item
return cls

@abstractmethod
async def add(self, data: T) -> T:
"""Add `data` to the collection.
Expand Down
7 changes: 3 additions & 4 deletions src/starlite_saqlalchemy/repository/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

from sqlalchemy import select, text
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
Expand Down Expand Up @@ -36,6 +36,7 @@

T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="orm.Base")
SQLARepoT = TypeVar("SQLARepoT", bound="SQLAlchemyRepository")


@contextmanager
Expand All @@ -59,11 +60,9 @@ def wrap_sqlalchemy_exception() -> Any:
raise RepositoryException(f"An exception occurred: {exc}") from exc


class SQLAlchemyRepository(AbstractRepository[ModelT]):
class SQLAlchemyRepository(AbstractRepository[ModelT], Generic[ModelT]):
"""SQLAlchemy based implementation of the repository interface."""

model_type: type[ModelT]

def __init__(
self, *, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None, **kwargs: Any
) -> None:
Expand Down
47 changes: 22 additions & 25 deletions src/starlite_saqlalchemy/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@
"""
from __future__ import annotations

import importlib
import inspect
import logging
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar

from starlite_saqlalchemy.db import async_session_factory, orm
from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository
from starlite_saqlalchemy.db import async_session_factory
from starlite_saqlalchemy.repository.types import ModelT
from starlite_saqlalchemy.worker import queue

if TYPE_CHECKING:
from saq.types import Context

from starlite_saqlalchemy.repository.abc import AbstractRepository
from starlite_saqlalchemy.repository.types import FilterTypes


logger = logging.getLogger(__name__)

ServiceT = TypeVar("ServiceT", bound="Service")
Context = dict[str, Any]

service_object_identity_map: dict[str, type[Service]] = {}


class ServiceException(Exception):
Expand All @@ -37,6 +38,7 @@ class UnauthorizedException(ServiceException):
class Service(Generic[ModelT]):
"""Generic Service object."""

__id__: ClassVar[str]
repository_type: type[AbstractRepository[ModelT]]

def __init__(self, **repo_kwargs: Any) -> None:
Expand All @@ -47,12 +49,17 @@ def __init__(self, **repo_kwargs: Any) -> None:
"""
self.repository = self.repository_type(**repo_kwargs)

@classmethod
def __class_getitem__(cls: type[ServiceT], item: type[ModelT]) -> type[ServiceT]:
"""Set `repository_type` from generic parameter."""
if not getattr(cls, "repository_type", None) and issubclass(item, orm.Base):
cls.repository_type = SQLAlchemyRepository[item] # type:ignore[valid-type]
return cls
def __init_subclass__(cls, *_: Any, **__: Any) -> None:
"""Map the service object to a unique identifier.
Important that the id is deterministic across running
application instances, e.g., using something like `hash()` or
`id()` won't work as those would be different on different
instances of the running application. So we use the full import
path to the object.
"""
cls.__id__ = f"{cls.__module__}.{cls.__name__}"
service_object_identity_map[cls.__id__] = cls

async def create(self, data: ModelT) -> ModelT:
"""Wrap repository instance creation.
Expand Down Expand Up @@ -138,8 +145,7 @@ async def enqueue_background_task(self, method_name: str, **kwargs: Any) -> None
return
await queue.enqueue(
make_service_callback.__qualname__,
service_module_name=module.__name__,
service_type_fqdn=type(self).__qualname__,
service_type_id=self.__id__,
service_method_name=method_name,
**kwargs,
)
Expand All @@ -148,28 +154,19 @@ async def enqueue_background_task(self, method_name: str, **kwargs: Any) -> None
async def make_service_callback(
_ctx: Context,
*,
service_module_name: str,
service_type_fqdn: str,
service_type_id: str,
service_method_name: str,
**kwargs: Any,
) -> None:
"""Make an async service callback.
Args:
_ctx: the SAQ context
service_module_name: Module of service type to instantiate.
service_type_fqdn: Reference to service type in module.
service_type_id: Value of `__id__` class var on service type.
service_method_name: Method to be called on the service object.
**kwargs: Unpacked into the service method call as keyword arguments.
"""
obj_: Any = importlib.import_module(service_module_name)
for name in service_type_fqdn.split("."):
obj_ = getattr(obj_, name, None)
if inspect.isclass(obj_) and issubclass(obj_, Service):
service_type = obj_
break
else:
raise RuntimeError("Couldn't find a service type with given module and fqdn")
service_type = service_object_identity_map[service_type_id]
async with async_session_factory() as session:
service_object: Service = service_type(session=session)
method = getattr(service_object, service_method_name)
Expand Down
73 changes: 71 additions & 2 deletions src/starlite_saqlalchemy/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Uses a `dict` for storage.
"""
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from uuid import uuid4
Expand All @@ -24,16 +26,24 @@ class GenericMockRepository(AbstractRepository[BaseT], Generic[BaseT]):
Uses a `dict` for storage.
"""

collection: "abc.MutableMapping[abc.Hashable, BaseT]" = {}
collection: abc.MutableMapping[abc.Hashable, BaseT] = {}

def __init__(self, id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any) -> None:
def __init__(self, id_factory: abc.Callable[[], Any] = uuid4, **_: Any) -> None:
super().__init__()
self._id_factory = id_factory

def _find_or_raise_not_found(self, id_: Any) -> BaseT:
return self.check_not_found(self.collection.get(id_))

async def add(self, data: BaseT, _allow_id: bool = False) -> BaseT:
"""Add `data` to the collection.
Args:
data: Instance to be added to the collection.
Returns:
The added instance.
"""
if _allow_id is False and self.get_id_attribute_value(data) is not None:
raise RepositoryConflictException("`add()` received identified item.")
now = datetime.now()
Expand All @@ -45,18 +55,61 @@ async def add(self, data: BaseT, _allow_id: bool = False) -> BaseT:
return data

async def delete(self, id_: Any) -> BaseT:
"""Delete instance identified by `id_`.
Args:
id_: Identifier of instance to be deleted.
Returns:
The deleted instance.
Raises:
RepositoryNotFoundException: If no instance found identified by `id_`.
"""
try:
return self._find_or_raise_not_found(id_)
finally:
del self.collection[id_]

async def get(self, id_: Any) -> BaseT:
"""Get instance identified by `id_`.
Args:
id_: Identifier of the instance to be retrieved.
Returns:
The retrieved instance.
Raises:
RepositoryNotFoundException: If no instance found identified by `id_`.
"""
return self._find_or_raise_not_found(id_)

async def list(self, *filters: "FilterTypes", **kwargs: Any) -> list[BaseT]:
"""Get a list of instances, optionally filtered.
Args:
*filters: Types for specific filtering operations.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
return list(self.collection.values())

async def update(self, data: BaseT) -> BaseT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that exists in the
collection.
Returns:
The updated instance.
Raises:
RepositoryNotFoundException: If no instance found with same identifier as `data`.
"""
item = self._find_or_raise_not_found(self.get_id_attribute_value(data))
# should never be modifiable
data.updated = datetime.now()
Expand All @@ -67,6 +120,22 @@ async def update(self, data: BaseT) -> BaseT:
return item

async def upsert(self, data: BaseT) -> BaseT:
"""Update or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
one doesn't exist.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
Returns:
The updated or created instance.
Raises:
RepositoryNotFoundException: If no instance found with same identifier as `data`.
"""
id_ = self.get_id_attribute_value(data)
if id_ in self.collection:
return await self.update(data)
Expand Down
11 changes: 4 additions & 7 deletions tests/unit/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ async def test_make_service_callback(
monkeypatch.setattr(service.Service, "receive_callback", recv_cb_mock, raising=False)
await service.make_service_callback(
{},
service_module_name="starlite_saqlalchemy.service",
service_type_fqdn="Service",
service_type_id="tests.utils.domain.Service",
service_method_name="receive_callback",
raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)),
)
Expand All @@ -105,11 +104,10 @@ async def test_make_service_callback_raises_runtime_error(
raw_authors: list[dict[str, Any]]
) -> None:
"""Tests loading and retrieval of service object types."""
with pytest.raises(RuntimeError):
with pytest.raises(KeyError):
await service.make_service_callback(
{},
service_module_name="starlite_saqlalchemy.service",
service_type_fqdn="TheService",
service_type_id="tests.utils.domain.LSKDFJ",
service_method_name="receive_callback",
raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)),
)
Expand All @@ -123,8 +121,7 @@ async def test_enqueue_service_callback(monkeypatch: "MonkeyPatch") -> None:
await service_instance.enqueue_background_task("receive_callback", raw_obj={"a": "b"})
enqueue_mock.assert_called_once_with(
"make_service_callback",
service_module_name="starlite_saqlalchemy.service",
service_type_fqdn="Service",
service_type_id="tests.utils.domain.Service",
service_method_name="receive_callback",
raw_obj={"a": "b"},
)
15 changes: 12 additions & 3 deletions tests/utils/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlalchemy.orm import Mapped

from starlite_saqlalchemy import db, dto, service
from starlite_saqlalchemy import db, dto, repository, service


class Author(db.orm.Base): # pylint: disable=too-few-public-methods
Expand All @@ -15,8 +15,17 @@ class Author(db.orm.Base): # pylint: disable=too-few-public-methods
dob: Mapped[date]


Service = service.Service[Author]
"""Author service object."""
class Repository(repository.sqlalchemy.SQLAlchemyRepository[Author]):
"""Author repository."""

model_type = Author


class Service(service.Service[Author]):
"""Author service object."""

repository_type = Repository


CreateDTO = dto.factory("AuthorCreateDTO", Author, purpose=dto.Purpose.WRITE, exclude={"id"})
"""
Expand Down

0 comments on commit ed3b59a

Please sign in to comment.