forked from topsport-com-au/starlite-saqlalchemy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Further development of the decorator pattern.
- Loading branch information
1 parent
400f0dd
commit c52b055
Showing
10 changed files
with
82 additions
and
149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,20 +9,21 @@ | |
from __future__ import annotations | ||
|
||
from enum import Enum, auto | ||
from inspect import get_annotations, getmodule, isclass | ||
from inspect import getmodule, isclass | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
ClassVar, | ||
Generic, | ||
NamedTuple, | ||
TypedDict, | ||
TypeVar, | ||
cast, | ||
get_args, | ||
get_origin, | ||
get_type_hints, | ||
) | ||
|
||
from pydantic import BaseConfig as BaseConfig_ | ||
from pydantic import BaseModel, create_model, validator | ||
from pydantic.fields import FieldInfo | ||
from sqlalchemy import inspect | ||
|
@@ -40,6 +41,9 @@ | |
from sqlalchemy.util import ReadOnlyProperties | ||
|
||
|
||
AnyDeclarative = TypeVar("AnyDeclarative", bound=DeclarativeBase) | ||
|
||
|
||
class Mark(str, Enum): | ||
"""For marking column definitions on the domain models. | ||
|
@@ -90,31 +94,24 @@ class Attrib(NamedTuple): | |
"""Single argument callables that are defined on the DTO as validators for the field.""" | ||
|
||
|
||
class BaseConfig(BaseConfig_): | ||
"""Base config for generated pydantic models.""" | ||
|
||
orm_mode = True | ||
|
||
|
||
class MapperBind(BaseModel): | ||
class _MapperBind(BaseModel, Generic[AnyDeclarative]): | ||
"""Produce an SQLAlchemy instance with values from a pydantic model.""" | ||
|
||
__sqla_model__: ClassVar[type[DeclarativeBase]] | ||
|
||
class Config(BaseConfig): | ||
"""Config for MapperBind pydantic models.""" | ||
class Config: | ||
"""Set orm_mode for `to_mapped()` method.""" | ||
|
||
@classmethod | ||
def __init_subclass__(cls, model: type[DeclarativeBase], **kwargs: Any) -> None: | ||
"""Set `__sqla_model__` class var. | ||
orm_mode = True | ||
|
||
Args: | ||
model: SQLAlchemy model represented by the DTO. | ||
""" | ||
cls.__sqla_model__ = model | ||
return super().__init_subclass__(**kwargs) | ||
def __init_subclass__( # pylint: disable=arguments-differ | ||
cls, model: type[DeclarativeBase] | None = None, **kwargs: Any | ||
) -> None: | ||
if model is not None: | ||
cls.__sqla_model__ = model | ||
super().__init_subclass__(**kwargs) | ||
|
||
def to_mapped(self) -> DeclarativeBase: | ||
def to_mapped(self) -> AnyDeclarative: | ||
"""Create an instance of `self.__sqla_model__` | ||
Fill the bound SQLAlchemy model recursively with values from | ||
|
@@ -124,11 +121,11 @@ def to_mapped(self) -> DeclarativeBase: | |
for field in self.__fields__.values(): | ||
value = getattr(self, field.name) | ||
if isinstance(value, (list, tuple)): | ||
value = [el.to_mapped() if isinstance(el, MapperBind) else el for el in value] | ||
if isinstance(value, MapperBind): | ||
value = [el.to_mapped() if isinstance(el, _MapperBind) else el for el in value] | ||
if isinstance(value, _MapperBind): | ||
value = value.to_mapped() | ||
as_model[field.name] = value | ||
return self.__sqla_model__(**as_model) | ||
return cast("AnyDeclarative", self.__sqla_model__(**as_model)) | ||
|
||
|
||
def _construct_field_info(elem: Column | RelationshipProperty, purpose: Purpose) -> FieldInfo: | ||
|
@@ -172,6 +169,11 @@ def _inspect_model( | |
return columns, relationships | ||
|
||
|
||
def _get_localns(model: type[DeclarativeBase]) -> dict[str, Any]: | ||
model_module = getmodule(model) | ||
return vars(model_module) if model_module is not None else {} | ||
|
||
|
||
def mark(mark_type: Mark) -> DTOInfo: | ||
"""Shortcut for ```python. | ||
|
@@ -198,16 +200,15 @@ class User(DeclarativeBase): | |
return {"dto": Attrib(mark=mark_type)} | ||
|
||
|
||
def factory( # pylint: disable=too-many-locals | ||
def factory( | ||
name: str, | ||
model: type[DeclarativeBase], | ||
model: type[AnyDeclarative], | ||
purpose: Purpose, | ||
*, | ||
exclude: set[str] | None = None, | ||
namespace: dict[str, Any] | None = None, | ||
annotations: dict[str, Any] | None = None, | ||
base: type[BaseModel] = MapperBind, | ||
) -> type[BaseModel]: | ||
"""Create a pydantic model class from a SQLAlchemy declarative ORM class. | ||
base: type[BaseModel] | None = None, | ||
) -> type[_MapperBind[AnyDeclarative]]: | ||
"""Infer a Pydantic model from a SQLAlchemy model. | ||
The fields that are included in the model can be controlled on the SQLAlchemy class | ||
definition by including a "dto" key in the `Column.info` mapping. For example: | ||
|
@@ -230,32 +231,24 @@ class User(DeclarativeBase): | |
model: The SQLAlchemy model class. | ||
purpose: Is the DTO for write or read operations? | ||
exclude: Explicitly exclude attributes from the DTO. | ||
namespace: Additional namespace used to resolve forward references. The default namespace | ||
used is that of the module of `model`. | ||
annotations: Annotations that override and supplement the annotations derived from `model`. | ||
base: A subclass of `pydantic.BaseModel` to be used as the base class of the DTO. | ||
Returns: | ||
A Pydantic model that includes only fields that are appropriate to `purpose` and not in | ||
`exclude`. | ||
""" | ||
exclude_ = exclude or set[str]() | ||
annotations_ = annotations or {} | ||
namespace_ = namespace or {} | ||
del exclude, annotations, namespace | ||
exclude = set() if exclude is None else exclude | ||
|
||
columns, relationships = _inspect_model(model) | ||
model_module = getmodule(model) | ||
localns = vars(model_module) if model_module is not None else {} | ||
localns.update(namespace_) | ||
type_hints = get_type_hints(model, localns=localns) | ||
type_hints.update(annotations_) | ||
fields: dict[str, tuple[Any, FieldInfo]] = {} | ||
validators: dict[str, AnyClassMethod] = {} | ||
for key, type_hint in type_hints.items(): | ||
for key, type_hint in get_type_hints(model, localns=_get_localns(model)).items(): | ||
# don't override fields that already exist on `base`. | ||
if base is not None and key in base.__fields__: | ||
continue | ||
|
||
if get_origin(type_hint) is Mapped: | ||
(type_,) = get_args(type_hint) | ||
else: | ||
type_ = type_hint | ||
(type_hint,) = get_args(type_hint) | ||
|
||
elem: Column | RelationshipProperty | ||
if key in columns: | ||
|
@@ -268,101 +261,40 @@ class User(DeclarativeBase): | |
|
||
attrib = _get_dto_attrib(elem) | ||
|
||
if _should_exclude_field(purpose, elem, exclude_, attrib): | ||
if _should_exclude_field(purpose, elem, exclude, attrib): | ||
continue | ||
|
||
if attrib.pydantic_type is not None: | ||
type_ = attrib.pydantic_type | ||
type_hint = attrib.pydantic_type | ||
|
||
for i, func in enumerate(attrib.validators or []): | ||
validators[f"_validates_{key}_{i}"] = validator(key, allow_reuse=True)(func) | ||
|
||
if isclass(type_) and issubclass(type_, DeclarativeBase): | ||
type_ = factory(f"{name}_{type_.__name__}", type_, purpose=purpose) | ||
if isclass(type_hint) and issubclass(type_hint, DeclarativeBase): | ||
type_hint = factory(f"{name}_{type_hint.__name__}", type_hint, purpose=purpose) | ||
|
||
fields[key] = (type_, _construct_field_info(elem, purpose)) | ||
fields[key] = (type_hint, _construct_field_info(elem, purpose)) | ||
|
||
return create_model( # type:ignore[no-any-return,call-overload] | ||
name, | ||
__base__=base, | ||
__cls_kwargs__={"model": model} if base is MapperBind else {}, | ||
__base__=(base, _MapperBind[AnyDeclarative]) | ||
if base is not None | ||
else _MapperBind[AnyDeclarative], | ||
__cls_kwargs__={"model": model}, | ||
__module__=getattr(model, "__module__", __name__), | ||
__validators__=validators, | ||
**fields, | ||
) | ||
|
||
|
||
def decorator( | ||
model: type[DeclarativeBase], | ||
purpose: Purpose, | ||
exclude: set[str] | None = None, | ||
) -> Callable[[type[BaseModel]], type[BaseModel]]: | ||
"""Create a pydantic model class from a SQLAlchemy declarative ORM class | ||
with validation support. | ||
This decorator is not recursive so relationships will be ignored (but can be | ||
overridden on the decorated class). | ||
Pydantic validation is supported using `validator` decorator from `starlite_sqlalchemy.pydantic` | ||
As for the `factory` function, included fields can be controlled on the SQLAlchemy class | ||
definition by including a "dto" key in the `Column.info` mapping. | ||
Example: | ||
```python | ||
class User(DeclarativeBase): | ||
id: Mapped[UUID] = mapped_column( | ||
default=uuid4, primary_key=True, info={"dto": Attrib(mark=dto.Mark.READ_ONLY)} | ||
) | ||
email: Mapped[str] | ||
password_hash: Mapped[str] = mapped_column(info={"dto": Attrib(mark=dto.Mark.SKIP)}) | ||
@dto.dto(User, Purpose.WRITE) | ||
class UserCreate: | ||
@dto.validator("email") | ||
def val_email(cls, v): | ||
if "@" not in v: | ||
raise ValueError("Invalid email") | ||
``` | ||
When setting `mapper_bind` (default to `True`), the resulting pydantic model will have a `mapper` method | ||
that can be used to create an instance of the original SQLAlchemy model with values from the pydantic one: | ||
```python | ||
@dto.dto(User, Purpose.WRITE) | ||
class UserCreate: | ||
pass | ||
user = UserCreate(email="[email protected]") | ||
user_model = user.mapper() # user_model is a User instance | ||
``` | ||
Args: | ||
model: The SQLAlchemy model class. | ||
purpose: Is the DTO for write or read operations? | ||
exclude: Explicitly exclude attributes from the DTO. | ||
Returns: | ||
A Pydantic model that includes only fields that are appropriate to `purpose` and not in | ||
`exclude`, except relationship fields. | ||
""" | ||
model: type[AnyDeclarative], purpose: Purpose, *, exclude: set[str] | None = None | ||
) -> Callable[[type[BaseModel]], type[_MapperBind[AnyDeclarative]]]: | ||
"""Infer a Pydantic model from SQLAlchemy model.""" | ||
|
||
def wrapper(cls: type[BaseModel]) -> type[BaseModel]: | ||
def wrapped() -> type[BaseModel]: | ||
name = cls.__name__ | ||
cls_module = getmodule(cls) | ||
namespace = vars(cls_module) if cls_module is not None else {} | ||
return factory( | ||
name=name, | ||
model=model, | ||
purpose=purpose, | ||
exclude=exclude, | ||
namespace=namespace, | ||
annotations=get_annotations(cls), | ||
base=cls, | ||
) | ||
def wrapper(cls: type[BaseModel]) -> type[_MapperBind[AnyDeclarative]]: | ||
def wrapped() -> type[_MapperBind[AnyDeclarative]]: | ||
return factory(cls.__name__, model, purpose, exclude=exclude, base=cls) | ||
|
||
return wrapped() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.