Skip to content

Commit

Permalink
Merge branch 'main' into dqs/datasource-extension
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki MORITA <[email protected]>
  • Loading branch information
ykmr1224 authored Jun 29, 2024
2 parents f2f0cf5 + 4326396 commit fe9b35f
Show file tree
Hide file tree
Showing 30 changed files with 460 additions and 346 deletions.
1 change: 0 additions & 1 deletion async-query-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ dependencies {
implementation project(':core')
implementation project(':spark') // TODO: dependency to spark should be eliminated
implementation project(':datasources') // TODO: dependency to datasources should be eliminated
implementation project(':legacy') // TODO: dependency to legacy should be eliminated
implementation 'org.json:json:20231013'
implementation 'com.google.code.gson:gson:2.8.9'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public interface EMRServerlessClientFactory {
/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @param accountId Account ID of the requester. It will be used to decide the cluster.
* @return An {@link EMRServerlessClient} instance.
*/
EMRServerlessClient getClient();
EMRServerlessClient getClient(String accountId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.metrics.MetricsService;

/** Implementation of {@link EMRServerlessClientFactory}. */
@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

Expand All @@ -27,13 +26,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
private EMRServerlessClient emrServerlessClient;
private String region;

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
@Override
public EMRServerlessClient getClient() {
public EMRServerlessClient getClient(String accountId) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullAsyncQueryRequestContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ public class QueryHandlerFactory {
private final MetricsService metricsService;
protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider;

public RefreshQueryHandler getRefreshQueryHandler() {
public RefreshQueryHandler getRefreshQueryHandler(String accountId) {
return new RefreshQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
flintIndexMetadataService,
leaseManager,
Expand All @@ -40,18 +40,18 @@ public RefreshQueryHandler getRefreshQueryHandler() {
sparkSubmitParametersBuilderProvider);
}

public StreamingQueryHandler getStreamingQueryHandler() {
public StreamingQueryHandler getStreamingQueryHandler(String accountId) {
return new StreamingQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService,
sparkSubmitParametersBuilderProvider);
}

public BatchQueryHandler getBatchQueryHandler() {
public BatchQueryHandler getBatchQueryHandler(String accountId) {
return new BatchQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ public DispatchQueryResponse dispatch(
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
}
}

Expand All @@ -74,28 +75,28 @@ private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchConte
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
IndexQueryDetails indexQueryDetails) {
DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) {
// Create should be handled by batch handler. This is to avoid DROP index incorrectly cancel
// an interactive job.
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// Manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(dispatchQueryRequest.getAccountId());
} else {
return getDefaultAsyncQueryHandler();
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId());
}
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler() {
private AsyncQueryHandler getDefaultAsyncQueryHandler(String accountId) {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler();
: queryHandlerFactory.getBatchQueryHandler(accountId);
}

@NotNull
Expand Down Expand Up @@ -143,11 +144,11 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId());
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId());
} else {
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(asyncQueryJobMetadata.getAccountId());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public Session createSession(
.sessionId(sessionIdProvider.getSessionId(request))
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(request.getAccountId()))
.build();
session.open(request, asyncQueryRequestContext);
return session;
Expand Down Expand Up @@ -65,7 +65,7 @@ public Optional<Session> getSession(String sessionId, String dataSourceName) {
.sessionId(sessionId)
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(model.get().getAccountId()))
.sessionModel(model.get())
.sessionInactivityTimeoutMilli(
sessionConfigSupplier.getSessionInactivityTimeoutMillis())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel)
throws InterruptedException, TimeoutException {
String applicationId = flintIndexStateModel.getApplicationId();
String jobId = flintIndexStateModel.getJobId();
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient =
emrServerlessClientFactory.getClient(flintIndexStateModel.getAccountId());
try {
emrServerlessClient.cancelJobRun(
flintIndexStateModel.getApplicationId(), flintIndexStateModel.getJobId(), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@

package org.opensearch.sql.spark.rest.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import lombok.Data;
import org.apache.commons.lang3.Validate;
import org.opensearch.core.xcontent.XContentParser;

@Data
public class CreateAsyncQueryRequest {
Expand All @@ -32,35 +28,4 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang, S
this.lang = Validate.notNull(lang, "lang can't be null");
this.sessionId = sessionId;
}

public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser)
throws IOException {
String query = null;
LangType lang = null;
String datasource = null;
String sessionId = null;
try {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
if (fieldName.equals("query")) {
query = parser.textOrNull();
} else if (fieldName.equals("lang")) {
String langString = parser.textOrNull();
lang = LangType.fromString(langString);
} else if (fieldName.equals("datasource")) {
datasource = parser.textOrNull();
} else if (fieldName.equals("sessionId")) {
sessionId = parser.textOrNull();
} else {
throw new IllegalArgumentException("Unknown field: " + fieldName);
}
}
return new CreateAsyncQueryRequest(query, datasource, lang, sessionId);
} catch (Exception e) {
throw new IllegalArgumentException(
String.format("Error while parsing the request body: %s", e.getMessage()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.utils;

import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import lombok.Getter;
import lombok.experimental.UtilityClass;
Expand All @@ -18,6 +20,7 @@
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;
import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
Expand All @@ -32,16 +35,15 @@
@UtilityClass
public class SQLQueryUtils {

// TODO Handle cases where the query has multiple table Names.
public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) {
public static List<FullyQualifiedTableName> extractFullyQualifiedTableNames(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor();
statement.accept(sparkSqlTableNameVisitor);
return sparkSqlTableNameVisitor.getFullyQualifiedTableName();
return sparkSqlTableNameVisitor.getFullyQualifiedTableNames();
}

public static IndexQueryDetails extractIndexDetails(String sqlQuery) {
Expand Down Expand Up @@ -73,23 +75,21 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private FullyQualifiedTableName fullyQualifiedTableName;
@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();

public SparkSqlTableNameVisitor() {
this.fullyQualifiedTableName = new FullyQualifiedTableName();
}
public SparkSqlTableNameVisitor() {}

@Override
public Void visitTableName(SqlBaseParser.TableNameContext ctx) {
fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText());
return super.visitTableName(ctx);
public Void visitIdentifierReference(IdentifierReferenceContext ctx) {
fullyQualifiedTableNames.add(new FullyQualifiedTableName(ctx.getText()));
return super.visitIdentifierReference(ctx);
}

@Override
public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDropTable(ctx);
Expand All @@ -99,7 +99,7 @@ public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDescribeRelation(ctx);
Expand All @@ -110,7 +110,7 @@ public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitCreateTableHeader(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public class AsyncQueryCoreIntegTest {
@BeforeEach
public void setUp() {
emrServerlessClientFactory =
() -> new EmrServerlessClientImpl(awsemrServerless, metricsService);
(accountId) -> new EmrServerlessClientImpl(awsemrServerless, metricsService);
SparkParameterComposerCollection collection = new SparkParameterComposerCollection();
collection.register(
DataSourceType.S3GLUE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
@ExtendWith(MockitoExtension.class)
public class EMRServerlessClientFactoryImplTest {

public static final String ACCOUNT_ID = "accountId";
@Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
@Mock private MetricsService metricsService;

Expand All @@ -30,7 +31,9 @@ public void testGetClient() {
.thenReturn(createSparkExecutionEngineConfig());
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient();

EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID);

Assertions.assertNotNull(emrserverlessClient);
}

Expand All @@ -41,16 +44,16 @@ public void testGetClientWithChangeInSetting() {
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient();
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertNotNull(emrserverlessClient);

EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertEquals(emrServerlessClient1, emrserverlessClient);

sparkExecutionEngineConfig.setRegion(TestConstants.US_WEST_REGION);
when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any()))
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertNotEquals(emrServerlessClient2, emrserverlessClient);
Assertions.assertNotEquals(emrServerlessClient2, emrServerlessClient1);
}
Expand All @@ -60,9 +63,11 @@ public void testGetClientWithException() {
when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())).thenReturn(null);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class, emrServerlessClientFactory::getClient);
IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID));

Assertions.assertEquals(
"Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config"
+ " in cluster settings to enable them.",
Expand All @@ -77,9 +82,11 @@ public void testGetClientWithExceptionWithNullRegion() {
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class, emrServerlessClientFactory::getClient);
IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID));

Assertions.assertEquals(
"Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config"
+ " in cluster settings to enable them.",
Expand Down
Loading

0 comments on commit fe9b35f

Please sign in to comment.