diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 598eeb63..7c0597d9 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -112,6 +112,20 @@ async def upsert(self, data: T) -> T: RepositoryNotFoundException: If no instance found with same identifier as `data`. """ + @abstractmethod + def filter_collection_by_kwargs(self, **kwargs: Any) -> None: + """Filter the collection by kwargs. + + Has `AND` semantics where multiple kwargs name/value pairs are provided. + + Args: + **kwargs: key/value pairs such that objects remaining in the collection after filtering + have the property that their attribute named `key` has value equal to `value`. + + Raises: + RepositoryException: if a named attribute doesn't exist on `self.model_type`. + """ + @staticmethod def check_not_found(item_or_none: T | None) -> T: """Raise `RepositoryNotFoundException` if `item_or_none` is `None`. diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index 58f611a0..582f079c 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -206,6 +206,16 @@ async def upsert(self, data: ModelT) -> ModelT: self.session.expunge(instance) return instance + def filter_collection_by_kwargs(self, **kwargs: Any) -> None: + """Filter the collection by kwargs. + + Args: + **kwargs: key/value pairs such that objects remaining in the collection after filtering + have the property that their attribute named `key` has value equal to `value`. + """ + with wrap_sqlalchemy_exception(): + self._select.filter_by(**kwargs) + @classmethod async def check_health(cls, session: AsyncSession) -> bool: """Perform a health check on the database. diff --git a/src/starlite_saqlalchemy/testing.py b/src/starlite_saqlalchemy/testing.py index 3ddace7f..29d06531 100644 --- a/src/starlite_saqlalchemy/testing.py +++ b/src/starlite_saqlalchemy/testing.py @@ -13,7 +13,10 @@ from starlite_saqlalchemy.db import orm from starlite_saqlalchemy.repository.abc import AbstractRepository -from starlite_saqlalchemy.repository.exceptions import RepositoryConflictException +from starlite_saqlalchemy.repository.exceptions import ( + RepositoryConflictException, + RepositoryException, +) if TYPE_CHECKING: from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence @@ -162,6 +165,22 @@ async def upsert(self, data: ModelT) -> ModelT: return await self.update(data) return await self.add(data, allow_id=True) + def filter_collection_by_kwargs(self, **kwargs: Any) -> None: + """Filter the collection by kwargs. + + Args: + **kwargs: key/value pairs such that objects remaining in the collection after filtering + have the property that their attribute named `key` has value equal to `value`. + """ + new_collection: dict[Hashable, ModelT] = {} + for item in self.collection.values(): + try: + if all(getattr(item, name) == value for name, value in kwargs.items()): + new_collection[item.id] = item + except AttributeError as orig: + raise RepositoryException from orig + self.collection = new_collection + @classmethod def seed_collection(cls, instances: Iterable[ModelT]) -> None: """Seed the collection for repository type. diff --git a/tests/integration/repository/test_sqlalchemy_repository.py b/tests/integration/repository/test_sqlalchemy_repository.py new file mode 100644 index 00000000..b3f5abe8 --- /dev/null +++ b/tests/integration/repository/test_sqlalchemy_repository.py @@ -0,0 +1,20 @@ +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker + +from starlite_saqlalchemy.repository.exceptions import RepositoryException +from tests.utils.domain import authors + + +@pytest.fixture(name="session") +def fx_session(engine: AsyncEngine) -> AsyncSession: + return async_sessionmaker(bind=engine)() + + +@pytest.fixture(name="repo") +def fx_repo(session: AsyncSession) -> authors.Repository: + return authors.Repository(session=session) + + +def test_filter_by_kwargs_with_incorrect_attribute_name(repo: authors.Repository) -> None: + with pytest.raises(RepositoryException): + repo.filter_collection_by_kwargs(whoops="silly me") diff --git a/tests/unit/repository/test_sqlalchemy.py b/tests/unit/repository/test_sqlalchemy.py index 70af2cd7..0de02e9b 100644 --- a/tests/unit/repository/test_sqlalchemy.py +++ b/tests/unit/repository/test_sqlalchemy.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, call import pytest -from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.exc import IntegrityError, InvalidRequestError, SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from starlite_saqlalchemy.repository.exceptions import ( @@ -229,3 +229,21 @@ def test__filter_on_datetime_field( field_mock.__gt__ = field_mock.__lt__ = lambda self, other: True mock_repo.model_type.updated = field_mock mock_repo._filter_on_datetime_field("updated", before, after) + + +def test_filter_collection_by_kwargs(mock_repo: SQLAlchemyRepository) -> None: + """Test `filter_by()` called with kwargs.""" + mock_repo.filter_collection_by_kwargs(a=1, b=2) + mock_repo._select.filter_by.assert_called_once_with(a=1, b=2) + + +def test_filter_collection_by_kwargs_raises_repository_exception_for_attribute_error( + mock_repo: SQLAlchemyRepository, +) -> None: + """Test that we raise a repository exception if an attribute name is + incorrect.""" + mock_repo._select.filter_by = MagicMock( # type:ignore[assignment] + side_effect=InvalidRequestError, + ) + with pytest.raises(RepositoryException): + mock_repo.filter_collection_by_kwargs(a=1) diff --git a/tests/unit/test_testing.py b/tests/unit/test_testing.py index 008c997d..f29669e6 100644 --- a/tests/unit/test_testing.py +++ b/tests/unit/test_testing.py @@ -9,7 +9,10 @@ from starlite.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND from starlite_saqlalchemy import testing -from starlite_saqlalchemy.repository.exceptions import RepositoryConflictException +from starlite_saqlalchemy.repository.exceptions import ( + RepositoryConflictException, + RepositoryException, +) from tests.utils.domain.authors import Author from tests.utils.domain.authors import Service as AuthorService from tests.utils.domain.books import Book @@ -54,6 +57,33 @@ def test_generic_mock_repository_clear_collection( assert not author_repository_type.collection +def test_generic_mock_repository_filter_collection_by_kwargs( + author_repository: testing.GenericMockRepository[Author], +) -> None: + """Test filtering the repository collection by kwargs.""" + author_repository.filter_collection_by_kwargs(name="Leo Tolstoy") + assert len(author_repository.collection) == 1 + assert list(author_repository.collection.values())[0].name == "Leo Tolstoy" + + +def test_generic_mock_repository_filter_collection_by_kwargs_and_semantics( + author_repository: testing.GenericMockRepository[Author], +) -> None: + """Test that filtering by kwargs has `AND` semantics when multiple kwargs, + not `OR`.""" + author_repository.filter_collection_by_kwargs(name="Agatha Christie", dob="1828-09-09") + assert len(author_repository.collection) == 0 + + +def test_generic_mock_repository_raises_repository_exception_if_named_attribute_doesnt_exist( + author_repository: testing.GenericMockRepository[Author], +) -> None: + """Test that a repo exception is raised if a named attribute doesn't + exist.""" + with pytest.raises(RepositoryException): + author_repository.filter_collection_by_kwargs(cricket="ball") + + @pytest.fixture(name="mock_response") def fx_mock_response() -> MagicMock: """Mock response for returning from mock client requests."""