Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
feat!: optional audit columns on base class (#264)
Browse files Browse the repository at this point in the history
* feat: makes the `created/updated` indicators on ORM models optional.

* fix: remove unused ignore.

* feat: choice of declarative base with/without audit columns.

* test: coverage! coverage! coverage!

Co-authored-by: Cody Fincher <[email protected]>
  • Loading branch information
peterschutt and cofin committed Jan 20, 2023
1 parent 45982b4 commit 5f4e883
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 96 deletions.
51 changes: 38 additions & 13 deletions src/starlite_saqlalchemy/db/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DeclarativeBase,
Mapped,
Session,
declarative_mixin,
declared_attr,
mapped_column,
registry,
Expand Down Expand Up @@ -47,32 +48,56 @@ def touch_updated_timestamp(session: Session, *_: Any) -> None:
session.
"""
for instance in session.dirty:
instance.updated = datetime.now()
if hasattr(instance, "updated"):
instance.updated = datetime.now()


class Base(DeclarativeBase):
"""Base for all SQLAlchemy declarative models."""
@declarative_mixin
class CommonColumns:
"""Common functionality shared between all declarative models."""

registry = registry(
metadata=MetaData(naming_convention=convention),
type_annotation_map={UUID: pg.UUID, dict: pg.JSONB},
)
__abstract__ = True
__name__: str

id: Mapped[UUID] = mapped_column(
default=uuid4, primary_key=True, info={DTO_KEY: dto.DTOField(mark=dto.Mark.READ_ONLY)}
)
"""Primary key column."""

# noinspection PyMethodParameters
@declared_attr.directive
def __tablename__(cls) -> str: # pylint: disable=no-self-argument
"""Infer table name from class name."""
return cls.__name__.lower()


@declarative_mixin
class AuditColumns:
"""Created/Updated At Fields Mixin."""

__abstract__ = True

created: Mapped[datetime] = mapped_column(
default=datetime.now, info={DTO_KEY: dto.DTOField(mark=dto.Mark.READ_ONLY)}
)
"""Date/time of instance creation."""
updated: Mapped[datetime] = mapped_column(
default=datetime.now, info={DTO_KEY: dto.DTOField(mark=dto.Mark.READ_ONLY)}
)
"""Date/time of instance update."""
"""Date/time of instance last update."""

# noinspection PyMethodParameters
@declared_attr.directive
def __tablename__(cls) -> str: # pylint: disable=no-self-argument
"""Infer table name from class name."""
return cls.__name__.lower()

meta = MetaData(naming_convention=convention)
registry_ = registry(metadata=meta, type_annotation_map={UUID: pg.UUID, dict: pg.JSONB})


class Base(CommonColumns, DeclarativeBase):
"""Base for all SQLAlchemy declarative models."""

registry = registry_


class AuditBase(AuditColumns, CommonColumns, DeclarativeBase):
"""Base for declarative models with audit columns."""

registry = registry_
2 changes: 1 addition & 1 deletion src/starlite_saqlalchemy/repository/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
]

T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="orm.Base")
ModelT = TypeVar("ModelT", bound="orm.Base | orm.AuditBase")
SQLARepoT = TypeVar("SQLARepoT", bound="SQLAlchemyRepository")


Expand Down
6 changes: 3 additions & 3 deletions src/starlite_saqlalchemy/sqlalchemy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
__all__ = ["SQLAlchemyHealthCheck", "config", "plugin"]


async def before_send_handler(message: "Message", _: "State", scope: "Scope") -> None:
async def before_send_handler(message: Message, _: State, scope: Scope) -> None:
"""Inspect status of response and commit, or rolls back.
Args:
Expand Down Expand Up @@ -61,8 +61,8 @@ async def ready(self) -> bool:
Returns:
`True` if healthy.
"""
async with self.session_maker() as session:
return ( # type:ignore[no-any-return] # pragma: no cover
async with self.session_maker() as session: # pragma: no cover
return ( # type:ignore[no-any-return]
await session.execute(text("SELECT 1"))
).scalar_one() == 1

Expand Down
3 changes: 1 addition & 2 deletions src/starlite_saqlalchemy/testing/controller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pytest import MonkeyPatch
from starlite.testing import TestClient

from starlite_saqlalchemy.db import orm
from starlite_saqlalchemy.service import Service


Expand All @@ -28,7 +27,7 @@ def __init__(
self,
client: TestClient,
base_path: str,
collection: Sequence[orm.Base],
collection: Sequence[Any],
raw_collection: Sequence[dict[str, Any]],
service_type: type[Service],
monkeypatch: MonkeyPatch,
Expand Down
10 changes: 7 additions & 3 deletions src/starlite_saqlalchemy/testing/generic_mock_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from starlite_saqlalchemy.repository.types import FilterTypes

ModelT = TypeVar("ModelT", bound=orm.Base)
ModelT = TypeVar("ModelT", bound=orm.Base | orm.AuditBase)
MockRepoT = TypeVar("MockRepoT", bound="GenericMockRepository")


Expand Down Expand Up @@ -62,7 +62,9 @@ async def add(self, data: ModelT, allow_id: bool = False) -> ModelT:
if allow_id is False and self.get_id_attribute_value(data) is not None:
raise ConflictError("`add()` received identified item.")
now = datetime.now()
data.updated = data.created = now
if hasattr(data, "updated") and hasattr(data, "created"):
# maybe the @declarative_mixin decorator doesn't play nice with pyright?
data.updated = data.created = now # pyright: ignore
if allow_id is False:
id_ = self._id_factory()
self.set_id_attribute_value(id_, data)
Expand Down Expand Up @@ -127,7 +129,9 @@ async def update(self, data: ModelT) -> ModelT:
"""
item = self._find_or_raise_not_found(self.get_id_attribute_value(data))
# should never be modifiable
data.updated = datetime.now()
if hasattr(data, "updated"):
# maybe the @declarative_mixin decorator doesn't play nice with pyright?
data.updated = datetime.now() # pyright: ignore
for key, val in data.__dict__.items():
if key.startswith("_"):
continue
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
if TYPE_CHECKING:
from typing import Any

from pytest import MonkeyPatch

from starlite_saqlalchemy.testing.generic_mock_repository import (
GenericMockRepository,
)
Expand All @@ -20,6 +22,27 @@
collect_ignore_glob = ["*"]


@pytest.fixture(autouse=True)
def _patch_bases(monkeypatch: MonkeyPatch) -> None:
"""Ensure new registry state for every test.
This prevents errors such as "Table '...' is already defined for
this MetaData instance...
"""
from sqlalchemy.orm import DeclarativeBase

from starlite_saqlalchemy.db import orm

class NewBase(orm.CommonColumns, DeclarativeBase):
...

class NewAuditBase(orm.AuditColumns, orm.CommonColumns, DeclarativeBase):
...

monkeypatch.setattr(orm, "Base", NewBase)
monkeypatch.setattr(orm, "AuditBase", NewAuditBase)


@pytest.fixture(name="authors")
def fx_authors(raw_authors: list[dict[str, Any]]) -> list[Author]:
"""Collection of parsed Author models."""
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/sqlalchemy/repository/test_generic_mock_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from starlite_saqlalchemy.db import orm
from starlite_saqlalchemy.exceptions import ConflictError, StarliteSaqlalchemyError
from starlite_saqlalchemy.testing.generic_mock_repository import GenericMockRepository
from tests.utils.domain.authors import Author
Expand Down Expand Up @@ -67,3 +68,54 @@ def test_generic_mock_repository_raises_repository_exception_if_named_attribute_
exist."""
with pytest.raises(StarliteSaqlalchemyError):
author_repository.filter_collection_by_kwargs(cricket="ball")


async def test_sets_created_updated_on_add() -> None:
"""Test that the repository updates the 'created' and 'updated' timestamps
if necessary."""

class Model(orm.AuditBase):
"""Inheriting from AuditBase gives the model 'created' and 'updated'
columns."""

...

instance = Model()
assert "created" not in vars(instance)
assert "updated" not in vars(instance)

instance = await GenericMockRepository[Model]().add(instance)
assert "created" in vars(instance)
assert "updated" in vars(instance)


async def test_sets_updated_on_update(author_repository: GenericMockRepository[Author]) -> None:
"""Test that the repository updates the 'updated' timestamp if
necessary."""

instance = list(author_repository.collection.values())[0]
original_updated = instance.updated
instance = await author_repository.update(instance)
assert instance.updated > original_updated


async def test_does_not_set_created_updated() -> None:
"""Test that the repository does not update the 'updated' timestamps when
appropriate."""

class Model(orm.Base):
"""Inheriting from Base means the model has no created/updated
timestamp columns."""

...

instance = Model()
repo = GenericMockRepository[Model]()
assert "created" not in vars(instance)
assert "updated" not in vars(instance)
instance = await repo.add(instance)
assert "created" not in vars(instance)
assert "updated" not in vars(instance)
instance = await repo.update(instance)
assert "created" not in vars(instance)
assert "updated" not in vars(instance)
Loading

0 comments on commit 5f4e883

Please sign in to comment.