Skip to content

Commit

Permalink
feat(dto): more iteration.
Browse files Browse the repository at this point in the history
Further development of the decorator pattern.
  • Loading branch information
peterschutt committed Nov 25, 2022
1 parent 400f0dd commit c52b055
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 149 deletions.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

check_untyped_defs = True
disallow_any_generics = False
disallow_incomplete_defs = True
disallow_incomplete_defs = False
disallow_untyped_decorators = True
disallow_untyped_defs = True
ignore_missing_imports = False
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ testpaths = ["tests/unit"]
disable = [
"line-too-long",
"no-self-argument",
"too-few-public-methods",
"too-many-arguments",
]
enable = "useless-suppression"
Expand Down
176 changes: 54 additions & 122 deletions src/starlite_saqlalchemy/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@
from __future__ import annotations

from enum import Enum, auto
from inspect import get_annotations, getmodule, isclass
from inspect import getmodule, isclass
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
NamedTuple,
TypedDict,
TypeVar,
cast,
get_args,
get_origin,
get_type_hints,
)

from pydantic import BaseConfig as BaseConfig_
from pydantic import BaseModel, create_model, validator
from pydantic.fields import FieldInfo
from sqlalchemy import inspect
Expand All @@ -40,6 +41,9 @@
from sqlalchemy.util import ReadOnlyProperties


AnyDeclarative = TypeVar("AnyDeclarative", bound=DeclarativeBase)


class Mark(str, Enum):
"""For marking column definitions on the domain models.
Expand Down Expand Up @@ -90,31 +94,24 @@ class Attrib(NamedTuple):
"""Single argument callables that are defined on the DTO as validators for the field."""


class BaseConfig(BaseConfig_):
"""Base config for generated pydantic models."""

orm_mode = True


class MapperBind(BaseModel):
class _MapperBind(BaseModel, Generic[AnyDeclarative]):
"""Produce an SQLAlchemy instance with values from a pydantic model."""

__sqla_model__: ClassVar[type[DeclarativeBase]]

class Config(BaseConfig):
"""Config for MapperBind pydantic models."""
class Config:
"""Set orm_mode for `to_mapped()` method."""

@classmethod
def __init_subclass__(cls, model: type[DeclarativeBase], **kwargs: Any) -> None:
"""Set `__sqla_model__` class var.
orm_mode = True

Args:
model: SQLAlchemy model represented by the DTO.
"""
cls.__sqla_model__ = model
return super().__init_subclass__(**kwargs)
def __init_subclass__( # pylint: disable=arguments-differ
cls, model: type[DeclarativeBase] | None = None, **kwargs: Any
) -> None:
if model is not None:
cls.__sqla_model__ = model
super().__init_subclass__(**kwargs)

def to_mapped(self) -> DeclarativeBase:
def to_mapped(self) -> AnyDeclarative:
"""Create an instance of `self.__sqla_model__`
Fill the bound SQLAlchemy model recursively with values from
Expand All @@ -124,11 +121,11 @@ def to_mapped(self) -> DeclarativeBase:
for field in self.__fields__.values():
value = getattr(self, field.name)
if isinstance(value, (list, tuple)):
value = [el.to_mapped() if isinstance(el, MapperBind) else el for el in value]
if isinstance(value, MapperBind):
value = [el.to_mapped() if isinstance(el, _MapperBind) else el for el in value]
if isinstance(value, _MapperBind):
value = value.to_mapped()
as_model[field.name] = value
return self.__sqla_model__(**as_model)
return cast("AnyDeclarative", self.__sqla_model__(**as_model))


