Skip to content

Commit

Permalink
Pass down request context to data accessors (#2715) (#2725)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
(cherry picked from commit c0a5123)
  • Loading branch information
ykmr1224 authored Jun 7, 2024
1 parent 392a7ae commit 88116bb
Show file tree
Hide file tree
Showing 36 changed files with 278 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.asyncquery;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;

Expand All @@ -22,7 +22,8 @@ public interface AsyncQueryExecutorService {
* @return {@link CreateAsyncQueryResponse}
*/
CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext);
CreateAsyncQueryRequest createAsyncQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext);

/**
* Returns async query response for a given queryId.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
Expand All @@ -37,9 +37,10 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext) {
CreateAsyncQueryRequest createAsyncQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext);
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext);
DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
DispatchQueryRequest.builder()
Expand All @@ -53,7 +54,8 @@ public CreateAsyncQueryResponse createAsyncQuery(
.sparkSubmitParameterModifier(
sparkExecutionEngineConfig.getSparkSubmitParameterModifier())
.sessionId(createAsyncQueryRequest.getSessionId())
.build());
.build(),
asyncQueryRequestContext);
asyncQueryJobMetadataStorageService.storeJobMetadata(
AsyncQueryJobMetadata.builder()
.queryId(dispatchQueryResponse.getQueryId())
Expand All @@ -65,7 +67,8 @@ public CreateAsyncQueryResponse createAsyncQuery(
.datasourceName(dispatchQueryResponse.getDatasourceName())
.jobType(dispatchQueryResponse.getJobType())
.indexName(dispatchQueryResponse.getIndexName())
.build());
.build(),
asyncQueryRequestContext);
return new CreateAsyncQueryResponse(
dispatchQueryResponse.getQueryId(), dispatchQueryResponse.getSessionId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import java.util.Optional;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

public interface AsyncQueryJobMetadataStorageService {

void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata);
void storeJobMetadata(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext);

Optional<AsyncQueryJobMetadata> getJobMetadata(String jobId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer;
Expand All @@ -28,7 +29,9 @@ public class OpenSearchAsyncQueryJobMetadataStorageService
LogManager.getLogger(OpenSearchAsyncQueryJobMetadataStorageService.class);

@Override
public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public void storeJobMetadata(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
stateStore.create(
mapIdToDocumentId(asyncQueryJobMetadata.getId()),
asyncQueryJobMetadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
package org.opensearch.sql.spark.asyncquery.model;

/** Context interface to provide additional request related information */
public interface RequestContext {
public interface AsyncQueryRequestContext {
Object getAttribute(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.asyncquery.model;

/** An implementation of RequestContext for where context is not required */
public class NullRequestContext implements RequestContext {
public class NullAsyncQueryRequestContext implements AsyncQueryRequestContext {
@Override
public Object getAttribute(String name) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.NullRequestContext;
import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

Expand All @@ -34,7 +34,7 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullRequestContext());
new NullAsyncQueryRequestContext());
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.opensearch.sql.spark.config;

import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

/** Interface for extracting and providing SparkExecutionEngineConfig */
public interface SparkExecutionEngineConfigSupplier {
Expand All @@ -10,5 +10,6 @@ public interface SparkExecutionEngineConfigSupplier {
*
* @return {@link SparkExecutionEngineConfig}.
*/
SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext);
SparkExecutionEngineConfig getSparkExecutionEngineConfig(
AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.cluster.ClusterName;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

@AllArgsConstructor
public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier {

private Settings settings;

@Override
public SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext) {
public SparkExecutionEngineConfig getSparkExecutionEngineConfig(
AsyncQueryRequestContext asyncQueryRequestContext) {
ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME);
return getBuilderFromSettingsIfAvailable().clusterName(clusterName.value()).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -72,7 +73,8 @@ public DispatchQueryResponse submit(
dataSourceMetadata,
JobRunState.SUCCESS.toString(),
StringUtils.EMPTY,
getElapsedTimeSince(startTime));
getElapsedTimeSince(startTime),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(asyncQueryId)
.jobId(DML_QUERY_JOB_ID)
Expand All @@ -89,7 +91,8 @@ public DispatchQueryResponse submit(
dataSourceMetadata,
JobRunState.FAILED.toString(),
e.getMessage(),
getElapsedTimeSince(startTime));
getElapsedTimeSince(startTime),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(asyncQueryId)
.jobId(DML_QUERY_JOB_ID)
Expand All @@ -106,7 +109,8 @@ private String storeIndexDMLResult(
DataSourceMetadata dataSourceMetadata,
String status,
String error,
long queryRunTime) {
long queryRunTime,
AsyncQueryRequestContext asyncQueryRequestContext) {
IndexDMLResult indexDMLResult =
IndexDMLResult.builder()
.queryId(queryId)
Expand All @@ -116,7 +120,7 @@ private String storeIndexDMLResult(
.queryRunTime(queryRunTime)
.updateTime(System.currentTimeMillis())
.build();
indexDMLResultStorageService.createIndexDMLResult(indexDMLResult);
indexDMLResultStorageService.createIndexDMLResult(indexDMLResult, asyncQueryRequestContext);
return queryId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,16 @@ public DispatchQueryResponse submit(
.acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()),
tags,
dataSourceMetadata.getResultIndex(),
dataSourceMetadata.getName()));
dataSourceMetadata.getName()),
context.getAsyncQueryRequestContext());
MetricUtils.incrementNumericalMetric(MetricName.EMR_INTERACTIVE_QUERY_JOBS_CREATION_COUNT);
}
session.submit(
new QueryRequest(
context.getQueryId(),
dispatchQueryRequest.getLangType(),
dispatchQueryRequest.getQuery()));
dispatchQueryRequest.getQuery()),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(context.getQueryId())
.jobId(session.getSessionModel().getJobId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand All @@ -37,7 +38,9 @@ public class SparkQueryDispatcher {
private final QueryHandlerFactory queryHandlerFactory;
private final QueryIdProvider queryIdProvider;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
public DispatchQueryResponse dispatch(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
Expand All @@ -48,13 +51,16 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build();
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Builder;
import lombok.Getter;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

@Getter
@Builder
Expand All @@ -17,4 +18,5 @@ public class DispatchQueryContext {
private final DataSourceMetadata dataSourceMetadata;
private final Map<String, String> tags;
private final IndexQueryDetails indexQueryDetails;
private final AsyncQueryRequestContext asyncQueryRequestContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
Expand Down Expand Up @@ -49,12 +50,18 @@ public class InteractiveSession implements Session {
private TimeProvider timeProvider;

@Override
public void open(CreateSessionRequest createSessionRequest) {
public void open(
CreateSessionRequest createSessionRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
try {
// append session id;
createSessionRequest
.getSparkSubmitParameters()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
.acceptModifier(
(parameters) -> {
parameters.sessionExecution(
sessionId.getSessionId(), createSessionRequest.getDatasourceName());
});
createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId());
StartJobRequest startJobRequest =
createSessionRequest.getStartJobRequest(sessionId.getSessionId());
Expand All @@ -65,7 +72,7 @@ public void open(CreateSessionRequest createSessionRequest) {
sessionModel =
initInteractiveSession(
accountId, applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
sessionStorageService.createSession(sessionModel);
sessionStorageService.createSession(sessionModel, asyncQueryRequestContext);
} catch (VersionConflictEngineException e) {
String errorMsg = "session already exist. " + sessionId;
LOG.error(errorMsg);
Expand All @@ -87,7 +94,8 @@ public void close() {
}

/** Submit statement. If submit successfully, Statement in waiting state. */
public StatementId submit(QueryRequest request) {
public StatementId submit(
QueryRequest request, AsyncQueryRequestContext asyncQueryRequestContext) {
Optional<SessionModel> model =
sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName());
if (model.isEmpty()) {
Expand All @@ -109,6 +117,7 @@ public StatementId submit(QueryRequest request) {
.datasourceName(sessionModel.getDatasourceName())
.query(request.getQuery())
.queryId(qid)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
st.open();
return statementId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
package org.opensearch.sql.spark.execution.session;

import java.util.Optional;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
import org.opensearch.sql.spark.execution.statement.Statement;
import org.opensearch.sql.spark.execution.statement.StatementId;

/** Session define the statement execution context. Each session is binding to one Spark Job. */
public interface Session {
/** open session. */
void open(CreateSessionRequest createSessionRequest);
void open(
CreateSessionRequest createSessionRequest, AsyncQueryRequestContext asyncQueryRequestContext);

/** close session. */
void close();
Expand All @@ -22,9 +24,10 @@ public interface Session {
* submit {@link QueryRequest}.
*
* @param request {@link QueryRequest}
* @param asyncQueryRequestContext {@link AsyncQueryRequestContext}
* @return {@link StatementId}
*/
StatementId submit(QueryRequest request);
StatementId submit(QueryRequest request, AsyncQueryRequestContext asyncQueryRequestContext);

/**
* get {@link Statement}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.execution.statestore.SessionStorageService;
import org.opensearch.sql.spark.execution.statestore.StatementStorageService;
Expand All @@ -26,15 +27,16 @@ public class SessionManager {
private final EMRServerlessClientFactory emrServerlessClientFactory;
private final SessionConfigSupplier sessionConfigSupplier;

public Session createSession(CreateSessionRequest request) {
public Session createSession(
CreateSessionRequest request, AsyncQueryRequestContext asyncQueryRequestContext) {
InteractiveSession session =
InteractiveSession.builder()
.sessionId(newSessionId(request.getDatasourceName()))
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.build();
session.open(request);
session.open(request, asyncQueryRequestContext);
return session;
}

Expand Down
Loading

0 comments on commit 88116bb

Please sign in to comment.