Skip to content

Commit

Permalink
Merge pull request #680 from python-openapi/feature/starlette-middleware
Browse files Browse the repository at this point in the history
Starlette middleware
  • Loading branch information
p1c2u authored Nov 3, 2023
2 parents 09f065b + 2076707 commit ba52208
Show file tree
Hide file tree
Showing 12 changed files with 719 additions and 28 deletions.
51 changes: 51 additions & 0 deletions docs/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,57 @@ Starlette

This section describes integration with `Starlette <https://www.starlette.io>`__ ASGI framework.

Middleware
~~~~~~~~~~

Starlette can be integrated by middleware. Add ``StarletteOpenAPIMiddleware`` with ``spec`` to your ``middleware`` list.

.. code-block:: python
:emphasize-lines: 1,6
from openapi_core.contrib.starlette.middlewares import StarletteOpenAPIMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
middleware = [
Middleware(StarletteOpenAPIMiddleware, spec=spec),
]
app = Starlette(
# ...
middleware=middleware,
)
After that you have access to unmarshal result object with all validated request data from endpoint through ``openapi`` key of request's scope directory.

.. code-block:: python
async def get_endpoint(req):
# get parameters object with path, query, cookies and headers parameters
validated_params = req.scope["openapi"].parameters
# or specific location parameters
validated_path_params = req.scope["openapi"].parameters.path
# get body
validated_body = req.scope["openapi"].body
# get security data
validated_security = req.scope["openapi"].security
You can skip response validation process: by setting ``response_cls`` to ``None``

.. code-block:: python
:emphasize-lines: 2
middleware = [
Middleware(StarletteOpenAPIMiddleware, spec=spec, response_cls=None),
]
app = Starlette(
# ...
middleware=middleware,
)
Low level
~~~~~~~~~

Expand Down
66 changes: 66 additions & 0 deletions openapi_core/contrib/starlette/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""OpenAPI core contrib starlette handlers module"""
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Type

from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.responses import Response

from openapi_core.templating.media_types.exceptions import MediaTypeNotFound
from openapi_core.templating.paths.exceptions import OperationNotFound
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class StarletteOpenAPIErrorsHandler:
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
ServerNotFound: 400,
SecurityNotFound: 403,
OperationNotFound: 405,
PathNotFound: 404,
MediaTypeNotFound: 415,
}

def __call__(
self,
errors: Iterable[Exception],
) -> JSONResponse:
data_errors = [self.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_error_max = max(data_errors, key=self.get_error_status)
return JSONResponse(data, status_code=data_error_max["status"])

@classmethod
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
if error.__cause__ is not None:
error = error.__cause__
return {
"title": str(error),
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
"type": str(type(error)),
}

@classmethod
def get_error_status(cls, error: Dict[str, Any]) -> str:
return str(error["status"])


class StarletteOpenAPIValidRequestHandler:
def __init__(self, request: Request, call_next: RequestResponseEndpoint):
self.request = request
self.call_next = call_next

async def __call__(
self, request_unmarshal_result: RequestUnmarshalResult
) -> Response:
self.request.scope["openapi"] = request_unmarshal_result
return await self.call_next(self.request)
75 changes: 75 additions & 0 deletions openapi_core/contrib/starlette/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""OpenAPI core contrib starlette middlewares module"""
from typing import Callable

from aioitertools.builtins import list as alist
from aioitertools.itertools import tee as atee
from jsonschema_path import SchemaPath
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import StreamingResponse
from starlette.types import ASGIApp

from openapi_core.contrib.starlette.handlers import (
StarletteOpenAPIErrorsHandler,
)
from openapi_core.contrib.starlette.handlers import (
StarletteOpenAPIValidRequestHandler,
)
from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest
from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse
from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor


class StarletteOpenAPIMiddleware(
BaseHTTPMiddleware, AsyncUnmarshallingProcessor[Request, Response]
):
request_cls = StarletteOpenAPIRequest
response_cls = StarletteOpenAPIResponse
valid_request_handler_cls = StarletteOpenAPIValidRequestHandler
errors_handler = StarletteOpenAPIErrorsHandler()

def __init__(self, app: ASGIApp, spec: SchemaPath):
BaseHTTPMiddleware.__init__(self, app)
AsyncUnmarshallingProcessor.__init__(self, spec)

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
valid_request_handler = self.valid_request_handler_cls(
request, call_next
)
response = await self.handle_request(
request, valid_request_handler, self.errors_handler
)
return await self.handle_response(
request, response, self.errors_handler
)

async def _get_openapi_request(
self, request: Request
) -> StarletteOpenAPIRequest:
body = await request.body()
return self.request_cls(request, body)

