From e24b51fd392985bee4224f4c2a60d0dbe2a5e1b4 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Fri, 28 Jun 2024 16:25:45 -0700 Subject: [PATCH] Pass accountId to EMRServerlessClientFactory.getClient (#2783) Signed-off-by: Tomoyuki Morita --- .../client/EMRServerlessClientFactory.java | 3 +- .../EMRServerlessClientFactoryImpl.java | 8 +--- .../spark/dispatcher/QueryHandlerFactory.java | 12 ++--- .../dispatcher/SparkQueryDispatcher.java | 25 +++++----- .../execution/session/SessionManager.java | 4 +- .../spark/flint/operation/FlintIndexOp.java | 3 +- .../EMRServerlessClientFactoryImplTest.java | 19 +++++--- .../dispatcher/SparkQueryDispatcherTest.java | 36 +++++++------- ...AsyncQueryExecutorServiceImplSpecTest.java | 34 ++++++------- .../AsyncQueryExecutorServiceSpec.java | 2 +- .../AsyncQueryGetResultSpecTest.java | 2 +- .../asyncquery/IndexQuerySpecAlterTest.java | 48 +++++++++---------- .../spark/asyncquery/IndexQuerySpecTest.java | 36 +++++++------- .../asyncquery/IndexQuerySpecVacuumTest.java | 2 +- .../FlintStreamingJobHouseKeeperTaskTest.java | 40 ++++++++++++---- .../session/InteractiveSessionTest.java | 2 +- .../execution/statement/StatementTest.java | 4 +- 17 files changed, 152 insertions(+), 128 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java index 2c05dc865d..c5305ba445 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java @@ -11,7 +11,8 @@ public interface EMRServerlessClientFactory { /** * Gets an instance of {@link EMRServerlessClient}. * + * @param accountId Account ID of the requester. It will be used to decide the cluster. * @return An {@link EMRServerlessClient} instance. */ - EMRServerlessClient getClient(); + EMRServerlessClient getClient(String accountId); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java index 33c0e9fbfa..72973b3bbb 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java @@ -18,7 +18,6 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.metrics.MetricsService; -/** Implementation of {@link EMRServerlessClientFactory}. */ @RequiredArgsConstructor public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory { @@ -27,13 +26,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor private EMRServerlessClient emrServerlessClient; private String region; - /** - * Gets an instance of {@link EMRServerlessClient}. - * - * @return An {@link EMRServerlessClient} instance. - */ @Override - public EMRServerlessClient getClient() { + public EMRServerlessClient getClient(String accountId) { SparkExecutionEngineConfig sparkExecutionEngineConfig = this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig( new NullAsyncQueryRequestContext()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java index 9951edc5a9..90329f2f9a 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -27,9 +27,9 @@ public class QueryHandlerFactory { private final EMRServerlessClientFactory emrServerlessClientFactory; private final MetricsService metricsService; - public RefreshQueryHandler getRefreshQueryHandler() { + public RefreshQueryHandler getRefreshQueryHandler(String accountId) { return new RefreshQueryHandler( - emrServerlessClientFactory.getClient(), + emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, flintIndexMetadataService, leaseManager, @@ -37,17 +37,17 @@ public RefreshQueryHandler getRefreshQueryHandler() { metricsService); } - public StreamingQueryHandler getStreamingQueryHandler() { + public StreamingQueryHandler getStreamingQueryHandler(String accountId) { return new StreamingQueryHandler( - emrServerlessClientFactory.getClient(), + emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, leaseManager, metricsService); } - public BatchQueryHandler getBatchQueryHandler() { + public BatchQueryHandler getBatchQueryHandler(String accountId) { return new BatchQueryHandler( - emrServerlessClientFactory.getClient(), + emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, leaseManager, metricsService); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 5facdee567..3366e21894 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -54,14 +54,15 @@ public DispatchQueryResponse dispatch( .asyncQueryRequestContext(asyncQueryRequestContext) .build(); - return getQueryHandlerForFlintExtensionQuery(indexQueryDetails) + return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails) .submit(dispatchQueryRequest, context); } else { DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); - return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context); + return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()) + .submit(dispatchQueryRequest, context); } } @@ -74,28 +75,28 @@ private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchConte } private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( - IndexQueryDetails indexQueryDetails) { + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { if (isEligibleForIndexDMLHandling(indexQueryDetails)) { return queryHandlerFactory.getIndexDMLHandler(); } else if (isEligibleForStreamingQuery(indexQueryDetails)) { - return queryHandlerFactory.getStreamingQueryHandler(); + return queryHandlerFactory.getStreamingQueryHandler(dispatchQueryRequest.getAccountId()); } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) { // Create should be handled by batch handler. This is to avoid DROP index incorrectly cancel // an interactive job. - return queryHandlerFactory.getBatchQueryHandler(); + return queryHandlerFactory.getBatchQueryHandler(dispatchQueryRequest.getAccountId()); } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { // Manual refresh should be handled by batch handler - return queryHandlerFactory.getRefreshQueryHandler(); + return queryHandlerFactory.getRefreshQueryHandler(dispatchQueryRequest.getAccountId()); } else { - return getDefaultAsyncQueryHandler(); + return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()); } } @NotNull - private AsyncQueryHandler getDefaultAsyncQueryHandler() { + private AsyncQueryHandler getDefaultAsyncQueryHandler(String accountId) { return sessionManager.isEnabled() ? queryHandlerFactory.getInteractiveQueryHandler() - : queryHandlerFactory.getBatchQueryHandler(); + : queryHandlerFactory.getBatchQueryHandler(accountId); } @NotNull @@ -143,11 +144,11 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { return queryHandlerFactory.getIndexDMLHandler(); } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { - return queryHandlerFactory.getRefreshQueryHandler(); + return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId()); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { - return queryHandlerFactory.getStreamingQueryHandler(); + return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId()); } else { - return queryHandlerFactory.getBatchQueryHandler(); + return queryHandlerFactory.getBatchQueryHandler(asyncQueryJobMetadata.getAccountId()); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index f838e89572..0c0727294b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -33,7 +33,7 @@ public Session createSession( .sessionId(sessionIdProvider.getSessionId(request)) .sessionStorageService(sessionStorageService) .statementStorageService(statementStorageService) - .serverlessClient(emrServerlessClientFactory.getClient()) + .serverlessClient(emrServerlessClientFactory.getClient(request.getAccountId())) .build(); session.open(request, asyncQueryRequestContext); return session; @@ -65,7 +65,7 @@ public Optional getSession(String sessionId, String dataSourceName) { .sessionId(sessionId) .sessionStorageService(sessionStorageService) .statementStorageService(statementStorageService) - .serverlessClient(emrServerlessClientFactory.getClient()) + .serverlessClient(emrServerlessClientFactory.getClient(model.get().getAccountId())) .sessionModel(model.get()) .sessionInactivityTimeoutMilli( sessionConfigSupplier.getSessionInactivityTimeoutMillis()) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 97ddccaf8f..244f4aee11 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -145,7 +145,8 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel) throws InterruptedException, TimeoutException { String applicationId = flintIndexStateModel.getApplicationId(); String jobId = flintIndexStateModel.getJobId(); - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrServerlessClient = + emrServerlessClientFactory.getClient(flintIndexStateModel.getAccountId()); try { emrServerlessClient.cancelJobRun( flintIndexStateModel.getApplicationId(), flintIndexStateModel.getJobId(), true); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java index a27363a153..309d29c600 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java @@ -21,6 +21,7 @@ @ExtendWith(MockitoExtension.class) public class EMRServerlessClientFactoryImplTest { + public static final String ACCOUNT_ID = "accountId"; @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; @Mock private MetricsService metricsService; @@ -30,7 +31,9 @@ public void testGetClient() { .thenReturn(createSparkExecutionEngineConfig()); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); - EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(); + + EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID); + Assertions.assertNotNull(emrserverlessClient); } @@ -41,16 +44,16 @@ public void testGetClientWithChangeInSetting() { .thenReturn(sparkExecutionEngineConfig); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); - EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID); Assertions.assertNotNull(emrserverlessClient); - EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(ACCOUNT_ID); Assertions.assertEquals(emrServerlessClient1, emrserverlessClient); sparkExecutionEngineConfig.setRegion(TestConstants.US_WEST_REGION); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn(sparkExecutionEngineConfig); - EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(ACCOUNT_ID); Assertions.assertNotEquals(emrServerlessClient2, emrserverlessClient); Assertions.assertNotEquals(emrServerlessClient2, emrServerlessClient1); } @@ -60,9 +63,11 @@ public void testGetClientWithException() { when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())).thenReturn(null); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); + IllegalArgumentException illegalArgumentException = Assertions.assertThrows( - IllegalArgumentException.class, emrServerlessClientFactory::getClient); + IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID)); + Assertions.assertEquals( "Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config" + " in cluster settings to enable them.", @@ -77,9 +82,11 @@ public void testGetClientWithExceptionWithNullRegion() { .thenReturn(sparkExecutionEngineConfig); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); + IllegalArgumentException illegalArgumentException = Assertions.assertThrows( - IllegalArgumentException.class, emrServerlessClientFactory::getClient); + IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID)); + Assertions.assertEquals( "Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config" + " in cluster settings to enable them.", diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index d57284b9ca..b9c95f66cc 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -128,7 +128,7 @@ void setUp() { @Test void testDispatchSelectQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -179,7 +179,7 @@ void testDispatchSelectQuery() { @Test void testDispatchSelectQueryWithLakeFormation() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -220,7 +220,7 @@ void testDispatchSelectQueryWithLakeFormation() { @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -262,7 +262,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { @Test void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -368,7 +368,7 @@ void testDispatchSelectQueryFailedCreateSession() { @Test void testDispatchCreateAutoRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -413,7 +413,7 @@ void testDispatchCreateAutoRefreshIndexQuery() { @Test void testDispatchCreateManualRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -456,7 +456,7 @@ void testDispatchCreateManualRefreshIndexQuery() { @Test void testDispatchWithPPLQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -499,7 +499,7 @@ void testDispatchWithPPLQuery() { @Test void testDispatchQueryWithoutATableAndDataSourceName() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -540,7 +540,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { @Test void testDispatchIndexQueryWithoutADatasourceName() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -585,7 +585,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { @Test void testDispatchMaterializedViewQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_mv_1"); @@ -630,7 +630,7 @@ void testDispatchMaterializedViewQuery() { @Test void testDispatchShowMVQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -671,7 +671,7 @@ void testDispatchShowMVQuery() { @Test void testRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -712,7 +712,7 @@ void testRefreshIndexQuery() { @Test void testDispatchDescribeIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -753,7 +753,7 @@ void testDispatchDescribeIndexQuery() { @Test void testDispatchAlterToAutoRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -906,7 +906,7 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() @@ -968,7 +968,7 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() @@ -982,7 +982,7 @@ void testCancelQueryWithNoSessionId() { @Test void testGetQueryResponse() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); // simulate result index is not created yet @@ -1079,7 +1079,7 @@ void testGetQueryResponseWithSuccess() { @Test void testDispatchQueryWithExtraSparkSubmitParameters() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) .thenReturn(dataSourceMetadata); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index f8b61aee5a..3ff806bf50 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -49,7 +49,7 @@ public class AsyncQueryExecutorServiceImplSpecTest extends AsyncQueryExecutorSer @Disabled("batch query is unsupported") public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -79,7 +79,7 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { @Disabled("batch query is unsupported") public void sessionLimitNotImpactBatchQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -104,7 +104,7 @@ public void sessionLimitNotImpactBatchQuery() { @Disabled("batch query is unsupported") public void createAsyncQueryCreateJobWithCorrectParameters() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -140,7 +140,7 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { @Test public void withSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -170,7 +170,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { @Test public void reuseSessionWhenCreateAsyncQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -224,7 +224,7 @@ public void reuseSessionWhenCreateAsyncQuery() { @Disabled("batch query is unsupported") public void batchQueryHasTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -240,7 +240,7 @@ public void batchQueryHasTimeout() { @Test public void interactiveQueryNoTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -274,7 +274,7 @@ public void datasourceWithBasicAuth() { .setProperties(properties) .build()); LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -295,7 +295,7 @@ public void datasourceWithBasicAuth() { @Test public void withSessionCreateAsyncQueryFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -344,7 +344,7 @@ public void withSessionCreateAsyncQueryFailed() { @Test public void createSessionMoreThanLimitFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -376,7 +376,7 @@ public void createSessionMoreThanLimitFailed() { @Test public void recreateSessionIfNotReady() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -417,7 +417,7 @@ public void recreateSessionIfNotReady() { @Test public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -465,7 +465,7 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { @Test public void recreateSessionIfStale() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -523,7 +523,7 @@ public void recreateSessionIfStale() { @Test public void submitQueryInInvalidSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -561,7 +561,7 @@ public void datasourceNameIncludeUppercase() { .build()); LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -583,7 +583,7 @@ public void datasourceNameIncludeUppercase() { @Test public void concurrentSessionLimitIsDomainLevel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -612,7 +612,7 @@ public void concurrentSessionLimitIsDomainLevel() { @Test public void testDatasourceDisabled() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index a5935db2c9..a12a5aeac8 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -351,7 +351,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { public static class LocalEMRServerlessClientFactory implements EMRServerlessClientFactory { @Override - public EMRServerlessClient getClient() { + public EMRServerlessClient getClient(String accountId) { return new LocalEMRSClient(); } } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index 518aa84a9f..e0f04761c7 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -417,7 +417,7 @@ private class AssertionHelper { private Interaction interaction; AssertionHelper(String query, LocalEMRSClient emrClient) { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrClient; this.queryService = createAsyncQueryExecutorService( emrServerlessClientFactory, diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java index 230853a5eb..70a43e42d5 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java @@ -63,7 +63,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -131,7 +131,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -214,7 +214,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessCientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessCientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessCientFactory); // Mock flint index @@ -276,8 +276,8 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { ImmutableList.of(ALTER_SKIPPING, ALTER_COVERING, ALTER_MV) .forEach( mockDS -> { - LocalEMRSClient localEMRSClient = new LocalEMRSClient(); - EMRServerlessClientFactory clientFactory = () -> localEMRSClient; + LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory clientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(clientFactory); @@ -307,9 +307,9 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { .getStatus()); flintIndexJob.assertState(FlintIndexState.ACTIVE); - localEMRSClient.startJobRunCalled(1); - localEMRSClient.getJobRunResultCalled(1); - localEMRSClient.cancelJobRunCalled(0); + emrsClient.startJobRunCalled(1); + emrsClient.getJobRunResultCalled(1); + emrsClient.cancelJobRunCalled(0); Map mappings = mockDS.getIndexMappings(); Map meta = (HashMap) mappings.get("_meta"); Map options = (Map) meta.get("options"); @@ -342,8 +342,8 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { ImmutableList.of(ALTER_SKIPPING, ALTER_COVERING, ALTER_MV) .forEach( mockDS -> { - LocalEMRSClient localEMRSClient = new LocalEMRSClient(); - EMRServerlessClientFactory clientFactory = () -> localEMRSClient; + LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory clientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(clientFactory); @@ -373,9 +373,9 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { .getStatus()); flintIndexJob.assertState(FlintIndexState.ACTIVE); - localEMRSClient.startJobRunCalled(1); - localEMRSClient.getJobRunResultCalled(1); - localEMRSClient.cancelJobRunCalled(0); + emrsClient.startJobRunCalled(1); + emrsClient.getJobRunResultCalled(1); + emrsClient.cancelJobRunCalled(0); Map mappings = mockDS.getIndexMappings(); Map meta = (HashMap) mappings.get("_meta"); Map options = (Map) meta.get("options"); @@ -419,7 +419,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -494,7 +494,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -562,7 +562,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -624,7 +624,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -686,7 +686,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -750,7 +750,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -811,7 +811,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -873,7 +873,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -940,7 +940,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -1005,7 +1005,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Random validation exception"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -1071,7 +1071,7 @@ public CancelJobRunResult cancelJobRun( throw new IllegalArgumentException("Unknown Error"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 2b6b1d2ba0..2eed7b13a0 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -124,7 +124,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -175,7 +175,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -215,7 +215,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -253,7 +253,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -290,7 +290,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -350,7 +350,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -397,7 +397,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -443,7 +443,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -494,7 +494,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -542,7 +542,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -590,7 +590,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -644,7 +644,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -691,7 +691,7 @@ public CancelJobRunResult cancelJobRun( throw new IllegalArgumentException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -742,7 +742,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -887,7 +887,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -917,7 +917,7 @@ public void cancelRefreshStatement() { mockDS -> { AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService( - () -> + (accountId) -> new LocalEMRSClient() { @Override public GetJobRunResult getJobRunResult( @@ -962,7 +962,7 @@ public void cancelRefreshStatementWithActiveState() { mockDS -> { AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService( - () -> + (accountId) -> new LocalEMRSClient() { @Override public GetJobRunResult getJobRunResult( @@ -1009,7 +1009,7 @@ public void cancelRefreshStatementWithFailureInFetchingIndexMetadata() { new MockFlintIndex(client(), indexName, FlintIndexType.COVERING, null); AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService( - () -> + (accountId) -> new LocalEMRSClient() { @Override public GetJobRunResult getJobRunResult(String applicationId, String jobId) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java index 3bccf1b30b..439b2ed2d6 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java @@ -156,7 +156,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return getJobRunResult.call(); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index 89f3ac9871..c5964a61e3 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -54,7 +54,9 @@ public void testStreamingJobHouseKeeperWhenDataSourceDisabled() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -133,7 +135,9 @@ public void testStreamingJobHouseKeeperWhenCancelJobGivesTimeout() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -181,7 +185,9 @@ public void testSimulateConcurrentJobHouseKeeperExecution() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); FlintStreamingJobHouseKeeperTask.isRunning.compareAndSet(false, true); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); @@ -231,7 +237,9 @@ public void testStreamingJobClearnerWhenDataSourceIsDeleted() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -278,7 +286,9 @@ public void testStreamingJobHouseKeeperWhenDataSourceIsNeitherDisabledNorDeleted FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -320,7 +330,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -347,7 +359,9 @@ public void testStreamingJobHouseKeeperWhenFlintIndexIsCorrupted() throws Interr FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -389,7 +403,9 @@ public void updateIndexToManualRefresh( }; FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -429,7 +445,9 @@ public void testStreamingJobHouseKeeperMultipleTimesWhenDataSourceDisabled() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -501,7 +519,9 @@ public void testRunStreamingJobHouseKeeperWhenDataSourceIsDeleted() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index e8aeb17505..d0bfed94c0 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -58,7 +58,7 @@ public void setup() { new OpenSearchSessionStorageService(stateStore, new SessionModelXContentSerializer()); statementStorageService = new OpenSearchStatementStorageService(stateStore, new StatementModelXContentSerializer()); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; sessionManager = new SessionManager( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index d76b419df6..e76776e2fc 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -61,7 +61,7 @@ public void setup() { new OpenSearchStatementStorageService(stateStore, new StatementModelXContentSerializer()); sessionStorageService = new OpenSearchSessionStorageService(stateStore, new SessionModelXContentSerializer()); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; sessionManager = new SessionManager( @@ -279,7 +279,7 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; Session session = sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext);