From d7f96864f4b3b21c2225cf212fafa3e39b16ba95 Mon Sep 17 00:00:00 2001 From: David Moore <4121492+davemooreuws@users.noreply.github.com> Date: Thu, 4 Apr 2024 10:00:10 +1100 Subject: [PATCH] fix: apply api middleware correctly (#129) --- nitric/context.py | 4 ++-- nitric/resources/apis.py | 10 ++++++++++ tests/resources/test_apis.py | 20 +++++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/nitric/context.py b/nitric/context.py index 779cd3c..5c770c0 100644 --- a/nitric/context.py +++ b/nitric/context.py @@ -439,9 +439,9 @@ async def chained_middleware(ctx: C, nxt: Optional[Middleware[C]] = None) -> C: return chained_middleware - middleware_chain = functools.reduce(reduce_chain, reversed(middlewares)) # type: ignore + middleware_chain = functools.reduce(reduce_chain, reversed(middlewares), last_middleware) # type: ignore # type ignored because mypy appears to misidentify the correct return type - return await middleware_chain(ctx, last_middleware) # type: ignore + return await middleware_chain(ctx) # type: ignore return composed diff --git a/nitric/resources/apis.py b/nitric/resources/apis.py index e88d2f3..43453e3 100644 --- a/nitric/resources/apis.py +++ b/nitric/resources/apis.py @@ -201,6 +201,9 @@ def _route(self, match: str, opts: Optional[RouteOptions] = None) -> Route: if opts is None: opts = RouteOptions() + if self.middleware is not None: + opts.middleware = self.middleware + opts.middleware + r = Route(self, match, opts) self.routes.append(r) return r @@ -339,6 +342,13 @@ def method( self, methods: List[HttpMethod], *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None ) -> None: """Register middleware for multiple HTTP Methods.""" + + # ensure route/api middlewares are added + middleware = ( + *self.middleware, + *middleware + ) + Method(self, methods, *middleware, opts=opts if opts else MethodOptions()) def get(self, *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None) -> None: diff --git a/tests/resources/test_apis.py b/tests/resources/test_apis.py index 4d53de8..22cee32 100644 --- a/tests/resources/test_apis.py +++ b/tests/resources/test_apis.py @@ -26,7 +26,7 @@ # from nitric.faas import HttpMethod, MethodOptions, ApiWorkerOptions from nitric.resources import api, ApiOptions, JwtSecurityDefinition -from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule +from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule, HttpMiddleware from nitric.proto.resources.v1 import ( ApiOpenIdConnectionDefinition, ApiSecurityDefinitionResource, @@ -40,6 +40,7 @@ from nitric.proto.apis.v1 import ApiDetailsResponse, ApiDetailsRequest, ApiWorkerScopes from nitric.context import ( + HttpContext, HttpMethod, ) @@ -221,6 +222,23 @@ def test_api_route(self): assert test_route.middleware == [] assert test_route.api.name == test_api.name + def test_api_route_middleware(self): + mock_declare = AsyncMock() + mock_response = Object() + mock_declare.return_value = mock_response + + async def middleware_test(ctx: HttpContext, nxt: HttpMiddleware): + return nxt(ctx) + + with patch("nitric.proto.resources.v1.ResourcesStub.declare", mock_declare): + test_api = api("test-api-route-middleware", ApiOptions(path="/api/v2/", middleware=[middleware_test])) + + test_route = test_api._route("/test") + + assert len(test_api.middleware) == 1 + assert len(test_route.middleware) == 1 + + def test_define_route(self): mock_declare = AsyncMock() mock_response = Object()