Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
feat(testing): ControllerTest utility. (#152)
Browse files Browse the repository at this point in the history
A utility for testing standard set of controllers in a standard way.

It's a starting point, there are things that will need to be improved
to make it more flexible, e.g.,:
- configurable expected HTTP status for methods
- support for patch routes
- support for testing a subset of http methods
  • Loading branch information
peterschutt authored Dec 4, 2022
1 parent ef04bf9 commit 4cc707b
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 41 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
exclude = alembic/*
max-line-length = 100
ignore = E,W,B008
ignore = E,W,B008,PT013
type-checking-exempt-modules = from sqlalchemy.orm
per-file-ignores =
examples/dto/*:T201,TC
Expand All @@ -11,6 +11,7 @@ per-file-ignores =
src/starlite_saqlalchemy/repository/filters.py:TC
src/starlite_saqlalchemy/scripts.py:T201
src/starlite_saqlalchemy/settings.py:TC
src/starlite_saqlalchemy/testing.py:SCS108
src/starlite_saqlalchemy/users/controllers.py:TC
tests/*:SCS108,PT013
tests/integration/test_tests.py:TC002,SCS108
Expand Down
103 changes: 101 additions & 2 deletions src/starlite_saqlalchemy/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@
"""
from __future__ import annotations

import random
from datetime import datetime
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from uuid import uuid4

from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED

from starlite_saqlalchemy.db import orm
from starlite_saqlalchemy.repository.abc import AbstractRepository
from starlite_saqlalchemy.repository.exceptions import RepositoryConflictException

if TYPE_CHECKING:
from collections.abc import Callable, Hashable, Iterable, MutableMapping
from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence

from pytest import MonkeyPatch
from starlite.testing import TestClient

from starlite_saqlalchemy.repository.types import FilterTypes
from starlite_saqlalchemy.service import Service

ModelT = TypeVar("ModelT", bound=orm.Base)
MockRepoT = TypeVar("MockRepoT", bound="GenericMockRepository")
Expand All @@ -27,7 +34,7 @@ class GenericMockRepository(AbstractRepository[ModelT], Generic[ModelT]):
Uses a `dict` for storage.
"""

collection: MutableMapping[Hashable, ModelT] = {}
collection: MutableMapping[Hashable, ModelT]
model_type: type[ModelT]

def __init__(self, id_factory: Callable[[], Any] = uuid4, **_: Any) -> None:
Expand Down Expand Up @@ -169,3 +176,95 @@ def seed_collection(cls, instances: Iterable[ModelT]) -> None:
def clear_collection(cls) -> None:
"""Empty the collection for repository type."""
cls.collection = {}


class ControllerTest:
"""Standard controller testing utility."""

def __init__(
self,
client: TestClient,
base_path: str,
collection: Sequence[orm.Base],
raw_collection: Sequence[dict[str, Any]],
service_type: type[Service],
monkeypatch: MonkeyPatch,
collection_filters: dict[str, Any] | None = None,
) -> None:
"""Perform standard tests of controllers.
Args:
client: Test client instance.
base_path: Path for POST and collection GET requests.
collection: Collection of domain objects.
raw_collection: Collection of raw representations of domain objects.
service_type: The domain Service object type.
monkeypatch: Pytest's monkeypatch.
collection_filters: Collection filters for GET collection request.
"""
self.client = client
self.base_path = base_path
self.collection = collection
self.raw_collection = raw_collection
self.service_type = service_type
self.monkeypatch = monkeypatch
self.collection_filters = collection_filters

def _get_random_member(self) -> Any:
return random.choice(self.collection)

def _get_raw_for_member(self, member: Any) -> dict[str, Any]:
return [item for item in self.raw_collection if item["id"] == str(member.id)][0]

def test_get_collection(self, with_filters: bool = False) -> None:
"""Test collection endpoint get request."""

async def _list(*_: Any, **__: Any) -> list[Any]:
return list(self.collection)

self.monkeypatch.setattr(self.service_type, "list", _list)

resp = self.client.get(
self.base_path, params=self.collection_filters if with_filters else None
)

assert resp.status_code == HTTP_200_OK
assert resp.json() == self.raw_collection

def test_member_request(self, method: str, service_method: str, exp_status: int) -> None:
"""Test member endpoint request."""
member = self._get_random_member()
raw = self._get_raw_for_member(member)

async def _method(*_: Any, **__: Any) -> Any:
return member

self.monkeypatch.setattr(self.service_type, service_method, _method)

if method.lower() == "post":
url = self.base_path
else:
url = f"{self.base_path}/{member.id}"

request_kw: dict[str, Any] = {}
if method.lower() in ("put", "post"):
request_kw["json"] = raw

resp = self.client.request(method, url, **request_kw)

assert resp.status_code == exp_status
assert resp.json() == raw

def run(self) -> None:
"""Run the tests."""
# test the collection route with and without filters for branch coverage.
self.test_get_collection()
if self.collection_filters:
self.test_get_collection(with_filters=True)
for method, service_method, status in [
("GET", "get", HTTP_200_OK),
("PUT", "update", HTTP_200_OK),
("POST", "create", HTTP_201_CREATED),
("DELETE", "delete", HTTP_200_OK),
]:
self.test_member_request(method, service_method, status)
42 changes: 20 additions & 22 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import importlib
import sys
from datetime import date, datetime
from typing import TYPE_CHECKING, TypeVar
from uuid import UUID, uuid4
from uuid import uuid4

import pytest
from starlite import Starlite
Expand All @@ -14,8 +13,7 @@

import starlite_saqlalchemy
from starlite_saqlalchemy import ConfigureApp, log
from tests.utils.domain.authors import Author
from tests.utils.domain.books import Book
from tests.utils.domain import authors, books

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -65,47 +63,47 @@ def fx_raw_authors() -> list[dict[str, Any]]:

return [
{
"id": UUID("97108ac1-ffcb-411d-8b1e-d9183399f63b"),
"id": "97108ac1-ffcb-411d-8b1e-d9183399f63b",
"name": "Agatha Christie",
"dob": date(1890, 9, 15),
"created": datetime.min,
"updated": datetime.min,
"dob": "1890-09-15",
"created": "0001-01-01T00:00:00",
"updated": "0001-01-01T00:00:00",
},
{
"id": UUID("5ef29f3c-3560-4d15-ba6b-a2e5c721e4d2"),
"id": "5ef29f3c-3560-4d15-ba6b-a2e5c721e4d2",
"name": "Leo Tolstoy",
"dob": date(1828, 9, 9),
"created": datetime.min,
"updated": datetime.min,
"dob": "1828-09-09",
"created": "0001-01-01T00:00:00",
"updated": "0001-01-01T00:00:00",
},
]


@pytest.fixture(name="authors")
def fx_authors(raw_authors: list[dict[str, Any]]) -> list[Author]:
def fx_authors(raw_authors: list[dict[str, Any]]) -> list[authors.Author]:
"""Collection of parsed Author models."""
return [Author(**raw) for raw in raw_authors]
return [authors.ReadDTO(**raw).to_mapped() for raw in raw_authors]


@pytest.fixture(name="raw_books")
def fx_raw_books() -> list[dict[str, Any]]:
def fx_raw_books(raw_authors: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Unstructured book representations."""
return [
{
"id": UUID("f34545b9-663c-4fce-915d-dd1ae9cea42a"),
"id": "f34545b9-663c-4fce-915d-dd1ae9cea42a",
"title": "Murder on the Orient Express",
"author_id": UUID("97108ac1-ffcb-411d-8b1e-d9183399f63b"),
"created": datetime.min,
"updated": datetime.min,
"author_id": "97108ac1-ffcb-411d-8b1e-d9183399f63b",
"author": raw_authors[0],
"created": "0001-01-01T00:00:00",
"updated": "0001-01-01T00:00:00",
},
]


@pytest.fixture(name="books")
def fx_books(raw_books: list[dict[str, Any]], authors: list[Author]) -> list[Book]:
def fx_books(raw_books: list[dict[str, Any]]) -> list[books.Book]:
"""Collection of parsed Book models."""
author_id_map = {author.id: author for author in authors}
return [Book(**raw, author=author_id_map[raw["author_id"]]) for raw in raw_books]
return [books.ReadDTO(**raw).to_mapped() for raw in raw_books]


@pytest.fixture(name="create_module")
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from pytest_docker.plugin import Services # type:ignore[import]
from starlite import Starlite

from tests.utils.domain.authors import Author


here = Path(__file__).parent

Expand Down Expand Up @@ -165,9 +167,7 @@ async def engine(docker_ip: str) -> AsyncEngine:


@pytest.fixture(autouse=True)
async def _seed_db(
engine: AsyncEngine, raw_authors: list[dict[str, Any]]
) -> abc.AsyncIterator[None]:
async def _seed_db(engine: AsyncEngine, authors: list[Author]) -> abc.AsyncIterator[None]:
"""Populate test database with.
Args:
Expand All @@ -179,7 +179,7 @@ async def _seed_db(
async with engine.begin() as conn:
await conn.run_sync(metadata.create_all)
async with engine.begin() as conn:
await conn.execute(author_table.insert(), raw_authors)
await conn.execute(author_table.insert(), [vars(item) for item in authors])
yield
async with engine.begin() as conn:
await conn.run_sync(metadata.drop_all)
Expand Down
15 changes: 3 additions & 12 deletions tests/unit/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from unittest.mock import AsyncMock
from uuid import uuid4

import orjson
import pytest

from starlite_saqlalchemy import db, service, worker
Expand Down Expand Up @@ -95,17 +94,9 @@ async def test_make_service_callback(
{},
service_type_id="tests.utils.domain.authors.Service",
service_method_name="receive_callback",
raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)),
)
recv_cb_mock.assert_called_once_with(
raw_obj={
"id": "97108ac1-ffcb-411d-8b1e-d9183399f63b",
"name": "Agatha Christie",
"dob": "1890-09-15",
"created": "0001-01-01T00:00:00",
"updated": "0001-01-01T00:00:00",
},
raw_obj=raw_authors[0],
)
recv_cb_mock.assert_called_once_with(raw_obj=raw_authors[0])


async def test_make_service_callback_raises_runtime_error(
Expand All @@ -117,7 +108,7 @@ async def test_make_service_callback_raises_runtime_error(
{},
service_type_id="tests.utils.domain.LSKDFJ",
service_method_name="receive_callback",
raw_obj=orjson.loads(orjson.dumps(raw_authors[0], default=str)),
raw_obj=raw_authors[0],
)


Expand Down
Loading

0 comments on commit 4cc707b

Please sign in to comment.