From 3af2b73f3c9ac1f9d297abac7efe2d6a6d5c3d6b Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 13 Mar 2024 15:17:19 -0700 Subject: [PATCH] Refactor query param (#2519) * Refactor query param Signed-off-by: Louis Chu * Reduce scope of changes Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .../model/SparkSubmitParameters.java | 5 ++ .../spark/client/EmrServerlessClientImpl.java | 2 +- .../sql/spark/client/StartJobRequest.java | 1 - .../spark/data/constants/SparkConstants.java | 3 +- .../spark/dispatcher/BatchQueryHandler.java | 2 +- .../dispatcher/StreamingQueryHandler.java | 2 +- .../session/CreateSessionRequest.java | 12 +---- .../model/SparkSubmitParametersTest.java | 7 +++ .../client/EmrServerlessClientImplTest.java | 17 ++++--- .../sql/spark/client/StartJobRequestTest.java | 4 +- .../dispatcher/SparkQueryDispatcherTest.java | 51 ++++++++++--------- .../session/InteractiveSessionTest.java | 2 +- 12 files changed, 59 insertions(+), 49 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 7ddb92900d..e3fe931a9e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -85,6 +85,11 @@ public Builder clusterName(String clusterName) { return this; } + public Builder query(String query) { + config.put(FLINT_JOB_QUERY, query); + return this; + } + public Builder dataSource(DataSourceMetadata metadata) { if (DataSourceType.S3GLUE.equals(metadata.getConnector())) { String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN); diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index 82644a2fb2..3a47eb21a7 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -56,7 +56,7 @@ public String startJobRun(StartJobRequest startJobRequest) { .withSparkSubmit( new SparkSubmit() .withEntryPoint(SPARK_SQL_APPLICATION_JAR) - .withEntryPointArguments(startJobRequest.getQuery(), resultIndex) + .withEntryPointArguments(resultIndex) .withSparkSubmitParameters(startJobRequest.getSparkSubmitParams()))); StartJobRunResult startJobRunResult = diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index f57c8facee..b532c439c0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -19,7 +19,6 @@ public class StartJobRequest { public static final Long DEFAULT_JOB_TIMEOUT = 120L; - private final String query; private final String jobName; private final String applicationId; private final String executionRoleArn; diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 95b3c25b99..906a0b740a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -89,8 +89,9 @@ public class SparkConstants { public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER = "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; - + public static final String FLINT_JOB_QUERY = "spark.flint.job.query"; public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; + public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL"; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index ecab31ebc9..0153291eb8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -74,13 +74,13 @@ public DispatchQueryResponse submit( tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); StartJobRequest startJobRequest = new StartJobRequest( - dispatchQueryRequest.getQuery(), clusterName + ":" + JobType.BATCH.getText(), dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.Builder.builder() .clusterName(clusterName) .dataSource(context.getDataSourceMetadata()) + .query(dispatchQueryRequest.getQuery()) .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() .toString(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 4a3c052739..8170b41c66 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -65,13 +65,13 @@ public DispatchQueryResponse submit( + indexQueryDetails.openSearchIndexName(); StartJobRequest startJobRequest = new StartJobRequest( - dispatchQueryRequest.getQuery(), jobName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), SparkSubmitParameters.Builder.builder() .clusterName(clusterName) .dataSource(dataSourceMetadata) + .query(dispatchQueryRequest.getQuery()) .structuredStreaming(true) .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index 855e1ce5b2..419b125ab9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -23,7 +23,6 @@ public class CreateSessionRequest { public StartJobRequest getStartJobRequest(String sessionId) { return new InteractiveSessionStartJobRequest( - "select 1", clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId, applicationId, executionRoleArn, @@ -34,22 +33,13 @@ public StartJobRequest getStartJobRequest(String sessionId) { static class InteractiveSessionStartJobRequest extends StartJobRequest { public InteractiveSessionStartJobRequest( - String query, String jobName, String applicationId, String executionRoleArn, String sparkSubmitParams, Map tags, String resultIndex) { - super( - query, - jobName, - applicationId, - executionRoleArn, - sparkSubmitParams, - tags, - false, - resultIndex); + super(jobName, applicationId, executionRoleArn, sparkSubmitParams, tags, false, resultIndex); } /** Interactive query keep running. */ diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java index a914a975b9..9b47cfc43a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java @@ -27,4 +27,11 @@ public void testBuildWithExtraParameters() { // Assert the conf is included with a space assertTrue(params.endsWith(" --conf A=1")); } + + @Test + public void testBuildQueryString() { + String query = "SHOW tables LIKE \"%\";"; + String params = SparkSubmitParameters.Builder.builder().query(query).build().toString(); + assertTrue(params.contains(query)); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 51f9add1e8..a5123e0174 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -42,6 +42,7 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; @ExtendWith(MockitoExtension.class) public class EmrServerlessClientImplTest { @@ -66,13 +67,14 @@ void testStartJobRun() { when(emrServerless.startJobRun(any())).thenReturn(response); EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + String parameters = SparkSubmitParameters.Builder.builder().query(QUERY).build().toString(); + emrServerlessClient.startJobRun( new StartJobRequest( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - SPARK_SUBMIT_PARAMETERS, + parameters, new HashMap<>(), false, DEFAULT_RESULT_INDEX)); @@ -83,8 +85,14 @@ void testStartJobRun() { Assertions.assertEquals( ENTRY_POINT_START_JAR, startJobRunRequest.getJobDriver().getSparkSubmit().getEntryPoint()); Assertions.assertEquals( - List.of(QUERY, DEFAULT_RESULT_INDEX), + List.of(DEFAULT_RESULT_INDEX), startJobRunRequest.getJobDriver().getSparkSubmit().getEntryPointArguments()); + Assertions.assertTrue( + startJobRunRequest + .getJobDriver() + .getSparkSubmit() + .getSparkSubmitParameters() + .contains(QUERY)); } @Test @@ -97,7 +105,6 @@ void testStartJobRunWithErrorMetric() { () -> emrServerlessClient.startJobRun( new StartJobRequest( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -116,7 +123,6 @@ void testStartJobRunResultIndex() { EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); emrServerlessClient.startJobRun( new StartJobRequest( - QUERY, EMRS_JOB_NAME, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -185,7 +191,6 @@ void testStartJobRunWithLongJobName() { EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); emrServerlessClient.startJobRun( new StartJobRequest( - QUERY, RandomStringUtils.random(300), EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java index eb7d9634ec..3671cfaa42 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/StartJobRequestTest.java @@ -20,10 +20,10 @@ void executionTimeout() { } private StartJobRequest onDemandJob() { - return new StartJobRequest("", "", "", "", "", Map.of(), false, null); + return new StartJobRequest("", "", "", "", Map.of(), false, null); } private StartJobRequest streamingJob() { - return new StartJobRequest("", "", "", "", "", Map.of(), true, null); + return new StartJobRequest("", "", "", "", Map.of(), true, null); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index aa2ffacac9..d1d5033ee0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -140,10 +140,10 @@ void testDispatchSelectQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -186,10 +186,10 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -229,10 +229,10 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { new HashMap<>() { { } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -342,10 +342,10 @@ void testDispatchIndexQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - })); + }, + query)); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -388,10 +388,10 @@ void testDispatchWithPPLQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -432,10 +432,10 @@ void testDispatchQueryWithoutATableAndDataSourceName() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -481,10 +481,10 @@ void testDispatchIndexQueryWithoutADatasourceName() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - })); + }, + query)); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -530,10 +530,10 @@ void testDispatchMaterializedViewQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - })); + }, + query)); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:streaming:flint_mv_1", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -575,10 +575,10 @@ void testDispatchShowMVQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -620,10 +620,10 @@ void testRefreshIndexQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -665,10 +665,10 @@ void testDispatchDescribeIndexQuery() { { put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); } - }); + }, + query); StartJobRequest expected = new StartJobRequest( - query, "TEST_CLUSTER:batch", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, @@ -938,7 +938,7 @@ void testDispatchQueryWithExtraSparkSubmitParameters() { } private String constructExpectedSparkSubmitParameterString( - String auth, Map authParams) { + String auth, Map authParams, String query) { StringBuilder authParamConfigBuilder = new StringBuilder(); for (String key : authParams.keySet()) { authParamConfigBuilder.append(" --conf "); @@ -978,7 +978,10 @@ private String constructExpectedSparkSubmitParameterString( + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " + " --conf spark.flint.datasource.name=my_glue " - + authParamConfigBuilder; + + authParamConfigBuilder + + " --conf spark.flint.job.query=" + + query + + " "; } private String withStructuredStreaming(String parameters) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 5669716684..6112261336 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -43,7 +43,7 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { @Before public void setup() { emrsClient = new TestEMRServerlessClient(); - startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + startJobRequest = new StartJobRequest("", "appId", "", "", new HashMap<>(), false, ""); stateStore = new StateStore(client(), clusterService()); }