Skip to content

Commit

Permalink
Fix handler for existing query (#2968)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>

(cherry picked from commit b14a8cb)
  • Loading branch information
ykmr1224 committed Sep 5, 2024
1 parent c6b329f commit 5bcc23b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public String cancelJob(
@Override
public DispatchQueryResponse submit(
DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) {
leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource()));
leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource()));

DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context);
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();
Expand All @@ -83,7 +83,7 @@ public DispatchQueryResponse submit(
.resultIndex(resp.getResultIndex())
.sessionId(resp.getSessionId())
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.BATCH)
.jobType(JobType.REFRESH)
.indexName(context.getIndexQueryDetails().openSearchIndexName())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
return queryHandlerFactory.getInteractiveQueryHandler();
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
} else if (asyncQueryJobMetadata.getJobType() == JobType.REFRESH) {
return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId());
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
public enum JobType {
INTERACTIVE("interactive"),
STREAMING("streaming"),
REFRESH("refresh"),
BATCH("batch");

private String text;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public void createAlterIndexQuery() {
assertFalse(flintIndexOptions.autoRefresh());
verifyCancelJobRunCalled();
verifyCreateIndexDMLResultCalled();
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID);
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH);
}

@Test
Expand All @@ -280,7 +280,7 @@ public void createStreamingQuery() {
verifyGetQueryIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID);
verifyStoreJobMetadataCalled(JOB_ID, JobType.STREAMING);
}

private void verifyStartJobRunCalled() {
Expand Down Expand Up @@ -315,7 +315,7 @@ public void createCreateIndexQuery() {
assertNull(response.getSessionId());
verifyGetQueryIdCalled();
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID);
verifyStoreJobMetadataCalled(JOB_ID, JobType.BATCH);
}

@Test
Expand All @@ -337,7 +337,7 @@ public void createRefreshQuery() {
verifyGetQueryIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID);
verifyStoreJobMetadataCalled(JOB_ID, JobType.REFRESH);
}

@Test
Expand All @@ -363,7 +363,7 @@ public void createInteractiveQuery() {
verifyGetSessionIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID);
verifyStoreJobMetadataCalled(JOB_ID, JobType.INTERACTIVE);
}

@Test
Expand Down Expand Up @@ -454,7 +454,7 @@ public void cancelIndexDMLQuery() {
@Test
public void cancelRefreshQuery() {
givenJobMetadataExists(
getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.BATCH).indexName(INDEX_NAME));
getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.REFRESH).indexName(INDEX_NAME));
when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME, asyncQueryRequestContext))
.thenReturn(
ImmutableMap.of(
Expand Down Expand Up @@ -507,7 +507,8 @@ private void givenSparkExecutionEngineConfigIsSupplied() {
.build());
}

private void givenFlintIndexMetadataExists(String indexName) {
private void givenFlintIndexMetadataExists(
String indexName, FlintIndexOptions flintIndexOptions) {
when(flintIndexMetadataService.getFlintIndexMetadata(indexName, asyncQueryRequestContext))
.thenReturn(
ImmutableMap.of(
Expand All @@ -516,9 +517,27 @@ private void givenFlintIndexMetadataExists(String indexName) {
.appId(APPLICATION_ID)
.jobId(JOB_ID)
.opensearchIndexName(indexName)
.flintIndexOptions(flintIndexOptions)
.build()));
}

// Overload method for default FlintIndexOptions usage
private void givenFlintIndexMetadataExists(String indexName) {
givenFlintIndexMetadataExists(indexName, new FlintIndexOptions());
}

// Method to set up FlintIndexMetadata with external scheduler
private void givenFlintIndexMetadataExistsWithExternalScheduler(String indexName) {
givenFlintIndexMetadataExists(indexName, createExternalSchedulerFlintIndexOptions());
}

// Helper method for creating FlintIndexOptions with external scheduler
private FlintIndexOptions createExternalSchedulerFlintIndexOptions() {
FlintIndexOptions options = new FlintIndexOptions();
options.setOption(FlintIndexOptions.SCHEDULER_MODE, "external");
return options;
}

private void givenValidDataSourceMetadataExist() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
DATASOURCE_NAME, asyncQueryRequestContext))
Expand Down Expand Up @@ -560,14 +579,15 @@ private void verifyGetSessionIdCalled() {
assertEquals(APPLICATION_ID, createSessionRequest.getApplicationId());
}

private void verifyStoreJobMetadataCalled(String jobId) {
private void verifyStoreJobMetadataCalled(String jobId, JobType jobType) {
verify(asyncQueryJobMetadataStorageService)
.storeJobMetadata(
asyncQueryJobMetadataArgumentCaptor.capture(), eq(asyncQueryRequestContext));
AsyncQueryJobMetadata asyncQueryJobMetadata = asyncQueryJobMetadataArgumentCaptor.getValue();
assertEquals(QUERY_ID, asyncQueryJobMetadata.getQueryId());
assertEquals(jobId, asyncQueryJobMetadata.getJobId());
assertEquals(DATASOURCE_NAME, asyncQueryJobMetadata.getDatasourceName());
assertEquals(jobType, asyncQueryJobMetadata.getJobType());
}

private void verifyCreateIndexDMLResultCalled() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ void testDispatchCreateManualRefreshIndexQuery() {
verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -661,6 +662,7 @@ void testRefreshIndexQuery() {
verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -831,12 +833,7 @@ void testDispatchWithUnSupportedDataSourceType() {

@Test
void testCancelJob() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
.withJobRunId(EMR_JOB_ID)
.withApplicationId(EMRS_APPLICATION_ID));
givenCancelJobRunSucceed();

String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);
Expand Down Expand Up @@ -897,17 +894,32 @@ void testCancelQueryWithInvalidStatementId() {

@Test
void testCancelQueryWithNoSessionId() {
givenCancelJobRunSucceed();

String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
}

@Test
void testCancelBatchJob() {
givenCancelJobRunSucceed();

String queryId =
sparkQueryDispatcher.cancelJob(
asyncQueryJobMetadata(JobType.BATCH), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
}

private void givenCancelJobRunSucceed() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
.withJobRunId(EMR_JOB_ID)
.withApplicationId(EMRS_APPLICATION_ID));

String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
}

@Test
Expand Down Expand Up @@ -1154,11 +1166,16 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str
}

private AsyncQueryJobMetadata asyncQueryJobMetadata() {
return asyncQueryJobMetadata(JobType.INTERACTIVE);
}

private AsyncQueryJobMetadata asyncQueryJobMetadata(JobType jobType) {
return AsyncQueryJobMetadata.builder()
.queryId(QUERY_ID)
.applicationId(EMRS_APPLICATION_ID)
.jobId(EMR_JOB_ID)
.datasourceName(MY_GLUE)
.jobType(jobType)
.build();
}

Expand Down

0 comments on commit 5bcc23b

Please sign in to comment.