Skip to content

Commit

Permalink
Pass down request context to data accessors (opensearch-project#2715)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>

(cherry picked from commit c0a5123)
  • Loading branch information
ykmr1224 committed Jun 7, 2024
1 parent bcfafc1 commit 76f04b9
Show file tree
Hide file tree
Showing 36 changed files with 358 additions and 336 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.asyncquery;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;

Expand All @@ -22,7 +22,8 @@ public interface AsyncQueryExecutorService {
* @return {@link CreateAsyncQueryResponse}
*/
CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext);
CreateAsyncQueryRequest createAsyncQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext);

/**
* Returns async query response for a given queryId.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
Expand All @@ -37,31 +37,38 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext) {
CreateAsyncQueryRequest createAsyncQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext);
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext);
DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
new DispatchQueryRequest(
sparkExecutionEngineConfig.getApplicationId(),
createAsyncQueryRequest.getQuery(),
createAsyncQueryRequest.getDatasource(),
createAsyncQueryRequest.getLang(),
sparkExecutionEngineConfig.getExecutionRoleARN(),
sparkExecutionEngineConfig.getClusterName(),
sparkExecutionEngineConfig.getSparkSubmitParameterModifier(),
createAsyncQueryRequest.getSessionId()));
DispatchQueryRequest.builder()
.accountId(sparkExecutionEngineConfig.getAccountId())
.applicationId(sparkExecutionEngineConfig.getApplicationId())
.query(createAsyncQueryRequest.getQuery())
.datasource(createAsyncQueryRequest.getDatasource())
.langType(createAsyncQueryRequest.getLang())
.executionRoleARN(sparkExecutionEngineConfig.getExecutionRoleARN())
.clusterName(sparkExecutionEngineConfig.getClusterName())
.sparkSubmitParameterModifier(
sparkExecutionEngineConfig.getSparkSubmitParameterModifier())
.sessionId(createAsyncQueryRequest.getSessionId())
.build(),
asyncQueryRequestContext);
asyncQueryJobMetadataStorageService.storeJobMetadata(
AsyncQueryJobMetadata.builder()
.queryId(dispatchQueryResponse.getQueryId())
.accountId(sparkExecutionEngineConfig.getAccountId())
.applicationId(sparkExecutionEngineConfig.getApplicationId())
.jobId(dispatchQueryResponse.getJobId())
.resultIndex(dispatchQueryResponse.getResultIndex())
.sessionId(dispatchQueryResponse.getSessionId())
.datasourceName(dispatchQueryResponse.getDatasourceName())
.jobType(dispatchQueryResponse.getJobType())
.indexName(dispatchQueryResponse.getIndexName())
.build());
.build(),
asyncQueryRequestContext);
return new CreateAsyncQueryResponse(
dispatchQueryResponse.getQueryId(), dispatchQueryResponse.getSessionId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import java.util.Optional;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

public interface AsyncQueryJobMetadataStorageService {

void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata);
void storeJobMetadata(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext);

Optional<AsyncQueryJobMetadata> getJobMetadata(String jobId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer;
Expand All @@ -28,7 +29,9 @@ public class OpenSearchAsyncQueryJobMetadataStorageService
LogManager.getLogger(OpenSearchAsyncQueryJobMetadataStorageService.class);

@Override
public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public void storeJobMetadata(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
stateStore.create(
mapIdToDocumentId(asyncQueryJobMetadata.getId()),
asyncQueryJobMetadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
package org.opensearch.sql.spark.asyncquery.model;

/** Context interface to provide additional request related information */
public interface RequestContext {
public interface AsyncQueryRequestContext {
Object getAttribute(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.asyncquery.model;

/** An implementation of RequestContext for where context is not required */
public class NullRequestContext implements RequestContext {
public class NullAsyncQueryRequestContext implements AsyncQueryRequestContext {
@Override
public Object getAttribute(String name) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.NullRequestContext;
import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

Expand All @@ -34,7 +34,7 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullRequestContext());
new NullAsyncQueryRequestContext());
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.opensearch.sql.spark.config;

import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

/** Interface for extracting and providing SparkExecutionEngineConfig */
public interface SparkExecutionEngineConfigSupplier {
Expand All @@ -10,5 +10,6 @@ public interface SparkExecutionEngineConfigSupplier {
*
* @return {@link SparkExecutionEngineConfig}.
*/
SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext);
SparkExecutionEngineConfig getSparkExecutionEngineConfig(
AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.cluster.ClusterName;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

@AllArgsConstructor
public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier {

private Settings settings;

@Override
public SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext) {
public SparkExecutionEngineConfig getSparkExecutionEngineConfig(
AsyncQueryRequestContext asyncQueryRequestContext) {
ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME);
return getBuilderFromSettingsIfAvailable().clusterName(clusterName.value()).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -72,7 +73,8 @@ public DispatchQueryResponse submit(
dataSourceMetadata,
JobRunState.SUCCESS.toString(),
StringUtils.EMPTY,
getElapsedTimeSince(startTime));
getElapsedTimeSince(startTime),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(asyncQueryId)
.jobId(DML_QUERY_JOB_ID)
Expand All @@ -89,7 +91,8 @@ public DispatchQueryResponse submit(
dataSourceMetadata,
JobRunState.FAILED.toString(),
e.getMessage(),
getElapsedTimeSince(startTime));
getElapsedTimeSince(startTime),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(asyncQueryId)
.jobId(DML_QUERY_JOB_ID)
Expand All @@ -106,7 +109,8 @@ private String storeIndexDMLResult(
DataSourceMetadata dataSourceMetadata,
String status,
String error,
long queryRunTime) {
long queryRunTime,
AsyncQueryRequestContext asyncQueryRequestContext) {
IndexDMLResult indexDMLResult =
IndexDMLResult.builder()
.queryId(queryId)
Expand All @@ -116,7 +120,7 @@ private String storeIndexDMLResult(
.queryRunTime(queryRunTime)
.updateTime(System.currentTimeMillis())
.build();
indexDMLResultStorageService.createIndexDMLResult(indexDMLResult);
indexDMLResultStorageService.createIndexDMLResult(indexDMLResult, asyncQueryRequestContext);
return queryId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,16 @@ public DispatchQueryResponse submit(
.acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()),
tags,
dataSourceMetadata.getResultIndex(),
dataSourceMetadata.getName()));
dataSourceMetadata.getName()),
context.getAsyncQueryRequestContext());
MetricUtils.incrementNumericalMetric(MetricName.EMR_INTERACTIVE_QUERY_JOBS_CREATION_COUNT);
}
session.submit(
new QueryRequest(
context.getQueryId(),
dispatchQueryRequest.getLangType(),
dispatchQueryRequest.getQuery()));
dispatchQueryRequest.getQuery()),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(context.getQueryId())
.jobId(session.getSessionModel().getJobId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand All @@ -37,7 +38,9 @@ public class SparkQueryDispatcher {
private final QueryHandlerFactory queryHandlerFactory;
private final QueryIdProvider queryIdProvider;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
public DispatchQueryResponse dispatch(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
Expand All @@ -48,13 +51,16 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build();
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Builder;
import lombok.Getter;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

@Getter
@Builder
Expand All @@ -17,4 +18,5 @@ public class DispatchQueryContext {
private final DataSourceMetadata dataSourceMetadata;
private final Map<String, String> tags;
private final IndexQueryDetails indexQueryDetails;
private final AsyncQueryRequestContext asyncQueryRequestContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
Expand Down Expand Up @@ -49,22 +50,29 @@ public class InteractiveSession implements Session {
private TimeProvider timeProvider;

@Override
public void open(CreateSessionRequest createSessionRequest) {
public void open(
CreateSessionRequest createSessionRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
try {
// append session id;
createSessionRequest
.getSparkSubmitParameters()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
.acceptModifier(
(parameters) -> {
parameters.sessionExecution(
sessionId.getSessionId(), createSessionRequest.getDatasourceName());
});
createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId());
StartJobRequest startJobRequest =
createSessionRequest.getStartJobRequest(sessionId.getSessionId());
String jobID = serverlessClient.startJobRun(startJobRequest);
String applicationId = startJobRequest.getApplicationId();
String accountId = createSessionRequest.getAccountId();

sessionModel =
initInteractiveSession(
applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
sessionStorageService.createSession(sessionModel);
accountId, applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
sessionStorageService.createSession(sessionModel, asyncQueryRequestContext);
} catch (VersionConflictEngineException e) {
String errorMsg = "session already exist. " + sessionId;
LOG.error(errorMsg);
Expand All @@ -86,7 +94,8 @@ public void close() {
}

/** Submit statement. If submit successfully, Statement in waiting state. */
public StatementId submit(QueryRequest request) {
public StatementId submit(
QueryRequest request, AsyncQueryRequestContext asyncQueryRequestContext) {
Optional<SessionModel> model =
sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName());
if (model.isEmpty()) {
Expand All @@ -99,6 +108,7 @@ public StatementId submit(QueryRequest request) {
Statement st =
Statement.builder()
.sessionId(sessionId)
.accountId(sessionModel.getAccountId())
.applicationId(sessionModel.getApplicationId())
.jobId(sessionModel.getJobId())
.statementStorageService(statementStorageService)
Expand All @@ -107,6 +117,7 @@ public StatementId submit(QueryRequest request) {
.datasourceName(sessionModel.getDatasourceName())
.query(request.getQuery())
.queryId(qid)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
st.open();
return statementId;
Expand All @@ -130,6 +141,7 @@ public Optional<Statement> get(StatementId stID) {
model ->
Statement.builder()
.sessionId(sessionId)
.accountId(model.getAccountId())
.applicationId(model.getApplicationId())
.jobId(model.getJobId())
.statementId(model.getStatementId())
Expand Down
Loading

0 comments on commit 76f04b9

Please sign in to comment.