-
-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #680 from python-openapi/feature/starlette-middleware
Starlette middleware
- Loading branch information
Showing
12 changed files
with
719 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""OpenAPI core contrib starlette handlers module""" | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Dict | ||
from typing import Iterable | ||
from typing import Optional | ||
from typing import Type | ||
|
||
from starlette.middleware.base import RequestResponseEndpoint | ||
from starlette.requests import Request | ||
from starlette.responses import JSONResponse | ||
from starlette.responses import Response | ||
|
||
from openapi_core.templating.media_types.exceptions import MediaTypeNotFound | ||
from openapi_core.templating.paths.exceptions import OperationNotFound | ||
from openapi_core.templating.paths.exceptions import PathNotFound | ||
from openapi_core.templating.paths.exceptions import ServerNotFound | ||
from openapi_core.templating.security.exceptions import SecurityNotFound | ||
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult | ||
|
||
|
||
class StarletteOpenAPIErrorsHandler: | ||
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = { | ||
ServerNotFound: 400, | ||
SecurityNotFound: 403, | ||
OperationNotFound: 405, | ||
PathNotFound: 404, | ||
MediaTypeNotFound: 415, | ||
} | ||
|
||
def __call__( | ||
self, | ||
errors: Iterable[Exception], | ||
) -> JSONResponse: | ||
data_errors = [self.format_openapi_error(err) for err in errors] | ||
data = { | ||
"errors": data_errors, | ||
} | ||
data_error_max = max(data_errors, key=self.get_error_status) | ||
return JSONResponse(data, status_code=data_error_max["status"]) | ||
|
||
@classmethod | ||
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]: | ||
if error.__cause__ is not None: | ||
error = error.__cause__ | ||
return { | ||
"title": str(error), | ||
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), | ||
"type": str(type(error)), | ||
} | ||
|
||
@classmethod | ||
def get_error_status(cls, error: Dict[str, Any]) -> str: | ||
return str(error["status"]) | ||
|
||
|
||
class StarletteOpenAPIValidRequestHandler: | ||
def __init__(self, request: Request, call_next: RequestResponseEndpoint): | ||
self.request = request | ||
self.call_next = call_next | ||
|
||
async def __call__( | ||
self, request_unmarshal_result: RequestUnmarshalResult | ||
) -> Response: | ||
self.request.scope["openapi"] = request_unmarshal_result | ||
return await self.call_next(self.request) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""OpenAPI core contrib starlette middlewares module""" | ||
from typing import Callable | ||
|
||
from aioitertools.builtins import list as alist | ||
from aioitertools.itertools import tee as atee | ||
from jsonschema_path import SchemaPath | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from starlette.middleware.base import RequestResponseEndpoint | ||
from starlette.requests import Request | ||
from starlette.responses import Response | ||
from starlette.responses import StreamingResponse | ||
from starlette.types import ASGIApp | ||
|
||
from openapi_core.contrib.starlette.handlers import ( | ||
StarletteOpenAPIErrorsHandler, | ||
) | ||
from openapi_core.contrib.starlette.handlers import ( | ||
StarletteOpenAPIValidRequestHandler, | ||
) | ||
from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest | ||
from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse | ||
from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor | ||
|
||
|
||
class StarletteOpenAPIMiddleware( | ||
BaseHTTPMiddleware, AsyncUnmarshallingProcessor[Request, Response] | ||
): | ||
request_cls = StarletteOpenAPIRequest | ||
response_cls = StarletteOpenAPIResponse | ||
valid_request_handler_cls = StarletteOpenAPIValidRequestHandler | ||
errors_handler = StarletteOpenAPIErrorsHandler() | ||
|
||
def __init__(self, app: ASGIApp, spec: SchemaPath): | ||
BaseHTTPMiddleware.__init__(self, app) | ||
AsyncUnmarshallingProcessor.__init__(self, spec) | ||
|
||
async def dispatch( | ||
self, request: Request, call_next: RequestResponseEndpoint | ||
) -> Response: | ||
valid_request_handler = self.valid_request_handler_cls( | ||
request, call_next | ||
) | ||
response = await self.handle_request( | ||
request, valid_request_handler, self.errors_handler | ||
) | ||
return await self.handle_response( | ||
request, response, self.errors_handler | ||
) | ||
|
||
async def _get_openapi_request( | ||
self, request: Request | ||
) -> StarletteOpenAPIRequest: | ||
body = await request.body() | ||
return self.request_cls(request, body) | ||
|
||
async def _get_openapi_response( | ||
self, response: Response | ||
) -> StarletteOpenAPIResponse: | ||
assert self.response_cls is not None | ||
data = None | ||
if isinstance(response, StreamingResponse): | ||
body_iter1, body_iter2 = atee(response.body_iterator) | ||
response.body_iterator = body_iter2 | ||
data = b"".join( | ||
[ | ||
chunk.encode(response.charset) | ||
if not isinstance(chunk, bytes) | ||
else chunk | ||
async for chunk in body_iter1 | ||
] | ||
) | ||
return self.response_cls(response, data=data) | ||
|
||
def _validate_response(self) -> bool: | ||
return self.response_cls is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.