From 703c75efbc46eb11d3661b8865fccd3537f61131 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 2 Nov 2022 16:23:30 +1000 Subject: [PATCH] feat(repository): make abc more general. The ABC was typed to accept a sqlalchemy session object. Changed to instead receive arbitrary kwargs, but if the base class actually ever receives kwargs it will error out due to super call to `object.__init__()`. The service object accepts arbitrary kwargs that are passed through to the repository, but doesn't care what they are. This all means that the only thing that knows and cares about the sqlalchemy session, is the sqlalchemy repository, and that feels right. One facet of this approach is that it makes the concept of the transaction an implementation detail. The sqlalchemy repo has the concept of session/transaction, but a repository doesn't _have_ to understand those things. This is consistent with the testing repository implementation, so happy to see how the pattern pans out. Closes #54 --- src/starlite_saqlalchemy/repository/abc.py | 6 ++---- src/starlite_saqlalchemy/repository/sqlalchemy.py | 7 +++++-- src/starlite_saqlalchemy/service.py | 5 ++--- tests/unit/test_service.py | 2 +- tests/unit/utils.py | 8 ++------ 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 93064d3e..75733cd0 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -8,8 +8,6 @@ from starlite_saqlalchemy.repository.exceptions import RepositoryNotFoundException if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncSession - from .types import FilterTypes __all__ = ["AbstractRepository"] @@ -26,8 +24,8 @@ class AbstractRepository(Generic[T], metaclass=ABCMeta): id_attribute = "id" """Name of the primary identifying attribute on `model_type`.""" - def __init__(self, session: AsyncSession) -> None: - self.session = session + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) @abstractmethod async def add(self, data: T) -> T: diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index a7bb2fc9..89f540f4 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -69,8 +69,11 @@ class SQLAlchemyRepository(AbstractRepository[ModelT]): model_type: type[ModelT] - def __init__(self, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None) -> None: - super().__init__(session) + def __init__( + self, *, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.session = session self._select = select(self.model_type) if select_ is None else select_ async def add(self, data: ModelT) -> ModelT: diff --git a/src/starlite_saqlalchemy/service.py b/src/starlite_saqlalchemy/service.py index ce3e4116..437024bc 100644 --- a/src/starlite_saqlalchemy/service.py +++ b/src/starlite_saqlalchemy/service.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from pydantic import BaseModel from saq.types import Context - from sqlalchemy.ext.asyncio import AsyncSession from starlite_saqlalchemy.repository.abc import AbstractRepository from starlite_saqlalchemy.repository.types import FilterTypes @@ -67,8 +66,8 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: f"__{model_type.__tablename__}DTO", model_type, dto.Purpose.READ ) - def __init__(self, session: AsyncSession) -> None: - self.repository: AbstractRepository[ModelT] = self.repository_type(session) + def __init__(self, **repo_kwargs: Any) -> None: + self.repository: AbstractRepository[ModelT] = self.repository_type(**repo_kwargs) # noinspection PyMethodMayBeStatic async def authorize_create(self, data: ModelT) -> ModelT: diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 6574077f..45db1275 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -52,7 +52,7 @@ async def test_enqueue_service_callback( """Tests that job enqueued with desired arguments.""" enqueue_mock = AsyncMock() monkeypatch.setattr(worker.queue, "enqueue", enqueue_mock) - service_instance = domain.Service(sqlalchemy_plugin.async_session_factory()) + service_instance = domain.Service(session=sqlalchemy_plugin.async_session_factory()) await service_instance.enqueue_callback( service.Operation.UPDATE, domain.Author(**raw_authors[0]) ) diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 4e40b692..e38fa308 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -12,8 +12,6 @@ if TYPE_CHECKING: from collections import abc - from sqlalchemy.ext.asyncio import AsyncSession - from starlite_saqlalchemy.repository.types import FilterTypes BaseT = TypeVar("BaseT", bound=Base) @@ -27,10 +25,8 @@ class GenericMockRepository(AbstractRepository[BaseT], Generic[BaseT]): collection: "abc.MutableMapping[abc.Hashable, BaseT]" = {} - def __init__( - self, session: "AsyncSession", id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any - ) -> None: - super().__init__(session) + def __init__(self, id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any) -> None: + super().__init__() self._id_factory = id_factory def _find_or_raise_not_found(self, id_: Any) -> BaseT: