diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 7ddb92900d..d292c068bd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -25,13 +25,13 @@ import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.AuthenticationType; -/** Define Spark Submit Parameters. */ +/** Defines the parameters required for Spark submit command construction. */ @AllArgsConstructor @RequiredArgsConstructor public class SparkSubmitParameters { - public static final String SPACE = " "; - public static final String EQUALS = "="; - public static final String FLINT_BASIC_AUTH = "basic"; + private static final String SPACE = " "; + private static final String EQUALS = "="; + private static final String FLINT_BASIC_AUTH = "basic"; private final String className; private final Map config; @@ -40,34 +40,12 @@ public class SparkSubmitParameters { private String extraParameters; public static class Builder { - private String className; - private final Map config; + private final Map config = new LinkedHashMap<>(); private String extraParameters; private Builder() { - className = DEFAULT_CLASS_NAME; - config = new LinkedHashMap<>(); - - config.put(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); - config.put( - HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, - DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); - config.put( - SPARK_JAR_PACKAGES_KEY, - SPARK_STANDALONE_PACKAGE + "," + SPARK_LAUNCHER_PACKAGE + "," + PPL_STANDALONE_PACKAGE); - config.put(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); - config.put(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); - config.put(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); - config.put(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); - config.put(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); - config.put(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); - config.put(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); - config.put(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); - config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); - config.put(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); - config.put(SPARK_SQL_EXTENSIONS_KEY, FLINT_SQL_EXTENSION + "," + FLINT_PPL_EXTENSION); - config.put(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); + initializeDefaultConfigurations(); } public static Builder builder() { @@ -86,35 +64,101 @@ public Builder clusterName(String clusterName) { } public Builder dataSource(DataSourceMetadata metadata) { - if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { - String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); - - config.put(DRIVER_ENV_ASSUME_ROLE_ARN_KEY, roleArn); - config.put(EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, roleArn); - config.put(HIVE_METASTORE_GLUE_ARN_KEY, roleArn); - config.put("spark.sql.catalog." + metadata.getName(), FLINT_DELEGATE_CATALOG); - config.put(FLINT_DATA_SOURCE_KEY, metadata.getName()); - - setFlintIndexStoreHost( - parseUri( - metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_URI), metadata.getName())); - setFlintIndexStoreAuthProperties( - metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD), - () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_REGION)); - config.put("spark.flint.datasource.name", metadata.getName()); - return this; + if (!DataSourceType.S3GLUE.equals(metadata.getConnector())) { + throw new UnsupportedOperationException( + String.format( + "Unsupported datasource type for async queries: %s", metadata.getConnector())); } - throw new UnsupportedOperationException( - String.format( - "UnSupported datasource type for async queries:: %s", metadata.getConnector())); + + configureDataSource(metadata); + return this; + } + + public Builder extraParameters(String params) { + this.extraParameters = params; + return this; } - private void setFlintIndexStoreHost(URI uri) { + public Builder query(String query) { + config.put(FLINT_JOB_QUERY, query); + return this; + } + + public Builder sessionExecution(String sessionId, String datasourceName) { + config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + config.put(FLINT_JOB_SESSION_ID, sessionId); + return this; + } + + public Builder structuredStreaming(Boolean isStructuredStreaming) { + if (Boolean.TRUE.equals(isStructuredStreaming)) { + config.put("spark.flint.job.type", "streaming"); + } + return this; + } + + public SparkSubmitParameters build() { + return new SparkSubmitParameters(className, config, extraParameters); + } + + private void configureDataSource(DataSourceMetadata metadata) { + // DataSource specific configuration + String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); + + config.put(DRIVER_ENV_ASSUME_ROLE_ARN_KEY, roleArn); + config.put(EXECUTOR_ENV_ASSUME_ROLE_ARN_KEY, roleArn); + config.put(HIVE_METASTORE_GLUE_ARN_KEY, roleArn); + config.put("spark.sql.catalog." + metadata.getName(), FLINT_DELEGATE_CATALOG); + config.put(FLINT_DATA_SOURCE_KEY, metadata.getName()); + + URI uri = + parseUri( + metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_URI), metadata.getName()); config.put(FLINT_INDEX_STORE_HOST_KEY, uri.getHost()); config.put(FLINT_INDEX_STORE_PORT_KEY, String.valueOf(uri.getPort())); config.put(FLINT_INDEX_STORE_SCHEME_KEY, uri.getScheme()); + + setFlintIndexStoreAuthProperties( + metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_USERNAME), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_AUTH_PASSWORD), + () -> metadata.getProperties().get(GLUE_INDEX_STORE_OPENSEARCH_REGION)); + config.put("spark.flint.datasource.name", metadata.getName()); + } + + private void initializeDefaultConfigurations() { + className = DEFAULT_CLASS_NAME; + // Default configurations initialization + config.put(S3_AWS_CREDENTIALS_PROVIDER_KEY, DEFAULT_S3_AWS_CREDENTIALS_PROVIDER_VALUE); + config.put( + HADOOP_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY, + DEFAULT_GLUE_CATALOG_CREDENTIALS_PROVIDER_FACTORY_KEY); + config.put( + SPARK_JAR_PACKAGES_KEY, + SPARK_STANDALONE_PACKAGE + "," + SPARK_LAUNCHER_PACKAGE + "," + PPL_STANDALONE_PACKAGE); + config.put(SPARK_JAR_REPOSITORIES_KEY, AWS_SNAPSHOT_REPOSITORY); + config.put(SPARK_DRIVER_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + config.put(SPARK_EXECUTOR_ENV_JAVA_HOME_KEY, JAVA_HOME_LOCATION); + config.put(SPARK_DRIVER_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); + config.put(SPARK_EXECUTOR_ENV_FLINT_CLUSTER_NAME_KEY, FLINT_DEFAULT_CLUSTER_NAME); + config.put(FLINT_INDEX_STORE_HOST_KEY, FLINT_DEFAULT_HOST); + config.put(FLINT_INDEX_STORE_PORT_KEY, FLINT_DEFAULT_PORT); + config.put(FLINT_INDEX_STORE_SCHEME_KEY, FLINT_DEFAULT_SCHEME); + config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_DEFAULT_AUTH); + config.put(FLINT_CREDENTIALS_PROVIDER_KEY, EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER); + config.put(SPARK_SQL_EXTENSIONS_KEY, FLINT_SQL_EXTENSION + "," + FLINT_PPL_EXTENSION); + config.put(HIVE_METASTORE_CLASS_KEY, GLUE_HIVE_CATALOG_FACTORY_CLASS); + } + + private URI parseUri(String opensearchUri, String datasourceName) { + try { + return new URI(opensearchUri); + } catch (URISyntaxException e) { + throw new IllegalArgumentException( + String.format( + "Bad URI in indexstore configuration for datasource: %s.", datasourceName), + e); + } } private void setFlintIndexStoreAuthProperties( @@ -133,57 +177,20 @@ private void setFlintIndexStoreAuthProperties( config.put(FLINT_INDEX_STORE_AUTH_KEY, authType); } } - - private URI parseUri(String opensearchUri, String datasourceName) { - try { - return new URI(opensearchUri); - } catch (URISyntaxException e) { - throw new IllegalArgumentException( - String.format( - "Bad URI in indexstore configuration of the : %s datasoure.", datasourceName)); - } - } - - public Builder structuredStreaming(Boolean isStructuredStreaming) { - if (isStructuredStreaming) { - config.put("spark.flint.job.type", "streaming"); - } - return this; - } - - public Builder extraParameters(String params) { - extraParameters = params; - return this; - } - - public Builder sessionExecution(String sessionId, String datasourceName) { - config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - config.put(FLINT_JOB_SESSION_ID, sessionId); - return this; - } - - public SparkSubmitParameters build() { - return new SparkSubmitParameters(className, config, extraParameters); - } } @Override public String toString() { - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append(" --class "); - stringBuilder.append(this.className); - stringBuilder.append(SPACE); - for (String key : config.keySet()) { - stringBuilder.append(" --conf "); - stringBuilder.append(key); - stringBuilder.append(EQUALS); - stringBuilder.append(config.get(key)); - stringBuilder.append(SPACE); - } - - if (extraParameters != null) { - stringBuilder.append(extraParameters); - } + StringBuilder stringBuilder = new StringBuilder(" --class ").append(className).append(SPACE); + config.forEach( + (key, value) -> + stringBuilder + .append(" --conf ") + .append(key) + .append(EQUALS) + .append(value) + .append(SPACE)); + if (extraParameters != null) stringBuilder.append(extraParameters); return stringBuilder.toString(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index 913e1ac378..21b48ba75c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -53,7 +53,7 @@ public String startJobRun(StartJobRequest startJobRequest) { .withSparkSubmit( new SparkSubmit() .withEntryPoint(SPARK_SQL_APPLICATION_JAR) - .withEntryPointArguments(startJobRequest.getQuery(), resultIndex) + .withEntryPointArguments(resultIndex) .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index f57c8facee..b532c439c0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -19,7 +19,6 @@ public class StartJobRequest { public static final Long DEFAULT_JOB_TIMEOUT = 120L; - private final String query; private final String jobName; private final String applicationId; private final String executionRoleArn; diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 95b3c25b99..906a0b740a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -89,8 +89,9 @@ public class SparkConstants { public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; - + public static final String FLINT_JOB_QUERY = "spark.flint.job.query"; public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; + public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 46dec38038..e61e04b22a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -75,13 +75,13 @@ public DispatchQueryResponse submit( tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); StartJobRequest startJobRequest = new StartJobRequest( - dispatchQueryRequest.getQuery(), jobName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.Builder.builder() .clusterName(clusterName) .dataSource(context.getDataSourceMetadata()) + .query(dispatchQueryRequest.getQuery()) .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() .toString(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 75337a3dad..218d6b11b7 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -52,13 +52,13 @@ public DispatchQueryResponse submit( tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); StartJobRequest startJobRequest = new StartJobRequest( - dispatchQueryRequest.getQuery(), jobName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.Builder.builder() .clusterName(clusterName) .dataSource(dataSourceMetadata) + .query(dispatchQueryRequest.getQuery()) .structuredStreaming(true) .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index b2201fbd01..9807feb57c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -22,7 +22,6 @@ public class CreateSessionRequest { public StartJobRequest getStartJobRequest() { return new InteractiveSessionStartJobRequest( - "select 1", jobName, applicationId, executionRoleArn, @@ -33,22 +32,13 @@ public StartJobRequest getStartJobRequest() { static class InteractiveSessionStartJobRequest extends StartJobRequest { public InteractiveSessionStartJobRequest( - String query, String jobName, String applicationId, String executionRoleArn, String sparkSubmitParams, Map tags, String resultIndex) { - super( - query, - jobName, - applicationId, - executionRoleArn, - sparkSubmitParams, - tags, - false, - resultIndex); + super(jobName, applicationId, executionRoleArn, sparkSubmitParams, tags, false, resultIndex); } /** Interactive query keep running. */ diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java index a914a975b9..9b47cfc43a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java @@ -27,4 +27,11 @@ public void testBuildWithExtraParameters() { // Assert the conf is included with a space assertTrue(params.endsWith(" --conf A=1")); } + + @Test + public void testBuildQueryString() { + String query = "SHOW tables LIKE \"%\";"; + String params = SparkSubmitParameters.Builder.builder().query(query).build().toString(); + assertTrue(params.contains(query)); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 67f4d9eb40..8b1cbc3af2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -41,6 +41,7 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; @ExtendWith(MockitoExtension.class) public class EmrServerlessClientImplTest { @@ -65,13 +66,14 @@ void testStartJobRun() { when(emrServerless.startJobRun(any())).thenReturn(response); EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + String parameters = SparkSubmitParameters.Builder.builder().query(QUERY).build().toString(); + emrServerlessClient.startJobRun( new StartJobRequest( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - SPARK_SUBMIT_PARAMETERS, + parameters, new HashMap<>(), false, DEFAULT_RESULT_INDEX)); @@ -82,8 +84,14 @@ void testStartJobRun() { Assertions.assertEquals( ENTRY_POINT_START_JAR, startJobRunRequest.getJobDriver().getSparkSubmit().getEntryPoint()); Assertions.assertEquals( - List.of(QUERY, DEFAULT_RESULT_INDEX), + List.of(DEFAULT_RESULT_INDEX), startJobRunRequest.getJobDriver().getSparkSubmit().getEntryPointArguments()); + Assertions.assertTrue( + startJobRunRequest + .getJobDriver() + .getSparkSubmit() + .getSparkSubmitParameters() + .contains(QUERY)); } @Test @@ -96,7 +104,6 @@ void testStartJobRunWithErrorMetric() { () -> emrServerlessClient.startJobRun( new StartJobRequest( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -115,7 +122,6 @@ void testStartJobRunResultIndex() { EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); emrServerlessClient.startJobRun( new StartJobRequest( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java index eb7d9634ec..3671cfaa42 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java @@ -20,10 +20,10 @@ void executionTimeout() { } private StartJobRequest onDemandJob() { - return new StartJobRequest("", "", "", "", "", Map.of(), false, null); + return new StartJobRequest("", "", "", "", Map.of(), false, null); } private StartJobRequest streamingJob() { - return new StartJobRequest("", "", "", "", "", Map.of(), true, null); + return new StartJobRequest("", "", "", "", Map.of(), true, null); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 2a499e7d30..76bc2f97e0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -140,10 +140,10 @@ void testDispatchSelectQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -167,7 +167,6 @@ void testDispatchSelectQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -195,10 +194,10 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -222,7 +221,6 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -248,10 +246,10 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { new HashMap<>() { { } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -275,7 +273,6 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -367,10 +364,10 @@ void testDispatchIndexQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - })); + }, + query)); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -394,7 +391,6 @@ void testDispatchIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -421,10 +417,10 @@ void testDispatchWithPPLQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -448,7 +444,6 @@ void testDispatchWithPPLQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -475,10 +470,10 @@ void testDispatchQueryWithoutATableAndDataSourceName() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -502,7 +497,6 @@ void testDispatchQueryWithoutATableAndDataSourceName() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -533,10 +527,10 @@ void testDispatchIndexQueryWithoutADatasourceName() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - })); + }, + query)); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -560,7 +554,6 @@ void testDispatchIndexQueryWithoutADatasourceName() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -591,10 +584,10 @@ void testDispatchMaterializedViewQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - })); + }, + query)); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -618,7 +611,6 @@ void testDispatchMaterializedViewQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -645,10 +637,10 @@ void testDispatchShowMVQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -672,7 +664,6 @@ void testDispatchShowMVQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -699,10 +690,10 @@ void testRefreshIndexQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -726,7 +717,6 @@ void testRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -753,10 +743,10 @@ void testDispatchDescribeIndexQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); when(emrServerlessClient.startJobRun( new StartJobRequest( - query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -780,7 +770,6 @@ void testDispatchDescribeIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -811,7 +800,7 @@ void testDispatchWithWrongURI() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); Assertions.assertEquals( - "Bad URI in indexstore configuration of the : my_glue datasoure.", + "Bad URI in indexstore configuration for datasource: my_glue.", illegalArgumentException.getMessage()); } @@ -833,7 +822,7 @@ void testDispatchWithUnSupportedDataSourceType() { EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); Assertions.assertEquals( - "UnSupported datasource type for async queries:: PROMETHEUS", + "Unsupported datasource type for async queries: PROMETHEUS", unsupportedOperationException.getMessage()); } @@ -1035,7 +1024,7 @@ void testDispatchQueryWithExtraSparkSubmitParameters() { } private String constructExpectedSparkSubmitParameterString( - String auth, Map authParams) { + String auth, Map authParams, String query) { StringBuilder authParamConfigBuilder = new StringBuilder(); for (String key : authParams.keySet()) { authParamConfigBuilder.append(" --conf "); @@ -1075,7 +1064,10 @@ private String constructExpectedSparkSubmitParameterString( + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " + " --conf spark.flint.datasource.name=my_glue " - + authParamConfigBuilder; + + authParamConfigBuilder + + " --conf spark.flint.job.query=" + + query + + " "; } private String withStructuredStreaming(String parameters) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 338da431fb..adbe75c236 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -41,7 +41,7 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { @Before public void setup() { emrsClient = new TestEMRServerlessClient(); - startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + startJobRequest = new StartJobRequest("", "appId", "", "", new HashMap<>(), false, ""); stateStore = new StateStore(client(), clusterService()); }