From b3b808afeb41d8f41b8de316cbbd5ee6a36e236a Mon Sep 17 00:00:00 2001 From: Bill Burke Date: Mon, 1 Jun 2020 16:38:39 -0400 Subject: [PATCH] Support FileRegions --- .../amazon/lambda/http/LambdaHttpHandler.java | 40 ++++++++-- .../resteasy/runtime/BaseFunction.java | 33 ++++++-- .../netty/runtime/virtual/VirtualChannel.java | 10 ++- .../virtual/VirtualClientConnection.java | 4 +- .../netty/runtime/virtual/VirtualMessage.java | 24 ++++++ integration-tests/amazon-lambda-http/pom.xml | 4 + .../lambda/AmazonLambdaSimpleTestCase.java | 25 +++++-- .../virtual-http-resteasy/pom.xml | 4 + .../io/quarkus/it/virtual/FunctionTest.java | 75 ++++++++----------- 9 files changed, 155 insertions(+), 64 deletions(-) create mode 100644 extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualMessage.java diff --git a/extensions/amazon-lambda-http/runtime/src/main/java/io/quarkus/amazon/lambda/http/LambdaHttpHandler.java b/extensions/amazon-lambda-http/runtime/src/main/java/io/quarkus/amazon/lambda/http/LambdaHttpHandler.java index a6b43e9e571e2..6aa3456478e77 100644 --- a/extensions/amazon-lambda-http/runtime/src/main/java/io/quarkus/amazon/lambda/http/LambdaHttpHandler.java +++ b/extensions/amazon-lambda-http/runtime/src/main/java/io/quarkus/amazon/lambda/http/LambdaHttpHandler.java @@ -3,18 +3,23 @@ import java.io.ByteArrayOutputStream; import java.net.InetSocketAddress; import java.net.URLEncoder; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import org.jboss.logging.Logger; + import com.amazonaws.serverless.proxy.internal.LambdaContainerHandler; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestHandler; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.HttpContent; @@ -28,10 +33,12 @@ import io.quarkus.amazon.lambda.http.model.AwsProxyResponse; import io.quarkus.amazon.lambda.http.model.Headers; import io.quarkus.netty.runtime.virtual.VirtualClientConnection; +import io.quarkus.netty.runtime.virtual.VirtualMessage; import io.quarkus.vertx.http.runtime.VertxHttpRecorder; @SuppressWarnings("unused") public class LambdaHttpHandler implements RequestHandler { + private static final Logger log = Logger.getLogger("quarkus.amazon.lambda.http"); private static Headers errorHeaders = new Headers(); static { @@ -50,6 +57,7 @@ public AwsProxyResponse handleRequest(AwsProxyRequest request, Context context) try { return nettyDispatch(connection, request); } catch (Exception e) { + log.error("Request Failure", e); return new AwsProxyResponse(500, errorHeaders, "{ \"message\": \"Internal Server Error\" }"); } finally { connection.close(); @@ -59,6 +67,7 @@ public AwsProxyResponse handleRequest(AwsProxyRequest request, Context context) private AwsProxyResponse nettyDispatch(VirtualClientConnection connection, AwsProxyRequest request) throws Exception { String path = request.getPath(); + //log.info("---- Got lambda request: " + path); if (request.getMultiValueQueryStringParameters() != null && !request.getMultiValueQueryStringParameters().isEmpty()) { StringBuilder sb = new StringBuilder(path); sb.append("?"); @@ -109,14 +118,14 @@ private AwsProxyResponse nettyDispatch(VirtualClientConnection connection, AwsPr connection.sendMessage(requestContent); AwsProxyResponse responseBuilder = new AwsProxyResponse(); ByteArrayOutputStream baos = null; + WritableByteChannel byteChannel = null; try { for (;;) { - // todo should we timeout? have a timeout config? //log.info("waiting for message"); - Object msg = connection.queue().poll(100, TimeUnit.MILLISECONDS); + VirtualMessage virtualMessage = connection.queue().poll(100, TimeUnit.MILLISECONDS); + if (virtualMessage == null) continue; + Object msg = virtualMessage.getMessage(); try { - if (msg == null) - continue; //log.info("Got message: " + msg.getClass().getName()); if (msg instanceof HttpResponse) { @@ -137,13 +146,22 @@ private AwsProxyResponse nettyDispatch(VirtualClientConnection connection, AwsPr HttpContent content = (HttpContent) msg; int readable = content.content().readableBytes(); if (baos == null && readable > 0) { - // todo what is right size? - baos = new ByteArrayOutputStream(500); + baos = createByteStream(); } for (int i = 0; i < readable; i++) { baos.write(content.content().readByte()); } } + if (msg instanceof FileRegion) { + FileRegion file = (FileRegion) msg; + if (file.count() > 0) { + if (baos == null) + baos = createByteStream(); + if (byteChannel == null) + byteChannel = Channels.newChannel(baos); + file.transferTo(byteChannel, 0); + } + } if (msg instanceof LastHttpContent) { if (baos != null) { if (isBinary(responseBuilder.getMultiValueHeaders().getFirst("Content-Type"))) { @@ -156,8 +174,10 @@ private AwsProxyResponse nettyDispatch(VirtualClientConnection connection, AwsPr return responseBuilder; } } finally { - if (msg != null) + if (msg != null) { + virtualMessage.completed(); ReferenceCountUtil.release(msg); + } } } } finally { @@ -167,6 +187,12 @@ private AwsProxyResponse nettyDispatch(VirtualClientConnection connection, AwsPr } } + private ByteArrayOutputStream createByteStream() { + ByteArrayOutputStream baos;// todo what is right size? + baos = new ByteArrayOutputStream(1000); + return baos; + } + private boolean isBinary(String contentType) { if (contentType != null) { int index = contentType.indexOf(';'); diff --git a/extensions/azure-functions-http/runtime/src/main/java/io/quarkus/azure/functions/resteasy/runtime/BaseFunction.java b/extensions/azure-functions-http/runtime/src/main/java/io/quarkus/azure/functions/resteasy/runtime/BaseFunction.java index 4aa03f9b918d8..1caa6ccb96dcc 100644 --- a/extensions/azure-functions-http/runtime/src/main/java/io/quarkus/azure/functions/resteasy/runtime/BaseFunction.java +++ b/extensions/azure-functions-http/runtime/src/main/java/io/quarkus/azure/functions/resteasy/runtime/BaseFunction.java @@ -3,6 +3,8 @@ import java.io.ByteArrayOutputStream; import java.io.PrintWriter; import java.io.StringWriter; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.util.Map; import java.util.Optional; import java.util.concurrent.TimeUnit; @@ -15,6 +17,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.HttpContent; @@ -24,6 +27,7 @@ import io.netty.handler.codec.http.LastHttpContent; import io.netty.util.ReferenceCountUtil; import io.quarkus.netty.runtime.virtual.VirtualClientConnection; +import io.quarkus.netty.runtime.virtual.VirtualMessage; import io.quarkus.runtime.Application; import io.quarkus.vertx.http.runtime.VertxHttpRecorder; @@ -96,14 +100,15 @@ protected HttpResponseMessage nettyDispatch(VirtualClientConnection connection, connection.sendMessage(requestContent); HttpResponseMessage.Builder responseBuilder = null; ByteArrayOutputStream baos = null; + WritableByteChannel byteChannel = null; try { for (;;) { // todo should we timeout? have a timeout config? //log.info("waiting for message"); - Object msg = connection.queue().poll(100, TimeUnit.MILLISECONDS); + VirtualMessage virtualMessage = connection.queue().poll(100, TimeUnit.MILLISECONDS); + if (virtualMessage == null) continue; + Object msg = virtualMessage.getMessage(); try { - if (msg == null) - continue; //log.info("Got message: " + msg.getClass().getName()); if (msg instanceof HttpResponse) { @@ -117,20 +122,32 @@ protected HttpResponseMessage nettyDispatch(VirtualClientConnection connection, HttpContent content = (HttpContent) msg; if (baos == null) { // todo what is right size? - baos = new ByteArrayOutputStream(500); + baos = createByteStream(); } int readable = content.content().readableBytes(); for (int i = 0; i < readable; i++) { baos.write(content.content().readByte()); } } + if (msg instanceof FileRegion) { + FileRegion file = (FileRegion) msg; + if (file.count() > 0) { + if (baos == null) + baos = createByteStream(); + if (byteChannel == null) + byteChannel = Channels.newChannel(baos); + file.transferTo(byteChannel, 0); + } + } if (msg instanceof LastHttpContent) { responseBuilder.body(baos.toByteArray()); return responseBuilder.build(); } } finally { - if (msg != null) + if (msg != null) { + virtualMessage.completed(); ReferenceCountUtil.release(msg); + } } } } finally { @@ -139,4 +156,10 @@ protected HttpResponseMessage nettyDispatch(VirtualClientConnection connection, } } } + + private ByteArrayOutputStream createByteStream() { + ByteArrayOutputStream baos; + baos = new ByteArrayOutputStream(500); + return baos; + } } diff --git a/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualChannel.java b/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualChannel.java index c0c364197ab72..439f7f27c229d 100644 --- a/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualChannel.java +++ b/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualChannel.java @@ -313,7 +313,15 @@ protected void doWrite(ChannelOutboundBuffer in) throws Exception { // It is possible the peer could have closed while we are writing, and in this case we should // simulate real socket behavior and ensure the sendMessage operation is failed. if (peer.isConnected()) { - peer.queue().add(ReferenceCountUtil.retain(msg)); + VirtualMessage virtualMessage = new VirtualMessage(msg); + ReferenceCountUtil.retain(msg); + peer.queue().add(virtualMessage); + // need to wait until client is finished with message + // Things like FileRegion get closed when they are removed from outbound buffer + // It sucks we have to synchronize the threads here with every message, + // but the buffer class isn't flexible enough to handle this scenario. + // Might not be a big deal. :) + virtualMessage.awaitComplete(); in.remove(); } else { if (exception == null) { diff --git a/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualClientConnection.java b/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualClientConnection.java index c19e91c81a01f..252eb1ab25f66 100644 --- a/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualClientConnection.java +++ b/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualClientConnection.java @@ -16,7 +16,7 @@ */ public class VirtualClientConnection { protected SocketAddress clientAddress; - protected BlockingQueue queue = new LinkedBlockingQueue<>(); + protected BlockingQueue queue = new LinkedBlockingQueue<>(); protected boolean connected = true; protected VirtualChannel peer; @@ -33,7 +33,7 @@ public SocketAddress clientAddress() { * * @return */ - public BlockingQueue queue() { + public BlockingQueue queue() { return queue; } diff --git a/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualMessage.java b/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualMessage.java new file mode 100644 index 0000000000000..6e7c126c84781 --- /dev/null +++ b/extensions/netty/runtime/src/main/java/io/quarkus/netty/runtime/virtual/VirtualMessage.java @@ -0,0 +1,24 @@ +package io.quarkus.netty.runtime.virtual; + +import java.util.concurrent.CompletableFuture; + +public class VirtualMessage { + private Object message; + private CompletableFuture future = new CompletableFuture<>(); + + public VirtualMessage(Object message) { + this.message = message; + } + + public Object getMessage() { + return message; + } + + public void completed() { + future.complete(null); + } + + public void awaitComplete() throws Exception { + future.get(); + } +} diff --git a/integration-tests/amazon-lambda-http/pom.xml b/integration-tests/amazon-lambda-http/pom.xml index 8b4d7fbeed028..a3db1ddc18fa1 100644 --- a/integration-tests/amazon-lambda-http/pom.xml +++ b/integration-tests/amazon-lambda-http/pom.xml @@ -22,6 +22,10 @@ io.quarkus quarkus-resteasy + + io.quarkus + quarkus-smallrye-openapi + io.quarkus quarkus-undertow diff --git a/integration-tests/amazon-lambda-http/src/test/java/io/quarkus/it/amazon/lambda/AmazonLambdaSimpleTestCase.java b/integration-tests/amazon-lambda-http/src/test/java/io/quarkus/it/amazon/lambda/AmazonLambdaSimpleTestCase.java index 952c1b61d5f06..c1e203d84684c 100644 --- a/integration-tests/amazon-lambda-http/src/test/java/io/quarkus/it/amazon/lambda/AmazonLambdaSimpleTestCase.java +++ b/integration-tests/amazon-lambda-http/src/test/java/io/quarkus/it/amazon/lambda/AmazonLambdaSimpleTestCase.java @@ -24,6 +24,16 @@ public void testGetText() throws Exception { testGetText("/hello"); } + @Test + public void testSwaggerUi() throws Exception { + // this tests the FileRegion support in the handler + AwsProxyRequest request = request("/swagger-ui/"); + AwsProxyResponse out = LambdaClient.invoke(AwsProxyResponse.class, request); + Assertions.assertEquals(out.getStatusCode(), 200); + Assertions.assertTrue(body(out).contains("Swagger UI")); + + } + private String body(AwsProxyResponse response) { if (!response.isBase64Encoded()) return response.getBody(); @@ -31,20 +41,23 @@ private String body(AwsProxyResponse response) { } private void testGetText(String path) { - AwsProxyRequest request = new AwsProxyRequest(); - request.setHttpMethod("GET"); - request.setPath(path); + AwsProxyRequest request = request(path); AwsProxyResponse out = LambdaClient.invoke(AwsProxyResponse.class, request); Assertions.assertEquals(out.getStatusCode(), 200); Assertions.assertEquals(body(out), "hello"); Assertions.assertTrue(out.getMultiValueHeaders().getFirst("Content-Type").startsWith("text/plain")); } - @Test - public void test404() throws Exception { + private AwsProxyRequest request(String path) { AwsProxyRequest request = new AwsProxyRequest(); request.setHttpMethod("GET"); - request.setPath("/nowhere"); + request.setPath(path); + return request; + } + + @Test + public void test404() throws Exception { + AwsProxyRequest request = request("/nowhere"); AwsProxyResponse out = LambdaClient.invoke(AwsProxyResponse.class, request); Assertions.assertEquals(out.getStatusCode(), 404); } diff --git a/integration-tests/virtual-http-resteasy/pom.xml b/integration-tests/virtual-http-resteasy/pom.xml index 07812e09b691c..3dafaae9369c0 100644 --- a/integration-tests/virtual-http-resteasy/pom.xml +++ b/integration-tests/virtual-http-resteasy/pom.xml @@ -23,6 +23,10 @@ io.quarkus quarkus-azure-functions-http + + io.quarkus + quarkus-smallrye-openapi + com.microsoft.azure.functions azure-functions-java-library diff --git a/integration-tests/virtual-http-resteasy/src/test/java/io/quarkus/it/virtual/FunctionTest.java b/integration-tests/virtual-http-resteasy/src/test/java/io/quarkus/it/virtual/FunctionTest.java index ed41833c3d548..507198a3a94f8 100644 --- a/integration-tests/virtual-http-resteasy/src/test/java/io/quarkus/it/virtual/FunctionTest.java +++ b/integration-tests/virtual-http-resteasy/src/test/java/io/quarkus/it/virtual/FunctionTest.java @@ -27,20 +27,22 @@ @QuarkusTest public class FunctionTest { @Test - public void testJaxrs() throws Exception { - String uri = "https://foo.com/hello"; - testGET(uri); - testPOST(uri); - } - - @Test - public void testNotFound() { + public void testSwagger() { final HttpRequestMessageMock req = new HttpRequestMessageMock(); - req.setUri(URI.create("https://nowhere.com/badroute")); + req.setUri(URI.create("https://foo.com/swagger-ui/")); req.setHttpMethod(HttpMethod.GET); // Invoke - final HttpResponseMessage ret = new Function().run(req, new ExecutionContext() { + final HttpResponseMessage ret = new Function().run(req, createContext()); + + // Verify + Assertions.assertEquals(ret.getStatus(), HttpStatus.OK); + String body = new String((byte[]) ret.getBody(), StandardCharsets.UTF_8); + Assertions.assertTrue(body.contains("Swagger UI")); + } + + private ExecutionContext createContext() { + return new ExecutionContext() { @Override public Logger getLogger() { return null; @@ -55,7 +57,24 @@ public String getInvocationId() { public String getFunctionName() { return null; } - }); + }; + } + + @Test + public void testJaxrs() throws Exception { + String uri = "https://foo.com/hello"; + testGET(uri); + testPOST(uri); + } + + @Test + public void testNotFound() { + final HttpRequestMessageMock req = new HttpRequestMessageMock(); + req.setUri(URI.create("https://nowhere.com/badroute")); + req.setHttpMethod(HttpMethod.GET); + + // Invoke + final HttpResponseMessage ret = new Function().run(req, createContext()); // Verify Assertions.assertEquals(ret.getStatus(), HttpStatus.NOT_FOUND); @@ -79,22 +98,7 @@ private void testGET(String uri) { req.setHttpMethod(HttpMethod.GET); // Invoke - final HttpResponseMessage ret = new Function().run(req, new ExecutionContext() { - @Override - public Logger getLogger() { - return null; - } - - @Override - public String getInvocationId() { - return null; - } - - @Override - public String getFunctionName() { - return null; - } - }); + final HttpResponseMessage ret = new Function().run(req, createContext()); // Verify Assertions.assertEquals(ret.getStatus(), HttpStatus.OK); @@ -112,22 +116,7 @@ private void testPOST(String uri) { req.getHeaders().put("Content-Type", "text/plain"); // Invoke - final HttpResponseMessage ret = new Function().run(req, new ExecutionContext() { - @Override - public Logger getLogger() { - return null; - } - - @Override - public String getInvocationId() { - return null; - } - - @Override - public String getFunctionName() { - return null; - } - }); + final HttpResponseMessage ret = new Function().run(req, createContext()); // Verify Assertions.assertEquals(ret.getStatus(), HttpStatus.OK);