Skip to content

Commit

Permalink
Change s3 sink client to async client (#4425)
Browse files Browse the repository at this point in the history
Signed-off-by: Taylor Gray <[email protected]>
  • Loading branch information
graytaylor0 authored Apr 18, 2024
1 parent c8a94fb commit 2859023
Show file tree
Hide file tree
Showing 28 changed files with 665 additions and 473 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
Expand Down Expand Up @@ -104,6 +105,8 @@ class S3SinkServiceIT {
private static final String PATH_PREFIX = UUID.randomUUID() + "/%{yyyy}/%{MM}/%{dd}/";
private static final int numberOfRecords = 2;
private S3Client s3Client;

private S3AsyncClient s3AsyncClient;
private String bucketName;
private String s3region;
private ParquetOutputCodecConfig parquetOutputCodecConfig;
Expand Down Expand Up @@ -152,6 +155,7 @@ public void setUp() {
s3region = System.getProperty("tests.s3sink.region");

s3Client = S3Client.builder().region(Region.of(s3region)).build();
s3AsyncClient = S3AsyncClient.builder().region(Region.of(s3region)).build();
bucketName = System.getProperty("tests.s3sink.bucket");
bufferFactory = new InMemoryBufferFactory();

Expand Down Expand Up @@ -266,9 +270,9 @@ void verify_flushed_records_into_s3_bucketNewLine_with_compression() throws IOEx
private S3SinkService createObjectUnderTest() {
OutputCodecContext codecContext = new OutputCodecContext("Tag", Collections.emptyList(), Collections.emptyList());
final S3GroupIdentifierFactory groupIdentifierFactory = new S3GroupIdentifierFactory(keyGenerator, expressionEvaluator, s3SinkConfig);
s3GroupManager = new S3GroupManager(s3SinkConfig, groupIdentifierFactory, bufferFactory, codecFactory, s3Client);
s3GroupManager = new S3GroupManager(s3SinkConfig, groupIdentifierFactory, bufferFactory, codecFactory, s3AsyncClient);

return new S3SinkService(s3SinkConfig, codecContext, s3Client, keyGenerator, Duration.ofSeconds(5), pluginMetrics, s3GroupManager);
return new S3SinkService(s3SinkConfig, codecContext, Duration.ofSeconds(5), pluginMetrics, s3GroupManager);
}

private int gets3ObjectCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@
import org.apache.parquet.io.PositionOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.NoSuchBucketException;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
import software.amazon.awssdk.services.s3.model.UploadPartResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.function.Supplier;

public class S3OutputStream extends PositionOutputStream {
Expand Down Expand Up @@ -51,7 +58,7 @@ public class S3OutputStream extends PositionOutputStream {
*/
private final byte[] buf;

private final S3Client s3Client;
private final S3AsyncClient s3Client;
/**
* Collection of the etags for the parts that have been uploaded
*/
Expand All @@ -74,14 +81,16 @@ public class S3OutputStream extends PositionOutputStream {
*/
private final String defaultBucket;

private final ExecutorService executorService;

/**
* Creates a new S3 OutputStream
*
* @param s3Client the AmazonS3 client
* @param bucketSupplier name of the bucket
* @param keySupplier path within the bucket
*/
public S3OutputStream(final S3Client s3Client,
public S3OutputStream(final S3AsyncClient s3Client,
final Supplier<String> bucketSupplier,
final Supplier<String> keySupplier,
final String defaultBucket) {
Expand All @@ -93,13 +102,18 @@ public S3OutputStream(final S3Client s3Client,
etags = new ArrayList<>();
open = true;
this.defaultBucket = defaultBucket;
this.executorService = Executors.newSingleThreadExecutor();
}

@Override
public void write(int b) {
assertOpen();
if (position >= buf.length) {
flushBufferAndRewind();
try {
flushBufferAndRewind();
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
}
}
buf[position++] = (byte) b;
}
Expand Down Expand Up @@ -132,7 +146,12 @@ public void write(byte[] byteArray, int o, int l) {
while (len > (size = buf.length - position)) {
System.arraycopy(byteArray, ofs, buf, position, size);
position += size;
flushBufferAndRewind();
try {
flushBufferAndRewind();
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
}

ofs += size;
len -= size;
}
Expand All @@ -147,36 +166,48 @@ public void write(byte[] byteArray, int o, int l) {
public void flush() {
}

@Override
public void close() {
public CompletableFuture<?> close(final Consumer<Boolean> runOnCompletion, final Consumer<Throwable> runOnError) {
if (open) {
open = false;
possiblyStartMultipartUpload();
if (position > 0) {
uploadPart();
}
try {
possiblyStartMultipartUpload();

if (position > 0) {
uploadPart();
}

CompletedPart[] completedParts = new CompletedPart[etags.size()];
for (int i = 0; i < etags.size(); i++) {
completedParts[i] = CompletedPart.builder()
.eTag(etags.get(i))
.partNumber(i + 1)
.build();
}

CompletedPart[] completedParts = new CompletedPart[etags.size()];
for (int i = 0; i < etags.size(); i++) {
completedParts[i] = CompletedPart.builder()
.eTag(etags.get(i))
.partNumber(i + 1)
LOG.debug("Completing S3 multipart upload with {} parts.", completedParts.length);

CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder()
.parts(completedParts)
.build();
}
CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder()
.bucket(bucket)
.key(key)
.uploadId(uploadId)
.multipartUpload(completedMultipartUpload)
.build();
CompletableFuture<CompleteMultipartUploadResponse> multipartUploadResponseCompletableFuture = s3Client.completeMultipartUpload(completeMultipartUploadRequest);

LOG.debug("Completing S3 multipart upload with {} parts.", completedParts.length);

CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder()
.parts(completedParts)
.build();
CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder()
.bucket(bucket)
.key(key)
.uploadId(uploadId)
.multipartUpload(completedMultipartUpload)
.build();
s3Client.completeMultipartUpload(completeMultipartUploadRequest);
multipartUploadResponseCompletableFuture.join();

runOnCompletion.accept(true);
return multipartUploadResponseCompletableFuture;
} catch (final Exception e) {
runOnError.accept(e);
runOnCompletion.accept(false);
}
}

return null;
}

public String getKey() {
Expand All @@ -189,7 +220,7 @@ private void assertOpen() {
}
}

private void flushBufferAndRewind() {
private void flushBufferAndRewind() throws ExecutionException, InterruptedException {
possiblyStartMultipartUpload();
uploadPart();
position = 0;
Expand All @@ -200,10 +231,11 @@ private void possiblyStartMultipartUpload() {

try {
createMultipartUpload();
} catch (final S3Exception e) {
if (defaultBucket != null && (e instanceof NoSuchBucketException || e.getMessage().contains(ACCESS_DENIED))) {
} catch (final CompletionException e) {
if (defaultBucket != null && (e.getCause() != null &&
(e.getCause() instanceof NoSuchBucketException || (e.getCause().getMessage() != null && e.getCause().getMessage().contains(ACCESS_DENIED))))) {
bucket = defaultBucket;
LOG.warn("Bucket {} could not be accessed to create multi-part upload, attempting to create multi-part upload to default_bucket {}", bucket, defaultBucket);
LOG.warn("Bucket {} could not be accessed to create multi-part upload, attempting to create multi-part upload to default_bucket {}. Error: {}", bucket, defaultBucket, e.getCause().getMessage());
createMultipartUpload();
} else {
throw e;
Expand All @@ -223,12 +255,17 @@ private void uploadPart() {
.partNumber(partNumber)
.contentLength((long) position)
.build();
RequestBody requestBody = RequestBody.fromInputStream(new ByteArrayInputStream(buf, 0, position),
position);

final InputStream inputStream = new ByteArrayInputStream(buf, 0, position);

AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromInputStream(inputStream, (long) position, executorService);

LOG.debug("Writing {} bytes to S3 multipart part number {}.", buf.length, partNumber);

UploadPartResponse uploadPartResponse = s3Client.uploadPart(uploadRequest, requestBody);
CompletableFuture<UploadPartResponse> uploadPartResponseFuture = s3Client.uploadPart(uploadRequest, asyncRequestBody);

final UploadPartResponse uploadPartResponse = uploadPartResponseFuture.join();

etags.add(uploadPartResponse.eTag());
}

Expand All @@ -242,8 +279,11 @@ private void createMultipartUpload() {
.bucket(bucket)
.key(key)
.build();
CreateMultipartUploadResponse multipartUpload = s3Client.createMultipartUpload(uploadRequest);
uploadId = multipartUpload.uploadId();
CompletableFuture<CreateMultipartUploadResponse> multipartUpload = s3Client.createMultipartUpload(uploadRequest);

final CreateMultipartUploadResponse response = multipartUpload.join();

uploadId = response.uploadId();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;

public final class ClientFactory {
Expand All @@ -26,8 +27,18 @@ static S3Client createS3Client(final S3SinkConfig s3SinkConfig, final AwsCredent
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
}

static S3AsyncClient createS3AsyncClient(final S3SinkConfig s3SinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return S3AsyncClient.builder()
.region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries()).build();
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries() * s3SinkConfig.getMaxUploadRetries()).build();
return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3AsyncClient;

import java.time.Duration;
import java.util.Collection;
Expand Down Expand Up @@ -77,7 +77,7 @@ public S3Sink(final PluginSetting pluginSetting,
final OutputCodec testCodec = pluginFactory.loadPlugin(OutputCodec.class, codecPluginSettings);
sinkInitialized = Boolean.FALSE;

final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
final S3AsyncClient s3Client = ClientFactory.createS3AsyncClient(s3SinkConfig, awsCredentialsSupplier);
BufferFactory innerBufferFactory = s3SinkConfig.getBufferType().getBufferFactory();
if(testCodec instanceof ParquetOutputCodec && s3SinkConfig.getBufferType() != BufferTypeOptions.INMEMORY) {
throw new InvalidPluginConfigurationException("The Parquet sink codec is an in_memory buffer only.");
Expand Down Expand Up @@ -115,7 +115,7 @@ public S3Sink(final PluginSetting pluginSetting,
final S3GroupManager s3GroupManager = new S3GroupManager(s3SinkConfig, s3GroupIdentifierFactory, bufferFactory, codecFactory, s3Client);


s3SinkService = new S3SinkService(s3SinkConfig, s3OutputCodecContext, s3Client, keyGenerator, RETRY_FLUSH_BACKOFF, pluginMetrics, s3GroupManager);
s3SinkService = new S3SinkService(s3SinkConfig, s3OutputCodecContext, RETRY_FLUSH_BACKOFF, pluginMetrics, s3GroupManager);
}

@Override
Expand Down
Loading

0 comments on commit 2859023

Please sign in to comment.