diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java new file mode 100644 index 0000000000..1713bed4e2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import lombok.RequiredArgsConstructor; +import org.opensearch.client.Client; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.leasemanager.LeaseManager; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class QueryHandlerFactory { + + private final JobExecutionResponseReader jobExecutionResponseReader; + private final FlintIndexMetadataService flintIndexMetadataService; + private final Client client; + private final SessionManager sessionManager; + private final LeaseManager leaseManager; + private final StateStore stateStore; + private final EMRServerlessClientFactory emrServerlessClientFactory; + + public RefreshQueryHandler getRefreshQueryHandler() { + return new RefreshQueryHandler( + emrServerlessClientFactory.getClient(), + jobExecutionResponseReader, + flintIndexMetadataService, + stateStore, + leaseManager); + } + + public StreamingQueryHandler getStreamingQueryHandler() { + return new StreamingQueryHandler( + emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager); + } + + public BatchQueryHandler getBatchQueryHandler() { + return new BatchQueryHandler( + emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager); + } + + public InteractiveQueryHandler getInteractiveQueryHandler() { + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager); + } + + public IndexDMLHandler getIndexDMLHandler() { + return new IndexDMLHandler( + emrServerlessClientFactory.getClient(), + jobExecutionResponseReader, + flintIndexMetadataService, + stateStore, + client); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index c4f4c74868..b6f5bcceb3 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -8,14 +8,12 @@ import java.util.HashMap; import java.util.Map; import lombok.AllArgsConstructor; +import org.jetbrains.annotations.NotNull; import org.json.JSONObject; -import org.opensearch.client.Client; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -23,10 +21,6 @@ import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.flint.FlintIndexMetadataService; -import org.opensearch.sql.spark.leasemanager.LeaseManager; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.SQLQueryUtils; @@ -39,63 +33,67 @@ public class SparkQueryDispatcher { public static final String CLUSTER_NAME_TAG_KEY = "domain_ident"; public static final String JOB_TYPE_TAG_KEY = "type"; - private EMRServerlessClientFactory emrServerlessClientFactory; - - private DataSourceService dataSourceService; - - private JobExecutionResponseReader jobExecutionResponseReader; - - private FlintIndexMetadataService flintIndexMetadataService; - - private Client client; - - private SessionManager sessionManager; - - private LeaseManager leaseManager; - - private StateStore stateStore; + private final DataSourceService dataSourceService; + private final SessionManager sessionManager; + private final QueryHandlerFactory queryHandlerFactory; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); DataSourceMetadata dataSourceMetadata = this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( dispatchQueryRequest.getDatasource()); - AsyncQueryHandler asyncQueryHandler = - sessionManager.isEnabled() - ? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) - : new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); - DispatchQueryContext.DispatchQueryContextBuilder contextBuilder = - DispatchQueryContext.builder() - .dataSourceMetadata(dataSourceMetadata) - .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) - .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); - - // override asyncQueryHandler with specific. + if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) && SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { - IndexQueryDetails indexQueryDetails = - SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); - fillMissingDetails(dispatchQueryRequest, indexQueryDetails); - contextBuilder.indexQueryDetails(indexQueryDetails); - - if (isEligibleForIndexDMLHandling(indexQueryDetails)) { - asyncQueryHandler = createIndexDMLHandler(emrServerlessClient); - } else if (isEligibleForStreamingQuery(indexQueryDetails)) { - asyncQueryHandler = - new StreamingQueryHandler( - emrServerlessClient, jobExecutionResponseReader, leaseManager); - } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { - // manual refresh should be handled by batch handler - asyncQueryHandler = - new RefreshQueryHandler( - emrServerlessClient, - jobExecutionResponseReader, - flintIndexMetadataService, - stateStore, - leaseManager); - } + IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .indexQueryDetails(indexQueryDetails) + .build(); + + return getQueryHandlerForFlintExtensionQuery(indexQueryDetails) + .submit(dispatchQueryRequest, context); + } else { + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build(); + return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context); } - return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build()); + } + + private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( + DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { + return DispatchQueryContext.builder() + .dataSourceMetadata(dataSourceMetadata) + .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) + .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); + } + + private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( + IndexQueryDetails indexQueryDetails) { + if (isEligibleForIndexDMLHandling(indexQueryDetails)) { + return queryHandlerFactory.getIndexDMLHandler(); + } else if (isEligibleForStreamingQuery(indexQueryDetails)) { + return queryHandlerFactory.getStreamingQueryHandler(); + } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { + // manual refresh should be handled by batch handler + return queryHandlerFactory.getRefreshQueryHandler(); + } else { + return getDefaultAsyncQueryHandler(); + } + } + + @NotNull + private AsyncQueryHandler getDefaultAsyncQueryHandler() { + return sessionManager.isEnabled() + ? queryHandlerFactory.getInteractiveQueryHandler() + : queryHandlerFactory.getBatchQueryHandler(); + } + + @NotNull + private static IndexQueryDetails getIndexQueryDetails(DispatchQueryRequest dispatchQueryRequest) { + IndexQueryDetails indexQueryDetails = + SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + fillDatasourceName(dispatchQueryRequest, indexQueryDetails); + return indexQueryDetails; } private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) { @@ -119,58 +117,35 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); - if (asyncQueryJobMetadata.getSessionId() != null) { - return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) - .getQueryResponse(asyncQueryJobMetadata); - } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata); - } else { - return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager) - .getQueryResponse(asyncQueryJobMetadata); - } + return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) + .getQueryResponse(asyncQueryJobMetadata); } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); - AsyncQueryHandler queryHandler; + return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) + .cancelJob(asyncQueryJobMetadata); + } + + private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( + AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { - queryHandler = - new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager); + return queryHandlerFactory.getInteractiveQueryHandler(); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - queryHandler = createIndexDMLHandler(emrServerlessClient); + return queryHandlerFactory.getIndexDMLHandler(); } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { - queryHandler = - new RefreshQueryHandler( - emrServerlessClient, - jobExecutionResponseReader, - flintIndexMetadataService, - stateStore, - leaseManager); + return queryHandlerFactory.getRefreshQueryHandler(); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { - queryHandler = - new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + return queryHandlerFactory.getStreamingQueryHandler(); } else { - queryHandler = - new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + return queryHandlerFactory.getBatchQueryHandler(); } - return queryHandler.cancelJob(asyncQueryJobMetadata); - } - - private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) { - return new IndexDMLHandler( - emrServerlessClient, - jobExecutionResponseReader, - flintIndexMetadataService, - stateStore, - client); } // TODO: Revisit this logic. // Currently, Spark if datasource is not provided in query. // Spark Assumes the datasource to be catalog. // This is required to handle drop index case properly when datasource name is not provided. - private static void fillMissingDetails( + private static void fillDatasourceName( DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { if (indexQueryDetails.getFullyQualifiedTableName() != null && indexQueryDetails.getFullyQualifiedTableName().getDatasourceName() == null) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 9038870c63..f93d065855 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -25,6 +25,7 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -65,23 +66,29 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) { @Provides public SparkQueryDispatcher sparkQueryDispatcher( - EMRServerlessClientFactory emrServerlessClientFactory, DataSourceService dataSourceService, + SessionManager sessionManager, + QueryHandlerFactory queryHandlerFactory) { + return new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); + } + + @Provides + public QueryHandlerFactory queryhandlerFactory( JobExecutionResponseReader jobExecutionResponseReader, FlintIndexMetadataServiceImpl flintIndexMetadataReader, NodeClient client, SessionManager sessionManager, DefaultLeaseManager defaultLeaseManager, - StateStore stateStore) { - return new SparkQueryDispatcher( - emrServerlessClientFactory, - dataSourceService, + StateStore stateStore, + EMRServerlessClientFactory emrServerlessClientFactory) { + return new QueryHandlerFactory( jobExecutionResponseReader, flintIndexMetadataReader, client, sessionManager, defaultLeaseManager, - stateStore); + stateStore, + emrServerlessClientFactory); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index c4cb96391b..fdd094259f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -58,6 +58,7 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionModel; @@ -200,16 +201,20 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(stateStore); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClientFactory, - this.dataSourceService, + QueryHandlerFactory queryHandlerFactory = + new QueryHandlerFactory( jobExecutionResponseReader, new FlintIndexMetadataServiceImpl(client), client, new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), new DefaultLeaseManager(pluginSettings, stateStore), - stateStore); + stateStore, + emrServerlessClientFactory); + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + this.dataSourceService, + new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), + queryHandlerFactory); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index bdadbc13df..ec87a86717 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -110,21 +110,22 @@ public class SparkQueryDispatcherTest { @BeforeEach void setUp() { - sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClientFactory, - dataSourceService, + QueryHandlerFactory queryHandlerFactory = + new QueryHandlerFactory( jobExecutionResponseReader, flintIndexMetadataService, openSearchClient, sessionManager, leaseManager, - stateStore); - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + stateStore, + emrServerlessClientFactory); + sparkQueryDispatcher = + new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); } @Test void testDispatchSelectQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -162,6 +163,7 @@ void testDispatchSelectQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -216,6 +218,7 @@ void testDispatchSelectQueryWithLakeFormation() { @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -244,6 +247,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( @@ -253,6 +257,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -261,6 +266,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { @Test void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -369,6 +375,7 @@ void testDispatchSelectQueryFailedCreateSession() { @Test void testDispatchIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -419,6 +426,7 @@ void testDispatchIndexQuery() { @Test void testDispatchWithPPLQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -446,6 +454,7 @@ void testDispatchWithPPLQuery() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( @@ -455,6 +464,7 @@ void testDispatchWithPPLQuery() { LangType.PPL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -463,6 +473,7 @@ void testDispatchWithPPLQuery() { @Test void testDispatchQueryWithoutATableAndDataSourceName() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -508,6 +519,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { @Test void testDispatchIndexQueryWithoutADatasourceName() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -557,6 +569,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { @Test void testDispatchMaterializedViewQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_mv_1"); @@ -606,6 +619,7 @@ void testDispatchMaterializedViewQuery() { @Test void testDispatchShowMVQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -651,6 +665,7 @@ void testDispatchShowMVQuery() { @Test void testRefreshIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -696,6 +711,7 @@ void testRefreshIndexQuery() { @Test void testDispatchDescribeIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -744,6 +760,7 @@ void testDispatchWithWrongURI() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); String query = "select * from my_glue.default.http_logs"; + IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, @@ -756,6 +773,7 @@ void testDispatchWithWrongURI() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); + Assertions.assertEquals( "Bad URI in indexstore configuration of the : my_glue datasoure.", illegalArgumentException.getMessage()); @@ -766,6 +784,7 @@ void testDispatchWithUnSupportedDataSourceType() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus")) .thenReturn(constructPrometheusDataSourceType()); String query = "select * from my_prometheus.default.http_logs"; + UnsupportedOperationException unsupportedOperationException = Assertions.assertThrows( UnsupportedOperationException.class, @@ -778,6 +797,7 @@ void testDispatchWithUnSupportedDataSourceType() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); + Assertions.assertEquals( "UnSupported datasource type for async queries:: PROMETHEUS", unsupportedOperationException.getMessage()); @@ -785,12 +805,15 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @@ -845,24 +868,29 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @Test void testGetQueryResponse() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); - // simulate result index is not created yet when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(new JSONObject()); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + Assertions.assertEquals("PENDING", result.get("status")); } @@ -872,10 +900,10 @@ void testGetQueryResponseWithSession() { doReturn(Optional.of(statement)).when(session).get(any()); when(statement.getStatementModel().getError()).thenReturn("mock error"); doReturn(StatementState.WAITING).when(statement).getStatementState(); - doReturn(new JSONObject()) .when(jobExecutionResponseReader) .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + JSONObject result = sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); @@ -890,6 +918,7 @@ void testGetQueryResponseWithInvalidSession() { doReturn(new JSONObject()) .when(jobExecutionResponseReader) .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + IllegalArgumentException exception = Assertions.assertThrows( IllegalArgumentException.class, @@ -916,6 +945,7 @@ void testGetQueryResponseWithStatementNotExist() { () -> sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + verifyNoInteractions(emrServerlessClient); Assertions.assertEquals( "no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage()); @@ -949,6 +979,7 @@ void testGetQueryResponseWithSuccess() { @Test void testDispatchQueryWithExtraSparkSubmitParameters() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata);