From 8b02fe86a8ee7f370285b85e5ddb424e21f06e29 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Thu, 25 Apr 2024 21:51:35 +0000 Subject: [PATCH 1/3] Refactor SparkQueryDispatcher Signed-off-by: Tomoyuki Morita --- .../spark/dispatcher/QueryHandlerFactory.java | 56 +++++++ .../dispatcher/SparkQueryDispatcher.java | 148 ++++++++---------- .../config/AsyncExecutorServiceModule.java | 24 ++- .../AsyncQueryExecutorServiceSpec.java | 15 +- .../dispatcher/SparkQueryDispatcherTest.java | 14 +- 5 files changed, 155 insertions(+), 102 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java new file mode 100644 index 0000000000..f7dcd071bb --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import lombok.RequiredArgsConstructor; +import org.opensearch.client.Client; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.leasemanager.LeaseManager; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class QueryHandlerFactory { + + private final JobExecutionResponseReader jobExecutionResponseReader; + private final FlintIndexMetadataService flintIndexMetadataService; + private final Client client; + private final SessionManager sessionManager; + private final LeaseManager leaseManager; + private final StateStore stateStore; + + public RefreshQueryHandler getRefreshQueryHandler(EMRServerlessClient emrServerlessClient) { + return new RefreshQueryHandler( + emrServerlessClient, + jobExecutionResponseReader, + flintIndexMetadataService, + stateStore, + leaseManager); + } + + public StreamingQueryHandler getStreamingQueryHandler(EMRServerlessClient emrServerlessClient) { + return new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + } + + public BatchQueryHandler getBatchQueryHandler(EMRServerlessClient emrServerlessClient) { + return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + } + + public InteractiveQueryHandler getInteractiveQueryHandler() { + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager); + } + + public IndexDMLHandler getIndexDMLHandler(EMRServerlessClient emrServerlessClient) { + return new IndexDMLHandler( + emrServerlessClient, + jobExecutionResponseReader, + flintIndexMetadataService, + stateStore, + client); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index c4f4c74868..bd2e5e4476 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -8,8 +8,8 @@ import java.util.HashMap; import java.util.Map; import lombok.AllArgsConstructor; +import org.jetbrains.annotations.NotNull; import org.json.JSONObject; -import org.opensearch.client.Client; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @@ -23,10 +23,6 @@ import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.flint.FlintIndexMetadataService; -import org.opensearch.sql.spark.leasemanager.LeaseManager; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.SQLQueryUtils; @@ -39,65 +35,68 @@ public class SparkQueryDispatcher { public static final String CLUSTER_NAME_TAG_KEY = "domain_ident"; public static final String JOB_TYPE_TAG_KEY = "type"; - private EMRServerlessClientFactory emrServerlessClientFactory; - - private DataSourceService dataSourceService; - - private JobExecutionResponseReader jobExecutionResponseReader; - - private FlintIndexMetadataService flintIndexMetadataService; - - private Client client; - - private SessionManager sessionManager; - - private LeaseManager leaseManager; - - private StateStore stateStore; + private final EMRServerlessClientFactory emrServerlessClientFactory; + private final DataSourceService dataSourceService; + private final SessionManager sessionManager; + private final QueryHandlerFactory queryHandlerFactory; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); DataSourceMetadata dataSourceMetadata = this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( dispatchQueryRequest.getDatasource()); - AsyncQueryHandler asyncQueryHandler = - sessionManager.isEnabled() - ? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) - : new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); - DispatchQueryContext.DispatchQueryContextBuilder contextBuilder = - DispatchQueryContext.builder() + + if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) + && SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { + IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); + DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .indexQueryDetails(indexQueryDetails) + .build(); + + return getQueryHandlerForFlintExtensionQuery(indexQueryDetails, emrServerlessClient) + .submit(dispatchQueryRequest, context); + } else { + DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .build(); + return getDefaultAsyncQueryHandler(emrServerlessClient).submit(dispatchQueryRequest, context); + } + } + + private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { + return DispatchQueryContext.builder() .dataSourceMetadata(dataSourceMetadata) .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); + } - // override asyncQueryHandler with specific. - if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) - && SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { - IndexQueryDetails indexQueryDetails = - SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); - fillMissingDetails(dispatchQueryRequest, indexQueryDetails); - contextBuilder.indexQueryDetails(indexQueryDetails); - - if (isEligibleForIndexDMLHandling(indexQueryDetails)) { - asyncQueryHandler = createIndexDMLHandler(emrServerlessClient); - } else if (isEligibleForStreamingQuery(indexQueryDetails)) { - asyncQueryHandler = - new StreamingQueryHandler( - emrServerlessClient, jobExecutionResponseReader, leaseManager); - } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { - // manual refresh should be handled by batch handler - asyncQueryHandler = - new RefreshQueryHandler( - emrServerlessClient, - jobExecutionResponseReader, - flintIndexMetadataService, - stateStore, - leaseManager); - } + private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(IndexQueryDetails indexQueryDetails, EMRServerlessClient emrServerlessClient) { + if (isEligibleForIndexDMLHandling(indexQueryDetails)) { + return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient); + } else if (isEligibleForStreamingQuery(indexQueryDetails)) { + return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient); + } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { + // manual refresh should be handled by batch handler + return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient); + } else { + return getDefaultAsyncQueryHandler(emrServerlessClient); } - return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build()); } + @NotNull + private AsyncQueryHandler getDefaultAsyncQueryHandler(EMRServerlessClient emrServerlessClient) { + return sessionManager.isEnabled() + ? queryHandlerFactory.getInteractiveQueryHandler() + : queryHandlerFactory.getBatchQueryHandler(emrServerlessClient); + } + + @NotNull + private static IndexQueryDetails getIndexQueryDetails(DispatchQueryRequest dispatchQueryRequest) { + IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + fillDatasourceName(dispatchQueryRequest, indexQueryDetails); + return indexQueryDetails; + } + + private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) { Boolean isCreateAutoRefreshIndex = IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType()) @@ -119,58 +118,33 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); - if (asyncQueryJobMetadata.getSessionId() != null) { - return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) - .getQueryResponse(asyncQueryJobMetadata); - } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata); - } else { - return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager) - .getQueryResponse(asyncQueryJobMetadata); - } + return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).getQueryResponse(asyncQueryJobMetadata); } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).cancelJob(asyncQueryJobMetadata); + } + + private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(AsyncQueryJobMetadata asyncQueryJobMetadata) { EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); - AsyncQueryHandler queryHandler; if (asyncQueryJobMetadata.getSessionId() != null) { - queryHandler = - new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager); + return queryHandlerFactory.getInteractiveQueryHandler(); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - queryHandler = createIndexDMLHandler(emrServerlessClient); + return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient); } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { - queryHandler = - new RefreshQueryHandler( - emrServerlessClient, - jobExecutionResponseReader, - flintIndexMetadataService, - stateStore, - leaseManager); + return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { - queryHandler = - new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient); } else { - queryHandler = - new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + return queryHandlerFactory.getBatchQueryHandler(emrServerlessClient); } - return queryHandler.cancelJob(asyncQueryJobMetadata); - } - - private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) { - return new IndexDMLHandler( - emrServerlessClient, - jobExecutionResponseReader, - flintIndexMetadataService, - stateStore, - client); } // TODO: Revisit this logic. // Currently, Spark if datasource is not provided in query. // Spark Assumes the datasource to be catalog. // This is required to handle drop index case properly when datasource name is not provided. - private static void fillMissingDetails( + private static void fillDatasourceName( DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { if (indexQueryDetails.getFullyQualifiedTableName() != null && indexQueryDetails.getFullyQualifiedTableName().getDatasourceName() == null) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 9038870c63..15bd543bda 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -25,6 +25,7 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -36,7 +37,8 @@ public class AsyncExecutorServiceModule extends AbstractModule { @Override - protected void configure() {} + protected void configure() { + } @Provides public AsyncQueryExecutorService asyncQueryExecutorService( @@ -67,15 +69,27 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) { public SparkQueryDispatcher sparkQueryDispatcher( EMRServerlessClientFactory emrServerlessClientFactory, DataSourceService dataSourceService, + SessionManager sessionManager, + QueryHandlerFactory queryHandlerFactory + ) { + return new SparkQueryDispatcher( + emrServerlessClientFactory, + dataSourceService, + sessionManager, + queryHandlerFactory + ); + } + + @Provides + public QueryHandlerFactory queryhandlerFactory( JobExecutionResponseReader jobExecutionResponseReader, FlintIndexMetadataServiceImpl flintIndexMetadataReader, NodeClient client, SessionManager sessionManager, DefaultLeaseManager defaultLeaseManager, - StateStore stateStore) { - return new SparkQueryDispatcher( - emrServerlessClientFactory, - dataSourceService, + StateStore stateStore + ) { + return new QueryHandlerFactory( jobExecutionResponseReader, flintIndexMetadataReader, client, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index c4cb96391b..bebb690737 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -58,6 +58,7 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionModel; @@ -200,16 +201,20 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( + jobExecutionResponseReader, + new FlintIndexMetadataServiceImpl(client), + client, + new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), + new DefaultLeaseManager(pluginSettings, stateStore), + stateStore + ); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( emrServerlessClientFactory, this.dataSourceService, - jobExecutionResponseReader, - new FlintIndexMetadataServiceImpl(client), - client, new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), - new DefaultLeaseManager(pluginSettings, stateStore), - stateStore); + queryHandlerFactory); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 1f250a0aea..76b747f440 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -110,16 +110,20 @@ public class SparkQueryDispatcherTest { @BeforeEach void setUp() { + QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( + jobExecutionResponseReader, + flintIndexMetadataService, + openSearchClient, + sessionManager, + leaseManager, + stateStore + ); sparkQueryDispatcher = new SparkQueryDispatcher( emrServerlessClientFactory, dataSourceService, - jobExecutionResponseReader, - flintIndexMetadataService, - openSearchClient, sessionManager, - leaseManager, - stateStore); + queryHandlerFactory); when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); } From bbed24adb6b440634f0e034860e2403b06c2018d Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Fri, 26 Apr 2024 22:39:52 +0000 Subject: [PATCH 2/3] Remove EMRServerlessClientFactory from SparkQueryDispatcher Signed-off-by: Tomoyuki Morita --- .../spark/dispatcher/QueryHandlerFactory.java | 21 +++--- .../dispatcher/SparkQueryDispatcher.java | 69 ++++++++++--------- .../config/AsyncExecutorServiceModule.java | 21 ++---- .../AsyncQueryExecutorServiceSpec.java | 18 ++--- .../dispatcher/SparkQueryDispatcherTest.java | 23 +++---- 5 files changed, 73 insertions(+), 79 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java index f7dcd071bb..1713bed4e2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -7,7 +7,7 @@ import lombok.RequiredArgsConstructor; import org.opensearch.client.Client; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; @@ -23,31 +23,34 @@ public class QueryHandlerFactory { private final SessionManager sessionManager; private final LeaseManager leaseManager; private final StateStore stateStore; + private final EMRServerlessClientFactory emrServerlessClientFactory; - public RefreshQueryHandler getRefreshQueryHandler(EMRServerlessClient emrServerlessClient) { + public RefreshQueryHandler getRefreshQueryHandler() { return new RefreshQueryHandler( - emrServerlessClient, + emrServerlessClientFactory.getClient(), jobExecutionResponseReader, flintIndexMetadataService, stateStore, leaseManager); } - public StreamingQueryHandler getStreamingQueryHandler(EMRServerlessClient emrServerlessClient) { - return new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + public StreamingQueryHandler getStreamingQueryHandler() { + return new StreamingQueryHandler( + emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager); } - public BatchQueryHandler getBatchQueryHandler(EMRServerlessClient emrServerlessClient) { - return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); + public BatchQueryHandler getBatchQueryHandler() { + return new BatchQueryHandler( + emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager); } public InteractiveQueryHandler getInteractiveQueryHandler() { return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager); } - public IndexDMLHandler getIndexDMLHandler(EMRServerlessClient emrServerlessClient) { + public IndexDMLHandler getIndexDMLHandler() { return new IndexDMLHandler( - emrServerlessClient, + emrServerlessClientFactory.getClient(), jobExecutionResponseReader, flintIndexMetadataService, stateStore, diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index bd2e5e4476..b6f5bcceb3 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -14,8 +14,6 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -35,13 +33,11 @@ public class SparkQueryDispatcher { public static final String CLUSTER_NAME_TAG_KEY = "domain_ident"; public static final String JOB_TYPE_TAG_KEY = "type"; - private final EMRServerlessClientFactory emrServerlessClientFactory; private final DataSourceService dataSourceService; private final SessionManager sessionManager; private final QueryHandlerFactory queryHandlerFactory; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); DataSourceMetadata dataSourceMetadata = this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( dispatchQueryRequest.getDatasource()); @@ -49,54 +45,57 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) && SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); - DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) - .indexQueryDetails(indexQueryDetails) - .build(); + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .indexQueryDetails(indexQueryDetails) + .build(); - return getQueryHandlerForFlintExtensionQuery(indexQueryDetails, emrServerlessClient) - .submit(dispatchQueryRequest, context); + return getQueryHandlerForFlintExtensionQuery(indexQueryDetails) + .submit(dispatchQueryRequest, context); } else { - DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) - .build(); - return getDefaultAsyncQueryHandler(emrServerlessClient).submit(dispatchQueryRequest, context); + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build(); + return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context); } } - private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { + private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( + DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { return DispatchQueryContext.builder() - .dataSourceMetadata(dataSourceMetadata) - .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) - .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); + .dataSourceMetadata(dataSourceMetadata) + .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) + .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); } - private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(IndexQueryDetails indexQueryDetails, EMRServerlessClient emrServerlessClient) { + private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( + IndexQueryDetails indexQueryDetails) { if (isEligibleForIndexDMLHandling(indexQueryDetails)) { - return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient); + return queryHandlerFactory.getIndexDMLHandler(); } else if (isEligibleForStreamingQuery(indexQueryDetails)) { - return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient); + return queryHandlerFactory.getStreamingQueryHandler(); } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { // manual refresh should be handled by batch handler - return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient); + return queryHandlerFactory.getRefreshQueryHandler(); } else { - return getDefaultAsyncQueryHandler(emrServerlessClient); + return getDefaultAsyncQueryHandler(); } } @NotNull - private AsyncQueryHandler getDefaultAsyncQueryHandler(EMRServerlessClient emrServerlessClient) { + private AsyncQueryHandler getDefaultAsyncQueryHandler() { return sessionManager.isEnabled() - ? queryHandlerFactory.getInteractiveQueryHandler() - : queryHandlerFactory.getBatchQueryHandler(emrServerlessClient); + ? queryHandlerFactory.getInteractiveQueryHandler() + : queryHandlerFactory.getBatchQueryHandler(); } @NotNull private static IndexQueryDetails getIndexQueryDetails(DispatchQueryRequest dispatchQueryRequest) { - IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + IndexQueryDetails indexQueryDetails = + SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); fillDatasourceName(dispatchQueryRequest, indexQueryDetails); return indexQueryDetails; } - private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) { Boolean isCreateAutoRefreshIndex = IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType()) @@ -118,25 +117,27 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).getQueryResponse(asyncQueryJobMetadata); + return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) + .getQueryResponse(asyncQueryJobMetadata); } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).cancelJob(asyncQueryJobMetadata); + return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) + .cancelJob(asyncQueryJobMetadata); } - private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(AsyncQueryJobMetadata asyncQueryJobMetadata) { - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); + private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( + AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { return queryHandlerFactory.getInteractiveQueryHandler(); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient); + return queryHandlerFactory.getIndexDMLHandler(); } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { - return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient); + return queryHandlerFactory.getRefreshQueryHandler(); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { - return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient); + return queryHandlerFactory.getStreamingQueryHandler(); } else { - return queryHandlerFactory.getBatchQueryHandler(emrServerlessClient); + return queryHandlerFactory.getBatchQueryHandler(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 15bd543bda..f93d065855 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -37,8 +37,7 @@ public class AsyncExecutorServiceModule extends AbstractModule { @Override - protected void configure() { - } + protected void configure() {} @Provides public AsyncQueryExecutorService asyncQueryExecutorService( @@ -67,17 +66,10 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) { @Provides public SparkQueryDispatcher sparkQueryDispatcher( - EMRServerlessClientFactory emrServerlessClientFactory, DataSourceService dataSourceService, SessionManager sessionManager, - QueryHandlerFactory queryHandlerFactory - ) { - return new SparkQueryDispatcher( - emrServerlessClientFactory, - dataSourceService, - sessionManager, - queryHandlerFactory - ); + QueryHandlerFactory queryHandlerFactory) { + return new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); } @Provides @@ -87,15 +79,16 @@ public QueryHandlerFactory queryhandlerFactory( NodeClient client, SessionManager sessionManager, DefaultLeaseManager defaultLeaseManager, - StateStore stateStore - ) { + StateStore stateStore, + EMRServerlessClientFactory emrServerlessClientFactory) { return new QueryHandlerFactory( jobExecutionResponseReader, flintIndexMetadataReader, client, sessionManager, defaultLeaseManager, - stateStore); + stateStore, + emrServerlessClientFactory); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index bebb690737..fdd094259f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -201,17 +201,17 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(stateStore); - QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( - jobExecutionResponseReader, - new FlintIndexMetadataServiceImpl(client), - client, - new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), - new DefaultLeaseManager(pluginSettings, stateStore), - stateStore - ); + QueryHandlerFactory queryHandlerFactory = + new QueryHandlerFactory( + jobExecutionResponseReader, + new FlintIndexMetadataServiceImpl(client), + client, + new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), + new DefaultLeaseManager(pluginSettings, stateStore), + stateStore, + emrServerlessClientFactory); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - emrServerlessClientFactory, this.dataSourceService, new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), queryHandlerFactory); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 76b747f440..49242728b7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -110,20 +110,17 @@ public class SparkQueryDispatcherTest { @BeforeEach void setUp() { - QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( - jobExecutionResponseReader, - flintIndexMetadataService, - openSearchClient, - sessionManager, - leaseManager, - stateStore - ); - sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClientFactory, - dataSourceService, + QueryHandlerFactory queryHandlerFactory = + new QueryHandlerFactory( + jobExecutionResponseReader, + flintIndexMetadataService, + openSearchClient, sessionManager, - queryHandlerFactory); + leaseManager, + stateStore, + emrServerlessClientFactory); + sparkQueryDispatcher = + new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); } From 02137dc180f35b97f942bd18dd4c7a2714e11936 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Mon, 29 Apr 2024 11:55:14 -0700 Subject: [PATCH 3/3] Fix unit test failures in SparkQueryDispatcherTest Signed-off-by: Tomoyuki Morita --- .../dispatcher/SparkQueryDispatcherTest.java | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 49242728b7..60f09330d7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -121,11 +121,11 @@ void setUp() { emrServerlessClientFactory); sparkQueryDispatcher = new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory); - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); } @Test void testDispatchSelectQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -163,6 +163,7 @@ void testDispatchSelectQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -171,6 +172,7 @@ void testDispatchSelectQuery() { @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -199,6 +201,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( @@ -208,6 +211,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -216,6 +220,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { @Test void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -324,6 +329,7 @@ void testDispatchSelectQueryFailedCreateSession() { @Test void testDispatchIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -374,6 +380,7 @@ void testDispatchIndexQuery() { @Test void testDispatchWithPPLQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -401,6 +408,7 @@ void testDispatchWithPPLQuery() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( @@ -410,6 +418,7 @@ void testDispatchWithPPLQuery() { LangType.PPL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -418,6 +427,7 @@ void testDispatchWithPPLQuery() { @Test void testDispatchQueryWithoutATableAndDataSourceName() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -463,6 +473,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { @Test void testDispatchIndexQueryWithoutADatasourceName() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -512,6 +523,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { @Test void testDispatchMaterializedViewQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_mv_1"); @@ -561,6 +573,7 @@ void testDispatchMaterializedViewQuery() { @Test void testDispatchShowMVQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -606,6 +619,7 @@ void testDispatchShowMVQuery() { @Test void testRefreshIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -651,6 +665,7 @@ void testRefreshIndexQuery() { @Test void testDispatchDescribeIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -699,6 +714,7 @@ void testDispatchWithWrongURI() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); String query = "select * from my_glue.default.http_logs"; + IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, @@ -711,6 +727,7 @@ void testDispatchWithWrongURI() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); + Assertions.assertEquals( "Bad URI in indexstore configuration of the : my_glue datasoure.", illegalArgumentException.getMessage()); @@ -721,6 +738,7 @@ void testDispatchWithUnSupportedDataSourceType() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus")) .thenReturn(constructPrometheusDataSourceType()); String query = "select * from my_prometheus.default.http_logs"; + UnsupportedOperationException unsupportedOperationException = Assertions.assertThrows( UnsupportedOperationException.class, @@ -733,6 +751,7 @@ void testDispatchWithUnSupportedDataSourceType() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME))); + Assertions.assertEquals( "UnSupported datasource type for async queries:: PROMETHEUS", unsupportedOperationException.getMessage()); @@ -740,12 +759,15 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @@ -800,24 +822,29 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @Test void testGetQueryResponse() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); - // simulate result index is not created yet when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(new JSONObject()); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + Assertions.assertEquals("PENDING", result.get("status")); } @@ -827,10 +854,10 @@ void testGetQueryResponseWithSession() { doReturn(Optional.of(statement)).when(session).get(any()); when(statement.getStatementModel().getError()).thenReturn("mock error"); doReturn(StatementState.WAITING).when(statement).getStatementState(); - doReturn(new JSONObject()) .when(jobExecutionResponseReader) .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + JSONObject result = sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); @@ -845,6 +872,7 @@ void testGetQueryResponseWithInvalidSession() { doReturn(new JSONObject()) .when(jobExecutionResponseReader) .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); + IllegalArgumentException exception = Assertions.assertThrows( IllegalArgumentException.class, @@ -871,6 +899,7 @@ void testGetQueryResponseWithStatementNotExist() { () -> sparkQueryDispatcher.getQueryResponse( asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + verifyNoInteractions(emrServerlessClient); Assertions.assertEquals( "no statement found. " + new StatementId(MOCK_STATEMENT_ID), exception.getMessage()); @@ -904,6 +933,7 @@ void testGetQueryResponseWithSuccess() { @Test void testDispatchQueryWithExtraSparkSubmitParameters() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata);