Skip to content

Commit

Permalink
feat(event_handler): add ability to expose a Swagger UI (#3254)
Browse files Browse the repository at this point in the history
Co-authored-by: Leandro Damascena <[email protected]>
  • Loading branch information
rubenfonseca and leandrodamascena authored Nov 6, 2023
1 parent 8329153 commit 8a09adc
Show file tree
Hide file tree
Showing 17 changed files with 671 additions and 63 deletions.
225 changes: 198 additions & 27 deletions aws_lambda_powertools/event_handler/api_gateway.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions aws_lambda_powertools/event_handler/lambda_function_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ def __init__(
strip_prefixes,
enable_validation,
)

def _get_base_path(self) -> str:
stage = self.current_event.request_context.stage
if stage and stage != "$default" and self.current_event.request_context.http.method.startswith(f"/{stage}"):
return f"/{stage}"
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,31 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values
response = next_middleware(app)

# Process the response body if it exists
raw_response = jsonable_encoder(response.body)
# Call the handler by calling the next middleware
response = next_middleware(app)

# Validate and serialize the response
return self._serialize_response(field=route.dependant.return_param, response_content=raw_response)
# Process the response
return self._handle_response(route=route, response=response)
except RequestValidationError as e:
return Response(
status_code=422,
content_type="application/json",
body=json.dumps({"detail": e.errors()}),
)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
response.body = json.dumps(
self._serialize_response(field=route.dependant.return_param, response_content=response.body),
sort_keys=True,
)

return response

def _serialize_response(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DEFAULT_API_VERSION = "1.0.0"
DEFAULT_OPENAPI_VERSION = "3.1.0"
32 changes: 16 additions & 16 deletions aws_lambda_powertools/event_handler/openapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,24 @@ class Config:
extra = "allow"


# https://swagger.io/specification/#tag-object
class Tag(BaseModel):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None

if PYDANTIC_V2:
model_config = {"extra": "allow"}

else:

class Config:
extra = "allow"


# https://swagger.io/specification/#operation-object
class Operation(BaseModel):
tags: Optional[List[str]] = None
tags: Optional[List[Tag]] = None
summary: Optional[str] = None
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
Expand Down Expand Up @@ -540,21 +555,6 @@ class Config:
extra = "allow"


# https://swagger.io/specification/#tag-object
class Tag(BaseModel):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None

if PYDANTIC_V2:
model_config = {"extra": "allow"}

else:

class Config:
extra = "allow"


# https://swagger.io/specification/#openapi-object
class OpenAPI(BaseModel):
openapi: str
Expand Down
Empty file.
52 changes: 52 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
def generate_swagger_html(spec: str, js_url: str, css_url: str) -> str:
"""
Generate Swagger UI HTML page
Parameters
----------
spec: str
The OpenAPI spec in the JSON format
js_url: str
The URL to the Swagger UI JavaScript file
css_url: str
The URL to the Swagger UI CSS file
"""
return f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Swagger UI</title>
<link rel="stylesheet" type="text/css" href="{css_url}">
</head>
<body>
<div id="swagger-ui">
Loading...
</div>
</body>
<script src="{js_url}"></script>
<script>
var swaggerUIOptions = {{
dom_id: "#swagger-ui",
docExpansion: "list",
deepLinking: true,
filter: true,
spec: JSON.parse(`
{spec}
`.trim()),
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
]
}}
var ui = SwaggerUIBundle(swaggerUIOptions)
</script>
</html>
""".strip()

Large diffs are not rendered by default.

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions aws_lambda_powertools/event_handler/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(
"""Amazon VPC Lattice resolver"""
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation)

def _get_base_path(self) -> str:
return ""


class VPCLatticeV2Resolver(ApiGatewayResolver):
"""VPC Lattice resolver
Expand Down Expand Up @@ -98,3 +101,6 @@ def __init__(
):
"""Amazon VPC Lattice resolver"""
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation)

def _get_base_path(self) -> str:
return ""
8 changes: 7 additions & 1 deletion aws_lambda_powertools/utilities/parameters/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,13 @@ def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str:

return self.client.get_parameter(**sdk_options)["Parameter"]["Value"]

def _get_multiple(self, path: str, decrypt: Optional[bool] = None, recursive: bool = False, **sdk_options) -> Dict[str, str]:
def _get_multiple(
self,
path: str,
decrypt: Optional[bool] = None,
recursive: bool = False,
**sdk_options,
) -> Dict[str, str]:
"""
Retrieve multiple parameter values from AWS Systems Manager Parameter Store
Expand Down
23 changes: 23 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,29 @@ def handler(event, context):
assert headers["Content-Encoding"] == ["gzip"]


def test_response_is_json_without_content_type():
response = Response(200, None, "")

assert response.is_json() is False


def test_response_is_json_with_json_content_type():
response = Response(200, content_types.APPLICATION_JSON, "")
assert response.is_json() is True


def test_response_is_json_with_multiple_json_content_types():
response = Response(
200,
None,
"",
{
"Content-Type": [content_types.APPLICATION_JSON, content_types.APPLICATION_JSON],
},
)
assert response.is_json() is True


def test_compress():
# GIVEN a function that has compress=True
# AND an event with a "Accept-Encoding" that include gzip
Expand Down
103 changes: 103 additions & 0 deletions tests/functional/event_handler/test_base_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from aws_lambda_powertools.event_handler import (
ALBResolver,
APIGatewayHttpResolver,
APIGatewayRestResolver,
LambdaFunctionUrlResolver,
VPCLatticeResolver,
VPCLatticeV2Resolver,
)
from tests.functional.utils import load_event


def test_base_path_api_gateway_rest():
app = APIGatewayRestResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("apiGatewayProxyEvent.json")
event["path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""


def test_base_path_api_gateway_http():
app = APIGatewayHttpResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("apiGatewayProxyV2Event.json")
event["rawPath"] = "/"
event["requestContext"]["http"]["path"] = "/"
event["requestContext"]["http"]["method"] = "GET"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""


def test_base_path_alb():
app = ALBResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("albEvent.json")
event["path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""


def test_base_path_lambda_function_url():
app = LambdaFunctionUrlResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("lambdaFunctionUrlIAMEvent.json")
event["rawPath"] = "/"
event["requestContext"]["http"]["path"] = "/"
event["requestContext"]["http"]["method"] = "GET"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""


def test_vpc_lattice():
app = VPCLatticeResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("vpcLatticeEvent.json")
event["raw_path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""


def test_vpc_latticev2():
app = VPCLatticeV2Resolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("vpcLatticeV2Event.json")
event["path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == ""
Loading

0 comments on commit 8a09adc

Please sign in to comment.