def _construct_field_info(elem: Column | RelationshipProperty, purpose: Purpose) -> FieldInfo:
Expand Down Expand Up @@ -172,6 +169,11 @@ def _inspect_model(
return columns, relationships


def _get_localns(model: type[DeclarativeBase]) -> dict[str, Any]:
model_module = getmodule(model)
return vars(model_module) if model_module is not None else {}


def mark(mark_type: Mark) -> DTOInfo:
"""Shortcut for ```python.
Expand All @@ -198,16 +200,15 @@ class User(DeclarativeBase):
return {"dto": Attrib(mark=mark_type)}


def factory( # pylint: disable=too-many-locals
def factory(
name: str,
model: type[DeclarativeBase],
model: type[AnyDeclarative],
purpose: Purpose,
*,
exclude: set[str] | None = None,
namespace: dict[str, Any] | None = None,
annotations: dict[str, Any] | None = None,
base: type[BaseModel] = MapperBind,
) -> type[BaseModel]:
"""Create a pydantic model class from a SQLAlchemy declarative ORM class.
base: type[BaseModel] | None = None,
) -> type[_MapperBind[AnyDeclarative]]:
"""Infer a Pydantic model from a SQLAlchemy model.
The fields that are included in the model can be controlled on the SQLAlchemy class
definition by including a "dto" key in the `Column.info` mapping. For example:
Expand All @@ -230,32 +231,24 @@ class User(DeclarativeBase):
model: The SQLAlchemy model class.
purpose: Is the DTO for write or read operations?
exclude: Explicitly exclude attributes from the DTO.
namespace: Additional namespace used to resolve forward references. The default namespace
used is that of the module of `model`.
annotations: Annotations that override and supplement the annotations derived from `model`.
base: A subclass of `pydantic.BaseModel` to be used as the base class of the DTO.
Returns:
A Pydantic model that includes only fields that are appropriate to `purpose` and not in
`exclude`.
"""
exclude_ = exclude or set[str]()
annotations_ = annotations or {}
namespace_ = namespace or {}
del exclude, annotations, namespace
exclude = set() if exclude is None else exclude

columns, relationships = _inspect_model(model)
model_module = getmodule(model)
localns = vars(model_module) if model_module is not None else {}
localns.update(namespace_)
type_hints = get_type_hints(model, localns=localns)
type_hints.update(annotations_)
fields: dict[str, tuple[Any, FieldInfo]] = {}
validators: dict[str, AnyClassMethod] = {}
for key, type_hint in type_hints.items():
for key, type_hint in get_type_hints(model, localns=_get_localns(model)).items():
# don't override fields that already exist on `base`.
if base is not None and key in base.__fields__:
continue

if get_origin(type_hint) is Mapped:
(type_,) = get_args(type_hint)
else:
type_ = type_hint
(type_hint,) = get_args(type_hint)

elem: Column | RelationshipProperty
if key in columns:
Expand All @@ -268,101 +261,40 @@ class User(DeclarativeBase):

attrib = _get_dto_attrib(elem)

if _should_exclude_field(purpose, elem, exclude_, attrib):
if _should_exclude_field(purpose, elem, exclude, attrib):
continue

if attrib.pydantic_type is not None:
type_ = attrib.pydantic_type
type_hint = attrib.pydantic_type

for i, func in enumerate(attrib.validators or []):
validators[f"_validates_{key}_{i}"] = validator(key, allow_reuse=True)(func)

if isclass(type_) and issubclass(type_, DeclarativeBase):
type_ = factory(f"{name}_{type_.__name__}", type_, purpose=purpose)
if isclass(type_hint) and issubclass(type_hint, DeclarativeBase):
type_hint = factory(f"{name}_{type_hint.__name__}", type_hint, purpose=purpose)

fields[key] = (type_, _construct_field_info(elem, purpose))
fields[key] = (type_hint, _construct_field_info(elem, purpose))

return create_model( # type:ignore[no-any-return,call-overload]
name,
__base__=base,
__cls_kwargs__={"model": model} if base is MapperBind else {},
__base__=(base, _MapperBind[AnyDeclarative])
if base is not None
else _MapperBind[AnyDeclarative],
__cls_kwargs__={"model": model},
__module__=getattr(model, "__module__", __name__),
__validators__=validators,
**fields,
)


def decorator(
model: type[DeclarativeBase],
purpose: Purpose,
exclude: set[str] | None = None,
) -> Callable[[type[BaseModel]], type[BaseModel]]:
"""Create a pydantic model class from a SQLAlchemy declarative ORM class
with validation support.
This decorator is not recursive so relationships will be ignored (but can be
overridden on the decorated class).
Pydantic validation is supported using `validator` decorator from `starlite_sqlalchemy.pydantic`
As for the `factory` function, included fields can be controlled on the SQLAlchemy class
definition by including a "dto" key in the `Column.info` mapping.
Example:
```python
class User(DeclarativeBase):
id: Mapped[UUID] = mapped_column(
default=uuid4, primary_key=True, info={"dto": Attrib(mark=dto.Mark.READ_ONLY)}
)
email: Mapped[str]
password_hash: Mapped[str] = mapped_column(info={"dto": Attrib(mark=dto.Mark.SKIP)})
@dto.dto(User, Purpose.WRITE)
class UserCreate:
@dto.validator("email")
def val_email(cls, v):
if "@" not in v:
raise ValueError("Invalid email")
```
When setting `mapper_bind` (default to `True`), the resulting pydantic model will have a `mapper` method
that can be used to create an instance of the original SQLAlchemy model with values from the pydantic one:
```python
@dto.dto(User, Purpose.WRITE)
class UserCreate:
pass
user = UserCreate(email="[email protected]")
user_model = user.mapper() # user_model is a User instance
```
Args:
model: The SQLAlchemy model class.
purpose: Is the DTO for write or read operations?
exclude: Explicitly exclude attributes from the DTO.
Returns:
A Pydantic model that includes only fields that are appropriate to `purpose` and not in
`exclude`, except relationship fields.
"""
model: type[AnyDeclarative], purpose: Purpose, *, exclude: set[str] | None = None
) -> Callable[[type[BaseModel]], type[_MapperBind[AnyDeclarative]]]:
"""Infer a Pydantic model from SQLAlchemy model."""

def wrapper(cls: type[BaseModel]) -> type[BaseModel]:
def wrapped() -> type[BaseModel]:
name = cls.__name__
cls_module = getmodule(cls)
namespace = vars(cls_module) if cls_module is not None else {}
return factory(
name=name,
model=model,
purpose=purpose,
exclude=exclude,
namespace=namespace,
annotations=get_annotations(cls),
base=cls,
)
def wrapper(cls: type[BaseModel]) -> type[_MapperBind[AnyDeclarative]]:
def wrapped() -> type[_MapperBind[AnyDeclarative]]:
return factory(cls.__name__, model, purpose, exclude=exclude, base=cls)

return wrapped()

Expand Down
2 changes: 1 addition & 1 deletion src/starlite_saqlalchemy/repository/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
T = TypeVar("T")


class ModelProtocol(Protocol): # pylint: disable=too-few-public-methods
class ModelProtocol(Protocol):
"""Protocol for repository models."""

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/starlite_saqlalchemy/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
from __future__ import annotations

# pylint: disable=too-few-public-methods,missing-class-docstring
# pylint: disable=missing-class-docstring
from typing import Literal

from pydantic import AnyUrl, BaseSettings, PostgresDsn, parse_obj_as
Expand Down
16 changes: 12 additions & 4 deletions tests/unit/test_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Tests for the dto factory."""
# pylint: disable=redefined-outer-name,too-few-public-methods,missing-class-docstring
# pylint: disable=missing-class-docstring
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Any, ClassVar
from uuid import UUID, uuid4
Expand All @@ -10,7 +10,7 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column

from starlite_saqlalchemy import dto, settings
from tests.utils.domain import Author
from tests.utils.domain import Author, CreateDTO

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -50,8 +50,8 @@ def test_dto_exclude() -> None:
assert dto_type.__fields__.keys() == {"name", "dob", "created", "updated"}


@pytest.fixture()
def base() -> type[DeclarativeBase]:
@pytest.fixture(name="base")
def fx_base() -> type[DeclarativeBase]:
"""Declarative base for test models.
Need a new base for every test, otherwise will get errors to do with
Expand Down Expand Up @@ -245,3 +245,11 @@ class Model(MappedAsDataclass, base):

dto_model = dto.factory("DTO", Model, purpose=dto.Purpose.WRITE)
assert dto_model.__fields__.keys() == {"id", "field"}


def test_from_dto() -> None:
"""Test conversion of a DTO instance to a model instance."""
data = CreateDTO.parse_obj({"name": "someone", "dob": "1982-03-22"})
author = data.to_mapped()
assert author.name == "someone"
assert author.dob == date(1982, 3, 22)
1 change: 0 additions & 1 deletion tests/unit/test_endpoint_decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Tests for endpoint_decorator.py."""
# pylint: disable=too-few-public-methods
from __future__ import annotations

import pytest
Expand Down
9 changes: 0 additions & 9 deletions tests/unit/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from unittest.mock import MagicMock

from starlite_saqlalchemy.db import orm
from tests.utils.domain import Author, CreateDTO


def test_sqla_touch_updated_timestamp() -> None:
Expand All @@ -13,11 +12,3 @@ def test_sqla_touch_updated_timestamp() -> None:
orm.touch_updated_timestamp(mock_session)
for mock_instance in mock_session.dirty:
assert isinstance(mock_instance.updated, datetime.datetime)


def test_from_dto() -> None:
"""Test conversion of a DTO instance to a model instance."""
data = CreateDTO(name="someone", dob="1982-03-22")
author = Author.from_dto(data)
assert author.name == "someone"
assert author.dob == datetime.date(1982, 3, 22)
Loading

0 comments on commit c52b055

Please sign in to comment.