diff --git a/client/rest/src/main/java/org/opensearch/client/RestClient.java b/client/rest/src/main/java/org/opensearch/client/RestClient.java index 4f899fd709112..92aed2c8fb179 100644 --- a/client/rest/src/main/java/org/opensearch/client/RestClient.java +++ b/client/rest/src/main/java/org/opensearch/client/RestClient.java @@ -36,6 +36,7 @@ import org.apache.http.ConnectionClosedException; import org.apache.http.Header; import org.apache.http.HttpEntity; +import org.apache.http.entity.HttpEntityWrapper; import org.apache.http.HttpHost; import org.apache.http.HttpRequest; import org.apache.http.HttpResponse; @@ -131,6 +132,29 @@ public class RestClient implements Closeable { private volatile NodeTuple> nodeTuple; private final WarningsHandler warningsHandler; private final boolean compressionEnabled; + private final Optional chunkedEnabled; + + RestClient( + CloseableHttpAsyncClient client, + Header[] defaultHeaders, + List nodes, + String pathPrefix, + FailureListener failureListener, + NodeSelector nodeSelector, + boolean strictDeprecationMode, + boolean compressionEnabled, + boolean chunkedEnabled + ) { + this.client = client; + this.defaultHeaders = Collections.unmodifiableList(Arrays.asList(defaultHeaders)); + this.failureListener = failureListener; + this.pathPrefix = pathPrefix; + this.nodeSelector = nodeSelector; + this.warningsHandler = strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE; + this.compressionEnabled = compressionEnabled; + this.chunkedEnabled = Optional.of(chunkedEnabled); + setNodes(nodes); + } RestClient( CloseableHttpAsyncClient client, @@ -149,6 +173,7 @@ public class RestClient implements Closeable { this.nodeSelector = nodeSelector; this.warningsHandler = strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE; this.compressionEnabled = compressionEnabled; + this.chunkedEnabled = Optional.empty(); setNodes(nodes); } @@ -583,36 +608,42 @@ private static void addSuppressedException(Exception suppressedException, Except } } - private static HttpRequestBase createHttpRequest(String method, URI uri, HttpEntity entity, boolean compressionEnabled) { + private HttpRequestBase createHttpRequest(String method, URI uri, HttpEntity entity) { switch (method.toUpperCase(Locale.ROOT)) { case HttpDeleteWithEntity.METHOD_NAME: - return addRequestBody(new HttpDeleteWithEntity(uri), entity, compressionEnabled); + return addRequestBody(new HttpDeleteWithEntity(uri), entity); case HttpGetWithEntity.METHOD_NAME: - return addRequestBody(new HttpGetWithEntity(uri), entity, compressionEnabled); + return addRequestBody(new HttpGetWithEntity(uri), entity); case HttpHead.METHOD_NAME: - return addRequestBody(new HttpHead(uri), entity, compressionEnabled); + return addRequestBody(new HttpHead(uri), entity); case HttpOptions.METHOD_NAME: - return addRequestBody(new HttpOptions(uri), entity, compressionEnabled); + return addRequestBody(new HttpOptions(uri), entity); case HttpPatch.METHOD_NAME: - return addRequestBody(new HttpPatch(uri), entity, compressionEnabled); + return addRequestBody(new HttpPatch(uri), entity); case HttpPost.METHOD_NAME: HttpPost httpPost = new HttpPost(uri); - addRequestBody(httpPost, entity, compressionEnabled); + addRequestBody(httpPost, entity); return httpPost; case HttpPut.METHOD_NAME: - return addRequestBody(new HttpPut(uri), entity, compressionEnabled); + return addRequestBody(new HttpPut(uri), entity); case HttpTrace.METHOD_NAME: - return addRequestBody(new HttpTrace(uri), entity, compressionEnabled); + return addRequestBody(new HttpTrace(uri), entity); default: throw new UnsupportedOperationException("http method not supported: " + method); } } - private static HttpRequestBase addRequestBody(HttpRequestBase httpRequest, HttpEntity entity, boolean compressionEnabled) { + private HttpRequestBase addRequestBody(HttpRequestBase httpRequest, HttpEntity entity) { if (entity != null) { if (httpRequest instanceof HttpEntityEnclosingRequestBase) { if (compressionEnabled) { - entity = new ContentCompressingEntity(entity); + if (chunkedEnabled.isPresent()) { + entity = new ContentCompressingEntity(entity, chunkedEnabled.get()); + } else { + entity = new ContentCompressingEntity(entity); + } + } else if (chunkedEnabled.isPresent()) { + entity = new ContentHttpEntity(entity, chunkedEnabled.get()); } ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(entity); } else { @@ -782,7 +813,7 @@ private class InternalRequest { String ignoreString = params.remove("ignore"); this.ignoreErrorCodes = getIgnoreErrorCodes(ignoreString, request.getMethod()); URI uri = buildUri(pathPrefix, request.getEndpoint(), params); - this.httpRequest = createHttpRequest(request.getMethod(), uri, request.getEntity(), compressionEnabled); + this.httpRequest = createHttpRequest(request.getMethod(), uri, request.getEntity()); this.cancellable = Cancellable.fromRequest(httpRequest); setHeaders(httpRequest, request.getOptions().getHeaders()); setRequestConfig(httpRequest, request.getOptions().getRequestConfig()); @@ -936,6 +967,7 @@ private static Exception extractAndWrapCause(Exception exception) { * A gzip compressing entity that also implements {@code getContent()}. */ public static class ContentCompressingEntity extends GzipCompressingEntity { + private Optional chunkedEnabled; /** * Creates a {@link ContentCompressingEntity} instance with the provided HTTP entity. @@ -944,6 +976,18 @@ public static class ContentCompressingEntity extends GzipCompressingEntity { */ public ContentCompressingEntity(HttpEntity entity) { super(entity); + this.chunkedEnabled = Optional.empty(); + } + + /** + * Creates a {@link ContentCompressingEntity} instance with the provided HTTP entity. + * + * @param entity the HTTP entity. + * @param chunkedEnabled force enable/disable chunked transfer-encoding. + */ + public ContentCompressingEntity(HttpEntity entity, boolean chunkedEnabled) { + super(entity); + this.chunkedEnabled = Optional.of(chunkedEnabled); } @Override @@ -954,6 +998,80 @@ public InputStream getContent() throws IOException { } return out.asInput(); } + + /** + * A gzip compressing entity doesn't work with chunked encoding with sigv4 + * + * @return false + */ + @Override + public boolean isChunked() { + return chunkedEnabled.orElseGet(super::isChunked); + } + + /** + * A gzip entity requires content length in http headers. + * + * @return content length of gzip entity + */ + @Override + public long getContentLength() { + if (chunkedEnabled.isPresent()) { + if (chunkedEnabled.get()) { + return -1L; + } else { + long size; + try (InputStream is = getContent()) { + size = is.readAllBytes().length; + } catch (IOException ex) { + size = -1L; + } + + return size; + } + } else { + return super.getContentLength(); + } + } + } + + /** + * An entity that lets the caller specify the return value of {@code isChunked()}. + */ + public static class ContentHttpEntity extends HttpEntityWrapper { + private Optional chunkedEnabled; + + /** + * Creates a {@link ContentHttpEntity} instance with the provided HTTP entity. + * + * @param entity the HTTP entity. + */ + public ContentHttpEntity(HttpEntity entity) { + super(entity); + this.chunkedEnabled = Optional.empty(); + } + + /** + * Creates a {@link ContentHttpEntity} instance with the provided HTTP entity. + * + * @param entity the HTTP entity. + * @param chunkedEnabled force enable/disable chunked transfer-encoding. + */ + public ContentHttpEntity(HttpEntity entity, boolean chunkedEnabled) { + super(entity); + this.chunkedEnabled = Optional.of(chunkedEnabled); + } + + /** + * A chunked entity requires transfer-encoding:chunked in http headers + * which requires isChunked to be true + * + * @return true + */ + @Override + public boolean isChunked() { + return chunkedEnabled.orElseGet(super::isChunked); + } } /** diff --git a/client/rest/src/main/java/org/opensearch/client/RestClientBuilder.java b/client/rest/src/main/java/org/opensearch/client/RestClientBuilder.java index 0b259c7983ca5..8841d371754c3 100644 --- a/client/rest/src/main/java/org/opensearch/client/RestClientBuilder.java +++ b/client/rest/src/main/java/org/opensearch/client/RestClientBuilder.java @@ -46,6 +46,7 @@ import java.security.PrivilegedAction; import java.util.List; import java.util.Objects; +import java.util.Optional; /** * Helps creating a new {@link RestClient}. Allows to set the most common http client configuration options when internally @@ -84,6 +85,7 @@ public final class RestClientBuilder { private NodeSelector nodeSelector = NodeSelector.ANY; private boolean strictDeprecationMode = false; private boolean compressionEnabled = false; + private Optional chunkedEnabled; /** * Creates a new builder instance and sets the hosts that the client will send requests to. @@ -100,6 +102,7 @@ public final class RestClientBuilder { } } this.nodes = nodes; + this.chunkedEnabled = Optional.empty(); } /** @@ -238,6 +241,16 @@ public RestClientBuilder setCompressionEnabled(boolean compressionEnabled) { return this; } + /** + * Whether the REST client should use Transfer-Encoding: chunked for requests or not" + * + * @param chunkedEnabled force enable/disable chunked transfer-encoding. + */ + public RestClientBuilder setChunkedEnabled(boolean chunkedEnabled) { + this.chunkedEnabled = Optional.of(chunkedEnabled); + return this; + } + /** * Creates a new {@link RestClient} based on the provided configuration. */ @@ -248,16 +261,34 @@ public RestClient build() { CloseableHttpAsyncClient httpClient = AccessController.doPrivileged( (PrivilegedAction) this::createHttpClient ); - RestClient restClient = new RestClient( - httpClient, - defaultHeaders, - nodes, - pathPrefix, - failureListener, - nodeSelector, - strictDeprecationMode, - compressionEnabled - ); + + RestClient restClient = null; + + if (chunkedEnabled.isPresent()) { + restClient = new RestClient( + httpClient, + defaultHeaders, + nodes, + pathPrefix, + failureListener, + nodeSelector, + strictDeprecationMode, + compressionEnabled, + chunkedEnabled.get() + ); + } else { + restClient = new RestClient( + httpClient, + defaultHeaders, + nodes, + pathPrefix, + failureListener, + nodeSelector, + strictDeprecationMode, + compressionEnabled + ); + } + httpClient.start(); return restClient; } diff --git a/client/rest/src/test/java/org/opensearch/client/RestClientCompressionTests.java b/client/rest/src/test/java/org/opensearch/client/RestClientCompressionTests.java new file mode 100644 index 0000000000000..e8b7742044f67 --- /dev/null +++ b/client/rest/src/test/java/org/opensearch/client/RestClientCompressionTests.java @@ -0,0 +1,165 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.client; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import org.apache.http.HttpEntity; +import org.apache.http.HttpHost; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CompletableFuture; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +public class RestClientCompressionTests extends RestClientTestCase { + + private static HttpServer httpServer; + + @BeforeClass + public static void startHttpServer() throws Exception { + httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + httpServer.createContext("/", new GzipResponseHandler()); + httpServer.start(); + } + + @AfterClass + public static void stopHttpServers() throws IOException { + httpServer.stop(0); + httpServer = null; + } + + /** + * A response handler that accepts gzip-encoded data and replies request and response encoding values + * followed by the request body. The response is compressed if "Accept-Encoding" is "gzip". + */ + private static class GzipResponseHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + + // Decode body (if any) + String contentEncoding = exchange.getRequestHeaders().getFirst("Content-Encoding"); + String contentLength = exchange.getRequestHeaders().getFirst("Content-Length"); + InputStream body = exchange.getRequestBody(); + boolean compressedRequest = false; + if ("gzip".equals(contentEncoding)) { + body = new GZIPInputStream(body); + compressedRequest = true; + } + byte[] bytes = readAll(body); + boolean compress = "gzip".equals(exchange.getRequestHeaders().getFirst("Accept-Encoding")); + if (compress) { + exchange.getResponseHeaders().add("Content-Encoding", "gzip"); + } + + exchange.sendResponseHeaders(200, 0); + + // Encode response if needed + OutputStream out = exchange.getResponseBody(); + if (compress) { + out = new GZIPOutputStream(out); + } + + // Outputs ## + out.write(String.valueOf(contentEncoding).getBytes(StandardCharsets.UTF_8)); + out.write('#'); + out.write((compress ? "gzip" : "null").getBytes(StandardCharsets.UTF_8)); + out.write('#'); + out.write((compressedRequest ? contentLength : "null").getBytes(StandardCharsets.UTF_8)); + out.write('#'); + out.write(bytes); + out.close(); + + exchange.close(); + } + } + + /** + * Read all bytes of an input stream and close it. + */ + private static byte[] readAll(InputStream in) throws IOException { + byte[] buffer = new byte[1024]; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + int len = 0; + while ((len = in.read(buffer)) > 0) { + bos.write(buffer, 0, len); + } + in.close(); + return bos.toByteArray(); + } + + private RestClient createClient(boolean enableCompression, boolean chunkedEnabled) { + InetSocketAddress address = httpServer.getAddress(); + return RestClient.builder(new HttpHost(address.getHostString(), address.getPort(), "http")) + .setCompressionEnabled(enableCompression) + .setChunkedEnabled(chunkedEnabled) + .build(); + } + + public void testCompressingClientWithContentLengthSync() throws Exception { + RestClient restClient = createClient(true, false); + + Request request = new Request("POST", "/"); + request.setEntity(new StringEntity("compressing client", ContentType.TEXT_PLAIN)); + + Response response = restClient.performRequest(request); + + HttpEntity entity = response.getEntity(); + String content = new String(readAll(entity.getContent()), StandardCharsets.UTF_8); + // Content-Encoding#Accept-Encoding#Content-Length#Content + Assert.assertEquals("gzip#gzip#38#compressing client", content); + + restClient.close(); + } + + public void testCompressingClientContentLengthAsync() throws Exception { + InetSocketAddress address = httpServer.getAddress(); + RestClient restClient = createClient(true, false); + + Request request = new Request("POST", "/"); + request.setEntity(new StringEntity("compressing client", ContentType.TEXT_PLAIN)); + + FutureResponse futureResponse = new FutureResponse(); + restClient.performRequestAsync(request, futureResponse); + Response response = futureResponse.get(); + + // Server should report it had a compressed request and sent back a compressed response + HttpEntity entity = response.getEntity(); + String content = new String(readAll(entity.getContent()), StandardCharsets.UTF_8); + + // Content-Encoding#Accept-Encoding#Content-Length#Content + Assert.assertEquals("gzip#gzip#38#compressing client", content); + + restClient.close(); + } + + public static class FutureResponse extends CompletableFuture implements ResponseListener { + @Override + public void onSuccess(Response response) { + this.complete(response); + } + + @Override + public void onFailure(Exception exception) { + this.completeExceptionally(exception); + } + } +} diff --git a/client/rest/src/test/java/org/opensearch/client/RestClientSingleHostTests.java b/client/rest/src/test/java/org/opensearch/client/RestClientSingleHostTests.java index a2aef065c9ae8..e5ce5eb91ad5a 100644 --- a/client/rest/src/test/java/org/opensearch/client/RestClientSingleHostTests.java +++ b/client/rest/src/test/java/org/opensearch/client/RestClientSingleHostTests.java @@ -138,6 +138,7 @@ public void createRestClient() { failureListener, NodeSelector.ANY, strictDeprecationMode, + false, false ); }