diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index 89529c8c82..eda2c0c72c 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -45,6 +45,7 @@ Sample Request:: curl --location 'http://localhost:9200/_plugins/_async_query' \ --header 'Content-Type: application/json' \ --data '{ + "datasource" : "my_glue", "lang" : "sql", "query" : "select * from my_glue.default.http_logs limit 10" }' diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index a86aa82695..74065c2d20 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -69,6 +69,7 @@ public CreateAsyncQueryResponse createAsyncQuery( new DispatchQueryRequest( sparkExecutionEngineConfig.getApplicationId(), createAsyncQueryRequest.getQuery(), + createAsyncQueryRequest.getDatasource(), createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), clusterName.value())); diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 904d199663..ece79ea2dc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -48,7 +48,7 @@ public class SparkQueryDispatcher { public static final String TABLE_TAG_KEY = "table"; public static final String CLUSTER_NAME_TAG_KEY = "cluster"; - private EMRServerlessClient EMRServerlessClient; + private EMRServerlessClient emrServerlessClient; private DataSourceService dataSourceService; @@ -57,12 +57,12 @@ public class SparkQueryDispatcher { private JobExecutionResponseReader jobExecutionResponseReader; public String dispatch(DispatchQueryRequest dispatchQueryRequest) { - return EMRServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest)); + return emrServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest)); } // TODO : Fetch from Result Index and then make call to EMR Serverless. public JSONObject getQueryResponse(String applicationId, String queryId) { - GetJobRunResult getJobRunResult = EMRServerlessClient.getJobRunResult(applicationId, queryId); + GetJobRunResult getJobRunResult = emrServerlessClient.getJobRunResult(applicationId, queryId); JSONObject result = new JSONObject(); if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) { result = jobExecutionResponseReader.getResultFromOpensearchIndex(queryId); @@ -72,10 +72,13 @@ public JSONObject getQueryResponse(String applicationId, String queryId) { } public String cancelJob(String applicationId, String jobId) { - CancelJobRunResult cancelJobRunResult = EMRServerlessClient.cancelJobRun(applicationId, jobId); + CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(applicationId, jobId); return cancelJobRunResult.getJobRunId(); } + // we currently don't support index queries in PPL language. + // so we are treating all of them as non-index queries which don't require any kind of query + // parsing. private StartJobRequest getStartJobRequest(DispatchQueryRequest dispatchQueryRequest) { if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) @@ -83,9 +86,9 @@ private StartJobRequest getStartJobRequest(DispatchQueryRequest dispatchQueryReq else { return getStartJobRequestForNonIndexQueries(dispatchQueryRequest); } + } else { + return getStartJobRequestForNonIndexQueries(dispatchQueryRequest); } - throw new UnsupportedOperationException( - String.format("UnSupported Lang type:: %s", dispatchQueryRequest.getLangType())); } private String getDataSourceRoleARN(DataSourceMetadata dataSourceMetadata) { @@ -133,27 +136,17 @@ private String constructSparkParameters(String datasourceName) { private StartJobRequest getStartJobRequestForNonIndexQueries( DispatchQueryRequest dispatchQueryRequest) { StartJobRequest startJobRequest; - FullyQualifiedTableName fullyQualifiedTableName = - SQLQueryUtils.extractFullyQualifiedTableName(dispatchQueryRequest.getQuery()); - if (fullyQualifiedTableName.getDatasourceName() == null) { - throw new UnsupportedOperationException("Missing datasource in the query syntax."); - } dataSourceUserAuthorizationHelper.authorizeDataSource( - this.dataSourceService.getRawDataSourceMetadata( - fullyQualifiedTableName.getDatasourceName())); - String jobName = - dispatchQueryRequest.getClusterName() - + ":" - + fullyQualifiedTableName.getFullyQualifiedName(); - Map tags = - getDefaultTagsForJobSubmission(dispatchQueryRequest, fullyQualifiedTableName); + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); + String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; + Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), jobName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), + constructSparkParameters(dispatchQueryRequest.getDatasource()), tags); return startJobRequest; } @@ -163,46 +156,54 @@ private StartJobRequest getStartJobRequestForIndexRequest( StartJobRequest startJobRequest; IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - if (fullyQualifiedTableName.getDatasourceName() == null) { - throw new UnsupportedOperationException("Queries without a datasource are not supported"); - } dataSourceUserAuthorizationHelper.authorizeDataSource( - this.dataSourceService.getRawDataSourceMetadata( - fullyQualifiedTableName.getDatasourceName())); + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); String jobName = getJobNameForIndexQuery(dispatchQueryRequest, indexDetails, fullyQualifiedTableName); - Map tags = - getDefaultTagsForJobSubmission(dispatchQueryRequest, fullyQualifiedTableName); + Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); tags.put(INDEX_TAG_KEY, indexDetails.getIndexName()); + tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName()); + tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName()); startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), jobName, dispatchQueryRequest.getApplicationId(), dispatchQueryRequest.getExecutionRoleARN(), - constructSparkParameters(fullyQualifiedTableName.getDatasourceName()), + constructSparkParameters(dispatchQueryRequest.getDatasource()), tags); return startJobRequest; } private static Map getDefaultTagsForJobSubmission( - DispatchQueryRequest dispatchQueryRequest, FullyQualifiedTableName fullyQualifiedTableName) { + DispatchQueryRequest dispatchQueryRequest) { Map tags = new HashMap<>(); tags.put(CLUSTER_NAME_TAG_KEY, dispatchQueryRequest.getClusterName()); - tags.put(DATASOURCE_TAG_KEY, fullyQualifiedTableName.getDatasourceName()); - tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName()); - tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName()); + tags.put(DATASOURCE_TAG_KEY, dispatchQueryRequest.getDatasource()); return tags; } + // Our queries work with datasource name and without datasource name. + // Inorder to have a constant jobName in both the scenarios, + // we are adding data source name from dispatcher request to the jobName. private static String getJobNameForIndexQuery( DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails, FullyQualifiedTableName fullyQualifiedTableName) { - return dispatchQueryRequest.getClusterName() - + ":" - + fullyQualifiedTableName.getFullyQualifiedName() - + "." - + indexDetails.getIndexName(); + if (fullyQualifiedTableName.getDatasourceName() == null) { + return dispatchQueryRequest.getClusterName() + + ":" + + dispatchQueryRequest.getDatasource() + + "." + + fullyQualifiedTableName.getFullyQualifiedName() + + "." + + indexDetails.getIndexName(); + } else { + return dispatchQueryRequest.getClusterName() + + ":" + + fullyQualifiedTableName.getFullyQualifiedName() + + "." + + indexDetails.getIndexName(); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 330eb3a03e..09240278ee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -12,6 +12,7 @@ public class DispatchQueryRequest { private final String applicationId; private final String query; + private final String datasource; private final LangType langType; private final String executionRoleARN; private final String clusterName; diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index c1ad979877..eb6260b16c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; +import java.util.HashSet; import lombok.AllArgsConstructor; import lombok.Data; import org.opensearch.core.xcontent.XContentParser; @@ -17,27 +18,44 @@ public class CreateAsyncQueryRequest { private String query; + private String datasource; private LangType lang; public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) throws IOException { String query = null; LangType lang = null; + String datasource = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + HashSet missingFields = new HashSet<>(); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); if (fieldName.equals("query")) { query = parser.textOrNull(); + if (query == null) { + missingFields.add("query"); + } } else if (fieldName.equals("lang")) { - lang = LangType.fromString(parser.textOrNull()); + String langString = parser.textOrNull(); + if (langString == null) { + missingFields.add("lang"); + } + lang = LangType.fromString(langString); + } else if (fieldName.equals("datasource")) { + datasource = parser.textOrNull(); + if (datasource == null) { + missingFields.add("datasource"); + } } else { throw new IllegalArgumentException("Unknown field: " + fieldName); } } - if (lang == null || query == null) { - throw new IllegalArgumentException("lang and query are required fields."); + + if (missingFields.size() > 0) { + throw new IllegalArgumentException( + String.format("Missing %s fields in the query request", missingFields)); } - return new CreateAsyncQueryRequest(query, lang); + return new CreateAsyncQueryRequest(query, datasource, lang); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 1ff2493e6d..ff1b17473a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -48,7 +48,8 @@ void testCreateAsyncQuery() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("select * from my_glue.default.http_logs", LangType.SQL); + new CreateAsyncQueryRequest( + "select * from my_glue.default.http_logs", "my_glue", LangType.SQL); when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) .thenReturn( "{\"applicationId\":\"00fd775baqpu4g0p\",\"executionRoleARN\":\"arn:aws:iam::270824043731:role/emr-job-execution-role\",\"region\":\"eu-west-1\"}"); @@ -58,6 +59,7 @@ void testCreateAsyncQuery() { new DispatchQueryRequest( "00fd775baqpu4g0p", "select * from my_glue.default.http_logs", + "my_glue", LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) @@ -73,6 +75,7 @@ void testCreateAsyncQuery() { new DispatchQueryRequest( "00fd775baqpu4g0p", "select * from my_glue.default.http_logs", + "my_glue", LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME)); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index d83505fde0..6d8f9b075b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -42,7 +42,7 @@ @ExtendWith(MockitoExtension.class) public class SparkQueryDispatcherTest { - @Mock private EMRServerlessClient EMRServerlessClient; + @Mock private EMRServerlessClient emrServerlessClient; @Mock private DataSourceService dataSourceService; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @@ -51,20 +51,18 @@ public class SparkQueryDispatcherTest { void testDispatchSelectQuery() { SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); - tags.put("table", "http_logs"); tags.put("cluster", TEST_CLUSTER_NAME); - tags.put("schema", "default"); String query = "select * from my_glue.default.http_logs"; - when(EMRServerlessClient.startJobRun( + when(emrServerlessClient.startJobRun( new StartJobRequest( query, - "TEST_CLUSTER:my_glue.default.http_logs", + "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, constructExpectedSparkSubmitParameterString(), @@ -76,12 +74,17 @@ void testDispatchSelectQuery() { String jobId = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( - EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(EMRServerlessClient, times(1)) + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)) .startJobRun( new StartJobRequest( query, - "TEST_CLUSTER:my_glue.default.http_logs", + "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, constructExpectedSparkSubmitParameterString(), @@ -93,7 +96,7 @@ void testDispatchSelectQuery() { void testDispatchIndexQuery() { SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); @@ -106,7 +109,7 @@ void testDispatchIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - when(EMRServerlessClient.startJobRun( + when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:my_glue.default.http_logs.elb_and_requestUri", @@ -121,8 +124,13 @@ void testDispatchIndexQuery() { String jobId = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( - EMRS_APPLICATION_ID, query, LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(EMRServerlessClient, times(1)) + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)) .startJobRun( new StartJobRequest( query, @@ -136,124 +144,149 @@ void testDispatchIndexQuery() { @Test void testDispatchWithPPLQuery() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); - String query = "select * from my_glue.default.http_logs"; - UnsupportedOperationException unsupportedOperationException = - Assertions.assertThrows( - UnsupportedOperationException.class, - () -> - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - LangType.PPL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME))); - Assertions.assertEquals( - "UnSupported Lang type:: PPL", unsupportedOperationException.getMessage()); - verifyNoInteractions(EMRServerlessClient); - verifyNoInteractions(dataSourceService); - verifyNoInteractions(dataSourceUserAuthorizationHelper); - verifyNoInteractions(jobExecutionResponseReader); + String query = "source = my_glue.default.http_logs"; + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + String jobId = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.PPL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)) + .startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags)); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test - void testDispatchQueryWithoutATableName() { + void testDispatchQueryWithoutATableAndDataSourceName() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("cluster", TEST_CLUSTER_NAME); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); String query = "show tables"; - UnsupportedOperationException unsupportedOperationException = - Assertions.assertThrows( - UnsupportedOperationException.class, - () -> - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME))); - Assertions.assertEquals( - "Missing datasource in the query syntax.", unsupportedOperationException.getMessage()); - verifyNoInteractions(EMRServerlessClient); - verifyNoInteractions(dataSourceService); - verifyNoInteractions(dataSourceUserAuthorizationHelper); - verifyNoInteractions(jobExecutionResponseReader); - } - - @Test - void testDispatchQueryWithoutADataSourceName() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - EMRServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); - String query = "select * from default.http_logs"; - UnsupportedOperationException unsupportedOperationException = - Assertions.assertThrows( - UnsupportedOperationException.class, - () -> - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME))); - Assertions.assertEquals( - "Missing datasource in the query syntax.", unsupportedOperationException.getMessage()); - verifyNoInteractions(EMRServerlessClient); - verifyNoInteractions(dataSourceService); - verifyNoInteractions(dataSourceUserAuthorizationHelper); - verifyNoInteractions(jobExecutionResponseReader); + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + String jobId = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)) + .startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags)); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test void testDispatchIndexQueryWithoutADatasourceName() { + HashMap tags = new HashMap<>(); + tags.put("datasource", "my_glue"); + tags.put("table", "http_logs"); + tags.put("index", "elb_and_requestUri"); + tags.put("cluster", TEST_CLUSTER_NAME); + tags.put("schema", "default"); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - UnsupportedOperationException unsupportedOperationException = - Assertions.assertThrows( - UnsupportedOperationException.class, - () -> - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME))); - Assertions.assertEquals( - "Queries without a datasource are not supported", - unsupportedOperationException.getMessage()); - verifyNoInteractions(EMRServerlessClient); - verifyNoInteractions(dataSourceService); - verifyNoInteractions(dataSourceUserAuthorizationHelper); - verifyNoInteractions(jobExecutionResponseReader); + when(emrServerlessClient.startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:my_glue.default.http_logs.elb_and_requestUri", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags))) + .thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + String jobId = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)) + .startJobRun( + new StartJobRequest( + query, + "TEST_CLUSTER:my_glue.default.http_logs.elb_and_requestUri", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString(), + tags)); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test void testDispatchWithWrongURI() { SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); @@ -268,6 +301,7 @@ void testDispatchWithWrongURI() { new DispatchQueryRequest( EMRS_APPLICATION_ID, query, + "my_glue", LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); @@ -280,7 +314,7 @@ void testDispatchWithWrongURI() { void testDispatchWithUnSupportedDataSourceType() { SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); @@ -295,6 +329,7 @@ void testDispatchWithUnSupportedDataSourceType() { new DispatchQueryRequest( EMRS_APPLICATION_ID, query, + "my_prometheus", LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); @@ -307,11 +342,11 @@ void testDispatchWithUnSupportedDataSourceType() { void testCancelJob() { SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); - when(EMRServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) @@ -324,11 +359,11 @@ void testCancelJob() { void testGetQueryResponse() { SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); - when(EMRServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); Assertions.assertEquals("PENDING", result.get("status")); @@ -339,18 +374,18 @@ void testGetQueryResponse() { void testGetQueryResponseWithSuccess() { SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader); - when(EMRServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); JSONObject queryResult = new JSONObject(); queryResult.put("data", "result"); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)) .thenReturn(queryResult); JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); - verify(EMRServerlessClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(emrServerlessClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID); Assertions.assertEquals(new HashSet<>(Arrays.asList("data", "status")), result.keySet()); Assertions.assertEquals("result", result.get("data")); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java index ef49d29829..8599e4b88e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestActionTest.java @@ -57,7 +57,7 @@ public void setUp() { @Test public void testDoExecute() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", LangType.SQL); + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "my_glue", LangType.SQL); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); when(jobExecutorService.createAsyncQuery(createAsyncQueryRequest)) @@ -73,7 +73,7 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { CreateAsyncQueryRequest createAsyncQueryRequest = - new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", LangType.SQL); + new CreateAsyncQueryRequest("source = my_glue.default.alb_logs", "my_glue", LangType.SQL); CreateAsyncQueryActionRequest request = new CreateAsyncQueryActionRequest(createAsyncQueryRequest); doThrow(new RuntimeException("Error"))