diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index ae82386c3f..d38c8554ae 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -6,7 +6,7 @@ package org.opensearch.sql.spark.asyncquery; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -22,7 +22,8 @@ public interface AsyncQueryExecutorService { * @return {@link CreateAsyncQueryResponse} */ CreateAsyncQueryResponse createAsyncQuery( - CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext); + CreateAsyncQueryRequest createAsyncQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext); /** * Returns async query response for a given queryId. diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 14107712f1..6d3d5b6765 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -18,7 +18,7 @@ import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; @@ -37,23 +37,29 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService @Override public CreateAsyncQueryResponse createAsyncQuery( - CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext) { + CreateAsyncQueryRequest createAsyncQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext) { SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext); + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext); DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - sparkExecutionEngineConfig.getApplicationId(), - createAsyncQueryRequest.getQuery(), - createAsyncQueryRequest.getDatasource(), - createAsyncQueryRequest.getLang(), - sparkExecutionEngineConfig.getExecutionRoleARN(), - sparkExecutionEngineConfig.getClusterName(), - sparkExecutionEngineConfig.getSparkSubmitParameterModifier(), - createAsyncQueryRequest.getSessionId())); + DispatchQueryRequest.builder() + .accountId(sparkExecutionEngineConfig.getAccountId()) + .applicationId(sparkExecutionEngineConfig.getApplicationId()) + .query(createAsyncQueryRequest.getQuery()) + .datasource(createAsyncQueryRequest.getDatasource()) + .langType(createAsyncQueryRequest.getLang()) + .executionRoleARN(sparkExecutionEngineConfig.getExecutionRoleARN()) + .clusterName(sparkExecutionEngineConfig.getClusterName()) + .sparkSubmitParameterModifier( + sparkExecutionEngineConfig.getSparkSubmitParameterModifier()) + .sessionId(createAsyncQueryRequest.getSessionId()) + .build(), + asyncQueryRequestContext); asyncQueryJobMetadataStorageService.storeJobMetadata( AsyncQueryJobMetadata.builder() .queryId(dispatchQueryResponse.getQueryId()) + .accountId(sparkExecutionEngineConfig.getAccountId()) .applicationId(sparkExecutionEngineConfig.getApplicationId()) .jobId(dispatchQueryResponse.getJobId()) .resultIndex(dispatchQueryResponse.getResultIndex()) @@ -61,7 +67,8 @@ public CreateAsyncQueryResponse createAsyncQuery( .datasourceName(dispatchQueryResponse.getDatasourceName()) .jobType(dispatchQueryResponse.getJobType()) .indexName(dispatchQueryResponse.getIndexName()) - .build()); + .build(), + asyncQueryRequestContext); return new CreateAsyncQueryResponse( dispatchQueryResponse.getQueryId(), dispatchQueryResponse.getSessionId()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java index 4ce34458cd..b4e94c984d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java @@ -9,10 +9,13 @@ import java.util.Optional; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; public interface AsyncQueryJobMetadataStorageService { - void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata); + void storeJobMetadata( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext); Optional getJobMetadata(String jobId); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java index 5356f14143..4847c8e00f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; @@ -28,7 +29,9 @@ public class OpenSearchAsyncQueryJobMetadataStorageService LogManager.getLogger(OpenSearchAsyncQueryJobMetadataStorageService.class); @Override - public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public void storeJobMetadata( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { stateStore.create( mapIdToDocumentId(asyncQueryJobMetadata.getId()), asyncQueryJobMetadata, diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/RequestContext.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java similarity index 84% rename from spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/RequestContext.java rename to spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java index 3a0f350701..56176faefb 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/RequestContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java @@ -6,6 +6,6 @@ package org.opensearch.sql.spark.asyncquery.model; /** Context interface to provide additional request related information */ -public interface RequestContext { +public interface AsyncQueryRequestContext { Object getAttribute(String name); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullRequestContext.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullAsyncQueryRequestContext.java similarity index 78% rename from spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullRequestContext.java rename to spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullAsyncQueryRequestContext.java index e106f57cff..918d1d5929 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullRequestContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/NullAsyncQueryRequestContext.java @@ -6,7 +6,7 @@ package org.opensearch.sql.spark.asyncquery.model; /** An implementation of RequestContext for where context is not required */ -public class NullRequestContext implements RequestContext { +public class NullAsyncQueryRequestContext implements AsyncQueryRequestContext { @Override public Object getAttribute(String name) { return null; diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java index 4250d32b0e..54d4f65351 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java @@ -13,7 +13,7 @@ import java.security.AccessController; import java.security.PrivilegedAction; import lombok.RequiredArgsConstructor; -import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -34,7 +34,7 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor public EMRServerlessClient getClient() { SparkExecutionEngineConfig sparkExecutionEngineConfig = this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig( - new NullRequestContext()); + new NullAsyncQueryRequestContext()); validateSparkExecutionEngineConfig(sparkExecutionEngineConfig); if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) { region = sparkExecutionEngineConfig.getRegion(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java index b5d061bad3..725df6bb0c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java @@ -1,6 +1,6 @@ package org.opensearch.sql.spark.config; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; /** Interface for extracting and providing SparkExecutionEngineConfig */ public interface SparkExecutionEngineConfigSupplier { @@ -10,5 +10,6 @@ public interface SparkExecutionEngineConfigSupplier { * * @return {@link SparkExecutionEngineConfig}. */ - SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext); + SparkExecutionEngineConfig getSparkExecutionEngineConfig( + AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java index 69a402bdfc..8d2c40f4cd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java @@ -9,7 +9,7 @@ import org.apache.commons.lang3.StringUtils; import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; @AllArgsConstructor public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier { @@ -17,7 +17,8 @@ public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEng private Settings settings; @Override - public SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext) { + public SparkExecutionEngineConfig getSparkExecutionEngineConfig( + AsyncQueryRequestContext asyncQueryRequestContext) { ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME); return getBuilderFromSettingsIfAvailable().clusterName(clusterName.value()).build(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index 199f24977c..e8413f469c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -17,6 +17,7 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -72,7 +73,8 @@ public DispatchQueryResponse submit( dataSourceMetadata, JobRunState.SUCCESS.toString(), StringUtils.EMPTY, - getElapsedTimeSince(startTime)); + getElapsedTimeSince(startTime), + context.getAsyncQueryRequestContext()); return DispatchQueryResponse.builder() .queryId(asyncQueryId) .jobId(DML_QUERY_JOB_ID) @@ -89,7 +91,8 @@ public DispatchQueryResponse submit( dataSourceMetadata, JobRunState.FAILED.toString(), e.getMessage(), - getElapsedTimeSince(startTime)); + getElapsedTimeSince(startTime), + context.getAsyncQueryRequestContext()); return DispatchQueryResponse.builder() .queryId(asyncQueryId) .jobId(DML_QUERY_JOB_ID) @@ -106,7 +109,8 @@ private String storeIndexDMLResult( DataSourceMetadata dataSourceMetadata, String status, String error, - long queryRunTime) { + long queryRunTime, + AsyncQueryRequestContext asyncQueryRequestContext) { IndexDMLResult indexDMLResult = IndexDMLResult.builder() .queryId(queryId) @@ -116,7 +120,7 @@ private String storeIndexDMLResult( .queryRunTime(queryRunTime) .updateTime(System.currentTimeMillis()) .build(); - indexDMLResultStorageService.createIndexDMLResult(indexDMLResult); + indexDMLResultStorageService.createIndexDMLResult(indexDMLResult, asyncQueryRequestContext); return queryId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index e41f4a49fd..5a3a78a33b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -110,14 +110,16 @@ public DispatchQueryResponse submit( .acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()), tags, dataSourceMetadata.getResultIndex(), - dataSourceMetadata.getName())); + dataSourceMetadata.getName()), + context.getAsyncQueryRequestContext()); MetricUtils.incrementNumericalMetric(MetricName.EMR_INTERACTIVE_QUERY_JOBS_CREATION_COUNT); } session.submit( new QueryRequest( context.getQueryId(), dispatchQueryRequest.getLangType(), - dispatchQueryRequest.getQuery())); + dispatchQueryRequest.getQuery()), + context.getAsyncQueryRequestContext()); return DispatchQueryResponse.builder() .queryId(context.getQueryId()) .jobId(session.getSessionModel().getJobId()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 67d2767493..24950b5cfe 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -13,6 +13,7 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -37,7 +38,9 @@ public class SparkQueryDispatcher { private final QueryHandlerFactory queryHandlerFactory; private final QueryIdProvider queryIdProvider; - public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { + public DispatchQueryResponse dispatch( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( dispatchQueryRequest.getDatasource()); @@ -48,13 +51,16 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); return getQueryHandlerForFlintExtensionQuery(indexQueryDetails) .submit(dispatchQueryRequest, context); } else { DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build(); + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .asyncQueryRequestContext(asyncQueryRequestContext) + .build(); return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java index 7b694e47f0..aabe43f641 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java @@ -9,6 +9,7 @@ import lombok.Builder; import lombok.Getter; import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; @Getter @Builder @@ -17,4 +18,5 @@ public class DispatchQueryContext { private final DataSourceMetadata dataSourceMetadata; private final Map tags; private final IndexQueryDetails indexQueryDetails; + private final AsyncQueryRequestContext asyncQueryRequestContext; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 9920fb9aec..cfbbeff339 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -17,6 +17,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.execution.statement.QueryRequest; @@ -49,22 +50,29 @@ public class InteractiveSession implements Session { private TimeProvider timeProvider; @Override - public void open(CreateSessionRequest createSessionRequest) { + public void open( + CreateSessionRequest createSessionRequest, + AsyncQueryRequestContext asyncQueryRequestContext) { try { // append session id; createSessionRequest .getSparkSubmitParameters() - .sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName()); + .acceptModifier( + (parameters) -> { + parameters.sessionExecution( + sessionId.getSessionId(), createSessionRequest.getDatasourceName()); + }); createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId()); StartJobRequest startJobRequest = createSessionRequest.getStartJobRequest(sessionId.getSessionId()); String jobID = serverlessClient.startJobRun(startJobRequest); String applicationId = startJobRequest.getApplicationId(); + String accountId = createSessionRequest.getAccountId(); sessionModel = initInteractiveSession( - applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - sessionStorageService.createSession(sessionModel); + accountId, applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); + sessionStorageService.createSession(sessionModel, asyncQueryRequestContext); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -86,7 +94,8 @@ public void close() { } /** Submit statement. If submit successfully, Statement in waiting state. */ - public StatementId submit(QueryRequest request) { + public StatementId submit( + QueryRequest request, AsyncQueryRequestContext asyncQueryRequestContext) { Optional model = sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName()); if (model.isEmpty()) { @@ -99,6 +108,7 @@ public StatementId submit(QueryRequest request) { Statement st = Statement.builder() .sessionId(sessionId) + .accountId(sessionModel.getAccountId()) .applicationId(sessionModel.getApplicationId()) .jobId(sessionModel.getJobId()) .statementStorageService(statementStorageService) @@ -107,6 +117,7 @@ public StatementId submit(QueryRequest request) { .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) .queryId(qid) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); st.open(); return statementId; @@ -130,6 +141,7 @@ public Optional get(StatementId stID) { model -> Statement.builder() .sessionId(sessionId) + .accountId(model.getAccountId()) .applicationId(model.getApplicationId()) .jobId(model.getJobId()) .statementId(model.getStatementId()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index e684d33989..2f0fcea650 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.session; import java.util.Optional; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; @@ -13,7 +14,8 @@ /** Session define the statement execution context. Each session is binding to one Spark Job. */ public interface Session { /** open session. */ - void open(CreateSessionRequest createSessionRequest); + void open( + CreateSessionRequest createSessionRequest, AsyncQueryRequestContext asyncQueryRequestContext); /** close session. */ void close(); @@ -22,9 +24,10 @@ public interface Session { * submit {@link QueryRequest}. * * @param request {@link QueryRequest} + * @param asyncQueryRequestContext {@link AsyncQueryRequestContext} * @return {@link StatementId} */ - StatementId submit(QueryRequest request); + StatementId submit(QueryRequest request, AsyncQueryRequestContext asyncQueryRequestContext); /** * get {@link Statement}. diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 685fbdf5fa..3a147c00e3 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -9,6 +9,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StatementStorageService; @@ -26,7 +27,8 @@ public class SessionManager { private final EMRServerlessClientFactory emrServerlessClientFactory; private final SessionConfigSupplier sessionConfigSupplier; - public Session createSession(CreateSessionRequest request) { + public Session createSession( + CreateSessionRequest request, AsyncQueryRequestContext asyncQueryRequestContext) { InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId(request.getDatasourceName())) @@ -34,7 +36,7 @@ public Session createSession(CreateSessionRequest request) { .statementStorageService(statementStorageService) .serverlessClient(emrServerlessClientFactory.getClient()) .build(); - session.open(request); + session.open(request, asyncQueryRequestContext); return session; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index b0205aec64..06a4d9ee98 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -14,6 +14,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.index.engine.DocumentMissingException; import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.statestore.StatementStorageService; import org.opensearch.sql.spark.rest.model.LangType; @@ -32,6 +33,7 @@ public class Statement { private final String datasourceName; private final String query; private final String queryId; + private final AsyncQueryRequestContext asyncQueryRequestContext; private final StatementStorageService statementStorageService; @Setter private StatementModel statementModel; @@ -49,7 +51,8 @@ public void open() { datasourceName, query, queryId); - statementModel = statementStorageService.createStatement(statementModel); + statementModel = + statementStorageService.createStatement(statementModel, asyncQueryRequestContext); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; LOG.error(errorMsg); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java index a4e9ede5ab..eefc6a9b14 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchSessionStorageService.java @@ -7,6 +7,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.xcontent.SessionModelXContentSerializer; @@ -18,7 +19,8 @@ public class OpenSearchSessionStorageService implements SessionStorageService { private final SessionModelXContentSerializer serializer; @Override - public SessionModel createSession(SessionModel sessionModel) { + public SessionModel createSession( + SessionModel sessionModel, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.create( sessionModel.getId(), sessionModel, diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java index 9e74ad9810..5fcccc22a4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java @@ -7,6 +7,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.execution.xcontent.StatementModelXContentSerializer; @@ -18,7 +19,8 @@ public class OpenSearchStatementStorageService implements StatementStorageServic private final StatementModelXContentSerializer serializer; @Override - public StatementModel createStatement(StatementModel statementModel) { + public StatementModel createStatement( + StatementModel statementModel, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.create( statementModel.getId(), statementModel, diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java index f67612b115..476e65714b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java @@ -6,13 +6,15 @@ package org.opensearch.sql.spark.execution.statestore; import java.util.Optional; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; /** Interface for accessing {@link SessionModel} data storage. */ public interface SessionStorageService { - SessionModel createSession(SessionModel sessionModel); + SessionModel createSession( + SessionModel sessionModel, AsyncQueryRequestContext asyncQueryRequestContext); Optional getSession(String id, String datasourceName); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java index 9253a4850d..39f1ecf704 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.statestore; import java.util.Optional; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; @@ -15,7 +16,8 @@ */ public interface StatementStorageService { - StatementModel createStatement(StatementModel statementModel); + StatementModel createStatement( + StatementModel statementModel, AsyncQueryRequestContext asyncQueryRequestContext); StatementModel updateStatementState( StatementModel oldStatementModel, StatementState statementState); diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java index c816572d02..9053e5dbc8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java @@ -5,11 +5,13 @@ package org.opensearch.sql.spark.flint; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; /** * Abstraction over the IndexDMLResult storage. It stores the result of IndexDML query execution. */ public interface IndexDMLResultStorageService { - IndexDMLResult createIndexDMLResult(IndexDMLResult result); + IndexDMLResult createIndexDMLResult( + IndexDMLResult result, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java index f5a1f70d1c..3be44ba410 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -18,7 +19,8 @@ public class OpenSearchIndexDMLResultStorageService implements IndexDMLResultSto private final StateStore stateStore; @Override - public IndexDMLResult createIndexDMLResult(IndexDMLResult result) { + public IndexDMLResult createIndexDMLResult( + IndexDMLResult result, AsyncQueryRequestContext asyncQueryRequestContexts) { DataSourceMetadata dataSourceMetadata = dataSourceService.getDataSourceMetadata(result.getDatasourceName()); return stateStore.create( diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java index d669875304..bef3b29987 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCreateAsyncQueryRequestAction.java @@ -18,7 +18,7 @@ import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; -import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionRequest; @@ -66,7 +66,7 @@ protected void doExecute( CreateAsyncQueryRequest createAsyncQueryRequest = request.getCreateAsyncQueryRequest(); CreateAsyncQueryResponse createAsyncQueryResponse = asyncQueryExecutorService.createAsyncQuery( - createAsyncQueryRequest, new NullRequestContext()); + createAsyncQueryRequest, new NullAsyncQueryRequestContext()); String responseContent = new JsonResponseFormatter(JsonResponseFormatter.Style.PRETTY) { @Override diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 2adf4aef7e..b7848718b9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -31,8 +31,8 @@ import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.exceptions.DatasourceDisabledException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionState; @@ -44,7 +44,7 @@ import org.opensearch.sql.spark.rest.model.LangType; public class AsyncQueryExecutorServiceImplSpecTest extends AsyncQueryExecutorServiceSpec { - RequestContext requestContext = new NullRequestContext(); + AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); @Disabled("batch query is unsupported") public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { @@ -60,7 +60,7 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertFalse(clusterService().state().routingTable().hasIndex(SPARK_REQUEST_BUFFER_INDEX_NAME)); emrsClient.startJobRunCalled(1); @@ -91,13 +91,13 @@ public void sessionLimitNotImpactBatchQuery() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); emrsClient.startJobRunCalled(1); CreateAsyncQueryResponse resp2 = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); emrsClient.startJobRunCalled(2); } @@ -112,7 +112,7 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); String params = emrsClient.getJobRequest().getSparkSubmitParams(); assertNull(response.getSessionId()); assertTrue(params.contains(String.format("--class %s", DEFAULT_CLASS_NAME))); @@ -127,7 +127,7 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); params = emrsClient.getJobRequest().getSparkSubmitParams(); assertTrue(params.contains(String.format("--class %s", FLINT_SESSION_CLASS_NAME))); assertTrue( @@ -148,7 +148,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(response.getSessionId()); Optional statementModel = statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); @@ -181,7 +181,7 @@ public void reuseSessionWhenCreateAsyncQuery() { CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(first.getSessionId()); // 2. reuse session id @@ -189,7 +189,7 @@ public void reuseSessionWhenCreateAsyncQuery() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId()), - requestContext); + asyncQueryRequestContext); assertEquals(first.getSessionId(), second.getSessionId()); assertNotEquals(first.getQueryId(), second.getQueryId()); @@ -232,7 +232,7 @@ public void batchQueryHasTimeout() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertEquals(120L, (long) emrsClient.getJobRequest().executionTimeout()); } @@ -249,7 +249,7 @@ public void interactiveQueryNoTimeout() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertEquals(0L, (long) emrsClient.getJobRequest().executionTimeout()); } @@ -282,7 +282,8 @@ public void datasourceWithBasicAuth() { enableSession(true); asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null), requestContext); + new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null), + asyncQueryRequestContext); String params = emrsClient.getJobRequest().getSparkSubmitParams(); assertTrue(params.contains(String.format("--conf spark.datasource.flint.auth=basic"))); assertTrue( @@ -305,7 +306,7 @@ public void withSessionCreateAsyncQueryFailed() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(response.getSessionId()); Optional statementModel = statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); @@ -356,7 +357,7 @@ public void createSessionMoreThanLimitFailed() { CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(first.getSessionId()); setSessionState(first.getSessionId(), SessionState.RUNNING); @@ -367,7 +368,7 @@ public void createSessionMoreThanLimitFailed() { () -> asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext)); + asyncQueryRequestContext)); assertEquals("domain concurrent active session can not exceed 1", exception.getMessage()); } @@ -386,7 +387,7 @@ public void recreateSessionIfNotReady() { CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(first.getSessionId()); // set sessionState to FAIL @@ -397,7 +398,7 @@ public void recreateSessionIfNotReady() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId()), - requestContext); + asyncQueryRequestContext); assertNotEquals(first.getSessionId(), second.getSessionId()); @@ -409,7 +410,7 @@ public void recreateSessionIfNotReady() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "select 1", MYS3_DATASOURCE, LangType.SQL, second.getSessionId()), - requestContext); + asyncQueryRequestContext); assertNotEquals(second.getSessionId(), third.getSessionId()); } @@ -428,7 +429,7 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "SHOW SCHEMAS IN " + MYS3_DATASOURCE, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(first.getSessionId()); // set sessionState to RUNNING @@ -442,7 +443,7 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { MYS3_DATASOURCE, LangType.SQL, first.getSessionId()), - requestContext); + asyncQueryRequestContext); assertEquals(first.getSessionId(), second.getSessionId()); @@ -457,7 +458,7 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { MYGLUE_DATASOURCE, LangType.SQL, second.getSessionId()), - requestContext); + asyncQueryRequestContext); assertNotEquals(second.getSessionId(), third.getSessionId()); } @@ -475,7 +476,7 @@ public void recreateSessionIfStale() { CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(first.getSessionId()); // set sessionState to RUNNING @@ -486,7 +487,7 @@ public void recreateSessionIfStale() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "select 1", MYS3_DATASOURCE, LangType.SQL, first.getSessionId()), - requestContext); + asyncQueryRequestContext); assertEquals(first.getSessionId(), second.getSessionId()); @@ -505,7 +506,7 @@ public void recreateSessionIfStale() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "select 1", MYS3_DATASOURCE, LangType.SQL, second.getSessionId()), - requestContext); + asyncQueryRequestContext); assertNotEquals(second.getSessionId(), third.getSessionId()); } finally { // set timeout setting to 0 @@ -535,7 +536,7 @@ public void submitQueryInInvalidSessionWillCreateNewSession() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "select 1", MYS3_DATASOURCE, LangType.SQL, invalidSessionId.getSessionId()), - requestContext); + asyncQueryRequestContext); assertNotNull(asyncQuery.getSessionId()); assertNotEquals(invalidSessionId.getSessionId(), asyncQuery.getSessionId()); } @@ -568,7 +569,8 @@ public void datasourceNameIncludeUppercase() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("select 1", "TESTS3", LangType.SQL, null), requestContext); + new CreateAsyncQueryRequest("select 1", "TESTS3", LangType.SQL, null), + asyncQueryRequestContext); String params = emrsClient.getJobRequest().getSparkSubmitParams(); assertNotNull(response.getSessionId()); @@ -591,7 +593,7 @@ public void concurrentSessionLimitIsDomainLevel() { CreateAsyncQueryResponse first = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(first.getSessionId()); setSessionState(first.getSessionId(), SessionState.RUNNING); @@ -602,7 +604,7 @@ public void concurrentSessionLimitIsDomainLevel() { () -> asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYGLUE_DATASOURCE, LangType.SQL, null), - requestContext)); + asyncQueryRequestContext)); assertEquals("domain concurrent active session can not exceed 1", exception.getMessage()); } @@ -623,7 +625,7 @@ public void testDatasourceDisabled() { try { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); fail("It should have thrown DataSourceDisabledException"); } catch (DatasourceDisabledException exception) { Assertions.assertEquals("Datasource mys3 is disabled.", exception.getMessage()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 96ed18e897..b87fb0dad7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -31,7 +32,7 @@ import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.config.OpenSearchSparkSubmitParameterModifier; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -53,7 +54,7 @@ public class AsyncQueryExecutorServiceImplTest { @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; - @Mock private RequestContext requestContext; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; private final String QUERY_ID = "QUERY_ID"; @BeforeEach @@ -72,22 +73,24 @@ void testCreateAsyncQuery() { "select * from my_glue.default.http_logs", "my_glue", LangType.SQL); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( - new SparkExecutionEngineConfig( - EMRS_APPLICATION_ID, - "eu-west-1", - EMRS_EXECUTION_ROLE, - sparkSubmitParameterModifier, - TEST_CLUSTER_NAME)); + SparkExecutionEngineConfig.builder() + .applicationId(EMRS_APPLICATION_ID) + .region("eu-west-1") + .executionRoleARN(EMRS_EXECUTION_ROLE) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .clusterName(TEST_CLUSTER_NAME) + .build()); DispatchQueryRequest expectedDispatchQueryRequest = - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - "select * from my_glue.default.http_logs", - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier); - when(sparkQueryDispatcher.dispatch(expectedDispatchQueryRequest)) + DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query("select * from my_glue.default.http_logs") + .datasource("my_glue") + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .build(); + when(sparkQueryDispatcher.dispatch(expectedDispatchQueryRequest, asyncQueryRequestContext)) .thenReturn( DispatchQueryResponse.builder() .queryId(QUERY_ID) @@ -96,15 +99,16 @@ void testCreateAsyncQuery() { .build()); CreateAsyncQueryResponse createAsyncQueryResponse = - jobExecutorService.createAsyncQuery(createAsyncQueryRequest, requestContext); + jobExecutorService.createAsyncQuery(createAsyncQueryRequest, asyncQueryRequestContext); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata(getAsyncQueryJobMetadata()); + .storeJobMetadata(getAsyncQueryJobMetadata(), asyncQueryRequestContext); verify(sparkExecutionEngineConfigSupplier, times(1)) - .getSparkExecutionEngineConfig(requestContext); + .getSparkExecutionEngineConfig(asyncQueryRequestContext); verify(sparkExecutionEngineConfigSupplier, times(1)) - .getSparkExecutionEngineConfig(requestContext); - verify(sparkQueryDispatcher, times(1)).dispatch(expectedDispatchQueryRequest); + .getSparkExecutionEngineConfig(asyncQueryRequestContext); + verify(sparkQueryDispatcher, times(1)) + .dispatch(expectedDispatchQueryRequest, asyncQueryRequestContext); Assertions.assertEquals(QUERY_ID, createAsyncQueryResponse.getQueryId()); } @@ -114,13 +118,15 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { new OpenSearchSparkSubmitParameterModifier("--conf spark.dynamicAllocation.enabled=false"); when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())) .thenReturn( - new SparkExecutionEngineConfig( - EMRS_APPLICATION_ID, - "eu-west-1", - EMRS_EXECUTION_ROLE, - modifier, - TEST_CLUSTER_NAME)); - when(sparkQueryDispatcher.dispatch(any())) + SparkExecutionEngineConfig.builder() + .applicationId(EMRS_APPLICATION_ID) + .region("eu-west-1") + .executionRoleARN(EMRS_EXECUTION_ROLE) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .sparkSubmitParameterModifier(modifier) + .clusterName(TEST_CLUSTER_NAME) + .build()); + when(sparkQueryDispatcher.dispatch(any(), any())) .thenReturn( DispatchQueryResponse.builder() .queryId(QUERY_ID) @@ -131,11 +137,12 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( "select * from my_glue.default.http_logs", "my_glue", LangType.SQL), - requestContext); + asyncQueryRequestContext); verify(sparkQueryDispatcher, times(1)) .dispatch( - argThat(actualReq -> actualReq.getSparkSubmitParameterModifier().equals(modifier))); + argThat(actualReq -> actualReq.getSparkSubmitParameterModifier().equals(modifier)), + eq(asyncQueryRequestContext)); } @Test @@ -161,6 +168,7 @@ void testGetAsyncQueryResultsWithInProgressJob() { JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); when(sparkQueryDispatcher.getQueryResponse(getAsyncQueryJobMetadata())).thenReturn(jobResult); + AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 9c378b9274..89819ddf48 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -52,7 +52,7 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; @@ -104,7 +104,7 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { protected StateStore stateStore; protected SessionStorageService sessionStorageService; protected StatementStorageService statementStorageService; - protected RequestContext requestContext; + protected AsyncQueryRequestContext asyncQueryRequestContext; @Override protected Collection> nodePlugins() { @@ -342,7 +342,8 @@ public EMRServerlessClient getClient() { } } - public SparkExecutionEngineConfig sparkExecutionEngineConfig(RequestContext requestContext) { + public SparkExecutionEngineConfig sparkExecutionEngineConfig( + AsyncQueryRequestContext asyncQueryRequestContext) { return SparkExecutionEngineConfig.builder() .applicationId("appId") .region("us-west-2") diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index d80c13367f..12fa8043ea 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -24,10 +24,10 @@ import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; -import org.opensearch.sql.spark.asyncquery.model.NullRequestContext; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; @@ -40,7 +40,7 @@ import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; public class AsyncQueryGetResultSpecTest extends AsyncQueryExecutorServiceSpec { - RequestContext requestContext = new NullRequestContext(); + AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); /** Mock Flint index and index state */ private final FlintDatasetMock mockIndex = @@ -440,7 +440,7 @@ public JSONObject getResultWithQueryId(String queryId, String resultIndex) { this.createQueryResponse = queryService.createAsyncQuery( new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); } AssertionHelper withInteraction(Interaction interaction) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java index 4786e496e0..801a24922f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java @@ -77,7 +77,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -146,7 +146,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -228,7 +228,7 @@ public CancelJobRunResult cancelJobRun( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -292,7 +292,7 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result assertEquals( @@ -358,7 +358,7 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result assertEquals( @@ -433,7 +433,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -508,7 +508,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -577,7 +577,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -639,7 +639,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -703,7 +703,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -767,7 +767,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -828,7 +828,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -887,7 +887,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -954,7 +954,7 @@ public CancelJobRunResult cancelJobRun( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -1019,7 +1019,7 @@ public CancelJobRunResult cancelJobRun( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -1085,7 +1085,7 @@ public CancelJobRunResult cancelJobRun( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.getQuery(), MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 486ccf7031..b4962240f5 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -136,7 +136,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(response.getQueryId()); assertTrue(clusterService.state().routingTable().hasIndex(mockDS.indexName)); @@ -187,7 +187,7 @@ public CancelJobRunResult cancelJobRun( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -227,7 +227,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = @@ -264,7 +264,7 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(mockDS.query, MYGLUE_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -307,7 +307,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(response.getQueryId()); assertTrue(clusterService.state().routingTable().hasIndex(mockDS.indexName)); @@ -367,7 +367,7 @@ public CancelJobRunResult cancelJobRun( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -414,7 +414,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = @@ -460,7 +460,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result assertEquals( @@ -511,7 +511,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -559,7 +559,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result assertEquals( @@ -606,7 +606,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result assertEquals( @@ -661,7 +661,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); AsyncQueryExecutionResponse asyncQueryExecutionResponse = asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); @@ -706,7 +706,7 @@ public CancelJobRunResult cancelJobRun( CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(mockDS.query, MYGLUE_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = @@ -754,7 +754,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = @@ -784,7 +784,7 @@ public void concurrentRefreshJobLimitNotApplied() { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNull(response.getSessionId()); } @@ -813,7 +813,7 @@ public void concurrentRefreshJobLimitAppliedToDDLWithAuthRefresh() { () -> asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext)); + asyncQueryRequestContext)); assertEquals("domain concurrent refresh job can not exceed 1", exception.getMessage()); } @@ -840,7 +840,7 @@ public void concurrentRefreshJobLimitAppliedToRefresh() { () -> asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext)); + asyncQueryRequestContext)); assertEquals("domain concurrent refresh job can not exceed 1", exception.getMessage()); } @@ -863,7 +863,7 @@ public void concurrentRefreshJobLimitNotAppliedToDDL() { CreateAsyncQueryResponse asyncQueryResponse = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); assertNotNull(asyncQueryResponse.getSessionId()); } @@ -896,7 +896,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // 2. cancel query IllegalArgumentException exception = @@ -940,7 +940,7 @@ public GetJobRunResult getJobRunResult( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.refreshQuery, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // mock index state. flintIndexJob.refreshing(); @@ -985,7 +985,7 @@ public GetJobRunResult getJobRunResult( asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( mockDS.refreshQuery, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // mock index state. flintIndexJob.active(); @@ -1032,7 +1032,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); // mock index state. flintIndexJob.refreshing(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java index c289bbe53f..3bccf1b30b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java @@ -172,7 +172,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - requestContext); + asyncQueryRequestContext); return asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageServiceTest.java index a0baaefab8..c84d68421d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageServiceTest.java @@ -14,6 +14,8 @@ import org.junit.jupiter.api.Assertions; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; import org.opensearch.sql.spark.utils.IDUtils; @@ -26,6 +28,7 @@ public class OpenSearchAsyncQueryJobMetadataStorageServiceTest extends OpenSearc private static final String MOCK_RESULT_INDEX = "resultIndex"; private static final String MOCK_QUERY_ID = "00fdo6u94n7abo0q"; private OpenSearchAsyncQueryJobMetadataStorageService openSearchJobMetadataStorageService; + private AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); @Before public void setup() { @@ -46,7 +49,7 @@ public void testStoreJobMetadata() { .datasourceName(DS_NAME) .build(); - openSearchJobMetadataStorageService.storeJobMetadata(expected); + openSearchJobMetadataStorageService.storeJobMetadata(expected, asyncQueryRequestContext); Optional actual = openSearchJobMetadataStorageService.getJobMetadata(expected.getQueryId()); @@ -68,7 +71,7 @@ public void testStoreJobMetadataWithResultExtraData() { .datasourceName(DS_NAME) .build(); - openSearchJobMetadataStorageService.storeJobMetadata(expected); + openSearchJobMetadataStorageService.storeJobMetadata(expected, asyncQueryRequestContext); Optional actual = openSearchJobMetadataStorageService.getJobMetadata(expected.getQueryId()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java index 0eb6be0f64..2409d32726 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java @@ -15,14 +15,14 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.cluster.ClusterName; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.spark.asyncquery.model.RequestContext; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; @ExtendWith(MockitoExtension.class) public class SparkExecutionEngineConfigSupplierImplTest { @Mock private Settings settings; - @Mock private RequestContext requestContext; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Test void testGetSparkExecutionEngineConfig() { @@ -34,7 +34,7 @@ void testGetSparkExecutionEngineConfig() { .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext); + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext); SparkSubmitParameters parameters = SparkSubmitParameters.builder().build(); sparkExecutionEngineConfig.getSparkSubmitParameterModifier().modifyParameters(parameters); @@ -63,7 +63,7 @@ void testGetSparkExecutionEngineConfigWithNullSetting() { .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext); + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext); Assertions.assertNull(sparkExecutionEngineConfig.getApplicationId()); Assertions.assertNull(sparkExecutionEngineConfig.getExecutionRoleARN()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 5d04c86cce..ef9e3736c7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -57,6 +57,7 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; @@ -91,6 +92,7 @@ public class SparkQueryDispatcherTest { @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; @Mock private QueryIdProvider queryIdProvider; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -140,6 +142,7 @@ void testDispatchSelectQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -153,14 +156,16 @@ void testDispatchSelectQuery() { DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query(query) + .datasource("my_glue") + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .build(), + asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -189,6 +194,7 @@ void testDispatchSelectQueryWithLakeFormation() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -201,15 +207,7 @@ void testDispatchSelectQueryWithLakeFormation() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -237,6 +235,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -249,15 +248,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -284,6 +275,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -296,16 +288,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); - + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); @@ -318,15 +301,16 @@ void testDispatchSelectQueryCreateNewSession() { DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); doReturn(true).when(sessionManager).isEnabled(); - doReturn(session).when(sessionManager).createSession(any()); + doReturn(session).when(sessionManager).createSession(any(), any()); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); - doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata); - DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch(queryRequest, asyncQueryRequestContext); verifyNoInteractions(emrServerlessClient); verify(sessionManager, never()).getSession(any()); @@ -344,17 +328,18 @@ void testDispatchSelectQueryReuseSession() { .when(sessionManager) .getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); - doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); when(session.isOperationalForDataSource(any())).thenReturn(true); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata); - DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch(queryRequest); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch(queryRequest, asyncQueryRequestContext); verifyNoInteractions(emrServerlessClient); - verify(sessionManager, never()).createSession(any()); + verify(sessionManager, never()).createSession(any(), any()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } @@ -365,13 +350,14 @@ void testDispatchSelectQueryFailedCreateSession() { DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, null); doReturn(true).when(sessionManager).isEnabled(); - doThrow(RuntimeException.class).when(sessionManager).createSession(any()); + doThrow(RuntimeException.class).when(sessionManager).createSession(any(), any()); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) .thenReturn(dataSourceMetadata); Assertions.assertThrows( - RuntimeException.class, () -> sparkQueryDispatcher.dispatch(queryRequest)); + RuntimeException.class, + () -> sparkQueryDispatcher.dispatch(queryRequest, asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); } @@ -400,6 +386,7 @@ void testDispatchIndexQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -412,15 +399,7 @@ void testDispatchIndexQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -448,6 +427,7 @@ void testDispatchWithPPLQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -461,14 +441,8 @@ void testDispatchWithPPLQuery() { DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.PPL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + getBaseDispatchQueryRequestBuilder(query).langType(LangType.PPL).build(), + asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -496,6 +470,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -508,15 +483,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -548,6 +515,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -560,15 +528,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -600,6 +560,7 @@ void testDispatchMaterializedViewQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_mv_1", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -612,15 +573,7 @@ void testDispatchMaterializedViewQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -648,6 +601,7 @@ void testDispatchShowMVQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -660,15 +614,7 @@ void testDispatchShowMVQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -696,6 +642,7 @@ void testRefreshIndexQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -708,15 +655,7 @@ void testRefreshIndexQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -744,6 +683,7 @@ void testDispatchDescribeIndexQuery() { StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", + null, EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -756,15 +696,7 @@ void testDispatchDescribeIndexQuery() { .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier)); + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); @@ -783,14 +715,7 @@ void testDispatchWithWrongURI() { IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier))); + getBaseDispatchQueryRequest(query), asyncQueryRequestContext)); Assertions.assertEquals( "Bad URI in indexstore configuration of the : my_glue datasoure.", @@ -808,14 +733,8 @@ void testDispatchWithUnSupportedDataSourceType() { UnsupportedOperationException.class, () -> sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_prometheus", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier))); + getBaseDispatchQueryRequestBuilder(query).datasource("my_prometheus").build(), + asyncQueryRequestContext)); Assertions.assertEquals( "UnSupported datasource type for async queries:: PROMETHEUS", @@ -1021,7 +940,7 @@ void testDispatchQueryWithExtraSparkSubmitParameters() { for (DispatchQueryRequest request : requests) { when(emrServerlessClient.startJobRun(any())).thenReturn(EMR_JOB_ID); - sparkQueryDispatcher.dispatch(request); + sparkQueryDispatcher.dispatch(request, asyncQueryRequestContext); verify(emrServerlessClient, times(1)) .startJobRun( @@ -1187,29 +1106,33 @@ private DataSourceMetadata constructPrometheusDataSourceType() { .build(); } + private DispatchQueryRequest getBaseDispatchQueryRequest(String query) { + return getBaseDispatchQueryRequestBuilder(query).build(); + } + + private DispatchQueryRequest.DispatchQueryRequestBuilder getBaseDispatchQueryRequestBuilder( + String query) { + return DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query(query) + .datasource("my_glue") + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier); + } + private DispatchQueryRequest constructDispatchQueryRequest( String query, LangType langType, String extraParameters) { - return new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - langType, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - (parameters) -> parameters.setExtraParameters(extraParameters), - null); + return getBaseDispatchQueryRequestBuilder(query) + .langType(langType) + .sparkSubmitParameterModifier( + (parameters) -> parameters.setExtraParameters(extraParameters)) + .build(); } private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, String sessionId) { - return new DispatchQueryRequest( - EMRS_APPLICATION_ID, - query, - "my_glue", - LangType.SQL, - EMRS_EXECUTION_ROLE, - TEST_CLUSTER_NAME, - sparkSubmitParameterModifier, - sessionId); + return getBaseDispatchQueryRequestBuilder(query).sessionId(sessionId).build(); } private AsyncQueryJobMetadata asyncQueryJobMetadata() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 0c606cc5df..a74a3a2737 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -18,6 +18,8 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.JobType; @@ -43,6 +45,7 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { private StatementStorageService statementStorageService; private SessionConfigSupplier sessionConfigSupplier = () -> 600000L; private SessionManager sessionManager; + private AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); @Before public void setup() { @@ -106,7 +109,7 @@ public void openSessionFailedConflict() { .statementStorageService(statementStorageService) .serverlessClient(emrsClient) .build(); - session.open(createSessionRequest()); + session.open(createSessionRequest(), asyncQueryRequestContext); InteractiveSession duplicateSession = InteractiveSession.builder() @@ -117,7 +120,8 @@ public void openSessionFailedConflict() { .build(); IllegalStateException exception = assertThrows( - IllegalStateException.class, () -> duplicateSession.open(createSessionRequest())); + IllegalStateException.class, + () -> duplicateSession.open(createSessionRequest(), asyncQueryRequestContext)); assertEquals("session already exist. " + sessionId, exception.getMessage()); } @@ -131,7 +135,7 @@ public void closeNotExistSession() { .statementStorageService(statementStorageService) .serverlessClient(emrsClient) .build(); - session.open(createSessionRequest()); + session.open(createSessionRequest(), asyncQueryRequestContext); client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); @@ -142,7 +146,8 @@ public void closeNotExistSession() { @Test public void sessionManagerCreateSession() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); new SessionAssertions(session) .assertSessionState(NOT_STARTED) @@ -152,7 +157,8 @@ public void sessionManagerCreateSession() { @Test public void sessionManagerGetSession() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); Optional managerSession = sessionManager.getSession(session.getSessionId()); assertTrue(managerSession.isPresent()); @@ -192,7 +198,7 @@ public SessionAssertions assertJobId(String expected) { } public SessionAssertions open(CreateSessionRequest req) { - session.open(req); + session.open(req, asyncQueryRequestContext); return this; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 9650e5a73c..b6b2279ea9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -20,6 +20,8 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionConfigSupplier; @@ -48,6 +50,7 @@ public class StatementTest extends OpenSearchIntegTestCase { private SessionConfigSupplier sessionConfigSupplier = () -> 600000L; private SessionManager sessionManager; + private AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); @Before public void setup() { @@ -222,31 +225,36 @@ public void cancelRunningStatementSuccess() { @Test public void submitStatementInRunningSession() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); // App change state to running sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(queryRequest()); + StatementId statementId = session.submit(queryRequest(), asyncQueryRequestContext); assertFalse(statementId.getId().isEmpty()); } @Test public void submitStatementInNotStartedState() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); - StatementId statementId = session.submit(queryRequest()); + StatementId statementId = session.submit(queryRequest(), asyncQueryRequestContext); assertFalse(statementId.getId().isEmpty()); } @Test public void failToSubmitStatementInDeadState() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertThrows( + IllegalStateException.class, + () -> session.submit(queryRequest(), asyncQueryRequestContext)); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " dead", @@ -255,12 +263,15 @@ public void failToSubmitStatementInDeadState() { @Test public void failToSubmitStatementInFailState() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertThrows( + IllegalStateException.class, + () -> session.submit(queryRequest(), asyncQueryRequestContext)); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " fail", @@ -269,8 +280,9 @@ public void failToSubmitStatementInFailState() { @Test public void newStatementFieldAssert() { - Session session = sessionManager.createSession(createSessionRequest()); - StatementId statementId = session.submit(queryRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); + StatementId statementId = session.submit(queryRequest(), asyncQueryRequestContext); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -286,7 +298,8 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); // other's delete session client() @@ -294,16 +307,19 @@ public void failToSubmitStatementInDeletedSession() { .actionGet(); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertThrows( + IllegalStateException.class, + () -> session.submit(queryRequest(), asyncQueryRequestContext)); assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); } @Test public void getStatementSuccess() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); // App change state to running sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(queryRequest()); + StatementId statementId = session.submit(queryRequest(), asyncQueryRequestContext); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -313,7 +329,8 @@ public void getStatementSuccess() { @Test public void getStatementNotExist() { - Session session = sessionManager.createSession(createSessionRequest()); + Session session = + sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); // App change state to running sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING);