Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
feat(service): __class_getitem__ sets repository_type on service.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterschutt committed Nov 17, 2022
1 parent d81e6c6 commit 4127f3a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 21 deletions.
15 changes: 11 additions & 4 deletions src/starlite_saqlalchemy/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import logging
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from starlite_saqlalchemy.db import async_session_factory
from starlite_saqlalchemy.repository.sqlalchemy import ModelT
from starlite_saqlalchemy.db import async_session_factory, orm
from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository
from starlite_saqlalchemy.repository.types import ModelT
from starlite_saqlalchemy.worker import queue

if TYPE_CHECKING:
Expand Down Expand Up @@ -39,7 +40,13 @@ class Service(Generic[ModelT]):
repository_type: type[AbstractRepository[ModelT]]

def __init__(self, **repo_kwargs: Any) -> None:
self.repository: AbstractRepository[ModelT] = self.repository_type(**repo_kwargs)
self.repository = self.repository_type(**repo_kwargs)

@classmethod
def __class_getitem__(cls: type[ServiceT], item: type[ModelT]) -> type[ServiceT]:
if not getattr(cls, "repository_type", None) and issubclass(item, orm.Base):
cls.repository_type = SQLAlchemyRepository[item] # type:ignore[valid-type]
return cls

async def create(self, data: ModelT) -> ModelT:
"""Wraps repository instance creation.
Expand Down Expand Up @@ -154,7 +161,7 @@ async def make_service_callback(
"""
obj_: Any = importlib.import_module(service_module_name)
for name in service_type_fqdn.split("."):
obj_ = getattr(obj_, name)
obj_ = getattr(obj_, name, None)
if inspect.isclass(obj_) and issubclass(obj_, Service):
service_type = obj_
break
Expand Down
1 change: 0 additions & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def _author_repository(raw_authors: list[dict[str, Any]], monkeypatch: pytest.Mo
author = domain.Author(**raw_author)
collection[getattr(author, AuthorRepository.id_attribute)] = author
monkeypatch.setattr(AuthorRepository, "collection", collection)
monkeypatch.setattr(domain, "Repository", AuthorRepository)
monkeypatch.setattr(domain.Service, "repository_type", AuthorRepository)


Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def test_make_service_callback(
monkeypatch.setattr(service.Service, "receive_callback", recv_cb_mock, raising=False)
await service.make_service_callback(
{},
service_module_name="tests.utils.domain",
service_module_name="starlite_saqlalchemy.service",
service_type_fqdn="Service",
service_method_name="receive_callback",
raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)),
Expand All @@ -108,8 +108,8 @@ async def test_make_service_callback_raises_runtime_error(
with pytest.raises(RuntimeError):
await service.make_service_callback(
{},
service_module_name="tests.utils.domain",
service_type_fqdn="Author.name",
service_module_name="starlite_saqlalchemy.service",
service_type_fqdn="TheService",
service_method_name="receive_callback",
raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)),
)
Expand All @@ -123,7 +123,7 @@ async def test_enqueue_service_callback(monkeypatch: "MonkeyPatch") -> None:
await service_instance.enqueue_background_task("receive_callback", raw_obj={"a": "b"})
enqueue_mock.assert_called_once_with(
"make_service_callback",
service_module_name="tests.utils.domain",
service_module_name="starlite_saqlalchemy.service",
service_type_fqdn="Service",
service_method_name="receive_callback",
raw_obj={"a": "b"},
Expand Down
16 changes: 4 additions & 12 deletions tests/utils/domain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Example domain objects for testing."""
from __future__ import annotations

from datetime import date # noqa: TC003

from sqlalchemy.orm import Mapped

from starlite_saqlalchemy import db, dto, service
from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository


class Author(db.orm.Base): # pylint: disable=too-few-public-methods
Expand All @@ -14,17 +15,8 @@ class Author(db.orm.Base): # pylint: disable=too-few-public-methods
dob: Mapped[date]


class Repository(SQLAlchemyRepository[Author]):
"""Author repository."""

model_type = Author


class Service(service.Service[Author]):
"""Author service object."""

repository_type = Repository

Service = service.Service[Author]
"""Author service object."""

CreateDTO = dto.factory("AuthorCreateDTO", Author, purpose=dto.Purpose.WRITE, exclude={"id"})
"""
Expand Down

0 comments on commit 4127f3a

Please sign in to comment.