From f49607e3062483419afc91ed74670092951ceb92 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Tue, 26 Sep 2023 13:22:59 +0000 Subject: [PATCH] Starlette middleware --- docs/integrations.rst | 34 ++ openapi_core/contrib/starlette/handlers.py | 66 ++++ openapi_core/contrib/starlette/middlewares.py | 75 +++++ openapi_core/contrib/starlette/requests.py | 12 +- openapi_core/typing.py | 4 + openapi_core/unmarshalling/processors.py | 72 ++++ poetry.lock | 34 +- pyproject.toml | 5 +- .../data/v3.0/starletteproject/__main__.py | 18 + .../v3.0/starletteproject/pets/endpoints.py | 92 ++++- .../starlette/test_starlette_project.py | 313 +++++++++++++++++- .../starlette/test_starlette_validation.py | 5 +- 12 files changed, 702 insertions(+), 28 deletions(-) create mode 100644 openapi_core/contrib/starlette/handlers.py create mode 100644 openapi_core/contrib/starlette/middlewares.py diff --git a/docs/integrations.rst b/docs/integrations.rst index e1f8bcb7..d6c7cc5f 100644 --- a/docs/integrations.rst +++ b/docs/integrations.rst @@ -402,6 +402,40 @@ Starlette This section describes integration with `Starlette `__ ASGI framework. +Middleware +~~~~~~~~~~ + +Starlette can be integrated by middleware. Add ``StarletteOpenAPIMiddleware`` with ``spec`` to your ``middleware`` list. + +.. code-block:: python + :emphasize-lines: 1,6 + + from openapi_core.contrib.starlette.middlewares import StarletteOpenAPIMiddleware + from starlette.applications import Starlette + from starlette.middleware import Middleware + + middleware = [ + Middleware(StarletteOpenAPIMiddleware, spec=spec), + ] + + app = Starlette(middleware=middleware) + +After that you have access to unmarshal result object with all validated request data from endpoint through ``openapi`` key of request's scope directory. + +.. code-block:: python + + async def get_endpoint(req): + # get parameters object with path, query, cookies and headers parameters + validated_params = req.scope["openapi"].parameters + # or specific location parameters + validated_path_params = req.scope["openapi"].parameters.path + + # get body + validated_body = req.scope["openapi"].body + + # get security data + validated_security = req.scope["openapi"].security + Low level ~~~~~~~~~ diff --git a/openapi_core/contrib/starlette/handlers.py b/openapi_core/contrib/starlette/handlers.py new file mode 100644 index 00000000..fbd16cca --- /dev/null +++ b/openapi_core/contrib/starlette/handlers.py @@ -0,0 +1,66 @@ +"""OpenAPI core contrib starlette handlers module""" +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Type + +from starlette.middleware.base import RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.responses import Response + +from openapi_core.templating.media_types.exceptions import MediaTypeNotFound +from openapi_core.templating.paths.exceptions import OperationNotFound +from openapi_core.templating.paths.exceptions import PathNotFound +from openapi_core.templating.paths.exceptions import ServerNotFound +from openapi_core.templating.security.exceptions import SecurityNotFound +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult + + +class StarletteOpenAPIErrorsHandler: + OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = { + ServerNotFound: 400, + SecurityNotFound: 403, + OperationNotFound: 405, + PathNotFound: 404, + MediaTypeNotFound: 415, + } + + def __call__( + self, + errors: Iterable[Exception], + ) -> JSONResponse: + data_errors = [self.format_openapi_error(err) for err in errors] + data = { + "errors": data_errors, + } + data_error_max = max(data_errors, key=self.get_error_status) + return JSONResponse(data, status_code=data_error_max["status"]) + + @classmethod + def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]: + if error.__cause__ is not None: + error = error.__cause__ + return { + "title": str(error), + "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), + "type": str(type(error)), + } + + @classmethod + def get_error_status(cls, error: Dict[str, Any]) -> str: + return str(error["status"]) + + +class StarletteOpenAPIValidRequestHandler: + def __init__(self, request: Request, call_next: RequestResponseEndpoint): + self.request = request + self.call_next = call_next + + async def __call__( + self, request_unmarshal_result: RequestUnmarshalResult + ) -> Response: + self.request.scope["openapi"] = request_unmarshal_result + return await self.call_next(self.request) diff --git a/openapi_core/contrib/starlette/middlewares.py b/openapi_core/contrib/starlette/middlewares.py new file mode 100644 index 00000000..f9bfb779 --- /dev/null +++ b/openapi_core/contrib/starlette/middlewares.py @@ -0,0 +1,75 @@ +"""OpenAPI core contrib starlette middlewares module""" +from typing import Callable + +from aioitertools.builtins import list as alist +from aioitertools.itertools import tee as atee +from jsonschema_path import SchemaPath +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.base import RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response +from starlette.responses import StreamingResponse +from starlette.types import ASGIApp + +from openapi_core.contrib.starlette.handlers import ( + StarletteOpenAPIErrorsHandler, +) +from openapi_core.contrib.starlette.handlers import ( + StarletteOpenAPIValidRequestHandler, +) +from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest +from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse +from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor + + +class StarletteOpenAPIMiddleware( + BaseHTTPMiddleware, AsyncUnmarshallingProcessor[Request, Response] +): + request_cls = StarletteOpenAPIRequest + response_cls = StarletteOpenAPIResponse + valid_request_handler_cls = StarletteOpenAPIValidRequestHandler + errors_handler = StarletteOpenAPIErrorsHandler() + + def __init__(self, app: ASGIApp, spec: SchemaPath): + BaseHTTPMiddleware.__init__(self, app) + AsyncUnmarshallingProcessor.__init__(self, spec) + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + valid_request_handler = self.valid_request_handler_cls( + request, call_next + ) + response = await self.handle_request( + request, valid_request_handler, self.errors_handler + ) + return await self.handle_response( + request, response, self.errors_handler + ) + + async def _get_openapi_request( + self, request: Request + ) -> StarletteOpenAPIRequest: + body = await request.body() + return self.request_cls(request, body) + + async def _get_openapi_response( + self, response: Response + ) -> StarletteOpenAPIResponse: + assert self.response_cls is not None + data = None + if isinstance(response, StreamingResponse): + body_iter1, body_iter2 = atee(response.body_iterator) + response.body_iterator = body_iter2 + data = b"".join( + [ + chunk.encode(response.charset) + if not isinstance(chunk, bytes) + else chunk + async for chunk in body_iter1 + ] + ) + return self.response_cls(response, data=data) + + def _validate_response(self) -> bool: + return self.response_cls is not None diff --git a/openapi_core/contrib/starlette/requests.py b/openapi_core/contrib/starlette/requests.py index b556fd8f..c6d223ae 100644 --- a/openapi_core/contrib/starlette/requests.py +++ b/openapi_core/contrib/starlette/requests.py @@ -8,7 +8,7 @@ class StarletteOpenAPIRequest: - def __init__(self, request: Request): + def __init__(self, request: Request, body: Optional[bytes] = None): if not isinstance(request, Request): raise TypeError(f"'request' argument is not type of {Request}") self.request = request @@ -19,7 +19,7 @@ def __init__(self, request: Request): cookie=self.request.cookies, ) - self._get_body = AsyncToSync(self.request.body, force_new_loop=True) + self._body = body @property def host_url(self) -> str: @@ -35,13 +35,7 @@ def method(self) -> str: @property def body(self) -> Optional[bytes]: - body = self._get_body() - if body is None: - return None - if isinstance(body, bytes): - return body - assert isinstance(body, str) - return body.encode("utf-8") + return self._body @property def content_type(self) -> str: diff --git a/openapi_core/typing.py b/openapi_core/typing.py index 78b66b24..ed682913 100644 --- a/openapi_core/typing.py +++ b/openapi_core/typing.py @@ -1,3 +1,4 @@ +from typing import Awaitable from typing import Callable from typing import Iterable from typing import TypeVar @@ -11,3 +12,6 @@ ErrorsHandlerCallable = Callable[[Iterable[Exception]], ResponseType] ValidRequestHandlerCallable = Callable[[RequestUnmarshalResult], ResponseType] +AsyncValidRequestHandlerCallable = Callable[ + [RequestUnmarshalResult], Awaitable[ResponseType] +] diff --git a/openapi_core/unmarshalling/processors.py b/openapi_core/unmarshalling/processors.py index f6afed1b..6a1945f9 100644 --- a/openapi_core/unmarshalling/processors.py +++ b/openapi_core/unmarshalling/processors.py @@ -8,6 +8,7 @@ from openapi_core.protocols import Request from openapi_core.protocols import Response from openapi_core.shortcuts import get_classes +from openapi_core.typing import AsyncValidRequestHandlerCallable from openapi_core.typing import ErrorsHandlerCallable from openapi_core.typing import RequestType from openapi_core.typing import ResponseType @@ -90,3 +91,74 @@ def handle_response( if response_unmarshal_result.errors: return errors_handler(response_unmarshal_result.errors) return response + + +class AsyncUnmarshallingProcessor(Generic[RequestType, ResponseType]): + def __init__( + self, + spec: SchemaPath, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + **unmarshaller_kwargs: Any, + ): + if ( + request_unmarshaller_cls is None + or response_unmarshaller_cls is None + ): + classes = get_classes(spec) + if request_unmarshaller_cls is None: + request_unmarshaller_cls = classes.request_unmarshaller_cls + if response_unmarshaller_cls is None: + response_unmarshaller_cls = classes.response_unmarshaller_cls + + self.request_processor = RequestUnmarshallingProcessor( + spec, + request_unmarshaller_cls, + **unmarshaller_kwargs, + ) + self.response_processor = ResponseUnmarshallingProcessor( + spec, + response_unmarshaller_cls, + **unmarshaller_kwargs, + ) + + async def _get_openapi_request(self, request: RequestType) -> Request: + raise NotImplementedError + + async def _get_openapi_response(self, response: ResponseType) -> Response: + raise NotImplementedError + + def _validate_response(self) -> bool: + raise NotImplementedError + + async def handle_request( + self, + request: RequestType, + valid_handler: AsyncValidRequestHandlerCallable[ResponseType], + errors_handler: ErrorsHandlerCallable[ResponseType], + ) -> ResponseType: + openapi_request = await self._get_openapi_request(request) + request_unmarshal_result = self.request_processor.process( + openapi_request + ) + if request_unmarshal_result.errors: + return errors_handler(request_unmarshal_result.errors) + result = await valid_handler(request_unmarshal_result) + return result + + async def handle_response( + self, + request: RequestType, + response: ResponseType, + errors_handler: ErrorsHandlerCallable[ResponseType], + ) -> ResponseType: + if not self._validate_response(): + return response + openapi_request = await self._get_openapi_request(request) + openapi_response = await self._get_openapi_response(response) + response_unmarshal_result = self.response_processor.process( + openapi_request, openapi_response + ) + if response_unmarshal_result.errors: + return errors_handler(response_unmarshal_result.errors) + return response diff --git a/poetry.lock b/poetry.lock index c82de783..ef9bab6e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -96,6 +96,20 @@ yarl = ">=1.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns", "brotlicffi"] +[[package]] +name = "aioitertools" +version = "0.11.0" +description = "itertools and builtins for AsyncIO and mixed iterables" +optional = true +python-versions = ">=3.6" +files = [ + {file = "aioitertools-0.11.0-py3-none-any.whl", hash = "sha256:04b95e3dab25b449def24d7df809411c10e62aab0cbe31a50ca4e68748c43394"}, + {file = "aioitertools-0.11.0.tar.gz", hash = "sha256:42c68b8dd3a69c2bf7f2233bf7df4bb58b557bca5252ac02ed5187bbc67d6831"}, +] + +[package.dependencies] +typing_extensions = {version = ">=4.0", markers = "python_version < \"3.10\""} + [[package]] name = "aiosignal" version = "1.3.1" @@ -1753,6 +1767,20 @@ files = [ flake8 = ">=4.0" pytest = ">=7.0" +[[package]] +name = "python-multipart" +version = "0.0.6" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"}, + {file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"}, +] + +[package.extras] +dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"] + [[package]] name = "pytz" version = "2023.3" @@ -2219,7 +2247,7 @@ test = ["pytest", "pytest-cov"] name = "starlette" version = "0.31.1" description = "The little ASGI library that shines." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "starlette-0.31.1-py3-none-any.whl", hash = "sha256:009fb98ecd551a55017d204f033c58b13abcd4719cb5c41503abbf6d260fde11"}, @@ -2464,9 +2492,9 @@ django = ["django"] falcon = ["falcon"] flask = ["flask"] requests = ["requests"] -starlette = ["starlette"] +starlette = ["aioitertools", "starlette"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "54f6f9ffda98506f95c651ddab0fccdd02790773341a193638e07b754aa9426d" +content-hash = "740dab592e2a321a845a8d504065b879abb56cfac43f2ee155c593d98f3278ea" diff --git a/pyproject.toml b/pyproject.toml index 4cb536a4..99147afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ jsonschema-path = "^0.3.1" asgiref = "^3.6.0" jsonschema = "^4.18.0" multidict = {version = "^6.0.4", optional = true} +aioitertools = {version = "^0.11.0", optional = true} [tool.poetry.extras] django = ["django"] @@ -83,7 +84,7 @@ falcon = ["falcon"] flask = ["flask"] requests = ["requests"] aiohttp = ["aiohttp", "multidict"] -starlette = ["starlette"] +starlette = ["starlette", "aioitertools"] [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -96,7 +97,9 @@ pre-commit = "*" pytest = "^7" pytest-flake8 = "*" pytest-cov = "*" +python-multipart = "*" responses = "*" +starlette = ">=0.26.1,<0.32.0" strict-rfc3339 = "^0.7" webob = "*" mypy = "^1.2" diff --git a/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py b/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py index bf1b0e7a..ee16e9c7 100644 --- a/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py +++ b/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py @@ -1,8 +1,25 @@ from starlette.applications import Starlette +from starlette.middleware import Middleware from starlette.routing import Route +from starletteproject.openapi import spec +from starletteproject.pets.endpoints import pet_detail_endpoint +from starletteproject.pets.endpoints import pet_list_endpoint from starletteproject.pets.endpoints import pet_photo_endpoint +from openapi_core.contrib.starlette.middlewares import ( + StarletteOpenAPIMiddleware, +) + +middleware = [ + Middleware( + StarletteOpenAPIMiddleware, + spec=spec, + ), +] + routes = [ + Route("/v1/pets", pet_list_endpoint, methods=["GET", "POST"]), + Route("/v1/pets/{petId}", pet_detail_endpoint, methods=["GET", "POST"]), Route( "/v1/pets/{petId}/photo", pet_photo_endpoint, methods=["GET", "POST"] ), @@ -10,5 +27,6 @@ app = Starlette( debug=True, + middleware=middleware, routes=routes, ) diff --git a/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py b/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py index 99f88ceb..c569cad2 100644 --- a/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py +++ b/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py @@ -1,5 +1,6 @@ from base64 import b64decode +from starlette.responses import JSONResponse from starlette.responses import Response from starlette.responses import StreamingResponse from starletteproject.openapi import spec @@ -20,18 +21,85 @@ ) -def pet_photo_endpoint(request): - openapi_request = StarletteOpenAPIRequest(request) - request_unmarshalled = unmarshal_request(openapi_request, spec=spec) +async def pet_list_endpoint(request): + assert request.scope["openapi"] + assert not request.scope["openapi"].errors + if request.method == "GET": + assert request.scope["openapi"].parameters.query == { + "page": 1, + "limit": 12, + "search": "", + } + data = [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + }, + ] + response_dict = { + "data": data, + } + headers = { + "X-Rate-Limit": "12", + } + return JSONResponse(response_dict, headers=headers) + elif request.method == "POST": + assert request.scope["openapi"].parameters.cookie == { + "user": 1, + } + assert request.scope["openapi"].parameters.header == { + "api-key": "12345", + } + assert request.scope["openapi"].body.__class__.__name__ == "PetCreate" + assert request.scope["openapi"].body.name in ["Cat", "Bird"] + if request.scope["openapi"].body.name == "Cat": + assert ( + request.scope["openapi"].body.ears.__class__.__name__ == "Ears" + ) + assert request.scope["openapi"].body.ears.healthy is True + if request.scope["openapi"].body.name == "Bird": + assert ( + request.scope["openapi"].body.wings.__class__.__name__ + == "Wings" + ) + assert request.scope["openapi"].body.wings.healthy is True + + headers = { + "X-Rate-Limit": "12", + } + return Response(status_code=201, headers=headers) + + +async def pet_detail_endpoint(request): + assert request.scope["openapi"] + assert not request.scope["openapi"].errors + if request.method == "GET": + assert request.scope["openapi"].parameters.path == { + "petId": 12, + } + data = { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + } + response_dict = { + "data": data, + } + headers = { + "X-Rate-Limit": "12", + } + return JSONResponse(response_dict, headers=headers) + + +async def pet_photo_endpoint(request): + body = await request.body() if request.method == "GET": contents = iter([OPENID_LOGO]) - response = StreamingResponse(contents, media_type="image/gif") - openapi_response = StarletteOpenAPIResponse(response, data=OPENID_LOGO) + return StreamingResponse(contents, media_type="image/gif") elif request.method == "POST": - contents = request.body() - response = Response(status_code=201) - openapi_response = StarletteOpenAPIResponse(response) - response_unmarshalled = unmarshal_response( - openapi_request, openapi_response, spec=spec - ) - return response + return Response(status_code=201) diff --git a/tests/integration/contrib/starlette/test_starlette_project.py b/tests/integration/contrib/starlette/test_starlette_project.py index f2cef304..ee7027f7 100644 --- a/tests/integration/contrib/starlette/test_starlette_project.py +++ b/tests/integration/contrib/starlette/test_starlette_project.py @@ -37,7 +37,318 @@ def api_key_encoded(self): return str(api_key_bytes_enc, "utf8") -class TestPetPhotoView(BaseTestPetstore): +class TestPetListEndpoint(BaseTestPetstore): + def test_get_no_required_param(self, client): + headers = { + "Authorization": "Basic testuser", + } + + with pytest.warns(DeprecationWarning): + response = client.get("/v1/pets", headers=headers) + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 400, + "title": "Missing required query parameter: limit", + } + ] + } + assert response.status_code == 400 + assert response.json() == expected_data + + def test_get_valid(self, client): + data_json = { + "limit": 12, + } + headers = { + "Authorization": "Basic testuser", + } + + with pytest.warns(DeprecationWarning): + response = client.get( + "/v1/pets", + params=data_json, + headers=headers, + ) + + expected_data = { + "data": [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + }, + ], + } + assert response.status_code == 200 + assert response.json() == expected_data + + def test_post_server_invalid(self, client): + response = client.post("/v1/pets") + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 400, + "title": ( + "Server not found for " + "http://petstore.swagger.io/v1/pets" + ), + } + ] + } + assert response.status_code == 400 + assert response.json() == expected_data + + def test_post_required_header_param_missing(self, client): + client.cookies.set("user", "1") + pet_name = "Cat" + pet_tag = "cats" + pet_street = "Piekna" + pet_city = "Warsaw" + pet_healthy = False + data_json = { + "name": pet_name, + "tag": pet_tag, + "position": 2, + "address": { + "street": pet_street, + "city": pet_city, + }, + "healthy": pet_healthy, + "wings": { + "healthy": pet_healthy, + }, + } + content_type = "application/json" + headers = { + "Authorization": "Basic testuser", + "Content-Type": content_type, + } + response = client.post( + "https://staging.gigantic-server.com/v1/pets", + json=data_json, + headers=headers, + ) + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 400, + "title": "Missing required header parameter: api-key", + } + ] + } + assert response.status_code == 400 + assert response.json() == expected_data + + def test_post_media_type_invalid(self, client): + client.cookies.set("user", "1") + data = "data" + content_type = "text/html" + headers = { + "Authorization": "Basic testuser", + "Content-Type": content_type, + "Api-Key": self.api_key_encoded, + } + response = client.post( + "https://staging.gigantic-server.com/v1/pets", + data=data, + headers=headers, + ) + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 415, + "title": ( + "Content for the following mimetype not found: " + "text/html. " + "Valid mimetypes: ['application/json', 'application/x-www-form-urlencoded', 'text/plain']" + ), + } + ] + } + assert response.status_code == 415 + assert response.json() == expected_data + + def test_post_required_cookie_param_missing(self, client): + data_json = { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + } + content_type = "application/json" + headers = { + "Authorization": "Basic testuser", + "Content-Type": content_type, + "Api-Key": self.api_key_encoded, + } + response = client.post( + "https://staging.gigantic-server.com/v1/pets", + json=data_json, + headers=headers, + ) + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 400, + "title": "Missing required cookie parameter: user", + } + ] + } + assert response.status_code == 400 + assert response.json() == expected_data + + @pytest.mark.parametrize( + "data_json", + [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + }, + { + "id": 12, + "name": "Bird", + "wings": { + "healthy": True, + }, + }, + ], + ) + def test_post_valid(self, client, data_json): + client.cookies.set("user", "1") + content_type = "application/json" + headers = { + "Authorization": "Basic testuser", + "Content-Type": content_type, + "Api-Key": self.api_key_encoded, + } + response = client.post( + "https://staging.gigantic-server.com/v1/pets", + json=data_json, + headers=headers, + ) + + assert response.status_code == 201 + assert not response.content + + +class TestPetDetailEndpoint(BaseTestPetstore): + def test_get_server_invalid(self, client): + response = client.get("http://testserver/v1/pets/12") + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 400, + "title": ( + "Server not found for " "http://testserver/v1/pets/12" + ), + } + ] + } + assert response.status_code == 400 + assert response.json() == expected_data + + def test_get_unauthorized(self, client): + response = client.get("/v1/pets/12") + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 403, + "title": ( + "Security not found. Schemes not valid for any " + "requirement: [['petstore_auth']]" + ), + } + ] + } + assert response.status_code == 403 + assert response.json() == expected_data + + def test_delete_method_invalid(self, client): + headers = { + "Authorization": "Basic testuser", + } + response = client.delete("/v1/pets/12", headers=headers) + + expected_data = { + "errors": [ + { + "type": ( + "" + ), + "status": 405, + "title": ( + "Operation delete not found for " + "http://petstore.swagger.io/v1/pets/12" + ), + } + ] + } + assert response.status_code == 405 + assert response.json() == expected_data + + def test_get_valid(self, client): + headers = { + "Authorization": "Basic testuser", + } + response = client.get("/v1/pets/12", headers=headers) + + expected_data = { + "data": { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + }, + } + assert response.status_code == 200 + assert response.json() == expected_data + + +class TestPetPhotoEndpoint(BaseTestPetstore): def test_get_valid(self, client, data_gif): headers = { "Authorization": "Basic testuser", diff --git a/tests/integration/contrib/starlette/test_starlette_validation.py b/tests/integration/contrib/starlette/test_starlette_validation.py index 992f4821..09f4a96b 100644 --- a/tests/integration/contrib/starlette/test_starlette_validation.py +++ b/tests/integration/contrib/starlette/test_starlette_validation.py @@ -48,8 +48,9 @@ def client(self, app): def test_request_validator_path_pattern(self, client, spec): response_data = {"data": "data"} - def test_route(request): - openapi_request = StarletteOpenAPIRequest(request) + async def test_route(request): + body = await request.body() + openapi_request = StarletteOpenAPIRequest(request, body) result = unmarshal_request(openapi_request, spec) assert not result.errors return JSONResponse(