diff --git a/CHANGELOG.md b/CHANGELOG.md index 1be3d3f53f2d6..8e7fa8b5547f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Added release notes for 1.3.5 ([#4343](https://github.com/opensearch-project/OpenSearch/pull/4343)) - Added release notes for 2.2.1 ([#4344](https://github.com/opensearch-project/OpenSearch/pull/4344)) - Label configuration for dependabot PRs ([#4348](https://github.com/opensearch-project/OpenSearch/pull/4348)) +- Support for HTTP/2 (server-side) ([#3847](https://github.com/opensearch-project/OpenSearch/pull/3847)) ### Changed - Dependency updates (httpcore, mockito, slf4j, httpasyncclient, commons-codec) ([#4308](https://github.com/opensearch-project/OpenSearch/pull/4308)) diff --git a/modules/transport-netty4/build.gradle b/modules/transport-netty4/build.gradle index b72cb6d868d79..5d2047d7f18a2 100644 --- a/modules/transport-netty4/build.gradle +++ b/modules/transport-netty4/build.gradle @@ -58,6 +58,7 @@ dependencies { api "io.netty:netty-buffer:${versions.netty}" api "io.netty:netty-codec:${versions.netty}" api "io.netty:netty-codec-http:${versions.netty}" + api "io.netty:netty-codec-http2:${versions.netty}" api "io.netty:netty-common:${versions.netty}" api "io.netty:netty-handler:${versions.netty}" api "io.netty:netty-resolver:${versions.netty}" diff --git a/modules/transport-netty4/licenses/netty-codec-http2-4.1.79.Final.jar.sha1 b/modules/transport-netty4/licenses/netty-codec-http2-4.1.79.Final.jar.sha1 new file mode 100644 index 0000000000000..f2989024cfce1 --- /dev/null +++ b/modules/transport-netty4/licenses/netty-codec-http2-4.1.79.Final.jar.sha1 @@ -0,0 +1 @@ +0eeffab0cd5efb699d5e4ab9b694d32fef6694b3 \ No newline at end of file diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java new file mode 100644 index 0000000000000..1424b392af8e7 --- /dev/null +++ b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.http.netty4; + +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.util.ReferenceCounted; +import org.opensearch.OpenSearchNetty4IntegTestCase; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.http.HttpServerTransport; +import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope; +import org.opensearch.test.OpenSearchIntegTestCase.Scope; + +import java.util.Collection; +import java.util.Locale; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasSize; + +@ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1) +public class Netty4Http2IT extends OpenSearchNetty4IntegTestCase { + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + public void testThatNettyHttpServerSupportsHttp2() throws Exception { + String[] requests = new String[] { "/", "/_nodes/stats", "/", "/_cluster/state", "/" }; + + HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); + TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses(); + TransportAddress transportAddress = randomFrom(boundAddresses); + + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2()) { + Collection responses = nettyHttpClient.get(transportAddress.address(), requests); + try { + assertThat(responses, hasSize(5)); + + Collection opaqueIds = Netty4HttpClient.returnOpaqueIds(responses); + assertOpaqueIdsInAnyOrder(opaqueIds); + } finally { + responses.forEach(ReferenceCounted::release); + } + } + } + + private void assertOpaqueIdsInAnyOrder(Collection opaqueIds) { + // check if opaque ids are present in any order, since for HTTP/2 we use streaming (no head of line blocking) + // and responses may come back at any order + int i = 0; + String msg = String.format(Locale.ROOT, "Expected list of opaque ids to be in any order, got [%s]", opaqueIds); + assertThat(msg, opaqueIds, containsInAnyOrder(IntStream.range(0, 5).mapToObj(Integer::toString).toArray())); + } + +} diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HttpRequestSizeLimitIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HttpRequestSizeLimitIT.java index 08df9259d475f..db76c0b145840 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HttpRequestSizeLimitIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4HttpRequestSizeLimitIT.java @@ -100,7 +100,7 @@ public void testLimitsInFlightRequests() throws Exception { HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); TransportAddress transportAddress = randomFrom(httpServerTransport.boundAddress().boundAddresses()); - try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) { + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http()) { Collection singleResponse = nettyHttpClient.post(transportAddress.address(), requests.subList(0, 1)); try { assertThat(singleResponse, hasSize(1)); @@ -130,7 +130,7 @@ public void testDoesNotLimitExcludedRequests() throws Exception { HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); TransportAddress transportAddress = randomFrom(httpServerTransport.boundAddress().boundAddresses()); - try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) { + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http()) { Collection responses = nettyHttpClient.put(transportAddress.address(), requestUris); try { assertThat(responses, hasSize(requestUris.size())); diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4PipeliningIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4PipeliningIT.java index 2bd1fa07f8afc..96193b0ecb954 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4PipeliningIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4PipeliningIT.java @@ -61,7 +61,7 @@ public void testThatNettyHttpServerSupportsPipelining() throws Exception { TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses(); TransportAddress transportAddress = randomFrom(boundAddresses); - try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) { + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http()) { Collection responses = nettyHttpClient.get(transportAddress.address(), requests); try { assertThat(responses, hasSize(5)); diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpChannel.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpChannel.java index 66d60032d11a8..2dd7aaf41986f 100644 --- a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpChannel.java +++ b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpChannel.java @@ -33,7 +33,10 @@ package org.opensearch.http.netty4; import io.netty.channel.Channel; +import io.netty.channel.ChannelPipeline; + import org.opensearch.action.ActionListener; +import org.opensearch.common.Nullable; import org.opensearch.common.concurrent.CompletableContext; import org.opensearch.http.HttpChannel; import org.opensearch.http.HttpResponse; @@ -45,9 +48,15 @@ public class Netty4HttpChannel implements HttpChannel { private final Channel channel; private final CompletableContext closeContext = new CompletableContext<>(); + private final ChannelPipeline inboundPipeline; Netty4HttpChannel(Channel channel) { + this(channel, null); + } + + Netty4HttpChannel(Channel channel, ChannelPipeline inboundPipeline) { this.channel = channel; + this.inboundPipeline = inboundPipeline; Netty4TcpChannel.addListener(this.channel.closeFuture(), closeContext); } @@ -81,6 +90,10 @@ public void close() { channel.close(); } + public @Nullable ChannelPipeline inboundPipeline() { + return inboundPipeline; + } + public Channel getNettyChannel() { return channel; } diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpServerTransport.java index decab45ffca38..1e0a4d89f2fd5 100644 --- a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpServerTransport.java @@ -40,18 +40,36 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; import io.netty.channel.FixedRecvByteBufAllocator; import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.socket.nio.NioChannelOption; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentDecompressor; +import io.netty.handler.codec.http.HttpMessage; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpServerUpgradeHandler; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodec; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodecFactory; +import io.netty.handler.codec.http2.CleartextHttp2ServerUpgradeHandler; +import io.netty.handler.codec.http2.Http2CodecUtil; +import io.netty.handler.codec.http2.Http2FrameCodecBuilder; +import io.netty.handler.codec.http2.Http2MultiplexHandler; +import io.netty.handler.codec.http2.Http2ServerUpgradeCodec; +import io.netty.handler.codec.http2.Http2StreamFrameToHttpObjectCodec; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutHandler; +import io.netty.util.AsciiString; import io.netty.util.AttributeKey; +import io.netty.util.ReferenceCountUtil; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; @@ -335,38 +353,152 @@ protected HttpChannelHandler(final Netty4HttpServerTransport transport, final Ht this.responseCreator = new Netty4HttpResponseCreator(); } + public ChannelHandler getRequestHandler() { + return requestHandler; + } + @Override protected void initChannel(Channel ch) throws Exception { Netty4HttpChannel nettyHttpChannel = new Netty4HttpChannel(ch); ch.attr(HTTP_CHANNEL_KEY).set(nettyHttpChannel); ch.pipeline().addLast("byte_buf_sizer", byteBufSizer); ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)); + + configurePipeline(ch); + transport.serverAcceptedChannel(nettyHttpChannel); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ExceptionsHelper.maybeDieOnAnotherThread(cause); + super.exceptionCaught(ctx, cause); + } + + protected void configurePipeline(Channel ch) { + final UpgradeCodecFactory upgradeCodecFactory = new UpgradeCodecFactory() { + @Override + public UpgradeCodec newUpgradeCodec(CharSequence protocol) { + if (AsciiString.contentEquals(Http2CodecUtil.HTTP_UPGRADE_PROTOCOL_NAME, protocol)) { + return new Http2ServerUpgradeCodec( + Http2FrameCodecBuilder.forServer().build(), + new Http2MultiplexHandler(createHttp2ChannelInitializer(ch.pipeline())) + ); + } else { + return null; + } + } + }; + + final HttpServerCodec sourceCodec = new HttpServerCodec( + handlingSettings.getMaxInitialLineLength(), + handlingSettings.getMaxHeaderSize(), + handlingSettings.getMaxChunkSize() + ); + + final HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(sourceCodec, upgradeCodecFactory); + final CleartextHttp2ServerUpgradeHandler cleartextUpgradeHandler = new CleartextHttp2ServerUpgradeHandler( + sourceCodec, + upgradeHandler, + createHttp2ChannelInitializerPriorKnowledge() + ); + + ch.pipeline().addLast(cleartextUpgradeHandler).addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws Exception { + final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); + aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); + + // If this handler is hit then no upgrade has been attempted and the client is just talking HTTP + final ChannelPipeline pipeline = ctx.pipeline(); + pipeline.addAfter(ctx.name(), "handler", getRequestHandler()); + pipeline.replace(this, "aggregator", aggregator); + + ch.pipeline().addLast("decoder_compress", new HttpContentDecompressor()); + ch.pipeline().addLast("encoder", new HttpResponseEncoder()); + if (handlingSettings.isCompression()) { + ch.pipeline() + .addAfter("aggregator", "encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + } + ch.pipeline().addBefore("handler", "request_creator", requestCreator); + ch.pipeline().addBefore("handler", "response_creator", responseCreator); + ch.pipeline() + .addBefore("handler", "pipelining", new Netty4HttpPipeliningHandler(logger, transport.pipeliningMaxEvents)); + + ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); + } + }); + } + + protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) { final HttpRequestDecoder decoder = new HttpRequestDecoder( handlingSettings.getMaxInitialLineLength(), handlingSettings.getMaxHeaderSize(), handlingSettings.getMaxChunkSize() ); decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR); - ch.pipeline().addLast("decoder", decoder); - ch.pipeline().addLast("decoder_compress", new HttpContentDecompressor()); - ch.pipeline().addLast("encoder", new HttpResponseEncoder()); + pipeline.addLast("decoder", decoder); + pipeline.addLast("decoder_compress", new HttpContentDecompressor()); + pipeline.addLast("encoder", new HttpResponseEncoder()); final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); - ch.pipeline().addLast("aggregator", aggregator); + pipeline.addLast("aggregator", aggregator); if (handlingSettings.isCompression()) { - ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + pipeline.addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } - ch.pipeline().addLast("request_creator", requestCreator); - ch.pipeline().addLast("response_creator", responseCreator); - ch.pipeline().addLast("pipelining", new Netty4HttpPipeliningHandler(logger, transport.pipeliningMaxEvents)); - ch.pipeline().addLast("handler", requestHandler); - transport.serverAcceptedChannel(nettyHttpChannel); + pipeline.addLast("request_creator", requestCreator); + pipeline.addLast("response_creator", responseCreator); + pipeline.addLast("pipelining", new Netty4HttpPipeliningHandler(logger, transport.pipeliningMaxEvents)); + pipeline.addLast("handler", requestHandler); } - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - ExceptionsHelper.maybeDieOnAnotherThread(cause); - super.exceptionCaught(ctx, cause); + protected void configureDefaultHttp2Pipeline(ChannelPipeline pipeline) { + pipeline.addLast(Http2FrameCodecBuilder.forServer().build()) + .addLast(new Http2MultiplexHandler(createHttp2ChannelInitializer(pipeline))); + } + + private ChannelInitializer createHttp2ChannelInitializerPriorKnowledge() { + return new ChannelInitializer() { + @Override + protected void initChannel(Channel childChannel) throws Exception { + configureDefaultHttp2Pipeline(childChannel.pipeline()); + } + }; + } + + /** + * Http2MultiplexHandler creates new pipeline, we are preserving the old one in case some handlers need to be + * access (like for example opensearch-security plugin which accesses SSL handlers). + */ + private ChannelInitializer createHttp2ChannelInitializer(ChannelPipeline inboundPipeline) { + return new ChannelInitializer() { + @Override + protected void initChannel(Channel childChannel) throws Exception { + final Netty4HttpChannel nettyHttpChannel = new Netty4HttpChannel(childChannel, inboundPipeline); + childChannel.attr(HTTP_CHANNEL_KEY).set(nettyHttpChannel); + + final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); + aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); + + childChannel.pipeline() + .addLast(new LoggingHandler(LogLevel.DEBUG)) + .addLast(new Http2StreamFrameToHttpObjectCodec(true)) + .addLast("byte_buf_sizer", byteBufSizer) + .addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)) + .addLast("decoder_decompress", new HttpContentDecompressor()); + + if (handlingSettings.isCompression()) { + childChannel.pipeline() + .addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + } + + childChannel.pipeline() + .addLast("aggregator", aggregator) + .addLast("request_creator", requestCreator) + .addLast("response_creator", responseCreator) + .addLast("pipelining", new Netty4HttpPipeliningHandler(logger, transport.pipeliningMaxEvents)) + .addLast("handler", getRequestHandler()); + } + }; } } diff --git a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4BadRequestTests.java b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4BadRequestTests.java index a0100930c7dcb..c18fe6efc4736 100644 --- a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4BadRequestTests.java +++ b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4BadRequestTests.java @@ -117,7 +117,7 @@ public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, httpServerTransport.start(); final TransportAddress transportAddress = randomFrom(httpServerTransport.boundAddress().boundAddresses()); - try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) { + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http()) { final Collection responses = nettyHttpClient.get( transportAddress.address(), "/_cluster/settings?pretty=%" diff --git a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpClient.java b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpClient.java index 57f95a022a33f..6fdd698c117f2 100644 --- a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpClient.java +++ b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpClient.java @@ -37,14 +37,19 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpClientUpgradeHandler; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; @@ -55,6 +60,17 @@ import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DelegatingDecompressorFrameListener; +import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; +import io.netty.handler.codec.http2.Http2Connection; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.handler.codec.http2.HttpConversionUtil; +import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandler; +import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandlerBuilder; +import io.netty.handler.codec.http2.InboundHttp2ToHttpAdapterBuilder; +import io.netty.util.AttributeKey; + import org.opensearch.common.collect.Tuple; import org.opensearch.common.unit.ByteSizeUnit; import org.opensearch.common.unit.ByteSizeValue; @@ -70,6 +86,7 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; import static io.netty.handler.codec.http.HttpHeaderNames.HOST; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; @@ -97,11 +114,32 @@ static Collection returnOpaqueIds(Collection responses } private final Bootstrap clientBootstrap; + private final BiFunction, AwaitableChannelInitializer> handlerFactory; + + Netty4HttpClient( + Bootstrap clientBootstrap, + BiFunction, AwaitableChannelInitializer> handlerFactory + ) { + this.clientBootstrap = clientBootstrap; + this.handlerFactory = handlerFactory; + } + + static Netty4HttpClient http() { + return new Netty4HttpClient( + new Bootstrap().channel(NettyAllocator.getChannelType()) + .option(ChannelOption.ALLOCATOR, NettyAllocator.getAllocator()) + .group(new NioEventLoopGroup(1)), + CountDownLatchHandlerHttp::new + ); + } - Netty4HttpClient() { - clientBootstrap = new Bootstrap().channel(NettyAllocator.getChannelType()) - .option(ChannelOption.ALLOCATOR, NettyAllocator.getAllocator()) - .group(new NioEventLoopGroup(1)); + static Netty4HttpClient http2() { + return new Netty4HttpClient( + new Bootstrap().channel(NettyAllocator.getChannelType()) + .option(ChannelOption.ALLOCATOR, NettyAllocator.getAllocator()) + .group(new NioEventLoopGroup(1)), + CountDownLatchHandlerHttp2::new + ); } public List get(SocketAddress remoteAddress, String... uris) throws InterruptedException { @@ -110,6 +148,7 @@ public List get(SocketAddress remoteAddress, String... uris) t final HttpRequest httpRequest = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, uris[i]); httpRequest.headers().add(HOST, "localhost"); httpRequest.headers().add("X-Opaque-ID", String.valueOf(i)); + httpRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); requests.add(httpRequest); } return sendRequests(remoteAddress, requests); @@ -143,6 +182,7 @@ private List processRequestsWithBody( request.headers().add(HttpHeaderNames.HOST, "localhost"); request.headers().add(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes()); request.headers().add(HttpHeaderNames.CONTENT_TYPE, "application/json"); + request.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); requests.add(request); } return sendRequests(remoteAddress, requests); @@ -153,12 +193,14 @@ private synchronized List sendRequests(final SocketAddress rem final CountDownLatch latch = new CountDownLatch(requests.size()); final List content = Collections.synchronizedList(new ArrayList<>(requests.size())); - clientBootstrap.handler(new CountDownLatchHandler(latch, content)); + final AwaitableChannelInitializer handler = handlerFactory.apply(latch, content); + clientBootstrap.handler(handler); ChannelFuture channelFuture = null; try { channelFuture = clientBootstrap.connect(remoteAddress); channelFuture.sync(); + handler.await(); for (HttpRequest request : requests) { channelFuture.channel().writeAndFlush(request); @@ -184,12 +226,12 @@ public void close() { /** * helper factory which adds returned data to a list and uses a count down latch to decide when done */ - private static class CountDownLatchHandler extends ChannelInitializer { + private static class CountDownLatchHandlerHttp extends AwaitableChannelInitializer { private final CountDownLatch latch; private final Collection content; - CountDownLatchHandler(final CountDownLatch latch, final Collection content) { + CountDownLatchHandlerHttp(final CountDownLatch latch, final Collection content) { this.latch = latch; this.content = content; } @@ -222,4 +264,145 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E } + /** + * The channel initializer with the ability to await for initialization to be completed + * + */ + private static abstract class AwaitableChannelInitializer extends ChannelInitializer { + void await() { + // do nothing + } + } + + /** + * helper factory which adds returned data to a list and uses a count down latch to decide when done + */ + private static class CountDownLatchHandlerHttp2 extends AwaitableChannelInitializer { + + private final CountDownLatch latch; + private final Collection content; + private Http2SettingsHandler settingsHandler; + + CountDownLatchHandlerHttp2(final CountDownLatch latch, final Collection content) { + this.latch = latch; + this.content = content; + } + + @Override + protected void initChannel(SocketChannel ch) { + final int maxContentLength = new ByteSizeValue(100, ByteSizeUnit.MB).bytesAsInt(); + final Http2Connection connection = new DefaultHttp2Connection(false); + settingsHandler = new Http2SettingsHandler(ch.newPromise()); + + final ChannelInboundHandler responseHandler = new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) { + final FullHttpResponse response = (FullHttpResponse) msg; + + // this is upgrade request, skipping it over + if (Boolean.TRUE.equals(ctx.channel().attr(AttributeKey.valueOf("upgrade")).getAndRemove())) { + return; + } + + // We copy the buffer manually to avoid a huge allocation on a pooled allocator. We have + // a test that tracks huge allocations, so we want to avoid them in this test code. + ByteBuf newContent = Unpooled.copiedBuffer(((FullHttpResponse) msg).content()); + content.add(response.replace(newContent)); + latch.countDown(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause); + latch.countDown(); + } + }; + + final HttpToHttp2ConnectionHandler connectionHandler = new HttpToHttp2ConnectionHandlerBuilder().connection(connection) + .frameListener( + new DelegatingDecompressorFrameListener( + connection, + new InboundHttp2ToHttpAdapterBuilder(connection).maxContentLength(maxContentLength).propagateSettings(true).build() + ) + ) + .build(); + + final HttpClientCodec sourceCodec = new HttpClientCodec(); + final Http2ClientUpgradeCodec upgradeCodec = new Http2ClientUpgradeCodec(connectionHandler); + final HttpClientUpgradeHandler upgradeHandler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, maxContentLength); + + ch.pipeline().addLast(sourceCodec); + ch.pipeline().addLast(upgradeHandler); + ch.pipeline().addLast(new HttpContentDecompressor()); + ch.pipeline().addLast(new UpgradeRequestHandler(settingsHandler, responseHandler)); + } + + @Override + void await() { + try { + // Await for HTTP/2 settings being sent over before moving on to sending the requests + settingsHandler.awaitSettings(5, TimeUnit.SECONDS); + } catch (final Exception ex) { + throw new RuntimeException(ex); + } + } + } + + /** + * A handler that triggers the cleartext upgrade to HTTP/2 (h2c) by sending an + * initial HTTP request. + */ + private static class UpgradeRequestHandler extends ChannelInboundHandlerAdapter { + private final ChannelInboundHandler settingsHandler; + private final ChannelInboundHandler responseHandler; + + UpgradeRequestHandler(final ChannelInboundHandler settingsHandler, final ChannelInboundHandler responseHandler) { + this.settingsHandler = settingsHandler; + this.responseHandler = responseHandler; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // The first request is HTTP/2 protocol upgrade (since we support only h2c there) + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + request.headers().add(HttpHeaderNames.HOST, "localhost"); + request.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + + ctx.channel().attr(AttributeKey.newInstance("upgrade")).set(true); + ctx.writeAndFlush(request); + ctx.fireChannelActive(); + + ctx.pipeline().remove(this); + ctx.pipeline().addLast(settingsHandler); + ctx.pipeline().addLast(responseHandler); + } + } + + private static class Http2SettingsHandler extends SimpleChannelInboundHandler { + private ChannelPromise promise; + + Http2SettingsHandler(ChannelPromise promise) { + this.promise = promise; + } + + /** + * Wait for this handler to be added after the upgrade to HTTP/2, and for initial preface + * handshake to complete. + */ + void awaitSettings(long timeout, TimeUnit unit) throws Exception { + if (!promise.awaitUninterruptibly(timeout, unit)) { + throw new IllegalStateException("Timed out waiting for HTTP/2 settings"); + } + if (!promise.isSuccess()) { + throw new RuntimeException(promise.cause()); + } + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Http2Settings msg) throws Exception { + promise.setSuccess(); + ctx.pipeline().remove(this); + } + } + } diff --git a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerPipeliningTests.java b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerPipeliningTests.java index 029aed1f3cc89..cda66b8d828fa 100644 --- a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerPipeliningTests.java +++ b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerPipeliningTests.java @@ -109,7 +109,7 @@ public void testThatHttpPipeliningWorks() throws Exception { } } - try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) { + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http()) { Collection responses = nettyHttpClient.get(transportAddress.address(), requests.toArray(new String[] {})); try { Collection responseBodies = Netty4HttpClient.returnHttpResponseBodies(responses); @@ -163,9 +163,12 @@ private class CustomHttpChannelHandler extends Netty4HttpServerTransport.HttpCha @Override protected void initChannel(Channel ch) throws Exception { super.initChannel(ch); - ch.pipeline().replace("handler", "handler", new PossiblySlowUpstreamHandler(executorService)); } + @Override + public ChannelHandler getRequestHandler() { + return new PossiblySlowUpstreamHandler(executorService); + } } class PossiblySlowUpstreamHandler extends SimpleChannelInboundHandler { diff --git a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerTransportTests.java index ec879e538fe20..eb96f14f10c70 100644 --- a/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/opensearch/http/netty4/Netty4HttpServerTransportTests.java @@ -202,7 +202,7 @@ public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, ) { transport.start(); final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); - try (Netty4HttpClient client = new Netty4HttpClient()) { + try (Netty4HttpClient client = Netty4HttpClient.http()) { final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); request.headers().set(HttpHeaderNames.EXPECT, expectation); HttpUtil.setContentLength(request, contentLength); @@ -322,7 +322,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th transport.start(); final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); - try (Netty4HttpClient client = new Netty4HttpClient()) { + try (Netty4HttpClient client = Netty4HttpClient.http()) { final String url = "/" + new String(new byte[maxInitialLineLength], Charset.forName("UTF-8")); final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, url); @@ -384,7 +384,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th transport.start(); final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); - try (Netty4HttpClient client = new Netty4HttpClient()) { + try (Netty4HttpClient client = Netty4HttpClient.http()) { DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, url); request.headers().add(HttpHeaderNames.ACCEPT_ENCODING, randomFrom("deflate", "gzip")); long numOfHugeAllocations = getHugeAllocationCount(); @@ -454,7 +454,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses()); // Test pre-flight request - try (Netty4HttpClient client = new Netty4HttpClient()) { + try (Netty4HttpClient client = Netty4HttpClient.http()) { final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/"); request.headers().add(CorsHandler.ORIGIN, "test-cors.org"); request.headers().add(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, "POST"); @@ -471,7 +471,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th } // Test short-circuited request - try (Netty4HttpClient client = new Netty4HttpClient()) { + try (Netty4HttpClient client = Netty4HttpClient.http()) { final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); request.headers().add(CorsHandler.ORIGIN, "google.com");