diff --git a/async-query-core/build.gradle b/async-query-core/build.gradle index 176d14950f..1de6cb3105 100644 --- a/async-query-core/build.gradle +++ b/async-query-core/build.gradle @@ -50,7 +50,6 @@ dependencies { implementation project(':core') implementation project(':spark') // TODO: dependency to spark should be eliminated implementation project(':datasources') // TODO: dependency to datasources should be eliminated - implementation project(':legacy') // TODO: dependency to legacy should be eliminated implementation 'org.json:json:20231013' implementation 'com.google.code.gson:gson:2.8.9' diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java index 2c05dc865d..c5305ba445 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java @@ -11,7 +11,8 @@ public interface EMRServerlessClientFactory { /** * Gets an instance of {@link EMRServerlessClient}. * + * @param accountId Account ID of the requester. It will be used to decide the cluster. * @return An {@link EMRServerlessClient} instance. */ - EMRServerlessClient getClient(); + EMRServerlessClient getClient(String accountId); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java index 33c0e9fbfa..72973b3bbb 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java @@ -18,7 +18,6 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.metrics.MetricsService; -/** Implementation of {@link EMRServerlessClientFactory}. */ @RequiredArgsConstructor public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory { @@ -27,13 +26,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor private EMRServerlessClient emrServerlessClient; private String region; - /** - * Gets an instance of {@link EMRServerlessClient}. - * - * @return An {@link EMRServerlessClient} instance. - */ @Override - public EMRServerlessClient getClient() { + public EMRServerlessClient getClient(String accountId) { SparkExecutionEngineConfig sparkExecutionEngineConfig = this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig( new NullAsyncQueryRequestContext()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java index 603b5a6765..d6e70a9d86 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryHandlerFactory.java @@ -29,9 +29,9 @@ public class QueryHandlerFactory { private final MetricsService metricsService; protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; - public RefreshQueryHandler getRefreshQueryHandler() { + public RefreshQueryHandler getRefreshQueryHandler(String accountId) { return new RefreshQueryHandler( - emrServerlessClientFactory.getClient(), + emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, flintIndexMetadataService, leaseManager, @@ -40,18 +40,18 @@ public RefreshQueryHandler getRefreshQueryHandler() { sparkSubmitParametersBuilderProvider); } - public StreamingQueryHandler getStreamingQueryHandler() { + public StreamingQueryHandler getStreamingQueryHandler(String accountId) { return new StreamingQueryHandler( - emrServerlessClientFactory.getClient(), + emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, leaseManager, metricsService, sparkSubmitParametersBuilderProvider); } - public BatchQueryHandler getBatchQueryHandler() { + public BatchQueryHandler getBatchQueryHandler(String accountId) { return new BatchQueryHandler( - emrServerlessClientFactory.getClient(), + emrServerlessClientFactory.getClient(accountId), jobExecutionResponseReader, leaseManager, metricsService, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 5facdee567..3366e21894 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -54,14 +54,15 @@ public DispatchQueryResponse dispatch( .asyncQueryRequestContext(asyncQueryRequestContext) .build(); - return getQueryHandlerForFlintExtensionQuery(indexQueryDetails) + return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails) .submit(dispatchQueryRequest, context); } else { DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); - return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context); + return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()) + .submit(dispatchQueryRequest, context); } } @@ -74,28 +75,28 @@ private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchConte } private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( - IndexQueryDetails indexQueryDetails) { + DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { if (isEligibleForIndexDMLHandling(indexQueryDetails)) { return queryHandlerFactory.getIndexDMLHandler(); } else if (isEligibleForStreamingQuery(indexQueryDetails)) { - return queryHandlerFactory.getStreamingQueryHandler(); + return queryHandlerFactory.getStreamingQueryHandler(dispatchQueryRequest.getAccountId()); } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) { // Create should be handled by batch handler. This is to avoid DROP index incorrectly cancel // an interactive job. - return queryHandlerFactory.getBatchQueryHandler(); + return queryHandlerFactory.getBatchQueryHandler(dispatchQueryRequest.getAccountId()); } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { // Manual refresh should be handled by batch handler - return queryHandlerFactory.getRefreshQueryHandler(); + return queryHandlerFactory.getRefreshQueryHandler(dispatchQueryRequest.getAccountId()); } else { - return getDefaultAsyncQueryHandler(); + return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()); } } @NotNull - private AsyncQueryHandler getDefaultAsyncQueryHandler() { + private AsyncQueryHandler getDefaultAsyncQueryHandler(String accountId) { return sessionManager.isEnabled() ? queryHandlerFactory.getInteractiveQueryHandler() - : queryHandlerFactory.getBatchQueryHandler(); + : queryHandlerFactory.getBatchQueryHandler(accountId); } @NotNull @@ -143,11 +144,11 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { return queryHandlerFactory.getIndexDMLHandler(); } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { - return queryHandlerFactory.getRefreshQueryHandler(); + return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId()); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { - return queryHandlerFactory.getStreamingQueryHandler(); + return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId()); } else { - return queryHandlerFactory.getBatchQueryHandler(); + return queryHandlerFactory.getBatchQueryHandler(asyncQueryJobMetadata.getAccountId()); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index f838e89572..0c0727294b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -33,7 +33,7 @@ public Session createSession( .sessionId(sessionIdProvider.getSessionId(request)) .sessionStorageService(sessionStorageService) .statementStorageService(statementStorageService) - .serverlessClient(emrServerlessClientFactory.getClient()) + .serverlessClient(emrServerlessClientFactory.getClient(request.getAccountId())) .build(); session.open(request, asyncQueryRequestContext); return session; @@ -65,7 +65,7 @@ public Optional getSession(String sessionId, String dataSourceName) { .sessionId(sessionId) .sessionStorageService(sessionStorageService) .statementStorageService(statementStorageService) - .serverlessClient(emrServerlessClientFactory.getClient()) + .serverlessClient(emrServerlessClientFactory.getClient(model.get().getAccountId())) .sessionModel(model.get()) .sessionInactivityTimeoutMilli( sessionConfigSupplier.getSessionInactivityTimeoutMillis()) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 97ddccaf8f..244f4aee11 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -145,7 +145,8 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel) throws InterruptedException, TimeoutException { String applicationId = flintIndexStateModel.getApplicationId(); String jobId = flintIndexStateModel.getJobId(); - EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrServerlessClient = + emrServerlessClientFactory.getClient(flintIndexStateModel.getAccountId()); try { emrServerlessClient.cancelJobRun( flintIndexStateModel.getApplicationId(), flintIndexStateModel.getJobId(), true); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/async-query-core/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index f3a9a198fb..e3250c7a58 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -5,12 +5,8 @@ package org.opensearch.sql.spark.rest.model; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; import lombok.Data; import org.apache.commons.lang3.Validate; -import org.opensearch.core.xcontent.XContentParser; @Data public class CreateAsyncQueryRequest { @@ -32,35 +28,4 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang, S this.lang = Validate.notNull(lang, "lang can't be null"); this.sessionId = sessionId; } - - public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) - throws IOException { - String query = null; - LangType lang = null; - String datasource = null; - String sessionId = null; - try { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - if (fieldName.equals("query")) { - query = parser.textOrNull(); - } else if (fieldName.equals("lang")) { - String langString = parser.textOrNull(); - lang = LangType.fromString(langString); - } else if (fieldName.equals("datasource")) { - datasource = parser.textOrNull(); - } else if (fieldName.equals("sessionId")) { - sessionId = parser.textOrNull(); - } else { - throw new IllegalArgumentException("Unknown field: " + fieldName); - } - } - return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); - } catch (Exception e) { - throw new IllegalArgumentException( - String.format("Error while parsing the request body: %s", e.getMessage())); - } - } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 9dfe30b4b5..a96e203cea 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.utils; +import java.util.LinkedList; +import java.util.List; import java.util.Locale; import lombok.Getter; import lombok.experimental.UtilityClass; @@ -18,6 +20,7 @@ import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; @@ -32,8 +35,7 @@ @UtilityClass public class SQLQueryUtils { - // TODO Handle cases where the query has multiple table Names. - public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) { + public static List extractFullyQualifiedTableNames(String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); @@ -41,7 +43,7 @@ public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQ SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor(); statement.accept(sparkSqlTableNameVisitor); - return sparkSqlTableNameVisitor.getFullyQualifiedTableName(); + return sparkSqlTableNameVisitor.getFullyQualifiedTableNames(); } public static IndexQueryDetails extractIndexDetails(String sqlQuery) { @@ -73,23 +75,21 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { - @Getter private FullyQualifiedTableName fullyQualifiedTableName; + @Getter private List fullyQualifiedTableNames = new LinkedList<>(); - public SparkSqlTableNameVisitor() { - this.fullyQualifiedTableName = new FullyQualifiedTableName(); - } + public SparkSqlTableNameVisitor() {} @Override - public Void visitTableName(SqlBaseParser.TableNameContext ctx) { - fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText()); - return super.visitTableName(ctx); + public Void visitIdentifierReference(IdentifierReferenceContext ctx) { + fullyQualifiedTableNames.add(new FullyQualifiedTableName(ctx.getText())); + return super.visitIdentifierReference(ctx); } @Override public Void visitDropTable(SqlBaseParser.DropTableContext ctx) { for (ParseTree parseTree : ctx.children) { if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { - fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText())); } } return super.visitDropTable(ctx); @@ -99,7 +99,7 @@ public Void visitDropTable(SqlBaseParser.DropTableContext ctx) { public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) { for (ParseTree parseTree : ctx.children) { if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { - fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText())); } } return super.visitDescribeRelation(ctx); @@ -110,7 +110,7 @@ public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) { public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) { for (ParseTree parseTree : ctx.children) { if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) { - fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText()); + fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText())); } } return super.visitCreateTableHeader(ctx); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 3d796f67ab..99d4cc722e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -136,7 +136,7 @@ public class AsyncQueryCoreIntegTest { @BeforeEach public void setUp() { emrServerlessClientFactory = - () -> new EmrServerlessClientImpl(awsemrServerless, metricsService); + (accountId) -> new EmrServerlessClientImpl(awsemrServerless, metricsService); SparkParameterComposerCollection collection = new SparkParameterComposerCollection(); collection.register( DataSourceType.S3GLUE, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java index a27363a153..309d29c600 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java @@ -21,6 +21,7 @@ @ExtendWith(MockitoExtension.class) public class EMRServerlessClientFactoryImplTest { + public static final String ACCOUNT_ID = "accountId"; @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; @Mock private MetricsService metricsService; @@ -30,7 +31,9 @@ public void testGetClient() { .thenReturn(createSparkExecutionEngineConfig()); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); - EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(); + + EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID); + Assertions.assertNotNull(emrserverlessClient); } @@ -41,16 +44,16 @@ public void testGetClientWithChangeInSetting() { .thenReturn(sparkExecutionEngineConfig); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); - EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID); Assertions.assertNotNull(emrserverlessClient); - EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(ACCOUNT_ID); Assertions.assertEquals(emrServerlessClient1, emrserverlessClient); sparkExecutionEngineConfig.setRegion(TestConstants.US_WEST_REGION); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn(sparkExecutionEngineConfig); - EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(); + EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(ACCOUNT_ID); Assertions.assertNotEquals(emrServerlessClient2, emrserverlessClient); Assertions.assertNotEquals(emrServerlessClient2, emrServerlessClient1); } @@ -60,9 +63,11 @@ public void testGetClientWithException() { when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())).thenReturn(null); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); + IllegalArgumentException illegalArgumentException = Assertions.assertThrows( - IllegalArgumentException.class, emrServerlessClientFactory::getClient); + IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID)); + Assertions.assertEquals( "Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config" + " in cluster settings to enable them.", @@ -77,9 +82,11 @@ public void testGetClientWithExceptionWithNullRegion() { .thenReturn(sparkExecutionEngineConfig); EMRServerlessClientFactory emrServerlessClientFactory = new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService); + IllegalArgumentException illegalArgumentException = Assertions.assertThrows( - IllegalArgumentException.class, emrServerlessClientFactory::getClient); + IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID)); + Assertions.assertEquals( "Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config" + " in cluster settings to enable them.", diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 993d489ded..cc151821aa 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -4,9 +4,7 @@ package org.opensearch.sql.spark.client; -import static java.util.Collections.emptyList; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -32,7 +30,6 @@ import java.util.List; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; @@ -40,10 +37,7 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.legacy.esdomain.LocalClusterState; -import org.opensearch.sql.legacy.metrics.Metrics; -import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.metrics.MetricsService; import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilder; @@ -51,22 +45,12 @@ @ExtendWith(MockitoExtension.class) public class EmrServerlessClientImplTest { @Mock private AWSEMRServerless emrServerless; - @Mock private OpenSearchSettings settings; @Mock private MetricsService metricsService; @Captor private ArgumentCaptor startJobRunRequestArgumentCaptor; @InjectMocks EmrServerlessClientImpl emrServerlessClient; - @BeforeEach - public void setUp() { - doReturn(emptyList()).when(settings).getSettings(); - when(settings.getSettingValue(Settings.Key.METRICS_ROLLING_INTERVAL)).thenReturn(3600L); - when(settings.getSettingValue(Settings.Key.METRICS_ROLLING_WINDOW)).thenReturn(600L); - LocalClusterState.state().setPluginSettings(settings); - Metrics.getInstance().registerDefaultMetrics(); - } - @Test void testStartJobRun() { StartJobRunResult response = new StartJobRunResult(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 6de778d3cd..5833ee91d4 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -159,7 +159,7 @@ void setUp() { @Test void testDispatchSelectQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -202,7 +202,7 @@ void testDispatchSelectQuery() { @Test void testDispatchSelectQueryWithLakeFormation() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -234,7 +234,7 @@ void testDispatchSelectQueryWithLakeFormation() { @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -265,6 +265,45 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { verifyNoInteractions(flintIndexMetadataService); } + @Test + void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + HashMap tags = new HashMap<>(); + tags.put(DATASOURCE_TAG_KEY, MY_GLUE); + tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); + String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "noauth", + new HashMap<>() { + { + } + }, + query); + StartJobRequest expected = + new StartJobRequest( + "TEST_CLUSTER:batch", + null, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + "query_execution_result_my_glue"); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithNoAuth(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + .thenReturn(dataSourceMetadata); + + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + verifyNoInteractions(flintIndexMetadataService); + } + @Test void testDispatchSelectQueryCreateNewSession() { String query = "select * from my_glue.default.http_logs"; @@ -334,7 +373,7 @@ void testDispatchSelectQueryFailedCreateSession() { @Test void testDispatchCreateAutoRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -370,7 +409,7 @@ void testDispatchCreateAutoRefreshIndexQuery() { @Test void testDispatchCreateManualRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -405,7 +444,7 @@ void testDispatchCreateManualRefreshIndexQuery() { @Test void testDispatchWithPPLQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -440,7 +479,7 @@ void testDispatchWithPPLQuery() { @Test void testDispatchQueryWithoutATableAndDataSourceName() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -473,7 +512,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { @Test void testDispatchIndexQueryWithoutADatasourceName() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -509,7 +548,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { @Test void testDispatchMaterializedViewQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_mv_1"); @@ -545,7 +584,7 @@ void testDispatchMaterializedViewQuery() { @Test void testDispatchShowMVQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -578,7 +617,7 @@ void testDispatchShowMVQuery() { @Test void testRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -611,7 +650,7 @@ void testRefreshIndexQuery() { @Test void testDispatchDescribeIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -644,7 +683,7 @@ void testDispatchDescribeIndexQuery() { @Test void testDispatchAlterToAutoRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -770,7 +809,7 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() @@ -832,7 +871,7 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() @@ -846,7 +885,7 @@ void testCancelQueryWithNoSessionId() { @Test void testGetQueryResponse() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); // simulate result index is not created yet @@ -943,7 +982,7 @@ void testGetQueryResponseWithSuccess() { @Test void testDispatchQueryWithExtraSparkSubmitParameters() { - when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) .thenReturn(dataSourceMetadata); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 620d187e52..0d7c43fc0d 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -5,12 +5,17 @@ package org.opensearch.sql.spark.utils; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; +import java.util.List; import lombok.Getter; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; @@ -26,46 +31,34 @@ public class SQLQueryUtilsTest { void testExtractionOfTableNameFromSQLQueries() { String sqlQuery = "select * from my_glue.default.http_logs"; FullyQualifiedTableName fullyQualifiedTableName = - SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery).get(0); + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertFullyQualifiedTableName("my_glue", "default", "http_logs", fullyQualifiedTableName); sqlQuery = "select * from my_glue.db.http_logs"; - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertEquals("my_glue", fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("db", fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery).get(0); + assertFullyQualifiedTableName("my_glue", "db", "http_logs", fullyQualifiedTableName); sqlQuery = "select * from my_glue.http_logs"; - fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - Assertions.assertEquals("my_glue", fullyQualifiedTableName.getSchemaName()); - Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery).get(0); + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertFullyQualifiedTableName(null, "my_glue", "http_logs", fullyQualifiedTableName); sqlQuery = "select * from http_logs"; - fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); - Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("http_logs", fullyQualifiedTableName.getTableName()); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery).get(0); + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertFullyQualifiedTableName(null, null, "http_logs", fullyQualifiedTableName); sqlQuery = "DROP TABLE myS3.default.alb_logs"; - fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery).get(0); + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertFullyQualifiedTableName("myS3", "default", "alb_logs", fullyQualifiedTableName); sqlQuery = "DESCRIBE TABLE myS3.default.alb_logs"; - fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery).get(0); + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertFullyQualifiedTableName("myS3", "default", "alb_logs", fullyQualifiedTableName); sqlQuery = "CREATE EXTERNAL TABLE\n" @@ -74,31 +67,60 @@ void testExtractionOfTableNameFromSQLQueries() { + "[ ROW FORMAT DELIMITED row_format ]\n" + "STORED AS file_format\n" + "LOCATION { 's3://bucket/folder/' }"; - fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery).get(0); + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertFullyQualifiedTableName("myS3", "default", "alb_logs", fullyQualifiedTableName); } @Test - void testErrorScenarios() { + void testMultipleTables() { + String[] sqlQueries = { + "SELECT * FROM my_glue.default.http_logs, my_glue.default.access_logs", + "SELECT * FROM my_glue.default.http_logs LEFT JOIN my_glue.default.access_logs", + "SELECT table1.id, table2.id FROM my_glue.default.http_logs table1 LEFT OUTER JOIN" + + " (SELECT * FROM my_glue.default.access_logs) table2 ON table1.tag = table2.tag", + "SELECT table1.id, table2.id FROM my_glue.default.http_logs FOR VERSION AS OF 1 table1" + + " LEFT OUTER JOIN" + + " (SELECT * FROM my_glue.default.access_logs) table2" + + " ON table1.tag = table2.tag" + }; + + for (String sqlQuery : sqlQueries) { + List fullyQualifiedTableNames = + SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery); + + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertEquals(2, fullyQualifiedTableNames.size()); + assertFullyQualifiedTableName( + "my_glue", "default", "http_logs", fullyQualifiedTableNames.get(0)); + assertFullyQualifiedTableName( + "my_glue", "default", "access_logs", fullyQualifiedTableNames.get(1)); + } + } + + @Test + void testMultipleTablesWithJoin() { + String sqlQuery = + "select * from my_glue.default.http_logs LEFT JOIN my_glue.default.access_logs"; + + List fullyQualifiedTableNames = + SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery); + + assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); + assertFullyQualifiedTableName( + "my_glue", "default", "http_logs", fullyQualifiedTableNames.get(0)); + assertFullyQualifiedTableName( + "my_glue", "default", "access_logs", fullyQualifiedTableNames.get(1)); + } + + @Test + void testNoFullyQualifiedTableName() { String sqlQuery = "SHOW tables"; - FullyQualifiedTableName fullyQualifiedTableName = - SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertNotNull(fullyQualifiedTableName); - Assertions.assertNull(fullyQualifiedTableName.getFullyQualifiedName()); - Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); - Assertions.assertNull(fullyQualifiedTableName.getTableName()); - Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); - - sqlQuery = "DESCRIBE TABLE FROM myS3.default.alb_logs"; - fullyQualifiedTableName = SQLQueryUtils.extractFullyQualifiedTableName(sqlQuery); - Assertions.assertFalse(SQLQueryUtils.isFlintExtensionQuery(sqlQuery)); - Assertions.assertEquals("FROM", fullyQualifiedTableName.getFullyQualifiedName()); - Assertions.assertNull(fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("FROM", fullyQualifiedTableName.getTableName()); - Assertions.assertNull(fullyQualifiedTableName.getDatasourceName()); + + List fullyQualifiedTableNames = + SQLQueryUtils.extractFullyQualifiedTableNames(sqlQuery); + + assertEquals(0, fullyQualifiedTableNames.size()); } @Test @@ -112,25 +134,27 @@ void testExtractionFromFlintSkippingIndexQueries() { + " WITH (auto_refresh = true)", "CREATE SKIPPING INDEX ON myS3.default.alb_logs(l_orderkey VALUE_SET) " + " WHERE elb_status_code = 500 " - + " WITH (auto_refresh = true)" + + " WITH (auto_refresh = true)", + "DROP SKIPPING INDEX ON myS3.default.alb_logs", + "VACUUM SKIPPING INDEX ON myS3.default.alb_logs", + "ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)", }; for (String query : createSkippingIndexQueries) { - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(query), "Failed query: " + query); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(query), "Failed query: " + query); + IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(query); FullyQualifiedTableName fullyQualifiedTableName = indexQueryDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexQueryDetails.getIndexName()); - Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + assertNull(indexQueryDetails.getIndexName()); + assertFullyQualifiedTableName("myS3", "default", "alb_logs", fullyQualifiedTableName); } } @Test void testExtractionFromFlintCoveringIndexQueries() { - String[] createCoveredIndexQueries = { + String[] coveringIndexQueries = { "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, l_quantity)", "CREATE INDEX IF NOT EXISTS elb_and_requestUri " + " ON myS3.default.alb_logs(l_orderkey, l_quantity) " @@ -139,167 +163,177 @@ void testExtractionFromFlintCoveringIndexQueries() { + " WITH (auto_refresh = true)", "CREATE INDEX elb_and_requestUri ON myS3.default.alb_logs(l_orderkey, l_quantity) " + " WHERE elb_status_code = 500 " - + " WITH (auto_refresh = true)" + + " WITH (auto_refresh = true)", + "DROP INDEX elb_and_requestUri ON myS3.default.alb_logs", + "VACUUM INDEX elb_and_requestUri ON myS3.default.alb_logs", + "ALTER INDEX elb_and_requestUri ON myS3.default.alb_logs WITH (auto_refresh = false)" }; - for (String query : createCoveredIndexQueries) { - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(query), "Failed query: " + query); + for (String query : coveringIndexQueries) { + assertTrue(SQLQueryUtils.isFlintExtensionQuery(query), "Failed query: " + query); + IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(query); FullyQualifiedTableName fullyQualifiedTableName = indexQueryDetails.getFullyQualifiedTableName(); - Assertions.assertEquals("elb_and_requestUri", indexQueryDetails.getIndexName()); - Assertions.assertEquals("myS3", fullyQualifiedTableName.getDatasourceName()); - Assertions.assertEquals("default", fullyQualifiedTableName.getSchemaName()); - Assertions.assertEquals("alb_logs", fullyQualifiedTableName.getTableName()); + assertEquals("elb_and_requestUri", indexQueryDetails.getIndexName()); + assertFullyQualifiedTableName("myS3", "default", "alb_logs", fullyQualifiedTableName); } } @Test void testExtractionFromFlintMVQuery() { - String createCoveredIndexQuery = - "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" - + " (auto_refresh = true)"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(createCoveredIndexQuery)); - IndexQueryDetails indexQueryDetails = - SQLQueryUtils.extractIndexDetails(createCoveredIndexQuery); - FullyQualifiedTableName fullyQualifiedTableName = - indexQueryDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexQueryDetails.getIndexName()); - Assertions.assertNull(fullyQualifiedTableName); - Assertions.assertEquals("mv_1", indexQueryDetails.getMvName()); + String[] mvQueries = { + "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" + + " (auto_refresh = true)", + "DROP MATERIALIZED VIEW mv_1", + "VACUUM MATERIALIZED VIEW mv_1", + "ALTER MATERIALIZED VIEW mv_1 WITH (auto_refresh = false)", + }; + + for (String query : mvQueries) { + assertTrue(SQLQueryUtils.isFlintExtensionQuery(query)); + + IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(query); + FullyQualifiedTableName fullyQualifiedTableName = + indexQueryDetails.getFullyQualifiedTableName(); + + assertNull(indexQueryDetails.getIndexName()); + assertNull(fullyQualifiedTableName); + assertEquals("mv_1", indexQueryDetails.getMvName()); + } } @Test - void testDescIndex() { + void testDescSkippingIndex() { String descSkippingIndex = "DESC SKIPPING INDEX ON mys3.default.http_logs"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descSkippingIndex)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(descSkippingIndex)); IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(descSkippingIndex); FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexDetails.getIndexName()); - Assertions.assertNotNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + assertNull(indexDetails.getIndexName()); + assertNotNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); String descCoveringIndex = "DESC INDEX cv1 ON mys3.default.http_logs"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descCoveringIndex)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(descCoveringIndex)); indexDetails = SQLQueryUtils.extractIndexDetails(descCoveringIndex); fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertEquals("cv1", indexDetails.getIndexName()); - Assertions.assertNotNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + assertEquals("cv1", indexDetails.getIndexName()); + assertNotNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); String descMv = "DESC MATERIALIZED VIEW mv1"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(descMv)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(descMv)); indexDetails = SQLQueryUtils.extractIndexDetails(descMv); fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexDetails.getIndexName()); - Assertions.assertEquals("mv1", indexDetails.getMvName()); - Assertions.assertNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); + assertNull(indexDetails.getIndexName()); + assertEquals("mv1", indexDetails.getMvName()); + assertNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.DESCRIBE, indexDetails.getIndexQueryActionType()); } @Test void testShowIndex() { String showCoveringIndex = " SHOW INDEX ON myS3.default.http_logs"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(showCoveringIndex)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(showCoveringIndex)); IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(showCoveringIndex); FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexDetails.getIndexName()); - Assertions.assertNull(indexDetails.getMvName()); - Assertions.assertNotNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); + assertNull(indexDetails.getIndexName()); + assertNull(indexDetails.getMvName()); + assertNotNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); String showMV = "SHOW MATERIALIZED VIEW IN my_glue.default"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(showMV)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(showMV)); indexDetails = SQLQueryUtils.extractIndexDetails(showMV); fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexDetails.getIndexName()); - Assertions.assertNull(indexDetails.getMvName()); - Assertions.assertNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); + assertNull(indexDetails.getIndexName()); + assertNull(indexDetails.getMvName()); + assertNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.SHOW, indexDetails.getIndexQueryActionType()); } @Test void testRefreshIndex() { String refreshSkippingIndex = "REFRESH SKIPPING INDEX ON mys3.default.http_logs"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshSkippingIndex)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshSkippingIndex)); IndexQueryDetails indexDetails = SQLQueryUtils.extractIndexDetails(refreshSkippingIndex); FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexDetails.getIndexName()); - Assertions.assertNotNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + assertNull(indexDetails.getIndexName()); + assertNotNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.SKIPPING, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); String refreshCoveringIndex = "REFRESH INDEX cv1 ON mys3.default.http_logs"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshCoveringIndex)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshCoveringIndex)); indexDetails = SQLQueryUtils.extractIndexDetails(refreshCoveringIndex); fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertEquals("cv1", indexDetails.getIndexName()); - Assertions.assertNotNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + assertEquals("cv1", indexDetails.getIndexName()); + assertNotNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.COVERING, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); String refreshMV = "REFRESH MATERIALIZED VIEW mv1"; - Assertions.assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshMV)); + assertTrue(SQLQueryUtils.isFlintExtensionQuery(refreshMV)); indexDetails = SQLQueryUtils.extractIndexDetails(refreshMV); fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); - Assertions.assertNull(indexDetails.getIndexName()); - Assertions.assertEquals("mv1", indexDetails.getMvName()); - Assertions.assertNull(fullyQualifiedTableName); - Assertions.assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); - Assertions.assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); + assertNull(indexDetails.getIndexName()); + assertEquals("mv1", indexDetails.getMvName()); + assertNull(fullyQualifiedTableName); + assertEquals(FlintIndexType.MATERIALIZED_VIEW, indexDetails.getIndexType()); + assertEquals(IndexQueryActionType.REFRESH, indexDetails.getIndexQueryActionType()); } /** https://github.com/opensearch-project/sql/issues/2206 */ @Test void testAutoRefresh() { - Assertions.assertFalse( + assertFalse( SQLQueryUtils.extractIndexDetails(skippingIndex().getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertFalse( + assertFalse( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "false").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "true").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "true").withSemicolon().getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"auto_refresh\"", "true").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"auto_refresh\"", "true").withSemicolon().getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"auto_refresh\"", "\"true\"").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( skippingIndex() .withProperty("\"auto_refresh\"", "\"true\"") @@ -308,48 +342,48 @@ void testAutoRefresh() { .getFlintIndexOptions() .autoRefresh()); - Assertions.assertFalse( + assertFalse( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("auto_refresh", "1").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertFalse( + assertFalse( SQLQueryUtils.extractIndexDetails(skippingIndex().withProperty("interval", "1").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertFalse( + assertFalse( SQLQueryUtils.extractIndexDetails( skippingIndex().withProperty("\"\"", "\"true\"").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertFalse( + assertFalse( SQLQueryUtils.extractIndexDetails(index().getQuery()).getFlintIndexOptions().autoRefresh()); - Assertions.assertFalse( + assertFalse( SQLQueryUtils.extractIndexDetails(index().withProperty("auto_refresh", "false").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails(index().withProperty("auto_refresh", "true").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( index().withProperty("auto_refresh", "true").withSemicolon().getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails(mv().withProperty("auto_refresh", "true").getQuery()) .getFlintIndexOptions() .autoRefresh()); - Assertions.assertTrue( + assertTrue( SQLQueryUtils.extractIndexDetails( mv().withProperty("auto_refresh", "true").withSemicolon().getQuery()) .getFlintIndexOptions() @@ -389,4 +423,14 @@ public IndexQuery withSemicolon() { return this; } } + + private void assertFullyQualifiedTableName( + String expectedDatasourceName, + String expectedSchemaName, + String expectedTableName, + FullyQualifiedTableName fullyQualifiedTableName) { + assertEquals(expectedDatasourceName, fullyQualifiedTableName.getDatasourceName()); + assertEquals(expectedSchemaName, fullyQualifiedTableName.getSchemaName()); + assertEquals(expectedTableName, fullyQualifiedTableName.getTableName()); + } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/async-query/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index ced5609083..b4a72584b8 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -37,6 +37,7 @@ import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction; +import org.opensearch.sql.spark.transport.format.CreateAsyncQueryRequestConverter; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; @@ -119,7 +120,7 @@ private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClie try { MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT); CreateAsyncQueryRequest submitJobRequest = - CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); + CreateAsyncQueryRequestConverter.fromXContentParser(restRequest.contentParser()); Scheduler.schedule( nodeClient, () -> diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index b8252494e7..0e9da0c13c 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -16,8 +16,8 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; +import org.opensearch.sql.spark.transport.model.AsyncQueryResult; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; import org.opensearch.tasks.Task; diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java index 3a2a5b110f..afa6797694 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatter.java @@ -14,7 +14,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.sql.protocol.response.QueryResult; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.transport.model.AsyncQueryResult; /** * JSON response format with schema header and data rows. For example, diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/format/CreateAsyncQueryRequestConverter.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/format/CreateAsyncQueryRequestConverter.java new file mode 100644 index 0000000000..c22c2da24d --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/format/CreateAsyncQueryRequestConverter.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.transport.format; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import lombok.experimental.UtilityClass; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.LangType; + +@UtilityClass +public class CreateAsyncQueryRequestConverter { + public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) { + String query = null; + LangType lang = null; + String datasource = null; + String sessionId = null; + try { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (fieldName.equals("query")) { + query = parser.textOrNull(); + } else if (fieldName.equals("lang")) { + String langString = parser.textOrNull(); + lang = LangType.fromString(langString); + } else if (fieldName.equals("datasource")) { + datasource = parser.textOrNull(); + } else if (fieldName.equals("sessionId")) { + sessionId = parser.textOrNull(); + } else { + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + } + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("Error while parsing the request body: %s", e.getMessage())); + } + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/model/AsyncQueryResult.java similarity index 87% rename from async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java rename to async-query/src/main/java/org/opensearch/sql/spark/transport/model/AsyncQueryResult.java index c229aa3920..712cebf7e1 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryResult.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/model/AsyncQueryResult.java @@ -1,4 +1,9 @@ -package org.opensearch.sql.spark.asyncquery.model; +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.transport.model; import java.util.Collection; import lombok.Getter; diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index f8b61aee5a..3ff806bf50 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -49,7 +49,7 @@ public class AsyncQueryExecutorServiceImplSpecTest extends AsyncQueryExecutorSer @Disabled("batch query is unsupported") public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -79,7 +79,7 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { @Disabled("batch query is unsupported") public void sessionLimitNotImpactBatchQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -104,7 +104,7 @@ public void sessionLimitNotImpactBatchQuery() { @Disabled("batch query is unsupported") public void createAsyncQueryCreateJobWithCorrectParameters() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -140,7 +140,7 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { @Test public void withSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -170,7 +170,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { @Test public void reuseSessionWhenCreateAsyncQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -224,7 +224,7 @@ public void reuseSessionWhenCreateAsyncQuery() { @Disabled("batch query is unsupported") public void batchQueryHasTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -240,7 +240,7 @@ public void batchQueryHasTimeout() { @Test public void interactiveQueryNoTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -274,7 +274,7 @@ public void datasourceWithBasicAuth() { .setProperties(properties) .build()); LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -295,7 +295,7 @@ public void datasourceWithBasicAuth() { @Test public void withSessionCreateAsyncQueryFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -344,7 +344,7 @@ public void withSessionCreateAsyncQueryFailed() { @Test public void createSessionMoreThanLimitFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -376,7 +376,7 @@ public void createSessionMoreThanLimitFailed() { @Test public void recreateSessionIfNotReady() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -417,7 +417,7 @@ public void recreateSessionIfNotReady() { @Test public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -465,7 +465,7 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { @Test public void recreateSessionIfStale() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -523,7 +523,7 @@ public void recreateSessionIfStale() { @Test public void submitQueryInInvalidSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -561,7 +561,7 @@ public void datasourceNameIncludeUppercase() { .build()); LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -583,7 +583,7 @@ public void datasourceNameIncludeUppercase() { @Test public void concurrentSessionLimitIsDomainLevel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -612,7 +612,7 @@ public void concurrentSessionLimitIsDomainLevel() { @Test public void testDatasourceDisabled() { LocalEMRSClient emrsClient = new LocalEMRSClient(); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 4a73fc8b13..ed00cb1022 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -361,7 +361,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { public static class LocalEMRServerlessClientFactory implements EMRServerlessClientFactory { @Override - public EMRServerlessClient getClient() { + public EMRServerlessClient getClient(String accountId) { return new LocalEMRSClient(); } } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index 12fa8043ea..e0f04761c7 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -25,7 +25,6 @@ import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; @@ -38,6 +37,7 @@ import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; +import org.opensearch.sql.spark.transport.model.AsyncQueryResult; public class AsyncQueryGetResultSpecTest extends AsyncQueryExecutorServiceSpec { AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); @@ -417,7 +417,7 @@ private class AssertionHelper { private Interaction interaction; AssertionHelper(String query, LocalEMRSClient emrClient) { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrClient; this.queryService = createAsyncQueryExecutorService( emrServerlessClientFactory, diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java index 230853a5eb..70a43e42d5 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java @@ -63,7 +63,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -131,7 +131,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -214,7 +214,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessCientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessCientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessCientFactory); // Mock flint index @@ -276,8 +276,8 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { ImmutableList.of(ALTER_SKIPPING, ALTER_COVERING, ALTER_MV) .forEach( mockDS -> { - LocalEMRSClient localEMRSClient = new LocalEMRSClient(); - EMRServerlessClientFactory clientFactory = () -> localEMRSClient; + LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory clientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(clientFactory); @@ -307,9 +307,9 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { .getStatus()); flintIndexJob.assertState(FlintIndexState.ACTIVE); - localEMRSClient.startJobRunCalled(1); - localEMRSClient.getJobRunResultCalled(1); - localEMRSClient.cancelJobRunCalled(0); + emrsClient.startJobRunCalled(1); + emrsClient.getJobRunResultCalled(1); + emrsClient.cancelJobRunCalled(0); Map mappings = mockDS.getIndexMappings(); Map meta = (HashMap) mappings.get("_meta"); Map options = (Map) meta.get("options"); @@ -342,8 +342,8 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { ImmutableList.of(ALTER_SKIPPING, ALTER_COVERING, ALTER_MV) .forEach( mockDS -> { - LocalEMRSClient localEMRSClient = new LocalEMRSClient(); - EMRServerlessClientFactory clientFactory = () -> localEMRSClient; + LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory clientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(clientFactory); @@ -373,9 +373,9 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { .getStatus()); flintIndexJob.assertState(FlintIndexState.ACTIVE); - localEMRSClient.startJobRunCalled(1); - localEMRSClient.getJobRunResultCalled(1); - localEMRSClient.cancelJobRunCalled(0); + emrsClient.startJobRunCalled(1); + emrsClient.getJobRunResultCalled(1); + emrsClient.cancelJobRunCalled(0); Map mappings = mockDS.getIndexMappings(); Map meta = (HashMap) mappings.get("_meta"); Map options = (Map) meta.get("options"); @@ -419,7 +419,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -494,7 +494,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -562,7 +562,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -624,7 +624,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -686,7 +686,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -750,7 +750,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -811,7 +811,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -873,7 +873,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(jobRun); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -940,7 +940,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -1005,7 +1005,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Random validation exception"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index @@ -1071,7 +1071,7 @@ public CancelJobRunResult cancelJobRun( throw new IllegalArgumentException("Unknown Error"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 2b6b1d2ba0..2eed7b13a0 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -124,7 +124,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -175,7 +175,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -215,7 +215,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -253,7 +253,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -290,7 +290,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -350,7 +350,7 @@ public CancelJobRunResult cancelJobRun( throw new ValidationException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -397,7 +397,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -443,7 +443,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -494,7 +494,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -542,7 +542,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -590,7 +590,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -644,7 +644,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -691,7 +691,7 @@ public CancelJobRunResult cancelJobRun( throw new IllegalArgumentException("Job run is not in a cancellable state"); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -742,7 +742,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -887,7 +887,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); @@ -917,7 +917,7 @@ public void cancelRefreshStatement() { mockDS -> { AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService( - () -> + (accountId) -> new LocalEMRSClient() { @Override public GetJobRunResult getJobRunResult( @@ -962,7 +962,7 @@ public void cancelRefreshStatementWithActiveState() { mockDS -> { AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService( - () -> + (accountId) -> new LocalEMRSClient() { @Override public GetJobRunResult getJobRunResult( @@ -1009,7 +1009,7 @@ public void cancelRefreshStatementWithFailureInFetchingIndexMetadata() { new MockFlintIndex(client(), indexName, FlintIndexType.COVERING, null); AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService( - () -> + (accountId) -> new LocalEMRSClient() { @Override public GetJobRunResult getJobRunResult(String applicationId, String jobId) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java index 3bccf1b30b..439b2ed2d6 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java @@ -156,7 +156,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return getJobRunResult.call(); } }; - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = createAsyncQueryExecutorService(emrServerlessClientFactory); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index 89f3ac9871..c5964a61e3 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -54,7 +54,9 @@ public void testStreamingJobHouseKeeperWhenDataSourceDisabled() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -133,7 +135,9 @@ public void testStreamingJobHouseKeeperWhenCancelJobGivesTimeout() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -181,7 +185,9 @@ public void testSimulateConcurrentJobHouseKeeperExecution() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); FlintStreamingJobHouseKeeperTask.isRunning.compareAndSet(false, true); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); @@ -231,7 +237,9 @@ public void testStreamingJobClearnerWhenDataSourceIsDeleted() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -278,7 +286,9 @@ public void testStreamingJobHouseKeeperWhenDataSourceIsNeitherDisabledNorDeleted FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -320,7 +330,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -347,7 +359,9 @@ public void testStreamingJobHouseKeeperWhenFlintIndexIsCorrupted() throws Interr FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -389,7 +403,9 @@ public void updateIndexToManualRefresh( }; FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -429,7 +445,9 @@ public void testStreamingJobHouseKeeperMultipleTimesWhenDataSourceDisabled() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -501,7 +519,9 @@ public void testRunStreamingJobHouseKeeperWhenDataSourceIsDeleted() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, getFlintIndexOpFactory(() -> emrsClient)); + dataSourceService, + flintIndexMetadataService, + getFlintIndexOpFactory((accountId) -> emrsClient)); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index e8aeb17505..d0bfed94c0 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -58,7 +58,7 @@ public void setup() { new OpenSearchSessionStorageService(stateStore, new SessionModelXContentSerializer()); statementStorageService = new OpenSearchStatementStorageService(stateStore, new StatementModelXContentSerializer()); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; sessionManager = new SessionManager( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index d76b419df6..e76776e2fc 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -61,7 +61,7 @@ public void setup() { new OpenSearchStatementStorageService(stateStore, new StatementModelXContentSerializer()); sessionStorageService = new OpenSearchSessionStorageService(stateStore, new SessionModelXContentSerializer()); - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; sessionManager = new SessionManager( @@ -279,7 +279,7 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { - EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; Session session = sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java b/async-query/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java index 711db75efb..bb7d5f7893 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/transport/format/AsyncQueryResultResponseFormatterTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.sql.spark.transport.format; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -11,7 +16,7 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; import org.opensearch.sql.executor.ExecutionEngine; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.transport.model.AsyncQueryResult; public class AsyncQueryResultResponseFormatterTest { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java b/async-query/src/test/java/org/opensearch/sql/spark/transport/format/CreateAsyncQueryRequestConverterTest.java similarity index 83% rename from async-query-core/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java rename to async-query/src/test/java/org/opensearch/sql/spark/transport/format/CreateAsyncQueryRequestConverterTest.java index de38ca0e3c..d7f8046a1b 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequestTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/transport/format/CreateAsyncQueryRequestConverterTest.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.spark.rest.model; +package org.opensearch.sql.spark.transport.format; import java.io.IOException; import org.junit.jupiter.api.Assertions; @@ -12,8 +12,10 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.LangType; -public class CreateAsyncQueryRequestTest { +public class CreateAsyncQueryRequestConverterTest { @Test public void fromXContent() throws IOException { @@ -24,7 +26,7 @@ public void fromXContent() throws IOException { + " \"query\": \"select 1\"\n" + "}"; CreateAsyncQueryRequest queryRequest = - CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + CreateAsyncQueryRequestConverter.fromXContentParser(xContentParser(request)); Assertions.assertEquals("my_glue", queryRequest.getDatasource()); Assertions.assertEquals(LangType.SQL, queryRequest.getLang()); Assertions.assertEquals("select 1", queryRequest.getQuery()); @@ -48,7 +50,7 @@ public void fromXContentWithDuplicateFields() throws IOException { IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, - () -> CreateAsyncQueryRequest.fromXContentParser(xContentParser(request))); + () -> CreateAsyncQueryRequestConverter.fromXContentParser(xContentParser(request))); Assertions.assertTrue( illegalArgumentException .getMessage() @@ -67,7 +69,7 @@ public void fromXContentWithUnknownField() throws IOException { IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, - () -> CreateAsyncQueryRequest.fromXContentParser(xContentParser(request))); + () -> CreateAsyncQueryRequestConverter.fromXContentParser(xContentParser(request))); Assertions.assertEquals( "Error while parsing the request body: Unknown field: random", illegalArgumentException.getMessage()); @@ -81,7 +83,7 @@ public void fromXContentWithWrongDatatype() throws IOException { IllegalArgumentException illegalArgumentException = Assertions.assertThrows( IllegalArgumentException.class, - () -> CreateAsyncQueryRequest.fromXContentParser(xContentParser(request))); + () -> CreateAsyncQueryRequestConverter.fromXContentParser(xContentParser(request))); Assertions.assertEquals( "Error while parsing the request body: Can't get text on a START_ARRAY at 1:16", illegalArgumentException.getMessage()); @@ -97,7 +99,7 @@ public void fromXContentWithSessionId() throws IOException { + " \"sessionId\": \"00fdjevgkf12s00q\"\n" + "}"; CreateAsyncQueryRequest queryRequest = - CreateAsyncQueryRequest.fromXContentParser(xContentParser(request)); + CreateAsyncQueryRequestConverter.fromXContentParser(xContentParser(request)); Assertions.assertEquals("00fdjevgkf12s00q", queryRequest.getSessionId()); }