async def _get_openapi_response(
self, response: Response
) -> StarletteOpenAPIResponse:
assert self.response_cls is not None
data = None
if isinstance(response, StreamingResponse):
body_iter1, body_iter2 = atee(response.body_iterator)
response.body_iterator = body_iter2
data = b"".join(
[
chunk.encode(response.charset)
if not isinstance(chunk, bytes)
else chunk
async for chunk in body_iter1
]
)
return self.response_cls(response, data=data)

def _validate_response(self) -> bool:
return self.response_cls is not None
12 changes: 3 additions & 9 deletions openapi_core/contrib/starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class StarletteOpenAPIRequest:
def __init__(self, request: Request):
def __init__(self, request: Request, body: Optional[bytes] = None):
if not isinstance(request, Request):
raise TypeError(f"'request' argument is not type of {Request}")
self.request = request
Expand All @@ -19,7 +19,7 @@ def __init__(self, request: Request):
cookie=self.request.cookies,
)

self._get_body = AsyncToSync(self.request.body, force_new_loop=True)
self._body = body

@property
def host_url(self) -> str:
Expand All @@ -35,13 +35,7 @@ def method(self) -> str:

@property
def body(self) -> Optional[bytes]:
body = self._get_body()
if body is None:
return None
if isinstance(body, bytes):
return body
assert isinstance(body, str)
return body.encode("utf-8")
return self._body

@property
def content_type(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions openapi_core/typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Awaitable
from typing import Callable
from typing import Iterable
from typing import TypeVar
Expand All @@ -11,3 +12,6 @@

ErrorsHandlerCallable = Callable[[Iterable[Exception]], ResponseType]
ValidRequestHandlerCallable = Callable[[RequestUnmarshalResult], ResponseType]
AsyncValidRequestHandlerCallable = Callable[
[RequestUnmarshalResult], Awaitable[ResponseType]
]
72 changes: 72 additions & 0 deletions openapi_core/unmarshalling/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openapi_core.protocols import Request
from openapi_core.protocols import Response
from openapi_core.shortcuts import get_classes
from openapi_core.typing import AsyncValidRequestHandlerCallable
from openapi_core.typing import ErrorsHandlerCallable
from openapi_core.typing import RequestType
from openapi_core.typing import ResponseType
Expand Down Expand Up @@ -90,3 +91,74 @@ def handle_response(
if response_unmarshal_result.errors:
return errors_handler(response_unmarshal_result.errors)
return response


class AsyncUnmarshallingProcessor(Generic[RequestType, ResponseType]):
def __init__(
self,
spec: SchemaPath,
request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None,
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
**unmarshaller_kwargs: Any,
):
if (
request_unmarshaller_cls is None
or response_unmarshaller_cls is None
):
classes = get_classes(spec)
if request_unmarshaller_cls is None:
request_unmarshaller_cls = classes.request_unmarshaller_cls
if response_unmarshaller_cls is None:
response_unmarshaller_cls = classes.response_unmarshaller_cls

self.request_processor = RequestUnmarshallingProcessor(
spec,
request_unmarshaller_cls,
**unmarshaller_kwargs,
)
self.response_processor = ResponseUnmarshallingProcessor(
spec,
response_unmarshaller_cls,
**unmarshaller_kwargs,
)

async def _get_openapi_request(self, request: RequestType) -> Request:
raise NotImplementedError

async def _get_openapi_response(self, response: ResponseType) -> Response:
raise NotImplementedError

def _validate_response(self) -> bool:
raise NotImplementedError

async def handle_request(
self,
request: RequestType,
valid_handler: AsyncValidRequestHandlerCallable[ResponseType],
errors_handler: ErrorsHandlerCallable[ResponseType],
) -> ResponseType:
openapi_request = await self._get_openapi_request(request)
request_unmarshal_result = self.request_processor.process(
openapi_request
)
if request_unmarshal_result.errors:
return errors_handler(request_unmarshal_result.errors)
result = await valid_handler(request_unmarshal_result)
return result

async def handle_response(
self,
request: RequestType,
response: ResponseType,
errors_handler: ErrorsHandlerCallable[ResponseType],
) -> ResponseType:
if not self._validate_response():
return response
openapi_request = await self._get_openapi_request(request)
openapi_response = await self._get_openapi_response(response)
response_unmarshal_result = self.response_processor.process(
openapi_request, openapi_response
)
if response_unmarshal_result.errors:
return errors_handler(response_unmarshal_result.errors)
return response
34 changes: 31 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ba52208

Please sign in to comment.