Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass down request context to data accessors #2715

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.asyncquery;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;

Expand All @@ -22,7 +22,8 @@ public interface AsyncQueryExecutorService {
* @return {@link CreateAsyncQueryResponse}
*/
CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext);
CreateAsyncQueryRequest createAsyncQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext);

/**
* Returns async query response for a given queryId.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
Expand All @@ -37,9 +37,10 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest, RequestContext requestContext) {
CreateAsyncQueryRequest createAsyncQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext);
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext);
DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
DispatchQueryRequest.builder()
Expand All @@ -53,7 +54,8 @@ public CreateAsyncQueryResponse createAsyncQuery(
.sparkSubmitParameterModifier(
sparkExecutionEngineConfig.getSparkSubmitParameterModifier())
.sessionId(createAsyncQueryRequest.getSessionId())
.build());
.build(),
asyncQueryRequestContext);
asyncQueryJobMetadataStorageService.storeJobMetadata(
AsyncQueryJobMetadata.builder()
.queryId(dispatchQueryResponse.getQueryId())
Expand All @@ -65,7 +67,8 @@ public CreateAsyncQueryResponse createAsyncQuery(
.datasourceName(dispatchQueryResponse.getDatasourceName())
.jobType(dispatchQueryResponse.getJobType())
.indexName(dispatchQueryResponse.getIndexName())
.build());
.build(),
asyncQueryRequestContext);
return new CreateAsyncQueryResponse(
dispatchQueryResponse.getQueryId(), dispatchQueryResponse.getSessionId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import java.util.Optional;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

public interface AsyncQueryJobMetadataStorageService {

void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata);
void storeJobMetadata(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext);

Optional<AsyncQueryJobMetadata> getJobMetadata(String jobId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer;
Expand All @@ -28,7 +29,9 @@ public class OpenSearchAsyncQueryJobMetadataStorageService
LogManager.getLogger(OpenSearchAsyncQueryJobMetadataStorageService.class);

@Override
public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) {
public void storeJobMetadata(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
stateStore.create(
mapIdToDocumentId(asyncQueryJobMetadata.getId()),
asyncQueryJobMetadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
package org.opensearch.sql.spark.asyncquery.model;

/** Context interface to provide additional request related information */
public interface RequestContext {
public interface AsyncQueryRequestContext {
Object getAttribute(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.asyncquery.model;

/** An implementation of RequestContext for where context is not required */
public class NullRequestContext implements RequestContext {
public class NullAsyncQueryRequestContext implements AsyncQueryRequestContext {
@Override
public Object getAttribute(String name) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.NullRequestContext;
import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

Expand All @@ -34,7 +34,7 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullRequestContext());
new NullAsyncQueryRequestContext());
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.opensearch.sql.spark.config;

import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

/** Interface for extracting and providing SparkExecutionEngineConfig */
public interface SparkExecutionEngineConfigSupplier {
Expand All @@ -10,5 +10,6 @@ public interface SparkExecutionEngineConfigSupplier {
*
* @return {@link SparkExecutionEngineConfig}.
*/
SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext);
SparkExecutionEngineConfig getSparkExecutionEngineConfig(
AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.cluster.ClusterName;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.spark.asyncquery.model.RequestContext;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

@AllArgsConstructor
public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier {

private Settings settings;

@Override
public SparkExecutionEngineConfig getSparkExecutionEngineConfig(RequestContext requestContext) {
public SparkExecutionEngineConfig getSparkExecutionEngineConfig(
AsyncQueryRequestContext asyncQueryRequestContext) {
ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME);
return getBuilderFromSettingsIfAvailable().clusterName(clusterName.value()).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -72,7 +73,8 @@ public DispatchQueryResponse submit(
dataSourceMetadata,
JobRunState.SUCCESS.toString(),
StringUtils.EMPTY,
getElapsedTimeSince(startTime));
getElapsedTimeSince(startTime),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(asyncQueryId)
.jobId(DML_QUERY_JOB_ID)
Expand All @@ -89,7 +91,8 @@ public DispatchQueryResponse submit(
dataSourceMetadata,
JobRunState.FAILED.toString(),
e.getMessage(),
getElapsedTimeSince(startTime));
getElapsedTimeSince(startTime),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(asyncQueryId)
.jobId(DML_QUERY_JOB_ID)
Expand All @@ -106,7 +109,8 @@ private String storeIndexDMLResult(
DataSourceMetadata dataSourceMetadata,
String status,
String error,
long queryRunTime) {
long queryRunTime,
AsyncQueryRequestContext asyncQueryRequestContext) {
IndexDMLResult indexDMLResult =
IndexDMLResult.builder()
.queryId(queryId)
Expand All @@ -116,7 +120,7 @@ private String storeIndexDMLResult(
.queryRunTime(queryRunTime)
.updateTime(System.currentTimeMillis())
.build();
indexDMLResultStorageService.createIndexDMLResult(indexDMLResult);
indexDMLResultStorageService.createIndexDMLResult(indexDMLResult, asyncQueryRequestContext);
return queryId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,16 @@ public DispatchQueryResponse submit(
.acceptModifier(dispatchQueryRequest.getSparkSubmitParameterModifier()),
tags,
dataSourceMetadata.getResultIndex(),
dataSourceMetadata.getName()));
dataSourceMetadata.getName()),
context.getAsyncQueryRequestContext());
MetricUtils.incrementNumericalMetric(MetricName.EMR_INTERACTIVE_QUERY_JOBS_CREATION_COUNT);
}
session.submit(
new QueryRequest(
context.getQueryId(),
dispatchQueryRequest.getLangType(),
dispatchQueryRequest.getQuery()));
dispatchQueryRequest.getQuery()),
context.getAsyncQueryRequestContext());
return DispatchQueryResponse.builder()
.queryId(context.getQueryId())
.jobId(session.getSessionModel().getJobId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand All @@ -37,7 +38,9 @@ public class SparkQueryDispatcher {
private final QueryHandlerFactory queryHandlerFactory;
private final QueryIdProvider queryIdProvider;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
public DispatchQueryResponse dispatch(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
Expand All @@ -48,13 +51,16 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build();
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Builder;
import lombok.Getter;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

@Getter
@Builder
Expand All @@ -17,4 +18,5 @@ public class DispatchQueryContext {
private final DataSourceMetadata dataSourceMetadata;
private final Map<String, String> tags;
private final IndexQueryDetails indexQueryDetails;
private final AsyncQueryRequestContext asyncQueryRequestContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
Expand Down Expand Up @@ -49,12 +50,18 @@ public class InteractiveSession implements Session {
private TimeProvider timeProvider;

@Override
public void open(CreateSessionRequest createSessionRequest) {
public void open(
CreateSessionRequest createSessionRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
try {
// append session id;
createSessionRequest
.getSparkSubmitParameters()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
.acceptModifier(
(parameters) -> {
parameters.sessionExecution(
sessionId.getSessionId(), createSessionRequest.getDatasourceName());
});
createSessionRequest.getTags().put(SESSION_ID_TAG_KEY, sessionId.getSessionId());
StartJobRequest startJobRequest =
createSessionRequest.getStartJobRequest(sessionId.getSessionId());
Expand All @@ -65,7 +72,7 @@ public void open(CreateSessionRequest createSessionRequest) {
sessionModel =
initInteractiveSession(
accountId, applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
sessionStorageService.createSession(sessionModel);
sessionStorageService.createSession(sessionModel, asyncQueryRequestContext);
} catch (VersionConflictEngineException e) {
String errorMsg = "session already exist. " + sessionId;
LOG.error(errorMsg);
Expand All @@ -87,7 +94,8 @@ public void close() {
}

/** Submit statement. If submit successfully, Statement in waiting state. */
public StatementId submit(QueryRequest request) {
public StatementId submit(
QueryRequest request, AsyncQueryRequestContext asyncQueryRequestContext) {
Optional<SessionModel> model =
sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName());
if (model.isEmpty()) {
Expand All @@ -109,6 +117,7 @@ public StatementId submit(QueryRequest request) {
.datasourceName(sessionModel.getDatasourceName())
.query(request.getQuery())
.queryId(qid)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
st.open();
return statementId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
package org.opensearch.sql.spark.execution.session;

import java.util.Optional;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
import org.opensearch.sql.spark.execution.statement.Statement;
import org.opensearch.sql.spark.execution.statement.StatementId;

/** Session define the statement execution context. Each session is binding to one Spark Job. */
public interface Session {
/** open session. */
void open(CreateSessionRequest createSessionRequest);
void open(
CreateSessionRequest createSessionRequest, AsyncQueryRequestContext asyncQueryRequestContext);

/** close session. */
void close();
Expand All @@ -22,9 +24,10 @@ public interface Session {
* submit {@link QueryRequest}.
*
* @param request {@link QueryRequest}
* @param asyncQueryRequestContext {@link AsyncQueryRequestContext}
* @return {@link StatementId}
*/
StatementId submit(QueryRequest request);
StatementId submit(QueryRequest request, AsyncQueryRequestContext asyncQueryRequestContext);

/**
* get {@link Statement}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.execution.statestore.SessionStorageService;
import org.opensearch.sql.spark.execution.statestore.StatementStorageService;
Expand All @@ -26,15 +27,16 @@ public class SessionManager {
private final EMRServerlessClientFactory emrServerlessClientFactory;
private final SessionConfigSupplier sessionConfigSupplier;

public Session createSession(CreateSessionRequest request) {
public Session createSession(
CreateSessionRequest request, AsyncQueryRequestContext asyncQueryRequestContext) {
InteractiveSession session =
InteractiveSession.builder()
.sessionId(newSessionId(request.getDatasourceName()))
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.build();
session.open(request);
session.open(request, asyncQueryRequestContext);
return session;
}

Expand Down
Loading
Loading