diff --git a/data-prepper-expression/src/test/java/org/opensearch/dataprepper/expression/GenericExpressionEvaluator_ConditionalIT.java b/data-prepper-expression/src/test/java/org/opensearch/dataprepper/expression/GenericExpressionEvaluator_ConditionalIT.java index 3b71b8527b..38f34a9eeb 100644 --- a/data-prepper-expression/src/test/java/org/opensearch/dataprepper/expression/GenericExpressionEvaluator_ConditionalIT.java +++ b/data-prepper-expression/src/test/java/org/opensearch/dataprepper/expression/GenericExpressionEvaluator_ConditionalIT.java @@ -5,6 +5,7 @@ package org.opensearch.dataprepper.expression; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -17,13 +18,13 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.Map; import java.util.List; +import java.util.Map; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; -import java.util.Random; import static org.awaitility.Awaitility.await; import static org.hamcrest.CoreMatchers.equalTo; @@ -34,7 +35,6 @@ import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; -import org.apache.commons.lang3.RandomStringUtils; class GenericExpressionEvaluator_ConditionalIT { /** @@ -94,7 +94,7 @@ void testConditionalExpressionEvaluator(final String expression, final Event eve void testGenericExpressionEvaluatorWithMultipleThreads(final String expression, final Event event, final Boolean expected) { final GenericExpressionEvaluator evaluator = applicationContext.getBean(GenericExpressionEvaluator.class); - final int numberOfThreads = 50; + final int numberOfThreads = 10; final ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads); List evaluationResults = Collections.synchronizedList(new ArrayList<>()); diff --git a/data-prepper-plugins/opensearch-source/build.gradle b/data-prepper-plugins/opensearch-source/build.gradle index f2e20c6b37..7310eb1d10 100644 --- a/data-prepper-plugins/opensearch-source/build.gradle +++ b/data-prepper-plugins/opensearch-source/build.gradle @@ -4,11 +4,17 @@ */ dependencies { implementation project(path: ':data-prepper-api') + implementation project(':data-prepper-plugins:aws-plugin-api') + implementation 'software.amazon.awssdk:apache-client' implementation 'com.fasterxml.jackson.core:jackson-databind' implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.14.2' implementation 'software.amazon.awssdk:s3' implementation 'software.amazon.awssdk:sts' testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' + implementation 'org.opensearch.client:opensearch-java:2.4.0' + implementation 'org.opensearch.client:opensearch-rest-client:2.7.0' + implementation "org.apache.commons:commons-lang3:3.12.0" + implementation 'org.apache.maven:maven-artifact:3.0.3' testImplementation testLibs.mockito.inline } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchIndexProgressState.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchIndexProgressState.java new file mode 100644 index 0000000000..a67976130d --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchIndexProgressState.java @@ -0,0 +1,9 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch; + +public class OpenSearchIndexProgressState { +} diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java index 87ead06fea..a78ba0c22e 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java @@ -4,14 +4,59 @@ */ package org.opensearch.dataprepper.plugins.source.opensearch; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.PitWorker; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.ScrollWorker; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; + public class OpenSearchService { + private final SearchAccessor searchAccessor; + private final OpenSearchSourceConfiguration openSearchSourceConfiguration; + private final SourceCoordinator sourceCoordinator; + private final Buffer> buffer; + + private Thread searchWorkerThread; + + public static OpenSearchService createOpenSearchService(final SearchAccessor searchAccessor, + final SourceCoordinator sourceCoordinator, + final OpenSearchSourceConfiguration openSearchSourceConfiguration, + final Buffer> buffer) { + return new OpenSearchService(searchAccessor, sourceCoordinator, openSearchSourceConfiguration, buffer); + } + + private OpenSearchService(final SearchAccessor searchAccessor, + final SourceCoordinator sourceCoordinator, + final OpenSearchSourceConfiguration openSearchSourceConfiguration, + final Buffer> buffer) { + this.searchAccessor = searchAccessor; + this.openSearchSourceConfiguration = openSearchSourceConfiguration; + this.buffer = buffer; + this.sourceCoordinator = sourceCoordinator; + this.sourceCoordinator.initialize(); + } + public void start() { - // todo: to implement - // Leverages a runnable (SearchWorker) to perform the querying on the source cluster + switch(searchAccessor.getSearchContextType()) { + case POINT_IN_TIME: + searchWorkerThread = new Thread(new PitWorker(searchAccessor, openSearchSourceConfiguration, sourceCoordinator, buffer)); + break; + case SCROLL: + searchWorkerThread = new Thread(new ScrollWorker(searchAccessor, openSearchSourceConfiguration, sourceCoordinator, buffer)); + break; + default: + throw new IllegalArgumentException( + String.format("Search context type must be POINT_IN_TIME or SCROLL, type %s was given instead", + searchAccessor.getSearchContextType())); + } + + searchWorkerThread.start(); } public void stop() { - // todo: to implement + } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java index fbf4397311..141079631f 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java @@ -10,32 +10,57 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.Source; +import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.model.source.coordinator.UsesSourceCoordination; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy; -@DataPrepperPlugin(name="opensearch", pluginType = Source.class , pluginConfigurationType =OpenSearchSourceConfiguration.class ) -public class OpenSearchSource implements Source> { +@DataPrepperPlugin(name="opensearch", pluginType = Source.class , pluginConfigurationType = OpenSearchSourceConfiguration.class ) +public class OpenSearchSource implements Source>, UsesSourceCoordination { + private final AwsCredentialsSupplier awsCredentialsSupplier; private final OpenSearchSourceConfiguration openSearchSourceConfiguration; + private SourceCoordinator sourceCoordinator; + private OpenSearchService openSearchService; + @DataPrepperPluginConstructor - public OpenSearchSource(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + public OpenSearchSource(final OpenSearchSourceConfiguration openSearchSourceConfiguration, + final AwsCredentialsSupplier awsCredentialsSupplier) { this.openSearchSourceConfiguration = openSearchSourceConfiguration; + this.awsCredentialsSupplier = awsCredentialsSupplier; } @Override - public void start(Buffer> buffer) { + public void start(final Buffer> buffer) { if (buffer == null) { throw new IllegalStateException("Buffer provided is null"); } - startProcess(openSearchSourceConfiguration); + startProcess(openSearchSourceConfiguration, buffer); } - private void startProcess(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { - // todo: implement - // Should leverage OpenSearchService to run the actual plugin core logic. + private void startProcess(final OpenSearchSourceConfiguration openSearchSourceConfiguration, final Buffer> buffer) { + final SearchAccessorStrategy searchAccessorStrategy = SearchAccessorStrategy.create(openSearchSourceConfiguration, awsCredentialsSupplier); + + final SearchAccessor searchAccessor = searchAccessorStrategy.getSearchAccessor(); + + openSearchService = OpenSearchService.createOpenSearchService(searchAccessor, sourceCoordinator, openSearchSourceConfiguration, buffer); + openSearchService.start(); } @Override public void stop() { - // Yet to implement + openSearchService.stop(); + } + + @Override + public void setSourceCoordinator(final SourceCoordinator sourceCoordinator) { + this.sourceCoordinator = (SourceCoordinator) sourceCoordinator; + } + + @Override + public Class getPartitionProgressStateClass() { + return OpenSearchIndexProgressState.class; } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceConfiguration.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceConfiguration.java index f4f53f70f7..69edc9fb60 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceConfiguration.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceConfiguration.java @@ -39,7 +39,7 @@ public class OpenSearchSourceConfiguration { @JsonProperty("connection") @Valid - private ConnectionConfiguration connectionConfiguration; + private ConnectionConfiguration connectionConfiguration = new ConnectionConfiguration(); @JsonProperty("indices") @Valid @@ -97,7 +97,7 @@ public SearchConfiguration getSearchConfiguration() { boolean validateAwsConfigWithUsernameAndPassword() { return !((Objects.nonNull(awsAuthenticationOptions) && (Objects.nonNull(username) || Objects.nonNull(password))) || - (Objects.isNull(awsAuthenticationOptions) && Objects.isNull(username) && Objects.isNull(password))); + (Objects.isNull(awsAuthenticationOptions) && (Objects.isNull(username) || Objects.isNull(password)))); } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java index c2a844adba..4c4a1ec058 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java @@ -7,17 +7,9 @@ import com.fasterxml.jackson.annotation.JsonProperty; import jakarta.validation.constraints.Size; -import software.amazon.awssdk.arns.Arn; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.sts.StsClient; -import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import java.util.Map; -import java.util.Optional; -import java.util.UUID; public class AwsAuthenticationConfiguration { private static final String AWS_IAM_ROLE = "role"; @@ -35,25 +27,6 @@ public class AwsAuthenticationConfiguration { @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") private Map awsStsHeaderOverrides; - private void validateStsRoleArn() { - final Arn arn = getArn(); - if (!AWS_IAM.equals(arn.service())) { - throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); - } - final Optional resourceType = arn.resource().resourceType(); - if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) { - throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); - } - } - - private Arn getArn() { - try { - return Arn.fromString(awsStsRoleArn); - } catch (final Exception e) { - throw new IllegalArgumentException(String.format("Invalid ARN format for aws.sts_role_arn. Check the format of %s", awsStsRoleArn)); - } - } - public String getAwsStsRoleArn() { return awsStsRoleArn; } @@ -62,36 +35,8 @@ public Region getAwsRegion() { return awsRegion != null ? Region.of(awsRegion) : null; } - public AwsCredentialsProvider authenticateAwsConfiguration() { - - final AwsCredentialsProvider awsCredentialsProvider; - if (awsStsRoleArn != null && !awsStsRoleArn.isEmpty()) { - - validateStsRoleArn(); - - final StsClient stsClient = StsClient.builder() - .region(getAwsRegion()) - .build(); - - AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder() - .roleSessionName("OpenSearch-Source-" + UUID.randomUUID()) - .roleArn(awsStsRoleArn); - if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) { - assumeRoleRequestBuilder = assumeRoleRequestBuilder - .overrideConfiguration(configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader)); - } - - awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder() - .stsClient(stsClient) - .refreshRequest(assumeRoleRequestBuilder.build()) - .build(); - - } else { - // use default credential provider - awsCredentialsProvider = DefaultCredentialsProvider.create(); - } - - return awsCredentialsProvider; + public Map getAwsStsHeaderOverrides() { + return awsStsHeaderOverrides; } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java index b176d36d01..78b9d0328f 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java @@ -4,13 +4,36 @@ */ package org.opensearch.dataprepper.plugins.source.opensearch.worker; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchIndexProgressState; +import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; + /** * PitWorker polls the source cluster via Point-In-Time contexts. */ -public class PitWorker implements Runnable { +public class PitWorker implements SearchWorker, Runnable { + + private final SearchAccessor searchAccessor; + private final OpenSearchSourceConfiguration openSearchSourceConfiguration; + private final SourceCoordinator sourceCoordinator; + private final Buffer> buffer; + + public PitWorker(final SearchAccessor searchAccessor, + final OpenSearchSourceConfiguration openSearchSourceConfiguration, + final SourceCoordinator sourceCoordinator, + final Buffer> buffer) { + this.searchAccessor = searchAccessor; + this.sourceCoordinator = sourceCoordinator; + this.openSearchSourceConfiguration = openSearchSourceConfiguration; + this.buffer = buffer; + } @Override public void run() { - //todo: to implement + // todo: implement } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/ScrollWorker.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/ScrollWorker.java index 8a5d9400c9..c6e0d049ae 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/ScrollWorker.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/ScrollWorker.java @@ -4,13 +4,36 @@ */ package org.opensearch.dataprepper.plugins.source.opensearch.worker; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchIndexProgressState; +import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; + /** * ScrollWorker polls the source cluster via scroll contexts. */ -public class ScrollWorker implements Runnable { +public class ScrollWorker implements SearchWorker { + + private final SearchAccessor searchAccessor; + private final OpenSearchSourceConfiguration openSearchSourceConfiguration; + private final SourceCoordinator sourceCoordinator; + private final Buffer> buffer; + + public ScrollWorker(final SearchAccessor searchAccessor, + final OpenSearchSourceConfiguration openSearchSourceConfiguration, + final SourceCoordinator sourceCoordinator, + final Buffer> buffer) { + this.searchAccessor = searchAccessor; + this.openSearchSourceConfiguration = openSearchSourceConfiguration; + this.sourceCoordinator = sourceCoordinator; + this.buffer = buffer; + } @Override public void run() { - //todo: to implement + // todo: implement } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/SearchWorker.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/SearchWorker.java new file mode 100644 index 0000000000..d075ebad4f --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/SearchWorker.java @@ -0,0 +1,9 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch.worker; + +public interface SearchWorker extends Runnable { +} diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/SearchWorkerStrategy.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/SearchWorkerStrategy.java deleted file mode 100644 index bfed961931..0000000000 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/SearchWorkerStrategy.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.dataprepper.plugins.source.opensearch.worker; - -import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; - -/** - * Search Worker Strategy determines which SearchWorker (PITSearchWorker or ScrollSearchWorker) based on the - * {@link OpenSearchSourceConfiguration}. - * - * @since 2.4 - */ -public class SearchWorkerStrategy { - - /** - * Get the SearchWorker based on the based on the {@link OpenSearchSourceConfiguration}. - * @param openSearchSourceConfiguration the plugins configuration - * @return a runnable to execute a strategy for polling the source cluster - * @since 2.4 - */ - public Runnable getSearchWorker(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { - // todo: to implement - return null; - } -} diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/AwsRequestSigningApacheInterceptor.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/AwsRequestSigningApacheInterceptor.java new file mode 100644 index 0000000000..b3460553a0 --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/AwsRequestSigningApacheInterceptor.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with + * the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; + +import org.apache.http.Header; +import org.apache.http.HttpEntityEnclosingRequest; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.NameValuePair; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.BasicHttpEntity; +import org.apache.http.message.BasicHeader; +import org.apache.http.protocol.HttpContext; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.regions.Region; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST; + +/** + * An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer} + * and {@link AwsCredentialsProvider}. This is a copy from the opensearch sink + */ +final class AwsRequestSigningApacheInterceptor implements HttpRequestInterceptor { + + /** + * Constant to check content-length + */ + private static final String CONTENT_LENGTH = "content-length"; + /** + * Constant to check Zero content length + */ + private static final String ZERO_CONTENT_LENGTH = "0"; + /** + * Constant to check if host is the endpoint + */ + private static final String HOST = "host"; + + /** + * The service that we're connecting to. + */ + private final String service; + + /** + * The particular signer implementation. + */ + private final Signer signer; + + /** + * The source of AWS credentials for signing. + */ + private final AwsCredentialsProvider awsCredentialsProvider; + + /** + * The region signing region. + */ + private final Region region; + + /** + * + * @param service service that we're connecting to + * @param signer particular signer implementation + * @param awsCredentialsProvider source of AWS credentials for signing + * @param region signing region + */ + public AwsRequestSigningApacheInterceptor(final String service, + final Signer signer, + final AwsCredentialsProvider awsCredentialsProvider, + final Region region) { + this.service = Objects.requireNonNull(service); + this.signer = Objects.requireNonNull(signer); + this.awsCredentialsProvider = Objects.requireNonNull(awsCredentialsProvider); + this.region = Objects.requireNonNull(region); + } + + /** + * {@inheritDoc} + */ + @Override + public void process(final HttpRequest request, final HttpContext context) + throws IOException { + URIBuilder uriBuilder; + try { + uriBuilder = new URIBuilder(request.getRequestLine().getUri()); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI" , e); + } + + // Copy Apache HttpRequest to AWS Request + SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.fromValue(request.getRequestLine().getMethod())) + .uri(buildUri(context, uriBuilder)); + + if (request instanceof HttpEntityEnclosingRequest) { + HttpEntityEnclosingRequest httpEntityEnclosingRequest = + (HttpEntityEnclosingRequest) request; + if (httpEntityEnclosingRequest.getEntity() != null) { + InputStream content = httpEntityEnclosingRequest.getEntity().getContent(); + requestBuilder.contentStreamProvider(() -> content); + } + } + requestBuilder.rawQueryParameters(nvpToMapParams(uriBuilder.getQueryParams())); + requestBuilder.headers(headerArrayToMap(request.getAllHeaders())); + + ExecutionAttributes attributes = new ExecutionAttributes(); + attributes.putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, awsCredentialsProvider.resolveCredentials()); + attributes.putAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, service); + attributes.putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, region); + + // Sign it + SdkHttpFullRequest signedRequest = signer.sign(requestBuilder.build(), attributes); + + // Now copy everything back + request.setHeaders(mapToHeaderArray(signedRequest.headers())); + if (request instanceof HttpEntityEnclosingRequest) { + HttpEntityEnclosingRequest httpEntityEnclosingRequest = + (HttpEntityEnclosingRequest) request; + if (httpEntityEnclosingRequest.getEntity() != null) { + BasicHttpEntity basicHttpEntity = new BasicHttpEntity(); + basicHttpEntity.setContent(signedRequest.contentStreamProvider() + .orElseThrow(() -> new IllegalStateException("There must be content")) + .newStream()); + httpEntityEnclosingRequest.setEntity(basicHttpEntity); + } + } + } + + private URI buildUri(final HttpContext context, URIBuilder uriBuilder) throws IOException { + try { + HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST); + + if (host != null) { + uriBuilder.setHost(host.getHostName()); + uriBuilder.setScheme(host.getSchemeName()); + uriBuilder.setPort(host.getPort()); + } + + return uriBuilder.build(); + } catch (URISyntaxException e) { + throw new IOException("Invalid URI", e); + } + } + + /** + * + * @param params list of HTTP query params as NameValuePairs + * @return a multimap of HTTP query params + */ + private static Map> nvpToMapParams(final List params) { + Map> parameterMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + for (NameValuePair nvp : params) { + List argsList = + parameterMap.computeIfAbsent(nvp.getName(), k -> new ArrayList<>()); + argsList.add(nvp.getValue()); + } + return parameterMap; + } + + /** + * @param headers modelled Header objects + * @return a Map of header entries + */ + private static Map> headerArrayToMap(final Header[] headers) { + Map> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + for (Header header : headers) { + if (!skipHeader(header)) { + headersMap.put(header.getName(), headersMap + .getOrDefault(header.getName(), + new LinkedList<>(Collections.singletonList(header.getValue())))); + } + } + return headersMap; + } + + /** + * @param header header line to check + * @return true if the given header should be excluded when signing + */ + private static boolean skipHeader(final Header header) { + return (CONTENT_LENGTH.equalsIgnoreCase(header.getName()) + && ZERO_CONTENT_LENGTH.equals(header.getValue())) // Strip Content-Length: 0 + || HOST.equalsIgnoreCase(header.getName()); // Host comes from endpoint + } + + /** + * @param mapHeaders Map of header entries + * @return modelled Header objects + */ + private static Header[] mapToHeaderArray(final Map> mapHeaders) { + Header[] headers = new Header[mapHeaders.size()]; + int i = 0; + for (Map.Entry> headerEntry : mapHeaders.entrySet()) { + for (String value : headerEntry.getValue()) { + headers[i++] = new BasicHeader(headerEntry.getKey(), value); + } + } + return headers; + } +} diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java index ced025f836..0c8ae1ea3a 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java @@ -10,12 +10,18 @@ import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreateScrollResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.DeletePitRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.DeleteScrollRequest; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPitRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPitResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchScrollRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchScrollResponse; public class ElasticsearchAccessor implements SearchAccessor { + @Override + public SearchContextType getSearchContextType() { + // todo: implement + return null; + } @Override public CreatePitResponse createPit(CreatePitRequest createPitRequest) { diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchAccessor.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchAccessor.java index 20ff59edce..1da995bff6 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchAccessor.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchAccessor.java @@ -4,12 +4,14 @@ */ package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; +import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreatePitRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreatePitResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreateScrollRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreateScrollResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.DeletePitRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.DeleteScrollRequest; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPitRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPitResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchScrollRequest; @@ -17,6 +19,20 @@ public class OpenSearchAccessor implements SearchAccessor { + private final OpenSearchClient openSearchClient; + private final SearchContextType searchContextType; + + public OpenSearchAccessor(final OpenSearchClient openSearchClient, final SearchContextType searchContextType) { + this.openSearchClient = openSearchClient; + this.searchContextType = searchContextType; + } + + + @Override + public SearchContextType getSearchContextType() { + return searchContextType; + } + @Override public CreatePitResponse createPit(CreatePitRequest createPitRequest) { //todo: implement diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessor.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessor.java index 2ceacfb329..8c451e6995 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessor.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessor.java @@ -10,6 +10,7 @@ import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreateScrollResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.DeletePitRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.DeleteScrollRequest; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPitRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPitResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchScrollRequest; @@ -21,6 +22,11 @@ * @since 2.4 */ public interface SearchAccessor { + /** + * Information on whether how this SearchAccessor should be used by {@link org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchService} + * @return the {@link SearchContextType} that has information on which search strategy should be used + */ + SearchContextType getSearchContextType(); /** * Creates a Point-In-Time (PIT) context for searching diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java index 9a67d70210..3a382b67d7 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java @@ -4,7 +4,53 @@ */ package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.conn.ssl.TrustAllStrategy; +import org.apache.http.conn.ssl.TrustStrategy; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; +import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.ssl.SSLContexts; +import org.apache.maven.artifact.versioning.DefaultArtifactVersion; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.client.json.jackson.JacksonJsonpMapper; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch._types.OpenSearchException; +import org.opensearch.client.opensearch.core.InfoResponse; +import org.opensearch.client.transport.OpenSearchTransport; +import org.opensearch.client.transport.aws.AwsSdk2Transport; +import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; +import org.opensearch.client.transport.rest_client.RestClientTransport; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyStore; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.util.List; +import java.util.Objects; /** * SearchAccessorStrategy determines which {@link SearchAccessor} (Elasticsearch or OpenSearch) should be used based on @@ -13,14 +59,242 @@ */ public class SearchAccessorStrategy { + private static final Logger LOG = LoggerFactory.getLogger(SearchAccessorStrategy.class); + + private static final String AOS_SERVICE_NAME = "es"; + static final String OPENSEARCH_DISTRIBUTION = "opensearch"; + private static final String OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF = "2.5.0"; + + private final AwsCredentialsSupplier awsCredentialsSupplier; + private final OpenSearchSourceConfiguration openSearchSourceConfiguration; + + public static SearchAccessorStrategy create(final OpenSearchSourceConfiguration openSearchSourceConfiguration, + final AwsCredentialsSupplier awsCredentialsSupplier) { + return new SearchAccessorStrategy(openSearchSourceConfiguration, awsCredentialsSupplier); + } + + private SearchAccessorStrategy(final OpenSearchSourceConfiguration openSearchSourceConfiguration, + final AwsCredentialsSupplier awsCredentialsSupplier) { + this.awsCredentialsSupplier = awsCredentialsSupplier; + this.openSearchSourceConfiguration = openSearchSourceConfiguration; + } + /** * Provides a {@link SearchAccessor} that is based on the {@link OpenSearchSourceConfiguration} - * @param openSearchSourceConfiguration the plugins configuration * @return a {@link SearchAccessor} * @since 2.4 */ - public SearchAccessor getSearchAccessor(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { - //todo: implement - return null; + public SearchAccessor getSearchAccessor() { + final RestClient restClient = createOpenSearchRestClient(); + final OpenSearchTransport transport = createOpenSearchTransport(restClient); + final OpenSearchClient openSearchClient = new OpenSearchClient(transport); + + InfoResponse infoResponse; + try { + infoResponse = openSearchClient.info(); + } catch (final IOException | OpenSearchException e) { + throw new RuntimeException("There was an error looking up the OpenSearch cluster info: ", e); + } + + final String distribution = infoResponse.version().distribution(); + final String versionNumber = infoResponse.version().number(); + + if (!distribution.equals(OPENSEARCH_DISTRIBUTION)) { + throw new IllegalArgumentException(String.format("Only opensearch distributions are supported at this time. The cluster distribution being used is '%s'", distribution)); + } + + SearchContextType searchContextType; + + if (versionSupportsPointInTimeForOpenSearch(versionNumber)) { + LOG.info("OpenSearch version {} detected. Point in time APIs will be used to search documents", versionNumber); + searchContextType = SearchContextType.POINT_IN_TIME; + } else { + LOG.info("OpenSearch version {} detected. Scroll contexts will be used to search documents. " + + "Upgrade your cluster to at least version {} to use Point in Time APIs instead of scroll.", versionNumber, OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF); + searchContextType = SearchContextType.SCROLL; + } + + return new OpenSearchAccessor(openSearchClient, searchContextType); + } + + private RestClient createOpenSearchRestClient() { + final List hosts = openSearchSourceConfiguration.getHosts(); + final HttpHost[] httpHosts = new HttpHost[hosts.size()]; + + int i = 0; + for (final String host : hosts) { + httpHosts[i] = HttpHost.create(host); + i++; + } + + final RestClientBuilder restClientBuilder = RestClient.builder(httpHosts); + + if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) { + LOG.info("Using aws credentials and sigv4 for auth for the OpenSearch source"); + attachSigv4Auth(restClientBuilder); + } else { + LOG.info("Using username and password for auth for the OpenSearch source"); + attachUsernamePassword(restClientBuilder); + } + + setConnectAndSocketTimeout(restClientBuilder); + + return restClientBuilder.build(); + } + + private void attachSSLContext(final ApacheHttpClient.Builder apacheHttpClientBuilder) { + TrustManager[] trustManagers = createTrustManagers(openSearchSourceConfiguration.getConnectionConfiguration().getCertPath()); + apacheHttpClientBuilder.tlsTrustManagersProvider(() -> trustManagers); + } + + private void attachSSLContext(final HttpAsyncClientBuilder httpClientBuilder) { + + final ConnectionConfiguration connectionConfiguration = openSearchSourceConfiguration.getConnectionConfiguration(); + final SSLContext sslContext = Objects.nonNull(connectionConfiguration.getCertPath()) ? getCAStrategy(connectionConfiguration.getCertPath()) : getTrustAllStrategy(); + httpClientBuilder.setSSLContext(sslContext); + + if (connectionConfiguration.isInsecure()) { + httpClientBuilder.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + } + } + + private static TrustManager[] createTrustManagers(final Path certPath) { + if (certPath != null) { + LOG.info("Using the cert provided in the config."); + try (InputStream certificateInputStream = Files.newInputStream(certPath)) { + final CertificateFactory factory = CertificateFactory.getInstance("X.509"); + final Certificate trustedCa = factory.generateCertificate(certificateInputStream); + final KeyStore trustStore = KeyStore.getInstance("pkcs12"); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", trustedCa); + + final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance("X509"); + trustManagerFactory.init(trustStore); + return trustManagerFactory.getTrustManagers(); + } catch (Exception ex) { + throw new RuntimeException(ex.getMessage(), ex); + } + } else { + return new TrustManager[] { new X509TrustAllManager() }; + } + } + + private void attachSigv4Auth(final RestClientBuilder restClientBuilder) { + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder() + .withRegion(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion()) + .withStsRoleArn(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsRoleArn()) + .withStsHeaderOverrides(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsHeaderOverrides()) + .build()); + + final HttpRequestInterceptor httpRequestInterceptor = new AwsRequestSigningApacheInterceptor( + AOS_SERVICE_NAME, + Aws4Signer.create(), + awsCredentialsProvider, + openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion() + ); + + restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> { + httpClientBuilder.addInterceptorLast(httpRequestInterceptor); + attachSSLContext(httpClientBuilder); + return httpClientBuilder; + }); + } + + private void attachUsernamePassword(final RestClientBuilder restClientBuilder) { + final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(AuthScope.ANY, + new UsernamePasswordCredentials(openSearchSourceConfiguration.getUsername(), openSearchSourceConfiguration.getPassword())); + + restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> { + httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + attachSSLContext(httpClientBuilder); + return httpClientBuilder; + }); + } + + private void setConnectAndSocketTimeout(final RestClientBuilder restClientBuilder) { + restClientBuilder.setRequestConfigCallback(requestConfigBuilder -> { + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) { + requestConfigBuilder.setConnectTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout().toMillis()); + } + + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout())) { + requestConfigBuilder.setSocketTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout().toMillis()); + } + + return requestConfigBuilder; + }); + } + + private OpenSearchTransport createOpenSearchTransport(final RestClient restClient) { + + if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) { + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder() + .withRegion(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion()) + .withStsRoleArn(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsRoleArn()) + .withStsHeaderOverrides(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsHeaderOverrides()) + .build()); + + return new AwsSdk2Transport(createSdkHttpClient(), + HttpHost.create(openSearchSourceConfiguration.getHosts().get(0)).getHostName(), + AOS_SERVICE_NAME, openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(), + AwsSdk2TransportOptions.builder() + .setCredentials(awsCredentialsProvider) + .setMapper(new JacksonJsonpMapper()) + .build()); + } + + return new RestClientTransport(restClient, new JacksonJsonpMapper()); + } + + private SdkHttpClient createSdkHttpClient() { + final ApacheHttpClient.Builder apacheHttpClientBuilder = ApacheHttpClient.builder(); + + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) { + apacheHttpClientBuilder.connectionTimeout(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout()); + } + + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout())) { + apacheHttpClientBuilder.socketTimeout(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout()); + } + + attachSSLContext(apacheHttpClientBuilder); + + return apacheHttpClientBuilder.build(); + } + + private SSLContext getCAStrategy(final Path certPath) { + LOG.info("Using the cert provided in the config."); + try { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + Certificate trustedCa; + try (InputStream is = Files.newInputStream(certPath)) { + trustedCa = factory.generateCertificate(is); + } + KeyStore trustStore = KeyStore.getInstance("pkcs12"); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", trustedCa); + SSLContextBuilder sslContextBuilder = SSLContexts.custom() + .loadTrustMaterial(trustStore, null); + return sslContextBuilder.build(); + } catch (Exception ex) { + throw new RuntimeException(ex.getMessage(), ex); + } + } + + private SSLContext getTrustAllStrategy() { + LOG.info("Using the trust all strategy"); + final TrustStrategy trustStrategy = new TrustAllStrategy(); + try { + return SSLContexts.custom().loadTrustMaterial(null, trustStrategy).build(); + } catch (Exception ex) { + throw new RuntimeException(ex.getMessage(), ex); + } + } + + private boolean versionSupportsPointInTimeForOpenSearch(final String version) { + final DefaultArtifactVersion cutoffVersion = new DefaultArtifactVersion(OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF); + final DefaultArtifactVersion actualVersion = new DefaultArtifactVersion(version); + return actualVersion.compareTo(cutoffVersion) >= 0; } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/X509TrustAllManager.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/X509TrustAllManager.java new file mode 100644 index 0000000000..1baffdac9e --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/X509TrustAllManager.java @@ -0,0 +1,22 @@ +package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; + +import javax.net.ssl.X509TrustManager; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +public class X509TrustAllManager implements X509TrustManager { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return null; + } +} diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/model/SearchContextType.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/model/SearchContextType.java new file mode 100644 index 0000000000..b044c6e41d --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/model/SearchContextType.java @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model; + +public enum SearchContextType { + SCROLL, + POINT_IN_TIME, + NONE +} diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchServiceTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchServiceTest.java new file mode 100644 index 0000000000..7e013dc277 --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchServiceTest.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; + +import static org.mockito.Mockito.verify; + +@ExtendWith(MockitoExtension.class) +public class OpenSearchServiceTest { + + @Mock + private OpenSearchSourceConfiguration openSearchSourceConfiguration; + + @Mock + private SearchAccessor searchAccessor; + + @Mock + private Buffer> buffer; + + @Mock + private SourceCoordinator sourceCoordinator; + + private OpenSearchService createObjectUnderTest() { + return OpenSearchService.createOpenSearchService(searchAccessor, sourceCoordinator, openSearchSourceConfiguration, buffer); + } + + @Test + void source_coordinator_is_initialized_on_construction() { + createObjectUnderTest(); + verify(sourceCoordinator).initialize(); + } +} diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java new file mode 100644 index 0000000000..68affcfc64 --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class OpenSearchSourceTest { + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private OpenSearchSourceConfiguration openSearchSourceConfiguration; + + @Mock + private OpenSearchService openSearchService; + + @Mock + private SearchAccessorStrategy searchAccessorStrategy; + + @Mock + private SearchAccessor searchAccessor; + + @Mock + private Buffer> buffer; + + @Mock + private SourceCoordinator sourceCoordinator; + + private OpenSearchSource createObjectUnderTest() { + return new OpenSearchSource(openSearchSourceConfiguration, awsCredentialsSupplier); + } + + @Test + void start_with_null_buffer_throws_IllegalStateException() { + assertThrows(IllegalStateException.class, () -> createObjectUnderTest().start(null)); + } + + @Test + void start_with_non_null_buffer_does_not_throw() { + + when(searchAccessorStrategy.getSearchAccessor()).thenReturn(searchAccessor); + doNothing().when(openSearchService).start(); + + final OpenSearchSource objectUnderTest = createObjectUnderTest(); + objectUnderTest.setSourceCoordinator(sourceCoordinator); + + try (final MockedStatic searchAccessorStrategyMockedStatic = mockStatic(SearchAccessorStrategy.class); + final MockedStatic openSearchServiceMockedStatic = mockStatic(OpenSearchService.class)) { + searchAccessorStrategyMockedStatic.when(() -> SearchAccessorStrategy.create(openSearchSourceConfiguration, awsCredentialsSupplier)).thenReturn(searchAccessorStrategy); + + openSearchServiceMockedStatic.when(() -> OpenSearchService.createOpenSearchService(searchAccessor, sourceCoordinator, openSearchSourceConfiguration, buffer)) + .thenReturn(openSearchService); + + objectUnderTest.start(buffer); + } + } +} diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfigurationTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfigurationTest.java index b50a9f887d..8adb0ab47c 100644 --- a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfigurationTest.java +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfigurationTest.java @@ -6,43 +6,23 @@ package org.opensearch.dataprepper.plugins.source.opensearch.configuration; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.sts.StsClient; -import software.amazon.awssdk.services.sts.StsClientBuilder; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import java.lang.reflect.Field; -import java.util.Collections; -import java.util.Map; import java.util.UUID; -import java.util.function.Consumer; import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.nullValue; -import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.MatcherAssert.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; class AwsAuthenticationConfigurationTest { private AwsAuthenticationConfiguration awsAuthenticationOptions; - private final String TEST_ROLE = "arn:aws:iam::123456789012:role/test-role"; - @BeforeEach void setUp() { awsAuthenticationOptions = new AwsAuthenticationConfiguration(); @@ -54,7 +34,7 @@ void getAwsRegion_returns_Region_of() throws NoSuchFieldException, IllegalAccess final Region expectedRegionObject = mock(Region.class); reflectivelySetField(awsAuthenticationOptions, "awsRegion", regionString); final Region actualRegion; - try(final MockedStatic regionMockedStatic = mockStatic(Region.class)) { + try (final MockedStatic regionMockedStatic = mockStatic(Region.class)) { regionMockedStatic.when(() -> Region.of(regionString)).thenReturn(expectedRegionObject); actualRegion = awsAuthenticationOptions.getAwsRegion(); } @@ -67,159 +47,6 @@ void getAwsRegion_returns_null_when_region_is_null() throws NoSuchFieldException assertThat(awsAuthenticationOptions.getAwsRegion(), nullValue()); } - @Test - void authenticateAWSConfiguration_should_return_s3Client_without_sts_role_arn() throws NoSuchFieldException, IllegalAccessException { - reflectivelySetField(awsAuthenticationOptions, "awsRegion", "us-east-1"); - reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", null); - - final DefaultCredentialsProvider mockedCredentialsProvider = mock(DefaultCredentialsProvider.class); - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic defaultCredentialsProviderMockedStatic = mockStatic(DefaultCredentialsProvider.class)) { - defaultCredentialsProviderMockedStatic.when(DefaultCredentialsProvider::create) - .thenReturn(mockedCredentialsProvider); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, sameInstance(mockedCredentialsProvider)); - } - - @Nested - class WithSts { - private StsClient stsClient; - private StsClientBuilder stsClientBuilder; - - @BeforeEach - void setUp() { - stsClient = mock(StsClient.class); - stsClientBuilder = mock(StsClientBuilder.class); - - when(stsClientBuilder.build()).thenReturn(stsClient); - } - - @Test - void authenticateAWSConfiguration_should_return_s3Client_with_sts_role_arn() throws NoSuchFieldException, IllegalAccessException { - reflectivelySetField(awsAuthenticationOptions, "awsRegion", "us-east-1"); - reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - - when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); - final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); - when(assumeRoleRequestBuilder.roleSessionName(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.roleArn(anyString())) - .thenReturn(assumeRoleRequestBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); - final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - - verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); - verify(assumeRoleRequestBuilder).roleSessionName(anyString()); - verify(assumeRoleRequestBuilder).build(); - verifyNoMoreInteractions(assumeRoleRequestBuilder); - } - - @Test - void authenticateAWSConfiguration_should_return_s3Client_with_sts_role_arn_when_no_region() throws NoSuchFieldException, IllegalAccessException { - reflectivelySetField(awsAuthenticationOptions, "awsRegion", null); - reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - assertThat(awsAuthenticationOptions.getAwsRegion(), equalTo(null)); - - when(stsClientBuilder.region(null)).thenReturn(stsClientBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - } - - @Test - void authenticateAWSConfiguration_should_override_STS_Headers_when_HeaderOverrides_when_set() throws NoSuchFieldException, IllegalAccessException { - final String headerName1 = UUID.randomUUID().toString(); - final String headerValue1 = UUID.randomUUID().toString(); - final String headerName2 = UUID.randomUUID().toString(); - final String headerValue2 = UUID.randomUUID().toString(); - final Map overrideHeaders = Map.of(headerName1, headerValue1, headerName2, headerValue2); - - reflectivelySetField(awsAuthenticationOptions, "awsRegion", "us-east-1"); - reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - reflectivelySetField(awsAuthenticationOptions, "awsStsHeaderOverrides", overrideHeaders); - - when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); - - final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); - when(assumeRoleRequestBuilder.roleSessionName(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.roleArn(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.overrideConfiguration(any(Consumer.class))) - .thenReturn(assumeRoleRequestBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); - final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - - final ArgumentCaptor> configurationCaptor = ArgumentCaptor.forClass(Consumer.class); - - verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); - verify(assumeRoleRequestBuilder).roleSessionName(anyString()); - verify(assumeRoleRequestBuilder).overrideConfiguration(configurationCaptor.capture()); - verify(assumeRoleRequestBuilder).build(); - verifyNoMoreInteractions(assumeRoleRequestBuilder); - - final Consumer actualOverride = configurationCaptor.getValue(); - - final AwsRequestOverrideConfiguration.Builder configurationBuilder = mock(AwsRequestOverrideConfiguration.Builder.class); - actualOverride.accept(configurationBuilder); - verify(configurationBuilder).putHeader(headerName1, headerValue1); - verify(configurationBuilder).putHeader(headerName2, headerValue2); - verifyNoMoreInteractions(configurationBuilder); - } - - @Test - void authenticateAWSConfiguration_should_not_override_STS_Headers_when_HeaderOverrides_are_empty() throws NoSuchFieldException, IllegalAccessException { - reflectivelySetField(awsAuthenticationOptions, "awsRegion", "us-east-1"); - reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - reflectivelySetField(awsAuthenticationOptions, "awsStsHeaderOverrides", Collections.emptyMap()); - - when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); - final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); - when(assumeRoleRequestBuilder.roleSessionName(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.roleArn(anyString())) - .thenReturn(assumeRoleRequestBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); - final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - - verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); - verify(assumeRoleRequestBuilder).roleSessionName(anyString()); - verify(assumeRoleRequestBuilder).build(); - verifyNoMoreInteractions(assumeRoleRequestBuilder); - } - } - private void reflectivelySetField(final AwsAuthenticationConfiguration awsAuthenticationOptions, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException { final Field field = AwsAuthenticationConfiguration.class.getDeclaredField(fieldName); try { @@ -229,4 +56,4 @@ private void reflectivelySetField(final AwsAuthenticationConfiguration awsAuthen field.setAccessible(false); } } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessStrategyTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessStrategyTest.java new file mode 100644 index 0000000000..2d1821b59c --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessStrategyTest.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch._types.OpenSearchVersionInfo; +import org.opensearch.client.opensearch.core.InfoResponse; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy.OPENSEARCH_DISTRIBUTION; + +@ExtendWith(MockitoExtension.class) +public class SearchAccessStrategyTest { + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private OpenSearchSourceConfiguration openSearchSourceConfiguration; + + @Mock + private ConnectionConfiguration connectionConfiguration; + + @BeforeEach + void setup() { + when(openSearchSourceConfiguration.getHosts()).thenReturn(List.of("http://localhost:9200")); + when(openSearchSourceConfiguration.getConnectionConfiguration()).thenReturn(connectionConfiguration); + } + + private SearchAccessorStrategy createObjectUnderTest() { + return SearchAccessorStrategy.create(openSearchSourceConfiguration, awsCredentialsSupplier); + } + + @ParameterizedTest + @ValueSource(strings = {"2.5.0", "2.6.1", "3.0.0"}) + void testHappyPath_with_username_and_password_and_insecure_for_different_point_in_time_versions_for_opensearch(final String osVersion) { + final String username = UUID.randomUUID().toString(); + final String password = UUID.randomUUID().toString(); + when(openSearchSourceConfiguration.getUsername()).thenReturn(username); + when(openSearchSourceConfiguration.getPassword()).thenReturn(password); + + when(connectionConfiguration.getCertPath()).thenReturn(null); + when(connectionConfiguration.getSocketTimeout()).thenReturn(null); + when(connectionConfiguration.getConnectTimeout()).thenReturn(null); + when(connectionConfiguration.isInsecure()).thenReturn(true); + + when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(null); + + final InfoResponse infoResponse = mock(InfoResponse.class); + final OpenSearchVersionInfo openSearchVersionInfo = mock(OpenSearchVersionInfo.class); + when(openSearchVersionInfo.distribution()).thenReturn(OPENSEARCH_DISTRIBUTION); + when(openSearchVersionInfo.number()).thenReturn(osVersion); + when(infoResponse.version()).thenReturn(openSearchVersionInfo); + + try (MockedConstruction openSearchClientMockedConstruction = mockConstruction(OpenSearchClient.class, + (clientMock, context) -> { + when(clientMock.info()).thenReturn(infoResponse); + })) { + + final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); + assertThat(searchAccessor, notNullValue()); + assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.POINT_IN_TIME)); + + final List constructedClients = openSearchClientMockedConstruction.constructed(); + assertThat(constructedClients.size(), equalTo(1)); + } + + verifyNoInteractions(awsCredentialsSupplier); + + } + + @ParameterizedTest + @ValueSource(strings = {"1.3.0", "2.4.9", "0.3.2"}) + void testHappyPath_with_aws_credentials_for_different_scroll_versions_for_opensearch(final String osVersion) { + when(connectionConfiguration.getCertPath()).thenReturn(null); + when(connectionConfiguration.getSocketTimeout()).thenReturn(null); + when(connectionConfiguration.getConnectTimeout()).thenReturn(null); + when(connectionConfiguration.isInsecure()).thenReturn(true); + + final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class); + when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1); + final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role"; + when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn); + when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap()); + when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration); + + final InfoResponse infoResponse = mock(InfoResponse.class); + final OpenSearchVersionInfo openSearchVersionInfo = mock(OpenSearchVersionInfo.class); + when(openSearchVersionInfo.distribution()).thenReturn(OPENSEARCH_DISTRIBUTION); + when(openSearchVersionInfo.number()).thenReturn(osVersion); + when(infoResponse.version()).thenReturn(openSearchVersionInfo); + + final ArgumentCaptor awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider); + + try (MockedConstruction openSearchClientMockedConstruction = mockConstruction(OpenSearchClient.class, + (clientMock, context) -> { + when(clientMock.info()).thenReturn(infoResponse); + })) { + + final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); + assertThat(searchAccessor, notNullValue()); + assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.SCROLL)); + + final List constructedClients = openSearchClientMockedConstruction.constructed(); + assertThat(constructedClients.size(), equalTo(1)); + } + + final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue(); + assertThat(awsCredentialsOptions, notNullValue()); + assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1)); + assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap())); + assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + } +} diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/AwsRequestSigningApacheInterceptor.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/AwsRequestSigningApacheInterceptor.java index ba2b04a521..5d10779a3e 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/AwsRequestSigningApacheInterceptor.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/AwsRequestSigningApacheInterceptor.java @@ -14,7 +14,6 @@ import org.apache.http.Header; import org.apache.http.HttpEntityEnclosingRequest; -import org.apache.http.HttpException; import org.apache.http.HttpHost; import org.apache.http.HttpRequest; import org.apache.http.HttpRequestInterceptor; @@ -120,7 +119,7 @@ public AwsRequestSigningApacheInterceptor(final String service, */ @Override public void process(final HttpRequest request, final HttpContext context) - throws HttpException, IOException { + throws IOException { URIBuilder uriBuilder; try { uriBuilder = new URIBuilder(request.getRequestLine().getUri());