Skip to content

Commit

Permalink
refactor(event-handler): api gateway handler review changes (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Brewer authored May 4, 2021
1 parent 213caed commit 59b3adf
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 47 deletions.
30 changes: 13 additions & 17 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import zlib
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union

from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
Expand All @@ -12,14 +12,11 @@


class ProxyEventType(Enum):
"""An enumerations of the supported proxy event types.
"""An enumerations of the supported proxy event types."""

**NOTE:** api_gateway is an alias of http_api_v1"""

http_api_v1 = "APIGatewayProxyEvent"
http_api_v2 = "APIGatewayProxyEventV2"
alb_event = "ALBEvent"
api_gateway = http_api_v1
APIGatewayProxyEvent = "APIGatewayProxyEvent"
APIGatewayProxyEventV2 = "APIGatewayProxyEventV2"
ALBEvent = "ALBEvent"


class CORSConfig(object):
Expand Down Expand Up @@ -236,7 +233,7 @@ class ApiGatewayResolver:
current_event: BaseProxyEvent
lambda_context: LambdaContext

def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None):
def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: CORSConfig = None):
"""
Parameters
----------
Expand Down Expand Up @@ -310,9 +307,9 @@ def _compile_regex(rule: str):

def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
"""Convert the event dict to the corresponding data class"""
if self._proxy_type == ProxyEventType.http_api_v1:
if self._proxy_type == ProxyEventType.APIGatewayProxyEvent:
return APIGatewayProxyEvent(event)
if self._proxy_type == ProxyEventType.http_api_v2:
if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2:
return APIGatewayProxyEventV2(event)
return ALBEvent(event)

Expand All @@ -327,9 +324,9 @@ def _resolve(self) -> ResponseBuilder:
if match:
return self._call_route(route, match.groupdict())

return self._not_found(method, path)
return self._not_found(method)

def _not_found(self, method: str, path: str) -> ResponseBuilder:
def _not_found(self, method: str) -> ResponseBuilder:
"""Called when no matching route was found and includes support for the cors preflight response"""
headers = {}
if self._cors:
Expand All @@ -344,7 +341,7 @@ def _not_found(self, method: str, path: str) -> ResponseBuilder:
status_code=404,
content_type="application/json",
headers=headers,
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
body=json.dumps({"message": "Not found"}),
)
)

Expand All @@ -353,12 +350,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
return ResponseBuilder(self._to_response(route.func(**args)), route)

@staticmethod
def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response:
def _to_response(result: Union[Dict, Response]) -> Response:
"""Convert the route's result to a Response
3 main result types are supported:
2 main result types are supported:
- Tuple[int, str, bytes] and Tuple[int, str, str]: status code, content-type and body (str|bytes)
- Dict[str, Any]: Rest api response with just the Dict to json stringify and content-type is set to
application/json
- Response: returned as is, and allows for more flexibility
Expand Down
12 changes: 10 additions & 2 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ from aws_lambda_powertools.event_handler.api_gateway import (
tracer = Tracer()
# Other supported proxy_types: "APIGatewayProxyEvent", "APIGatewayProxyEventV2", "ALBEvent"
app = ApiGatewayResolver(
proxy_type=ProxyEventType.http_api_v1,
proxy_type=ProxyEventType.APIGatewayProxyEvent,
cors=CORSConfig(
allow_origin="https://www.example.com/",
expose_headers=["x-exposed-response-header"],
Expand All @@ -52,24 +52,28 @@ app = ApiGatewayResolver(
)
)


@app.get("/foo", compress=True)
def get_foo() -> Tuple[int, str, str]:
# Matches on http GET and proxy path "/foo"
# and return status code: 200, content-type: text/html and body: Hello
return 200, "text/html", "Hello"


@app.get("/logo.png")
def get_logo() -> Tuple[int, str, bytes]:
# Base64 encodes the return bytes body automatically
logo: bytes = load_logo()
return 200, "image/png", logo


@app.post("/make_foo", cors=True)
def make_foo() -> Tuple[int, str, str]:
# Matches on http POST and proxy path "/make_foo"
post_data: dict = app. current_event.json_body
post_data: dict = app.current_event.json_body
return 200, "application/json", json.dumps(post_data["value"])


@app.delete("/delete/<uid>")
def delete_foo(uid: str) -> Tuple[int, str, str]:
# Matches on http DELETE and proxy path starting with "/delete/"
Expand All @@ -78,16 +82,19 @@ def delete_foo(uid: str) -> Tuple[int, str, str]:
assert app.current_event.request_context.authorizer.claims["username"] == "Mike"
return 200, "application/json", json.dumps({"id": uid})


@app.get("/hello/<username>")
def hello_user(username: str) -> Tuple[int, str, str]:
return 200, "text/html", f"Hello {username}!"


@app.get("/rest")
def rest_fun() -> Dict:
# Returns a statusCode: 200, Content-Type: application/json and json.dumps dict
# and handles the serialization of decimals to json string
return {"message": "Example", "second": Decimal("100.01")}


@app.get("/foo3")
def foo3() -> Response:
return Response(
Expand All @@ -97,6 +104,7 @@ def foo3() -> Response:
body=json.dumps({"message": "Foo3"}),
)


@tracer.capture_lambda_handler
def lambda_handler(event, context) -> Dict:
return app.resolve(event, context)
Expand Down
56 changes: 28 additions & 28 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import zlib
from decimal import Decimal
from pathlib import Path
from typing import Dict, Tuple
from typing import Dict

from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
Expand All @@ -29,10 +29,10 @@ def read_media(file_name: str) -> bytes:

def test_alb_event():
# GIVEN a Application Load Balancer proxy type event
app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event)
app = ApiGatewayResolver(proxy_type=ProxyEventType.ALBEvent)

@app.get("/lambda")
def foo() -> Tuple[int, str, str]:
def foo():
assert isinstance(app.current_event, ALBEvent)
assert app.lambda_context == {}
return 200, TEXT_HTML, "foo"
Expand All @@ -49,13 +49,13 @@ def foo() -> Tuple[int, str, str]:

def test_api_gateway_v1():
# GIVEN a Http API V1 proxy type event
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1)
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

