diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index d950bdc9c52..b3d77df24b4 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -664,6 +664,10 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None prefix : str, optional An optional prefix to be added to the originally defined rule """ + + # Add reference to parent ApiGatewayResolver to support use cases where people subclass it to add custom logic + router.api_resolver = self + for route, func in router._routes.items(): if prefix: rule = route[0] @@ -678,6 +682,7 @@ class Router(BaseRouter): def __init__(self): self._routes: Dict[tuple, Callable] = {} + self.api_resolver: Optional[BaseRouter] = None def route( self, diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 09594789ac3..f28752e6de6 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1057,3 +1057,22 @@ def foo(account_id): assert post_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON assert put_result["statusCode"] == 404 assert put_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_api_gateway_app_router_access_to_resolver(): + # GIVEN a Router with registered routes + app = ApiGatewayResolver() + router = Router() + + @router.get("/my/path") + def foo(): + # WHEN accessing the api resolver instance via the router + # THEN it is accessible and equal to the instantiated api resolver + assert app == router.api_resolver + return {} + + app.include_router(router) + result = app(LOAD_GW_EVENT, {}) + + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON