From a76d3572262080c909cb8e735b3e6df275caf453 Mon Sep 17 00:00:00 2001 From: Jye Cusch Date: Tue, 19 Oct 2021 13:00:12 +1100 Subject: [PATCH] fix: allow for async and sync handlers/middleware --- nitric/faas.py | 25 +++++++++++++++++++++---- tests/test_faas.py | 17 ++++++++++++++++- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/nitric/faas.py b/nitric/faas.py index 6047b8f..44fd51a 100644 --- a/nitric/faas.py +++ b/nitric/faas.py @@ -217,16 +217,33 @@ def compose_middleware(*middlewares: Union[Middleware, List[Middleware]]) -> Mid The resulting middleware will effectively be a chain of the provided middleware, where each calls the next in the chain when they're successful. """ + middlewares = list(middlewares) if len(middlewares) == 1 and not isinstance(middlewares[0], list): return middlewares[0] middlewares = [compose_middleware(m) if isinstance(m, list) else m for m in middlewares] async def handler(ctx, next_middleware=lambda ctx: ctx): - middleware_chain = functools.reduce( - lambda acc_next, cur: lambda context: cur(context, acc_next), reversed(middlewares + (next_middleware,)) - ) - return middleware_chain(ctx) + def reduceChain(acc_next, cur): + async def chainedMiddleware(context): + # Count the positional arguments to determine if the function is a handler or middleware. + all_args = cur.__code__.co_argcount + kwargs = len(cur.__defaults__) if cur.__defaults__ is not None else 0 + pos_args = all_args - kwargs + if pos_args == 2: + # Call the middleware with next and return the result + return ( + (await cur(context, acc_next)) if asyncio.iscoroutinefunction(cur) else cur(context, acc_next) + ) + else: + # Call the handler with ctx only, then call the remainder of the middleware chain + result = (await cur(context)) if asyncio.iscoroutinefunction(cur) else cur(context) + return (await acc_next(result)) if asyncio.iscoroutinefunction(acc_next) else acc_next(result) + + return chainedMiddleware + + middleware_chain = functools.reduce(reduceChain, reversed(middlewares + [next_middleware])) + return await middleware_chain(ctx) return handler diff --git a/tests/test_faas.py b/tests/test_faas.py index 1989ca2..f014d7c 100644 --- a/tests/test_faas.py +++ b/tests/test_faas.py @@ -21,7 +21,7 @@ import pytest -from nitric.faas import start, FunctionServer, EventContext, HttpContext +from nitric.faas import start, FunctionServer, HttpContext, compose_middleware, HttpResponse from nitricapi.nitric.faas.v1 import ( ServerMessage, @@ -46,6 +46,21 @@ def __init__(self): class EventClientTest(IsolatedAsyncioTestCase): + async def test_compose_middleware(self): + async def middleware(ctx: HttpContext, next) -> HttpContext: + ctx.res.status = 401 + return await next(ctx) + + async def handler(ctx: HttpContext) -> HttpContext: + ctx.res.body = "some text" + return ctx + + composed = compose_middleware(middleware, handler) + + ctx = HttpContext(response=HttpResponse(), request=None) + result = await composed(ctx) + assert result.res.status == 401 + def test_start_with_one_handler(self): mock_server_constructor = Mock() mock_server = Object()