diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java index 0a09b6b8789f7..62bf845a77058 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; @@ -116,21 +117,27 @@ public void testSuccessfulDecodeHttpRequest() throws IOException { ByteBuf buf = requestEncoder.encode(httpRequest); int slicePoint = randomInt(buf.writerIndex() - 1); - ByteBuf slicedBuf = buf.retainedSlice(0, slicePoint); ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex()); - handler.consumeReads(toChannelBuffer(slicedBuf)); + try { + handler.consumeReads(toChannelBuffer(slicedBuf)); - verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class)); + verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class)); - handler.consumeReads(toChannelBuffer(slicedBuf2)); + handler.consumeReads(toChannelBuffer(slicedBuf2)); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); - HttpRequest nioHttpRequest = requestCaptor.getValue(); - assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); - assertEquals(RestRequest.Method.GET, nioHttpRequest.method()); + HttpRequest nioHttpRequest = requestCaptor.getValue(); + assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); + assertEquals(RestRequest.Method.GET, nioHttpRequest.method()); + } finally { + handler.close(); + buf.release(); + slicedBuf.release(); + slicedBuf2.release(); + } } public void testDecodeHttpRequestError() throws IOException { @@ -138,16 +145,20 @@ public void testDecodeHttpRequestError() throws IOException { io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); ByteBuf buf = requestEncoder.encode(httpRequest); - buf.setByte(0, ' '); - buf.setByte(1, ' '); - buf.setByte(2, ' '); + try { + buf.setByte(0, ' '); + buf.setByte(1, ' '); + buf.setByte(2, ' '); - handler.consumeReads(toChannelBuffer(buf)); + handler.consumeReads(toChannelBuffer(buf)); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture()); - assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException); + assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException); + } finally { + buf.release(); + } } public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException { @@ -157,9 +168,11 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t HttpUtil.setKeepAlive(httpRequest, false); ByteBuf buf = requestEncoder.encode(httpRequest); - - handler.consumeReads(toChannelBuffer(buf)); - + try { + handler.consumeReads(toChannelBuffer(buf)); + } finally { + buf.release(); + } verify(transport, times(0)).incomingRequestError(any(), any(), any()); verify(transport, times(0)).incomingRequest(any(), any()); @@ -168,13 +181,17 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t FlushOperation flushOperation = flushOperations.get(0); FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); - assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); - assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); - - flushOperation.getListener().accept(null, null); - // Since we have keep-alive set to false, we should close the channel after the response has been - // flushed - verify(nioHttpChannel).close(); + try { + assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + + flushOperation.getListener().accept(null, null); + // Since we have keep-alive set to false, we should close the channel after the response has been + // flushed + verify(nioHttpChannel).close(); + } finally { + response.release(); + } } @SuppressWarnings("unchecked") @@ -189,11 +206,15 @@ public void testEncodeHttpResponse() throws IOException { SocketChannelContext context = mock(SocketChannelContext.class); HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class)); List flushOperations = handler.writeToBytes(writeOperation); - - FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); - - assertEquals(HttpResponseStatus.OK, response.status()); - assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); + FlushOperation operation = flushOperations.get(0); + FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(operation.getBuffersToWrite())); + ((ChannelPromise) operation.getListener()).setSuccess(); + try { + assertEquals(HttpResponseStatus.OK, response.status()); + assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); + } finally { + response.release(); + } } public void testCorsEnabledWithoutAllowOrigins() throws IOException { @@ -201,9 +222,13 @@ public void testCorsEnabledWithoutAllowOrigins() throws IOException { Settings settings = Settings.builder() .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) .build(); - io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + FullHttpResponse response = executeCorsRequest(settings, "remote-host", "request-host"); + try { + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } finally { + response.release(); + } } public void testCorsEnabledWithAllowOrigins() throws IOException { @@ -213,11 +238,15 @@ public void testCorsEnabledWithAllowOrigins() throws IOException { .put(SETTING_CORS_ENABLED.getKey(), true) .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) .build(); - io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); + FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + try { + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } finally { + response.release(); + } } public void testCorsAllowOriginWithSameHost() throws IOException { @@ -228,29 +257,44 @@ public void testCorsAllowOriginWithSameHost() throws IOException { .put(SETTING_CORS_ENABLED.getKey(), true) .build(); FullHttpResponse response = executeCorsRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - + String allowedOrigins; + try { + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } finally { + response.release(); + } originValue = "http://" + originValue; response = executeCorsRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); + try { + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } finally { + response.release(); + } originValue = originValue + ":5555"; host = host + ":5555"; response = executeCorsRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - + try { + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } finally { + response.release(); + } originValue = originValue.replace("http", "https"); response = executeCorsRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); + try { + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } finally { + response.release(); + } } public void testThatStringLiteralWorksOnMatch() throws IOException { @@ -261,12 +305,16 @@ public void testThatStringLiteralWorksOnMatch() throws IOException { .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) .build(); - io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + try { + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + } finally { + response.release(); + } } public void testThatAnyOriginWorks() throws IOException { @@ -275,12 +323,16 @@ public void testThatAnyOriginWorks() throws IOException { .put(SETTING_CORS_ENABLED.getKey(), true) .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) .build(); - io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + try { + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + } finally { + response.release(); + } } private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException { @@ -300,8 +352,9 @@ private FullHttpResponse executeCorsRequest(final Settings settings, final Strin SocketChannelContext context = mock(SocketChannelContext.class); List flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {})); - + handler.close(); FlushOperation flushOperation = flushOperations.get(0); + ((ChannelPromise) flushOperation.getListener()).setSuccess(); return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); } @@ -314,8 +367,11 @@ private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) t io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri); ByteBuf buf = requestEncoder.encode(request); - - handler.consumeReads(toChannelBuffer(buf)); + try { + handler.consumeReads(toChannelBuffer(buf)); + } finally { + buf.release(); + } ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class); verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class));