Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change s3 sink client to async client #4425

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
Expand Down Expand Up @@ -104,6 +105,8 @@ class S3SinkServiceIT {
private static final String PATH_PREFIX = UUID.randomUUID() + "/%{yyyy}/%{MM}/%{dd}/";
private static final int numberOfRecords = 2;
private S3Client s3Client;

private S3AsyncClient s3AsyncClient;
private String bucketName;
private String s3region;
private ParquetOutputCodecConfig parquetOutputCodecConfig;
Expand Down Expand Up @@ -152,6 +155,7 @@ public void setUp() {
s3region = System.getProperty("tests.s3sink.region");

s3Client = S3Client.builder().region(Region.of(s3region)).build();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove s3Client ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

S3Client is used in the test to validate objects

s3AsyncClient = S3AsyncClient.builder().region(Region.of(s3region)).build();
bucketName = System.getProperty("tests.s3sink.bucket");
bufferFactory = new InMemoryBufferFactory();

Expand Down Expand Up @@ -266,9 +270,9 @@ void verify_flushed_records_into_s3_bucketNewLine_with_compression() throws IOEx
private S3SinkService createObjectUnderTest() {
OutputCodecContext codecContext = new OutputCodecContext("Tag", Collections.emptyList(), Collections.emptyList());
final S3GroupIdentifierFactory groupIdentifierFactory = new S3GroupIdentifierFactory(keyGenerator, expressionEvaluator, s3SinkConfig);
s3GroupManager = new S3GroupManager(s3SinkConfig, groupIdentifierFactory, bufferFactory, codecFactory, s3Client);
s3GroupManager = new S3GroupManager(s3SinkConfig, groupIdentifierFactory, bufferFactory, codecFactory, s3AsyncClient);

return new S3SinkService(s3SinkConfig, codecContext, s3Client, keyGenerator, Duration.ofSeconds(5), pluginMetrics, s3GroupManager);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is keyGenerator used ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's used by S3GroupIdentifierFactory in this test

return new S3SinkService(s3SinkConfig, codecContext, Duration.ofSeconds(5), pluginMetrics, s3GroupManager);
}

private int gets3ObjectCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@
import org.apache.parquet.io.PositionOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.NoSuchBucketException;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
import software.amazon.awssdk.services.s3.model.UploadPartResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.function.Supplier;

public class S3OutputStream extends PositionOutputStream {
Expand Down Expand Up @@ -51,7 +58,7 @@ public class S3OutputStream extends PositionOutputStream {
*/
private final byte[] buf;

private final S3Client s3Client;
private final S3AsyncClient s3Client;
/**
* Collection of the etags for the parts that have been uploaded
*/
Expand All @@ -74,14 +81,16 @@ public class S3OutputStream extends PositionOutputStream {
*/
private final String defaultBucket;

private final ExecutorService executorService;

/**
* Creates a new S3 OutputStream
*
* @param s3Client the AmazonS3 client
* @param bucketSupplier name of the bucket
* @param keySupplier path within the bucket
*/
public S3OutputStream(final S3Client s3Client,
public S3OutputStream(final S3AsyncClient s3Client,
final Supplier<String> bucketSupplier,
final Supplier<String> keySupplier,
final String defaultBucket) {
Expand All @@ -93,13 +102,18 @@ public S3OutputStream(final S3Client s3Client,
etags = new ArrayList<>();
open = true;
this.defaultBucket = defaultBucket;
this.executorService = Executors.newSingleThreadExecutor();
}

@Override
public void write(int b) {
assertOpen();
if (position >= buf.length) {
flushBufferAndRewind();
try {
flushBufferAndRewind();
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is client handling the exception ?

}
}
buf[position++] = (byte) b;
}
Expand Down Expand Up @@ -132,7 +146,12 @@ public void write(byte[] byteArray, int o, int l) {
while (len > (size = buf.length - position)) {
System.arraycopy(byteArray, ofs, buf, position, size);
position += size;
flushBufferAndRewind();
try {
flushBufferAndRewind();
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
}

ofs += size;
len -= size;
}
Expand All @@ -147,36 +166,48 @@ public void write(byte[] byteArray, int o, int l) {
public void flush() {
}

@Override
public void close() {
public CompletableFuture<?> close(final Consumer<Boolean> runOnCompletion, final Consumer<Throwable> runOnError) {
if (open) {
open = false;
possiblyStartMultipartUpload();
if (position > 0) {
uploadPart();
}
try {
possiblyStartMultipartUpload();

if (position > 0) {
uploadPart();
}

CompletedPart[] completedParts = new CompletedPart[etags.size()];
for (int i = 0; i < etags.size(); i++) {
completedParts[i] = CompletedPart.builder()
.eTag(etags.get(i))
.partNumber(i + 1)
.build();
}

CompletedPart[] completedParts = new CompletedPart[etags.size()];
for (int i = 0; i < etags.size(); i++) {
completedParts[i] = CompletedPart.builder()
.eTag(etags.get(i))
.partNumber(i + 1)
LOG.debug("Completing S3 multipart upload with {} parts.", completedParts.length);

CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder()
.parts(completedParts)
.build();
}
CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder()
.bucket(bucket)
.key(key)
.uploadId(uploadId)
.multipartUpload(completedMultipartUpload)
.build();
CompletableFuture<CompleteMultipartUploadResponse> multipartUploadResponseCompletableFuture = s3Client.completeMultipartUpload(completeMultipartUploadRequest);

LOG.debug("Completing S3 multipart upload with {} parts.", completedParts.length);

CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder()
.parts(completedParts)
.build();
CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder()
.bucket(bucket)
.key(key)
.uploadId(uploadId)
.multipartUpload(completedMultipartUpload)
.build();
s3Client.completeMultipartUpload(completeMultipartUploadRequest);
multipartUploadResponseCompletableFuture.join();

runOnCompletion.accept(true);
return multipartUploadResponseCompletableFuture;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The execution is complete in line 200. Why do we want to return the future here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can return null here

} catch (final Exception e) {
runOnError.accept(e);
runOnCompletion.accept(false);
}
}

return null;
}

public String getKey() {
Expand All @@ -189,7 +220,7 @@ private void assertOpen() {
}
}

private void flushBufferAndRewind() {
private void flushBufferAndRewind() throws ExecutionException, InterruptedException {
possiblyStartMultipartUpload();
uploadPart();
position = 0;
Expand All @@ -200,10 +231,11 @@ private void possiblyStartMultipartUpload() {

try {
createMultipartUpload();
} catch (final S3Exception e) {
if (defaultBucket != null && (e instanceof NoSuchBucketException || e.getMessage().contains(ACCESS_DENIED))) {
} catch (final CompletionException e) {
if (defaultBucket != null && (e.getCause() != null &&
(e.getCause() instanceof NoSuchBucketException || (e.getCause().getMessage() != null && e.getCause().getMessage().contains(ACCESS_DENIED))))) {
bucket = defaultBucket;
LOG.warn("Bucket {} could not be accessed to create multi-part upload, attempting to create multi-part upload to default_bucket {}", bucket, defaultBucket);
LOG.warn("Bucket {} could not be accessed to create multi-part upload, attempting to create multi-part upload to default_bucket {}. Error: {}", bucket, defaultBucket, e.getCause().getMessage());
createMultipartUpload();
} else {
throw e;
Expand All @@ -223,12 +255,17 @@ private void uploadPart() {
.partNumber(partNumber)
.contentLength((long) position)
.build();
RequestBody requestBody = RequestBody.fromInputStream(new ByteArrayInputStream(buf, 0, position),
position);

final InputStream inputStream = new ByteArrayInputStream(buf, 0, position);

AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromInputStream(inputStream, (long) position, executorService);

LOG.debug("Writing {} bytes to S3 multipart part number {}.", buf.length, partNumber);

UploadPartResponse uploadPartResponse = s3Client.uploadPart(uploadRequest, requestBody);
CompletableFuture<UploadPartResponse> uploadPartResponseFuture = s3Client.uploadPart(uploadRequest, asyncRequestBody);

final UploadPartResponse uploadPartResponse = uploadPartResponseFuture.join();

etags.add(uploadPartResponse.eTag());
}

Expand All @@ -242,8 +279,11 @@ private void createMultipartUpload() {
.bucket(bucket)
.key(key)
.build();
CreateMultipartUploadResponse multipartUpload = s3Client.createMultipartUpload(uploadRequest);
uploadId = multipartUpload.uploadId();
CompletableFuture<CreateMultipartUploadResponse> multipartUpload = s3Client.createMultipartUpload(uploadRequest);

final CreateMultipartUploadResponse response = multipartUpload.join();

uploadId = response.uploadId();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;

public final class ClientFactory {
Expand All @@ -26,8 +27,18 @@ static S3Client createS3Client(final S3SinkConfig s3SinkConfig, final AwsCredent
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
}

static S3AsyncClient createS3AsyncClient(final S3SinkConfig s3SinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return S3AsyncClient.builder()
.region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries()).build();
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries() * s3SinkConfig.getMaxUploadRetries()).build();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these too many retries ? What is the default value ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we were doing this number of retries * the max upload retries, but we were just doing it manually. So to get the same number of retries just using the client, we would multiply like this

return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3AsyncClient;

import java.time.Duration;
import java.util.Collection;
Expand Down Expand Up @@ -77,7 +77,7 @@ public S3Sink(final PluginSetting pluginSetting,
final OutputCodec testCodec = pluginFactory.loadPlugin(OutputCodec.class, codecPluginSettings);
sinkInitialized = Boolean.FALSE;

final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
final S3AsyncClient s3Client = ClientFactory.createS3AsyncClient(s3SinkConfig, awsCredentialsSupplier);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have option to choose the type of client? (ie support both sync and async client)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this configurable or dynamic based on the buffer_type, but I would say that it doesn't make sense to make this configurable for users

BufferFactory innerBufferFactory = s3SinkConfig.getBufferType().getBufferFactory();
if(testCodec instanceof ParquetOutputCodec && s3SinkConfig.getBufferType() != BufferTypeOptions.INMEMORY) {
throw new InvalidPluginConfigurationException("The Parquet sink codec is an in_memory buffer only.");
Expand Down Expand Up @@ -115,7 +115,7 @@ public S3Sink(final PluginSetting pluginSetting,
final S3GroupManager s3GroupManager = new S3GroupManager(s3SinkConfig, s3GroupIdentifierFactory, bufferFactory, codecFactory, s3Client);


s3SinkService = new S3SinkService(s3SinkConfig, s3OutputCodecContext, s3Client, keyGenerator, RETRY_FLUSH_BACKOFF, pluginMetrics, s3GroupManager);
s3SinkService = new S3SinkService(s3SinkConfig, s3OutputCodecContext, RETRY_FLUSH_BACKOFF, pluginMetrics, s3GroupManager);
}

@Override
Expand Down
Loading
Loading