Skip to content

Commit

Permalink
feat(apigateway): add exception_handler support (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Brewer authored Dec 16, 2021
1 parent e91932c commit 8c859de
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 15 deletions.
63 changes: 49 additions & 14 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from enum import Enum
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import ServiceError
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
from aws_lambda_powertools.shared.json_encoder import Encoder
Expand All @@ -27,7 +27,6 @@
_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
# API GW/ALB decode non-safe URI chars; we must support them too
_UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605

_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"


Expand Down Expand Up @@ -435,6 +434,7 @@ def __init__(
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._route_keys: List[str] = []
self._exception_handlers: Dict[Type, Callable] = {}
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}
Expand Down Expand Up @@ -596,6 +596,10 @@ def _not_found(self, method: str) -> ResponseBuilder:
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return ResponseBuilder(handler(NotFoundError()))

return ResponseBuilder(
Response(
status_code=HTTPStatus.NOT_FOUND.value,
Expand All @@ -609,16 +613,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
except ServiceError as e:
return ResponseBuilder(
Response(
status_code=e.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
),
route,
)
except Exception:
except Exception as exc:
response_builder = self._call_exception_handler(exc, route)
if response_builder:
return response_builder

if self._debug:
# If the user has turned on debug mode,
# we'll let the original exception propagate so
Expand All @@ -628,10 +627,46 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="".join(traceback.format_exc()),
)
),
route,
)

raise

def not_found(self, func: Callable):
return self.exception_handler(NotFoundError)(func)

def exception_handler(self, exc_class: Type[Exception]):
def register_exception_handler(func: Callable):
self._exception_handlers[exc_class] = func

return register_exception_handler

def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]:
# Use "Method Resolution Order" to allow for matching against a base class
# of an exception
for cls in exp_type.__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
handler = self._lookup_exception_handler(type(exp))
if handler:
return ResponseBuilder(handler(exp), route)

if isinstance(exp, ServiceError):
return ResponseBuilder(
Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
),
route,
)

return None

def _to_response(self, result: Union[Dict, Response]) -> Response:
"""Convert the route's result to a Response
Expand Down
75 changes: 74 additions & 1 deletion tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def patch_func():
def handler(event, context):
return app.resolve(event, context)

# Also check check the route configurations
# Also check the route configurations
routes = app._routes
assert len(routes) == 5
for route in routes:
Expand Down Expand Up @@ -1076,3 +1076,76 @@ def foo():

assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON


def test_exception_handler():
# GIVEN a resolver with an exception handler defined for ValueError
app = ApiGatewayResolver()

@app.exception_handler(ValueError)
def handle_value_error(ex: ValueError):
print(f"request path is '{app.current_event.path}'")
return Response(
status_code=418,
content_type=content_types.TEXT_HTML,
body=str(ex),
)

@app.get("/my/path")
def get_lambda() -> Response:
raise ValueError("Foo!")

# WHEN calling the event handler
# AND a ValueError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 418
assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
assert result["body"] == "Foo!"


def test_exception_handler_service_error():
# GIVEN
app = ApiGatewayResolver()

@app.exception_handler(ServiceError)
def service_error(ex: ServiceError):
print(ex.msg)
return Response(
status_code=ex.status_code,
content_type=content_types.APPLICATION_JSON,
body="CUSTOM ERROR FORMAT",
)

@app.get("/my/path")
def get_lambda() -> Response:
raise InternalServerError("Something sensitive")

# WHEN calling the event handler
# AND a ServiceError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 500
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
assert result["body"] == "CUSTOM ERROR FORMAT"


def test_exception_handler_not_found():
# GIVEN a resolver with an exception handler defined for a 404 not found
app = ApiGatewayResolver()

@app.not_found
def handle_not_found(exc: NotFoundError) -> Response:
assert isinstance(exc, NotFoundError)
return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!")

# WHEN calling the event handler
# AND not route is found
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 404
assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
assert result["body"] == "I am a teapot!"

0 comments on commit 8c859de

Please sign in to comment.