From be8271455f210c148f6202672471aa47a3daaccc Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Thu, 21 Sep 2023 15:32:09 -0700 Subject: [PATCH] Cancel Job API (#2126) Signed-off-by: Vamsi Manohar --- .../rest/RestDataSourceQueryAction.java | 6 +- docs/user/interfaces/asyncqueryinterface.rst | 27 ++++--- .../asyncquery/AsyncQueryExecutorService.java | 8 ++ .../AsyncQueryExecutorServiceImpl.java | 11 +++ .../spark/client/EmrServerlessClientImpl.java | 20 +++++ .../sql/spark/client/SparkJobClient.java | 3 + .../dispatcher/SparkQueryDispatcher.java | 6 ++ .../rest/RestAsyncQueryManagementAction.java | 2 +- ...ransportCancelAsyncQueryRequestAction.java | 17 ++++- .../model/CancelAsyncQueryActionRequest.java | 2 + .../AsyncQueryExecutorServiceImplTest.java | 30 ++++++++ .../client/EmrServerlessClientImplTest.java | 29 +++++++ .../dispatcher/SparkQueryDispatcherTest.java | 76 +++++++++++-------- ...portCancelAsyncQueryRequestActionTest.java | 29 ++++++- 14 files changed, 213 insertions(+), 53 deletions(-) diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index b5929d0f20..2947afc5b9 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -88,8 +88,7 @@ public List routes() { new Route(GET, BASE_DATASOURCE_ACTION_URL), /* - * GET datasources - * Request URL: GET + * PUT datasources * Request body: * Ref * [org.opensearch.sql.plugin.transport.datasource.model.UpdateDataSourceActionRequest] @@ -100,8 +99,7 @@ public List routes() { new Route(PUT, BASE_DATASOURCE_ACTION_URL), /* - * GET datasources - * Request URL: GET + * DELETE datasources * Request body: Ref * [org.opensearch.sql.plugin.transport.datasource.model.DeleteDataSourceActionRequest] * Response body: Ref diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index 98990b795b..f59afe8180 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -32,16 +32,13 @@ We make use of default aws credentials chain to make calls to the emr serverless have pass role permissions for emr-job-execution-role mentioned in the engine configuration. - Async Query Creation API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``. -HTTP URI: _plugins/_query/_async_query +HTTP URI: _plugins/_async_query HTTP VERB: POST - - Sample Request:: curl --location 'http://localhost:9200/_plugins/_async_query' \ @@ -57,23 +54,19 @@ Sample Response:: "queryId": "00fd796ut1a7eg0q" } + Async Query Result API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/result``. Async Query Creation and Result Query permissions are orthogonal, so any user with result api permissions and queryId can query the corresponding query results irrespective of the user who created the async query. - -HTTP URI: _plugins/_query/_async_query/{queryId} +HTTP URI: _plugins/_async_query/{queryId} HTTP VERB: GET - Sample Request BODY:: curl --location --request GET 'http://localhost:9200/_plugins/_async_query/00fd796ut1a7eg0q' \ --header 'Content-Type: application/json' \ - --data '{ - "query" : "select * from default.http_logs limit 1" - }' Sample Response if the Query is in Progress :: @@ -106,3 +99,17 @@ Sample Response If the Query is successful :: "total": 1, "size": 1 } + + +Async Query Cancellation API +====================================== +If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/jobs/delete``. + +HTTP URI: _plugins/_async_query/{queryId} +HTTP VERB: DELETE + +Sample Request Body :: + + curl --location --request DELETE 'http://localhost:9200/_plugins/_async_query/00fdalrvgkbh2g0q' \ + --header 'Content-Type: application/json' \ + diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index df13daa2a2..7caa69293a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -29,4 +29,12 @@ public interface AsyncQueryExecutorService { * @return {@link AsyncQueryExecutionResponse} */ AsyncQueryExecutionResponse getAsyncQueryResults(String queryId); + + /** + * Cancels running async query and returns the cancelled queryId. + * + * @param queryId queryId. + * @return {@link String} cancelledQueryId. + */ + String cancelQuery(String queryId); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index e5ed65920e..efc23e08b5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -95,6 +95,17 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } + @Override + public String cancelQuery(String queryId) { + Optional asyncQueryJobMetadata = + asyncQueryJobMetadataStorageService.getJobMetadata(queryId); + if (asyncQueryJobMetadata.isPresent()) { + return sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadata.get().getApplicationId(), queryId); + } + throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); + } + private void validateSparkExecutionEngineSettings() { if (!isSparkJobExecutionEnabled) { throw new IllegalArgumentException( diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index b554c4cd23..2377b2f5da 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -9,12 +9,15 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.CancelJobRunRequest; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunRequest; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobDriver; import com.amazonaws.services.emrserverless.model.SparkSubmit; import com.amazonaws.services.emrserverless.model.StartJobRunRequest; import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import com.amazonaws.services.emrserverless.model.ValidationException; import java.security.AccessController; import java.security.PrivilegedAction; import org.apache.logging.log4j.LogManager; @@ -65,4 +68,21 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { logger.info("Job Run state: " + getJobRunResult.getJobRun().getState()); return getJobRunResult; } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + CancelJobRunRequest cancelJobRunRequest = + new CancelJobRunRequest().withJobRunId(jobId).withApplicationId(applicationId); + try { + CancelJobRunResult cancelJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) + () -> emrServerless.cancelJobRun(cancelJobRunRequest)); + logger.info(String.format("Job : %s cancelled", cancelJobRunResult.getJobRunId())); + return cancelJobRunResult; + } catch (ValidationException e) { + throw new IllegalArgumentException( + String.format("Couldn't cancel the queryId: %s due to %s", jobId, e.getMessage())); + } + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java index ff9f4acedd..c6b3059c77 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java @@ -7,6 +7,7 @@ package org.opensearch.sql.spark.client; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; public interface SparkJobClient { @@ -19,4 +20,6 @@ String startJobRun( String sparkSubmitParams); GetJobRunResult getJobRunResult(String applicationId, String jobId); + + CancelJobRunResult cancelJobRun(String applicationId, String jobId); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index f632ceaf6a..442838331f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -15,6 +15,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.net.URI; @@ -64,6 +65,11 @@ public JSONObject getQueryResponse(String applicationId, String queryId) { return result; } + public String cancelJob(String applicationId, String jobId) { + CancelJobRunResult cancelJobRunResult = sparkJobClient.cancelJobRun(applicationId, jobId); + return cancelJobRunResult.getJobRunId(); + } + // TODO: Analyze given query // Extract datasourceName // Apply Authorizaiton. diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 56484688dc..741501cd18 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -194,7 +194,7 @@ public void onResponse( CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse) { restChannel.sendResponse( new BytesRestResponse( - RestStatus.OK, + RestStatus.NO_CONTENT, "application/json; charset=UTF-8", cancelAsyncQueryActionResponse.getResult())); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 990dbccd0b..232a280db5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -12,6 +12,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -21,13 +22,17 @@ public class TransportCancelAsyncQueryRequestAction extends HandledTransportAction { public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete"; + private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CancelAsyncQueryActionResponse::new); @Inject public TransportCancelAsyncQueryRequestAction( - TransportService transportService, ActionFilters actionFilters) { + TransportService transportService, + ActionFilters actionFilters, + AsyncQueryExecutorServiceImpl asyncQueryExecutorService) { super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new); + this.asyncQueryExecutorService = asyncQueryExecutorService; } @Override @@ -35,7 +40,13 @@ protected void doExecute( Task task, CancelAsyncQueryActionRequest request, ActionListener listener) { - String responseContent = "deleted_job"; - listener.onResponse(new CancelAsyncQueryActionResponse(responseContent)); + try { + String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); + listener.onResponse( + new CancelAsyncQueryActionResponse( + String.format("Deleted async query with id: %s", jobId))); + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java index e12f184efe..0065b575ed 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java @@ -9,11 +9,13 @@ import java.io.IOException; import lombok.AllArgsConstructor; +import lombok.Getter; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; @AllArgsConstructor +@Getter public class CancelAsyncQueryActionRequest extends ActionRequest { private String queryId; diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index cf04278892..5e832777fc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -142,4 +142,34 @@ void testGetAsyncQueryResultsWithDisabledExecutionEngine() { + " to enable Async Query APIs", illegalArgumentException.getMessage()); } + + @Test + void testCancelJobWithJobNotFound() { + AsyncQueryExecutorService asyncQueryExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.empty()); + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)); + Assertions.assertEquals( + "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); + verifyNoInteractions(sparkQueryDispatcher); + verifyNoInteractions(settings); + } + + @Test + void testCancelJob() { + AsyncQueryExecutorService asyncQueryExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + when(sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID)).thenReturn(EMR_JOB_ID); + String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, jobId); + verifyNoInteractions(settings); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 36f10cd08b..925ee73bcd 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -5,17 +5,22 @@ package org.opensearch.sql.spark.client; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_JOB_NAME; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.QUERY; import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS; import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import com.amazonaws.services.emrserverless.model.ValidationException; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -45,4 +50,28 @@ void testGetJobRunState() { EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); } + + @Test + void testCancelJobRun() { + when(emrServerless.cancelJobRun(any())) + .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); + } + + @Test + void testCancelJobRunWithValidationException() { + doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)); + Assertions.assertEquals( + "Couldn't cancel the queryId: job-123xxx due to Error (Service: null; Status Code: 0; Error" + + " Code: null; Request ID: null; Proxy: null)", + illegalArgumentException.getMessage()); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 800bd59b72..2000eeefed 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -14,6 +14,7 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; @@ -79,36 +80,17 @@ void testDispatchWithWrongURI() { illegalArgumentException.getMessage()); } - private DataSourceMetadata constructMyGlueDataSourceMetadata() { - DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); - dataSourceMetadata.setName("my_glue"); - dataSourceMetadata.setConnector(DataSourceType.S3GLUE); - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put( - "glue.indexstore.opensearch.uri", - "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); - properties.put("glue.indexstore.opensearch.auth", "sigv4"); - properties.put("glue.indexstore.opensearch.region", "eu-west-1"); - dataSourceMetadata.setProperties(properties); - return dataSourceMetadata; - } - - private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { - DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); - dataSourceMetadata.setName("my_glue"); - dataSourceMetadata.setConnector(DataSourceType.S3GLUE); - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); - properties.put("glue.indexstore.opensearch.auth", "sigv4"); - properties.put("glue.indexstore.opensearch.region", "eu-west-1"); - dataSourceMetadata.setProperties(properties); - return dataSourceMetadata; + @Test + void testCancelJob() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + when(sparkJobClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + String jobId = sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test @@ -140,7 +122,7 @@ void testGetQueryResponseWithSuccess() { Assertions.assertEquals("SUCCESS", result.get("status")); } - String constructExpectedSparkSubmitParameterString() { + private String constructExpectedSparkSubmitParameterString() { return " --class org.opensearch.sql.FlintJob --conf" + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + " --conf" @@ -171,4 +153,36 @@ String constructExpectedSparkSubmitParameterString() { + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog "; } + + private DataSourceMetadata constructMyGlueDataSourceMetadata() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put( + "glue.indexstore.opensearch.uri", + "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } + + private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index c560c882c0..2ff76b9b57 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -7,6 +7,10 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; + import java.util.HashSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -19,6 +23,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -32,24 +37,40 @@ public class TransportCancelAsyncQueryRequestActionTest { @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private AsyncQueryExecutorServiceImpl asyncQueryExecutorService; + @Captor private ArgumentCaptor deleteJobActionResponseArgumentCaptor; + @Captor private ArgumentCaptor exceptionArgumentCaptor; + @BeforeEach public void setUp() { action = new TransportCancelAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>())); + transportService, new ActionFilters(new HashSet<>()), asyncQueryExecutorService); } @Test public void testDoExecute() { - CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest("jobId"); - + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse = deleteJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("deleted_job", cancelAsyncQueryActionResponse.getResult()); + Assertions.assertEquals( + "Deleted async query with id: " + EMR_JOB_ID, cancelAsyncQueryActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("Error", exception.getMessage()); } }