diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 486514f366..52cd863081 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -74,7 +74,7 @@ public String cancelJob( @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { - leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource())); + leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource())); DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); @@ -84,7 +84,7 @@ public DispatchQueryResponse submit( .resultIndex(resp.getResultIndex()) .sessionId(resp.getSessionId()) .datasourceName(dataSourceMetadata.getName()) - .jobType(JobType.BATCH) + .jobType(JobType.REFRESH) .indexName(context.getIndexQueryDetails().openSearchIndexName()) .status(QueryState.WAITING) .build(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index a6fdd3f102..e9da322fd9 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -175,7 +175,7 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( return queryHandlerFactory.getInteractiveQueryHandler(); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { return queryHandlerFactory.getIndexDMLHandler(); - } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { + } else if (asyncQueryJobMetadata.getJobType() == JobType.REFRESH) { return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId()); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java index 01f5f422e9..af1f69d74b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java @@ -8,6 +8,7 @@ public enum JobType { INTERACTIVE("interactive"), STREAMING("streaming"), + REFRESH("refresh"), BATCH("batch"); private String text; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index b00e1f3d34..221a12ac2e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -421,6 +421,7 @@ void testDispatchCreateManualRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); assertEquals(expected, startJobRequestArgumentCaptor.getValue()); assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -757,6 +758,7 @@ void testRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); assertEquals(expected, startJobRequestArgumentCaptor.getValue()); assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -932,12 +934,7 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) - .thenReturn( - new CancelJobRunResult() - .withJobRunId(EMR_JOB_ID) - .withApplicationId(EMRS_APPLICATION_ID)); + givenCancelJobRunSucceed(); String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); @@ -997,6 +994,25 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + @Test + void testCancelBatchJob() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(JobType.BATCH), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + private void givenCancelJobRunSucceed() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( @@ -1273,11 +1289,16 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str } private AsyncQueryJobMetadata asyncQueryJobMetadata() { + return asyncQueryJobMetadata(JobType.INTERACTIVE); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadata(JobType jobType) { return AsyncQueryJobMetadata.builder() .queryId(QUERY_ID) .applicationId(EMRS_APPLICATION_ID) .jobId(EMR_JOB_ID) .datasourceName(MY_GLUE) + .jobType(jobType) .build(); }