@app.get("/my/path")
def get_lambda() -> Tuple[int, str, str]:
def get_lambda() -> Response:
assert isinstance(app.current_event, APIGatewayProxyEvent)
assert app.lambda_context == {}
return 200, APPLICATION_JSON, json.dumps({"foo": "value"})
return Response(200, APPLICATION_JSON, json.dumps({"foo": "value"}))

# WHEN calling the event handler
result = app(LOAD_GW_EVENT, {})
Expand All @@ -68,12 +68,12 @@ def get_lambda() -> Tuple[int, str, str]:

def test_api_gateway():
# GIVEN a Rest API Gateway proxy type event
app = ApiGatewayResolver(proxy_type=ProxyEventType.api_gateway)
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

@app.get("/my/path")
def get_lambda() -> Tuple[int, str, str]:
def get_lambda() -> Response:
assert isinstance(app.current_event, APIGatewayProxyEvent)
return 200, TEXT_HTML, "foo"
return Response(200, TEXT_HTML, "foo")

# WHEN calling the event handler
result = app(LOAD_GW_EVENT, {})
Expand All @@ -87,13 +87,13 @@ def get_lambda() -> Tuple[int, str, str]:

def test_api_gateway_v2():
# GIVEN a Http API V2 proxy type event
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v2)
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEventV2)

@app.post("/my/path")
def my_path() -> Tuple[int, str, str]:
def my_path() -> Response:
assert isinstance(app.current_event, APIGatewayProxyEventV2)
post_data = app.current_event.json_body
return 200, "plain/text", post_data["username"]
return Response(200, "plain/text", post_data["username"])

# WHEN calling the event handler
result = app(load_event("apiGatewayProxyV2Event.json"), {})
Expand All @@ -110,9 +110,9 @@ def test_include_rule_matching():
app = ApiGatewayResolver()

@app.get("/<name>/<my_id>")
def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]:
def get_lambda(my_id: str, name: str) -> Response:
assert name == "my"
return 200, TEXT_HTML, my_id
return Response(200, TEXT_HTML, my_id)

# WHEN calling the event handler
result = app(LOAD_GW_EVENT, {})
Expand Down Expand Up @@ -179,8 +179,8 @@ def test_cors():
app = ApiGatewayResolver()

@app.get("/my/path", cors=True)
def with_cors() -> Tuple[int, str, str]:
return 200, TEXT_HTML, "test"
def with_cors() -> Response:
return Response(200, TEXT_HTML, "test")

def handler(event, context):
return app.resolve(event, context)
Expand All @@ -205,8 +205,8 @@ def test_compress():
expected_value = '{"test": "value"}'

@app.get("/my/request", compress=True)
def with_compression() -> Tuple[int, str, str]:
return 200, APPLICATION_JSON, expected_value
def with_compression() -> Response:
return Response(200, APPLICATION_JSON, expected_value)

def handler(event, context):
return app.resolve(event, context)
Expand All @@ -230,8 +230,8 @@ def test_base64_encode():
mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}

@app.get("/my/path", compress=True)
def read_image() -> Tuple[int, str, bytes]:
return 200, "image/png", read_media("idempotent_sequence_exception.png")
def read_image() -> Response:
return Response(200, "image/png", read_media("idempotent_sequence_exception.png"))

# WHEN calling the event handler
result = app(mock_event, None)
Expand All @@ -251,8 +251,8 @@ def test_compress_no_accept_encoding():
expected_value = "Foo"

@app.get("/my/path", compress=True)
def return_text() -> Tuple[int, str, str]:
return 200, "text/plain", expected_value
def return_text() -> Response:
return Response(200, "text/plain", expected_value)

# WHEN calling the event handler
result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None)
Expand All @@ -267,8 +267,8 @@ def test_cache_control_200():
app = ApiGatewayResolver()

@app.get("/success", cache_control="max-age=600")
def with_cache_control() -> Tuple[int, str, str]:
return 200, TEXT_HTML, "has 200 response"
def with_cache_control() -> Response:
return Response(200, TEXT_HTML, "has 200 response")

def handler(event, context):
return app.resolve(event, context)
Expand All @@ -288,8 +288,8 @@ def test_cache_control_non_200():
app = ApiGatewayResolver()

@app.delete("/fails", cache_control="max-age=600")
def with_cache_control_has_500() -> Tuple[int, str, str]:
return 503, TEXT_HTML, "has 503 response"
def with_cache_control_has_500() -> Response:
return Response(503, TEXT_HTML, "has 503 response")

def handler(event, context):
return app.resolve(event, context)
Expand All @@ -306,7 +306,7 @@ def handler(event, context):

def test_rest_api():
# GIVEN a function that returns a Dict
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1)
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)
expected_dict = {"foo": "value", "second": Decimal("100.01")}

@app.get("/my/path")
Expand All @@ -325,7 +325,7 @@ def rest_func() -> Dict:

def test_handling_response_type():
# GIVEN a function that returns Response
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1)
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

@app.get("/my/path")
def rest_func() -> Response:
Expand Down

0 comments on commit 59b3adf

Please sign in to comment.