From fb6ded6c8d6fa91059b7c0ada74dfe8890f1d403 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 14 Jun 2024 16:27:21 -0700 Subject: [PATCH] Introduce SparkParameterComposerCollection Signed-off-by: Tomoyuki Morita --- .../model/SparkSubmitParameters.java | 231 ------------------ .../config/SparkSubmitParameterModifier.java | 4 +- .../spark/dispatcher/BatchQueryHandler.java | 13 +- .../dispatcher/InteractiveQueryHandler.java | 15 +- .../spark/dispatcher/QueryHandlerFactory.java | 17 +- .../spark/dispatcher/RefreshQueryHandler.java | 11 +- .../dispatcher/StreamingQueryHandler.java | 20 +- .../session/CreateSessionRequest.java | 6 +- .../execution/session/InteractiveSession.java | 7 +- .../DataSourceSparkParameterComposer.java | 19 ++ .../GeneralSparkParameterComposer.java | 17 ++ .../SparkParameterComposerCollection.java | 65 +++++ .../parameter/SparkSubmitParameters.java | 52 ++++ .../SparkSubmitParametersBuilder.java | 171 +++++++++++++ .../SparkSubmitParametersBuilderProvider.java | 18 ++ .../AsyncQueryExecutorServiceImplTest.java | 5 +- .../model/SparkSubmitParametersTest.java | 84 ------- .../client/EmrServerlessClientImplTest.java | 45 ++-- .../dispatcher/SparkQueryDispatcherTest.java | 15 +- .../SparkParameterComposerCollectionTest.java | 80 ++++++ .../SparkSubmitParametersBuilderTest.java | 181 ++++++++++++++ .../OpenSearchExtraParameterComposer.java | 30 +++ ...penSearchSparkSubmitParameterModifier.java | 6 +- ...rkExecutionEngineConfigClusterSetting.java | 2 + ...utionEngineConfigClusterSettingLoader.java | 35 +++ ...parkExecutionEngineConfigSupplierImpl.java | 34 +-- ...3GlueDataSourceSparkParameterComposer.java | 108 ++++++++ .../config/AsyncExecutorServiceModule.java | 32 ++- .../AsyncQueryExecutorServiceSpec.java | 12 +- .../OpenSearchExtraParameterComposerTest.java | 52 ++++ ...nEngineConfigClusterSettingLoaderTest.java | 67 +++++ ...ExecutionEngineConfigSupplierImplTest.java | 36 +-- .../sql/spark/constants/TestConstants.java | 1 + .../execution/session/SessionTestUtil.java | 5 +- ...eDataSourceSparkParameterComposerTest.java | 160 ++++++++++++ 35 files changed, 1227 insertions(+), 429 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 6badea6a74..e69de29bb2 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -1,231 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.asyncquery.model; - -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_REGION; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_LAKEFORMATION_ENABLED; -import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; -import static org.opensearch.sql.spark.data.constants.SparkConstants.*; - -import java.net.URI; -import java.net.URISyntaxException; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.function.Supplier; -import lombok.AllArgsConstructor; -import lombok.RequiredArgsConstructor; -import lombok.Setter; -import org.apache.commons.lang3.BooleanUtils; -import org.apache.commons.text.StringEscapeUtils; -import org.opensearch.sql.datasource.model.DataSourceMetadata; -import org.opensearch.sql.datasource.model.DataSourceType; -import org.opensearch.sql.datasources.auth.AuthenticationType; -import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; -import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; - -/** Define Spark Submit Parameters. */ -@AllArgsConstructor -@RequiredArgsConstructor -public class SparkSubmitParameters { - public static final String SPACE = " "; - public static final String EQUALS = "="; - public static final String FLINT_BASIC_AUTH = "basic"; - - private final String className; - private final Map config; - - /** Extra parameters to append finally */ - @Setter private String extraParameters; - - public void setConfigItem(String key, String value) { - config.put(key, value); - } - - public void deleteConfigItem(String key) { - config.remove(key); - } - - public static Builder builder() { - return Builder.builder(); - } - - public SparkSubmitParameters acceptModifier(SparkSubmitParameterModifier modifier) { - modifier.modifyParameters(this); - return this; - } - - public static class Builder { - - private String className; - private final Map config; - private String extraParameters; - - private Builder() { - className = DEFAULT_CLASS_NAME; - config = new LinkedHashMap<>(); - - config.put(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); - config.put( - HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, - DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); - config.put(SPARK_JARS_KEY, ICEBERG_SPARK_RUNTIME_PACKAGE); - config.put( - SPARK_JAR_PACKAGES_KEY, - SPARK_STANDALONE_PACKAGE + "," + SPARK_LAUNCHER_PACKAGE + "," + PPL_STANDALONE_PACKAGE); - config.put(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); - config.put(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); - config.put(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); - config.put(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); - config.put(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); - config.put(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); - config.put(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); - config.put(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); - config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); - config.put(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); - config.put( - SPARK_SQL_EXTENSIONS_KEY, - ICEBERG_SPARK_EXTENSION + "," + FLINT_SQL_EXTENSION + "," + FLINT_PPL_EXTENSION); - config.put(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); - config.put(SPARK_CATALOG, ICEBERG_SESSION_CATALOG); - config.put(SPARK_CATALOG_CATALOG_IMPL, ICEBERG_GLUE_CATALOG); - } - - public static Builder builder() { - return new Builder(); - } - - public Builder className(String className) { - this.className = className; - return this; - } - - public Builder clusterName(String clusterName) { - config.put(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); - config.put(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); - return this; - } - - /** - * For query in spark submit parameters to be parsed correctly, escape the characters in the - * query, then wrap the query with double quotes. - */ - public Builder query(String query) { - String escapedQuery = StringEscapeUtils.escapeJava(query); - String wrappedQuery = "\"" + escapedQuery + "\""; - config.put(FLINT_JOB_QUERY, wrappedQuery); - return this; - } - - public Builder dataSource(DataSourceMetadata metadata) { - if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { - String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); - - config.put(DRIVER_ENV_ASSUME_ROLE_ARN_KEY, roleArn); - config.put(EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, roleArn); - config.put(HIVE_METASTORE_GLUE_ARN_KEY, roleArn); - config.put("spark.sql.catalog." + metadata.getName(), FLINT_DELEGATE_CATALOG); - config.put(FLINT_DATA_SOURCE_KEY, metadata.getName()); - - final boolean lakeFormationEnabled = - BooleanUtils.toBoolean(metadata.getProperties().get(GLUE_LAKEFORMATION_ENABLED)); - config.put(EMR_LAKEFORMATION_OPTION, Boolean.toString(lakeFormationEnabled)); - config.put(FLINT_ACCELERATE_USING_COVERING_INDEX, Boolean.toString(!lakeFormationEnabled)); - - setFlintIndexStoreHost( - parseUri( - metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_URI), metadata.getName())); - setFlintIndexStoreAuthProperties( - metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_REGION)); - config.put("spark.flint.datasource.name", metadata.getName()); - return this; - } - throw new UnsupportedOperationException( - String.format( - "UnSupported datasource type for async queries:: %s", metadata.getConnector())); - } - - private void setFlintIndexStoreHost(URI uri) { - config.put(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); - config.put(FLINT_INDEX_STORE_PORT_KEY, String.valueOf(uri.getPort())); - config.put(FLINT_INDEX_STORE_SCHEME_KEY, uri.getScheme()); - } - - private void setFlintIndexStoreAuthProperties( - String authType, - Supplier userName, - Supplier password, - Supplier region) { - if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) { - config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH); - config.put(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get()); - config.put(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get()); - } else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) { - config.put(FLINT_INDEX_STORE_AUTH_KEY, "sigv4"); - config.put(FLINT_INDEX_STORE_AWSREGION_KEY, region.get()); - } else { - config.put(FLINT_INDEX_STORE_AUTH_KEY, authType); - } - } - - private URI parseUri(String opensearchUri, String datasourceName) { - try { - return new URI(opensearchUri); - } catch (URISyntaxException e) { - throw new IllegalArgumentException( - String.format( - "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); - } - } - - public Builder structuredStreaming(Boolean isStructuredStreaming) { - if (isStructuredStreaming) { - config.put("spark.flint.job.type", "streaming"); - } - return this; - } - - public Builder extraParameters(String params) { - extraParameters = params; - return this; - } - - public SparkSubmitParameters build() { - return new SparkSubmitParameters(className, config, extraParameters); - } - } - - public void sessionExecution(String sessionId, String datasourceName) { - config.put(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName)); - config.put(FLINT_JOB_SESSION_ID, sessionId); - } - - @Override - public String toString() { - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append(" --class "); - stringBuilder.append(this.className); - stringBuilder.append(SPACE); - for (String key : config.keySet()) { - stringBuilder.append(" --conf "); - stringBuilder.append(key); - stringBuilder.append(EQUALS); - stringBuilder.append(config.get(key)); - stringBuilder.append(SPACE); - } - - if (extraParameters != null) { - stringBuilder.append(extraParameters); - } - return stringBuilder.toString(); - } -} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java b/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java index 1c6ce5952a..8c184f2f69 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/config/SparkSubmitParameterModifier.java @@ -1,11 +1,11 @@ package org.opensearch.sql.spark.config; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; /** * Interface for extension point to allow modification of spark submit parameter. modifyParameter * method is called after the default spark submit parameter is build. */ public interface SparkSubmitParameterModifier { - void modifyParameters(SparkSubmitParameters parameters); + void modifyParameters(SparkSubmitParametersBuilder parametersBuilder); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 8014cf935f..2654f83aad 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -16,7 +16,6 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -25,6 +24,7 @@ import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -37,6 +37,7 @@ public class BatchQueryHandler extends AsyncQueryHandler { protected final JobExecutionResponseReader jobExecutionResponseReader; protected final LeaseManager leaseManager; protected final MetricsService metricsService; + protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { @@ -80,12 +81,16 @@ public DispatchQueryResponse submit( dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.builder() + sparkSubmitParametersBuilderProvider + .getSparkSubmitParametersBuilder() .clusterName(clusterName) - .dataSource(context.getDataSourceMetadata()) .query(dispatchQueryRequest.getQuery()) - .build() + .dataSource( + context.getDataSourceMetadata(), + dispatchQueryRequest, + context.getAsyncQueryRequestContext()) .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) + .acceptComposers(dispatchQueryRequest, context.getAsyncQueryRequestContext()) .toString(), tags, false, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 266d5db978..ec43bccf11 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -16,7 +16,6 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -32,6 +31,7 @@ import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.EmrMetrics; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -46,6 +46,7 @@ public class InteractiveQueryHandler extends AsyncQueryHandler { private final JobExecutionResponseReader jobExecutionResponseReader; private final LeaseManager leaseManager; private final MetricsService metricsService; + protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { @@ -112,12 +113,16 @@ public DispatchQueryResponse submit( dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.builder() + sparkSubmitParametersBuilderProvider + .getSparkSubmitParametersBuilder() .className(FLINT_SESSION_CLASS_NAME) .clusterName(clusterName) - .dataSource(dataSourceMetadata) - .build() - .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()), + .dataSource( + dataSourceMetadata, + dispatchQueryRequest, + context.getAsyncQueryRequestContext()) + .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) + .acceptComposers(dispatchQueryRequest, context.getAsyncQueryRequestContext()), tags, dataSourceMetadata.getResultIndex(), dataSourceMetadata.getName()), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java index 9951edc5a9..603b5a6765 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -13,6 +13,7 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @RequiredArgsConstructor @@ -26,6 +27,7 @@ public class QueryHandlerFactory { private final FlintIndexOpFactory flintIndexOpFactory; private final EMRServerlessClientFactory emrServerlessClientFactory; private final MetricsService metricsService; + protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; public RefreshQueryHandler getRefreshQueryHandler() { return new RefreshQueryHandler( @@ -34,7 +36,8 @@ public RefreshQueryHandler getRefreshQueryHandler() { flintIndexMetadataService, leaseManager, flintIndexOpFactory, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } public StreamingQueryHandler getStreamingQueryHandler() { @@ -42,7 +45,8 @@ public StreamingQueryHandler getStreamingQueryHandler() { emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } public BatchQueryHandler getBatchQueryHandler() { @@ -50,12 +54,17 @@ public BatchQueryHandler getBatchQueryHandler() { emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } public InteractiveQueryHandler getInteractiveQueryHandler() { return new InteractiveQueryHandler( - sessionManager, jobExecutionResponseReader, leaseManager, metricsService); + sessionManager, + jobExecutionResponseReader, + leaseManager, + metricsService, + sparkSubmitParametersBuilderProvider); } public IndexDMLHandler getIndexDMLHandler() { diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 634dfa49f6..99984ecc46 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -20,6 +20,7 @@ import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -37,8 +38,14 @@ public RefreshQueryHandler( FlintIndexMetadataService flintIndexMetadataService, LeaseManager leaseManager, FlintIndexOpFactory flintIndexOpFactory, - MetricsService metricsService) { - super(emrServerlessClient, jobExecutionResponseReader, leaseManager, metricsService); + MetricsService metricsService, + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider) { + super( + emrServerlessClient, + jobExecutionResponseReader, + leaseManager, + metricsService, + sparkSubmitParametersBuilderProvider); this.flintIndexMetadataService = flintIndexMetadataService; this.flintIndexOpFactory = flintIndexOpFactory; } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 7291637e5b..2fbf2466da 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -12,7 +12,6 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -23,6 +22,7 @@ import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -35,8 +35,14 @@ public StreamingQueryHandler( EMRServerlessClient emrServerlessClient, JobExecutionResponseReader jobExecutionResponseReader, LeaseManager leaseManager, - MetricsService metricsService) { - super(emrServerlessClient, jobExecutionResponseReader, leaseManager, metricsService); + MetricsService metricsService, + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider) { + super( + emrServerlessClient, + jobExecutionResponseReader, + leaseManager, + metricsService, + sparkSubmitParametersBuilderProvider); } @Override @@ -70,13 +76,15 @@ public DispatchQueryResponse submit( dispatchQueryRequest.getAccountId(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.builder() + sparkSubmitParametersBuilderProvider + .getSparkSubmitParametersBuilder() .clusterName(clusterName) - .dataSource(dataSourceMetadata) .query(dispatchQueryRequest.getQuery()) .structuredStreaming(true) - .build() + .dataSource( + dataSourceMetadata, dispatchQueryRequest, context.getAsyncQueryRequestContext()) .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()) + .acceptComposers(dispatchQueryRequest, context.getAsyncQueryRequestContext()) .toString(), tags, indexQueryDetails.getFlintIndexOptions().autoRefresh(), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index 4170f0c2d6..6398dd224f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -7,9 +7,9 @@ import java.util.Map; import lombok.Data; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; @Data public class CreateSessionRequest { @@ -17,7 +17,7 @@ public class CreateSessionRequest { private final String accountId; private final String applicationId; private final String executionRoleArn; - private final SparkSubmitParameters sparkSubmitParameters; + private final SparkSubmitParametersBuilder sparkSubmitParametersBuilder; private final Map tags; private final String resultIndex; private final String datasourceName; @@ -28,7 +28,7 @@ public StartJobRequest getStartJobRequest(String sessionId) { accountId, applicationId, executionRoleArn, - sparkSubmitParameters.toString(), + sparkSubmitParametersBuilder.toString(), tags, resultIndex); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 4a8d6a8f58..b836457e8e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -56,11 +56,8 @@ public void open( try { // append session id; createSessionRequest - .getSparkSubmitParameters() - .acceptModifier( - (parameters) -> { - parameters.sessionExecution(sessionId, createSessionRequest.getDatasourceName()); - }); + .getSparkSubmitParametersBuilder() + .sessionExecution(sessionId, createSessionRequest.getDatasourceName()); createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId); StartJobRequest startJobRequest = createSessionRequest.getStartJobRequest(sessionId); String jobID = serverlessClient.startJobRun(startJobRequest); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java new file mode 100644 index 0000000000..4411bfe22d --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/DataSourceSparkParameterComposer.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Compose Spark parameter based on DataSourceMetadata */ +public interface DataSourceSparkParameterComposer { + void compose( + DataSourceMetadata dataSourceMetadata, + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java new file mode 100644 index 0000000000..b5200d550c --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/GeneralSparkParameterComposer.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Compose spark submit parameters. See {@link SparkParameterComposerCollection}. */ +public interface GeneralSparkParameterComposer { + void compose( + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java new file mode 100644 index 0000000000..7a7c00b655 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollection.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +/** Stores Spark parameter composers and dispatch compose request to each composer */ +public class SparkParameterComposerCollection { + Collection generalComposers = new ArrayList<>(); + Map> datasourceComposers = + new HashMap<>(); + + public void register(DataSourceType dataSourceType, DataSourceSparkParameterComposer composer) { + if (!datasourceComposers.containsKey(dataSourceType)) { + datasourceComposers.put(dataSourceType, new LinkedList<>()); + } + datasourceComposers.get(dataSourceType).add(composer); + } + + public void register(GeneralSparkParameterComposer composer) { + generalComposers.add(composer); + } + + /** Execute composers associated with the datasource type */ + public void composeByDataSource( + DataSourceMetadata dataSourceMetadata, + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + for (DataSourceSparkParameterComposer composer : + getComposersFor(dataSourceMetadata.getConnector())) { + composer.compose(dataSourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context); + } + } + + /** Execute all the registered generic composers */ + public void compose( + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + for (GeneralSparkParameterComposer composer : generalComposers) { + composer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + } + } + + private Collection getComposersFor(DataSourceType type) { + return datasourceComposers.getOrDefault(type, ImmutableList.of()); + } + + public boolean isComposerRegistered(DataSourceType type) { + return datasourceComposers.containsKey(type) && !datasourceComposers.get(type).isEmpty(); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java new file mode 100644 index 0000000000..2e142ed117 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParameters.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_CLASS_NAME; + +import java.util.LinkedHashMap; +import java.util.Map; +import lombok.Setter; + +/** Define Spark Submit Parameters. */ +public class SparkSubmitParameters { + public static final String SPACE = " "; + public static final String EQUALS = "="; + + @Setter private String className = DEFAULT_CLASS_NAME; + private Map config = new LinkedHashMap<>(); + + /** Extra parameters to append finally */ + @Setter private String extraParameters; + + public void setConfigItem(String key, String value) { + config.put(key, value); + } + + public void deleteConfigItem(String key) { + config.remove(key); + } + + @Override + public String toString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append(" --class "); + stringBuilder.append(this.className); + stringBuilder.append(SPACE); + for (String key : config.keySet()) { + stringBuilder.append(" --conf "); + stringBuilder.append(key); + stringBuilder.append(EQUALS); + stringBuilder.append(config.get(key)); + stringBuilder.append(SPACE); + } + + if (extraParameters != null) { + stringBuilder.append(extraParameters); + } + return stringBuilder.toString(); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java new file mode 100644 index 0000000000..01a665a485 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.AWS_SNAPSHOT_REPOSITORY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_CLUSTER_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_PPL_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SQL_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.GLUE_HIVE_CATALOG_FACTORY_CLASS; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_CLASS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_GLUE_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_SESSION_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_SPARK_EXTENSION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ICEBERG_SPARK_RUNTIME_PACKAGE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.JAVA_HOME_LOCATION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.PPL_STANDALONE_PACKAGE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.S3_AWS_CREDENTIALS_PROVIDER_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_CATALOG_CATALOG_IMPL; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_DRIVER_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_EXECUTOR_ENV_JAVA_HOME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_PACKAGES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JAR_REPOSITORIES_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_LAUNCHER_PACKAGE; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_EXTENSIONS_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_STANDALONE_PACKAGE; + +import lombok.Getter; +import org.apache.commons.text.StringEscapeUtils; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; + +public class SparkSubmitParametersBuilder { + private final SparkParameterComposerCollection sparkParameterComposerCollection; + @Getter private final SparkSubmitParameters sparkSubmitParameters; + + public SparkSubmitParametersBuilder( + SparkParameterComposerCollection sparkParameterComposerCollection) { + this.sparkParameterComposerCollection = sparkParameterComposerCollection; + sparkSubmitParameters = new SparkSubmitParameters(); + setDefaultConfigs(); + } + + private void setDefaultConfigs() { + setConfigItem(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); + setConfigItem( + HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, + DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + setConfigItem(SPARK_JARS_KEY, ICEBERG_SPARK_RUNTIME_PACKAGE); + setConfigItem( + SPARK_JAR_PACKAGES_KEY, + SPARK_STANDALONE_PACKAGE + "," + SPARK_LAUNCHER_PACKAGE + "," + PPL_STANDALONE_PACKAGE); + setConfigItem(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); + setConfigItem(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + setConfigItem(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + setConfigItem(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); + setConfigItem(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); + setConfigItem(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); + setConfigItem(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); + setConfigItem(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); + setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); + setConfigItem(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); + setConfigItem( + SPARK_SQL_EXTENSIONS_KEY, + ICEBERG_SPARK_EXTENSION + "," + FLINT_SQL_EXTENSION + "," + FLINT_PPL_EXTENSION); + setConfigItem(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); + setConfigItem(SPARK_CATALOG, ICEBERG_SESSION_CATALOG); + setConfigItem(SPARK_CATALOG_CATALOG_IMPL, ICEBERG_GLUE_CATALOG); + } + + private void setConfigItem(String key, String value) { + sparkSubmitParameters.setConfigItem(key, value); + } + + public SparkSubmitParametersBuilder className(String className) { + sparkSubmitParameters.setClassName(className); + return this; + } + + /** clusterName will be used for logging and metrics in Spark */ + public SparkSubmitParametersBuilder clusterName(String clusterName) { + setConfigItem(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); + setConfigItem(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, clusterName); + return this; + } + + /** + * For query in spark submit parameters to be parsed correctly, escape the characters in the + * query, then wrap the query with double quotes. + */ + public SparkSubmitParametersBuilder query(String query) { + String escapedQuery = StringEscapeUtils.escapeJava(query); + String wrappedQuery = "\"" + escapedQuery + "\""; + setConfigItem(FLINT_JOB_QUERY, wrappedQuery); + return this; + } + + public SparkSubmitParametersBuilder dataSource( + DataSourceMetadata metadata, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + if (sparkParameterComposerCollection.isComposerRegistered(metadata.getConnector())) { + sparkParameterComposerCollection.composeByDataSource( + metadata, sparkSubmitParameters, dispatchQueryRequest, context); + return this; + } else { + throw new UnsupportedOperationException( + String.format( + "UnSupported datasource type for async queries:: %s", metadata.getConnector())); + } + } + + public SparkSubmitParametersBuilder structuredStreaming(Boolean isStructuredStreaming) { + if (isStructuredStreaming) { + setConfigItem("spark.flint.job.type", "streaming"); + } + return this; + } + + public SparkSubmitParametersBuilder extraParameters(String params) { + sparkSubmitParameters.setExtraParameters(params); + return this; + } + + public SparkSubmitParametersBuilder sessionExecution(String sessionId, String datasourceName) { + setConfigItem(FLINT_JOB_REQUEST_INDEX, OpenSearchStateStoreUtil.getIndexName(datasourceName)); + setConfigItem(FLINT_JOB_SESSION_ID, sessionId); + return this; + } + + public SparkSubmitParametersBuilder acceptModifier(SparkSubmitParameterModifier modifier) { + modifier.modifyParameters(this); + return this; + } + + public SparkSubmitParametersBuilder acceptComposers( + DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext context) { + sparkParameterComposerCollection.compose(sparkSubmitParameters, dispatchQueryRequest, context); + return this; + } + + @Override + public String toString() { + return sparkSubmitParameters.toString(); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java new file mode 100644 index 0000000000..ccc9ffb680 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderProvider.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import lombok.RequiredArgsConstructor; + +/** Provide SparkSubmitParametersBuilder instance with SparkParameterComposerCollection injected */ +@RequiredArgsConstructor +public class SparkSubmitParametersBuilderProvider { + private final SparkParameterComposerCollection sparkParameterComposerCollection; + + public SparkSubmitParametersBuilder getSparkSubmitParametersBuilder() { + return new SparkSubmitParametersBuilder(sparkParameterComposerCollection); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 8325a10fbc..dbc51bb0ad 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -33,7 +33,6 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; @@ -115,9 +114,7 @@ void testCreateAsyncQuery() { @Test void testCreateAsyncQueryWithExtraSparkSubmitParameter() { SparkSubmitParameterModifier modifier = - (SparkSubmitParameters parameters) -> { - parameters.setExtraParameters("--conf spark.dynamicAllocation.enabled=false"); - }; + (builder) -> builder.extraParameters("--conf spark.dynamicAllocation.enabled=false"); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( SparkExecutionEngineConfig.builder() diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java index 10f12251b0..e69de29bb2 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java @@ -1,84 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.asyncquery.model; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; -import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; - -import org.junit.jupiter.api.Test; - -public class SparkSubmitParametersTest { - - @Test - public void testBuildWithoutExtraParameters() { - String params = SparkSubmitParameters.builder().build().toString(); - - assertNotNull(params); - } - - @Test - public void testBuildWithExtraParameters() { - String params = - SparkSubmitParameters.builder().extraParameters("--conf A=1").build().toString(); - - // Assert the conf is included with a space - assertTrue(params.endsWith(" --conf A=1")); - } - - @Test - public void testBuildQueryString() { - String rawQuery = "SHOW tables LIKE \"%\";"; - String expectedQueryInParams = "\"SHOW tables LIKE \\\"%\\\";\""; - String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); - assertTrue(params.contains(expectedQueryInParams)); - } - - @Test - public void testBuildQueryStringNestedQuote() { - String rawQuery = "SELECT '\"1\"'"; - String expectedQueryInParams = "\"SELECT '\\\"1\\\"'\""; - String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); - assertTrue(params.contains(expectedQueryInParams)); - } - - @Test - public void testBuildQueryStringSpecialCharacter() { - String rawQuery = "SELECT '{\"test ,:+\\\"inner\\\"/\\|?#><\"}'"; - String expectedQueryInParams = "SELECT '{\\\"test ,:+\\\\\\\"inner\\\\\\\"/\\\\|?#><\\\"}'"; - String params = SparkSubmitParameters.builder().query(rawQuery).build().toString(); - assertTrue(params.contains(expectedQueryInParams)); - } - - @Test - public void testOverrideConfigItem() { - SparkSubmitParameters params = SparkSubmitParameters.builder().build(); - params.setConfigItem(SPARK_JARS_KEY, "Overridden"); - String result = params.toString(); - - assertTrue(result.contains(String.format("%s=Overridden", SPARK_JARS_KEY))); - } - - @Test - public void testDeleteConfigItem() { - SparkSubmitParameters params = SparkSubmitParameters.builder().build(); - params.deleteConfigItem(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); - String result = params.toString(); - - assertFalse(result.contains(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY)); - } - - @Test - public void testAddConfigItem() { - SparkSubmitParameters params = SparkSubmitParameters.builder().build(); - params.setConfigItem("AdditionalKey", "Value"); - String result = params.toString(); - - assertTrue(result.contains("AdditionalKey=Value")); - } -} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 35b42ccaaf..993d489ded 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -37,14 +37,16 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; @ExtendWith(MockitoExtension.class) public class EmrServerlessClientImplTest { @@ -54,6 +56,8 @@ public class EmrServerlessClientImplTest { @Captor private ArgumentCaptor startJobRunRequestArgumentCaptor; + @InjectMocks EmrServerlessClientImpl emrServerlessClient; + @BeforeEach public void setUp() { doReturn(emptyList()).when(settings).getSettings(); @@ -68,9 +72,10 @@ void testStartJobRun() { StartJobRunResult response = new StartJobRunResult(); when(emrServerless.startJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); - String parameters = SparkSubmitParameters.builder().query(QUERY).build().toString(); + String parameters = + new SparkSubmitParametersBuilder(new SparkParameterComposerCollection()) + .query(QUERY) + .toString(); emrServerlessClient.startJobRun( new StartJobRequest( @@ -104,8 +109,6 @@ void testStartJobRunWithErrorMetric() { doThrow(new AWSEMRServerlessException("Couldn't start job")) .when(emrServerless) .startJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -128,8 +131,6 @@ void testStartJobRunResultIndex() { StartJobRunResult response = new StartJobRunResult(); when(emrServerless.startJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); emrServerlessClient.startJobRun( new StartJobRequest( EMRS_JOB_NAME, @@ -149,16 +150,12 @@ void testGetJobRunState() { GetJobRunResult response = new GetJobRunResult(); response.setJobRun(jobRun); when(emrServerless.getJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); } @Test void testGetJobRunStateWithErrorMetric() { doThrow(new ValidationException("Not a good job")).when(emrServerless).getJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, @@ -170,18 +167,17 @@ void testGetJobRunStateWithErrorMetric() { void testCancelJobRun() { when(emrServerless.cancelJobRun(any())) .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false); + Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); } @Test void testCancelJobRunWithErrorMetric() { doThrow(new RuntimeException()).when(emrServerless).cancelJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + Assertions.assertThrows( RuntimeException.class, () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, "123", false)); @@ -190,24 +186,24 @@ void testCancelJobRunWithErrorMetric() { @Test void testCancelJobRunWithValidationException() { doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + RuntimeException runtimeException = Assertions.assertThrows( RuntimeException.class, () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)); + Assertions.assertEquals("Internal Server Error.", runtimeException.getMessage()); } @Test void testCancelJobRunWithNativeEMRExceptionWithValidationException() { doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + ValidationException validationException = Assertions.assertThrows( ValidationException.class, () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, true)); + Assertions.assertTrue(validationException.getMessage().contains("Error")); } @@ -215,10 +211,10 @@ void testCancelJobRunWithNativeEMRExceptionWithValidationException() { void testCancelJobRunWithNativeEMRException() { when(emrServerless.cancelJobRun(any())) .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); + CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, true); + Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); } @@ -227,8 +223,6 @@ void testStartJobRunWithLongJobName() { StartJobRunResult response = new StartJobRunResult(); when(emrServerless.startJobRun(any())).thenReturn(response); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); emrServerlessClient.startJobRun( new StartJobRequest( RandomStringUtils.random(300), @@ -239,6 +233,7 @@ void testStartJobRunWithLongJobName() { new HashMap<>(), false, DEFAULT_RESULT_INDEX)); + verify(emrServerless, times(1)).startJobRun(startJobRunRequestArgumentCaptor.capture()); StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue(); Assertions.assertEquals(255, startJobRunRequest.getName().length()); @@ -247,8 +242,6 @@ void testStartJobRunWithLongJobName() { @Test void testStartJobRunThrowsValidationException() { when(emrServerless.startJobRun(any())).thenThrow(new ValidationException("Unmatched quote")); - EmrServerlessClientImpl emrServerlessClient = - new EmrServerlessClientImpl(emrServerless, metricsService); IllegalArgumentException exception = Assertions.assertThrows( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index d57284b9ca..2158aef68a 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -76,6 +76,9 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.DataSourceSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -95,8 +98,11 @@ public class SparkQueryDispatcherTest { @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; @Mock private QueryIdProvider queryIdProvider; @Mock private AsyncQueryRequestContext asyncQueryRequestContext; + @Mock private DataSourceSparkParameterComposer dataSourceSparkParameterComposer; @Mock private MetricsService metricsService; + private SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; + @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -111,6 +117,9 @@ public class SparkQueryDispatcherTest { @BeforeEach void setUp() { + SparkParameterComposerCollection collection = new SparkParameterComposerCollection(); + collection.register(DataSourceType.S3GLUE, dataSourceSparkParameterComposer); + sparkSubmitParametersBuilderProvider = new SparkSubmitParametersBuilderProvider(collection); QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( jobExecutionResponseReader, @@ -120,7 +129,8 @@ void setUp() { indexDMLResultStorageService, flintIndexOpFactory, emrServerlessClientFactory, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); sparkQueryDispatcher = new SparkQueryDispatcher( dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); @@ -1287,8 +1297,7 @@ private DispatchQueryRequest constructDispatchQueryRequest( String query, LangType langType, String extraParameters) { return getBaseDispatchQueryRequestBuilder(query) .langType(langType) - .sparkSubmitParameterModifier( - (parameters) -> parameters.setExtraParameters(extraParameters)) + .sparkSubmitParameterModifier((builder) -> builder.extraParameters(extraParameters)) .build(); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java new file mode 100644 index 0000000000..91aa521499 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkParameterComposerCollectionTest.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +class SparkParameterComposerCollectionTest { + + @Mock DataSourceSparkParameterComposer composer1; + @Mock DataSourceSparkParameterComposer composer2; + @Mock DataSourceSparkParameterComposer composer3; + @Mock DispatchQueryRequest dispatchQueryRequest; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; + + DataSourceType type1 = new DataSourceType("TYPE1"); + DataSourceType type2 = new DataSourceType("TYPE2"); + DataSourceType type3 = new DataSourceType("TYPE3"); + + SparkParameterComposerCollection collection; + + @BeforeEach + void setUp() { + collection = new SparkParameterComposerCollection(); + collection.register(type1, composer1); + collection.register(type1, composer2); + collection.register(type2, composer3); + } + + @Test + void isComposerRegistered() { + assertTrue(collection.isComposerRegistered(type1)); + assertTrue(collection.isComposerRegistered(type2)); + assertFalse(collection.isComposerRegistered(type3)); + } + + @Test + void composeWithRegisteredType() { + DataSourceMetadata metadata = + new DataSourceMetadata.Builder().setConnector(type1).setName("name").build(); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + collection.composeByDataSource( + metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + + verify(composer1) + .compose(metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + verify(composer2) + .compose(metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + verifyNoInteractions(composer3); + } + + @Test + void composeWithUnregisteredType() { + DataSourceMetadata metadata = + new DataSourceMetadata.Builder().setConnector(type3).setName("name").build(); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + collection.composeByDataSource( + metadata, sparkSubmitParameters, dispatchQueryRequest, asyncQueryRequestContext); + + verifyNoInteractions(composer1, composer2, composer3); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java new file mode 100644 index 0000000000..467b7247c8 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilderTest.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_JARS_KEY; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +public class SparkSubmitParametersBuilderTest { + + @Mock SparkParameterComposerCollection sparkParameterComposerCollection; + @Mock SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; + @Mock DispatchQueryRequest dispatchQueryRequest; + + @InjectMocks SparkSubmitParametersBuilder sparkSubmitParametersBuilder; + + @Test + public void testBuildWithoutExtraParameters() { + String params = sparkSubmitParametersBuilder.toString(); + + assertNotNull(params); + } + + @Test + public void testBuildWithExtraParameters() { + String params = sparkSubmitParametersBuilder.extraParameters("--conf A=1").toString(); + + // Assert the conf is included with a space + assertTrue(params.endsWith(" --conf A=1")); + } + + @Test + public void testBuildQueryString() { + String rawQuery = "SHOW tables LIKE \"%\";"; + String expectedQueryInParams = "\"SHOW tables LIKE \\\"%\\\";\""; + String params = sparkSubmitParametersBuilder.query(rawQuery).toString(); + assertTrue(params.contains(expectedQueryInParams)); + } + + @Test + public void testBuildQueryStringNestedQuote() { + String rawQuery = "SELECT '\"1\"'"; + String expectedQueryInParams = "\"SELECT '\\\"1\\\"'\""; + String params = sparkSubmitParametersBuilder.query(rawQuery).toString(); + assertTrue(params.contains(expectedQueryInParams)); + } + + @Test + public void testBuildQueryStringSpecialCharacter() { + String rawQuery = "SELECT '{\"test ,:+\\\"inner\\\"/\\|?#><\"}'"; + String expectedQueryInParams = "SELECT '{\\\"test ,:+\\\\\\\"inner\\\\\\\"/\\\\|?#><\\\"}'"; + String params = sparkSubmitParametersBuilder.query(rawQuery).toString(); + assertTrue(params.contains(expectedQueryInParams)); + } + + @Test + public void testClassName() { + String params = sparkSubmitParametersBuilder.className("CLASS_NAME").toString(); + assertTrue(params.contains("--class CLASS_NAME")); + } + + @Test + public void testClusterName() { + String params = sparkSubmitParametersBuilder.clusterName("CLUSTER_NAME").toString(); + assertTrue(params.contains("spark.emr-serverless.driverEnv.FLINT_CLUSTER_NAME=CLUSTER_NAME")); + assertTrue(params.contains("spark.executorEnv.FLINT_CLUSTER_NAME=CLUSTER_NAME")); + } + + @Test + public void testOverrideConfigItem() { + SparkSubmitParameters params = sparkSubmitParametersBuilder.getSparkSubmitParameters(); + params.setConfigItem(SPARK_JARS_KEY, "Overridden"); + String result = params.toString(); + + assertTrue(result.contains(String.format("%s=Overridden", SPARK_JARS_KEY))); + } + + @Test + public void testDeleteConfigItem() { + SparkSubmitParameters params = sparkSubmitParametersBuilder.getSparkSubmitParameters(); + params.deleteConfigItem(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + String result = params.toString(); + + assertFalse(result.contains(HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY)); + } + + @Test + public void testAddConfigItem() { + SparkSubmitParameters params = sparkSubmitParametersBuilder.getSparkSubmitParameters(); + params.setConfigItem("AdditionalKey", "Value"); + String result = params.toString(); + + assertTrue(result.contains("AdditionalKey=Value")); + } + + @Test + public void testStructuredStreaming() { + SparkSubmitParameters params = + sparkSubmitParametersBuilder.structuredStreaming(true).getSparkSubmitParameters(); + String result = params.toString(); + + assertTrue(result.contains("spark.flint.job.type=streaming")); + } + + @Test + public void testSessionExecution() { + SparkSubmitParameters params = + sparkSubmitParametersBuilder + .sessionExecution("SESSION_ID", "DATASOURCE_NAME") + .getSparkSubmitParameters(); + String result = params.toString(); + + assertTrue( + result.contains("spark.flint.job.requestIndex=.query_execution_request_datasource_name")); + assertTrue(result.contains("spark.flint.job.sessionId=SESSION_ID")); + } + + @Test + public void testAcceptModifier() { + sparkSubmitParametersBuilder.acceptModifier(sparkSubmitParameterModifier); + + verify(sparkSubmitParameterModifier).modifyParameters(sparkSubmitParametersBuilder); + } + + @Test + public void testDataSource() { + when(sparkParameterComposerCollection.isComposerRegistered(DataSourceType.S3GLUE)) + .thenReturn(true); + + DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setConnector(DataSourceType.S3GLUE) + .setName("name") + .build(); + SparkSubmitParameters params = + sparkSubmitParametersBuilder + .dataSource(metadata, dispatchQueryRequest, asyncQueryRequestContext) + .getSparkSubmitParameters(); + + verify(sparkParameterComposerCollection) + .composeByDataSource(metadata, params, dispatchQueryRequest, asyncQueryRequestContext); + } + + @Test + public void testUnsupportedDataSource() { + when(sparkParameterComposerCollection.isComposerRegistered(DataSourceType.S3GLUE)) + .thenReturn(false); + + DataSourceMetadata metadata = + new DataSourceMetadata.Builder() + .setConnector(DataSourceType.S3GLUE) + .setName("name") + .build(); + assertThrows( + UnsupportedOperationException.class, + () -> + sparkSubmitParametersBuilder.dataSource( + metadata, dispatchQueryRequest, asyncQueryRequestContext)); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java new file mode 100644 index 0000000000..1925ada46e --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposer.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.parameter.GeneralSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkSubmitParameters; + +/** Load extra parameters from settings and add to Spark submit parameters */ +@RequiredArgsConstructor +public class OpenSearchExtraParameterComposer implements GeneralSparkParameterComposer { + private final SparkExecutionEngineConfigClusterSettingLoader settingLoader; + + @Override + public void compose( + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + settingLoader + .load() + .ifPresent( + settings -> + sparkSubmitParameters.setExtraParameters(settings.getSparkSubmitParameters())); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java index a034e04095..117d161440 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchSparkSubmitParameterModifier.java @@ -6,7 +6,7 @@ package org.opensearch.sql.spark.config; import lombok.AllArgsConstructor; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; @AllArgsConstructor public class OpenSearchSparkSubmitParameterModifier implements SparkSubmitParameterModifier { @@ -14,7 +14,7 @@ public class OpenSearchSparkSubmitParameterModifier implements SparkSubmitParame private String extraParameters; @Override - public void modifyParameters(SparkSubmitParameters parameters) { - parameters.setExtraParameters(this.extraParameters); + public void modifyParameters(SparkSubmitParametersBuilder builder) { + builder.extraParameters(this.extraParameters); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java index 0347f5ffc1..5c1328bf91 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.config; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import lombok.Builder; import lombok.Data; import org.opensearch.sql.utils.SerializeUtils; @@ -14,6 +15,7 @@ * setting. */ @Data +@Builder @JsonIgnoreProperties(ignoreUnknown = true) public class SparkExecutionEngineConfigClusterSetting { // optional diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java new file mode 100644 index 0000000000..561a4653a5 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoader.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.sql.common.setting.Settings; + +@RequiredArgsConstructor +public class SparkExecutionEngineConfigClusterSettingLoader { + private final Settings settings; + + public Optional load() { + String sparkExecutionEngineConfigSettingString = + this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); + if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) { + return Optional.of( + AccessController.doPrivileged( + (PrivilegedAction) + () -> + SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigSettingString))); + } else { + return Optional.empty(); + } + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java index fe931a5b91..66ad964ad1 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java @@ -6,12 +6,8 @@ package org.opensearch.sql.spark.config; import static org.opensearch.sql.common.setting.Settings.Key.CLUSTER_NAME; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; -import java.security.AccessController; -import java.security.PrivilegedAction; import lombok.AllArgsConstructor; -import org.apache.commons.lang3.StringUtils; import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; @@ -19,7 +15,8 @@ @AllArgsConstructor public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier { - private Settings settings; + private final Settings settings; + private final SparkExecutionEngineConfigClusterSettingLoader settingLoader; @Override public SparkExecutionEngineConfig getSparkExecutionEngineConfig( @@ -30,23 +27,14 @@ public SparkExecutionEngineConfig getSparkExecutionEngineConfig( private SparkExecutionEngineConfig.SparkExecutionEngineConfigBuilder getBuilderFromSettingsIfAvailable() { - String sparkExecutionEngineConfigSettingString = - this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); - if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) { - SparkExecutionEngineConfigClusterSetting setting = - AccessController.doPrivileged( - (PrivilegedAction) - () -> - SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig( - sparkExecutionEngineConfigSettingString)); - return SparkExecutionEngineConfig.builder() - .applicationId(setting.getApplicationId()) - .executionRoleARN(setting.getExecutionRoleARN()) - .sparkSubmitParameterModifier( - new OpenSearchSparkSubmitParameterModifier(setting.getSparkSubmitParameters())) - .region(setting.getRegion()); - } else { - return SparkExecutionEngineConfig.builder(); - } + return settingLoader + .load() + .map( + setting -> + SparkExecutionEngineConfig.builder() + .applicationId(setting.getApplicationId()) + .executionRoleARN(setting.getExecutionRoleARN()) + .region(setting.getRegion())) + .orElse(SparkExecutionEngineConfig.builder()); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java b/async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java new file mode 100644 index 0000000000..26dbf3529a --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposer.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_REGION; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_LAKEFORMATION_ENABLED; +import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN; +import static org.opensearch.sql.spark.data.constants.SparkConstants.DRIVER_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR_LAKEFORMATION_OPTION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_ACCELERATE_USING_COVERING_INDEX; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DATA_SOURCE_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DELEGATE_CATALOG; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_PASSWORD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AUTH_USERNAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_AWSREGION_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_HOST_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.function.Supplier; +import org.apache.commons.lang3.BooleanUtils; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +public class S3GlueDataSourceSparkParameterComposer implements DataSourceSparkParameterComposer { + public static final String FLINT_BASIC_AUTH = "basic"; + + @Override + public void compose( + DataSourceMetadata metadata, + SparkSubmitParameters params, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); + + params.setConfigItem(DRIVER_ENV_ASSUME_ROLE_ARN_KEY, roleArn); + params.setConfigItem(EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, roleArn); + params.setConfigItem(HIVE_METASTORE_GLUE_ARN_KEY, roleArn); + params.setConfigItem("spark.sql.catalog." + metadata.getName(), FLINT_DELEGATE_CATALOG); + params.setConfigItem(FLINT_DATA_SOURCE_KEY, metadata.getName()); + + final boolean lakeFormationEnabled = + BooleanUtils.toBoolean(metadata.getProperties().get(GLUE_LAKEFORMATION_ENABLED)); + params.setConfigItem(EMR_LAKEFORMATION_OPTION, Boolean.toString(lakeFormationEnabled)); + params.setConfigItem( + FLINT_ACCELERATE_USING_COVERING_INDEX, Boolean.toString(!lakeFormationEnabled)); + + setFlintIndexStoreHost( + params, + parseUri( + metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_URI), metadata.getName())); + setFlintIndexStoreAuthProperties( + params, + metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_REGION)); + params.setConfigItem("spark.flint.datasource.name", metadata.getName()); + } + + private void setFlintIndexStoreHost(SparkSubmitParameters params, URI uri) { + params.setConfigItem(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); + params.setConfigItem(FLINT_INDEX_STORE_PORT_KEY, String.valueOf(uri.getPort())); + params.setConfigItem(FLINT_INDEX_STORE_SCHEME_KEY, uri.getScheme()); + } + + private void setFlintIndexStoreAuthProperties( + SparkSubmitParameters params, + String authType, + Supplier userName, + Supplier password, + Supplier region) { + if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) { + params.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH); + params.setConfigItem(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get()); + params.setConfigItem(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get()); + } else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) { + params.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, "sigv4"); + params.setConfigItem(FLINT_INDEX_STORE_AWSREGION_KEY, region.get()); + } else { + params.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, authType); + } + } + + private URI parseUri(String opensearchUri, String datasourceName) { + try { + return new URI(opensearchUri); + } catch (URISyntaxException e) { + throw new IllegalArgumentException( + String.format( + "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); + } + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index d75b6616f7..05f7d1095c 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -15,6 +15,7 @@ import org.opensearch.common.inject.Singleton; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.legacy.metrics.GaugeMetric; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; @@ -23,6 +24,8 @@ import org.opensearch.sql.spark.asyncquery.OpenSearchAsyncQueryJobMetadataStorageService; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; +import org.opensearch.sql.spark.config.OpenSearchExtraParameterComposer; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigClusterSettingLoader; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.DatasourceEmbeddedQueryIdProvider; @@ -53,6 +56,9 @@ import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; import org.opensearch.sql.spark.metrics.OpenSearchMetricsService; +import org.opensearch.sql.spark.parameter.S3GlueDataSourceSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; @@ -111,7 +117,8 @@ public QueryHandlerFactory queryhandlerFactory( IndexDMLResultStorageService indexDMLResultStorageService, FlintIndexOpFactory flintIndexOpFactory, EMRServerlessClientFactory emrServerlessClientFactory, - MetricsService metricsService) { + MetricsService metricsService, + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider) { return new QueryHandlerFactory( openSearchJobExecutionResponseReader, flintIndexMetadataReader, @@ -120,7 +127,8 @@ public QueryHandlerFactory queryhandlerFactory( indexDMLResultStorageService, flintIndexOpFactory, emrServerlessClientFactory, - metricsService); + metricsService, + sparkSubmitParametersBuilderProvider); } @Provides @@ -147,6 +155,15 @@ public FlintIndexStateModelService flintIndexStateModelService( return new OpenSearchFlintIndexStateModelService(stateStore, serializer); } + @Provides + public SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider( + Settings settings, SparkExecutionEngineConfigClusterSettingLoader clusterSettingLoader) { + SparkParameterComposerCollection collection = new SparkParameterComposerCollection(); + collection.register(DataSourceType.S3GLUE, new S3GlueDataSourceSparkParameterComposer()); + collection.register(new OpenSearchExtraParameterComposer(clusterSettingLoader)); + return new SparkSubmitParametersBuilderProvider(collection); + } + @Provides public IndexDMLResultStorageService indexDMLResultStorageService( DataSourceService dataSourceService, StateStore stateStore) { @@ -197,8 +214,15 @@ public MetricsService metricsService() { } @Provides - public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier(Settings settings) { - return new SparkExecutionEngineConfigSupplierImpl(settings); + public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier( + Settings settings, SparkExecutionEngineConfigClusterSettingLoader clusterSettingLoader) { + return new SparkExecutionEngineConfigSupplierImpl(settings, clusterSettingLoader); + } + + @Provides + public SparkExecutionEngineConfigClusterSettingLoader + sparkExecutionEngineConfigClusterSettingLoader(Settings settings) { + return new SparkExecutionEngineConfigClusterSettingLoader(settings); } @Provides diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index a5935db2c9..4a73fc8b13 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -88,6 +88,9 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.metrics.OpenSearchMetricsService; +import org.opensearch.sql.spark.parameter.S3GlueDataSourceSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.storage.DataSourceFactory; @@ -253,6 +256,12 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpenSearchAsyncQueryJobMetadataStorageService( stateStore, new AsyncQueryJobMetadataXContentSerializer()); + SparkParameterComposerCollection sparkParameterComposerCollection = + new SparkParameterComposerCollection(); + sparkParameterComposerCollection.register( + DataSourceType.S3GLUE, new S3GlueDataSourceSparkParameterComposer()); + SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider = + new SparkSubmitParametersBuilderProvider(sparkParameterComposerCollection); QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( jobExecutionResponseReader, @@ -271,7 +280,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( new FlintIndexMetadataServiceImpl(client), emrServerlessClientFactory), emrServerlessClientFactory, - new OpenSearchMetricsService()); + new OpenSearchMetricsService(), + sparkSubmitParametersBuilderProvider); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( this.dataSourceService, diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java new file mode 100644 index 0000000000..d3b0b2727a --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchExtraParameterComposerTest.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.parameter.SparkSubmitParameters; + +@ExtendWith(MockitoExtension.class) +class OpenSearchExtraParameterComposerTest { + + public static final String PARAMS = "PARAMS"; + @Mock SparkExecutionEngineConfigClusterSettingLoader settingsLoader; + @Mock SparkSubmitParameters sparkSubmitParameters; + @Mock DispatchQueryRequest dispatchQueryRequest; + @Mock AsyncQueryRequestContext context; + + @InjectMocks OpenSearchExtraParameterComposer openSearchExtraParameterComposer; + + @Test + public void paramExists_compose() { + SparkExecutionEngineConfigClusterSetting setting = + SparkExecutionEngineConfigClusterSetting.builder().sparkSubmitParameters(PARAMS).build(); + when(settingsLoader.load()).thenReturn(Optional.of(setting)); + + openSearchExtraParameterComposer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + + verify(sparkSubmitParameters).setExtraParameters(PARAMS); + } + + @Test + public void paramNotExist_compose() { + when(settingsLoader.load()).thenReturn(Optional.empty()); + + openSearchExtraParameterComposer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + + verifyNoInteractions(sparkSubmitParameters); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java new file mode 100644 index 0000000000..f9ccd93b00 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingLoaderTest.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; +import static org.opensearch.sql.spark.constants.TestConstants.ACCOUNT_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; +import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS; +import static org.opensearch.sql.spark.constants.TestConstants.US_WEST_REGION; + +import java.util.Optional; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; + +@ExtendWith(MockitoExtension.class) +class SparkExecutionEngineConfigClusterSettingLoaderTest { + @Mock Settings settings; + + @InjectMocks + SparkExecutionEngineConfigClusterSettingLoader sparkExecutionEngineConfigClusterSettingLoader; + + @Test + public void blankConfig() { + when(settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(""); + + Optional result = + sparkExecutionEngineConfigClusterSettingLoader.load(); + + assertTrue(result.isEmpty()); + } + + @Test + public void validConfig() { + when(settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(getConfigJson()); + + SparkExecutionEngineConfigClusterSetting result = + sparkExecutionEngineConfigClusterSettingLoader.load().get(); + + Assertions.assertEquals(ACCOUNT_ID, result.getAccountId()); + Assertions.assertEquals(EMRS_APPLICATION_ID, result.getApplicationId()); + Assertions.assertEquals(EMRS_EXECUTION_ROLE, result.getExecutionRoleARN()); + Assertions.assertEquals(US_WEST_REGION, result.getRegion()); + Assertions.assertEquals(SPARK_SUBMIT_PARAMETERS, result.getSparkSubmitParameters()); + } + + String getConfigJson() { + return new JSONObject() + .put("accountId", ACCOUNT_ID) + .put("applicationId", EMRS_APPLICATION_ID) + .put("executionRoleARN", EMRS_EXECUTION_ROLE) + .put("region", US_WEST_REGION) + .put("sparkSubmitParameters", SPARK_SUBMIT_PARAMETERS) + .toString(); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java index 128868a755..124d8d0b6e 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java @@ -12,7 +12,7 @@ import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.constants.TestConstants.US_WEST_REGION; -import org.json.JSONObject; +import java.util.Optional; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -21,7 +21,6 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; @ExtendWith(MockitoExtension.class) public class SparkExecutionEngineConfigSupplierImplTest { @@ -29,43 +28,46 @@ public class SparkExecutionEngineConfigSupplierImplTest { @Mock private Settings settings; @Mock private AsyncQueryRequestContext asyncQueryRequestContext; + @Mock + private SparkExecutionEngineConfigClusterSettingLoader + sparkExecutionEngineConfigClusterSettingLoader; + @Test void testGetSparkExecutionEngineConfig() { SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(settings); - when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) - .thenReturn(getConfigJson()); + new SparkExecutionEngineConfigSupplierImpl( + settings, sparkExecutionEngineConfigClusterSettingLoader); when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + when(sparkExecutionEngineConfigClusterSettingLoader.load()) + .thenReturn(Optional.of(getClusterSetting())); SparkExecutionEngineConfig sparkExecutionEngineConfig = sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext); - SparkSubmitParameters parameters = SparkSubmitParameters.builder().build(); - sparkExecutionEngineConfig.getSparkSubmitParameterModifier().modifyParameters(parameters); Assertions.assertEquals(EMRS_APPLICATION_ID, sparkExecutionEngineConfig.getApplicationId()); Assertions.assertEquals(EMRS_EXECUTION_ROLE, sparkExecutionEngineConfig.getExecutionRoleARN()); Assertions.assertEquals(US_WEST_REGION, sparkExecutionEngineConfig.getRegion()); Assertions.assertEquals(TEST_CLUSTER_NAME, sparkExecutionEngineConfig.getClusterName()); - Assertions.assertTrue(parameters.toString().contains(SPARK_SUBMIT_PARAMETERS)); } - String getConfigJson() { - return new JSONObject() - .put("applicationId", EMRS_APPLICATION_ID) - .put("executionRoleARN", EMRS_EXECUTION_ROLE) - .put("region", US_WEST_REGION) - .put("sparkSubmitParameters", SPARK_SUBMIT_PARAMETERS) - .toString(); + SparkExecutionEngineConfigClusterSetting getClusterSetting() { + return SparkExecutionEngineConfigClusterSetting.builder() + .applicationId(EMRS_APPLICATION_ID) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .region(US_WEST_REGION) + .sparkSubmitParameters(SPARK_SUBMIT_PARAMETERS) + .build(); } @Test void testGetSparkExecutionEngineConfigWithNullSetting() { SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(settings); - when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(null); + new SparkExecutionEngineConfigSupplierImpl( + settings, sparkExecutionEngineConfigClusterSettingLoader); when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + when(sparkExecutionEngineConfigClusterSettingLoader.load()).thenReturn(Optional.empty()); SparkExecutionEngineConfig sparkExecutionEngineConfig = sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index 5b4ffbea2c..15871bf6b2 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -8,6 +8,7 @@ public class TestConstants { public static final String TEST_DATASOURCE_NAME = "test_datasource_name"; public static final String EMR_JOB_ID = "job-123xxx"; + public static final String ACCOUNT_ID = "TEST_ACCOUNT_ID"; public static final String EMRS_APPLICATION_ID = "app-xxxxx"; public static final String EMRS_EXECUTION_ROLE = "execution_role"; public static final String SPARK_SUBMIT_PARAMETERS = "--conf org.flint.sql.SQLJob"; diff --git a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java index 06689a15d0..e5ca93e96e 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/SessionTestUtil.java @@ -9,7 +9,8 @@ import static org.opensearch.sql.spark.constants.TestConstants.TEST_DATASOURCE_NAME; import java.util.HashMap; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; public class SessionTestUtil { @@ -19,7 +20,7 @@ public static CreateSessionRequest createSessionRequest() { null, "appId", "arn", - SparkSubmitParameters.builder().build(), + new SparkSubmitParametersBuilder(new SparkParameterComposerCollection()), new HashMap<>(), "resultIndex", TEST_DATASOURCE_NAME); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java new file mode 100644 index 0000000000..55e62d52f0 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/parameter/S3GlueDataSourceSparkParameterComposerTest.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.parameter; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceStatus; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +class S3GlueDataSourceSparkParameterComposerTest { + + public static final String VALID_URI = "https://test.host.com:9200"; + public static final String INVALID_URI = "http://test/\r\n"; + public static final String USERNAME = "USERNAME"; + public static final String PASSWORD = "PASSWORD"; + public static final String REGION = "REGION"; + public static final String TRUE = "true"; + public static final String ROLE_ARN = "ROLE_ARN"; + + private static final String COMMON_EXPECTED_PARAMS = + " --class org.apache.spark.sql.FlintJob " + + getConfList( + "spark.emr-serverless.driverEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=ROLE_ARN", + "spark.executorEnv.ASSUME_ROLE_CREDENTIALS_ROLE_ARN=ROLE_ARN", + "spark.hive.metastore.glue.role.arn=ROLE_ARN", + "spark.sql.catalog.DATASOURCE_NAME=org.opensearch.sql.FlintDelegatingSessionCatalog", + "spark.flint.datasource.name=DATASOURCE_NAME", + "spark.emr-serverless.lakeformation.enabled=true", + "spark.flint.optimizer.covering.enabled=false", + "spark.datasource.flint.host=test.host.com", + "spark.datasource.flint.port=9200", + "spark.datasource.flint.scheme=https"); + + @Mock DispatchQueryRequest dispatchQueryRequest; + + @Test + public void testBasicAuth() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.BASICAUTH, VALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext()); + + assertEquals( + COMMON_EXPECTED_PARAMS + + getConfList( + "spark.datasource.flint.auth=basic", + "spark.datasource.flint.auth.username=USERNAME", + "spark.datasource.flint.auth.password=PASSWORD"), + sparkSubmitParameters.toString()); + } + + @Test + public void testComposeWithSigV4Auth() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.AWSSIGV4AUTH, VALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext()); + + assertEquals( + COMMON_EXPECTED_PARAMS + + getConfList( + "spark.datasource.flint.auth=sigv4", "spark.datasource.flint.region=REGION"), + sparkSubmitParameters.toString()); + } + + @Test + public void testComposeWithNoAuth() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.NOAUTH, VALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext()); + + assertEquals( + COMMON_EXPECTED_PARAMS + getConfList("spark.datasource.flint.auth=noauth"), + sparkSubmitParameters.toString()); + } + + @Test + public void testComposeWithBadUri() { + DataSourceMetadata dataSourceMetadata = + getDataSourceMetadata(AuthenticationType.NOAUTH, INVALID_URI); + SparkSubmitParameters sparkSubmitParameters = new SparkSubmitParameters(); + + assertThrows( + IllegalArgumentException.class, + () -> + new S3GlueDataSourceSparkParameterComposer() + .compose( + dataSourceMetadata, + sparkSubmitParameters, + dispatchQueryRequest, + new NullAsyncQueryRequestContext())); + } + + private DataSourceMetadata getDataSourceMetadata( + AuthenticationType authenticationType, String uri) { + return new DataSourceMetadata.Builder() + .setConnector(DataSourceType.S3GLUE) + .setName("DATASOURCE_NAME") + .setDescription("DESCRIPTION") + .setResultIndex("RESULT_INDEX") + .setDataSourceStatus(DataSourceStatus.ACTIVE) + .setProperties(getProperties(authenticationType, uri)) + .build(); + } + + private Map getProperties(AuthenticationType authType, String uri) { + return ImmutableMap.builder() + .put(GlueDataSourceFactory.GLUE_ROLE_ARN, ROLE_ARN) + .put(GlueDataSourceFactory.GLUE_LAKEFORMATION_ENABLED, TRUE) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI, uri) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH, authType.getName()) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME, USERNAME) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD, PASSWORD) + .put(GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_REGION, REGION) + .build(); + } + + private static String getConfList(String... params) { + return Arrays.stream(params) + .map(param -> String.format(" --conf %s ", param)) + .collect(Collectors.joining()); + } +}