diff --git a/mypy.ini b/mypy.ini index a4f69ede..0ead1e2f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ check_untyped_defs = True disallow_any_generics = False -disallow_incomplete_defs = True +disallow_incomplete_defs = False disallow_untyped_decorators = True disallow_untyped_defs = True ignore_missing_imports = False diff --git a/pyproject.toml b/pyproject.toml index eb2ebce3..5fae20e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ testpaths = ["tests/unit"] disable = [ "line-too-long", "no-self-argument", + "too-few-public-methods", "too-many-arguments", ] enable = "useless-suppression" diff --git a/src/starlite_saqlalchemy/dto.py b/src/starlite_saqlalchemy/dto.py index 9c003e17..a0c299c2 100644 --- a/src/starlite_saqlalchemy/dto.py +++ b/src/starlite_saqlalchemy/dto.py @@ -9,20 +9,21 @@ from __future__ import annotations from enum import Enum, auto -from inspect import get_annotations, getmodule, isclass +from inspect import getmodule, isclass from typing import ( TYPE_CHECKING, Any, ClassVar, + Generic, NamedTuple, TypedDict, + TypeVar, cast, get_args, get_origin, get_type_hints, ) -from pydantic import BaseConfig as BaseConfig_ from pydantic import BaseModel, create_model, validator from pydantic.fields import FieldInfo from sqlalchemy import inspect @@ -40,6 +41,9 @@ from sqlalchemy.util import ReadOnlyProperties +AnyDeclarative = TypeVar("AnyDeclarative", bound=DeclarativeBase) + + class Mark(str, Enum): """For marking column definitions on the domain models. @@ -90,31 +94,24 @@ class Attrib(NamedTuple): """Single argument callables that are defined on the DTO as validators for the field.""" -class BaseConfig(BaseConfig_): - """Base config for generated pydantic models.""" - - orm_mode = True - - -class MapperBind(BaseModel): +class _MapperBind(BaseModel, Generic[AnyDeclarative]): """Produce an SQLAlchemy instance with values from a pydantic model.""" __sqla_model__: ClassVar[type[DeclarativeBase]] - class Config(BaseConfig): - """Config for MapperBind pydantic models.""" + class Config: + """Set orm_mode for `to_mapped()` method.""" - @classmethod - def __init_subclass__(cls, model: type[DeclarativeBase], **kwargs: Any) -> None: - """Set `__sqla_model__` class var. + orm_mode = True - Args: - model: SQLAlchemy model represented by the DTO. - """ - cls.__sqla_model__ = model - return super().__init_subclass__(**kwargs) + def __init_subclass__( # pylint: disable=arguments-differ + cls, model: type[DeclarativeBase] | None = None, **kwargs: Any + ) -> None: + if model is not None: + cls.__sqla_model__ = model + super().__init_subclass__(**kwargs) - def to_mapped(self) -> DeclarativeBase: + def to_mapped(self) -> AnyDeclarative: """Create an instance of `self.__sqla_model__` Fill the bound SQLAlchemy model recursively with values from @@ -124,11 +121,11 @@ def to_mapped(self) -> DeclarativeBase: for field in self.__fields__.values(): value = getattr(self, field.name) if isinstance(value, (list, tuple)): - value = [el.to_mapped() if isinstance(el, MapperBind) else el for el in value] - if isinstance(value, MapperBind): + value = [el.to_mapped() if isinstance(el, _MapperBind) else el for el in value] + if isinstance(value, _MapperBind): value = value.to_mapped() as_model[field.name] = value - return self.__sqla_model__(**as_model) + return cast("AnyDeclarative", self.__sqla_model__(**as_model)) def _construct_field_info(elem: Column | RelationshipProperty, purpose: Purpose) -> FieldInfo: @@ -172,6 +169,11 @@ def _inspect_model( return columns, relationships +def _get_localns(model: type[DeclarativeBase]) -> dict[str, Any]: + model_module = getmodule(model) + return vars(model_module) if model_module is not None else {} + + def mark(mark_type: Mark) -> DTOInfo: """Shortcut for ```python. @@ -198,16 +200,15 @@ class User(DeclarativeBase): return {"dto": Attrib(mark=mark_type)} -def factory( # pylint: disable=too-many-locals +def factory( name: str, - model: type[DeclarativeBase], + model: type[AnyDeclarative], purpose: Purpose, + *, exclude: set[str] | None = None, - namespace: dict[str, Any] | None = None, - annotations: dict[str, Any] | None = None, - base: type[BaseModel] = MapperBind, -) -> type[BaseModel]: - """Create a pydantic model class from a SQLAlchemy declarative ORM class. + base: type[BaseModel] | None = None, +) -> type[_MapperBind[AnyDeclarative]]: + """Infer a Pydantic model from a SQLAlchemy model. The fields that are included in the model can be controlled on the SQLAlchemy class definition by including a "dto" key in the `Column.info` mapping. For example: @@ -230,32 +231,24 @@ class User(DeclarativeBase): model: The SQLAlchemy model class. purpose: Is the DTO for write or read operations? exclude: Explicitly exclude attributes from the DTO. - namespace: Additional namespace used to resolve forward references. The default namespace - used is that of the module of `model`. - annotations: Annotations that override and supplement the annotations derived from `model`. base: A subclass of `pydantic.BaseModel` to be used as the base class of the DTO. Returns: A Pydantic model that includes only fields that are appropriate to `purpose` and not in `exclude`. """ - exclude_ = exclude or set[str]() - annotations_ = annotations or {} - namespace_ = namespace or {} - del exclude, annotations, namespace + exclude = set() if exclude is None else exclude + columns, relationships = _inspect_model(model) - model_module = getmodule(model) - localns = vars(model_module) if model_module is not None else {} - localns.update(namespace_) - type_hints = get_type_hints(model, localns=localns) - type_hints.update(annotations_) fields: dict[str, tuple[Any, FieldInfo]] = {} validators: dict[str, AnyClassMethod] = {} - for key, type_hint in type_hints.items(): + for key, type_hint in get_type_hints(model, localns=_get_localns(model)).items(): + # don't override fields that already exist on `base`. + if base is not None and key in base.__fields__: + continue + if get_origin(type_hint) is Mapped: - (type_,) = get_args(type_hint) - else: - type_ = type_hint + (type_hint,) = get_args(type_hint) elem: Column | RelationshipProperty if key in columns: @@ -268,24 +261,26 @@ class User(DeclarativeBase): attrib = _get_dto_attrib(elem) - if _should_exclude_field(purpose, elem, exclude_, attrib): + if _should_exclude_field(purpose, elem, exclude, attrib): continue if attrib.pydantic_type is not None: - type_ = attrib.pydantic_type + type_hint = attrib.pydantic_type for i, func in enumerate(attrib.validators or []): validators[f"_validates_{key}_{i}"] = validator(key, allow_reuse=True)(func) - if isclass(type_) and issubclass(type_, DeclarativeBase): - type_ = factory(f"{name}_{type_.__name__}", type_, purpose=purpose) + if isclass(type_hint) and issubclass(type_hint, DeclarativeBase): + type_hint = factory(f"{name}_{type_hint.__name__}", type_hint, purpose=purpose) - fields[key] = (type_, _construct_field_info(elem, purpose)) + fields[key] = (type_hint, _construct_field_info(elem, purpose)) return create_model( # type:ignore[no-any-return,call-overload] name, - __base__=base, - __cls_kwargs__={"model": model} if base is MapperBind else {}, + __base__=(base, _MapperBind[AnyDeclarative]) + if base is not None + else _MapperBind[AnyDeclarative], + __cls_kwargs__={"model": model}, __module__=getattr(model, "__module__", __name__), __validators__=validators, **fields, @@ -293,76 +288,13 @@ class User(DeclarativeBase): def decorator( - model: type[DeclarativeBase], - purpose: Purpose, - exclude: set[str] | None = None, -) -> Callable[[type[BaseModel]], type[BaseModel]]: - """Create a pydantic model class from a SQLAlchemy declarative ORM class - with validation support. - - This decorator is not recursive so relationships will be ignored (but can be - overridden on the decorated class). - - Pydantic validation is supported using `validator` decorator from `starlite_sqlalchemy.pydantic` - - As for the `factory` function, included fields can be controlled on the SQLAlchemy class - definition by including a "dto" key in the `Column.info` mapping. - - Example: - ```python - class User(DeclarativeBase): - id: Mapped[UUID] = mapped_column( - default=uuid4, primary_key=True, info={"dto": Attrib(mark=dto.Mark.READ_ONLY)} - ) - email: Mapped[str] - password_hash: Mapped[str] = mapped_column(info={"dto": Attrib(mark=dto.Mark.SKIP)}) - - - @dto.dto(User, Purpose.WRITE) - class UserCreate: - @dto.validator("email") - def val_email(cls, v): - if "@" not in v: - raise ValueError("Invalid email") - ``` - - When setting `mapper_bind` (default to `True`), the resulting pydantic model will have a `mapper` method - that can be used to create an instance of the original SQLAlchemy model with values from the pydantic one: - - ```python - @dto.dto(User, Purpose.WRITE) - class UserCreate: - pass - - - user = UserCreate(email="john@email.me") - user_model = user.mapper() # user_model is a User instance - ``` - - Args: - model: The SQLAlchemy model class. - purpose: Is the DTO for write or read operations? - exclude: Explicitly exclude attributes from the DTO. - - Returns: - A Pydantic model that includes only fields that are appropriate to `purpose` and not in - `exclude`, except relationship fields. - """ + model: type[AnyDeclarative], purpose: Purpose, *, exclude: set[str] | None = None +) -> Callable[[type[BaseModel]], type[_MapperBind[AnyDeclarative]]]: + """Infer a Pydantic model from SQLAlchemy model.""" - def wrapper(cls: type[BaseModel]) -> type[BaseModel]: - def wrapped() -> type[BaseModel]: - name = cls.__name__ - cls_module = getmodule(cls) - namespace = vars(cls_module) if cls_module is not None else {} - return factory( - name=name, - model=model, - purpose=purpose, - exclude=exclude, - namespace=namespace, - annotations=get_annotations(cls), - base=cls, - ) + def wrapper(cls: type[BaseModel]) -> type[_MapperBind[AnyDeclarative]]: + def wrapped() -> type[_MapperBind[AnyDeclarative]]: + return factory(cls.__name__, model, purpose, exclude=exclude, base=cls) return wrapped() diff --git a/src/starlite_saqlalchemy/repository/types.py b/src/starlite_saqlalchemy/repository/types.py index fdffad7e..56ac9f60 100644 --- a/src/starlite_saqlalchemy/repository/types.py +++ b/src/starlite_saqlalchemy/repository/types.py @@ -19,7 +19,7 @@ T = TypeVar("T") -class ModelProtocol(Protocol): # pylint: disable=too-few-public-methods +class ModelProtocol(Protocol): """Protocol for repository models.""" @classmethod diff --git a/src/starlite_saqlalchemy/settings.py b/src/starlite_saqlalchemy/settings.py index dc4e809d..81637d48 100644 --- a/src/starlite_saqlalchemy/settings.py +++ b/src/starlite_saqlalchemy/settings.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -# pylint: disable=too-few-public-methods,missing-class-docstring +# pylint: disable=missing-class-docstring from typing import Literal from pydantic import AnyUrl, BaseSettings, PostgresDsn, parse_obj_as diff --git a/tests/unit/test_dto.py b/tests/unit/test_dto.py index 6e4a1fea..3e8ea62a 100644 --- a/tests/unit/test_dto.py +++ b/tests/unit/test_dto.py @@ -1,5 +1,5 @@ """Tests for the dto factory.""" -# pylint: disable=redefined-outer-name,too-few-public-methods,missing-class-docstring +# pylint: disable=missing-class-docstring from datetime import date, datetime, timedelta from typing import TYPE_CHECKING, Any, ClassVar from uuid import UUID, uuid4 @@ -10,7 +10,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column from starlite_saqlalchemy import dto, settings -from tests.utils.domain import Author +from tests.utils.domain import Author, CreateDTO if TYPE_CHECKING: from collections.abc import Callable @@ -50,8 +50,8 @@ def test_dto_exclude() -> None: assert dto_type.__fields__.keys() == {"name", "dob", "created", "updated"} -@pytest.fixture() -def base() -> type[DeclarativeBase]: +@pytest.fixture(name="base") +def fx_base() -> type[DeclarativeBase]: """Declarative base for test models. Need a new base for every test, otherwise will get errors to do with @@ -245,3 +245,11 @@ class Model(MappedAsDataclass, base): dto_model = dto.factory("DTO", Model, purpose=dto.Purpose.WRITE) assert dto_model.__fields__.keys() == {"id", "field"} + + +def test_from_dto() -> None: + """Test conversion of a DTO instance to a model instance.""" + data = CreateDTO.parse_obj({"name": "someone", "dob": "1982-03-22"}) + author = data.to_mapped() + assert author.name == "someone" + assert author.dob == date(1982, 3, 22) diff --git a/tests/unit/test_endpoint_decorator.py b/tests/unit/test_endpoint_decorator.py index ac693d05..5946db5f 100644 --- a/tests/unit/test_endpoint_decorator.py +++ b/tests/unit/test_endpoint_decorator.py @@ -1,5 +1,4 @@ """Tests for endpoint_decorator.py.""" -# pylint: disable=too-few-public-methods from __future__ import annotations import pytest diff --git a/tests/unit/test_orm.py b/tests/unit/test_orm.py index fd25ba22..52ebb942 100644 --- a/tests/unit/test_orm.py +++ b/tests/unit/test_orm.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock from starlite_saqlalchemy.db import orm -from tests.utils.domain import Author, CreateDTO def test_sqla_touch_updated_timestamp() -> None: @@ -13,11 +12,3 @@ def test_sqla_touch_updated_timestamp() -> None: orm.touch_updated_timestamp(mock_session) for mock_instance in mock_session.dirty: assert isinstance(mock_instance.updated, datetime.datetime) - - -def test_from_dto() -> None: - """Test conversion of a DTO instance to a model instance.""" - data = CreateDTO(name="someone", dob="1982-03-22") - author = Author.from_dto(data) - assert author.name == "someone" - assert author.dob == datetime.date(1982, 3, 22) diff --git a/tests/utils/controllers.py b/tests/utils/controllers.py index fb871d76..e3029eae 100644 --- a/tests/utils/controllers.py +++ b/tests/utils/controllers.py @@ -8,7 +8,7 @@ from starlite.status_codes import HTTP_200_OK from starlite_saqlalchemy.repository.types import FilterTypes -from tests.utils.domain import Author, CreateDTO, ReadDTO, Service, UpdateDTO +from tests.utils.domain import CreateDTO, ReadDTO, Service, UpdateDTO DETAIL_ROUTE = "/{author_id:uuid}" @@ -30,7 +30,7 @@ async def get_authors( @post() async def create_author(data: CreateDTO, service: Service) -> ReadDTO: """Create an `Author`.""" - return ReadDTO.from_orm(await service.create(Author.from_dto(data))) + return ReadDTO.from_orm(await service.create(data.to_mapped())) @get(DETAIL_ROUTE) @@ -42,7 +42,7 @@ async def get_author(service: Service, author_id: UUID) -> ReadDTO: @put(DETAIL_ROUTE) async def update_author(data: UpdateDTO, service: Service, author_id: UUID) -> ReadDTO: """Update an author.""" - return ReadDTO.from_orm(await service.update(author_id, Author.from_dto(data))) + return ReadDTO.from_orm(await service.update(author_id, data.to_mapped())) @delete(DETAIL_ROUTE, status_code=HTTP_200_OK) diff --git a/tests/utils/domain.py b/tests/utils/domain.py index 2679c370..588def3b 100644 --- a/tests/utils/domain.py +++ b/tests/utils/domain.py @@ -3,12 +3,13 @@ from datetime import date # noqa: TC003 +from pydantic import BaseModel from sqlalchemy.orm import Mapped from starlite_saqlalchemy import db, dto, repository, service -class Author(db.orm.Base): # pylint: disable=too-few-public-methods +class Author(db.orm.Base): """The Author domain object.""" name: Mapped[str] @@ -27,15 +28,16 @@ class Service(service.Service[Author]): repository_type = Repository -CreateDTO = dto.factory("AuthorCreateDTO", Author, purpose=dto.Purpose.WRITE, exclude={"id"}) -""" -A pydantic model to validate `Author` creation data. -""" -ReadDTO = dto.factory("AuthorReadDTO", Author, purpose=dto.Purpose.READ) +@dto.decorator(Author, dto.Purpose.WRITE) +class CreateDTO(BaseModel): + """A pydantic model to validate `Author` creation data.""" + + +ReadDTO = dto.factory("AuthorReadDTO", Author, dto.Purpose.READ) """ A pydantic model to serialize outbound `Author` representations. """ -UpdateDTO = dto.factory("AuthorUpdateDTO", Author, purpose=dto.Purpose.WRITE) +UpdateDTO = dto.factory("AuthorUpdateDTO", Author, dto.Purpose.WRITE) """ A pydantic model to validate and deserialize `Author` update data. """