diff --git a/src/starlite_saqlalchemy/service.py b/src/starlite_saqlalchemy/service.py index f9ebb848..31bd9eeb 100644 --- a/src/starlite_saqlalchemy/service.py +++ b/src/starlite_saqlalchemy/service.py @@ -10,8 +10,9 @@ import logging from typing import TYPE_CHECKING, Any, Generic, TypeVar -from starlite_saqlalchemy.db import async_session_factory -from starlite_saqlalchemy.repository.sqlalchemy import ModelT +from starlite_saqlalchemy.db import async_session_factory, orm +from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository +from starlite_saqlalchemy.repository.types import ModelT from starlite_saqlalchemy.worker import queue if TYPE_CHECKING: @@ -39,7 +40,13 @@ class Service(Generic[ModelT]): repository_type: type[AbstractRepository[ModelT]] def __init__(self, **repo_kwargs: Any) -> None: - self.repository: AbstractRepository[ModelT] = self.repository_type(**repo_kwargs) + self.repository = self.repository_type(**repo_kwargs) + + @classmethod + def __class_getitem__(cls: type[ServiceT], item: type[ModelT]) -> type[ServiceT]: + if not getattr(cls, "repository_type", None) and issubclass(item, orm.Base): + cls.repository_type = SQLAlchemyRepository[item] # type:ignore[valid-type] + return cls async def create(self, data: ModelT) -> ModelT: """Wraps repository instance creation. @@ -154,7 +161,7 @@ async def make_service_callback( """ obj_: Any = importlib.import_module(service_module_name) for name in service_type_fqdn.split("."): - obj_ = getattr(obj_, name) + obj_ = getattr(obj_, name, None) if inspect.isclass(obj_) and issubclass(obj_, Service): service_type = obj_ break diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index caea657c..bce74576 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -52,7 +52,6 @@ def _author_repository(raw_authors: list[dict[str, Any]], monkeypatch: pytest.Mo author = domain.Author(**raw_author) collection[getattr(author, AuthorRepository.id_attribute)] = author monkeypatch.setattr(AuthorRepository, "collection", collection) - monkeypatch.setattr(domain, "Repository", AuthorRepository) monkeypatch.setattr(domain.Service, "repository_type", AuthorRepository) diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index dfae2931..559fd7cf 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -85,7 +85,7 @@ async def test_make_service_callback( monkeypatch.setattr(service.Service, "receive_callback", recv_cb_mock, raising=False) await service.make_service_callback( {}, - service_module_name="tests.utils.domain", + service_module_name="starlite_saqlalchemy.service", service_type_fqdn="Service", service_method_name="receive_callback", raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)), @@ -108,8 +108,8 @@ async def test_make_service_callback_raises_runtime_error( with pytest.raises(RuntimeError): await service.make_service_callback( {}, - service_module_name="tests.utils.domain", - service_type_fqdn="Author.name", + service_module_name="starlite_saqlalchemy.service", + service_type_fqdn="TheService", service_method_name="receive_callback", raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)), ) @@ -123,7 +123,7 @@ async def test_enqueue_service_callback(monkeypatch: "MonkeyPatch") -> None: await service_instance.enqueue_background_task("receive_callback", raw_obj={"a": "b"}) enqueue_mock.assert_called_once_with( "make_service_callback", - service_module_name="tests.utils.domain", + service_module_name="starlite_saqlalchemy.service", service_type_fqdn="Service", service_method_name="receive_callback", raw_obj={"a": "b"}, diff --git a/tests/utils/domain.py b/tests/utils/domain.py index c248d8e9..15aecb07 100644 --- a/tests/utils/domain.py +++ b/tests/utils/domain.py @@ -1,10 +1,11 @@ """Example domain objects for testing.""" +from __future__ import annotations + from datetime import date # noqa: TC003 from sqlalchemy.orm import Mapped from starlite_saqlalchemy import db, dto, service -from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository class Author(db.orm.Base): # pylint: disable=too-few-public-methods @@ -14,17 +15,8 @@ class Author(db.orm.Base): # pylint: disable=too-few-public-methods dob: Mapped[date] -class Repository(SQLAlchemyRepository[Author]): - """Author repository.""" - - model_type = Author - - -class Service(service.Service[Author]): - """Author service object.""" - - repository_type = Repository - +Service = service.Service[Author] +"""Author service object.""" CreateDTO = dto.factory("AuthorCreateDTO", Author, purpose=dto.Purpose.WRITE, exclude={"id"}) """