Skip to content

Commit

Permalink
Merge pull request #52 from nitrictech/fix/middleware-composition
Browse files Browse the repository at this point in the history
Fix/middleware composition
  • Loading branch information
tjholm authored Oct 26, 2021
2 parents 069471d + f46dd9d commit 3b3e8b3
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
29 changes: 23 additions & 6 deletions nitric/faas.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,33 @@ def compose_middleware(*middlewares: Union[Middleware, List[Middleware]]) -> Mid
The resulting middleware will effectively be a chain of the provided middleware,
where each calls the next in the chain when they're successful.
"""
middlewares = list(middlewares)
if len(middlewares) == 1 and not isinstance(middlewares[0], list):
return middlewares[0]

middlewares = [compose_middleware(m) if isinstance(m, list) else m for m in middlewares]

async def handler(ctx, next_middleware=lambda ctx: ctx):
middleware_chain = functools.reduce(
lambda acc_next, cur: lambda context: cur(context, acc_next), reversed(middlewares + (next_middleware,))
)
return middleware_chain(ctx)
def reduce_chain(acc_next, cur):
async def chained_middleware(context):
# Count the positional arguments to determine if the function is a handler or middleware.
all_args = cur.__code__.co_argcount
kwargs = len(cur.__defaults__) if cur.__defaults__ is not None else 0
pos_args = all_args - kwargs
if pos_args == 2:
# Call the middleware with next and return the result
return (
(await cur(context, acc_next)) if asyncio.iscoroutinefunction(cur) else cur(context, acc_next)
)
else:
# Call the handler with ctx only, then call the remainder of the middleware chain
result = (await cur(context)) if asyncio.iscoroutinefunction(cur) else cur(context)
return (await acc_next(result)) if asyncio.iscoroutinefunction(acc_next) else acc_next(result)

return chained_middleware

middleware_chain = functools.reduce(reduce_chain, reversed(middlewares + [next_middleware]))
return await middleware_chain(ctx)

return handler

Expand Down Expand Up @@ -279,7 +296,7 @@ def start(self, *handlers: Union[Middleware, List[Middleware]]):
if not self._any_handler and not self._http_handler and not self._event_handler:
raise Exception("At least one handler function must be provided.")

asyncio.run(self.run())
asyncio.run(self._run())

@property
def _http_handler(self):
Expand All @@ -289,7 +306,7 @@ def _http_handler(self):
def _event_handler(self):
return self.__event_handler if self.__event_handler else self._any_handler

async def run(self):
async def _run(self):
"""Register a new FaaS worker with the Membrane, using the provided function as the handler."""
channel = new_default_channel()
client = FaasServiceStub(channel)
Expand Down
37 changes: 26 additions & 11 deletions tests/test_faas.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest

from nitric.faas import start, FunctionServer, EventContext, HttpContext
from nitric.faas import start, FunctionServer, HttpContext, compose_middleware, HttpResponse

from nitricapi.nitric.faas.v1 import (
ServerMessage,
Expand All @@ -46,6 +46,21 @@ def __init__(self):


class EventClientTest(IsolatedAsyncioTestCase):
async def test_compose_middleware(self):
async def middleware(ctx: HttpContext, next) -> HttpContext:
ctx.res.status = 401
return await next(ctx)

async def handler(ctx: HttpContext) -> HttpContext:
ctx.res.body = "some text"
return ctx

composed = compose_middleware(middleware, handler)

ctx = HttpContext(response=HttpResponse(), request=None)
result = await composed(ctx)
assert result.res.status == 401

def test_start_with_one_handler(self):
mock_server_constructor = Mock()
mock_server = Object()
Expand Down Expand Up @@ -94,7 +109,7 @@ def test_start_starts_event_loop(self):
mock_run.return_value = mock_run_coroutine

with patch("nitric.faas.compose_middleware", mock_compose):
with patch("nitric.faas.FunctionServer.run", mock_run):
with patch("nitric.faas.FunctionServer._run", mock_run):
with patch("asyncio.run", mock_asyncio_run):
FunctionServer().start(mock_handler)

Expand Down Expand Up @@ -124,7 +139,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_handler).run()
await FunctionServer().http(mock_handler)._run()

# gRPC channel created
mock_grpc_channel.assert_called_once()
Expand Down Expand Up @@ -165,7 +180,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()

# accept the init response from server
assert 1 == stream_calls
Expand Down Expand Up @@ -200,7 +215,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()

# accept the init response from server
assert 1 == stream_calls
Expand Down Expand Up @@ -235,7 +250,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()

# accept the init response from server
assert 1 == stream_calls
Expand Down Expand Up @@ -270,7 +285,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()

# accept the init response from server
assert 1 == stream_calls
Expand Down Expand Up @@ -305,7 +320,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()

# accept the init response from server
assert 1 == stream_calls
Expand Down Expand Up @@ -334,7 +349,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().event(mock_handler).run()
await FunctionServer().event(mock_handler)._run()

# accept the trigger response from server
assert 1 == stream_calls
Expand Down Expand Up @@ -373,7 +388,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_handler).run()
await FunctionServer().http(mock_handler)._run()

# accept the trigger response from server
assert 1 == stream_calls
Expand Down Expand Up @@ -414,7 +429,7 @@ async def mock_stream(self, request_iterator):
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
await FunctionServer().http(mock_handler).run()
await FunctionServer().http(mock_handler)._run()

# accept the trigger response from server
assert 1 == stream_calls
Expand Down

0 comments on commit 3b3e8b3

Please sign in to comment.