Skip to content

Commit

Permalink
Updates the S3 sink to use the AWS Plugin for loading AWS credentials (
Browse files Browse the repository at this point in the history
…#2787)

Updates the S3 sink to use the AWS Plugin for loading AWS credentials. Resolves #2767

Signed-off-by: David Venable <[email protected]>
  • Loading branch information
dlvenable authored Jun 2, 2023
1 parent 41c657c commit 40e60fb
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 293 deletions.
1 change: 1 addition & 0 deletions data-prepper-plugins/s3-sink/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
dependencies {
implementation project(':data-prepper-api')
implementation project(path: ':data-prepper-plugins:common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ public void setUp() {
when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb"));
when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT3M"));
when(s3SinkConfig.getThresholdOptions()).thenReturn(thresholdOptions);
when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(s3region));

lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_SUCCEEDED)).thenReturn(snapshotSuccessCounter);
lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_FAILED)).thenReturn(snapshotFailedCounter);
Expand Down Expand Up @@ -136,7 +134,7 @@ void verify_flushed_records_into_s3_bucket() {
}

private S3SinkService createObjectUnderTest() {
return new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics);
return new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics);
}

private int gets3ObjectCount() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink;

import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3Client;

public final class ClientFactory {
private ClientFactory() { }

static S3Client createS3Client(final S3SinkConfig s3SinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return S3Client.builder()
.region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries()).build();
return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
}

private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) {
return AwsCredentialsOptions.builder()
.withRegion(awsAuthenticationOptions.getAwsRegion())
.withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn())
.withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.dataprepper.plugins.sink;

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.configuration.PluginModel;
Expand All @@ -22,6 +23,8 @@
import org.opensearch.dataprepper.plugins.sink.codec.Codec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.s3.S3Client;

import java.util.Collection;

/**
Expand All @@ -35,7 +38,7 @@ public class S3Sink extends AbstractSink<Record<Event>> {
private final S3SinkConfig s3SinkConfig;
private final Codec codec;
private volatile boolean sinkInitialized;
private S3SinkService s3SinkService;
private final S3SinkService s3SinkService;
private final BufferFactory bufferFactory;

/**
Expand All @@ -44,8 +47,10 @@ public class S3Sink extends AbstractSink<Record<Event>> {
* @param pluginFactory dp plugin factory.
*/
@DataPrepperPluginConstructor
public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig,
final PluginFactory pluginFactory) {
public S3Sink(final PluginSetting pluginSetting,
final S3SinkConfig s3SinkConfig,
final PluginFactory pluginFactory,
final AwsCredentialsSupplier awsCredentialsSupplier) {
super(pluginSetting);
this.s3SinkConfig = s3SinkConfig;
final PluginModel codecConfiguration = s3SinkConfig.getCodec();
Expand All @@ -59,6 +64,8 @@ public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig
} else {
bufferFactory = new InMemoryBufferFactory();
}
final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics);
}

