Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starlette middleware #680

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,57 @@ Starlette

This section describes integration with `Starlette <https://www.starlette.io>`__ ASGI framework.

Middleware
~~~~~~~~~~

Starlette can be integrated by middleware. Add ``StarletteOpenAPIMiddleware`` with ``spec`` to your ``middleware`` list.

.. code-block:: python
:emphasize-lines: 1,6

from openapi_core.contrib.starlette.middlewares import StarletteOpenAPIMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware

middleware = [
Middleware(StarletteOpenAPIMiddleware, spec=spec),
]

app = Starlette(
# ...
middleware=middleware,
)

After that you have access to unmarshal result object with all validated request data from endpoint through ``openapi`` key of request's scope directory.

.. code-block:: python

async def get_endpoint(req):
# get parameters object with path, query, cookies and headers parameters
validated_params = req.scope["openapi"].parameters
# or specific location parameters
validated_path_params = req.scope["openapi"].parameters.path

# get body
validated_body = req.scope["openapi"].body

# get security data
validated_security = req.scope["openapi"].security

You can skip response validation process: by setting ``response_cls`` to ``None``

.. code-block:: python
:emphasize-lines: 2

middleware = [
Middleware(StarletteOpenAPIMiddleware, spec=spec, response_cls=None),
]

app = Starlette(
# ...
middleware=middleware,
)

Low level
~~~~~~~~~

Expand Down
66 changes: 66 additions & 0 deletions openapi_core/contrib/starlette/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""OpenAPI core contrib starlette handlers module"""
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Type

from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.responses import Response

from openapi_core.templating.media_types.exceptions import MediaTypeNotFound
from openapi_core.templating.paths.exceptions import OperationNotFound
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class StarletteOpenAPIErrorsHandler:
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
ServerNotFound: 400,
SecurityNotFound: 403,
OperationNotFound: 405,
PathNotFound: 404,
MediaTypeNotFound: 415,
}

def __call__(
self,
errors: Iterable[Exception],
) -> JSONResponse:
data_errors = [self.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_error_max = max(data_errors, key=self.get_error_status)
return JSONResponse(data, status_code=data_error_max["status"])

@classmethod
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
if error.__cause__ is not None:
error = error.__cause__
return {
"title": str(error),
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
"type": str(type(error)),
}

@classmethod
def get_error_status(cls, error: Dict[str, Any]) -> str:
return str(error["status"])


class StarletteOpenAPIValidRequestHandler:
def __init__(self, request: Request, call_next: RequestResponseEndpoint):
self.request = request
self.call_next = call_next

async def __call__(
self, request_unmarshal_result: RequestUnmarshalResult
) -> Response:
self.request.scope["openapi"] = request_unmarshal_result
return await self.call_next(self.request)
75 changes: 75 additions & 0 deletions openapi_core/contrib/starlette/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""OpenAPI core contrib starlette middlewares module"""
from typing import Callable

from aioitertools.builtins import list as alist
from aioitertools.itertools import tee as atee
from jsonschema_path import SchemaPath
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import StreamingResponse
from starlette.types import ASGIApp

from openapi_core.contrib.starlette.handlers import (
StarletteOpenAPIErrorsHandler,
)
from openapi_core.contrib.starlette.handlers import (
StarletteOpenAPIValidRequestHandler,
)
from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest
from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse
from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor


class StarletteOpenAPIMiddleware(
BaseHTTPMiddleware, AsyncUnmarshallingProcessor[Request, Response]
):
request_cls = StarletteOpenAPIRequest
response_cls = StarletteOpenAPIResponse
valid_request_handler_cls = StarletteOpenAPIValidRequestHandler
errors_handler = StarletteOpenAPIErrorsHandler()

def __init__(self, app: ASGIApp, spec: SchemaPath):
BaseHTTPMiddleware.__init__(self, app)
AsyncUnmarshallingProcessor.__init__(self, spec)

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
valid_request_handler = self.valid_request_handler_cls(
request, call_next
)
response = await self.handle_request(
request, valid_request_handler, self.errors_handler
)
return await self.handle_response(
request, response, self.errors_handler
)

async def _get_openapi_request(
self, request: Request
) -> StarletteOpenAPIRequest:
body = await request.body()
return self.request_cls(request, body)

async def _get_openapi_response(
self, response: Response
) -> StarletteOpenAPIResponse:
assert self.response_cls is not None
data = None
if isinstance(response, StreamingResponse):
body_iter1, body_iter2 = atee(response.body_iterator)
response.body_iterator = body_iter2
data = b"".join(
[
chunk.encode(response.charset)
if not isinstance(chunk, bytes)
else chunk
async for chunk in body_iter1
]
)
return self.response_cls(response, data=data)

def _validate_response(self) -> bool:
return self.response_cls is not None
12 changes: 3 additions & 9 deletions openapi_core/contrib/starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class StarletteOpenAPIRequest:
def __init__(self, request: Request):
def __init__(self, request: Request, body: Optional[bytes] = None):
if not isinstance(request, Request):
raise TypeError(f"'request' argument is not type of {Request}")
self.request = request
Expand All @@ -19,7 +19,7 @@ def __init__(self, request: Request):
cookie=self.request.cookies,
)

self._get_body = AsyncToSync(self.request.body, force_new_loop=True)
self._body = body

@property
def host_url(self) -> str:
Expand All @@ -35,13 +35,7 @@ def method(self) -> str:

@property
def body(self) -> Optional[bytes]:
body = self._get_body()
if body is None:
return None
if isinstance(body, bytes):
return body
assert isinstance(body, str)
return body.encode("utf-8")
return self._body

@property
def content_type(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions openapi_core/typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Awaitable
from typing import Callable
from typing import Iterable
from typing import TypeVar
Expand All @@ -11,3 +12,6 @@

ErrorsHandlerCallable = Callable[[Iterable[Exception]], ResponseType]
ValidRequestHandlerCallable = Callable[[RequestUnmarshalResult], ResponseType]
AsyncValidRequestHandlerCallable = Callable[
[RequestUnmarshalResult], Awaitable[ResponseType]
]
72 changes: 72 additions & 0 deletions openapi_core/unmarshalling/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openapi_core.protocols import Request
from openapi_core.protocols import Response
from openapi_core.shortcuts import get_classes
from openapi_core.typing import AsyncValidRequestHandlerCallable
from openapi_core.typing import ErrorsHandlerCallable
from openapi_core.typing import RequestType
from openapi_core.typing import ResponseType
Expand Down Expand Up @@ -90,3 +91,74 @@
if response_unmarshal_result.errors:
return errors_handler(response_unmarshal_result.errors)
return response


class AsyncUnmarshallingProcessor(Generic[RequestType, ResponseType]):
def __init__(
self,
spec: SchemaPath,
request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None,
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
**unmarshaller_kwargs: Any,
):
if (
request_unmarshaller_cls is None
or response_unmarshaller_cls is None
):
classes = get_classes(spec)
if request_unmarshaller_cls is None:
request_unmarshaller_cls = classes.request_unmarshaller_cls
if response_unmarshaller_cls is None:
response_unmarshaller_cls = classes.response_unmarshaller_cls

self.request_processor = RequestUnmarshallingProcessor(
spec,
request_unmarshaller_cls,
**unmarshaller_kwargs,
)
self.response_processor = ResponseUnmarshallingProcessor(
spec,
response_unmarshaller_cls,
**unmarshaller_kwargs,
)

async def _get_openapi_request(self, request: RequestType) -> Request:
raise NotImplementedError

Check warning on line 126 in openapi_core/unmarshalling/processors.py

View check run for this annotation

Codecov / codecov/patch

openapi_core/unmarshalling/processors.py#L126

Added line #L126 was not covered by tests

async def _get_openapi_response(self, response: ResponseType) -> Response:
raise NotImplementedError

Check warning on line 129 in openapi_core/unmarshalling/processors.py

View check run for this annotation

Codecov / codecov/patch

openapi_core/unmarshalling/processors.py#L129

Added line #L129 was not covered by tests

def _validate_response(self) -> bool:
raise NotImplementedError

Check warning on line 132 in openapi_core/unmarshalling/processors.py

View check run for this annotation

Codecov / codecov/patch

openapi_core/unmarshalling/processors.py#L132

Added line #L132 was not covered by tests

async def handle_request(
self,
request: RequestType,
valid_handler: AsyncValidRequestHandlerCallable[ResponseType],
errors_handler: ErrorsHandlerCallable[ResponseType],
) -> ResponseType:
openapi_request = await self._get_openapi_request(request)
request_unmarshal_result = self.request_processor.process(
openapi_request
)
if request_unmarshal_result.errors:
return errors_handler(request_unmarshal_result.errors)
result = await valid_handler(request_unmarshal_result)
return result

async def handle_response(
self,
request: RequestType,
response: ResponseType,
errors_handler: ErrorsHandlerCallable[ResponseType],
) -> ResponseType:
if not self._validate_response():
return response

Check warning on line 156 in openapi_core/unmarshalling/processors.py

View check run for this annotation

Codecov / codecov/patch

openapi_core/unmarshalling/processors.py#L156

Added line #L156 was not covered by tests
openapi_request = await self._get_openapi_request(request)
openapi_response = await self._get_openapi_response(response)
response_unmarshal_result = self.response_processor.process(
openapi_request, openapi_response
)
if response_unmarshal_result.errors:
return errors_handler(response_unmarshal_result.errors)
return response
34 changes: 31 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading