From 0c59d5d3831b74fe1a7d091b913de64739086d04 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 14 Jun 2024 16:27:21 -0700 Subject: [PATCH 1/3] Introduce SparkParameterComposerCollection Signed-off-by: Tomoyuki Morita --- async-query-core/README.md | 6 +- .../model/SparkSubmitParameters.java | 231 ------------ .../config/SparkSubmitParameterModifier.java | 7 +- .../spark/dispatcher/BatchQueryHandler.java | 13 +- .../dispatcher/InteractiveQueryHandler.java | 15 +- .../spark/dispatcher/QueryHandlerFactory.java | 17 +- .../spark/dispatcher/RefreshQueryHandler.java | 11 +- .../dispatcher/StreamingQueryHandler.java | 20 +- .../session/CreateSessionRequest.java | 6 +- .../execution/session/InteractiveSession.java | 7 +- .../DataSourceSparkParameterComposer.java | 19 + .../GeneralSparkParameterComposer.java | 17 + .../SparkParameterComposerCollection.java | 65 ++++ .../parameter/SparkSubmitParameters.java | 52 +++ .../SparkSubmitParametersBuilder.java | 171 +++++++++ .../SparkSubmitParametersBuilderProvider.java | 18 + .../AsyncQueryExecutorServiceImplTest.java | 5 +- .../model/SparkSubmitParametersTest.java | 84 ----- .../client/EmrServerlessClientImplTest.java | 45 +-- .../dispatcher/SparkQueryDispatcherTest.java | 355 ++++-------------- .../SparkParameterComposerCollectionTest.java | 93 +++++ .../SparkSubmitParametersBuilderTest.java | 201 ++++++++++ .../OpenSearchExtraParameterComposer.java | 30 ++ ...penSearchSparkSubmitParameterModifier.java | 6 +- ...rkExecutionEngineConfigClusterSetting.java | 2 + ...utionEngineConfigClusterSettingLoader.java | 35 ++ ...parkExecutionEngineConfigSupplierImpl.java | 34 +- ...3GlueDataSourceSparkParameterComposer.java | 108 ++++++ .../config/AsyncExecutorServiceModule.java | 32 +- .../AsyncQueryExecutorServiceSpec.java | 12 +- .../OpenSearchExtraParameterComposerTest.java | 52 +++ ...nEngineConfigClusterSettingLoaderTest.java | 67 ++++ ...ExecutionEngineConfigSupplierImplTest.java | 36 +- .../sql/spark/constants/TestConstants.java | 1 + .../execution/session/SessionTestUtil.java | 5 +- ...eDataSourceSparkParameterComposerTest.java | 160 ++++++++ 36 files changed, 1335 insertions(+), 703 deletions(-) delete mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java delete mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java diff --git a/async-query-core/README.md b/async-query-core/README.md index 61b6057269..815088bce6 100644 --- a/async-query-core/README.md +++ b/async-query-core/README.md @@ -27,6 +27,8 @@ Following is the list of extension points where the consumer of the library need - [QueryIdProvider](src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java) - [SessionIdProvider](src/main/java/org/opensearch/sql/spark/execution/session/SessionIdProvider.java) - [SessionConfigSupplier](src/main/java/org/opensearch/sql/spark/execution/session/SessionConfigSupplier.java) - - [SparkExecutionEngineConfigSupplier](src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java) - - [SparkSubmitParameterModifier](src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java) - [EMRServerlessClientFactory](src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java) + - [SparkExecutionEngineConfigSupplier](src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java) + - [DataSourceSparkParameterComposer](src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java) + - [GeneralSparkParameterComposer](src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java) + - [SparkSubmitParameterModifier](src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java) To be deprecated in favor of GeneralSparkParameterComposer diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java deleted file mode 100644 index 6badea6a74..0000000000 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.asyncquery.model; - -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_REGION; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_LAKEFORMATION_ENABLED; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; -import static org.opensearch.sql.spark.data.constants.SparkConstants.*; - -import java.net.URI; -import java.net.URISyntaxException; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.function.Supplier; -import lombok.AllArgsConstructor; -import lombok.RequiredArgsConstructor; -import lombok.Setter; -import org.apache.commons.lang3.BooleanUtils; -import org.apache.commons.text.StringEscapeUtils; -import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.datasource.model.DataSourceType; -import org.opensearch.sql.datasources.auth.AuthenticationType; -import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; -import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; - -/** Define Spark Submit Parameters. */ -@AllArgsConstructor -@RequiredArgsConstructor -public class SparkSubmitParameters { - public static final String SPACE = " "; - public static final String EQUALS = "="; - public static final String FLINT_BASIC_AUTH = "basic"; - - private final String className; - private final Map config; - - /** Extra parameters to append finally */ - @Setter private String extraParameters; - - public void setConfigItem(String key, String value) { - config.put(key, value); - } - - public void deleteConfigItem(String key) { - config.remove(key); - } - - public static Builder builder() { - return Builder.builder(); - } - - public SparkSubmitParameters acceptModifier(SparkSubmitParameterModifier modifier) { - modifier.modifyParameters(this); - return this; - } - - public static class Builder { - - private String className; - private final Map config; - private String extraParameters; - - private Builder() { - className = DEFAULT_CLASS_NAME; - config = new LinkedHashMap<>(); - - config.put(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); - config.put( - HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, - DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); - config.put(SPARK_JARS_KEY, ICEBERG_SPARK_RUNTIME_PACKAGE); - config.put( - SPARK_JAR_PACKAGES_KEY, - SPARK_STANDALONE_PACKAGE + "," + SPARK_LAUNCHER_PACKAGE + "," + PPL_STANDALONE_PACKAGE); - config.put(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); - config.put(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); - config.put(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); - config.put(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); - config.put(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); - config.put(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); - config.put(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); - config.put(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); - config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); - config.put(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); - config.put( - SPARK_SQL_EXTENSIONS_KEY, - ICEBERG_SPARK_EXTENSION + "," + FLINT_SQL_EXTENSION + "," + FLINT_PPL_EXTENSION); - config.put(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); - config.put(SPARK_CATALOG, ICEBERG_SESSION_CATALOG); - config.put(SPARK_CATALOG_CATALOG_IMPL, ICEBERG_GLUE_CATALOG); - } - - public static Builder builder() { - return new Builder(); - } - - public Builder className(String className) { - this.className = className; - return this; - } - - public Builder clusterName(String clusterName) { - config.put(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); - config.put(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); - return this; - } - - /** - * For query in spark submit parameters to be parsed correctly, escape the characters in the - * query, then wrap the query with double quotes. - */ - public Builder query(String query) { - String escapedQuery = StringEscapeUtils.escapeJava(query); - String wrappedQuery = "\"" + escapedQuery + "\""; - config.put(FLINT_JOB_QUERY, wrappedQuery); - return this; - } - - public Builder dataSource(DataSourceMetadata metadata) { - if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { - String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); - - config.put(DRIVER_ENV_ASSUME_ROLE_ARN_KEY, roleArn); - config.put(EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, roleArn); - config.put(HIVE_METASTORE_GLUE_ARN_KEY, roleArn); - config.put("spark.sql.catalog." + metadata.getName(), FLINT_DELEGATE_CATALOG); - config.put(FLINT_DATA_SOURCE_KEY, metadata.getName()); - - final boolean lakeFormationEnabled = - BooleanUtils.toBoolean(metadata.getProperties().get(GLUE_LAKEFORMATION_ENABLED)); - config.put(EMR_LAKEFORMATION_OPTION, Boolean.toString(lakeFormationEnabled)); - config.put(FLINT_ACCELERATE_USING_COVERING_INDEX, Boolean.toString(!lakeFormationEnabled)); - - setFlintIndexStoreHost( - parseUri( - metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_URI), metadata.getName())); - setFlintIndexStoreAuthProperties( - metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_REGION)); - config.put("spark.flint.datasource.name", metadata.getName()); - return this; - } - throw new UnsupportedOperationException( - String.format( - "UnSupported datasource type for async queries:: %s", metadata.getConnector())); - } - - private void setFlintIndexStoreHost(URI uri) { - config.put(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); - config.put(FLINT_INDEX_STORE_PORT_KEY, String.valueOf(uri.getPort())); - config.put(FLINT_INDEX_STORE_SCHEME_KEY, uri.getScheme()); - } - - private void setFlintIndexStoreAuthProperties( - String authType, - Supplier userName, - Supplier password, - Supplier region) { - if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) { - config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH); - config.put(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get()); - config.put(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get()); - } else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) { - config.put(FLINT_INDEX_STORE_AUTH_KEY, "sigv4"); - config.put(FLINT_INDEX_STORE_AWSREGION_KEY, region.get()); - } else { - config.put(FLINT_INDEX_STORE_AUTH_KEY, authType); - } - } - - private URI parseUri(String opensearchUri, String datasourceName) { - try { - return new URI(opensearchUri); - } catch (URISyntaxException e) { - throw new IllegalArgumentException( - String.format( - "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); - } - } - - public Builder structuredStreaming(Boolean isStructuredStreaming) { - if (isStructuredStreaming) { - config.put("spark.flint.job.type", "streaming"); - } - return this; - } - - public Builder extraParameters(String params) { - extraParameters = params; - return this; - } - - public SparkSubmitParameters build() { - return new SparkSubmitParameters(className, config, extraParameters); - } - } - - public void sessionExecution(String sessionId, String datasourceName) { - config.put(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName)); - config.put(FLINT_JOB_SESSION_ID, sessionId); - } - - @Override - public String toString() { - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append(" --class "); - stringBuilder.append(this.className); - stringBuilder.append(SPACE); - for (String key : config.keySet()) { - stringBuilder.append(" --conf "); - stringBuilder.append(key); - stringBuilder.append(EQUALS); - stringBuilder.append(config.get(key)); - stringBuilder.append(SPACE); - } - - if (extraParameters != null) { - stringBuilder.append(extraParameters); - } - return stringBuilder.toString(); - } -} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java b/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java index 1c6ce5952a..a50491078c 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java @@ -1,11 +1,12 @@ package org.opensearch.sql.spark.config; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; /** * Interface for extension point to allow modification of spark submit parameter. modifyParameter - * method is called after the default spark submit parameter is build. + * method is called after the default spark submit parameter is build. To be deprecated in favor of + * {@link org.opensearch.sql.spark.parameter.GeneralSparkParameterComposer} */ public interface SparkSubmitParameterModifier { - void modifyParameters(SparkSubmitParameters parameters); + void modifyParameters(SparkSubmitParametersBuilder parametersBuilder); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 8014cf935f..2654f83aad 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -16,7 +16,6 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -25,6 +24,7 @@ import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -37,6 +37,7 @@ public class BatchQueryHandler extends AsyncQueryHandler { protected final JobExecutionResponseReader jobExecutionResponseReader; protected final LeaseManager leaseManager; protected final MetricsService metricsService; + protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { @@ -80,12 +81,16 @@ public DispatchQueryResponse submit( dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.builder() + sparkSubmitParametersBuilderProvider + .getSparkSubmitParametersBuilder() .clusterName(clusterName) - .dataSource(context.getDataSourceMetadata()) .query(dispatchQueryRequest.getQuery()) - .build() + .dataSource( + context.getDataSourceMetadata(), + dispatchQueryRequest, + context.getAsyncQueryRequestContext()) .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) + .acceptComposers(dispatchQueryRequest, context.getAsyncQueryRequestContext()) .toString(), tags, false, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 266d5db978..ec43bccf11 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -16,7 +16,6 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -32,6 +31,7 @@ import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.EmrMetrics; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -46,6 +46,7 @@ public class InteractiveQueryHandler extends AsyncQueryHandler { private final JobExecutionResponseReader jobExecutionResponseReader; private final LeaseManager leaseManager; private final MetricsService metricsService; + protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { @@ -112,12 +113,16 @@ public DispatchQueryResponse submit( dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.builder() + sparkSubmitParametersBuilderProvider + .getSparkSubmitParametersBuilder() .className(FLINT_SESSION_CLASS_NAME) .clusterName(clusterName) - .dataSource(dataSourceMetadata) - .build() - .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()), + .dataSource( + dataSourceMetadata, + dispatchQueryRequest, + context.getAsyncQueryRequestContext()) + .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) + .acceptComposers(dispatchQueryRequest, context.getAsyncQueryRequestContext()), tags, dataSourceMetadata.getResultIndex(), dataSourceMetadata.getName()), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java index 90329f2f9a..d6e70a9d86 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -13,6 +13,7 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @RequiredArgsConstructor @@ -26,6 +27,7 @@ public class QueryHandlerFactory { private final FlintIndexOpFactory flintIndexOpFactory; private final EMRServerlessClientFactory emrServerlessClientFactory; private final MetricsService metricsService; + protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; public RefreshQueryHandler getRefreshQueryHandler(String accountId) { return new RefreshQueryHandler( @@ -34,7 +36,8 @@ public RefreshQueryHandler getRefreshQueryHandler(String accountId) { flintIndexMetadataService, leaseManager, flintIndexOpFactory, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } public StreamingQueryHandler getStreamingQueryHandler(String accountId) { @@ -42,7 +45,8 @@ public StreamingQueryHandler getStreamingQueryHandler(String accountId) { emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, leaseManager, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } public BatchQueryHandler getBatchQueryHandler(String accountId) { @@ -50,12 +54,17 @@ public BatchQueryHandler getBatchQueryHandler(String accountId) { emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, leaseManager, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } public InteractiveQueryHandler getInteractiveQueryHandler() { return new InteractiveQueryHandler( - sessionManager, jobExecutionResponseReader, leaseManager, metricsService); + sessionManager, + jobExecutionResponseReader, + leaseManager, + metricsService, + sparkSubmitParametersBuilderProvider); } public IndexDMLHandler getIndexDMLHandler() { diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 634dfa49f6..99984ecc46 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -20,6 +20,7 @@ import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -37,8 +38,14 @@ public RefreshQueryHandler( FlintIndexMetadataService flintIndexMetadataService, LeaseManager leaseManager, FlintIndexOpFactory flintIndexOpFactory, - MetricsService metricsService) { - super(emrServerlessClient, jobExecutionResponseReader, leaseManager, metricsService); + MetricsService metricsService, + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider) { + super( + emrServerlessClient, + jobExecutionResponseReader, + leaseManager, + metricsService, + sparkSubmitParametersBuilderProvider); this.flintIndexMetadataService = flintIndexMetadataService; this.flintIndexOpFactory = flintIndexOpFactory; } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 7291637e5b..2fbf2466da 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -12,7 +12,6 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -23,6 +22,7 @@ import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -35,8 +35,14 @@ public StreamingQueryHandler( EMRServerlessClient emrServerlessClient, JobExecutionResponseReader jobExecutionResponseReader, LeaseManager leaseManager, - MetricsService metricsService) { - super(emrServerlessClient, jobExecutionResponseReader, leaseManager, metricsService); + MetricsService metricsService, + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider) { + super( + emrServerlessClient, + jobExecutionResponseReader, + leaseManager, + metricsService, + sparkSubmitParametersBuilderProvider); } @Override @@ -70,13 +76,15 @@ public DispatchQueryResponse submit( dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.builder() + sparkSubmitParametersBuilderProvider + .getSparkSubmitParametersBuilder() .clusterName(clusterName) - .dataSource(dataSourceMetadata) .query(dispatchQueryRequest.getQuery()) .structuredStreaming(true) - .build() + .dataSource( + dataSourceMetadata, dispatchQueryRequest, context.getAsyncQueryRequestContext()) .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) + .acceptComposers(dispatchQueryRequest, context.getAsyncQueryRequestContext()) .toString(), tags, indexQueryDetails.getFlintIndexOptions().autoRefresh(), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index 4170f0c2d6..6398dd224f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -7,9 +7,9 @@ import java.util.Map; import lombok.Data; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; @Data public class CreateSessionRequest { @@ -17,7 +17,7 @@ public class CreateSessionRequest { private final String accountId; private final String applicationId; private final String executionRoleArn; - private final SparkSubmitParameters sparkSubmitParameters; + private final SparkSubmitParametersBuilder sparkSubmitParametersBuilder; private final Map tags; private final String resultIndex; private final String datasourceName; @@ -28,7 +28,7 @@ public StartJobRequest getStartJobRequest(String sessionId) { accountId, applicationId, executionRoleArn, - sparkSubmitParameters.toString(), + sparkSubmitParametersBuilder.toString(), tags, resultIndex); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 37b2619783..aeedaef4e7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -53,11 +53,8 @@ public void open( AsyncQueryRequestContext asyncQueryRequestContext) { // append session id; createSessionRequest - .getSparkSubmitParameters() - .acceptModifier( - (parameters) -> { - parameters.sessionExecution(sessionId, createSessionRequest.getDatasourceName()); - }); + .getSparkSubmitParametersBuilder() + .sessionExecution(sessionId, createSessionRequest.getDatasourceName()); createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId); StartJobRequest startJobRequest = createSessionRequest.getStartJobRequest(sessionId); String jobID = serverlessClient.startJobRun(startJobRequest); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java new file mode 100644 index 0000000000..4411bfe22d --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Compose Spark parameter based on DataSourceMetadata */ +public interface DataSourceSparkParameterComposer { + void compose( + DataSourceMetadata dataSourceMetadata, + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java new file mode 100644 index 0000000000..b5200d550c --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Compose spark submit parameters. See {@link SparkParameterComposerCollection}. */ +public interface GeneralSparkParameterComposer { + void compose( + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java new file mode 100644 index 0000000000..281759afd2 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Stores Spark parameter composers and dispatch compose request to each composer */ +public class SparkParameterComposerCollection { + Collection generalComposers = new ArrayList<>(); + Map> datasourceComposers = + new HashMap<>(); + + public void register(DataSourceType dataSourceType, DataSourceSparkParameterComposer composer) { + if (!datasourceComposers.containsKey(dataSourceType)) { + datasourceComposers.put(dataSourceType, new LinkedList<>()); + } + datasourceComposers.get(dataSourceType).add(composer); + } + + public void register(GeneralSparkParameterComposer composer) { + generalComposers.add(composer); + } + + /** Execute composers associated with the datasource type */ + public void composeByDataSource( + DataSourceMetadata dataSourceMetadata, + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + for (DataSourceSparkParameterComposer composer : + getComposersFor(dataSourceMetadata.getConnector())) { + composer.compose(dataSourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context); + } + } + + /** Execute all the registered generic composers */ + public void compose( + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + for (GeneralSparkParameterComposer composer : generalComposers) { + composer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + } + } + + private Collection getComposersFor(DataSourceType type) { + return datasourceComposers.getOrDefault(type, ImmutableList.of()); + } + + public boolean isComposerRegistered(DataSourceType type) { + return datasourceComposers.containsKey(type); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java new file mode 100644 index 0000000000..2e142ed117 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; + +import java.util.LinkedHashMap; +import java.util.Map; +import lombok.Setter; + +/** Define Spark Submit Parameters. */ +public class SparkSubmitParameters { + public static final String SPACE = " "; + public static final String EQUALS = "="; + + @Setter private String className = DEFAULT_CLASS_NAME; + private Map config = new LinkedHashMap<>(); + + /** Extra parameters to append finally */ + @Setter private String extraParameters; + + public void setConfigItem(String key, String value) { + config.put(key, value); + } + + public void deleteConfigItem(String key) { + config.remove(key); + } + + @Override + public String toString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append(" --class "); + stringBuilder.append(this.className); + stringBuilder.append(SPACE); + for (String key : config.keySet()) { + stringBuilder.append(" --conf "); + stringBuilder.append(key); + stringBuilder.append(EQUALS); + stringBuilder.append(config.get(key)); + stringBuilder.append(SPACE); + } + + if (extraParameters != null) { + stringBuilder.append(extraParameters); + } + return stringBuilder.toString(); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java new file mode 100644 index 0000000000..01a665a485 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.AWS_SNAPSHOT_REPOSITORY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_CLUSTER_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_PPL_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SQL_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_HIVE_CATALOG_FACTORY_CLASS; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_CLASS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_GLUE_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_SESSION_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_SPARK_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_SPARK_RUNTIME_PACKAGE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.JAVA_HOME_LOCATION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.PPL_STANDALONE_PACKAGE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.S3_AWS_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_CATALOG_CATALOG_IMPL; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_DRIVER_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_EXECUTOR_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_PACKAGES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_REPOSITORIES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_LAUNCHER_PACKAGE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_EXTENSIONS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_STANDALONE_PACKAGE; + +import lombok.Getter; +import org.apache.commons.text.StringEscapeUtils; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; + +public class SparkSubmitParametersBuilder { + private final SparkParameterComposerCollection sparkParameterComposerCollection; + @Getter private final SparkSubmitParameters sparkSubmitParameters; + + public SparkSubmitParametersBuilder( + SparkParameterComposerCollection sparkParameterComposerCollection) { + this.sparkParameterComposerCollection = sparkParameterComposerCollection; + sparkSubmitParameters = new SparkSubmitParameters(); + setDefaultConfigs(); + } + + private void setDefaultConfigs() { + setConfigItem(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); + setConfigItem( + HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, + DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + setConfigItem(SPARK_JARS_KEY, ICEBERG_SPARK_RUNTIME_PACKAGE); + setConfigItem( + SPARK_JAR_PACKAGES_KEY, + SPARK_STANDALONE_PACKAGE + "," + SPARK_LAUNCHER_PACKAGE + "," + PPL_STANDALONE_PACKAGE); + setConfigItem(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); + setConfigItem(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + setConfigItem(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + setConfigItem(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); + setConfigItem(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); + setConfigItem(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); + setConfigItem(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); + setConfigItem(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); + setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); + setConfigItem(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); + setConfigItem( + SPARK_SQL_EXTENSIONS_KEY, + ICEBERG_SPARK_EXTENSION + "," + FLINT_SQL_EXTENSION + "," + FLINT_PPL_EXTENSION); + setConfigItem(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); + setConfigItem(SPARK_CATALOG, ICEBERG_SESSION_CATALOG); + setConfigItem(SPARK_CATALOG_CATALOG_IMPL, ICEBERG_GLUE_CATALOG); + } + + private void setConfigItem(String key, String value) { + sparkSubmitParameters.setConfigItem(key, value); + } + + public SparkSubmitParametersBuilder className(String className) { + sparkSubmitParameters.setClassName(className); + return this; + } + + /** clusterName will be used for logging and metrics in Spark */ + public SparkSubmitParametersBuilder clusterName(String clusterName) { + setConfigItem(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); + setConfigItem(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); + return this; + } + + /** + * For query in spark submit parameters to be parsed correctly, escape the characters in the + * query, then wrap the query with double quotes. + */ + public SparkSubmitParametersBuilder query(String query) { + String escapedQuery = StringEscapeUtils.escapeJava(query); + String wrappedQuery = "\"" + escapedQuery + "\""; + setConfigItem(FLINT_JOB_QUERY, wrappedQuery); + return this; + } + + public SparkSubmitParametersBuilder dataSource( + DataSourceMetadata metadata, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + if (sparkParameterComposerCollection.isComposerRegistered(metadata.getConnector())) { + sparkParameterComposerCollection.composeByDataSource( + metadata, sparkSubmitParameters, dispatchQueryRequest, context); + return this; + } else { + throw new UnsupportedOperationException( + String.format( + "UnSupported datasource type for async queries:: %s", metadata.getConnector())); + } + } + + public SparkSubmitParametersBuilder structuredStreaming(Boolean isStructuredStreaming) { + if (isStructuredStreaming) { + setConfigItem("spark.flint.job.type", "streaming"); + } + return this; + } + + public SparkSubmitParametersBuilder extraParameters(String params) { + sparkSubmitParameters.setExtraParameters(params); + return this; + } + + public SparkSubmitParametersBuilder sessionExecution(String sessionId, String datasourceName) { + setConfigItem(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName)); + setConfigItem(FLINT_JOB_SESSION_ID, sessionId); + return this; + } + + public SparkSubmitParametersBuilder acceptModifier(SparkSubmitParameterModifier modifier) { + modifier.modifyParameters(this); + return this; + } + + public SparkSubmitParametersBuilder acceptComposers( + DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext context) { + sparkParameterComposerCollection.compose(sparkSubmitParameters, dispatchQueryRequest, context); + return this; + } + + @Override + public String toString() { + return sparkSubmitParameters.toString(); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java new file mode 100644 index 0000000000..ccc9ffb680 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import lombok.RequiredArgsConstructor; + +/** Provide SparkSubmitParametersBuilder instance with SparkParameterComposerCollection injected */ +@RequiredArgsConstructor +public class SparkSubmitParametersBuilderProvider { + private final SparkParameterComposerCollection sparkParameterComposerCollection; + + public SparkSubmitParametersBuilder getSparkSubmitParametersBuilder() { + return new SparkSubmitParametersBuilder(sparkParameterComposerCollection); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 8325a10fbc..dbc51bb0ad 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -33,7 +33,6 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; @@ -115,9 +114,7 @@ void testCreateAsyncQuery() { @Test void testCreateAsyncQueryWithExtraSparkSubmitParameter() { SparkSubmitParameterModifier modifier = - (SparkSubmitParameters parameters) -> { - parameters.setExtraParameters("--conf spark.dynamicAllocation.enabled=false"); - }; + (builder) -> builder.extraParameters("--conf spark.dynamicAllocation.enabled=false"); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( SparkExecutionEngineConfig.builder() diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java deleted file mode 100644 index 10f12251b0..0000000000 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.asyncquery.model; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; - -import org.junit.jupiter.api.Test; - -public class SparkSubmitParametersTest { - - @Test - public void testBuildWithoutExtraParameters() { - String params = SparkSubmitParameters.builder().build().toString(); - - assertNotNull(params); - } - - @Test - public void testBuildWithExtraParameters() { - String params = - SparkSubmitParameters.builder().extraParameters("--conf A=1").build().toString(); - - // Assert the conf is included with a space - assertTrue(params.endsWith(" --conf A=1")); - } - - @Test - public void testBuildQueryString() { - String rawQuery = "SHOW tables LIKE \"%\";"; - String expectedQueryInParams = "\"SHOW tables LIKE \\\"%\\\";\""; - String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); - assertTrue(params.contains(expectedQueryInParams)); - } - - @Test - public void testBuildQueryStringNestedQuote() { - String rawQuery = "SELECT '\"1\"'"; - String expectedQueryInParams = "\"SELECT '\\\"1\\\"'\""; - String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); - assertTrue(params.contains(expectedQueryInParams)); - } - - @Test - public void testBuildQueryStringSpecialCharacter() { - String rawQuery = "SELECT '{\"test ,:+\\\"inner\\\"/\\|?#><\"}'"; - String expectedQueryInParams = "SELECT '{\\\"test ,:+\\\\\\\"inner\\\\\\\"/\\\\|?#><\\\"}'"; - String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); - assertTrue(params.contains(expectedQueryInParams)); - } - - @Test - public void testOverrideConfigItem() { - SparkSubmitParameters params = SparkSubmitParameters.builder().build(); - params.setConfigItem(SPARK_JARS_KEY, "Overridden"); - String result = params.toString(); - - assertTrue(result.contains(String.format("%s=Overridden", SPARK_JARS_KEY))); - } - - @Test - public void testDeleteConfigItem() { - SparkSubmitParameters params = SparkSubmitParameters.builder().build(); - params.deleteConfigItem(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); - String result = params.toString(); - - assertFalse(result.contains(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY)); - } - - @Test - public void testAddConfigItem() { - SparkSubmitParameters params = SparkSubmitParameters.builder().build(); - params.setConfigItem("AdditionalKey", "Value"); - String result = params.toString(); - - assertTrue(result.contains("AdditionalKey=Value")); - } -} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index e2473d0275..42d703f9ac 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -34,10 +34,12 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; @ExtendWith(MockitoExtension.class) public class EmrServerlessClientImplTest { @@ -46,14 +48,17 @@ public class EmrServerlessClientImplTest { @Captor private ArgumentCaptor startJobRunRequestArgumentCaptor; + @InjectMocks EmrServerlessClientImpl emrServerlessClient; + @Test void testStartJobRun() { StartJobRunResult response = new StartJobRunResult(); when(emrServerless.startJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); - String parameters = SparkSubmitParameters.builder().query(QUERY).build().toString(); + String parameters = + new SparkSubmitParametersBuilder(new SparkParameterComposerCollection()) + .query(QUERY) + .toString(); emrServerlessClient.startJobRun( new StartJobRequest( @@ -87,8 +92,6 @@ void testStartJobRunWithErrorMetric() { doThrow(new AWSEMRServerlessException("Couldn't start job")) .when(emrServerless) .startJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -111,8 +114,6 @@ void testStartJobRunResultIndex() { StartJobRunResult response = new StartJobRunResult(); when(emrServerless.startJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); emrServerlessClient.startJobRun( new StartJobRequest( EMRS_JOB_NAME, @@ -132,16 +133,12 @@ void testGetJobRunState() { GetJobRunResult response = new GetJobRunResult(); response.setJobRun(jobRun); when(emrServerless.getJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); } @Test void testGetJobRunStateWithErrorMetric() { doThrow(new ValidationException("Not a good job")).when(emrServerless).getJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -153,18 +150,17 @@ void testGetJobRunStateWithErrorMetric() { void testCancelJobRun() { when(emrServerless.cancelJobRun(any())) .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false); + Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); } @Test void testCancelJobRunWithErrorMetric() { doThrow(new RuntimeException()).when(emrServerless).cancelJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + Assertions.assertThrows( RuntimeException.class, () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, "123", false)); @@ -173,24 +169,24 @@ void testCancelJobRunWithErrorMetric() { @Test void testCancelJobRunWithValidationException() { doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)); + Assertions.assertEquals("Internal Server Error.", runtimeException.getMessage()); } @Test void testCancelJobRunWithNativeEMRExceptionWithValidationException() { doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + ValidationException validationException = Assertions.assertThrows( ValidationException.class, () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, true)); + Assertions.assertTrue(validationException.getMessage().contains("Error")); } @@ -198,10 +194,10 @@ void testCancelJobRunWithNativeEMRExceptionWithValidationException() { void testCancelJobRunWithNativeEMRException() { when(emrServerless.cancelJobRun(any())) .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, true); + Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); } @@ -210,8 +206,6 @@ void testStartJobRunWithLongJobName() { StartJobRunResult response = new StartJobRunResult(); when(emrServerless.startJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); emrServerlessClient.startJobRun( new StartJobRequest( RandomStringUtils.random(300), @@ -222,6 +216,7 @@ void testStartJobRunWithLongJobName() { new HashMap<>(), false, DEFAULT_RESULT_INDEX)); + verify(emrServerless, times(1)).startJobRun(startJobRunRequestArgumentCaptor.capture()); StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue(); Assertions.assertEquals(255, startJobRunRequest.getName().length()); @@ -230,8 +225,6 @@ void testStartJobRunWithLongJobName() { @Test void testStartJobRunThrowsValidationException() { when(emrServerless.startJobRun(any())).thenThrow(new ValidationException("Unmatched quote")); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); IllegalArgumentException exception = Assertions.assertThrows( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index b9c95f66cc..5582de332c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -27,9 +27,10 @@ import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_PASSWORD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_USERNAME; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.CLUSTER_NAME_TAG_KEY; import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.DATASOURCE_TAG_KEY; @@ -45,6 +46,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -76,6 +78,10 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.DataSourceSparkParameterComposer; +import org.opensearch.sql.spark.parameter.GeneralSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -83,6 +89,10 @@ public class SparkQueryDispatcherTest { public static final String MY_GLUE = "my_glue"; + public static final String KEY_FROM_COMPOSER = "key.from.composer"; + public static final String VALUE_FROM_COMPOSER = "value.from.composer"; + public static final String KEY_FROM_DATASOURCE_COMPOSER = "key.from.datasource.composer"; + public static final String VALUE_FROM_DATASOURCE_COMPOSER = "value.from.datasource.composer"; @Mock private EMRServerlessClient emrServerlessClient; @Mock private EMRServerlessClientFactory emrServerlessClientFactory; @Mock private DataSourceService dataSourceService; @@ -96,6 +106,22 @@ public class SparkQueryDispatcherTest { @Mock private QueryIdProvider queryIdProvider; @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Mock private MetricsService metricsService; + private DataSourceSparkParameterComposer dataSourceSparkParameterComposer = + (datasourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> { + sparkSubmitParameters.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, "basic"); + sparkSubmitParameters.setConfigItem(FLINT_INDEX_STORE_HOST_KEY, "HOST"); + sparkSubmitParameters.setConfigItem(FLINT_INDEX_STORE_PORT_KEY, "PORT"); + sparkSubmitParameters.setConfigItem(FLINT_INDEX_STORE_SCHEME_KEY, "SCHEMA"); + sparkSubmitParameters.setConfigItem( + KEY_FROM_DATASOURCE_COMPOSER, VALUE_FROM_DATASOURCE_COMPOSER); + }; + + private GeneralSparkParameterComposer generalSparkParameterComposer = + (sparkSubmitParameters, dispatchQueryRequest, context) -> { + sparkSubmitParameters.setConfigItem(KEY_FROM_COMPOSER, VALUE_FROM_COMPOSER); + }; + + private SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -111,6 +137,10 @@ public class SparkQueryDispatcherTest { @BeforeEach void setUp() { + SparkParameterComposerCollection collection = new SparkParameterComposerCollection(); + collection.register(DataSourceType.S3GLUE, dataSourceSparkParameterComposer); + collection.register(generalSparkParameterComposer); + sparkSubmitParametersBuilderProvider = new SparkSubmitParametersBuilderProvider(collection); QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( jobExecutionResponseReader, @@ -120,7 +150,8 @@ void setUp() { indexDMLResultStorageService, flintIndexOpFactory, emrServerlessClientFactory, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); sparkQueryDispatcher = new SparkQueryDispatcher( dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); @@ -134,15 +165,7 @@ void testDispatchSelectQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -185,16 +208,7 @@ void testDispatchSelectQueryWithLakeFormation() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query, - true); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -226,16 +240,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "basic", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -260,45 +265,6 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { verifyNoInteractions(flintIndexMetadataService); } - @Test - void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, MY_GLUE); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }, - query); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithNoAuth(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - verifyNoInteractions(flintIndexMetadataService); - } - @Test void testDispatchSelectQueryCreateNewSession() { String query = "select * from my_glue.default.http_logs"; @@ -377,16 +343,7 @@ void testDispatchCreateAutoRefreshIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query)); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -421,15 +378,7 @@ void testDispatchCreateManualRefreshIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = false)"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -462,15 +411,7 @@ void testDispatchWithPPLQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "source = my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -505,15 +446,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "show tables"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -549,16 +482,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query)); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -594,16 +518,7 @@ void testDispatchMaterializedViewQuery() { String query = "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query)); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_mv_1", @@ -636,15 +551,7 @@ void testDispatchShowMVQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "SHOW MATERIALIZED VIEW IN mys3.default"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -677,15 +584,7 @@ void testRefreshIndexQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "REFRESH SKIPPING INDEX ON my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -718,15 +617,7 @@ void testDispatchDescribeIndexQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -762,16 +653,7 @@ void testDispatchAlterToAutoRefreshIndexQuery() { String query = "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }, - query)); + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -867,24 +749,6 @@ void testDispatchVacuumIndexQuery() { verify(queryHandlerFactory, times(1)).getIndexDMLHandler(); } - @Test - void testDispatchWithWrongURI() { - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) - .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); - String query = "select * from my_glue.default.http_logs"; - - IllegalArgumentException illegalArgumentException = - Assertions.assertThrows( - IllegalArgumentException.class, - () -> - sparkQueryDispatcher.dispatch( - getBaseDispatchQueryRequest(query), asyncQueryRequestContext)); - - Assertions.assertEquals( - "Bad URI in indexstore configuration of the : my_glue datasoure.", - illegalArgumentException.getMessage()); - } - @Test void testDispatchWithUnSupportedDataSourceType() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus")) @@ -1111,71 +975,46 @@ void testDispatchQueryWithExtraSparkSubmitParameters() { } } - private String constructExpectedSparkSubmitParameterString( - String auth, Map authParams, String query) { - return constructExpectedSparkSubmitParameterString(auth, authParams, query, false); + private String constructExpectedSparkSubmitParameterString(String query) { + return constructExpectedSparkSubmitParameterString(query, null); } - private String constructExpectedSparkSubmitParameterString( - String auth, Map authParams, String query, boolean lakeFormationEnabled) { - StringBuilder authParamConfigBuilder = new StringBuilder(); - for (String key : authParams.keySet()) { - authParamConfigBuilder.append(" --conf "); - authParamConfigBuilder.append(key); - authParamConfigBuilder.append("="); - authParamConfigBuilder.append(authParams.get(key)); - } + private String constructExpectedSparkSubmitParameterString(String query, String jobType) { query = "\"" + query + "\""; - return " --class org.apache.spark.sql.FlintJob --conf" - + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" - + " --conf" - + " spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory" - + " --conf spark.jars=/usr/share/aws/iceberg/lib/iceberg-spark3-runtime.jar --conf" - + " spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.3.0-SNAPSHOT,org.opensearch:opensearch-spark-sql-application_2.12:0.3.0-SNAPSHOT,org.opensearch:opensearch-spark-ppl_2.12:0.3.0-SNAPSHOT" - + " --conf" - + " spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots" - + " --conf" - + " spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" - + " --conf spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/" - + " --conf spark.emr-serverless.driverEnv.FLINT_CLUSTER_NAME=TEST_CLUSTER --conf" - + " spark.executorEnv.FLINT_CLUSTER_NAME=TEST_CLUSTER --conf" - + " spark.datasource.flint.host=search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com" - + " --conf spark.datasource.flint.port=-1 --conf" - + " spark.datasource.flint.scheme=https --conf spark.datasource.flint.auth=" - + auth - + " --conf" - + " spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" - + " --conf" - + " spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions,org.opensearch.flint.spark.FlintSparkExtensions,org.opensearch.flint.spark.FlintPPLSparkExtensions" - + " --conf" - + " spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory" - + " --conf spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog " - + " --conf" - + " spark.sql.catalog.spark_catalog.catalog-impl=org.apache.iceberg.aws.glue.GlueCatalog " - + " --conf" - + " spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" - + " --conf" - + " spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" - + " --conf" - + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" - + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " - + " --conf spark.flint.datasource.name=my_glue --conf" - + " spark.emr-serverless.lakeformation.enabled=" - + Boolean.toString(lakeFormationEnabled) - + " --conf spark.flint.optimizer.covering.enabled=" - + Boolean.toString(!lakeFormationEnabled) - + authParamConfigBuilder - + " --conf spark.flint.job.query=" - + query - + " "; + return " --class org.apache.spark.sql.FlintJob " + + getConfParam( + "spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider", + "spark.hadoop.aws.catalog.credentials.provider.factory.class=com.amazonaws.glue.catalog.metastore.STSAssumeRoleSessionCredentialsProviderFactory", + "spark.jars=/usr/share/aws/iceberg/lib/iceberg-spark3-runtime.jar", + "spark.jars.packages=org.opensearch:opensearch-spark-standalone_2.12:0.3.0-SNAPSHOT,org.opensearch:opensearch-spark-sql-application_2.12:0.3.0-SNAPSHOT,org.opensearch:opensearch-spark-ppl_2.12:0.3.0-SNAPSHOT", + "spark.jars.repositories=https://aws.oss.sonatype.org/content/repositories/snapshots", + "spark.emr-serverless.driverEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/", + "spark.executorEnv.JAVA_HOME=/usr/lib/jvm/java-17-amazon-corretto.x86_64/", + "spark.emr-serverless.driverEnv.FLINT_CLUSTER_NAME=TEST_CLUSTER", + "spark.executorEnv.FLINT_CLUSTER_NAME=TEST_CLUSTER", + "spark.datasource.flint.host=HOST", + "spark.datasource.flint.port=PORT", + "spark.datasource.flint.scheme=SCHEMA", + "spark.datasource.flint.auth=basic", + "spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider", + "spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions,org.opensearch.flint.spark.FlintSparkExtensions,org.opensearch.flint.spark.FlintPPLSparkExtensions", + "spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory", + "spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog", + "spark.sql.catalog.spark_catalog.catalog-impl=org.apache.iceberg.aws.glue.GlueCatalog") + + getConfParam("spark.flint.job.query=" + query) + + (jobType != null ? getConfParam("spark.flint.job.type=" + jobType) : "") + + getConfParam( + KEY_FROM_DATASOURCE_COMPOSER + "=" + VALUE_FROM_DATASOURCE_COMPOSER, + KEY_FROM_COMPOSER + "=" + VALUE_FROM_COMPOSER); } - private String withStructuredStreaming(String parameters) { - return parameters + " --conf spark.flint.job.type=streaming "; + private String getConfParam(String... params) { + return Arrays.stream(params) + .map(param -> String.format(" --conf %s ", param)) + .collect(Collectors.joining()); } private DataSourceMetadata constructMyGlueDataSourceMetadata() { - Map properties = new HashMap<>(); properties.put("glue.auth.type", "iam_role"); properties.put( @@ -1210,37 +1049,6 @@ private DataSourceMetadata constructMyGlueDataSourceMetadataWithBasicAuth() { .build(); } - private DataSourceMetadata constructMyGlueDataSourceMetadataWithNoAuth() { - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put( - "glue.indexstore.opensearch.uri", - "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); - properties.put("glue.indexstore.opensearch.auth", "noauth"); - return new DataSourceMetadata.Builder() - .setName(MY_GLUE) - .setConnector(DataSourceType.S3GLUE) - .setProperties(properties) - .build(); - } - - private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); - properties.put("glue.indexstore.opensearch.auth", "awssigv4"); - properties.put("glue.indexstore.opensearch.region", "eu-west-1"); - return new DataSourceMetadata.Builder() - .setName(MY_GLUE) - .setConnector(DataSourceType.S3GLUE) - .setProperties(properties) - .build(); - } - private DataSourceMetadata constructMyGlueDataSourceMetadataWithLakeFormation() { Map properties = new HashMap<>(); @@ -1287,8 +1095,7 @@ private DispatchQueryRequest constructDispatchQueryRequest( String query, LangType langType, String extraParameters) { return getBaseDispatchQueryRequestBuilder(query) .langType(langType) - .sparkSubmitParameterModifier( - (parameters) -> parameters.setExtraParameters(extraParameters)) + .sparkSubmitParameterModifier((builder) -> builder.extraParameters(extraParameters)) .build(); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java new file mode 100644 index 0000000000..c0c97caa58 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +class SparkParameterComposerCollectionTest { + + @Mock DataSourceSparkParameterComposer composer1; + @Mock DataSourceSparkParameterComposer composer2; + @Mock DataSourceSparkParameterComposer composer3; + @Mock GeneralSparkParameterComposer generalComposer; + @Mock DispatchQueryRequest dispatchQueryRequest; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; + + DataSourceType type1 = new DataSourceType("TYPE1"); + DataSourceType type2 = new DataSourceType("TYPE2"); + DataSourceType type3 = new DataSourceType("TYPE3"); + + SparkParameterComposerCollection collection; + + @BeforeEach + void setUp() { + collection = new SparkParameterComposerCollection(); + collection.register(type1, composer1); + collection.register(type1, composer2); + collection.register(type2, composer3); + collection.register(generalComposer); + } + + @Test + void isComposerRegistered() { + assertTrue(collection.isComposerRegistered(type1)); + assertTrue(collection.isComposerRegistered(type2)); + assertFalse(collection.isComposerRegistered(type3)); + } + + @Test + void composeByDataSourceWithRegisteredType() { + DataSourceMetadata metadata = + new DataSourceMetadata.Builder().setConnector(type1).setName("name").build(); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + collection.composeByDataSource( + metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + + verify(composer1) + .compose(metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + verify(composer2) + .compose(metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + verifyNoInteractions(composer3); + } + + @Test + void composeByDataSourceWithUnregisteredType() { + DataSourceMetadata metadata = + new DataSourceMetadata.Builder().setConnector(type3).setName("name").build(); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + collection.composeByDataSource( + metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + + verifyNoInteractions(composer1, composer2, composer3); + } + + @Test + void compose() { + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + collection.compose(sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + + verify(generalComposer) + .compose(sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + verifyNoInteractions(composer1, composer2, composer3); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java new file mode 100644 index 0000000000..3f4bea02f2 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java @@ -0,0 +1,201 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +public class SparkSubmitParametersBuilderTest { + + @Mock SparkParameterComposerCollection sparkParameterComposerCollection; + @Mock SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; + @Mock DispatchQueryRequest dispatchQueryRequest; + + @InjectMocks SparkSubmitParametersBuilder sparkSubmitParametersBuilder; + + @Test + public void testBuildWithoutExtraParameters() { + String params = sparkSubmitParametersBuilder.toString(); + + assertNotNull(params); + } + + @Test + public void testBuildWithExtraParameters() { + String params = sparkSubmitParametersBuilder.extraParameters("--conf A=1").toString(); + + // Assert the conf is included with a space + assertTrue(params.endsWith(" --conf A=1")); + } + + @Test + public void testBuildQueryString() { + String rawQuery = "SHOW tables LIKE \"%\";"; + String expectedQueryInParams = "\"SHOW tables LIKE \\\"%\\\";\""; + String params = sparkSubmitParametersBuilder.query(rawQuery).toString(); + assertTrue(params.contains(expectedQueryInParams)); + } + + @Test + public void testBuildQueryStringNestedQuote() { + String rawQuery = "SELECT '\"1\"'"; + String expectedQueryInParams = "\"SELECT '\\\"1\\\"'\""; + String params = sparkSubmitParametersBuilder.query(rawQuery).toString(); + assertTrue(params.contains(expectedQueryInParams)); + } + + @Test + public void testBuildQueryStringSpecialCharacter() { + String rawQuery = "SELECT '{\"test ,:+\\\"inner\\\"/\\|?#><\"}'"; + String expectedQueryInParams = "SELECT '{\\\"test ,:+\\\\\\\"inner\\\\\\\"/\\\\|?#><\\\"}'"; + String params = sparkSubmitParametersBuilder.query(rawQuery).toString(); + assertTrue(params.contains(expectedQueryInParams)); + } + + @Test + public void testClassName() { + String params = sparkSubmitParametersBuilder.className("CLASS_NAME").toString(); + assertTrue(params.contains("--class CLASS_NAME")); + } + + @Test + public void testClusterName() { + String params = sparkSubmitParametersBuilder.clusterName("CLUSTER_NAME").toString(); + assertTrue(params.contains("spark.emr-serverless.driverEnv.FLINT_CLUSTER_NAME=CLUSTER_NAME")); + assertTrue(params.contains("spark.executorEnv.FLINT_CLUSTER_NAME=CLUSTER_NAME")); + } + + @Test + public void testOverrideConfigItem() { + SparkSubmitParameters params = sparkSubmitParametersBuilder.getSparkSubmitParameters(); + params.setConfigItem(SPARK_JARS_KEY, "Overridden"); + String result = params.toString(); + + assertTrue(result.contains(String.format("%s=Overridden", SPARK_JARS_KEY))); + } + + @Test + public void testDeleteConfigItem() { + SparkSubmitParameters params = sparkSubmitParametersBuilder.getSparkSubmitParameters(); + params.deleteConfigItem(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + String result = params.toString(); + + assertFalse(result.contains(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY)); + } + + @Test + public void testAddConfigItem() { + SparkSubmitParameters params = sparkSubmitParametersBuilder.getSparkSubmitParameters(); + params.setConfigItem("AdditionalKey", "Value"); + String result = params.toString(); + + assertTrue(result.contains("AdditionalKey=Value")); + } + + @Test + public void testStructuredStreaming() { + SparkSubmitParameters params = + sparkSubmitParametersBuilder.structuredStreaming(true).getSparkSubmitParameters(); + String result = params.toString(); + + assertTrue(result.contains("spark.flint.job.type=streaming")); + } + + @Test + public void testNonStructuredStreaming() { + SparkSubmitParameters params = + sparkSubmitParametersBuilder.structuredStreaming(false).getSparkSubmitParameters(); + String result = params.toString(); + + assertFalse(result.contains("spark.flint.job.type=streaming")); + } + + @Test + public void testSessionExecution() { + SparkSubmitParameters params = + sparkSubmitParametersBuilder + .sessionExecution("SESSION_ID", "DATASOURCE_NAME") + .getSparkSubmitParameters(); + String result = params.toString(); + + assertTrue( + result.contains("spark.flint.job.requestIndex=.query_execution_request_datasource_name")); + assertTrue(result.contains("spark.flint.job.sessionId=SESSION_ID")); + } + + @Test + public void testAcceptModifier() { + sparkSubmitParametersBuilder.acceptModifier(sparkSubmitParameterModifier); + + verify(sparkSubmitParameterModifier).modifyParameters(sparkSubmitParametersBuilder); + } + + @Test + public void testDataSource() { + when(sparkParameterComposerCollection.isComposerRegistered(DataSourceType.S3GLUE)) + .thenReturn(true); + + DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setConnector(DataSourceType.S3GLUE) + .setName("name") + .build(); + SparkSubmitParameters params = + sparkSubmitParametersBuilder + .dataSource(metadata, dispatchQueryRequest, asyncQueryRequestContext) + .getSparkSubmitParameters(); + + verify(sparkParameterComposerCollection) + .composeByDataSource(metadata, params, dispatchQueryRequest, asyncQueryRequestContext); + } + + @Test + public void testUnsupportedDataSource() { + when(sparkParameterComposerCollection.isComposerRegistered(DataSourceType.S3GLUE)) + .thenReturn(false); + + DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setConnector(DataSourceType.S3GLUE) + .setName("name") + .build(); + assertThrows( + UnsupportedOperationException.class, + () -> + sparkSubmitParametersBuilder.dataSource( + metadata, dispatchQueryRequest, asyncQueryRequestContext)); + } + + @Test + public void testAcceptComposers() { + SparkSubmitParameters params = + sparkSubmitParametersBuilder + .acceptComposers(dispatchQueryRequest, asyncQueryRequestContext) + .getSparkSubmitParameters(); + + verify(sparkParameterComposerCollection) + .compose(params, dispatchQueryRequest, asyncQueryRequestContext); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java new file mode 100644 index 0000000000..1925ada46e --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.parameter.GeneralSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkSubmitParameters; + +/** Load extra parameters from settings and add to Spark submit parameters */ +@RequiredArgsConstructor +public class OpenSearchExtraParameterComposer implements GeneralSparkParameterComposer { + private final SparkExecutionEngineConfigClusterSettingLoader settingLoader; + + @Override + public void compose( + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + settingLoader + .load() + .ifPresent( + settings -> + sparkSubmitParameters.setExtraParameters(settings.getSparkSubmitParameters())); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java index a034e04095..117d161440 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java @@ -6,7 +6,7 @@ package org.opensearch.sql.spark.config; import lombok.AllArgsConstructor; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; @AllArgsConstructor public class OpenSearchSparkSubmitParameterModifier implements SparkSubmitParameterModifier { @@ -14,7 +14,7 @@ public class OpenSearchSparkSubmitParameterModifier implements SparkSubmitParame private String extraParameters; @Override - public void modifyParameters(SparkSubmitParameters parameters) { - parameters.setExtraParameters(this.extraParameters); + public void modifyParameters(SparkSubmitParametersBuilder builder) { + builder.extraParameters(this.extraParameters); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java index 0347f5ffc1..5c1328bf91 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.config; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import lombok.Builder; import lombok.Data; import org.opensearch.sql.utils.SerializeUtils; @@ -14,6 +15,7 @@ * setting. */ @Data +@Builder @JsonIgnoreProperties(ignoreUnknown = true) public class SparkExecutionEngineConfigClusterSetting { // optional diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java new file mode 100644 index 0000000000..561a4653a5 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.sql.common.setting.Settings; + +@RequiredArgsConstructor +public class SparkExecutionEngineConfigClusterSettingLoader { + private final Settings settings; + + public Optional load() { + String sparkExecutionEngineConfigSettingString = + this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); + if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) { + return Optional.of( + AccessController.doPrivileged( + (PrivilegedAction) + () -> + SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigSettingString))); + } else { + return Optional.empty(); + } + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java index fe931a5b91..66ad964ad1 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java @@ -6,12 +6,8 @@ package org.opensearch.sql.spark.config; import static org.opensearch.sql.common.setting.Settings.Key.CLUSTER_NAME; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; -import java.security.AccessController; -import java.security.PrivilegedAction; import lombok.AllArgsConstructor; -import org.apache.commons.lang3.StringUtils; import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; @@ -19,7 +15,8 @@ @AllArgsConstructor public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier { - private Settings settings; + private final Settings settings; + private final SparkExecutionEngineConfigClusterSettingLoader settingLoader; @Override public SparkExecutionEngineConfig getSparkExecutionEngineConfig( @@ -30,23 +27,14 @@ public SparkExecutionEngineConfig getSparkExecutionEngineConfig( private SparkExecutionEngineConfig.SparkExecutionEngineConfigBuilder getBuilderFromSettingsIfAvailable() { - String sparkExecutionEngineConfigSettingString = - this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); - if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) { - SparkExecutionEngineConfigClusterSetting setting = - AccessController.doPrivileged( - (PrivilegedAction) - () -> - SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig( - sparkExecutionEngineConfigSettingString)); - return SparkExecutionEngineConfig.builder() - .applicationId(setting.getApplicationId()) - .executionRoleARN(setting.getExecutionRoleARN()) - .sparkSubmitParameterModifier( - new OpenSearchSparkSubmitParameterModifier(setting.getSparkSubmitParameters())) - .region(setting.getRegion()); - } else { - return SparkExecutionEngineConfig.builder(); - } + return settingLoader + .load() + .map( + setting -> + SparkExecutionEngineConfig.builder() + .applicationId(setting.getApplicationId()) + .executionRoleARN(setting.getExecutionRoleARN()) + .region(setting.getRegion())) + .orElse(SparkExecutionEngineConfig.builder()); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java b/async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java new file mode 100644 index 0000000000..26dbf3529a --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_REGION; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_LAKEFORMATION_ENABLED; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DRIVER_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR_LAKEFORMATION_OPTION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_ACCELERATE_USING_COVERING_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DATA_SOURCE_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DELEGATE_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_PASSWORD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_USERNAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.function.Supplier; +import org.apache.commons.lang3.BooleanUtils; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +public class S3GlueDataSourceSparkParameterComposer implements DataSourceSparkParameterComposer { + public static final String FLINT_BASIC_AUTH = "basic"; + + @Override + public void compose( + DataSourceMetadata metadata, + SparkSubmitParameters params, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); + + params.setConfigItem(DRIVER_ENV_ASSUME_ROLE_ARN_KEY, roleArn); + params.setConfigItem(EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, roleArn); + params.setConfigItem(HIVE_METASTORE_GLUE_ARN_KEY, roleArn); + params.setConfigItem("spark.sql.catalog." + metadata.getName(), FLINT_DELEGATE_CATALOG); + params.setConfigItem(FLINT_DATA_SOURCE_KEY, metadata.getName()); + + final boolean lakeFormationEnabled = + BooleanUtils.toBoolean(metadata.getProperties().get(GLUE_LAKEFORMATION_ENABLED)); + params.setConfigItem(EMR_LAKEFORMATION_OPTION, Boolean.toString(lakeFormationEnabled)); + params.setConfigItem( + FLINT_ACCELERATE_USING_COVERING_INDEX, Boolean.toString(!lakeFormationEnabled)); + + setFlintIndexStoreHost( + params, + parseUri( + metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_URI), metadata.getName())); + setFlintIndexStoreAuthProperties( + params, + metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_REGION)); + params.setConfigItem("spark.flint.datasource.name", metadata.getName()); + } + + private void setFlintIndexStoreHost(SparkSubmitParameters params, URI uri) { + params.setConfigItem(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); + params.setConfigItem(FLINT_INDEX_STORE_PORT_KEY, String.valueOf(uri.getPort())); + params.setConfigItem(FLINT_INDEX_STORE_SCHEME_KEY, uri.getScheme()); + } + + private void setFlintIndexStoreAuthProperties( + SparkSubmitParameters params, + String authType, + Supplier userName, + Supplier password, + Supplier region) { + if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) { + params.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH); + params.setConfigItem(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get()); + params.setConfigItem(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get()); + } else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) { + params.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, "sigv4"); + params.setConfigItem(FLINT_INDEX_STORE_AWSREGION_KEY, region.get()); + } else { + params.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, authType); + } + } + + private URI parseUri(String opensearchUri, String datasourceName) { + try { + return new URI(opensearchUri); + } catch (URISyntaxException e) { + throw new IllegalArgumentException( + String.format( + "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); + } + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index d75b6616f7..05f7d1095c 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -15,6 +15,7 @@ import org.opensearch.common.inject.Singleton; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.legacy.metrics.GaugeMetric; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; @@ -23,6 +24,8 @@ import org.opensearch.sql.spark.asyncquery.OpenSearchAsyncQueryJobMetadataStorageService; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; +import org.opensearch.sql.spark.config.OpenSearchExtraParameterComposer; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigClusterSettingLoader; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.DatasourceEmbeddedQueryIdProvider; @@ -53,6 +56,9 @@ import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; import org.opensearch.sql.spark.metrics.OpenSearchMetricsService; +import org.opensearch.sql.spark.parameter.S3GlueDataSourceSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; @@ -111,7 +117,8 @@ public QueryHandlerFactory queryhandlerFactory( IndexDMLResultStorageService indexDMLResultStorageService, FlintIndexOpFactory flintIndexOpFactory, EMRServerlessClientFactory emrServerlessClientFactory, - MetricsService metricsService) { + MetricsService metricsService, + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider) { return new QueryHandlerFactory( openSearchJobExecutionResponseReader, flintIndexMetadataReader, @@ -120,7 +127,8 @@ public QueryHandlerFactory queryhandlerFactory( indexDMLResultStorageService, flintIndexOpFactory, emrServerlessClientFactory, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } @Provides @@ -147,6 +155,15 @@ public FlintIndexStateModelService flintIndexStateModelService( return new OpenSearchFlintIndexStateModelService(stateStore, serializer); } + @Provides + public SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider( + Settings settings, SparkExecutionEngineConfigClusterSettingLoader clusterSettingLoader) { + SparkParameterComposerCollection collection = new SparkParameterComposerCollection(); + collection.register(DataSourceType.S3GLUE, new S3GlueDataSourceSparkParameterComposer()); + collection.register(new OpenSearchExtraParameterComposer(clusterSettingLoader)); + return new SparkSubmitParametersBuilderProvider(collection); + } + @Provides public IndexDMLResultStorageService indexDMLResultStorageService( DataSourceService dataSourceService, StateStore stateStore) { @@ -197,8 +214,15 @@ public MetricsService metricsService() { } @Provides - public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier(Settings settings) { - return new SparkExecutionEngineConfigSupplierImpl(settings); + public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier( + Settings settings, SparkExecutionEngineConfigClusterSettingLoader clusterSettingLoader) { + return new SparkExecutionEngineConfigSupplierImpl(settings, clusterSettingLoader); + } + + @Provides + public SparkExecutionEngineConfigClusterSettingLoader + sparkExecutionEngineConfigClusterSettingLoader(Settings settings) { + return new SparkExecutionEngineConfigClusterSettingLoader(settings); } @Provides diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index a12a5aeac8..ed00cb1022 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -88,6 +88,9 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.metrics.OpenSearchMetricsService; +import org.opensearch.sql.spark.parameter.S3GlueDataSourceSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.storage.DataSourceFactory; @@ -253,6 +256,12 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpenSearchAsyncQueryJobMetadataStorageService( stateStore, new AsyncQueryJobMetadataXContentSerializer()); + SparkParameterComposerCollection sparkParameterComposerCollection = + new SparkParameterComposerCollection(); + sparkParameterComposerCollection.register( + DataSourceType.S3GLUE, new S3GlueDataSourceSparkParameterComposer()); + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider = + new SparkSubmitParametersBuilderProvider(sparkParameterComposerCollection); QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( jobExecutionResponseReader, @@ -271,7 +280,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataServiceImpl(client), emrServerlessClientFactory), emrServerlessClientFactory, - new OpenSearchMetricsService()); + new OpenSearchMetricsService(), + sparkSubmitParametersBuilderProvider); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( this.dataSourceService, diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java new file mode 100644 index 0000000000..d3b0b2727a --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.parameter.SparkSubmitParameters; + +@ExtendWith(MockitoExtension.class) +class OpenSearchExtraParameterComposerTest { + + public static final String PARAMS = "PARAMS"; + @Mock SparkExecutionEngineConfigClusterSettingLoader settingsLoader; + @Mock SparkSubmitParameters sparkSubmitParameters; + @Mock DispatchQueryRequest dispatchQueryRequest; + @Mock AsyncQueryRequestContext context; + + @InjectMocks OpenSearchExtraParameterComposer openSearchExtraParameterComposer; + + @Test + public void paramExists_compose() { + SparkExecutionEngineConfigClusterSetting setting = + SparkExecutionEngineConfigClusterSetting.builder().sparkSubmitParameters(PARAMS).build(); + when(settingsLoader.load()).thenReturn(Optional.of(setting)); + + openSearchExtraParameterComposer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + + verify(sparkSubmitParameters).setExtraParameters(PARAMS); + } + + @Test + public void paramNotExist_compose() { + when(settingsLoader.load()).thenReturn(Optional.empty()); + + openSearchExtraParameterComposer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + + verifyNoInteractions(sparkSubmitParameters); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java new file mode 100644 index 0000000000..f9ccd93b00 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; +import static org.opensearch.sql.spark.constants.TestConstants.ACCOUNT_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS; +import static org.opensearch.sql.spark.constants.TestConstants.US_WEST_REGION; + +import java.util.Optional; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; + +@ExtendWith(MockitoExtension.class) +class SparkExecutionEngineConfigClusterSettingLoaderTest { + @Mock Settings settings; + + @InjectMocks + SparkExecutionEngineConfigClusterSettingLoader sparkExecutionEngineConfigClusterSettingLoader; + + @Test + public void blankConfig() { + when(settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(""); + + Optional result = + sparkExecutionEngineConfigClusterSettingLoader.load(); + + assertTrue(result.isEmpty()); + } + + @Test + public void validConfig() { + when(settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(getConfigJson()); + + SparkExecutionEngineConfigClusterSetting result = + sparkExecutionEngineConfigClusterSettingLoader.load().get(); + + Assertions.assertEquals(ACCOUNT_ID, result.getAccountId()); + Assertions.assertEquals(EMRS_APPLICATION_ID, result.getApplicationId()); + Assertions.assertEquals(EMRS_EXECUTION_ROLE, result.getExecutionRoleARN()); + Assertions.assertEquals(US_WEST_REGION, result.getRegion()); + Assertions.assertEquals(SPARK_SUBMIT_PARAMETERS, result.getSparkSubmitParameters()); + } + + String getConfigJson() { + return new JSONObject() + .put("accountId", ACCOUNT_ID) + .put("applicationId", EMRS_APPLICATION_ID) + .put("executionRoleARN", EMRS_EXECUTION_ROLE) + .put("region", US_WEST_REGION) + .put("sparkSubmitParameters", SPARK_SUBMIT_PARAMETERS) + .toString(); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java index 128868a755..124d8d0b6e 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java @@ -12,7 +12,7 @@ import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.constants.TestConstants.US_WEST_REGION; -import org.json.JSONObject; +import java.util.Optional; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -21,7 +21,6 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; @ExtendWith(MockitoExtension.class) public class SparkExecutionEngineConfigSupplierImplTest { @@ -29,43 +28,46 @@ public class SparkExecutionEngineConfigSupplierImplTest { @Mock private Settings settings; @Mock private AsyncQueryRequestContext asyncQueryRequestContext; + @Mock + private SparkExecutionEngineConfigClusterSettingLoader + sparkExecutionEngineConfigClusterSettingLoader; + @Test void testGetSparkExecutionEngineConfig() { SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(settings); - when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) - .thenReturn(getConfigJson()); + new SparkExecutionEngineConfigSupplierImpl( + settings, sparkExecutionEngineConfigClusterSettingLoader); when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + when(sparkExecutionEngineConfigClusterSettingLoader.load()) + .thenReturn(Optional.of(getClusterSetting())); SparkExecutionEngineConfig sparkExecutionEngineConfig = sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext); - SparkSubmitParameters parameters = SparkSubmitParameters.builder().build(); - sparkExecutionEngineConfig.getSparkSubmitParameterModifier().modifyParameters(parameters); Assertions.assertEquals(EMRS_APPLICATION_ID, sparkExecutionEngineConfig.getApplicationId()); Assertions.assertEquals(EMRS_EXECUTION_ROLE, sparkExecutionEngineConfig.getExecutionRoleARN()); Assertions.assertEquals(US_WEST_REGION, sparkExecutionEngineConfig.getRegion()); Assertions.assertEquals(TEST_CLUSTER_NAME, sparkExecutionEngineConfig.getClusterName()); - Assertions.assertTrue(parameters.toString().contains(SPARK_SUBMIT_PARAMETERS)); } - String getConfigJson() { - return new JSONObject() - .put("applicationId", EMRS_APPLICATION_ID) - .put("executionRoleARN", EMRS_EXECUTION_ROLE) - .put("region", US_WEST_REGION) - .put("sparkSubmitParameters", SPARK_SUBMIT_PARAMETERS) - .toString(); + SparkExecutionEngineConfigClusterSetting getClusterSetting() { + return SparkExecutionEngineConfigClusterSetting.builder() + .applicationId(EMRS_APPLICATION_ID) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .region(US_WEST_REGION) + .sparkSubmitParameters(SPARK_SUBMIT_PARAMETERS) + .build(); } @Test void testGetSparkExecutionEngineConfigWithNullSetting() { SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(settings); - when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(null); + new SparkExecutionEngineConfigSupplierImpl( + settings, sparkExecutionEngineConfigClusterSettingLoader); when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + when(sparkExecutionEngineConfigClusterSettingLoader.load()).thenReturn(Optional.empty()); SparkExecutionEngineConfig sparkExecutionEngineConfig = sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index 5b4ffbea2c..15871bf6b2 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -8,6 +8,7 @@ public class TestConstants { public static final String TEST_DATASOURCE_NAME = "test_datasource_name"; public static final String EMR_JOB_ID = "job-123xxx"; + public static final String ACCOUNT_ID = "TEST_ACCOUNT_ID"; public static final String EMRS_APPLICATION_ID = "app-xxxxx"; public static final String EMRS_EXECUTION_ROLE = "execution_role"; public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; diff --git a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java index 06689a15d0..e5ca93e96e 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java @@ -9,7 +9,8 @@ import static org.opensearch.sql.spark.constants.TestConstants.TEST_DATASOURCE_NAME; import java.util.HashMap; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; public class SessionTestUtil { @@ -19,7 +20,7 @@ public static CreateSessionRequest createSessionRequest() { null, "appId", "arn", - SparkSubmitParameters.builder().build(), + new SparkSubmitParametersBuilder(new SparkParameterComposerCollection()), new HashMap<>(), "resultIndex", TEST_DATASOURCE_NAME); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java new file mode 100644 index 0000000000..55e62d52f0 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceStatus; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +class S3GlueDataSourceSparkParameterComposerTest { + + public static final String VALID_URI = "https://test.host.com:9200"; + public static final String INVALID_URI = "http://test/\r\n"; + public static final String USERNAME = "USERNAME"; + public static final String PASSWORD = "PASSWORD"; + public static final String REGION = "REGION"; + public static final String TRUE = "true"; + public static final String ROLE_ARN = "ROLE_ARN"; + + private static final String COMMON_EXPECTED_PARAMS = + " --class org.apache.spark.sql.FlintJob " + + getConfList( + "spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=ROLE_ARN", + "spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=ROLE_ARN", + "spark.hive.metastore.glue.role.arn=ROLE_ARN", + "spark.sql.catalog.DATASOURCE_NAME=org.opensearch.sql.FlintDelegatingSessionCatalog", + "spark.flint.datasource.name=DATASOURCE_NAME", + "spark.emr-serverless.lakeformation.enabled=true", + "spark.flint.optimizer.covering.enabled=false", + "spark.datasource.flint.host=test.host.com", + "spark.datasource.flint.port=9200", + "spark.datasource.flint.scheme=https"); + + @Mock DispatchQueryRequest dispatchQueryRequest; + + @Test + public void testBasicAuth() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.BASICAUTH, VALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext()); + + assertEquals( + COMMON_EXPECTED_PARAMS + + getConfList( + "spark.datasource.flint.auth=basic", + "spark.datasource.flint.auth.username=USERNAME", + "spark.datasource.flint.auth.password=PASSWORD"), + sparkSubmitParameters.toString()); + } + + @Test + public void testComposeWithSigV4Auth() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.AWSSIGV4AUTH, VALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext()); + + assertEquals( + COMMON_EXPECTED_PARAMS + + getConfList( + "spark.datasource.flint.auth=sigv4", "spark.datasource.flint.region=REGION"), + sparkSubmitParameters.toString()); + } + + @Test + public void testComposeWithNoAuth() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.NOAUTH, VALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext()); + + assertEquals( + COMMON_EXPECTED_PARAMS + getConfList("spark.datasource.flint.auth=noauth"), + sparkSubmitParameters.toString()); + } + + @Test + public void testComposeWithBadUri() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.NOAUTH, INVALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + assertThrows( + IllegalArgumentException.class, + () -> + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext())); + } + + private DataSourceMetadata getDataSourceMetadata( + AuthenticationType authenticationType, String uri) { + return new DataSourceMetadata.Builder() + .setConnector(DataSourceType.S3GLUE) + .setName("DATASOURCE_NAME") + .setDescription("DESCRIPTION") + .setResultIndex("RESULT_INDEX") + .setDataSourceStatus(DataSourceStatus.ACTIVE) + .setProperties(getProperties(authenticationType, uri)) + .build(); + } + + private Map getProperties(AuthenticationType authType, String uri) { + return ImmutableMap.builder() + .put(GlueDataSourceFactory.GLUE_ROLE_ARN, ROLE_ARN) + .put(GlueDataSourceFactory.GLUE_LAKEFORMATION_ENABLED, TRUE) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI, uri) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH, authType.getName()) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME, USERNAME) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD, PASSWORD) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_REGION, REGION) + .build(); + } + + private static String getConfList(String... params) { + return Arrays.stream(params) + .map(param -> String.format(" --conf %s ", param)) + .collect(Collectors.joining()); + } +} From eb49d98c693a2edd93e1a3eb1682c2c58e9faaca Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 28 Jun 2024 11:45:27 -0700 Subject: [PATCH 2/3] Fix comments Signed-off-by: Tomoyuki Morita --- .../DataSourceSparkParameterComposer.java | 7 ++++++- .../GeneralSparkParameterComposer.java | 16 +++++++++++++++- .../SparkParameterComposerCollection.java | 17 ++++++++++++++--- ...ecutionEngineConfigClusterSettingLoader.java | 1 + 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java index 4411bfe22d..324889b6e0 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java @@ -9,7 +9,12 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; -/** Compose Spark parameter based on DataSourceMetadata */ +/** + * Compose Spark parameters specific to the {@link + * org.opensearch.sql.datasource.model.DataSourceType} based on the {@link DataSourceMetadata}. For + * the parameters not specific to {@link org.opensearch.sql.datasource.model.DataSourceType}, please + * use {@link GeneralSparkParameterComposer}. + */ public interface DataSourceSparkParameterComposer { void compose( DataSourceMetadata dataSourceMetadata, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java index b5200d550c..c3d46ba5c6 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java @@ -8,8 +8,22 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; -/** Compose spark submit parameters. See {@link SparkParameterComposerCollection}. */ +/** + * Compose spark submit parameters based on the request and context. For {@link + * org.opensearch.sql.datasource.model.DataSourceType} specific parameters, please use {@link + * DataSourceSparkParameterComposer}. See {@link SparkParameterComposerCollection}. + */ public interface GeneralSparkParameterComposer { + + /** + * Modify sparkSubmitParameters based on dispatchQueryRequest and context. + * + * @param sparkSubmitParameters Implementation of this method will modify this. + * @param dispatchQueryRequest Request. Implementation can refer it to compose + * sparkSubmitParameters. + * @param context Context of the request. Implementation can refer it to compose + * sparkSubmitParameters. + */ void compose( SparkSubmitParameters sparkSubmitParameters, DispatchQueryRequest dispatchQueryRequest, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java index 281759afd2..a6a88738bf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java @@ -16,12 +16,19 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; -/** Stores Spark parameter composers and dispatch compose request to each composer */ +/** + * Stores Spark parameter composers and dispatch compose request to each composer. Composers should + * be registered during initialization such as in Guice Module. + */ public class SparkParameterComposerCollection { - Collection generalComposers = new ArrayList<>(); - Map> datasourceComposers = + private Collection generalComposers = new ArrayList<>(); + private Map> datasourceComposers = new HashMap<>(); + /** + * Register composers for specific DataSourceType. The registered composer is called only if the + * request is for the dataSourceType. + */ public void register(DataSourceType dataSourceType, DataSourceSparkParameterComposer composer) { if (!datasourceComposers.containsKey(dataSourceType)) { datasourceComposers.put(dataSourceType, new LinkedList<>()); @@ -29,6 +36,10 @@ public void register(DataSourceType dataSourceType, DataSourceSparkParameterComp datasourceComposers.get(dataSourceType).add(composer); } + /** + * Register general composer. The composer is called when spark parameter is generated regardless + * of datasource type. + */ public void register(GeneralSparkParameterComposer composer) { generalComposers.add(composer); } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java index 561a4653a5..73b057ca5c 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java @@ -14,6 +14,7 @@ import org.apache.commons.lang3.StringUtils; import org.opensearch.sql.common.setting.Settings; +/** Load SparkExecutionEngineConfigClusterSetting from settings with privilege check. */ @RequiredArgsConstructor public class SparkExecutionEngineConfigClusterSettingLoader { private final Settings settings; From 9f2162eac0ba281eb3e67c09ca43ad7c6de200c9 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 28 Jun 2024 15:52:33 -0700 Subject: [PATCH 3/3] Fix integ test Signed-off-by: Tomoyuki Morita --- .../asyncquery/AsyncQueryCoreIntegTest.java | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index f2d205ffc6..99d4cc722e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -7,9 +7,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; @@ -77,6 +77,8 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -110,6 +112,7 @@ public class AsyncQueryCoreIntegTest { @Mock FlintIndexClient flintIndexClient; @Mock AsyncQueryRequestContext asyncQueryRequestContext; @Mock MetricsService metricsService; + @Mock SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; // storage services @Mock AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; @@ -134,6 +137,16 @@ public class AsyncQueryCoreIntegTest { public void setUp() { emrServerlessClientFactory = (accountId) -> new EmrServerlessClientImpl(awsemrServerless, metricsService); + SparkParameterComposerCollection collection = new SparkParameterComposerCollection(); + collection.register( + DataSourceType.S3GLUE, + (dataSourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> + sparkSubmitParameters.setConfigItem( + "key.from.datasource.composer", "value.from.datasource.composer")); + collection.register( + (sparkSubmitParameters, dispatchQueryRequest, context) -> + sparkSubmitParameters.setConfigItem( + "key.from.generic.composer", "value.from.generic.composer")); SessionManager sessionManager = new SessionManager( sessionStorageService, @@ -156,7 +169,8 @@ public void setUp() { indexDMLResultStorageService, flintIndexOpFactory, emrServerlessClientFactory, - metricsService); + metricsService, + new SparkSubmitParametersBuilderProvider(collection)); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); @@ -272,7 +286,11 @@ private void verifyStartJobRunCalled() { verify(awsemrServerless).startJobRun(startJobRunRequestArgumentCaptor.capture()); StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue(); assertEquals(APPLICATION_ID, startJobRunRequest.getApplicationId()); - assertNotNull(startJobRunRequest.getJobDriver().getSparkSubmit().getSparkSubmitParameters()); + String submitParameters = + startJobRunRequest.getJobDriver().getSparkSubmit().getSparkSubmitParameters(); + assertTrue( + submitParameters.contains("key.from.datasource.composer=value.from.datasource.composer")); + assertTrue(submitParameters.contains("key.from.generic.composer=value.from.generic.composer")); } @Test