@Override
Expand All @@ -85,7 +92,6 @@ public void doInitialize() {
* Initialize {@link S3SinkService}
*/
private void doInitializeInternal() {
s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics);
sinkInitialized = Boolean.TRUE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3Client;

import java.io.IOException;
Expand All @@ -47,6 +45,7 @@ public class S3SinkService {
private final BufferFactory bufferFactory;
private final Collection<EventHandle> bufferedEventHandles;
private final Codec codec;
private final S3Client s3Client;
private Buffer currentBuffer;
private final int maxEvents;
private final ByteCount maxBytes;
Expand All @@ -61,15 +60,17 @@ public class S3SinkService {

/**
* @param s3SinkConfig s3 sink related configuration.
* @param bufferFactory factory of buffer.
* @param bufferFactory factory of buffer.
* @param codec parser.
* @param s3Client
* @param pluginMetrics metrics.
*/
public S3SinkService(final S3SinkConfig s3SinkConfig, final BufferFactory bufferFactory,
final Codec codec, final PluginMetrics pluginMetrics) {
final Codec codec, final S3Client s3Client, final PluginMetrics pluginMetrics) {
this.s3SinkConfig = s3SinkConfig;
this.bufferFactory = bufferFactory;
this.codec = codec;
this.s3Client = s3Client;
reentrantLock = new ReentrantLock();

bufferedEventHandles = new LinkedList<>();
Expand Down Expand Up @@ -154,7 +155,7 @@ protected boolean retryFlushToS3(final Buffer currentBuffer, final String s3Key)
int retryCount = maxRetries;
do {
try {
currentBuffer.flushToS3(createS3Client(), bucket, s3Key);
currentBuffer.flushToS3(s3Client, bucket, s3Key);
isUploadedToS3 = Boolean.TRUE;
} catch (AwsServiceException | SdkClientException e) {
LOG.error("Exception occurred while uploading records to s3 bucket. Retry countdown : {} | exception:",
Expand All @@ -179,15 +180,4 @@ protected String generateKey() {
final String namePattern = ObjectKey.objectFileName(s3SinkConfig);
return (!pathPrefix.isEmpty()) ? pathPrefix + namePattern : namePattern;
}

/**
* create s3 client instance.
* @return {@link S3Client}
*/
public S3Client createS3Client() {
return S3Client.builder().region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(s3SinkConfig.getAwsAuthenticationOptions().authenticateAwsConfiguration())
.overrideConfiguration(ClientOverrideConfiguration.builder().retryPolicy(RetryPolicy.builder()
.numRetries(s3SinkConfig.getMaxConnectionRetries()).build()).build()).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,11 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.util.Map;
import java.util.Optional;
import java.util.UUID;

public class AwsAuthenticationOptions {
private static final String AWS_IAM_ROLE = "role";
private static final String AWS_IAM = "iam";

@JsonProperty("region")
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;
Expand All @@ -35,58 +24,15 @@ public class AwsAuthenticationOptions {
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;

private void validateStsRoleArn() {
final Arn arn = getArn();
if (!AWS_IAM.equals(arn.service())) {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
}
final Optional<String> resourceType = arn.resource().resourceType();
if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
}
}

private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
}

public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
}

public AwsCredentialsProvider authenticateAwsConfiguration() {

final AwsCredentialsProvider awsCredentialsProvider;
if (awsStsRoleArn != null && !awsStsRoleArn.isEmpty()) {

validateStsRoleArn();

final StsClient stsClient = StsClient.builder()
.region(getAwsRegion())
.build();

AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder()
.roleSessionName("S3-Sink-" + UUID.randomUUID())
.roleArn(awsStsRoleArn);
if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) {
assumeRoleRequestBuilder = assumeRoleRequestBuilder
.overrideConfiguration(configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader));
}

awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(assumeRoleRequestBuilder.build())
.build();

} else {
// use default credential provider
awsCredentialsProvider = DefaultCredentialsProvider.create();
}
public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

return awsCredentialsProvider;
public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;

import java.util.Map;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class ClientFactoryTest {
@Mock
private S3SinkConfig s3SinkConfig;
@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;

@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;

@BeforeEach
void setUp() {
when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
}

@Test
void createS3Client_with_real_S3Client() {
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1);
final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);

assertThat(s3Client, notNullValue());
}

@ParameterizedTest
@ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"})
void createS3Client_provides_correct_inputs(final String regionString) {
final Region region = Region.of(regionString);
final String stsRoleArn = UUID.randomUUID().toString();
final Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);

final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider);

final S3ClientBuilder s3ClientBuilder = mock(S3ClientBuilder.class);
when(s3ClientBuilder.region(region)).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.credentialsProvider(any())).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(s3ClientBuilder);
try(final MockedStatic<S3Client> s3ClientMockedStatic = mockStatic(S3Client.class)) {
s3ClientMockedStatic.when(S3Client::builder)
.thenReturn(s3ClientBuilder);
ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
}

final ArgumentCaptor<AwsCredentialsProvider> credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class);
verify(s3ClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture());

final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue();

assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider));

final ArgumentCaptor<AwsCredentialsOptions> optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture());

final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue();
assertThat(actualCredentialsOptions.getRegion(), equalTo(region));
assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));
}
}
Loading

0 comments on commit 40e60fb

Please sign in to comment.