diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 4da2b63b600..0d192ab3b11 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -22,6 +22,8 @@ dependencies { api project(":inject") api project(":inject-java-test") api project(":http-server") + api project(":http-server-netty") + api project(":jackson-databind") api project(":router") api project(":runtime") diff --git a/benchmarks/src/jmh/java/io/micronaut/http/server/stack/FullHttpStackBenchmark.java b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/FullHttpStackBenchmark.java new file mode 100644 index 00000000000..ae33a657cd8 --- /dev/null +++ b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/FullHttpStackBenchmark.java @@ -0,0 +1,195 @@ +package io.micronaut.http.server.stack; + +import io.micronaut.context.ApplicationContext; +import io.micronaut.http.server.netty.NettyHttpServer; +import io.micronaut.runtime.server.EmbeddedServer; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.concurrent.FastThreadLocalThread; +import org.junit.jupiter.api.Assertions; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.profile.AsyncProfiler; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +public class FullHttpStackBenchmark { + @Benchmark + public void test(Holder holder) { + ByteBuf response = holder.exchange(); + if (!holder.responseBytes.equals(response)) { + throw new AssertionError("Response did not match"); + } + response.release(); + } + + public static void main(String[] args) throws Exception { + JmhFastThreadLocalExecutor exec = new JmhFastThreadLocalExecutor(1, "init-test"); + exec.submit(() -> { + // simple test that everything works properly + for (StackFactory stack : StackFactory.values()) { + Holder holder = new Holder(); + holder.stack = stack; + holder.setUp(); + holder.tearDown(); + } + return null; + }).get(); + exec.shutdown(); + + Options opt = new OptionsBuilder() + .include(FullHttpStackBenchmark.class.getName() + ".*") + .warmupIterations(20) + .measurementIterations(30) + .mode(Mode.AverageTime) + .timeUnit(TimeUnit.NANOSECONDS) + .addProfiler(AsyncProfiler.class, "libPath=/home/yawkat/bin/async-profiler-2.9-linux-x64/build/libasyncProfiler.so;output=flamegraph") + .forks(1) + .jvmArgsAppend("-Djmh.executor=CUSTOM", "-Djmh.executor.class=" + JmhFastThreadLocalExecutor.class.getName()) + .build(); + + new Runner(opt).run(); + } + + @State(Scope.Thread) + public static class Holder { + @Param({"MICRONAUT"/*, "PURE_NETTY"*/}) + StackFactory stack = StackFactory.MICRONAUT; + + AutoCloseable ctx; + EmbeddedChannel channel; + ByteBuf requestBytes; + ByteBuf responseBytes; + + @Setup + public void setUp() { + if (!(Thread.currentThread() instanceof FastThreadLocalThread)) { + throw new IllegalStateException("Should run on a netty FTL thread"); + } + + Stack stack = this.stack.openChannel(); + ctx = stack.closeable; + channel = stack.serverChannel; + + channel.freezeTime(); + + EmbeddedChannel clientChannel = new EmbeddedChannel(); + clientChannel.pipeline().addLast(new HttpClientCodec()); + clientChannel.pipeline().addLast(new HttpObjectAggregator(1000)); + + FullHttpRequest request = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, + HttpMethod.POST, + "/search/find", + Unpooled.wrappedBuffer("{\"haystack\": [\"xniomb\", \"seelzp\", \"nzogdq\", \"omblsg\", \"idgtlm\", \"ydonzo\"], \"needle\": \"idg\"}".getBytes(StandardCharsets.UTF_8)) + ); + request.headers().add(HttpHeaderNames.CONTENT_LENGTH, request.content().readableBytes()); + request.headers().add(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + request.headers().add(HttpHeaderNames.ACCEPT, HttpHeaderValues.APPLICATION_JSON); + clientChannel.writeOutbound(request); + clientChannel.flushOutbound(); + + requestBytes = PooledByteBufAllocator.DEFAULT.buffer(); + while (true) { + ByteBuf part = clientChannel.readOutbound(); + if (part == null) { + break; + } + requestBytes.writeBytes(part); + } + + // sanity check: run req/resp once and see that the response is correct + responseBytes = exchange(); + clientChannel.writeInbound(responseBytes.retainedDuplicate()); + FullHttpResponse response = clientChannel.readInbound(); + //System.out.println(response); + //System.out.println(response.content().toString(StandardCharsets.UTF_8)); + Assertions.assertEquals(HttpResponseStatus.OK, response.status()); + Assertions.assertEquals("application/json", response.headers().get(HttpHeaderNames.CONTENT_TYPE)); + Assertions.assertEquals("keep-alive", response.headers().get(HttpHeaderNames.CONNECTION)); + String expectedResponseBody = "{\"listIndex\":4,\"stringIndex\":0}"; + Assertions.assertEquals(expectedResponseBody, response.content().toString(StandardCharsets.UTF_8)); + Assertions.assertEquals(expectedResponseBody.length(), response.headers().getInt(HttpHeaderNames.CONTENT_LENGTH)); + response.release(); + } + + private ByteBuf exchange() { + channel.writeInbound(requestBytes.retainedDuplicate()); + channel.runPendingTasks(); + CompositeByteBuf response = PooledByteBufAllocator.DEFAULT.compositeBuffer(); + while (true) { + ByteBuf part = channel.readOutbound(); + if (part == null) { + break; + } + response.addComponent(true, part); + } + return response; + } + + @TearDown + public void tearDown() throws Exception { + ctx.close(); + requestBytes.release(); + responseBytes.release(); + } + } + + public enum StackFactory { + MICRONAUT { + @Override + Stack openChannel() { + ApplicationContext ctx = ApplicationContext.run(Map.of( + "spec.name", "FullHttpStackBenchmark", + "micronaut.server.date-header", false // disabling this makes the response identical each time + )); + EmbeddedServer server = ctx.getBean(EmbeddedServer.class); + EmbeddedChannel channel = ((NettyHttpServer) server).buildEmbeddedChannel(false); + return new Stack(channel, ctx); + } + }, + PURE_NETTY { + @Override + Stack openChannel() { + HttpObjectAggregator aggregator = new HttpObjectAggregator(10_000_000); + aggregator.setMaxCumulationBufferComponents(100000); + EmbeddedChannel channel = new EmbeddedChannel(); + channel.pipeline().addLast(new HttpServerCodec()); + channel.pipeline().addLast(aggregator); + channel.pipeline().addLast(new RequestHandler()); + return new Stack(channel, () -> { + }); + } + }; + + abstract Stack openChannel(); + } + + private record Stack(EmbeddedChannel serverChannel, AutoCloseable closeable) { + } + +} diff --git a/benchmarks/src/jmh/java/io/micronaut/http/server/stack/JmhFastThreadLocalExecutor.java b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/JmhFastThreadLocalExecutor.java new file mode 100644 index 00000000000..43850330c9b --- /dev/null +++ b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/JmhFastThreadLocalExecutor.java @@ -0,0 +1,26 @@ +package io.micronaut.http.server.stack; + +import io.micronaut.core.annotation.NonNull; +import io.netty.util.concurrent.FastThreadLocalThread; + +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +public final class JmhFastThreadLocalExecutor extends ThreadPoolExecutor { + public JmhFastThreadLocalExecutor(int maxThreads, String prefix) { + super(maxThreads, maxThreads, + 60L, TimeUnit.SECONDS, + new SynchronousQueue<>(), + new ThreadFactory() { + final AtomicInteger counter = new AtomicInteger(); + + @Override + public Thread newThread(@NonNull Runnable r) { + return new FastThreadLocalThread(r, prefix + "-jmh-worker-ftl-" + counter.incrementAndGet()); + } + }); + } +} diff --git a/benchmarks/src/jmh/java/io/micronaut/http/server/stack/RequestHandler.java b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/RequestHandler.java new file mode 100644 index 00000000000..9354538a47b --- /dev/null +++ b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/RequestHandler.java @@ -0,0 +1,120 @@ +package io.micronaut.http.server.stack; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; +import com.fasterxml.jackson.databind.ObjectWriter; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslProvider; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.util.List; + +@ChannelHandler.Sharable +final class RequestHandler extends SimpleChannelInboundHandler { + private final ObjectMapper objectMapper = new ObjectMapper(); + private final ObjectReader reader = objectMapper.readerFor(SearchController.Input.class); + private final ObjectWriter writerResult = objectMapper.writerFor(SearchController.Result.class); + private final ObjectWriter writerStatus = objectMapper.writerFor(Status.class); + + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception { + FullHttpResponse response = computeResponse(ctx, msg); + response.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); + response.headers().add(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes()); + ctx.writeAndFlush(response, ctx.voidPromise()); + ctx.read(); + } + + private FullHttpResponse computeResponse(ChannelHandlerContext ctx, FullHttpRequest msg) { + try { + String path = URI.create(msg.uri()).getPath(); + if (path.equals("/search/find")) { + return computeResponseSearch(ctx, msg); + } + if (path.equals("/status")) { + return computeResponseStatus(ctx, msg); + } + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND); + } catch (Exception e) { + e.printStackTrace(); + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR); + } + } + + private FullHttpResponse computeResponseSearch(ChannelHandlerContext ctx, FullHttpRequest msg) throws IOException { + if (!msg.method().equals(HttpMethod.POST)) { + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.METHOD_NOT_ALLOWED); + } + if (!msg.headers().contains(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON, true)) { + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + } + + ByteBuf content = msg.content(); + SearchController.Input input; + if (content.hasArray()) { + input = reader.readValue(content.array(), content.readerIndex() + content.arrayOffset(), content.readableBytes()); + } else { + input = reader.readValue((InputStream) new ByteBufInputStream(content)); + } + + SearchController.Result result = find(input.haystack(), input.needle()); + if (result == null) { + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND); + } else { + return serialize(ctx, writerResult, result); + } + } + + private FullHttpResponse serialize(ChannelHandlerContext ctx, ObjectWriter writer, Object result) throws IOException { + ByteBuf buffer = ctx.alloc().buffer(); + writer.writeValue((OutputStream) new ByteBufOutputStream(buffer), result); + DefaultFullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); + response.headers().add(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + return response; + } + + private FullHttpResponse computeResponseStatus(ChannelHandlerContext ctx, FullHttpRequest msg) throws IOException { + if (!msg.method().equals(HttpMethod.GET)) { + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.METHOD_NOT_ALLOWED); + } + + Status status = new Status( + ctx.channel().getClass().getName(), + SslContext.defaultServerProvider() + ); + + return serialize(ctx, writerStatus, status); + } + + private static SearchController.Result find(List haystack, String needle) { + for (int listIndex = 0; listIndex < haystack.size(); listIndex++) { + String s = haystack.get(listIndex); + int stringIndex = s.indexOf(needle); + if (stringIndex != -1) { + return new SearchController.Result(listIndex, stringIndex); + } + } + return null; + } + + record Status(String channelImplementation, + SslProvider sslProvider) { + } +} diff --git a/benchmarks/src/jmh/java/io/micronaut/http/server/stack/SearchController.java b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/SearchController.java new file mode 100644 index 00000000000..ac23997ad09 --- /dev/null +++ b/benchmarks/src/jmh/java/io/micronaut/http/server/stack/SearchController.java @@ -0,0 +1,39 @@ +package io.micronaut.http.server.stack; + +import io.micronaut.context.annotation.Requires; +import io.micronaut.core.annotation.Introspected; +import io.micronaut.http.HttpResponse; +import io.micronaut.http.MutableHttpResponse; +import io.micronaut.http.annotation.Body; +import io.micronaut.http.annotation.Controller; +import io.micronaut.http.annotation.Post; + +import java.util.List; + +@Controller("/search") +@Requires(property = "spec.name", value = "FullHttpStackBenchmark") +public class SearchController { + @Post("find") + public HttpResponse find(@Body Input input) { + return find(input.haystack, input.needle); + } + + private static MutableHttpResponse find(List haystack, String needle) { + for (int listIndex = 0; listIndex < haystack.size(); listIndex++) { + String s = haystack.get(listIndex); + int stringIndex = s.indexOf(needle); + if (stringIndex != -1) { + return HttpResponse.ok(new Result(listIndex, stringIndex)); + } + } + return HttpResponse.notFound(); + } + + @Introspected + record Input(List haystack, String needle) { + } + + @Introspected + record Result(int listIndex, int stringIndex) { + } +} diff --git a/core/src/main/java/io/micronaut/core/execution/DelayedExecutionFlow.java b/core/src/main/java/io/micronaut/core/execution/DelayedExecutionFlow.java new file mode 100644 index 00000000000..44490372deb --- /dev/null +++ b/core/src/main/java/io/micronaut/core/execution/DelayedExecutionFlow.java @@ -0,0 +1,44 @@ +/* + * Copyright 2017-2023 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.core.execution; + +import io.micronaut.core.annotation.Nullable; + +/** + * {@link ExecutionFlow} that can be completed similar to a + * {@link java.util.concurrent.CompletableFuture}. + * + * @param The type of this flow + */ +public sealed interface DelayedExecutionFlow extends ExecutionFlow permits DelayedExecutionFlowImpl { + static DelayedExecutionFlow create() { + return new DelayedExecutionFlowImpl<>(); + } + + /** + * Complete this flow normally. + * + * @param result The result value + */ + void complete(@Nullable T result); + + /** + * Complete this flow with an exception. + * + * @param exc The exception + */ + void completeExceptionally(Throwable exc); +} diff --git a/core/src/main/java/io/micronaut/core/execution/DelayedExecutionFlowImpl.java b/core/src/main/java/io/micronaut/core/execution/DelayedExecutionFlowImpl.java new file mode 100644 index 00000000000..10ccc9fe7b1 --- /dev/null +++ b/core/src/main/java/io/micronaut/core/execution/DelayedExecutionFlowImpl.java @@ -0,0 +1,449 @@ +/* + * Copyright 2017-2023 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.core.execution; + +import io.micronaut.core.annotation.NonNull; +import io.micronaut.core.annotation.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.function.Supplier; + +@SuppressWarnings("rawtypes") +final class DelayedExecutionFlowImpl implements DelayedExecutionFlow { + private static final Logger LOG = LoggerFactory.getLogger(DelayedExecutionFlowImpl.class); + + /** + * Object used as a stand-in for a {@code null} completion to distinguish it from the + * uncompleted state. + */ + private static final Object NULL = new Object(); + + /** + * The head of the linked list of steps in this flow. + */ + private final Head head = new Head(); + /** + * The tail of the linked list of steps in this flow. + */ + private Step tail = head; + + /** + * Perform the given step with the given item. Continue on until there is either no more steps, + * either because onComplete was hit or because the consumer is not finished adding all the + * steps, or until a step does not finish immediately, e.g. flatMap returning a non-immediate + * flow. + * + * @param step The step to execute first + * @param item The input item for the step + */ + private static void work(Step step, Object item) { + while (true) { + item = step.apply(item); + if (item == null) { + // step suspended + break; + } + step = step.atomicSetOutput(item); + if (step == null) { + break; + } + } + } + + /** + * Complete this flow with the given result. + * + * @param result The result object. May be a {@link Failure}, {@link #NULL}, or any other + * successful value. + */ + private void complete0(@NonNull Object result) { + Step immediateStep = head.atomicSetOutput(result); + if (immediateStep != null) { + work(immediateStep, result); + } + } + + @Override + public void complete(T result) { + complete0(result == null ? NULL : result); + } + + @Override + public void completeExceptionally(Throwable exc) { + complete0(new Failure(exc)); + } + + /** + * Add a new step to this flow. + * + * @param next The new step + * @param The return type of the flow for generics support + * @return This flow + */ + @SuppressWarnings("unchecked") + private ExecutionFlow next(Step next) { + Step oldTail = tail; + tail = next; + Object output = oldTail.atomicSetNext(next); + if (output != null) { + work(next, output); + } + return (ExecutionFlow) this; + } + + @Override + public ExecutionFlow map(Function transformer) { + return next(new Map(transformer)); + } + + @SuppressWarnings("unchecked") + @Override + public ExecutionFlow flatMap(Function> transformer) { + return next(new FlatMap((Function) transformer)); + } + + @Override + public ExecutionFlow then(Supplier> supplier) { + return next(new Then<>(supplier)); + } + + @Override + public ExecutionFlow onErrorResume(Function> fallback) { + return next(new OnErrorResume(fallback)); + } + + @Override + public ExecutionFlow putInContext(String key, Object value) { + return this; + } + + @Override + public void onComplete(BiConsumer fn) { + next(new OnComplete<>(fn)); + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public ImperativeExecutionFlow tryComplete() { + Object tailOutput = tail.output; + if (tailOutput != null) { + if (tailOutput instanceof Failure failure) { + return (ImperativeExecutionFlow) new ImperativeExecutionFlowImpl(null, failure.t); + } else if (tailOutput == NULL) { + return (ImperativeExecutionFlow) new ImperativeExecutionFlowImpl(null, null); + } else { + return (ImperativeExecutionFlow) new ImperativeExecutionFlowImpl(tailOutput, null); + } + } else { + return null; + } + } + + /** + * Special wrapper for exception results. + * + * @param t The exception of the failure + */ + private record Failure(Throwable t) { + } + + private abstract static class Step { + /** + * The next step to take, or {@code null} if there is no next step yet. + */ + private volatile Step next; + /** + * The output of this step, or {@code null} if this step has not completed yet. + */ + private volatile Object output; + + /** + * Apply this step. Must call one of {@link #returnImmediate}, {@link #returnFlow}, + * {@link #returnError} or {@link #returnUnchanged}. + * + * @param input The input for the step + * @return The return value of the {@code return*} method called + */ + abstract Object apply(Object input); + + /** + * Atomically set the output of this step. If this returns non-null, the caller must call + * {@link #work(Step, Object)} with the returned step. + * + * @param output The output of this step + * @return The next step to execute using {@link #work(Step, Object)}, or {@code null} if + * the next step will be executed later + */ + @Nullable + final Step atomicSetOutput(Object output) { + if (this.output != null) { + // this is a best-effort check, the output field isn't always set + throw new IllegalStateException("Already completed"); + } + Step next = this.next; + if (next != null) { + return next; + } + this.output = output; + next = this.next; + if (next != null) { + // another thread completed at the same time! one or both threads hit this sync + // block. + synchronized (this) { + // deconfliction path + next = this.next; + if (next != null) { + // our sync block was executed first, unset output so the other thread aborts + this.output = null; + return next; + } + } + } + // no next step yet. + return null; + } + + /** + * Atomically set the next step. If this returns non-null, the caller must call + * {@link #work(Step, Object)} with the returned output value. + * + * @param next The next step to execute + * @return The output value of this step, to be passed to {@link #work(Step, Object)}, or + * {@code null} if the output is not yet known and the given step will be executed later + */ + @Nullable + final Object atomicSetNext(Step next) { + if (this.next != null) { + // this is a best-effort check, the next field isn't always set + throw new IllegalStateException("Already added a next step"); + } + Object output = this.output; + if (output != null) { + return output; + } + this.next = next; + output = this.output; + if (output != null) { + // another thread completed at the same time! one or both threads hit this sync + // block. + synchronized (this) { + // deconfliction path + output = this.output; + if (output != null) { + // our sync block was executed first, unset next so the other thread aborts + this.next = null; + return output; + } + } + } + // no output yet. + return null; + } + + /** + * Return a flow from this step (e.g. from flatMap). + * + * @param outputFlow The flow to return + * @return The value to return from {@link #work} + */ + final Object returnFlow(ExecutionFlow outputFlow) { + ImperativeExecutionFlow complete = outputFlow.tryComplete(); + if (complete != null) { + Throwable error = complete.getError(); + if (error == null) { + return returnImmediate(complete.getValue()); + } else { + return returnError(error); + } + } + + outputFlow.onComplete((v, t) -> { + Object result; + if (t == null) { + result = v == null ? NULL : v; + } else { + result = new Failure(t); + } + Step step = atomicSetOutput(result); + if (step != null) { + work(step, result); + } + }); + return null; + } + + /** + * Return an immediate successful value from this step (e.g. from map). + * + * @param o The value to return + * @return The value to return from {@link #work} + */ + final Object returnImmediate(@Nullable Object o) { + return o == null ? NULL : o; + } + + /** + * Signal that this step made no change to the input (e.g. a {@code map} when the flow has + * an error). + * + * @param input The input passed to {@link #apply} + * @return The value to return from {@link #work} + */ + final Object returnUnchanged(Object input) { + return input; + } + + /** + * Return an immediate failed value from this step (e.g. from map). + * + * @param e The exception to return + * @return The value to return from {@link #work} + */ + final Object returnError(Throwable e) { + return new Failure(e); + } + } + + /** + * Mock step used as the head of the linked list of steps. + */ + private static final class Head extends Step { + @Override + Object apply(Object input) { + throw new UnsupportedOperationException(); + } + } + + private static final class Map extends Step { + private final Function transformer; + + private Map(Function transformer) { + this.transformer = transformer; + } + + @SuppressWarnings("unchecked") + @Override + Object apply(Object input) { + try { + if (input instanceof Failure) { + return returnUnchanged(input); + } else if (input == NULL) { + return returnImmediate(transformer.apply(null)); + } else { + return returnImmediate(transformer.apply(input)); + } + } catch (Exception e) { + return returnError(e); + } + } + } + + private static final class FlatMap extends Step { + private final Function transformer; + + private FlatMap(Function transformer) { + this.transformer = transformer; + } + + @Override + Object apply(Object input) { + if (input instanceof Failure) { + return returnUnchanged(input); + } else { + try { + if (input == NULL) { + return returnFlow(transformer.apply(null)); + } else { + return returnFlow(transformer.apply(input)); + } + } catch (Exception e) { + return returnError(e); + } + } + } + } + + private static final class Then extends Step { + private final Supplier> transformer; + + private Then(Supplier> transformer) { + this.transformer = transformer; + } + + @Override + Object apply(Object input) { + if (input instanceof Failure) { + return returnUnchanged(input); + } else { + try { + return returnFlow(transformer.get()); + } catch (Exception e) { + return returnError(e); + } + } + } + } + + private static final class OnErrorResume extends Step { + private final Function> fallback; + + private OnErrorResume(Function> fallback) { + this.fallback = fallback; + } + + @Override + Object apply(Object input) { + if (input instanceof Failure failure) { + try { + return returnFlow(fallback.apply(failure.t)); + } catch (Exception e) { + return returnError(e); + } + } else { + return returnUnchanged(input); + } + } + } + + private static final class OnComplete extends Step { + private final BiConsumer consumer; + + public OnComplete(BiConsumer consumer) { + this.consumer = consumer; + } + + @SuppressWarnings("unchecked") + @Override + Object apply(Object input) { + try { + if (input instanceof Failure failure) { + consumer.accept(null, failure.t); + } else if (input == NULL) { + consumer.accept(null, null); + } else { + consumer.accept((E) input, null); + } + } catch (Exception e) { + LOG.error("Failed to execute onComplete", e); + } + return null; + } + } +} diff --git a/core/src/main/java/io/micronaut/core/execution/ExecutionFlow.java b/core/src/main/java/io/micronaut/core/execution/ExecutionFlow.java index 892bb361f91..a9dea9633b2 100644 --- a/core/src/main/java/io/micronaut/core/execution/ExecutionFlow.java +++ b/core/src/main/java/io/micronaut/core/execution/ExecutionFlow.java @@ -83,7 +83,7 @@ static ExecutionFlow empty() { */ @NonNull static ExecutionFlow async(@NonNull Executor executor, @NonNull Supplier> supplier) { - CompletableFuture completableFuture = new CompletableFuture<>(); + DelayedExecutionFlow completableFuture = DelayedExecutionFlow.create(); executor.execute(() -> supplier.get().onComplete((t, throwable) -> { if (throwable != null) { if (throwable instanceof CompletionException completionException) { @@ -94,7 +94,7 @@ static ExecutionFlow async(@NonNull Executor executor, @NonNull Supplier< completableFuture.complete(t); } })); - return CompletableFutureExecutionFlow.just(completableFuture); + return completableFuture; } /** @@ -176,9 +176,10 @@ default CompletableFuture toCompletableFuture() { if (throwable instanceof CompletionException completionException) { throwable = completionException.getCause(); } - CompletableFuture.failedFuture(throwable); + completableFuture.completeExceptionally(throwable); + } else { + completableFuture.complete(value); } - CompletableFuture.completedFuture(value); }); return completableFuture; } diff --git a/core/src/test/groovy/io/micronaut/core/execution/DelayedExecutionFlowSpec.groovy b/core/src/test/groovy/io/micronaut/core/execution/DelayedExecutionFlowSpec.groovy new file mode 100644 index 00000000000..56c9c9a4dbe --- /dev/null +++ b/core/src/test/groovy/io/micronaut/core/execution/DelayedExecutionFlowSpec.groovy @@ -0,0 +1,81 @@ +package io.micronaut.core.execution + +import org.apache.groovy.internal.util.Function +import spock.lang.Specification + +import java.util.concurrent.CompletableFuture + +class DelayedExecutionFlowSpec extends Specification { + def "single thread permutations"(List orderOfCompletion) { + given: + List> futures = [null, new CompletableFuture(), new CompletableFuture(), new CompletableFuture()] + List results = ["step0", "step2", "step3", "step5"] + ExecutionFlow inputFlow = new DelayedExecutionFlowImpl<>() + ExecutionFlow flowStep2 = CompletableFutureExecutionFlow.just(futures[1]) + ExecutionFlow flowStep3 = CompletableFutureExecutionFlow.just(futures[2]) + ExecutionFlow flowStep5 = CompletableFutureExecutionFlow.just(futures[3]) + String output = null + List, ExecutionFlow>> permTestSteps = [ + (ExecutionFlow prev) -> prev.map { + assert it == "step0" + return "step1" + }, + (ExecutionFlow prev) -> prev.flatMap { + assert it == "step1" + return flowStep2 + }, + (ExecutionFlow prev) -> prev.then { + return flowStep3 + }, + (ExecutionFlow prev) -> prev.map { + assert it == "step3" + throw new RuntimeException("step4") + }, + (ExecutionFlow prev) -> prev.map { + throw new AssertionError("should not be called") + }, + (ExecutionFlow prev) -> prev.flatMap { + throw new AssertionError("should not be called") + }, + (ExecutionFlow prev) -> prev.then { + throw new AssertionError("should not be called") + }, + (ExecutionFlow prev) -> prev.onErrorResume { + assert it.message == "step4" + return flowStep5 + }, + (ExecutionFlow prev) -> prev.onComplete((s, t) -> output = s), + ] + + ExecutionFlow flow = inputFlow + for (int i = 0; i < permTestSteps.size(); i++) { + for (int j = 0; j < orderOfCompletion.size(); j++) { + if (orderOfCompletion[j] == i) { + if (j == 0) { + inputFlow.complete(results[j]) + } else { + futures[j].complete(results[j]) + } + } + } + flow = permTestSteps[i](flow) + } + + where: + orderOfCompletion << powerSet([0, 1, 2, 3, 4, 5, 6, 7, 8], 4) + } + + private static List> powerSet(List base, int exp) { + if (exp == 0) { + return [[]] + } + List> output = [] + List> next = powerSet(base, exp - 1) + for (T t : base) { + for (List head : next) { + output.add(head + t) + } + } + return output + } +} diff --git a/http-netty/build.gradle b/http-netty/build.gradle index d4edb612b48..3732901ad0d 100644 --- a/http-netty/build.gradle +++ b/http-netty/build.gradle @@ -21,6 +21,7 @@ dependencies { testImplementation project(":runtime") testImplementation project(":websocket") + testImplementation project(":jackson-databind") } spotless { diff --git a/http-netty/src/main/java/io/micronaut/http/netty/NettyHttpHeaders.java b/http-netty/src/main/java/io/micronaut/http/netty/NettyHttpHeaders.java index 0a46c1ebb34..e777fb3eb37 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/NettyHttpHeaders.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/NettyHttpHeaders.java @@ -20,6 +20,7 @@ import io.micronaut.core.convert.ConversionService; import io.micronaut.core.type.MutableHeaders; import io.micronaut.http.HttpHeaderValues; +import io.micronaut.http.HttpHeaders; import io.micronaut.http.MediaType; import io.micronaut.http.MutableHttpHeaders; import io.netty.handler.codec.http.DefaultHttpHeaders; @@ -242,4 +243,17 @@ public ConversionService getConversionService() { public void setConversionService(ConversionService conversionService) { this.conversionService = conversionService; } + + @Override + public Optional contentType() { + // optimization to avoid ConversionService + String str = get(HttpHeaders.CONTENT_TYPE); + if (str != null) { + try { + return Optional.of(MediaType.of(str)); + } catch (IllegalArgumentException ignored) { + } + } + return Optional.empty(); + } } diff --git a/http-netty/src/main/java/io/micronaut/http/netty/reactive/HandlerPublisher.java b/http-netty/src/main/java/io/micronaut/http/netty/reactive/HandlerPublisher.java index 9fe95e86902..1c826cfb1c6 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/reactive/HandlerPublisher.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/reactive/HandlerPublisher.java @@ -21,7 +21,6 @@ import io.netty.handler.codec.http.HttpContent; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.EventExecutor; -import io.netty.util.internal.TypeParameterMatcher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.slf4j.Logger; @@ -32,7 +31,15 @@ import java.util.Queue; import java.util.concurrent.atomic.AtomicBoolean; -import static io.micronaut.http.netty.reactive.HandlerPublisher.State.*; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.BUFFERING; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.DEMANDING; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.DONE; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.DRAINING; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.IDLE; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.NO_CONTEXT; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.NO_SUBSCRIBER; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.NO_SUBSCRIBER_ERROR; +import static io.micronaut.http.netty.reactive.HandlerPublisher.State.NO_SUBSCRIBER_OR_CONTEXT; /** * Publisher for a Netty Handler. @@ -62,7 +69,7 @@ * @since 1.0 */ @Internal -public class HandlerPublisher extends ChannelDuplexHandler implements HotObservable { +public abstract class HandlerPublisher extends ChannelDuplexHandler implements HotObservable { private static final Logger LOG = LoggerFactory.getLogger(HandlerPublisher.class); /** * Used for buffering a completion signal. @@ -76,7 +83,6 @@ public String toString() { private final AtomicBoolean completed = new AtomicBoolean(false); private final EventExecutor executor; - private final TypeParameterMatcher matcher; private final Queue buffer = new LinkedList<>(); @@ -100,11 +106,9 @@ public String toString() { * with, if not, an exception will be thrown when the handler is registered. * * @param executor The executor to execute asynchronous events from the subscriber on. - * @param subscriberMessageType The type of message this publisher accepts. */ - public HandlerPublisher(EventExecutor executor, Class subscriberMessageType) { + public HandlerPublisher(EventExecutor executor) { this.executor = executor; - this.matcher = TypeParameterMatcher.get(subscriberMessageType); } @Override @@ -139,9 +143,7 @@ public void cancel() { * @param msg The message to check. * @return True if the message should be accepted. */ - protected boolean acceptInboundMessage(Object msg) { - return matcher.match(msg); - } + protected abstract boolean acceptInboundMessage(Object msg); /** * Override to handle when a subscriber cancels the subscription. diff --git a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsHandler.java b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsHandler.java index 2a6a659ec82..080d4c6638b 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsHandler.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsHandler.java @@ -230,7 +230,12 @@ public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exce currentlyStreamedMessage = inMsg; // It has a body, stream it - HandlerPublisher publisher = new HandlerPublisher(ctx.executor(), HttpContent.class) { + HandlerPublisher publisher = new HandlerPublisher(ctx.executor()) { + @Override + protected boolean acceptInboundMessage(Object msg) { + return msg instanceof HttpContent; + } + @Override protected void cancelled() { if (ctx.executor().inEventLoop()) { diff --git a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java index a2d4cf67584..d0baea2d300 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/stream/HttpStreamsServerHandler.java @@ -246,7 +246,12 @@ private void handleWebSocketResponse(ChannelHandlerContext ctx, HttpResponse mes } else { // First, insert new handlers in the chain after us for handling the websocket ChannelPipeline pipeline = ctx.pipeline(); - HandlerPublisher publisher = new HandlerPublisher<>(ctx.executor(), WebSocketFrame.class); + HandlerPublisher publisher = new HandlerPublisher<>(ctx.executor()) { + @Override + protected boolean acceptInboundMessage(Object msg) { + return msg instanceof WebSocketFrame; + } + }; HandlerSubscriber subscriber = new HandlerSubscriber<>(ctx.executor()); pipeline.addAfter(ctx.executor(), ctx.name(), "websocket-subscriber", subscriber); pipeline.addAfter(ctx.executor(), ctx.name(), "websocket-publisher", publisher); diff --git a/http-netty/src/test/groovy/io/micronaut/http/netty/reactive/HandlerPublisherSpec.groovy b/http-netty/src/test/groovy/io/micronaut/http/netty/reactive/HandlerPublisherSpec.groovy index 5b2fb717a4f..34e046b6592 100644 --- a/http-netty/src/test/groovy/io/micronaut/http/netty/reactive/HandlerPublisherSpec.groovy +++ b/http-netty/src/test/groovy/io/micronaut/http/netty/reactive/HandlerPublisherSpec.groovy @@ -27,7 +27,12 @@ class HandlerPublisherSpec extends Specification { */ def embeddedChannel = new EmbeddedChannel() - def handlerPublisher = new HandlerPublisher(embeddedChannel.eventLoop(), Object) + def handlerPublisher = new HandlerPublisher(embeddedChannel.eventLoop()) { + @Override + protected boolean acceptInboundMessage(Object msg) { + return true + } + } boolean killOnNextRead = false embeddedChannel.pipeline().addLast(new ChannelDuplexHandler() { @Override diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java index e8c66825ecc..1a4b638562b 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java @@ -550,7 +550,10 @@ private void insertMicronautHandlers(boolean zeroCopySupported) { pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_COMPRESSOR, contentCompressor); pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_DECOMPRESSOR, new HttpContentDecompressor()); - pipeline.addLast(NettyServerWebSocketUpgradeHandler.COMPRESSION_HANDLER, new WebSocketServerCompressionHandler()); + Optional>> webSocketUpgradeHandler = embeddedServices.getWebSocketUpgradeHandler(server); + if (webSocketUpgradeHandler.isPresent()) { + pipeline.addLast(NettyServerWebSocketUpgradeHandler.COMPRESSION_HANDLER, new WebSocketServerCompressionHandler()); + } pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_STREAM, new HttpStreamsServerHandler()); pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_CHUNK, new ChunkedWriteHandler()); pipeline.addLast(HttpRequestDecoder.ID, requestDecoder); @@ -561,9 +564,7 @@ private void insertMicronautHandlers(boolean zeroCopySupported) { pipeline.addLast("request-certificate-handler", new HttpRequestCertificateHandler(sslHandler)); } pipeline.addLast(HttpResponseEncoder.ID, responseEncoder); - embeddedServices.getWebSocketUpgradeHandler(server).ifPresent(websocketHandler -> - pipeline.addLast(ChannelPipelineCustomizer.HANDLER_WEBSOCKET_UPGRADE, websocketHandler) - ); + webSocketUpgradeHandler.ifPresent(h -> pipeline.addLast(ChannelPipelineCustomizer.HANDLER_WEBSOCKET_UPGRADE, h)); pipeline.addLast(ChannelPipelineCustomizer.HANDLER_MICRONAUT_INBOUND, routingInBoundHandler); } diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyHttpRequest.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyHttpRequest.java index 119a831120b..96bca32ad2b 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyHttpRequest.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyHttpRequest.java @@ -287,7 +287,7 @@ public MutableConvertibleValues getAttributes() { synchronized (this) { // double check attributes = this.attributes; if (attributes == null) { - attributes = new MutableConvertibleValuesMap<>(new HashMap<>(4)); + attributes = new MutableConvertibleValuesMap<>(new HashMap<>(8)); this.attributes = attributes; } } diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyRequestLifecycle.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyRequestLifecycle.java index 46984fd23ac..3d56eda33cf 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyRequestLifecycle.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/NettyRequestLifecycle.java @@ -17,7 +17,7 @@ import io.micronaut.core.annotation.Internal; import io.micronaut.core.annotation.Nullable; -import io.micronaut.core.execution.CompletableFutureExecutionFlow; +import io.micronaut.core.execution.DelayedExecutionFlow; import io.micronaut.core.execution.ExecutionFlow; import io.micronaut.core.type.Argument; import io.micronaut.http.HttpMethod; @@ -54,7 +54,6 @@ import java.util.Collection; import java.util.List; import java.util.Optional; -import java.util.concurrent.CompletableFuture; @Internal final class NettyRequestLifecycle extends RequestLifecycle { @@ -151,7 +150,7 @@ private ExecutionFlow> waitForBody(RouteMatch routeMatch) { HttpContentProcessor processor = rib.httpContentProcessorResolver.resolve(nettyRequest, routeMatch); StreamingDataSubscriber pr = new StreamingDataSubscriber(completer, processor); ((StreamedHttpRequest) nettyRequest.getNativeRequest()).subscribe(pr); - return CompletableFutureExecutionFlow.just(pr.completion); + return pr.completion; } void handleException(Throwable cause) { @@ -187,7 +186,8 @@ private boolean shouldReadBody(RouteMatch routeMatch) { } private static class StreamingDataSubscriber implements Subscriber { - final CompletableFuture> completion = new CompletableFuture<>(); + final DelayedExecutionFlow> completion = DelayedExecutionFlow.create(); + private boolean completed = false; private final List bufferList = new ArrayList<>(1); private final HttpContentProcessor contentProcessor; @@ -272,7 +272,10 @@ private void handleError(Throwable t) { // this may drop the exception if the route has already been executed. However, that is // only the case if there are publisher parameters, and those will still receive the // failure. Hopefully. - completion.completeExceptionally(t); + if (!completed) { + completion.completeExceptionally(t); + completed = true; + } downstreamDone = true; } @@ -298,6 +301,7 @@ public void onComplete() { private void executeRoute() { completion.complete(completer.routeMatch); + completed = true; } } } diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy index ea793220e50..65ca22d9cee 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/cors/CorsFilterSpec.groovy @@ -17,12 +17,16 @@ package io.micronaut.http.server.netty.cors import io.micronaut.context.ApplicationContext import io.micronaut.core.annotation.Nullable -import io.micronaut.core.async.publisher.Publishers import io.micronaut.core.util.StringUtils -import io.micronaut.http.* +import io.micronaut.http.HttpAttributes +import io.micronaut.http.HttpHeaders +import io.micronaut.http.HttpMethod +import io.micronaut.http.HttpRequest +import io.micronaut.http.HttpResponse +import io.micronaut.http.HttpStatus +import io.micronaut.http.MutableHttpResponse import io.micronaut.http.annotation.Controller import io.micronaut.http.annotation.Get -import io.micronaut.http.filter.ServerFilterChain import io.micronaut.http.server.HttpServerConfiguration import io.micronaut.http.server.cors.CorsFilter import io.micronaut.http.server.cors.CorsOriginConfiguration @@ -32,8 +36,6 @@ import io.micronaut.web.router.RouteMatch import io.micronaut.web.router.Router import io.micronaut.web.router.UriRouteMatch import org.apache.http.client.utils.URIBuilder -import org.reactivestreams.Publisher -import reactor.core.publisher.Mono import spock.lang.AutoCleanup import spock.lang.Shared import spock.lang.Specification @@ -41,7 +43,15 @@ import spock.lang.Unroll import java.util.stream.Collectors -import static io.micronaut.http.HttpHeaders.* +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_MAX_AGE +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS +import static io.micronaut.http.HttpHeaders.VARY class CorsFilterSpec extends Specification { @@ -56,7 +66,7 @@ class CorsFilterSpec extends Specification { HttpRequest request = createRequest(null as String) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: "the request is passed through" result.isPresent() @@ -79,7 +89,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -104,7 +114,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -146,7 +156,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -172,7 +182,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -219,7 +229,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -260,7 +270,7 @@ class CorsFilterSpec extends Specification { when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -301,7 +311,7 @@ class CorsFilterSpec extends Specification { } when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: notThrown(NullPointerException) @@ -339,7 +349,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -383,7 +393,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -431,7 +441,7 @@ class CorsFilterSpec extends Specification { request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -481,7 +491,7 @@ class CorsFilterSpec extends Specification { CorsFilter corsHandler = buildCorsHandler(config) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -522,7 +532,7 @@ class CorsFilterSpec extends Specification { request.getAttribute(HttpAttributes.AVAILABLE_HTTP_METHODS, _) >> Optional.of(routes.stream().map(route->route.getHttpMethod()).collect(Collectors.toList())) when: - Optional> result = Mono.from(corsHandler.doFilter(request, okChain())).blockOptional() + Optional> result = filterOk(corsHandler, request) then: result.isPresent() @@ -554,13 +564,14 @@ class CorsFilterSpec extends Specification { } } - private ServerFilterChain okChain() { - new ServerFilterChain() { - @Override - Publisher> proceed(HttpRequest req) { - Publishers.just(HttpResponse.ok()) - } + private Optional> filterOk(CorsFilter filter, HttpRequest req) { + def earlyResponse = filter.filterRequest(req) + if (earlyResponse != null) { + return Optional.of(earlyResponse) } + MutableHttpResponse response = HttpResponse.ok() + filter.filterResponse(req, response) + return Optional.of(response) } private HttpServerConfiguration.CorsConfiguration enabledCorsConfiguration(Map corsConfigurationMap = null) { diff --git a/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java b/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java index f8f16200434..48408994e68 100644 --- a/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java +++ b/http-server/src/main/java/io/micronaut/http/server/cors/CorsFilter.java @@ -15,13 +15,14 @@ */ package io.micronaut.http.server.cors; +import io.micronaut.core.annotation.Internal; import io.micronaut.core.annotation.NonNull; import io.micronaut.core.annotation.Nullable; -import io.micronaut.core.async.publisher.Publishers; import io.micronaut.core.convert.ArgumentConversionContext; import io.micronaut.core.convert.ConversionContext; import io.micronaut.core.convert.ImmutableArgumentConversionContext; import io.micronaut.core.io.socket.SocketUtils; +import io.micronaut.core.order.Ordered; import io.micronaut.core.util.StringUtils; import io.micronaut.http.HttpHeaders; import io.micronaut.http.HttpMethod; @@ -29,14 +30,13 @@ import io.micronaut.http.HttpResponse; import io.micronaut.http.HttpStatus; import io.micronaut.http.MutableHttpResponse; -import io.micronaut.http.annotation.Filter; -import io.micronaut.http.filter.HttpServerFilter; -import io.micronaut.http.filter.ServerFilterChain; +import io.micronaut.http.annotation.RequestFilter; +import io.micronaut.http.annotation.ResponseFilter; +import io.micronaut.http.annotation.ServerFilter; import io.micronaut.http.filter.ServerFilterPhase; import io.micronaut.http.server.HttpServerConfiguration; import io.micronaut.http.server.util.HttpHostResolver; import org.jetbrains.annotations.NotNull; -import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,7 +49,16 @@ import java.util.stream.Collectors; import static io.micronaut.http.HttpAttributes.AVAILABLE_HTTP_METHODS; -import static io.micronaut.http.HttpHeaders.*; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_MAX_AGE; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS; +import static io.micronaut.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD; +import static io.micronaut.http.HttpHeaders.ORIGIN; +import static io.micronaut.http.HttpHeaders.VARY; import static io.micronaut.http.annotation.Filter.MATCH_ALL_PATTERN; /** @@ -59,8 +68,8 @@ * @author Graeme Rocher * @since 1.0 */ -@Filter(MATCH_ALL_PATTERN) -public class CorsFilter implements HttpServerFilter { +@ServerFilter(MATCH_ALL_PATTERN) +public class CorsFilter implements Ordered { private static final Logger LOG = LoggerFactory.getLogger(CorsFilter.class); private static final ArgumentConversionContext CONVERSION_CONTEXT_HTTP_METHOD = ImmutableArgumentConversionContext.of(HttpMethod.class); @@ -79,17 +88,19 @@ public CorsFilter(HttpServerConfiguration.CorsConfiguration corsConfiguration, this.httpHostResolver = httpHostResolver; } - @Override - public Publisher> doFilter(HttpRequest request, ServerFilterChain chain) { + @RequestFilter + @Nullable + @Internal + public final HttpResponse filterRequest(HttpRequest request) { String origin = request.getHeaders().getOrigin().orElse(null); if (origin == null) { LOG.trace("Http Header " + HttpHeaders.ORIGIN + " not present. Proceeding with the request."); - return chain.proceed(request); + return null; // proceed } CorsOriginConfiguration corsOriginConfiguration = getConfiguration(request).orElse(null); if (corsOriginConfiguration != null) { if (CorsUtil.isPreflightRequest(request)) { - return handlePreflightRequest(request, chain, corsOriginConfiguration); + return handlePreflightRequest(request, corsOriginConfiguration); } if (!validateMethodToMatch(request, corsOriginConfiguration).isPresent()) { return forbidden(); @@ -98,13 +109,25 @@ public Publisher> doFilter(HttpRequest request, Server LOG.trace("The resolved configuration allows any origin. To prevent drive-by-localhost attacks the request is forbidden"); return forbidden(); } - return Publishers.then(chain.proceed(request), resp -> decorateResponseWithHeaders(request, resp, corsOriginConfiguration)); + return null; // proceed } else if (shouldDenyToPreventDriveByLocalhostAttack(origin, request)) { LOG.trace("the request specifies an origin different than localhost. To prevent drive-by-localhost attacks the request is forbidden"); return forbidden(); } LOG.trace("CORS configuration not found for {} origin", origin); - return chain.proceed(request); + return null; // proceed + } + + @ResponseFilter + @Internal + public final void filterResponse(HttpRequest request, MutableHttpResponse response) { + CorsOriginConfiguration corsOriginConfiguration = getConfiguration(request).orElse(null); + if (corsOriginConfiguration != null) { + if (CorsUtil.isPreflightRequest(request)) { + decorateResponseWithHeadersForPreflightRequest(request, response, corsOriginConfiguration); + } + decorateResponseWithHeaders(request, response, corsOriginConfiguration); + } } /** @@ -347,8 +370,8 @@ private boolean hasAllowedHeaders(@NonNull HttpRequest request, @NonNull Cors } @NotNull - private static Publisher> forbidden() { - return Publishers.just(HttpResponse.status(HttpStatus.FORBIDDEN)); + private static MutableHttpResponse forbidden() { + return HttpResponse.status(HttpStatus.FORBIDDEN); } @NonNull @@ -375,24 +398,20 @@ private void decorateResponseWithHeaders(@NonNull HttpRequest request, } @NonNull - private Publisher> handlePreflightRequest(@NonNull HttpRequest request, - @NonNull ServerFilterChain chain, + private MutableHttpResponse handlePreflightRequest(@NonNull HttpRequest request, @NonNull CorsOriginConfiguration corsOriginConfiguration) { Optional statusOptional = validatePreflightRequest(request, corsOriginConfiguration); if (statusOptional.isPresent()) { HttpStatus status = statusOptional.get(); if (status.getCode() >= 400) { - return Publishers.just(HttpResponse.status(status)); + return HttpResponse.status(status); } MutableHttpResponse resp = HttpResponse.status(status); decorateResponseWithHeadersForPreflightRequest(request, resp, corsOriginConfiguration); decorateResponseWithHeaders(request, resp, corsOriginConfiguration); - return Publishers.just(resp); + return resp; } - return Publishers.then(chain.proceed(request), resp -> { - decorateResponseWithHeadersForPreflightRequest(request, resp, corsOriginConfiguration); - decorateResponseWithHeaders(request, resp, corsOriginConfiguration); - }); + return null; } @NonNull diff --git a/http/src/main/java/io/micronaut/http/reactive/execution/ReactorExecutionFlowImpl.java b/http/src/main/java/io/micronaut/http/reactive/execution/ReactorExecutionFlowImpl.java index 74fb7d33722..7ece13d7075 100644 --- a/http/src/main/java/io/micronaut/http/reactive/execution/ReactorExecutionFlowImpl.java +++ b/http/src/main/java/io/micronaut/http/reactive/execution/ReactorExecutionFlowImpl.java @@ -17,7 +17,6 @@ import io.micronaut.core.annotation.Internal; import io.micronaut.core.annotation.Nullable; -import io.micronaut.core.execution.CompletableFutureExecutionFlow; import io.micronaut.core.execution.ExecutionFlow; import io.micronaut.core.execution.ImperativeExecutionFlow; import org.reactivestreams.Publisher; @@ -129,8 +128,6 @@ public ImperativeExecutionFlow tryComplete() { static Mono toMono(ExecutionFlow next) { if (next instanceof ReactorExecutionFlowImpl reactiveFlowImpl) { return reactiveFlowImpl.value; - } else if (next instanceof CompletableFutureExecutionFlow completableFutureFlow) { - return Mono.fromCompletionStage(completableFutureFlow.toCompletableFuture()); } else if (next instanceof ImperativeExecutionFlow imperativeFlow) { Mono m; if (imperativeFlow.getError() != null) { @@ -150,8 +147,9 @@ static Mono toMono(ExecutionFlow next) { }); } return m; + } else { + return Mono.fromCompletionStage(next.toCompletableFuture()); } - throw new IllegalStateException(); } static Mono toMono(Supplier> next) {