diff --git a/plugins/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3BlobContainerRetriesTests.java b/plugins/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3BlobContainerRetriesTests.java new file mode 100644 index 0000000000000..ab88cc9368292 --- /dev/null +++ b/plugins/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3BlobContainerRetriesTests.java @@ -0,0 +1,385 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.repositories.s3; + +import com.amazonaws.SdkClientException; +import com.amazonaws.services.s3.internal.MD5DigestCalculatingInputStream; +import com.amazonaws.util.Base16; +import com.sun.net.httpserver.HttpServer; +import org.apache.http.HttpStatus; +import org.elasticsearch.cluster.metadata.RepositoryMetaData; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.SuppressForbidden; +import org.elasticsearch.common.blobstore.BlobContainer; +import org.elasticsearch.common.blobstore.BlobPath; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.lucene.store.ByteArrayIndexInput; +import org.elasticsearch.common.lucene.store.InputStreamIndexInput; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.mocksocket.MockHttpServer; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.io.InputStream; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import java.nio.charset.StandardCharsets; +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.elasticsearch.repositories.s3.S3ClientSettings.DISABLE_CHUNKED_ENCODING; +import static org.elasticsearch.repositories.s3.S3ClientSettings.ENDPOINT_SETTING; +import static org.elasticsearch.repositories.s3.S3ClientSettings.MAX_RETRIES_SETTING; +import static org.elasticsearch.repositories.s3.S3ClientSettings.READ_TIMEOUT_SETTING; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +/** + * This class tests how a {@link S3BlobContainer} and its underlying AWS S3 client are retrying requests when reading or writing blobs. + */ +@SuppressForbidden(reason = "use a http server") +public class S3BlobContainerRetriesTests extends ESTestCase { + + private HttpServer httpServer; + private S3Service service; + + @Before + public void setUp() throws Exception { + service = new S3Service(); + httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + httpServer.start(); + super.setUp(); + } + + @After + public void tearDown() throws Exception { + IOUtils.close(service); + httpServer.stop(0); + super.tearDown(); + } + + private BlobContainer createBlobContainer(final @Nullable Integer maxRetries, + final @Nullable TimeValue readTimeout, + final @Nullable Boolean disableChunkedEncoding, + final @Nullable ByteSizeValue bufferSize) { + final Settings.Builder clientSettings = Settings.builder(); + final String clientName = randomAlphaOfLength(5).toLowerCase(Locale.ROOT); + + final String endpoint; + if (httpServer.getAddress().getAddress() instanceof Inet6Address) { + endpoint = "http://[" + httpServer.getAddress().getHostString() + "]:" + httpServer.getAddress().getPort(); + } else { + endpoint = "http://" + httpServer.getAddress().getHostString() + ":" + httpServer.getAddress().getPort(); + } + clientSettings.put(ENDPOINT_SETTING.getConcreteSettingForNamespace(clientName).getKey(), endpoint); + if (maxRetries != null) { + clientSettings.put(MAX_RETRIES_SETTING.getConcreteSettingForNamespace(clientName).getKey(), maxRetries); + } + if (readTimeout != null) { + clientSettings.put(READ_TIMEOUT_SETTING.getConcreteSettingForNamespace(clientName).getKey(), readTimeout); + } + if (disableChunkedEncoding != null) { + clientSettings.put(DISABLE_CHUNKED_ENCODING.getConcreteSettingForNamespace(clientName).getKey(), disableChunkedEncoding); + } + + final MockSecureSettings secureSettings = new MockSecureSettings(); + secureSettings.setString(S3ClientSettings.ACCESS_KEY_SETTING.getConcreteSettingForNamespace(clientName).getKey(), "access"); + secureSettings.setString(S3ClientSettings.SECRET_KEY_SETTING.getConcreteSettingForNamespace(clientName).getKey(), "secret"); + clientSettings.setSecureSettings(secureSettings); + service.refreshAndClearCache(S3ClientSettings.load(clientSettings.build())); + + final RepositoryMetaData repositoryMetaData = new RepositoryMetaData("repository", S3Repository.TYPE, + Settings.builder().put(S3Repository.CLIENT_NAME.getKey(), clientName).build()); + + return new S3BlobContainer(BlobPath.cleanPath(), new S3BlobStore(service, "bucket", + S3Repository.SERVER_SIDE_ENCRYPTION_SETTING.getDefault(Settings.EMPTY), + bufferSize == null ? S3Repository.BUFFER_SIZE_SETTING.getDefault(Settings.EMPTY) : bufferSize, + S3Repository.CANNED_ACL_SETTING.getDefault(Settings.EMPTY), + S3Repository.STORAGE_CLASS_SETTING.getDefault(Settings.EMPTY), + repositoryMetaData)); + } + + public void testReadBlobWithRetries() throws Exception { + final int maxRetries = randomInt(5); + final CountDown countDown = new CountDown(maxRetries + 1); + + final byte[] bytes = randomByteArrayOfLength(randomIntBetween(1, 512)); + httpServer.createContext("/bucket/read_blob_max_retries", exchange -> { + Streams.readFully(exchange.getRequestBody()); + if (countDown.countDown()) { + exchange.getResponseHeaders().add("Content-Type", "text/plain; charset=utf-8"); + exchange.sendResponseHeaders(HttpStatus.SC_OK, bytes.length); + exchange.getResponseBody().write(bytes); + exchange.close(); + return; + } + exchange.sendResponseHeaders(randomFrom(HttpStatus.SC_INTERNAL_SERVER_ERROR, HttpStatus.SC_BAD_GATEWAY, + HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_GATEWAY_TIMEOUT), -1); + exchange.close(); + }); + + final BlobContainer blobContainer = createBlobContainer(maxRetries, null, null, null); + try (InputStream inputStream = blobContainer.readBlob("read_blob_max_retries")) { + assertArrayEquals(bytes, BytesReference.toBytes(Streams.readFully(inputStream))); + assertThat(countDown.isCountedDown(), is(true)); + } + } + + public void testReadBlobWithReadTimeouts() { + final TimeValue readTimeout = TimeValue.timeValueMillis(randomIntBetween(100, 500)); + final BlobContainer blobContainer = createBlobContainer(1, readTimeout, null, null); + + // HTTP server does not send a response + httpServer.createContext("/bucket/read_blob_unresponsive", exchange -> {}); + + Exception exception = expectThrows(SdkClientException.class, () -> blobContainer.readBlob("read_blob_unresponsive")); + assertThat(exception.getMessage().toLowerCase(Locale.ROOT), containsString("read timed out")); + assertThat(exception.getCause(), instanceOf(SocketTimeoutException.class)); + + // HTTP server sends a partial response + final byte[] bytes = randomByteArrayOfLength(randomIntBetween(10, 128)); + httpServer.createContext("/bucket/read_blob_incomplete", exchange -> { + exchange.getResponseHeaders().add("Content-Type", "text/plain; charset=utf-8"); + exchange.sendResponseHeaders(HttpStatus.SC_OK, bytes.length); + exchange.getResponseBody().write(bytes, 0, randomIntBetween(1, bytes.length - 1)); + if (randomBoolean()) { + exchange.getResponseBody().flush(); + } + }); + + exception = expectThrows(SocketTimeoutException.class, () -> { + try (InputStream stream = blobContainer.readBlob("read_blob_incomplete")) { + Streams.readFully(stream); + } + }); + assertThat(exception.getMessage().toLowerCase(Locale.ROOT), containsString("read timed out")); + } + + public void testWriteBlobWithRetries() throws Exception { + final int maxRetries = randomInt(5); + final CountDown countDown = new CountDown(maxRetries + 1); + + final byte[] bytes = randomByteArrayOfLength(randomIntBetween(1, 512)); + httpServer.createContext("/bucket/write_blob_max_retries", exchange -> { + final BytesReference body = Streams.readFully(exchange.getRequestBody()); + if (countDown.countDown()) { + if (Objects.deepEquals(bytes, BytesReference.toBytes(body))) { + exchange.sendResponseHeaders(HttpStatus.SC_OK, -1); + } else { + exchange.sendResponseHeaders(HttpStatus.SC_BAD_REQUEST, -1); + } + exchange.close(); + return; + } + exchange.sendResponseHeaders(randomFrom(HttpStatus.SC_INTERNAL_SERVER_ERROR, HttpStatus.SC_BAD_GATEWAY, + HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_GATEWAY_TIMEOUT), -1); + exchange.close(); + }); + + final BlobContainer blobContainer = createBlobContainer(maxRetries, null, true, null); + try (InputStream stream = new InputStreamIndexInput(new ByteArrayIndexInput("desc", bytes), bytes.length)) { + blobContainer.writeBlob("write_blob_max_retries", stream, bytes.length, false); + } + assertThat(countDown.isCountedDown(), is(true)); + } + + public void testWriteBlobWithReadTimeouts() { + final TimeValue readTimeout = TimeValue.timeValueMillis(randomIntBetween(100, 500)); + final BlobContainer blobContainer = createBlobContainer(1, readTimeout, true, null); + + // HTTP server does not send a response + httpServer.createContext("/bucket/write_blob_timeout", exchange -> { + if (randomBoolean()) { + Streams.readFully(exchange.getRequestBody()); + } + }); + + final byte[] bytes = randomByteArrayOfLength(randomIntBetween(1, 128)); + Exception exception = expectThrows(IOException.class, () -> { + try (InputStream stream = new InputStreamIndexInput(new ByteArrayIndexInput("desc", bytes), bytes.length)) { + blobContainer.writeBlob("write_blob_timeout", stream, bytes.length, false); + } + }); + assertThat(exception.getMessage().toLowerCase(Locale.ROOT), + containsString("unable to upload object [write_blob_timeout] using a single upload")); + + assertThat(exception.getCause(), instanceOf(SdkClientException.class)); + assertThat(exception.getCause().getMessage().toLowerCase(Locale.ROOT), containsString("read timed out")); + + assertThat(exception.getCause().getCause(), instanceOf(SocketTimeoutException.class)); + assertThat(exception.getCause().getCause().getMessage().toLowerCase(Locale.ROOT), containsString("read timed out")); + } + + public void testWriteLargeBlob() throws Exception { + final boolean useTimeout = rarely(); + final TimeValue readTimeout = useTimeout ? TimeValue.timeValueMillis(randomIntBetween(100, 500)) : null; + final ByteSizeValue bufferSize = new ByteSizeValue(5, ByteSizeUnit.MB); + final BlobContainer blobContainer = createBlobContainer(null, readTimeout, true, bufferSize); + + final int parts = randomIntBetween(1, 2); + final long lastPartSize = randomLongBetween(10, 512); + final long blobSize = (parts * bufferSize.getBytes()) + lastPartSize; + + final int maxRetries = 2; // we want all requests to fail at least once + final CountDown countDownInitiate = new CountDown(maxRetries); + final AtomicInteger countDownUploads = new AtomicInteger(maxRetries * (parts + 1)); + final CountDown countDownComplete = new CountDown(maxRetries); + + httpServer.createContext("/bucket/write_large_blob", exchange -> { + if ("POST".equals(exchange.getRequestMethod()) + && exchange.getRequestURI().getQuery().equals("uploads")) { + // initiate multipart upload request + if (countDownInitiate.countDown()) { + byte[] response = ("\n" + + "\n" + + " bucket\n" + + " write_large_blob\n" + + " TEST\n" + + "").getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().add("Content-Type", "application/xml"); + exchange.sendResponseHeaders(HttpStatus.SC_OK, response.length); + exchange.getResponseBody().write(response); + exchange.close(); + return; + } + } else if ("PUT".equals(exchange.getRequestMethod())) { + // upload part request + MD5DigestCalculatingInputStream md5 = new MD5DigestCalculatingInputStream(exchange.getRequestBody()); + BytesReference bytes = Streams.readFully(md5); + assertThat((long) bytes.length(), anyOf(equalTo(lastPartSize), equalTo(bufferSize.getBytes()))); + + if (countDownUploads.decrementAndGet() % 2 == 0) { + exchange.getResponseHeaders().add("ETag", Base16.encodeAsString(md5.getMd5Digest())); + exchange.sendResponseHeaders(HttpStatus.SC_OK, -1); + exchange.close(); + return; + } + + } else if ("POST".equals(exchange.getRequestMethod()) + && exchange.getRequestURI().getQuery().equals("uploadId=TEST")) { + // complete multipart upload request + Streams.readFully(exchange.getRequestBody()); + if (countDownComplete.countDown()) { + byte[] response = ("\n" + + "\n" + + " bucket\n" + + " write_large_blob\n" + + "").getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().add("Content-Type", "application/xml"); + exchange.sendResponseHeaders(HttpStatus.SC_OK, response.length); + exchange.getResponseBody().write(response); + exchange.close(); + return; + } + } + + // sends an error back or let the request time out + if (useTimeout == false) { + exchange.sendResponseHeaders(randomFrom(HttpStatus.SC_INTERNAL_SERVER_ERROR, HttpStatus.SC_BAD_GATEWAY, + HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_GATEWAY_TIMEOUT), -1); + exchange.close(); + } + }); + + blobContainer.writeBlob("write_large_blob", new ZeroInputStream(blobSize), blobSize, false); + + assertThat(countDownInitiate.isCountedDown(), is(true)); + assertThat(countDownUploads.get(), equalTo(0)); + assertThat(countDownComplete.isCountedDown(), is(true)); + } + + /** + * A resettable InputStream that only serves zeros. + * + * Ideally it should be wrapped into a BufferedInputStream but it seems that the AWS SDK is calling InputStream{@link #reset()} + * before calling InputStream{@link #mark(int)}, which is not permitted by the {@link #reset()} method contract. + **/ + private static class ZeroInputStream extends InputStream { + + private final AtomicBoolean closed = new AtomicBoolean(false); + private final long length; + private final AtomicLong reads; + private volatile long mark; + + private ZeroInputStream(final long length) { + this.length = length; + this.reads = new AtomicLong(length); + this.mark = -1; + } + + @Override + public int read() throws IOException { + ensureOpen(); + if (reads.decrementAndGet() < 0) { + return -1; + } + return 0; + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public synchronized void mark(int readlimit) { + mark = reads.get(); + } + + @Override + public synchronized void reset() throws IOException { + ensureOpen(); + reads.set(mark); + } + + @Override + public int available() throws IOException { + ensureOpen(); + return Math.toIntExact(length - reads.get()); + } + + @Override + public void close() throws IOException { + closed.set(true); + } + + private void ensureOpen() throws IOException { + if (closed.get()) { + throw new IOException("Stream closed"); + } + } + } +}