From d81e6c6df972c5d8d919e8d9c398404b57acfd8f Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Thu, 17 Nov 2022 17:07:23 +1000 Subject: [PATCH] feat(repository): __class_getitem__ sets `model_type` on repo. --- src/starlite_saqlalchemy/repository/abc.py | 7 +++++ src/starlite_saqlalchemy/repository/types.py | 27 +++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 75733cd0..e2006221 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -13,6 +13,7 @@ __all__ = ["AbstractRepository"] T = TypeVar("T") +RepoT = TypeVar("RepoT", bound="AbstractRepository") class AbstractRepository(Generic[T], metaclass=ABCMeta): @@ -27,6 +28,12 @@ class AbstractRepository(Generic[T], metaclass=ABCMeta): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) + @classmethod + def __class_getitem__(cls: type[RepoT], item: type[T]) -> type[RepoT]: + if not isinstance(item, TypeVar) and not getattr(cls, "model_type", None): + cls.model_type = item + return cls + @abstractmethod async def add(self, data: T) -> T: """Add `data` to the collection. diff --git a/src/starlite_saqlalchemy/repository/types.py b/src/starlite_saqlalchemy/repository/types.py index e9a66500..fdffad7e 100644 --- a/src/starlite_saqlalchemy/repository/types.py +++ b/src/starlite_saqlalchemy/repository/types.py @@ -1,7 +1,7 @@ """Repository type definitions.""" from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any, Protocol, TypeVar from starlite_saqlalchemy.repository.filters import ( BeforeAfter, @@ -9,5 +9,30 @@ LimitOffset, ) +if TYPE_CHECKING: + from pydantic import BaseModel + FilterTypes = BeforeAfter | CollectionFilter[Any] | LimitOffset """Aggregate type alias of the types supported for collection filtering.""" + + +T = TypeVar("T") + + +class ModelProtocol(Protocol): # pylint: disable=too-few-public-methods + """Protocol for repository models.""" + + @classmethod + def from_dto(cls: type[T], dto_instance: BaseModel) -> T: # pragma: no cover + """ + + Args: + dto_instance: A pydantic model. + + Returns: + Instance of type with values populated from `dto_instance`. + """ + ... # pylint: disable=unnecessary-ellipsis + + +ModelT = TypeVar("ModelT", bound=ModelProtocol)