Skip to content

Commit

Permalink
Remove EMRServerlessClientFactory from SparkQueryDispatcher
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed Apr 26, 2024
1 parent 8b02fe8 commit bbed24a
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import lombok.RequiredArgsConstructor;
import org.opensearch.client.Client;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
Expand All @@ -23,31 +23,34 @@ public class QueryHandlerFactory {
private final SessionManager sessionManager;
private final LeaseManager leaseManager;
private final StateStore stateStore;
private final EMRServerlessClientFactory emrServerlessClientFactory;

public RefreshQueryHandler getRefreshQueryHandler(EMRServerlessClient emrServerlessClient) {
public RefreshQueryHandler getRefreshQueryHandler() {
return new RefreshQueryHandler(
emrServerlessClient,
emrServerlessClientFactory.getClient(),
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
leaseManager);
}

public StreamingQueryHandler getStreamingQueryHandler(EMRServerlessClient emrServerlessClient) {
return new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
public StreamingQueryHandler getStreamingQueryHandler() {
return new StreamingQueryHandler(
emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager);
}

public BatchQueryHandler getBatchQueryHandler(EMRServerlessClient emrServerlessClient) {
return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
public BatchQueryHandler getBatchQueryHandler() {
return new BatchQueryHandler(
emrServerlessClientFactory.getClient(), jobExecutionResponseReader, leaseManager);
}

public InteractiveQueryHandler getInteractiveQueryHandler() {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
}

public IndexDMLHandler getIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
public IndexDMLHandler getIndexDMLHandler() {
return new IndexDMLHandler(
emrServerlessClient,
emrServerlessClientFactory.getClient(),
jobExecutionResponseReader,
flintIndexMetadataService,
stateStore,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand All @@ -35,68 +33,69 @@ public class SparkQueryDispatcher {
public static final String CLUSTER_NAME_TAG_KEY = "domain_ident";
public static final String JOB_TYPE_TAG_KEY = "type";

private final EMRServerlessClientFactory emrServerlessClientFactory;
private final DataSourceService dataSourceService;
private final SessionManager sessionManager;
private final QueryHandlerFactory queryHandlerFactory;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.build();
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails, emrServerlessClient)
.submit(dispatchQueryRequest, context);
return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context = getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.build();
return getDefaultAsyncQueryHandler(emrServerlessClient).submit(dispatchQueryRequest, context);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata).build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
}
}

private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
private static DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
return DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()));
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()));
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(IndexQueryDetails indexQueryDetails, EMRServerlessClient emrServerlessClient) {
private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
IndexQueryDetails indexQueryDetails) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient);
return queryHandlerFactory.getIndexDMLHandler();
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient);
return queryHandlerFactory.getStreamingQueryHandler();
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient);
return queryHandlerFactory.getRefreshQueryHandler();
} else {
return getDefaultAsyncQueryHandler(emrServerlessClient);
return getDefaultAsyncQueryHandler();
}
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler(EMRServerlessClient emrServerlessClient) {
private AsyncQueryHandler getDefaultAsyncQueryHandler() {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler(emrServerlessClient);
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler();
}

@NotNull
private static IndexQueryDetails getIndexQueryDetails(DispatchQueryRequest dispatchQueryRequest) {
IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
IndexQueryDetails indexQueryDetails =
SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
fillDatasourceName(dispatchQueryRequest, indexQueryDetails);
return indexQueryDetails;
}


private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) {
Boolean isCreateAutoRefreshIndex =
IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())
Expand All @@ -118,25 +117,27 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail
}

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).getQueryResponse(asyncQueryJobMetadata);
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata)
.getQueryResponse(asyncQueryJobMetadata);
}

public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata).cancelJob(asyncQueryJobMetadata);
return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata)
.cancelJob(asyncQueryJobMetadata);
}

private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
AsyncQueryJobMetadata asyncQueryJobMetadata) {
if (asyncQueryJobMetadata.getSessionId() != null) {
return queryHandlerFactory.getInteractiveQueryHandler();
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return queryHandlerFactory.getIndexDMLHandler(emrServerlessClient);
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
return queryHandlerFactory.getRefreshQueryHandler(emrServerlessClient);
return queryHandlerFactory.getRefreshQueryHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
return queryHandlerFactory.getStreamingQueryHandler(emrServerlessClient);
return queryHandlerFactory.getStreamingQueryHandler();
} else {
return queryHandlerFactory.getBatchQueryHandler(emrServerlessClient);
return queryHandlerFactory.getBatchQueryHandler();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
public class AsyncExecutorServiceModule extends AbstractModule {

@Override
protected void configure() {
}
protected void configure() {}

@Provides
public AsyncQueryExecutorService asyncQueryExecutorService(
Expand Down Expand Up @@ -67,17 +66,10 @@ public StateStore stateStore(NodeClient client, ClusterService clusterService) {

@Provides
public SparkQueryDispatcher sparkQueryDispatcher(
EMRServerlessClientFactory emrServerlessClientFactory,
DataSourceService dataSourceService,
SessionManager sessionManager,
QueryHandlerFactory queryHandlerFactory
) {
return new SparkQueryDispatcher(
emrServerlessClientFactory,
dataSourceService,
sessionManager,
queryHandlerFactory
);
QueryHandlerFactory queryHandlerFactory) {
return new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory);
}

@Provides
Expand All @@ -87,15 +79,16 @@ public QueryHandlerFactory queryhandlerFactory(
NodeClient client,
SessionManager sessionManager,
DefaultLeaseManager defaultLeaseManager,
StateStore stateStore
) {
StateStore stateStore,
EMRServerlessClientFactory emrServerlessClientFactory) {
return new QueryHandlerFactory(
jobExecutionResponseReader,
flintIndexMetadataReader,
client,
sessionManager,
defaultLeaseManager,
stateStore);
stateStore,
emrServerlessClientFactory);
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,17 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
StateStore stateStore = new StateStore(client, clusterService);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory(
jobExecutionResponseReader,
new FlintIndexMetadataServiceImpl(client),
client,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore
);
QueryHandlerFactory queryHandlerFactory =
new QueryHandlerFactory(
jobExecutionResponseReader,
new FlintIndexMetadataServiceImpl(client),
client,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore,
emrServerlessClientFactory);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClientFactory,
this.dataSourceService,
new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings),
queryHandlerFactory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,17 @@ public class SparkQueryDispatcherTest {

@BeforeEach
void setUp() {
QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory(
jobExecutionResponseReader,
flintIndexMetadataService,
openSearchClient,
sessionManager,
leaseManager,
stateStore
);
sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClientFactory,
dataSourceService,
QueryHandlerFactory queryHandlerFactory =
new QueryHandlerFactory(
jobExecutionResponseReader,
flintIndexMetadataService,
openSearchClient,
sessionManager,
queryHandlerFactory);
leaseManager,
stateStore,
emrServerlessClientFactory);
sparkQueryDispatcher =
new SparkQueryDispatcher(dataSourceService, sessionManager, queryHandlerFactory);
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
}

Expand Down

0 comments on commit bbed24a

Please sign in to comment.