Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SparkQueryDispatcher #2636

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.dispatcher;

import lombok.RequiredArgsConstructor;
import org.opensearch.client.Client;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;

@RequiredArgsConstructor
public class QueryHandlerFactory {

private final JobExecutionResponseReader jobExecutionResponseReader;
private final FlintIndexMetadataService flintIndexMetadataService;
private final Client client;
private final SessionManager sessionManager;
private final LeaseManager leaseManager;
private final StateStore stateStore;
private final EMRServerlessClientFactory emrServerlessClientFactory;

public RefreshQueryHandler getRefreshQueryHandler() {
return new RefreshQueryHandler(
emrServerlessClientFactory.getClient(),
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
}

public StreamingQueryHandler getStreamingQueryHandler() {
return new StreamingQueryHandler(
emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager);
}

public BatchQueryHandler getBatchQueryHandler() {
return new BatchQueryHandler(
emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager);
}

public InteractiveQueryHandler getInteractiveQueryHandler() {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
}

public IndexDMLHandler getIndexDMLHandler() {
return new IndexDMLHandler(
emrServerlessClientFactory.getClient(),
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
client);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,19 @@
import java.util.HashMap;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
import org.opensearch.sql.spark.dispatcher.model.JobType;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.utils.SQLQueryUtils;

Expand All @@ -39,63 +33,67 @@ public class SparkQueryDispatcher {
public static final String CLUSTER_NAME_TAG_KEY = "domain_ident";
public static final String JOB_TYPE_TAG_KEY = "type";

private EMRServerlessClientFactory emrServerlessClientFactory;

private DataSourceService dataSourceService;

private JobExecutionResponseReader jobExecutionResponseReader;

private FlintIndexMetadataService flintIndexMetadataService;

private Client client;

private SessionManager sessionManager;

private LeaseManager leaseManager;

private StateStore stateStore;
private final DataSourceService dataSourceService;
private final SessionManager sessionManager;
private final QueryHandlerFactory queryHandlerFactory;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
AsyncQueryHandler asyncQueryHandler =
sessionManager.isEnabled()
? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
: new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
DispatchQueryContext.DispatchQueryContextBuilder contextBuilder =
DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()));

// override asyncQueryHandler with specific.

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails =
SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
fillMissingDetails(dispatchQueryRequest, indexQueryDetails);
contextBuilder.indexQueryDetails(indexQueryDetails);

if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
asyncQueryHandler = createIndexDMLHandler(emrServerlessClient);
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
asyncQueryHandler =
new StreamingQueryHandler(
emrServerlessClient, jobExecutionResponseReader, leaseManager);
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// manual refresh should be handled by batch handler
asyncQueryHandler =
new RefreshQueryHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
}
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
}
return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build());
}

private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
return DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()));
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
IndexQueryDetails indexQueryDetails) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler();
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler();
} else {
return getDefaultAsyncQueryHandler();
}
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler() {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler();
}

@NotNull
private static IndexQueryDetails getIndexQueryDetails(DispatchQueryRequest dispatchQueryRequest) {
IndexQueryDetails indexQueryDetails =
SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
fillDatasourceName(dispatchQueryRequest, indexQueryDetails);
return indexQueryDetails;
}

private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) {
Expand All @@ -119,58 +117,35 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail
}

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
if (asyncQueryJobMetadata.getSessionId() != null) {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata);
} else {
return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
}
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata)
.getQueryResponse(asyncQueryJobMetadata);
}

public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
AsyncQueryHandler queryHandler;
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata)
.cancelJob(asyncQueryJobMetadata);
}

private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
AsyncQueryJobMetadata asyncQueryJobMetadata) {
if (asyncQueryJobMetadata.getSessionId() != null) {
queryHandler =
new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getInteractiveQueryHandler();
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
queryHandler = createIndexDMLHandler(emrServerlessClient);
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
queryHandler =
new RefreshQueryHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
return queryHandlerFactory.getRefreshQueryHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
queryHandler =
new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getStreamingQueryHandler();
} else {
queryHandler =
new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
return queryHandlerFactory.getBatchQueryHandler();
}
return queryHandler.cancelJob(asyncQueryJobMetadata);
}

private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
client);
}

// TODO: Revisit this logic.
// Currently, Spark if datasource is not provided in query.
// Spark Assumes the datasource to be catalog.
// This is required to handle drop index case properly when datasource name is not provided.
private static void fillMissingDetails(
private static void fillDatasourceName(
DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) {
if (indexQueryDetails.getFullyQualifiedTableName() != null
&& indexQueryDetails.getFullyQualifiedTableName().getDatasourceName() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
Expand Down Expand Up @@ -65,23 +66,29 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) {

@Provides
public SparkQueryDispatcher sparkQueryDispatcher(
EMRServerlessClientFactory emrServerlessClientFactory,
DataSourceService dataSourceService,
SessionManager sessionManager,
QueryHandlerFactory queryHandlerFactory) {
return new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory);
}

@Provides
public QueryHandlerFactory queryhandlerFactory(
JobExecutionResponseReader jobExecutionResponseReader,
FlintIndexMetadataServiceImpl flintIndexMetadataReader,
NodeClient client,
SessionManager sessionManager,
DefaultLeaseManager defaultLeaseManager,
StateStore stateStore) {
return new SparkQueryDispatcher(
emrServerlessClientFactory,
dataSourceService,
StateStore stateStore,
EMRServerlessClientFactory emrServerlessClientFactory) {
return new QueryHandlerFactory(
jobExecutionResponseReader,
flintIndexMetadataReader,
client,
sessionManager,
defaultLeaseManager,
stateStore);
stateStore,
emrServerlessClientFactory);
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.QueryHandlerFactory;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.session.SessionModel;
Expand Down Expand Up @@ -200,16 +201,20 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
StateStore stateStore = new StateStore(client, clusterService);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClientFactory,
this.dataSourceService,
QueryHandlerFactory queryHandlerFactory =
new QueryHandlerFactory(
jobExecutionResponseReader,
new FlintIndexMetadataServiceImpl(client),
client,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
stateStore,
emrServerlessClientFactory);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
this.dataSourceService,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
queryHandlerFactory);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
Expand Down
Loading
Loading