Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

feat(repo): abstract method filter_collection_by_kwargs() #159

Merged
merged 1 commit into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/starlite_saqlalchemy/repository/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ async def upsert(self, data: T) -> T:
RepositoryNotFoundException: If no instance found with same identifier as `data`.
"""

@abstractmethod
def filter_collection_by_kwargs(self, **kwargs: Any) -> None:
"""Filter the collection by kwargs.

Has `AND` semantics where multiple kwargs name/value pairs are provided.

Args:
**kwargs: key/value pairs such that objects remaining in the collection after filtering
have the property that their attribute named `key` has value equal to `value`.

Raises:
RepositoryException: if a named attribute doesn't exist on `self.model_type`.
"""

@staticmethod
def check_not_found(item_or_none: T | None) -> T:
"""Raise `RepositoryNotFoundException` if `item_or_none` is `None`.
Expand Down
10 changes: 10 additions & 0 deletions src/starlite_saqlalchemy/repository/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ async def upsert(self, data: ModelT) -> ModelT:
self.session.expunge(instance)
return instance

def filter_collection_by_kwargs(self, **kwargs: Any) -> None:
"""Filter the collection by kwargs.

Args:
**kwargs: key/value pairs such that objects remaining in the collection after filtering
have the property that their attribute named `key` has value equal to `value`.
"""
with wrap_sqlalchemy_exception():
self._select.filter_by(**kwargs)

@classmethod
async def check_health(cls, session: AsyncSession) -> bool:
"""Perform a health check on the database.
Expand Down
21 changes: 20 additions & 1 deletion src/starlite_saqlalchemy/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

from starlite_saqlalchemy.db import orm
from starlite_saqlalchemy.repository.abc import AbstractRepository
from starlite_saqlalchemy.repository.exceptions import RepositoryConflictException
from starlite_saqlalchemy.repository.exceptions import (
RepositoryConflictException,
RepositoryException,
)

if TYPE_CHECKING:
from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence
Expand Down Expand Up @@ -162,6 +165,22 @@ async def upsert(self, data: ModelT) -> ModelT:
return await self.update(data)
return await self.add(data, allow_id=True)

def filter_collection_by_kwargs(self, **kwargs: Any) -> None:
"""Filter the collection by kwargs.

Args:
**kwargs: key/value pairs such that objects remaining in the collection after filtering
have the property that their attribute named `key` has value equal to `value`.
"""
new_collection: dict[Hashable, ModelT] = {}
for item in self.collection.values():
try:
if all(getattr(item, name) == value for name, value in kwargs.items()):
new_collection[item.id] = item
except AttributeError as orig:
raise RepositoryException from orig
self.collection = new_collection

@classmethod
def seed_collection(cls, instances: Iterable[ModelT]) -> None:
"""Seed the collection for repository type.
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/repository/test_sqlalchemy_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker

from starlite_saqlalchemy.repository.exceptions import RepositoryException
from tests.utils.domain import authors


@pytest.fixture(name="session")
def fx_session(engine: AsyncEngine) -> AsyncSession:
return async_sessionmaker(bind=engine)()


@pytest.fixture(name="repo")
def fx_repo(session: AsyncSession) -> authors.Repository:
return authors.Repository(session=session)


def test_filter_by_kwargs_with_incorrect_attribute_name(repo: authors.Repository) -> None:
with pytest.raises(RepositoryException):
repo.filter_collection_by_kwargs(whoops="silly me")
20 changes: 19 additions & 1 deletion tests/unit/repository/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from unittest.mock import AsyncMock, MagicMock, call

import pytest
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.exc import IntegrityError, InvalidRequestError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession

from starlite_saqlalchemy.repository.exceptions import (
Expand Down Expand Up @@ -229,3 +229,21 @@ def test__filter_on_datetime_field(
field_mock.__gt__ = field_mock.__lt__ = lambda self, other: True
mock_repo.model_type.updated = field_mock
mock_repo._filter_on_datetime_field("updated", before, after)


def test_filter_collection_by_kwargs(mock_repo: SQLAlchemyRepository) -> None:
"""Test `filter_by()` called with kwargs."""
mock_repo.filter_collection_by_kwargs(a=1, b=2)
mock_repo._select.filter_by.assert_called_once_with(a=1, b=2)


def test_filter_collection_by_kwargs_raises_repository_exception_for_attribute_error(
mock_repo: SQLAlchemyRepository,
) -> None:
"""Test that we raise a repository exception if an attribute name is
incorrect."""
mock_repo._select.filter_by = MagicMock( # type:ignore[assignment]
side_effect=InvalidRequestError,
)
with pytest.raises(RepositoryException):
mock_repo.filter_collection_by_kwargs(a=1)
32 changes: 31 additions & 1 deletion tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from starlite.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND

from starlite_saqlalchemy import testing
from starlite_saqlalchemy.repository.exceptions import RepositoryConflictException
from starlite_saqlalchemy.repository.exceptions import (
RepositoryConflictException,
RepositoryException,
)
from tests.utils.domain.authors import Author
from tests.utils.domain.authors import Service as AuthorService
from tests.utils.domain.books import Book
Expand Down Expand Up @@ -54,6 +57,33 @@ def test_generic_mock_repository_clear_collection(
assert not author_repository_type.collection


def test_generic_mock_repository_filter_collection_by_kwargs(
author_repository: testing.GenericMockRepository[Author],
) -> None:
"""Test filtering the repository collection by kwargs."""
author_repository.filter_collection_by_kwargs(name="Leo Tolstoy")
assert len(author_repository.collection) == 1
assert list(author_repository.collection.values())[0].name == "Leo Tolstoy"


def test_generic_mock_repository_filter_collection_by_kwargs_and_semantics(
author_repository: testing.GenericMockRepository[Author],
) -> None:
"""Test that filtering by kwargs has `AND` semantics when multiple kwargs,
not `OR`."""
author_repository.filter_collection_by_kwargs(name="Agatha Christie", dob="1828-09-09")
assert len(author_repository.collection) == 0


def test_generic_mock_repository_raises_repository_exception_if_named_attribute_doesnt_exist(
author_repository: testing.GenericMockRepository[Author],
) -> None:
"""Test that a repo exception is raised if a named attribute doesn't
exist."""
with pytest.raises(RepositoryException):
author_repository.filter_collection_by_kwargs(cricket="ball")


@pytest.fixture(name="mock_response")
def fx_mock_response() -> MagicMock:
"""Mock response for returning from mock client requests."""
Expand Down