From 41100be087b016971250b686dff0d14aafa34b9f Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Mon, 29 Jul 2024 13:44:56 -0700 Subject: [PATCH 01/10] Fix Integ test for datasource enabled setting with security plugin (#2865) Signed-off-by: Vamsi Manohar --- .../sql/datasource/DataSourceEnabledIT.java | 39 +++++-------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java index 480a6dc563..9c522134a4 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java @@ -6,7 +6,6 @@ package org.opensearch.sql.datasource; import static org.opensearch.sql.legacy.TestUtils.getResponseBody; -import static org.opensearch.sql.legacy.TestsConstants.DATASOURCES; import lombok.SneakyThrows; import org.json.JSONObject; @@ -25,31 +24,25 @@ protected boolean preserveClusterUponCompletion() { } @Test - public void testDataSourceIndexIsCreatedByDefault() { - assertDataSourceCount(0); - assertSelectFromDataSourceReturnsDoesNotExist(); - assertDataSourceIndexCreated(true); - } - - @Test - public void testDataSourceIndexIsCreatedIfSettingIsEnabled() { - setDataSourcesEnabled("transient", true); + public void testAsyncQueryAPIFailureIfSettingIsDisabled() { + setDataSourcesEnabled("transient", false); assertDataSourceCount(0); assertSelectFromDataSourceReturnsDoesNotExist(); - assertDataSourceIndexCreated(true); + assertAsyncQueryApiDisabled(); } @Test - public void testDataSourceIndexIsNotCreatedIfSettingIsDisabled() { - setDataSourcesEnabled("transient", false); - assertDataSourceCount(0); - assertSelectFromDataSourceReturnsDoesNotExist(); - assertDataSourceIndexCreated(false); - assertAsyncQueryApiDisabled(); + public void testDataSourceCreationWithDefaultSettings() { + createOpenSearchDataSource(); + createIndex(); + assertDataSourceCount(1); + assertSelectFromDataSourceReturnsSuccess(); + assertSelectFromDummyIndexInValidDataSourceDataSourceReturnsDoesNotExist(); } @Test public void testAfterPreviousEnable() { + setDataSourcesEnabled("transient", true); createOpenSearchDataSource(); createIndex(); assertDataSourceCount(1); @@ -141,18 +134,6 @@ private void assertDataSourceCount(int expected) { Assert.assertEquals(expected, jsonBody.getJSONArray("datarows").length()); } - @SneakyThrows - private void assertDataSourceIndexCreated(boolean expected) { - Request request = new Request("GET", "/" + DATASOURCES); - Response response = performRequest(request); - String responseBody = getResponseBody(response); - boolean indexDoesExist = - response.getStatusLine().getStatusCode() == 200 - && responseBody.contains(DATASOURCES) - && responseBody.contains("mappings"); - Assert.assertEquals(expected, indexDoesExist); - } - @SneakyThrows private Response performRequest(Request request) { try { From 6b8ee3da41908d9b0e8987feb19d02790f546158 Mon Sep 17 00:00:00 2001 From: "Daniel (dB.) Doubrovkine" Date: Mon, 29 Jul 2024 16:51:32 -0500 Subject: [PATCH 02/10] Update PULL_REQUEST_TEMPLATE to include an API spec change in the checklist. (#2808) * Update PULL_REQUEST_TEMPLATE to include an API spec change in the checklist. Signed-off-by: dblock * Re-added sections. Signed-off-by: dblock --------- Signed-off-by: dblock --- .github/PULL_REQUEST_TEMPLATE.md | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2e325678fe..c84ed5b13a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,16 +1,18 @@ ### Description [Describe what this change achieves] - -### Issues Resolved -[List any issues this PR will resolve] - + +### Related Issues +Resolves #[Issue number to be closed when this PR is merged] + + ### Check List - [ ] New functionality includes testing. - - [ ] All tests pass, including unit test, integration test and doctest - [ ] New functionality has been documented. - - [ ] New functionality has javadoc added - - [ ] New functionality has user manual doc added -- [ ] Commits are signed per the DCO using --signoff + - [ ] New functionality has javadoc added. + - [ ] New functionality has a user manual doc added. +- [ ] API changes companion pull request [created](https://github.com/opensearch-project/opensearch-api-specification/blob/main/DEVELOPER_GUIDE.md). +- [ ] Commits are signed per the DCO using `--signoff`. +- [ ] Public documentation issue/PR [created](https://github.com/opensearch-project/documentation-website/issues/new/choose). By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. -For more information on following Developer Certificate of Origin and signing off your commits, please check [here](https://github.com/opensearch-project/OpenSearch/blob/main/CONTRIBUTING.md#developer-certificate-of-origin). \ No newline at end of file +For more information on following Developer Certificate of Origin and signing off your commits, please check [here](https://github.com/opensearch-project/sql/blob/main/CONTRIBUTING.md#developer-certificate-of-origin). From ba82e1255b301d92eee9e1ad36e44e07afdb3839 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Tue, 30 Jul 2024 13:51:35 -0700 Subject: [PATCH 03/10] Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetada method (#2866) * Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetadata method Signed-off-by: Tomoyuki Morita * Add comments Signed-off-by: Tomoyuki Morita * Fix style Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita --- .../model/AsyncQueryRequestContext.java | 6 +- .../dispatcher/SparkQueryDispatcher.java | 2 +- .../asyncquery/AsyncQueryCoreIntegTest.java | 3 +- .../dispatcher/SparkQueryDispatcherTest.java | 72 ++++++++++++------- .../sql/datasource/DataSourceService.java | 5 +- .../sql/datasource/RequestContext.java | 15 ++++ .../sql/analysis/AnalyzerTestBase.java | 4 +- .../service/DataSourceServiceImpl.java | 4 +- .../service/DataSourceServiceImplTest.java | 8 ++- 9 files changed, 84 insertions(+), 35 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/datasource/RequestContext.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java index 56176faefb..d5a478d592 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.asyncquery.model; +import org.opensearch.sql.datasource.RequestContext; + /** Context interface to provide additional request related information */ -public interface AsyncQueryRequestContext { - Object getAttribute(String name); -} +public interface AsyncQueryRequestContext extends RequestContext {} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0e871f9ddc..0061ea7179 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -44,7 +44,7 @@ public DispatchQueryResponse dispatch( AsyncQueryRequestContext asyncQueryRequestContext) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - dispatchQueryRequest.getDatasource()); + dispatchQueryRequest.getDatasource(), asyncQueryRequestContext); if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { String query = dispatchQueryRequest.getQuery(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 99d4cc722e..34ededc74d 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -512,7 +512,8 @@ private void givenFlintIndexMetadataExists(String indexName) { } private void givenValidDataSourceMetadataExist() { - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(DATASOURCE_NAME)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn( new DataSourceMetadata.Builder() .setName(DATASOURCE_NAME) diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index f9a83ef9f6..a7a79c758e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -180,7 +180,8 @@ void testDispatchSelectQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -223,7 +224,8 @@ void testDispatchSelectQueryWithLakeFormation() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithLakeFormation(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -255,7 +257,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -278,7 +281,8 @@ void testDispatchSelectQueryCreateNewSession() { doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -304,7 +308,8 @@ void testDispatchSelectQueryReuseSession() { when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); when(session.isOperationalForDataSource(any())).thenReturn(true); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -324,7 +329,8 @@ void testDispatchSelectQueryFailedCreateSession() { doReturn(true).when(sessionManager).isEnabled(); doThrow(RuntimeException.class).when(sessionManager).createSession(any(), any()); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); Assertions.assertThrows( @@ -358,7 +364,8 @@ void testDispatchCreateAutoRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -393,7 +400,8 @@ void testDispatchCreateManualRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -426,7 +434,8 @@ void testDispatchWithPPLQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -450,7 +459,8 @@ void testDispatchWithSparkUDFQuery() { "CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'"); for (String query : udfQueries) { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); IllegalArgumentException illegalArgumentException = @@ -489,7 +499,8 @@ void testInvalidSQLQueryDispatchToSpark() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -532,7 +543,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -568,7 +580,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -589,8 +602,7 @@ void testDispatchMaterializedViewQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); String query = - "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" - + " (auto_refresh = true)"; + "CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); StartJobRequest expected = new StartJobRequest( @@ -604,7 +616,8 @@ void testDispatchMaterializedViewQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -637,7 +650,8 @@ void testDispatchShowMVQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -670,7 +684,8 @@ void testRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -703,7 +718,8 @@ void testDispatchDescribeIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -739,7 +755,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -762,7 +779,8 @@ void testDispatchAlterToManualRefreshIndexQuery() { "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = false)"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -785,7 +803,8 @@ void testDispatchDropIndexQuery() { String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -808,7 +827,8 @@ void testDispatchVacuumIndexQuery() { String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -824,7 +844,8 @@ void testDispatchVacuumIndexQuery() { @Test void testDispatchWithUnSupportedDataSourceType() { - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_prometheus", asyncQueryRequestContext)) .thenReturn(constructPrometheusDataSourceType()); String query = "select * from my_prometheus.default.http_logs"; @@ -1018,7 +1039,8 @@ void testGetQueryResponseWithSuccess() { void testDispatchQueryWithExtraSparkSubmitParameters() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); String extraParameters = "--conf spark.dynamicAllocation.enabled=false"; diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 6af5d19e5c..a8caa4719a 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -82,6 +82,9 @@ public interface DataSourceService { * Specifically for addressing use cases in SparkQueryDispatcher. * * @param dataSourceName of the {@link DataSource} + * @param context request context used by the implementation. It is passed by async-query-core. + * refer {@link RequestContext} */ - DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName); + DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext context); } diff --git a/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java b/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java new file mode 100644 index 0000000000..199930d340 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.datasource; + +/** + * Context interface to provide additional request related information. It is introduced to allow + * async-query-core library user to pass request context information to implementations of data + * accessors. + */ +public interface RequestContext { + Object getAttribute(String name); +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index b35cfbb5e1..0bf959a1b7 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -28,6 +28,7 @@ import org.opensearch.sql.config.TestConfig; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; @@ -236,7 +237,8 @@ public Boolean dataSourceExists(String dataSourceName) { } @Override - public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) { + public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext requestContext) { return null; } } diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 61f3c8cd5d..81b6432891 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -11,6 +11,7 @@ import java.util.*; import java.util.stream.Collectors; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceStatus; @@ -122,7 +123,8 @@ public Boolean dataSourceExists(String dataSourceName) { } @Override - public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) { + public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext requestContext) { DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName); verifyDataSourceAccess(dataSourceMetadata); return dataSourceMetadata; diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index 5a94945e5b..9a1022706f 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -36,6 +36,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceStatus; @@ -52,6 +53,7 @@ class DataSourceServiceImplTest { @Mock private DataSourceFactory dataSourceFactory; @Mock private StorageEngine storageEngine; @Mock private DataSourceMetadataStorage dataSourceMetadataStorage; + @Mock private RequestContext requestContext; @Mock private DataSourceUserAuthorizationHelper dataSourceUserAuthorizationHelper; @@ -461,7 +463,9 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadataWithDisabledData() { DatasourceDisabledException datasourceDisabledException = Assertions.assertThrows( DatasourceDisabledException.class, - () -> dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS")); + () -> + dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "testDS", requestContext)); Assertions.assertEquals( "Datasource testDS is disabled.", datasourceDisabledException.getMessage()); } @@ -484,7 +488,7 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadata() { when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) .thenReturn(Optional.of(dataSourceMetadata)); DataSourceMetadata dataSourceMetadata1 = - dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS"); + dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS", requestContext); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.uri")); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username")); From 103c4160ae2a129284ff5be70bac40951e0c6a18 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Tue, 30 Jul 2024 15:20:19 -0700 Subject: [PATCH 04/10] Fixed 2.16 integ test failures (#2871) Signed-off-by: Vamsi Manohar --- .../opensearch/sql/datasource/DataSourceEnabledIT.java | 10 ++++++++++ .../sql/legacy/OpenSearchSQLRestTestCase.java | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java index 9c522134a4..b0bc87a0c6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java @@ -38,6 +38,7 @@ public void testDataSourceCreationWithDefaultSettings() { assertDataSourceCount(1); assertSelectFromDataSourceReturnsSuccess(); assertSelectFromDummyIndexInValidDataSourceDataSourceReturnsDoesNotExist(); + deleteSelfDataSourceCreated(); } @Test @@ -52,6 +53,8 @@ public void testAfterPreviousEnable() { assertDataSourceCount(0); assertSelectFromDataSourceReturnsDoesNotExist(); assertAsyncQueryApiDisabled(); + setDataSourcesEnabled("transient", true); + deleteSelfDataSourceCreated(); } @SneakyThrows @@ -142,4 +145,11 @@ private Response performRequest(Request request) { return e.getResponse(); } } + + @SneakyThrows + private void deleteSelfDataSourceCreated() { + Request deleteRequest = getDeleteDataSourceRequest("self"); + Response deleteResponse = client().performRequest(deleteRequest); + Assert.assertEquals(204, deleteResponse.getStatusLine().getStatusCode()); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java index d73e3468d4..ced69d54a0 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/OpenSearchSQLRestTestCase.java @@ -195,7 +195,9 @@ protected static void wipeAllOpenSearchIndices(RestClient client) throws IOExcep try { // System index, mostly named .opensearch-xxx or .opendistro-xxx, are not allowed to // delete - if (!indexName.startsWith(".opensearch") && !indexName.startsWith(".opendistro")) { + if (!indexName.startsWith(".opensearch") + && !indexName.startsWith(".opendistro") + && !indexName.startsWith(".ql")) { client.performRequest(new Request("DELETE", "/" + indexName)); } } catch (Exception e) { From aa7a6902a8d03647eecf45564c488c936c53ee3f Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 31 Jul 2024 17:05:28 +0800 Subject: [PATCH 05/10] Change the default value of plugins.query.size_limit to MAX_RESULT_WINDOW (10000) (#2860) * Change the default value of plugins.query.size_limit to MAX_RESULT_WINDOW (10000) Signed-off-by: Lantao Jin * fix ut Signed-off-by: Lantao Jin * fix spotless Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- docs/user/admin/settings.rst | 2 +- docs/user/optimization/optimization.rst | 22 +++++++++---------- docs/user/ppl/admin/settings.rst | 4 ++-- docs/user/ppl/interfaces/endpoint.rst | 2 +- .../org/opensearch/sql/legacy/ExplainIT.java | 2 +- .../setting/OpenSearchSettings.java | 3 ++- .../setting/OpenSearchSettingsTest.java | 6 ++--- 7 files changed, 21 insertions(+), 20 deletions(-) diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index 662d882745..6b24e41f87 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -202,7 +202,7 @@ plugins.query.size_limit Description ----------- -The new engine fetches a default size of index from OpenSearch set by this setting, the default value is 200. You can change the value to any value not greater than the max result window value in index level (10000 by default), here is an example:: +The new engine fetches a default size of index from OpenSearch set by this setting, the default value equals to max result window in index level (10000 by default). You can change the value to any value not greater than the max result window value in index level (`index.max_result_window`), here is an example:: >> curl -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings -d '{ "transient" : { diff --git a/docs/user/optimization/optimization.rst b/docs/user/optimization/optimization.rst index 8ab998309d..835fe96eba 100644 --- a/docs/user/optimization/optimization.rst +++ b/docs/user/optimization/optimization.rst @@ -44,7 +44,7 @@ The consecutive Filter operator will be merged as one Filter operator:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"bool\":{\"filter\":[{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"bool\":{\"filter\":[{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, searchDone=false)" }, "children": [] } @@ -71,7 +71,7 @@ The Filter operator should be push down under Sort operator:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":null,\"to\":20,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone=false)" }, "children": [] } @@ -102,7 +102,7 @@ The Project list will push down to Query DSL to `filter the source `_. +Without sort push down optimization, the sort operator will sort the result from child operator. By default, only 10000 docs will extracted from the source index, `you can change this value by using size_limit setting <../admin/settings.rst#opensearch-query-size-limit>`_. diff --git a/docs/user/ppl/admin/settings.rst b/docs/user/ppl/admin/settings.rst index ad56408693..28e6897d3d 100644 --- a/docs/user/ppl/admin/settings.rst +++ b/docs/user/ppl/admin/settings.rst @@ -125,9 +125,9 @@ plugins.query.size_limit Description ----------- -The size configure the maximum amount of documents to be pull from OpenSearch. The default value is: 200 +The size configure the maximum amount of documents to be pull from OpenSearch. The default value is: 10000 -Notes: This setting will impact the correctness of the aggregation operation, for example, there are 1000 docs in the index, by default, only 200 docs will be extract from index and do aggregation. +Notes: This setting will impact the correctness of the aggregation operation, for example, there are 1000 docs in the index, if you change the value to 200, only 200 docs will be extract from index and do aggregation. Example ------- diff --git a/docs/user/ppl/interfaces/endpoint.rst b/docs/user/ppl/interfaces/endpoint.rst index 793b94eb8d..fb931fb0ba 100644 --- a/docs/user/ppl/interfaces/endpoint.rst +++ b/docs/user/ppl/interfaces/endpoint.rst @@ -91,7 +91,7 @@ The following PPL query demonstrated that where and stats command were pushed do { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" }, "children": [] } diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java index b42e9f84f4..27f8eca3ef 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/ExplainIT.java @@ -185,7 +185,7 @@ public void orderByOnNestedFieldTest() throws Exception { Assert.assertThat( result.replaceAll("\\s+", ""), equalTo( - "{\"from\":0,\"size\":200,\"sort\":[{\"message.info\":" + "{\"from\":0,\"size\":10000,\"sort\":[{\"message.info\":" + "{\"order\":\"asc\",\"nested\":{\"path\":\"message\"}}}]}")); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index b4ce82a828..475a584623 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -28,6 +28,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.unit.MemorySizeValue; import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.IndexSettings; import org.opensearch.sql.common.setting.LegacySettings; import org.opensearch.sql.common.setting.Settings; @@ -90,7 +91,7 @@ public class OpenSearchSettings extends Settings { public static final Setting QUERY_SIZE_LIMIT_SETTING = Setting.intSetting( Key.QUERY_SIZE_LIMIT.getKeyValue(), - LegacyOpenDistroSettings.QUERY_SIZE_LIMIT_SETTING, + IndexSettings.MAX_RESULT_WINDOW_SETTING, 0, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java index e99e5b360a..84fb705ae0 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java @@ -34,6 +34,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.index.IndexSettings; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.sql.common.setting.LegacySettings; import org.opensearch.sql.common.setting.Settings; @@ -132,8 +133,7 @@ void settingsFallback() { org.opensearch.common.settings.Settings.EMPTY)); assertEquals( settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT), - LegacyOpenDistroSettings.QUERY_SIZE_LIMIT_SETTING.get( - org.opensearch.common.settings.Settings.EMPTY)); + IndexSettings.MAX_RESULT_WINDOW_SETTING.get(org.opensearch.common.settings.Settings.EMPTY)); assertEquals( settings.getSettingValue(Settings.Key.METRICS_ROLLING_WINDOW), LegacyOpenDistroSettings.METRICS_ROLLING_WINDOW_SETTING.get( @@ -165,7 +165,7 @@ public void updateLegacySettingsFallback() { assertEquals( QUERY_MEMORY_LIMIT_SETTING.get(settings), new ByteSizeValue((int) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.2))); - assertEquals(QUERY_SIZE_LIMIT_SETTING.get(settings), 100); + assertEquals(QUERY_SIZE_LIMIT_SETTING.get(settings), 10000); assertEquals(METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L); assertEquals(METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L); } From 53bfeba8ffa0a79027c06fbb6157fa740333d5df Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Wed, 31 Jul 2024 15:08:35 -0700 Subject: [PATCH 06/10] Add AsyncQueryRequestContext to QueryIdProvider parameter (#2870) Signed-off-by: Tomoyuki Morita --- .../DatasourceEmbeddedQueryIdProvider.java | 5 ++- .../sql/spark/dispatcher/QueryIdProvider.java | 4 ++- .../dispatcher/SparkQueryDispatcher.java | 12 ++++--- .../asyncquery/AsyncQueryCoreIntegTest.java | 17 ++++----- ...DatasourceEmbeddedQueryIdProviderTest.java | 35 +++++++++++++++++++ 5 files changed, 59 insertions(+), 14 deletions(-) create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java index c170040718..3564fa9552 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.utils.IDUtils; @@ -12,7 +13,9 @@ public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider { @Override - public String getQueryId(DispatchQueryRequest dispatchQueryRequest) { + public String getQueryId( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext) { return IDUtils.encode(dispatchQueryRequest.getDatasource()); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java index 2167eb6b7a..a108ca1209 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java @@ -5,9 +5,11 @@ package org.opensearch.sql.spark.dispatcher; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; /** Interface for extension point to specify queryId. Called when new query is executed. */ public interface QueryIdProvider { - String getQueryId(DispatchQueryRequest dispatchQueryRequest); + String getQueryId( + DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0061ea7179..a424db4c34 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -69,7 +69,8 @@ private DispatchQueryResponse handleFlintExtensionQuery( DataSourceMetadata dataSourceMetadata) { IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + getDefaultDispatchContextBuilder( + dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext) .indexQueryDetails(indexQueryDetails) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); @@ -84,7 +85,8 @@ private DispatchQueryResponse handleDefaultQuery( DataSourceMetadata dataSourceMetadata) { DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + getDefaultDispatchContextBuilder( + dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); @@ -93,11 +95,13 @@ private DispatchQueryResponse handleDefaultQuery( } private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( - DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { + DispatchQueryRequest dispatchQueryRequest, + DataSourceMetadata dataSourceMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { return DispatchQueryContext.builder() .dataSourceMetadata(dataSourceMetadata) .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) - .queryId(queryIdProvider.getQueryId(dispatchQueryRequest)); + .queryId(queryIdProvider.getQueryId(dispatchQueryRequest, asyncQueryRequestContext)); } private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 34ededc74d..d82d3bdab7 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -185,7 +185,7 @@ public void setUp() { public void createDropIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); givenCancelJobRunSucceed(); @@ -209,7 +209,7 @@ public void createDropIndexQuery() { public void createVacuumIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); @@ -231,7 +231,7 @@ public void createVacuumIndexQuery() { public void createAlterIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); givenCancelJobRunSucceed(); @@ -261,7 +261,7 @@ public void createAlterIndexQuery() { public void createStreamingQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -297,7 +297,7 @@ private void verifyStartJobRunCalled() { public void createCreateIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -321,7 +321,7 @@ public void createCreateIndexQuery() { public void createRefreshQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -344,7 +344,7 @@ public void createInteractiveQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); givenSessionExists(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(sessionIdProvider.getSessionId(any())).thenReturn(SESSION_ID); givenSessionExists(); // called twice when(awsemrServerless.startJobRun(any())) @@ -538,7 +538,8 @@ private void givenGetJobRunReturnJobRunWithState(String state) { } private void verifyGetQueryIdCalled() { - verify(queryIdProvider).getQueryId(dispatchQueryRequestArgumentCaptor.capture()); + verify(queryIdProvider) + .getQueryId(dispatchQueryRequestArgumentCaptor.capture(), eq(asyncQueryRequestContext)); DispatchQueryRequest dispatchQueryRequest = dispatchQueryRequestArgumentCaptor.getValue(); assertEquals(ACCOUNT_ID, dispatchQueryRequest.getAccountId()); assertEquals(APPLICATION_ID, dispatchQueryRequest.getApplicationId()); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java new file mode 100644 index 0000000000..7f1c92dff3 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +class DatasourceEmbeddedQueryIdProviderTest { + @Mock AsyncQueryRequestContext asyncQueryRequestContext; + + DatasourceEmbeddedQueryIdProvider datasourceEmbeddedQueryIdProvider = + new DatasourceEmbeddedQueryIdProvider(); + + @Test + public void test() { + String queryId = + datasourceEmbeddedQueryIdProvider.getQueryId( + DispatchQueryRequest.builder().datasource("DATASOURCE").build(), + asyncQueryRequestContext); + + assertNotNull(queryId); + verifyNoInteractions(asyncQueryRequestContext); + } +} From 1b17520d79ee6e90e3994298e3998adc263b02b0 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Wed, 31 Jul 2024 15:59:37 -0700 Subject: [PATCH 07/10] Fixed integ test delete myindex issue and wipe All indices with security enabled domain (#2878) Signed-off-by: Vamsi Manohar --- .../sql/datasource/DataSourceEnabledIT.java | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java index b0bc87a0c6..a53c04d871 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceEnabledIT.java @@ -7,8 +7,10 @@ import static org.opensearch.sql.legacy.TestUtils.getResponseBody; +import java.io.IOException; import lombok.SneakyThrows; import org.json.JSONObject; +import org.junit.After; import org.junit.Assert; import org.junit.Test; import org.opensearch.client.Request; @@ -18,9 +20,9 @@ public class DataSourceEnabledIT extends PPLIntegTestCase { - @Override - protected boolean preserveClusterUponCompletion() { - return false; + @After + public void cleanUp() throws IOException { + wipeAllClusterSettings(); } @Test @@ -39,6 +41,7 @@ public void testDataSourceCreationWithDefaultSettings() { assertSelectFromDataSourceReturnsSuccess(); assertSelectFromDummyIndexInValidDataSourceDataSourceReturnsDoesNotExist(); deleteSelfDataSourceCreated(); + deleteIndex(); } @Test @@ -55,6 +58,7 @@ public void testAfterPreviousEnable() { assertAsyncQueryApiDisabled(); setDataSourcesEnabled("transient", true); deleteSelfDataSourceCreated(); + deleteIndex(); } @SneakyThrows @@ -98,6 +102,12 @@ private void createIndex() { Assert.assertEquals(200, response.getStatusLine().getStatusCode()); } + private void deleteIndex() { + Request request = new Request("DELETE", "/myindex"); + Response response = performRequest(request); + Assert.assertEquals(200, response.getStatusLine().getStatusCode()); + } + private void createOpenSearchDataSource() { Request request = new Request("POST", "/_plugins/_query/_datasources"); request.setJsonEntity( From 3daf64fbce5a8d29e846669689b3a7b12c5c7f07 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 31 Jul 2024 16:10:13 -0700 Subject: [PATCH 08/10] [Feature] Flint query scheduler part1 - integrate job scheduler plugin (#2834) * [Feature] Flint query scheduler part1 - integrate job scheduler plugin Signed-off-by: Louis Chu * Add comments Signed-off-by: Louis Chu * Add unit test Signed-off-by: Louis Chu * Remove test rest API Signed-off-by: Louis Chu * Fix doc test Signed-off-by: Louis Chu * Add more tests Signed-off-by: Louis Chu * Fix IT Signed-off-by: Louis Chu * Fix IT with security Signed-off-by: Louis Chu * Improve test coverage Signed-off-by: Louis Chu * Fix integTest cluster Signed-off-by: Louis Chu * Fix UT Signed-off-by: Louis Chu * Update UT Signed-off-by: Louis Chu * Fix bwc test Signed-off-by: Louis Chu * Resolve comments Signed-off-by: Louis Chu * Fix bwc test Signed-off-by: Louis Chu * clean up doc test Signed-off-by: Louis Chu * Resolve comments Signed-off-by: Louis Chu * Fix UT Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .gitignore | 1 + .../src/main/antlr/SqlBaseParser.g4 | 17 +- async-query/build.gradle | 3 + .../OpenSearchAsyncQueryScheduler.java | 197 ++++++++ ...penSearchRefreshIndexJobRequestParser.java | 71 +++ .../job/OpenSearchRefreshIndexJob.java | 93 ++++ .../OpenSearchRefreshIndexJobRequest.java | 108 +++++ .../async-query-scheduler-index-mapping.yml | 41 ++ .../async-query-scheduler-index-settings.yml | 11 + .../OpenSearchAsyncQuerySchedulerTest.java | 434 ++++++++++++++++++ .../job/OpenSearchRefreshIndexJobTest.java | 145 ++++++ .../OpenSearchRefreshIndexJobRequestTest.java | 81 ++++ build.gradle | 3 +- common/build.gradle | 4 +- core/build.gradle | 2 +- doctest/build.gradle | 53 +++ integ-test/build.gradle | 65 ++- legacy/build.gradle | 2 +- plugin/build.gradle | 11 +- .../org/opensearch/sql/plugin/SQLPlugin.java | 32 +- ...rch.jobscheduler.spi.JobSchedulerExtension | 6 + ppl/build.gradle | 2 +- protocol/build.gradle | 2 +- sql/build.gradle | 2 +- 24 files changed, 1357 insertions(+), 29 deletions(-) create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java create mode 100644 async-query/src/main/resources/async-query-scheduler-index-mapping.yml create mode 100644 async-query/src/main/resources/async-query-scheduler-index-settings.yml create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java create mode 100644 plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension diff --git a/.gitignore b/.gitignore index 1b892036dd..b9775dea04 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,5 @@ gen .worktrees http-client.env.json /doctest/sql-cli/ +/doctest/opensearch-job-scheduler/ .factorypath diff --git a/async-query-core/src/main/antlr/SqlBaseParser.g4 b/async-query-core/src/main/antlr/SqlBaseParser.g4 index a50051715e..c7aa56cf92 100644 --- a/async-query-core/src/main/antlr/SqlBaseParser.g4 +++ b/async-query-core/src/main/antlr/SqlBaseParser.g4 @@ -66,8 +66,8 @@ compoundStatement ; setStatementWithOptionalVarKeyword - : SET (VARIABLE | VAR)? assignmentList #setVariableWithOptionalKeyword - | SET (VARIABLE | VAR)? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + : SET variable? assignmentList #setVariableWithOptionalKeyword + | SET variable? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword ; @@ -215,9 +215,9 @@ statement routineCharacteristics RETURN (query | expression) #createUserDefinedFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction - | DECLARE (OR REPLACE)? VARIABLE? + | DECLARE (OR REPLACE)? variable? identifierReference dataType? variableDefaultExpression? #createVariable - | DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable + | DROP TEMPORARY variable (IF EXISTS)? identifierReference #dropVariable | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? (statement|setResetStatement) #explain | SHOW TABLES ((FROM | IN) identifierReference)? @@ -272,8 +272,8 @@ setResetStatement | SET TIME ZONE interval #setTimeZone | SET TIME ZONE timezone #setTimeZone | SET TIME ZONE .*? #setTimeZone - | SET (VARIABLE | VAR) assignmentList #setVariable - | SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + | SET variable assignmentList #setVariable + | SET variable LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ LEFT_PAREN query RIGHT_PAREN #setVariable | SET configKey EQ configValue #setQuotedConfiguration | SET configKey (EQ .*?)? #setConfiguration @@ -438,6 +438,11 @@ namespaces | SCHEMAS ; +variable + : VARIABLE + | VAR + ; + describeFuncName : identifierReference | stringLit diff --git a/async-query/build.gradle b/async-query/build.gradle index 5a4a0d729d..abda6161d3 100644 --- a/async-query/build.gradle +++ b/async-query/build.gradle @@ -16,6 +16,8 @@ repositories { dependencies { + implementation "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + api project(':core') api project(':async-query-core') implementation project(':protocol') @@ -97,6 +99,7 @@ jacocoTestCoverageVerification { // ignore because XContext IOException 'org.opensearch.sql.spark.execution.statestore.StateStore', 'org.opensearch.sql.spark.rest.*', + 'org.opensearch.sql.spark.scheduler.OpenSearchRefreshIndexJobRequestParser', 'org.opensearch.sql.spark.transport.model.*' ] limit { diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java new file mode 100644 index 0000000000..c7a66fc6be --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java @@ -0,0 +1,197 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import org.apache.commons.io.IOUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.sql.spark.scheduler.job.OpenSearchRefreshIndexJob; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +/** Scheduler class for managing asynchronous query jobs. */ +public class OpenSearchAsyncQueryScheduler { + public static final String SCHEDULER_INDEX_NAME = ".async-query-scheduler"; + public static final String SCHEDULER_PLUGIN_JOB_TYPE = "async-query-scheduler"; + private static final String SCHEDULER_INDEX_MAPPING_FILE_NAME = + "async-query-scheduler-index-mapping.yml"; + private static final String SCHEDULER_INDEX_SETTINGS_FILE_NAME = + "async-query-scheduler-index-settings.yml"; + private static final Logger LOG = LogManager.getLogger(); + + private Client client; + private ClusterService clusterService; + + /** Loads job resources, setting up required services and job runner instance. */ + public void loadJobResource(Client client, ClusterService clusterService, ThreadPool threadPool) { + this.client = client; + this.clusterService = clusterService; + OpenSearchRefreshIndexJob openSearchRefreshIndexJob = + OpenSearchRefreshIndexJob.getJobRunnerInstance(); + openSearchRefreshIndexJob.setClusterService(clusterService); + openSearchRefreshIndexJob.setThreadPool(threadPool); + openSearchRefreshIndexJob.setClient(client); + } + + /** Schedules a new job by indexing it into the job index. */ + public void scheduleJob(OpenSearchRefreshIndexJobRequest request) { + if (!this.clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) { + createAsyncQuerySchedulerIndex(); + } + IndexRequest indexRequest = new IndexRequest(SCHEDULER_INDEX_NAME); + indexRequest.id(request.getName()); + indexRequest.opType(DocWriteRequest.OpType.CREATE); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + IndexResponse indexResponse; + try { + indexRequest.source(request.toXContent(JsonXContent.contentBuilder(), EMPTY_PARAMS)); + ActionFuture indexResponseActionFuture = client.index(indexRequest); + indexResponse = indexResponseActionFuture.actionGet(); + } catch (VersionConflictEngineException exception) { + throw new IllegalArgumentException("A job already exists with name: " + request.getName()); + } catch (Throwable e) { + LOG.error("Failed to schedule job : {}", request.getName(), e); + throw new RuntimeException(e); + } + + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Job : {} successfully created", request.getName()); + } else { + throw new RuntimeException( + "Schedule job failed with result : " + indexResponse.getResult().getLowercase()); + } + } + + /** Unschedules a job by marking it as disabled and updating its last update time. */ + public void unscheduleJob(String jobId) throws IOException { + assertIndexExists(); + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(jobId) + .enabled(false) + .lastUpdateTime(Instant.now()) + .build(); + updateJob(request); + } + + /** Updates an existing job with new parameters. */ + public void updateJob(OpenSearchRefreshIndexJobRequest request) throws IOException { + assertIndexExists(); + UpdateRequest updateRequest = new UpdateRequest(SCHEDULER_INDEX_NAME, request.getName()); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + updateRequest.doc(request.toXContent(JsonXContent.contentBuilder(), EMPTY_PARAMS)); + UpdateResponse updateResponse; + try { + ActionFuture updateResponseActionFuture = client.update(updateRequest); + updateResponse = updateResponseActionFuture.actionGet(); + } catch (DocumentMissingException exception) { + throw new IllegalArgumentException("Job: " + request.getName() + " doesn't exist"); + } catch (Throwable e) { + LOG.error("Failed to update job : {}", request.getName(), e); + throw new RuntimeException(e); + } + + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED) + || updateResponse.getResult().equals(DocWriteResponse.Result.NOOP)) { + LOG.debug("Job : {} successfully updated", request.getName()); + } else { + throw new RuntimeException( + "Update job failed with result : " + updateResponse.getResult().getLowercase()); + } + } + + /** Removes a job by deleting its document from the index. */ + public void removeJob(String jobId) { + assertIndexExists(); + DeleteRequest deleteRequest = new DeleteRequest(SCHEDULER_INDEX_NAME, jobId); + deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + ActionFuture deleteResponseActionFuture = client.delete(deleteRequest); + DeleteResponse deleteResponse = deleteResponseActionFuture.actionGet(); + + if (deleteResponse.getResult().equals(DocWriteResponse.Result.DELETED)) { + LOG.debug("Job : {} successfully deleted", jobId); + } else if (deleteResponse.getResult().equals(DocWriteResponse.Result.NOT_FOUND)) { + throw new IllegalArgumentException("Job : " + jobId + " doesn't exist"); + } else { + throw new RuntimeException( + "Remove job failed with result : " + deleteResponse.getResult().getLowercase()); + } + } + + /** Creates the async query scheduler index with specified mappings and settings. */ + @VisibleForTesting + void createAsyncQuerySchedulerIndex() { + try { + InputStream mappingFileStream = + OpenSearchAsyncQueryScheduler.class + .getClassLoader() + .getResourceAsStream(SCHEDULER_INDEX_MAPPING_FILE_NAME); + InputStream settingsFileStream = + OpenSearchAsyncQueryScheduler.class + .getClassLoader() + .getResourceAsStream(SCHEDULER_INDEX_SETTINGS_FILE_NAME); + CreateIndexRequest createIndexRequest = new CreateIndexRequest(SCHEDULER_INDEX_NAME); + createIndexRequest.mapping( + IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML); + createIndexRequest.settings( + IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); + ActionFuture createIndexResponseActionFuture = + client.admin().indices().create(createIndexRequest); + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + + if (createIndexResponse.isAcknowledged()) { + LOG.debug("Index: {} creation Acknowledged", SCHEDULER_INDEX_NAME); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + LOG.error("Error creating index: {}", SCHEDULER_INDEX_NAME, e); + throw new RuntimeException( + "Internal server error while creating " + + SCHEDULER_INDEX_NAME + + " index: " + + e.getMessage(), + e); + } + } + + private void assertIndexExists() { + if (!this.clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) { + throw new IllegalStateException("Job index does not exist."); + } + } + + /** Returns the job runner instance for the scheduler. */ + public static ScheduledJobRunner getJobRunner() { + return OpenSearchRefreshIndexJob.getJobRunnerInstance(); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java new file mode 100644 index 0000000000..0422e7c015 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler; + +import java.io.IOException; +import java.time.Instant; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; + +public class OpenSearchRefreshIndexJobRequestParser { + + private static Instant parseInstantValue(XContentParser parser) throws IOException { + if (XContentParser.Token.VALUE_NULL.equals(parser.currentToken())) { + return null; + } + if (parser.currentToken().isValue()) { + return Instant.ofEpochMilli(parser.longValue()); + } + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + return null; + } + + public static ScheduledJobParser getJobParser() { + return (parser, id, jobDocVersion) -> { + OpenSearchRefreshIndexJobRequest.OpenSearchRefreshIndexJobRequestBuilder builder = + OpenSearchRefreshIndexJobRequest.builder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + while (!parser.nextToken().equals(XContentParser.Token.END_OBJECT)) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case OpenSearchRefreshIndexJobRequest.JOB_NAME_FIELD: + builder.jobName(parser.text()); + break; + case OpenSearchRefreshIndexJobRequest.JOB_TYPE_FIELD: + builder.jobType(parser.text()); + break; + case OpenSearchRefreshIndexJobRequest.ENABLED_FIELD: + builder.enabled(parser.booleanValue()); + break; + case OpenSearchRefreshIndexJobRequest.ENABLED_TIME_FIELD: + builder.enabledTime(parseInstantValue(parser)); + break; + case OpenSearchRefreshIndexJobRequest.LAST_UPDATE_TIME_FIELD: + builder.lastUpdateTime(parseInstantValue(parser)); + break; + case OpenSearchRefreshIndexJobRequest.SCHEDULE_FIELD: + builder.schedule(ScheduleParser.parse(parser)); + break; + case OpenSearchRefreshIndexJobRequest.LOCK_DURATION_SECONDS: + builder.lockDurationSeconds(parser.longValue()); + break; + case OpenSearchRefreshIndexJobRequest.JITTER: + builder.jitter(parser.doubleValue()); + break; + default: + XContentParserUtils.throwUnknownToken(parser.currentToken(), parser.getTokenLocation()); + } + } + return builder.build(); + }; + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java new file mode 100644 index 0000000000..e465a8790f --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.job; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.plugins.Plugin; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +/** + * The job runner class for scheduling refresh index query. + * + *

The job runner should be a singleton class if it uses OpenSearch client or other objects + * passed from OpenSearch. Because when registering the job runner to JobScheduler plugin, + * OpenSearch has not invoked plugins' createComponents() method. That is saying the plugin is not + * completely initialized, and the OpenSearch {@link org.opensearch.client.Client}, {@link + * ClusterService} and other objects are not available to plugin and this job runner. + * + *

So we have to move this job runner initialization to {@link Plugin} createComponents() method, + * and using singleton job runner to ensure we register a usable job runner instance to JobScheduler + * plugin. + */ +public class OpenSearchRefreshIndexJob implements ScheduledJobRunner { + + private static final Logger log = LogManager.getLogger(OpenSearchRefreshIndexJob.class); + + public static OpenSearchRefreshIndexJob INSTANCE = new OpenSearchRefreshIndexJob(); + + public static OpenSearchRefreshIndexJob getJobRunnerInstance() { + return INSTANCE; + } + + private ClusterService clusterService; + private ThreadPool threadPool; + private Client client; + + private OpenSearchRefreshIndexJob() { + // Singleton class, use getJobRunnerInstance method instead of constructor + } + + public void setClusterService(ClusterService clusterService) { + this.clusterService = clusterService; + } + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + public void setClient(Client client) { + this.client = client; + } + + @Override + public void runJob(ScheduledJobParameter jobParameter, JobExecutionContext context) { + if (!(jobParameter instanceof OpenSearchRefreshIndexJobRequest)) { + throw new IllegalStateException( + "Job parameter is not instance of OpenSearchRefreshIndexJobRequest, type: " + + jobParameter.getClass().getCanonicalName()); + } + + if (this.clusterService == null) { + throw new IllegalStateException("ClusterService is not initialized."); + } + + if (this.threadPool == null) { + throw new IllegalStateException("ThreadPool is not initialized."); + } + + if (this.client == null) { + throw new IllegalStateException("Client is not initialized."); + } + + Runnable runnable = + () -> { + doRefresh(jobParameter.getName()); + }; + threadPool.generic().submit(runnable); + } + + void doRefresh(String refreshIndex) { + // TODO: add logic to refresh index + log.info("Scheduled refresh index job on : " + refreshIndex); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java new file mode 100644 index 0000000000..7eaa4e2d29 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.model; + +import java.io.IOException; +import java.time.Instant; +import lombok.Builder; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.Schedule; + +/** Represents a job request to refresh index. */ +@Builder +public class OpenSearchRefreshIndexJobRequest implements ScheduledJobParameter { + // Constant fields for JSON serialization + public static final String JOB_NAME_FIELD = "jobName"; + public static final String JOB_TYPE_FIELD = "jobType"; + public static final String LAST_UPDATE_TIME_FIELD = "lastUpdateTime"; + public static final String LAST_UPDATE_TIME_FIELD_READABLE = "last_update_time_field"; + public static final String SCHEDULE_FIELD = "schedule"; + public static final String ENABLED_TIME_FIELD = "enabledTime"; + public static final String ENABLED_TIME_FIELD_READABLE = "enabled_time_field"; + public static final String LOCK_DURATION_SECONDS = "lockDurationSeconds"; + public static final String JITTER = "jitter"; + public static final String ENABLED_FIELD = "enabled"; + + // name is doc id + private final String jobName; + private final String jobType; + private final Schedule schedule; + private final boolean enabled; + private final Instant lastUpdateTime; + private final Instant enabledTime; + private final Long lockDurationSeconds; + private final Double jitter; + + @Override + public String getName() { + return jobName; + } + + public String getJobType() { + return jobType; + } + + @Override + public Schedule getSchedule() { + return schedule; + } + + @Override + public boolean isEnabled() { + return enabled; + } + + @Override + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + @Override + public Instant getEnabledTime() { + return enabledTime; + } + + @Override + public Long getLockDurationSeconds() { + return lockDurationSeconds; + } + + @Override + public Double getJitter() { + return jitter; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) + throws IOException { + builder.startObject(); + builder.field(JOB_NAME_FIELD, getName()).field(ENABLED_FIELD, isEnabled()); + if (getSchedule() != null) { + builder.field(SCHEDULE_FIELD, getSchedule()); + } + if (getJobType() != null) { + builder.field(JOB_TYPE_FIELD, getJobType()); + } + if (getEnabledTime() != null) { + builder.timeField( + ENABLED_TIME_FIELD, ENABLED_TIME_FIELD_READABLE, getEnabledTime().toEpochMilli()); + } + builder.timeField( + LAST_UPDATE_TIME_FIELD, + LAST_UPDATE_TIME_FIELD_READABLE, + getLastUpdateTime().toEpochMilli()); + if (this.lockDurationSeconds != null) { + builder.field(LOCK_DURATION_SECONDS, this.lockDurationSeconds); + } + if (this.jitter != null) { + builder.field(JITTER, this.jitter); + } + builder.endObject(); + return builder; + } +} diff --git a/async-query/src/main/resources/async-query-scheduler-index-mapping.yml b/async-query/src/main/resources/async-query-scheduler-index-mapping.yml new file mode 100644 index 0000000000..36bd1b873e --- /dev/null +++ b/async-query/src/main/resources/async-query-scheduler-index-mapping.yml @@ -0,0 +1,41 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .async-query-scheduler index +# Also "dynamic" is set to "false" so that other fields cannot be added. +dynamic: false +properties: + name: + type: keyword + jobType: + type: keyword + lastUpdateTime: + type: date + format: epoch_millis + enabledTime: + type: date + format: epoch_millis + schedule: + properties: + initialDelay: + type: long + interval: + properties: + start_time: + type: date + format: "strict_date_time||epoch_millis" + period: + type: integer + unit: + type: keyword + enabled: + type: boolean + lockDurationSeconds: + type: long + null_value: -1 + jitter: + type: double + null_value: 0.0 \ No newline at end of file diff --git a/async-query/src/main/resources/async-query-scheduler-index-settings.yml b/async-query/src/main/resources/async-query-scheduler-index-settings.yml new file mode 100644 index 0000000000..386f1f4f34 --- /dev/null +++ b/async-query/src/main/resources/async-query-scheduler-index-settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .async-query-scheduler index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" \ No newline at end of file diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java new file mode 100644 index 0000000000..de86f111f3 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java @@ -0,0 +1,434 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME; + +import java.io.IOException; +import java.time.Instant; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +public class OpenSearchAsyncQuerySchedulerTest { + + private static final String TEST_SCHEDULER_INDEX_NAME = "testQS"; + + private static final String TEST_JOB_ID = "testJob"; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ThreadPool threadPool; + + @Mock private ActionFuture indexResponseActionFuture; + + @Mock private ActionFuture updateResponseActionFuture; + + @Mock private ActionFuture deleteResponseActionFuture; + + @Mock private ActionFuture createIndexResponseActionFuture; + + @Mock private IndexResponse indexResponse; + + @Mock private UpdateResponse updateResponse; + + private OpenSearchAsyncQueryScheduler scheduler; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + scheduler = new OpenSearchAsyncQueryScheduler(); + scheduler.loadJobResource(client, clusterService, threadPool); + } + + @Test + public void testScheduleJob() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) + .thenReturn(Boolean.FALSE); + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, TEST_SCHEDULER_INDEX_NAME)); + when(client.index(any(IndexRequest.class))).thenReturn(indexResponseActionFuture); + + // Test the if case + when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + scheduler.scheduleJob(request); + + // Verify index created + verify(client.admin().indices(), times(1)).create(ArgumentMatchers.any()); + + // Verify doc indexed + ArgumentCaptor captor = ArgumentCaptor.forClass(IndexRequest.class); + verify(client, times(1)).index(captor.capture()); + IndexRequest capturedRequest = captor.getValue(); + assertEquals(request.getName(), capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testScheduleJobWithExistingJob() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) + .thenReturn(Boolean.TRUE); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(client.index(any(IndexRequest.class))).thenThrow(VersionConflictEngineException.class); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.scheduleJob(request); + }); + + verify(client, times(1)).index(ArgumentCaptor.forClass(IndexRequest.class).capture()); + assertEquals("A job already exists with name: testJob", exception.getMessage()); + } + + @Test + public void testScheduleJobWithExceptions() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) + .thenReturn(Boolean.FALSE); + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, TEST_SCHEDULER_INDEX_NAME)); + when(client.index(any(IndexRequest.class))).thenThrow(new RuntimeException("Test exception")); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + assertThrows(RuntimeException.class, () -> scheduler.scheduleJob(request)); + + when(client.index(any(IndexRequest.class))).thenReturn(indexResponseActionFuture); + when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + RuntimeException exception = + assertThrows(RuntimeException.class, () -> scheduler.scheduleJob(request)); + assertEquals("Schedule job failed with result : not_found", exception.getMessage()); + } + + @Test + public void testUnscheduleJob() throws IOException { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(updateResponseActionFuture.actionGet()).thenReturn(updateResponse); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + + when(client.update(any(UpdateRequest.class))).thenReturn(updateResponseActionFuture); + + scheduler.unscheduleJob(TEST_JOB_ID); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(client).update(captor.capture()); + + UpdateRequest capturedRequest = captor.getValue(); + assertEquals(TEST_JOB_ID, capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + + // Reset the captor for the next verification + captor = ArgumentCaptor.forClass(UpdateRequest.class); + + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); + scheduler.unscheduleJob(TEST_JOB_ID); + + verify(client, times(2)).update(captor.capture()); + capturedRequest = captor.getValue(); + assertEquals(TEST_JOB_ID, capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testUnscheduleJobWithIndexNotFound() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + assertThrows(IllegalStateException.class, () -> scheduler.unscheduleJob(TEST_JOB_ID)); + } + + @Test + public void testUpdateJob() throws IOException { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(updateResponseActionFuture.actionGet()).thenReturn(updateResponse); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + + when(client.update(any(UpdateRequest.class))).thenReturn(updateResponseActionFuture); + + scheduler.updateJob(request); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(client).update(captor.capture()); + + UpdateRequest capturedRequest = captor.getValue(); + assertEquals(request.getName(), capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testUpdateJobWithIndexNotFound() { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + assertThrows(IllegalStateException.class, () -> scheduler.updateJob(request)); + } + + @Test + public void testUpdateJobWithExceptions() { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + when(client.update(any(UpdateRequest.class))) + .thenThrow(new DocumentMissingException(null, null)); + + IllegalArgumentException exception1 = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.updateJob(request); + }); + + assertEquals("Job: testJob doesn't exist", exception1.getMessage()); + + when(client.update(any(UpdateRequest.class))).thenThrow(new RuntimeException("Test exception")); + + RuntimeException exception2 = + assertThrows( + RuntimeException.class, + () -> { + scheduler.updateJob(request); + }); + + assertEquals("java.lang.RuntimeException: Test exception", exception2.getMessage()); + + when(client.update(any(UpdateRequest.class))).thenReturn(updateResponseActionFuture); + when(updateResponseActionFuture.actionGet()).thenReturn(updateResponse); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + RuntimeException exception = + assertThrows(RuntimeException.class, () -> scheduler.updateJob(request)); + assertEquals("Update job failed with result : not_found", exception.getMessage()); + } + + @Test + public void testRemoveJob() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + DeleteResponse deleteResponse = mock(DeleteResponse.class); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.DELETED); + + when(client.delete(any(DeleteRequest.class))).thenReturn(deleteResponseActionFuture); + + scheduler.removeJob(TEST_JOB_ID); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteRequest.class); + verify(client).delete(captor.capture()); + + DeleteRequest capturedRequest = captor.getValue(); + assertEquals(TEST_JOB_ID, capturedRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, capturedRequest.getRefreshPolicy()); + } + + @Test + public void testRemoveJobWithIndexNotFound() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + assertThrows(IllegalStateException.class, () -> scheduler.removeJob(TEST_JOB_ID)); + } + + @Test + public void testCreateAsyncQuerySchedulerIndex() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + CreateIndexResponse createIndexResponse = mock(CreateIndexResponse.class); + when(createIndexResponseActionFuture.actionGet()).thenReturn(createIndexResponse); + when(createIndexResponse.isAcknowledged()).thenReturn(true); + + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + + scheduler.createAsyncQuerySchedulerIndex(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateIndexRequest.class); + verify(client.admin().indices()).create(captor.capture()); + + CreateIndexRequest capturedRequest = captor.getValue(); + assertEquals(SCHEDULER_INDEX_NAME, capturedRequest.index()); + } + + @Test + public void testCreateAsyncQuerySchedulerIndexFailure() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); + + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenThrow(new RuntimeException("Error creating index")); + + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> { + scheduler.createAsyncQuerySchedulerIndex(); + }); + + assertEquals( + "Internal server error while creating .async-query-scheduler index: Error creating index", + exception.getMessage()); + + when(client.admin().indices().create(any(CreateIndexRequest.class))) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(false, false, SCHEDULER_INDEX_NAME)); + + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + RuntimeException runtimeException = + Assertions.assertThrows(RuntimeException.class, () -> scheduler.scheduleJob(request)); + Assertions.assertEquals( + "Internal server error while creating .async-query-scheduler index: Index creation is not" + + " acknowledged.", + runtimeException.getMessage()); + } + + @Test + public void testUpdateJobNotFound() { + OpenSearchRefreshIndexJobRequest request = + OpenSearchRefreshIndexJobRequest.builder() + .jobName(TEST_JOB_ID) + .lastUpdateTime(Instant.now()) + .build(); + + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(client.update(any(UpdateRequest.class))) + .thenThrow(new DocumentMissingException(null, null)); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.updateJob(request); + }); + + assertEquals("Job: testJob doesn't exist", exception.getMessage()); + } + + @Test + public void testRemoveJobNotFound() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + DeleteResponse deleteResponse = mock(DeleteResponse.class); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + when(client.delete(any(DeleteRequest.class))).thenReturn(deleteResponseActionFuture); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + scheduler.removeJob(TEST_JOB_ID); + }); + + assertEquals("Job : testJob doesn't exist", exception.getMessage()); + } + + @Test + public void testRemoveJobWithExceptions() { + when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(true); + + when(client.delete(any(DeleteRequest.class))).thenThrow(new RuntimeException("Test exception")); + + assertThrows(RuntimeException.class, () -> scheduler.removeJob(TEST_JOB_ID)); + + DeleteResponse deleteResponse = mock(DeleteResponse.class); + when(client.delete(any(DeleteRequest.class))).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); + + RuntimeException runtimeException = + Assertions.assertThrows(RuntimeException.class, () -> scheduler.removeJob(TEST_JOB_ID)); + Assertions.assertEquals("Remove job failed with result : noop", runtimeException.getMessage()); + } + + @Test + public void testGetJobRunner() { + ScheduledJobRunner jobRunner = OpenSearchAsyncQueryScheduler.getJobRunner(); + assertNotNull(jobRunner); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java new file mode 100644 index 0000000000..cbf137997e --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.job; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import java.time.Instant; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.threadpool.ThreadPool; + +public class OpenSearchRefreshIndexJobTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ThreadPool threadPool; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock private JobExecutionContext context; + + private OpenSearchRefreshIndexJob jobRunner; + + private OpenSearchRefreshIndexJob spyJobRunner; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + jobRunner = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + jobRunner.setClient(null); + jobRunner.setClusterService(null); + jobRunner.setThreadPool(null); + } + + @Test + public void testRunJobWithCorrectParameter() { + spyJobRunner = spy(jobRunner); + spyJobRunner.setClusterService(clusterService); + spyJobRunner.setThreadPool(threadPool); + spyJobRunner.setClient(client); + + OpenSearchRefreshIndexJobRequest jobParameter = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .lastUpdateTime(Instant.now()) + .lockDurationSeconds(10L) + .build(); + + spyJobRunner.runJob(jobParameter, context); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(threadPool.generic()).submit(captor.capture()); + + Runnable runnable = captor.getValue(); + runnable.run(); + + verify(spyJobRunner).doRefresh(eq(jobParameter.getName())); + } + + @Test + public void testRunJobWithIncorrectParameter() { + jobRunner = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + jobRunner.setClusterService(clusterService); + jobRunner.setThreadPool(threadPool); + jobRunner.setClient(client); + + ScheduledJobParameter wrongParameter = mock(ScheduledJobParameter.class); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(wrongParameter, context), + "Expected IllegalStateException but no exception was thrown"); + + assertEquals( + "Job parameter is not instance of OpenSearchRefreshIndexJobRequest, type: " + + wrongParameter.getClass().getCanonicalName(), + exception.getMessage()); + } + + @Test + public void testRunJobWithUninitializedServices() { + OpenSearchRefreshIndexJobRequest jobParameter = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .lastUpdateTime(Instant.now()) + .build(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("ClusterService is not initialized.", exception.getMessage()); + + jobRunner.setClusterService(clusterService); + + exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("ThreadPool is not initialized.", exception.getMessage()); + + jobRunner.setThreadPool(threadPool); + + exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("Client is not initialized.", exception.getMessage()); + } + + @Test + public void testGetJobRunnerInstanceMultipleCalls() { + OpenSearchRefreshIndexJob instance1 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + OpenSearchRefreshIndexJob instance2 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + OpenSearchRefreshIndexJob instance3 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); + + assertSame(instance1, instance2); + assertSame(instance2, instance3); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java new file mode 100644 index 0000000000..108f1acfd5 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.model; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; + +public class OpenSearchRefreshIndexJobRequestTest { + + @Test + public void testBuilderAndGetterMethods() { + Instant now = Instant.now(); + IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); + + OpenSearchRefreshIndexJobRequest jobRequest = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .jobType("testType") + .schedule(schedule) + .enabled(true) + .lastUpdateTime(now) + .enabledTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + assertEquals("testJob", jobRequest.getName()); + assertEquals("testType", jobRequest.getJobType()); + assertEquals(schedule, jobRequest.getSchedule()); + assertTrue(jobRequest.isEnabled()); + assertEquals(now, jobRequest.getLastUpdateTime()); + assertEquals(now, jobRequest.getEnabledTime()); + assertEquals(60L, jobRequest.getLockDurationSeconds()); + assertEquals(0.1, jobRequest.getJitter()); + } + + @Test + public void testToXContent() throws IOException { + Instant now = Instant.now(); + IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); + + OpenSearchRefreshIndexJobRequest jobRequest = + OpenSearchRefreshIndexJobRequest.builder() + .jobName("testJob") + .jobType("testType") + .schedule(schedule) + .enabled(true) + .lastUpdateTime(now) + .enabledTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + XContentBuilder builder = XContentFactory.jsonBuilder().prettyPrint(); + jobRequest.toXContent(builder, EMPTY_PARAMS); + String jsonString = builder.toString(); + + assertTrue(jsonString.contains("\"jobName\" : \"testJob\"")); + assertTrue(jsonString.contains("\"jobType\" : \"testType\"")); + assertTrue(jsonString.contains("\"start_time\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"period\" : 1")); + assertTrue(jsonString.contains("\"unit\" : \"Minutes\"")); + assertTrue(jsonString.contains("\"enabled\" : true")); + assertTrue(jsonString.contains("\"lastUpdateTime\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"enabledTime\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"lockDurationSeconds\" : 60")); + assertTrue(jsonString.contains("\"jitter\" : 0.1")); + } +} diff --git a/build.gradle b/build.gradle index b3e09d7b50..702d6f478a 100644 --- a/build.gradle +++ b/build.gradle @@ -50,6 +50,7 @@ buildscript { return "https://github.com/prometheus/prometheus/releases/download/v${prometheus_binary_version}/prometheus-${prometheus_binary_version}."+ getOSFamilyType() + "-" + getArchType() + ".tar.gz" } aws_java_sdk_version = "1.12.651" + guava_version = "32.1.3-jre" } repositories { @@ -192,7 +193,7 @@ configurations.all { exclude group: "commons-logging", module: "commons-logging" // enforce 1.1.3, https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379 resolutionStrategy.force 'commons-codec:commons-codec:1.13' - resolutionStrategy.force 'com.google.guava:guava:32.0.1-jre' + resolutionStrategy.force "com.google.guava:guava:${guava_version}" } // updateVersion: Task to auto increment to the next development iteration diff --git a/common/build.gradle b/common/build.gradle index b4ee98a5b7..15c48dd6b3 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -34,7 +34,7 @@ repositories { dependencies { api "org.antlr:antlr4-runtime:4.7.1" - api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + api group: 'com.google.guava', name: 'guava', version: "${guava_version}" api group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' @@ -46,7 +46,7 @@ dependencies { testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.assertj', name: 'assertj-core', version: '3.9.1' - testImplementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + testImplementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' testImplementation('org.junit.jupiter:junit-jupiter:5.9.3') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0' diff --git a/core/build.gradle b/core/build.gradle index 655e7d92c2..f36777030c 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -46,7 +46,7 @@ pitest { } dependencies { - api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + api group: 'com.google.guava', name: 'guava', version: "${guava_version}" api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' api group: 'com.facebook.presto', name: 'presto-matching', version: '0.240' diff --git a/doctest/build.gradle b/doctest/build.gradle index ec5a26b52b..a125a4f336 100644 --- a/doctest/build.gradle +++ b/doctest/build.gradle @@ -5,6 +5,8 @@ import org.opensearch.gradle.testclusters.RunTask +import java.util.concurrent.Callable + plugins { id 'base' id 'com.wiredforcode.spawn' @@ -109,6 +111,10 @@ if (version_tokens.length > 1) { String mlCommonsRemoteFile = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + opensearch_no_snapshot + '/latest/linux/x64/tar/builds/opensearch/plugins/opensearch-ml-' + opensearch_build + '.zip' String mlCommonsPlugin = 'opensearch-ml' +String bwcOpenSearchJSDownload = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + opensearch_no_snapshot + '/latest/linux/x64/tar/builds/' + + 'opensearch/plugins/opensearch-job-scheduler-' + opensearch_build + '.zip' +String jsPlugin = 'opensearch-job-scheduler' + testClusters { docTestCluster { // Disable loading of `ML-commons` plugin, because it might be unavailable (not released yet). @@ -133,6 +139,7 @@ testClusters { } })) */ + plugin(getJobSchedulerPlugin(jsPlugin, bwcOpenSearchJSDownload)) plugin ':opensearch-sql-plugin' testDistribution = 'archive' } @@ -159,3 +166,49 @@ spotless { googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format') } } + +def getJobSchedulerPlugin(String jsPlugin, String bwcOpenSearchJSDownload) { + return provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + // Use absolute paths + String basePath = new File('.').getCanonicalPath() + File dir = new File(basePath + File.separator + 'doctest' + File.separator + jsPlugin) + + // Log the directory path for debugging + println("Creating directory: " + dir.getAbsolutePath()) + + // Create directory if it doesn't exist + if (!dir.exists()) { + if (!dir.mkdirs()) { + throw new IOException("Failed to create directory: " + dir.getAbsolutePath()) + } + } + + // Define the file path + File f = new File(dir, jsPlugin + '-' + opensearch_build + '.zip') + + // Download file if it doesn't exist + if (!f.exists()) { + println("Downloading file from: " + bwcOpenSearchJSDownload) + println("Saving to file: " + f.getAbsolutePath()) + + new URL(bwcOpenSearchJSDownload).withInputStream { ins -> + f.withOutputStream { it << ins } + } + } + + // Check if the file was created successfully + if (!f.exists()) { + throw new FileNotFoundException("File was not created: " + f.getAbsolutePath()) + } + + return fileTree(f.getParent()).matching { include f.getName() }.singleFile + } + } + } + }) +} diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 93153cf737..1acacdb4a5 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -80,7 +80,6 @@ ext { var projectAbsPath = projectDir.getAbsolutePath() File downloadedSecurityPlugin = Paths.get(projectAbsPath, 'bin', 'opensearch-security-snapshot.zip').toFile() - configureSecurityPlugin = { OpenSearchCluster cluster -> cluster.getNodes().forEach { node -> @@ -138,6 +137,10 @@ ext { cluster.plugin provider((Callable) (() -> (RegularFile) (() -> downloadedSecurityPlugin))) } + + bwcOpenSearchJSDownload = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + baseVersion + '/latest/linux/x64/tar/builds/' + + 'opensearch/plugins/opensearch-job-scheduler-' + bwcVersion + '.zip' + bwcJobSchedulerPath = bwcFilePath + "job-scheduler/" } tasks.withType(licenseHeaders.class) { @@ -153,7 +156,6 @@ configurations.all { resolutionStrategy.force "commons-logging:commons-logging:1.2" // enforce 1.1.3, https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379 resolutionStrategy.force 'commons-codec:commons-codec:1.13' - resolutionStrategy.force 'com.google.guava:guava:32.0.1-jre' resolutionStrategy.force "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${versions.jackson}" @@ -166,6 +168,7 @@ configurations.all { resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" resolutionStrategy.force "com.amazonaws:aws-java-sdk-core:${aws_java_sdk_version}" + resolutionStrategy.force "com.google.guava:guava:${guava_version}" } configurations { @@ -191,6 +194,7 @@ dependencies { testCompileOnly 'org.apiguardian:apiguardian-api:1.1.2' // Needed for BWC tests + zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${bwcVersion}-SNAPSHOT" } @@ -219,22 +223,42 @@ testClusters.all { } } +def getJobSchedulerPlugin() { + provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return configurations.zipArchive.asFileTree.matching { + include '**/opensearch-job-scheduler*' + }.singleFile + } + } + } + }) +} + testClusters { integTest { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" setting "plugins.query.datasources.encryption.masterkey", "1234567812345678" } remoteCluster { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" } integTestWithSecurity { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" } remoteIntegTestWithSecurity { testDistribution = 'archive' + plugin(getJobSchedulerPlugin()) plugin ":opensearch-sql-plugin" } } @@ -502,6 +526,24 @@ task comparisonTest(type: RestIntegTestTask) { testDistribution = "ARCHIVE" versions = [baseVersion, opensearch_version] numberOfNodes = 3 + plugin(provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + if (new File("$project.rootDir/$bwcFilePath/job-scheduler/$bwcVersion").exists()) { + project.delete(files("$project.rootDir/$bwcFilePath/job-scheduler/$bwcVersion")) + } + project.mkdir bwcJobSchedulerPath + bwcVersion + ant.get(src: bwcOpenSearchJSDownload, + dest: bwcJobSchedulerPath + bwcVersion, + httpusecaches: false) + return fileTree(bwcJobSchedulerPath + bwcVersion).getSingleFile() + } + } + } + })) plugin(provider(new Callable(){ @Override RegularFile call() throws Exception { @@ -522,17 +564,18 @@ task comparisonTest(type: RestIntegTestTask) { } List> plugins = [ - provider(new Callable() { - @Override - RegularFile call() throws Exception { - return new RegularFile() { - @Override - File getAsFile() { - return fileTree(bwcFilePath + project.version).getSingleFile() + getJobSchedulerPlugin(), + provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return fileTree(bwcFilePath + project.version).getSingleFile() + } } } - } - }) + }) ] // Creates 2 test clusters with 3 nodes of the old version. diff --git a/legacy/build.gradle b/legacy/build.gradle index 0467db183d..e3ddf27066 100644 --- a/legacy/build.gradle +++ b/legacy/build.gradle @@ -107,7 +107,7 @@ dependencies { because 'https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379' } } - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" implementation group: 'org.json', name: 'json', version:'20231013' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' diff --git a/plugin/build.gradle b/plugin/build.gradle index 710d81ed0a..7ebd0ad2d9 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -48,6 +48,7 @@ opensearchplugin { name 'opensearch-sql' description 'OpenSearch SQL' classname 'org.opensearch.sql.plugin.SQLPlugin' + extendedPlugins = ['opensearch-job-scheduler'] licenseFile rootProject.file("LICENSE.txt") noticeFile rootProject.file("NOTICE") } @@ -98,7 +99,8 @@ configurations.all { resolutionStrategy.force "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" // enforce 1.1.3, https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379 resolutionStrategy.force 'commons-codec:commons-codec:1.13' - resolutionStrategy.force 'com.google.guava:guava:32.0.1-jre' + resolutionStrategy.force "com.google.guava:guava:${guava_version}" + resolutionStrategy.force 'com.google.guava:failureaccess:1.0.2' resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${versions.jackson}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}" @@ -139,6 +141,10 @@ spotless { } dependencies { + compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + compileOnly "com.google.guava:guava:${guava_version}" + compileOnly 'com.google.guava:failureaccess:1.0.2' + api "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" @@ -204,11 +210,10 @@ dependencyLicenses.enabled = false // enable testingConventions check will cause errors like: "Classes ending with [Tests] must subclass [LuceneTestCase]" testingConventions.enabled = false -// TODO: need to verify the thirdPartyAudi +// TODO: need to verify the thirdPartyAudit // currently it complains missing classes like ibatis, mysql etc, should not be a problem thirdPartyAudit.enabled = false - apply plugin: 'com.netflix.nebula.ospackage' validateNebulaPom.enabled = false diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index b86ab9218a..a1b1e32955 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -42,6 +42,9 @@ import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; import org.opensearch.indices.SystemIndexDescriptor; +import org.opensearch.jobscheduler.spi.JobSchedulerExtension; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.ScriptPlugin; @@ -91,6 +94,9 @@ import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; +import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.scheduler.OpenSearchRefreshIndexJobRequestParser; +import org.opensearch.sql.spark.scheduler.job.OpenSearchRefreshIndexJob; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; @@ -105,7 +111,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; -public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin, SystemIndexPlugin { +public class SQLPlugin extends Plugin + implements ActionPlugin, ScriptPlugin, SystemIndexPlugin, JobSchedulerExtension { private static final Logger LOGGER = LogManager.getLogger(SQLPlugin.class); @@ -116,6 +123,7 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin, Sys private NodeClient client; private DataSourceServiceImpl dataSourceService; + private OpenSearchAsyncQueryScheduler asyncQueryScheduler; private Injector injector; public String name() { @@ -208,6 +216,8 @@ public Collection createComponents( this.client = (NodeClient) client; this.dataSourceService = createDataSourceService(); dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); + this.asyncQueryScheduler = new OpenSearchAsyncQueryScheduler(); + this.asyncQueryScheduler.loadJobResource(client, clusterService, threadPool); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); LocalClusterState.state().setClient(client); @@ -243,6 +253,26 @@ public Collection createComponents( pluginSettings); } + @Override + public String getJobType() { + return OpenSearchAsyncQueryScheduler.SCHEDULER_PLUGIN_JOB_TYPE; + } + + @Override + public String getJobIndex() { + return OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME; + } + + @Override + public ScheduledJobRunner getJobRunner() { + return OpenSearchRefreshIndexJob.getJobRunnerInstance(); + } + + @Override + public ScheduledJobParser getJobParser() { + return OpenSearchRefreshIndexJobRequestParser.getJobParser(); + } + @Override public List> getExecutorBuilders(Settings settings) { return singletonList( diff --git a/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension b/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension new file mode 100644 index 0000000000..5337857c15 --- /dev/null +++ b/plugin/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension @@ -0,0 +1,6 @@ +# +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +# + +org.opensearch.sql.plugin.SQLPlugin \ No newline at end of file diff --git a/ppl/build.gradle b/ppl/build.gradle index d58882d5e8..2a3d6bdbf9 100644 --- a/ppl/build.gradle +++ b/ppl/build.gradle @@ -48,7 +48,7 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' implementation "org.antlr:antlr4-runtime:4.7.1" - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" api group: 'org.json', name: 'json', version: '20231013' implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:"${versions.log4j}" api project(':common') diff --git a/protocol/build.gradle b/protocol/build.gradle index 5bbff68e51..b5d7929041 100644 --- a/protocol/build.gradle +++ b/protocol/build.gradle @@ -30,7 +30,7 @@ plugins { } dependencies { - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "${versions.jackson}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "${versions.jackson_databind}" implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" diff --git a/sql/build.gradle b/sql/build.gradle index 81872e6035..10bb4b24bb 100644 --- a/sql/build.gradle +++ b/sql/build.gradle @@ -46,7 +46,7 @@ dependencies { antlr "org.antlr:antlr4:4.7.1" implementation "org.antlr:antlr4-runtime:4.7.1" - implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: "${guava_version}" implementation group: 'org.json', name: 'json', version:'20231013' implementation project(':common') implementation project(':core') From 82ef68e2b25c7c10740e74968bbe960c000c1cee Mon Sep 17 00:00:00 2001 From: panguixin Date: Thu, 1 Aug 2024 23:10:13 +0800 Subject: [PATCH 09/10] Support common format geo point (#2801) --------- Signed-off-by: panguixin --- .../sql/legacy/SQLIntegTestCase.java | 8 +- .../org/opensearch/sql/legacy/TestUtils.java | 5 ++ .../opensearch/sql/legacy/TestsConstants.java | 1 + .../opensearch/sql/sql/GeopointFormatsIT.java | 60 +++++++++++++ integ-test/src/test/resources/geopoints.json | 12 +++ .../geopoint_index_mapping.json | 9 ++ .../data/utils/OpenSearchJsonContent.java | 50 ++++------- .../value/OpenSearchExprValueFactory.java | 59 ++++++++++-- .../data/utils/OpenSearchJsonContentTest.java | 31 +++++++ .../value/OpenSearchExprValueFactoryTest.java | 89 +++++++++++++------ 10 files changed, 256 insertions(+), 68 deletions(-) create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java create mode 100644 integ-test/src/test/resources/geopoints.json create mode 100644 integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java index 63c44bf831..c6d15a305d 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java @@ -20,6 +20,7 @@ import static org.opensearch.sql.legacy.TestUtils.getDogs3IndexMapping; import static org.opensearch.sql.legacy.TestUtils.getEmployeeNestedTypeIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getGameOfThronesIndexMapping; +import static org.opensearch.sql.legacy.TestUtils.getGeopointIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getJoinTypeIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getLocationIndexMapping; import static org.opensearch.sql.legacy.TestUtils.getMappingFile; @@ -724,7 +725,12 @@ public enum Index { TestsConstants.TEST_INDEX_NESTED_WITH_NULLS, "multi_nested", getNestedTypeIndexMapping(), - "src/test/resources/nested_with_nulls.json"); + "src/test/resources/nested_with_nulls.json"), + GEOPOINTS( + TestsConstants.TEST_INDEX_GEOPOINT, + "dates", + getGeopointIndexMapping(), + "src/test/resources/geopoints.json"); private final String name; private final String type; diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java b/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java index 65cacf16d2..195dda0cbd 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/TestUtils.java @@ -245,6 +245,11 @@ public static String getDataTypeNonnumericIndexMapping() { return getMappingFile(mappingFile); } + public static String getGeopointIndexMapping() { + String mappingFile = "geopoint_index_mapping.json"; + return getMappingFile(mappingFile); + } + public static void loadBulk(Client client, String jsonPath, String defaultIndex) throws Exception { System.out.println(String.format("Loading file %s into opensearch cluster", jsonPath)); diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java b/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java index 29bc9813fa..73838feb4f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java @@ -57,6 +57,7 @@ public class TestsConstants { public static final String TEST_INDEX_WILDCARD = TEST_INDEX + "_wildcard"; public static final String TEST_INDEX_MULTI_NESTED_TYPE = TEST_INDEX + "_multi_nested"; public static final String TEST_INDEX_NESTED_WITH_NULLS = TEST_INDEX + "_nested_with_nulls"; + public static final String TEST_INDEX_GEOPOINT = TEST_INDEX + "_geopoint"; public static final String DATASOURCES = ".ql-datasources"; public static final String DATE_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java new file mode 100644 index 0000000000..f25eeec241 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/GeopointFormatsIT.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import java.util.Map; +import org.apache.commons.lang3.tuple.Pair; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +public class GeopointFormatsIT extends SQLIntegTestCase { + + @Override + public void init() throws Exception { + loadIndex(Index.GEOPOINTS); + } + + @Test + public void testReadingGeopoints() throws IOException { + String query = String.format("SELECT point FROM %s LIMIT 5", Index.GEOPOINTS.getName()); + JSONObject result = executeJdbcRequest(query); + verifySchema(result, schema("point", null, "geo_point")); + verifyDataRows( + result, + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71)), + rows(Map.of("lon", 74, "lat", 40.71))); + } + + private static final double TOLERANCE = 1E-5; + + public void testReadingGeoHash() throws IOException { + String query = String.format("SELECT point FROM %s WHERE _id='6'", Index.GEOPOINTS.getName()); + JSONObject result = executeJdbcRequest(query); + verifySchema(result, schema("point", null, "geo_point")); + Pair point = getGeoValue(result); + assertEquals(40.71, point.getLeft(), TOLERANCE); + assertEquals(74, point.getRight(), TOLERANCE); + } + + private Pair getGeoValue(JSONObject result) { + JSONObject geoRaw = + (JSONObject) ((JSONArray) ((JSONArray) result.get("datarows")).get(0)).get(0); + double lat = geoRaw.getDouble("lat"); + double lon = geoRaw.getDouble("lon"); + return Pair.of(lat, lon); + } +} diff --git a/integ-test/src/test/resources/geopoints.json b/integ-test/src/test/resources/geopoints.json new file mode 100644 index 0000000000..95900fe811 --- /dev/null +++ b/integ-test/src/test/resources/geopoints.json @@ -0,0 +1,12 @@ +{"index": {"_id": "1"}} +{"point": {"lat": 40.71, "lon": 74.00}} +{"index": {"_id": "2"}} +{"point": "40.71,74.00"} +{"index": {"_id": "3"}} +{"point": [74.00, 40.71]} +{"index": {"_id": "4"}} +{"point": "POINT (74.00 40.71)"} +{"index": {"_id": "5"}} +{"point": {"type": "Point", "coordinates": [74.00, 40.71]}} +{"index": {"_id": "6"}} +{"point": "txhxegj0uyp3"} diff --git a/integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json b/integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json new file mode 100644 index 0000000000..61340530d8 --- /dev/null +++ b/integ-test/src/test/resources/indexDefinitions/geopoint_index_mapping.json @@ -0,0 +1,9 @@ +{ + "mappings": { + "properties": { + "point": { + "type": "geo_point" + } + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java index bdb15428e1..4446c1f979 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContent.java @@ -7,11 +7,19 @@ import com.fasterxml.jackson.databind.JsonNode; import com.google.common.collect.Iterators; +import java.io.IOException; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.OpenSearchParseException; +import org.opensearch.common.geo.GeoPoint; +import org.opensearch.common.geo.GeoUtils; +import org.opensearch.common.xcontent.json.JsonXContentParser; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; /** The Implementation of Content to represent {@link JsonNode}. */ @RequiredArgsConstructor @@ -122,25 +130,17 @@ public Object objectValue() { @Override public Pair geoValue() { final JsonNode value = value(); - if (value.has("lat") && value.has("lon")) { - Double lat = 0d; - Double lon = 0d; - try { - lat = extractDoubleValue(value.get("lat")); - } catch (Exception exception) { - throw new IllegalStateException( - "latitude must be number value, but got value: " + value.get("lat")); - } - try { - lon = extractDoubleValue(value.get("lon")); - } catch (Exception exception) { - throw new IllegalStateException( - "longitude must be number value, but got value: " + value.get("lon")); - } - return Pair.of(lat, lon); - } else { - throw new IllegalStateException( - "geo point must in format of {\"lat\": number, \"lon\": number}"); + try (XContentParser parser = + new JsonXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + value.traverse())) { + parser.nextToken(); + GeoPoint point = new GeoPoint(); + GeoUtils.parseGeoPoint(parser, point, true); + return Pair.of(point.getLat(), point.getLon()); + } catch (IOException ex) { + throw new OpenSearchParseException("error parsing geo point", ex); } } @@ -148,16 +148,4 @@ public Pair geoValue() { private JsonNode value() { return value; } - - /** Get doubleValue from JsonNode if possible. */ - private Double extractDoubleValue(JsonNode node) { - if (node.isTextual()) { - return Double.valueOf(node.textValue()); - } - if (node.isNumber()) { - return node.doubleValue(); - } else { - throw new IllegalStateException("node must be a number"); - } - } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index 3cb182de5b..417aaddaee 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -39,6 +39,7 @@ import java.util.function.BiFunction; import lombok.Getter; import lombok.Setter; +import org.opensearch.OpenSearchParseException; import org.opensearch.common.time.DateFormatter; import org.opensearch.common.time.DateFormatters; import org.opensearch.common.time.FormatNames; @@ -62,7 +63,6 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchBinaryType; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchDateType; -import org.opensearch.sql.opensearch.data.type.OpenSearchGeoPointType; import org.opensearch.sql.opensearch.data.type.OpenSearchIpType; import org.opensearch.sql.opensearch.data.utils.Content; import org.opensearch.sql.opensearch.data.utils.ObjectContent; @@ -134,10 +134,6 @@ public void extendTypeMapping(Map typeMapping) { .put( OpenSearchDataType.of(OpenSearchDataType.MappingType.Ip), (c, dt) -> new OpenSearchExprIpValue(c.stringValue())) - .put( - OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint), - (c, dt) -> - new OpenSearchExprGeoPointValue(c.geoValue().getLeft(), c.geoValue().getRight())) .put( OpenSearchDataType.of(OpenSearchDataType.MappingType.Binary), (c, dt) -> new OpenSearchExprBinaryValue(c.stringValue())) @@ -193,8 +189,11 @@ private ExprValue parse( return ExprNullValue.of(); } - ExprType type = fieldType.get(); - if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) + final ExprType type = fieldType.get(); + + if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint))) { + return parseGeoPoint(content, supportArrays); + } else if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) || content.isArray()) { return parseArray(content, field, type, supportArrays); } else if (type.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.Object)) @@ -362,6 +361,49 @@ private ExprValue parseArray( return new ExprCollectionValue(result); } + /** + * Parse geo point content. + * + * @param content Content to parse. + * @param supportArrays Parsing the whole array or not + * @return Geo point value parsed from content. + */ + private ExprValue parseGeoPoint(Content content, boolean supportArrays) { + // there is only one point in doc. + if (content.isArray() == false) { + final var pair = content.geoValue(); + return new OpenSearchExprGeoPointValue(pair.getLeft(), pair.getRight()); + } + + var elements = content.array(); + var first = elements.next(); + // an array in the [longitude, latitude] format. + if (first.isNumber()) { + double lon = first.doubleValue(); + var second = elements.next(); + if (second.isNumber() == false) { + throw new OpenSearchParseException("lat must be a number, got " + second.objectValue()); + } + return new OpenSearchExprGeoPointValue(second.doubleValue(), lon); + } + + // there are multi points in doc + var pair = first.geoValue(); + var firstPoint = new OpenSearchExprGeoPointValue(pair.getLeft(), pair.getRight()); + if (supportArrays) { + List result = new ArrayList<>(); + result.add(firstPoint); + elements.forEachRemaining( + e -> { + var p = e.geoValue(); + result.add(new OpenSearchExprGeoPointValue(p.getLeft(), p.getRight())); + }); + return new ExprCollectionValue(result); + } else { + return firstPoint; + } + } + /** * Parse inner array value. Can be object type and recurse continues. * @@ -375,8 +417,7 @@ private ExprValue parseInnerArrayValue( Content content, String prefix, ExprType type, boolean supportArrays) { if (type instanceof OpenSearchIpType || type instanceof OpenSearchBinaryType - || type instanceof OpenSearchDateType - || type instanceof OpenSearchGeoPointType) { + || type instanceof OpenSearchDateType) { return parse(content, prefix, Optional.of(type), supportArrays); } else if (content.isString()) { return parse(content, prefix, Optional.of(OpenSearchDataType.of(STRING)), supportArrays); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java new file mode 100644 index 0000000000..c2cf0328bd --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/utils/OpenSearchJsonContentTest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.data.utils; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.JsonNode; +import java.io.IOException; +import org.junit.jupiter.api.Test; +import org.opensearch.OpenSearchParseException; + +public class OpenSearchJsonContentTest { + @Test + public void testGetValueWithIOException() throws IOException { + JsonNode jsonNode = mock(JsonNode.class); + JsonParser jsonParser = mock(JsonParser.class); + when(jsonNode.traverse()).thenReturn(jsonParser); + when(jsonParser.nextToken()).thenThrow(new IOException()); + OpenSearchJsonContent content = new OpenSearchJsonContent(jsonNode); + OpenSearchParseException exception = + assertThrows(OpenSearchParseException.class, content::geoValue); + assertTrue(exception.getMessage().contains("error parsing geo point")); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java index 83e26f85e4..6b4d825ab1 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java @@ -47,6 +47,8 @@ import lombok.EqualsAndHashCode; import lombok.ToString; import org.junit.jupiter.api.Test; +import org.opensearch.OpenSearchParseException; +import org.opensearch.geometry.utils.Geohash; import org.opensearch.sql.data.model.ExprCollectionValue; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprTimeValue; @@ -597,6 +599,18 @@ public void constructArrayOfGeoPoints() { .get("geoV")); } + @Test + public void constructArrayOfGeoPointsReturnsFirstIndex() { + assertEquals( + new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), + tupleValue( + "{\"geoV\":[" + + "{\"lat\":42.60355556,\"lon\":-97.25263889}," + + "{\"lat\":-33.6123556,\"lon\":66.287449}" + + "]}") + .get("geoV")); + } + @Test public void constructArrayOfIPsReturnsFirstIndex() { assertEquals( @@ -671,14 +685,50 @@ public void constructIP() { tupleValue("{\"ipV\":\"192.168.0.1\"}").get("ipV")); } + private static final double TOLERANCE = 1E-5; + @Test public void constructGeoPoint() { + final double lat = 42.60355556; + final double lon = -97.25263889; + final var expectedGeoPointValue = new OpenSearchExprGeoPointValue(lat, lon); + // An object with a latitude and longitude. assertEquals( - new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), - tupleValue("{\"geoV\":{\"lat\":42.60355556,\"lon\":-97.25263889}}").get("geoV")); + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":{\"lat\":%.8f,\"lon\":%.8f}}", lat, lon)).get("geoV")); + + // A string in the “latitude,longitude” format. assertEquals( - new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), - tupleValue("{\"geoV\":{\"lat\":\"42.60355556\",\"lon\":\"-97.25263889\"}}").get("geoV")); + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":\"%.8f,%.8f\"}", lat, lon)).get("geoV")); + + // A geohash. + var point = + (OpenSearchExprGeoPointValue.GeoPoint) + tupleValue(String.format("{\"geoV\":\"%s\"}", Geohash.stringEncode(lon, lat))) + .get("geoV") + .value(); + assertEquals(lat, point.getLat(), TOLERANCE); + assertEquals(lon, point.getLon(), TOLERANCE); + + // An array in the [longitude, latitude] format. + assertEquals( + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":[%.8f, %.8f]}", lon, lat)).get("geoV")); + + // A Well-Known Text POINT in the “POINT(longitude latitude)” format. + assertEquals( + expectedGeoPointValue, + tupleValue(String.format("{\"geoV\":\"POINT (%.8f %.8f)\"}", lon, lat)).get("geoV")); + + // GeoJSON format, where the coordinates are in the [longitude, latitude] format + assertEquals( + expectedGeoPointValue, + tupleValue( + String.format( + "{\"geoV\":{\"type\":\"Point\",\"coordinates\":[%.8f,%.8f]}}", lon, lat)) + .get("geoV")); + assertEquals( new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), constructFromObject("geoV", "42.60355556,-97.25263889")); @@ -686,38 +736,23 @@ public void constructGeoPoint() { @Test public void constructGeoPointFromUnsupportedFormatShouldThrowException() { - IllegalStateException exception = + OpenSearchParseException exception = assertThrows( - IllegalStateException.class, - () -> tupleValue("{\"geoV\":[42.60355556,-97.25263889]}").get("geoV")); - assertEquals( - "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); + OpenSearchParseException.class, + () -> tupleValue("{\"geoV\": [42.60355556, false]}").get("geoV")); + assertEquals("lat must be a number, got false", exception.getMessage()); exception = assertThrows( - IllegalStateException.class, + OpenSearchParseException.class, () -> tupleValue("{\"geoV\":{\"lon\":-97.25263889}}").get("geoV")); - assertEquals( - "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); - - exception = - assertThrows( - IllegalStateException.class, - () -> tupleValue("{\"geoV\":{\"lat\":-97.25263889}}").get("geoV")); - assertEquals( - "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); + assertEquals("field [lat] missing", exception.getMessage()); exception = assertThrows( - IllegalStateException.class, + OpenSearchParseException.class, () -> tupleValue("{\"geoV\":{\"lat\":true,\"lon\":-97.25263889}}").get("geoV")); - assertEquals("latitude must be number value, but got value: true", exception.getMessage()); - - exception = - assertThrows( - IllegalStateException.class, - () -> tupleValue("{\"geoV\":{\"lat\":42.60355556,\"lon\":false}}").get("geoV")); - assertEquals("longitude must be number value, but got value: false", exception.getMessage()); + assertEquals("lat must be a number", exception.getMessage()); } @Test From 14a80a95fb5fa36781b46f28ba52d406927e21c0 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Thu, 1 Aug 2024 09:48:33 -0700 Subject: [PATCH 10/10] Add AsyncQueryRequestContext to FlintIndexMetadataService/FlintIndexStateModelService (#2879) Signed-off-by: Tomoyuki Morita --- .../asyncquery/AsyncQueryExecutorService.java | 2 +- .../AsyncQueryExecutorServiceImpl.java | 4 +- .../spark/dispatcher/AsyncQueryHandler.java | 5 +- .../spark/dispatcher/BatchQueryHandler.java | 5 +- .../sql/spark/dispatcher/IndexDMLHandler.java | 16 +++- .../dispatcher/InteractiveQueryHandler.java | 5 +- .../spark/dispatcher/RefreshQueryHandler.java | 10 +- .../dispatcher/SparkQueryDispatcher.java | 6 +- .../dispatcher/StreamingQueryHandler.java | 5 +- .../flint/FlintIndexMetadataService.java | 11 ++- .../flint/FlintIndexStateModelService.java | 46 ++++++++- .../spark/flint/operation/FlintIndexOp.java | 47 ++++++---- .../flint/operation/FlintIndexOpAlter.java | 8 +- .../flint/operation/FlintIndexOpCancel.java | 6 +- .../flint/operation/FlintIndexOpDrop.java | 6 +- .../flint/operation/FlintIndexOpVacuum.java | 6 +- .../asyncquery/AsyncQueryCoreIntegTest.java | 19 ++-- .../AsyncQueryExecutorServiceImplTest.java | 8 +- .../spark/dispatcher/IndexDMLHandlerTest.java | 11 ++- .../dispatcher/SparkQueryDispatcherTest.java | 15 ++- .../flint/operation/FlintIndexOpTest.java | 37 ++++++-- .../operation/FlintIndexOpVacuumTest.java | 94 ++++++++++++++----- .../FlintStreamingJobHouseKeeperTask.java | 13 ++- .../flint/FlintIndexMetadataServiceImpl.java | 9 +- ...OpenSearchFlintIndexStateModelService.java | 13 ++- ...ransportCancelAsyncQueryRequestAction.java | 5 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 6 +- .../spark/asyncquery/IndexQuerySpecTest.java | 23 +++-- .../asyncquery/model/MockFlintSparkJob.java | 11 ++- .../FlintStreamingJobHouseKeeperTaskTest.java | 8 +- .../FlintIndexMetadataServiceImplTest.java | 27 +++++- ...SearchFlintIndexStateModelServiceTest.java | 13 ++- ...portCancelAsyncQueryRequestActionTest.java | 16 +++- 33 files changed, 384 insertions(+), 132 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index d38c8554ae..b0c339e93d 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -39,5 +39,5 @@ CreateAsyncQueryResponse createAsyncQuery( * @param queryId queryId. * @return {@link String} cancelledQueryId. */ - String cancelQuery(String queryId); + String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 6d3d5b6765..d304766465 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -106,11 +106,11 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { } @Override - public String cancelQuery(String queryId) { + public String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryRequestContext) { Optional asyncQueryJobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (asyncQueryJobMetadata.isPresent()) { - return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get()); + return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext); } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java index d61ac17aa3..2bafd88b85 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -12,6 +12,7 @@ import com.amazonaws.services.emrserverless.model.JobRunState; import org.json.JSONObject; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -54,7 +55,9 @@ protected abstract JSONObject getResponseFromResultIndex( protected abstract JSONObject getResponseFromExecutor( AsyncQueryJobMetadata asyncQueryJobMetadata); - public abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); + public abstract String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext); public abstract DispatchQueryResponse submit( DispatchQueryRequest request, DispatchQueryContext context); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 2654f83aad..661ebe27fc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -16,6 +16,7 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -61,7 +62,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { emrServerlessClient.cancelJobRun( asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId(), false); return asyncQueryJobMetadata.getQueryId(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index e8413f469c..f8217142c3 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -62,9 +62,11 @@ public DispatchQueryResponse submit( long startTime = System.currentTimeMillis(); try { IndexQueryDetails indexDetails = context.getIndexQueryDetails(); - FlintIndexMetadata indexMetadata = getFlintIndexMetadata(indexDetails); + FlintIndexMetadata indexMetadata = + getFlintIndexMetadata(indexDetails, context.getAsyncQueryRequestContext()); - getIndexOp(dispatchQueryRequest, indexDetails).apply(indexMetadata); + getIndexOp(dispatchQueryRequest, indexDetails) + .apply(indexMetadata, context.getAsyncQueryRequestContext()); String asyncQueryId = storeIndexDMLResult( @@ -146,9 +148,11 @@ private FlintIndexOp getIndexOp( } } - private FlintIndexMetadata getFlintIndexMetadata(IndexQueryDetails indexDetails) { + private FlintIndexMetadata getFlintIndexMetadata( + IndexQueryDetails indexDetails, AsyncQueryRequestContext asyncQueryRequestContext) { Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexDetails.openSearchIndexName(), asyncQueryRequestContext); if (!indexMetadataMap.containsKey(indexDetails.openSearchIndexName())) { throw new IllegalStateException( String.format( @@ -174,7 +178,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { throw new IllegalArgumentException("can't cancel index DML query"); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index ec43bccf11..9a9baedde2 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -16,6 +16,7 @@ import org.json.JSONObject; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -71,7 +72,9 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String queryId = asyncQueryJobMetadata.getQueryId(); getStatementByQueryId( asyncQueryJobMetadata.getSessionId(), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 99984ecc46..38145a143e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -8,6 +8,7 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -51,10 +52,13 @@ public RefreshQueryHandler( } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String datasourceName = asyncQueryJobMetadata.getDatasourceName(); Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(asyncQueryJobMetadata.getIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + asyncQueryJobMetadata.getIndexName(), asyncQueryRequestContext); if (!indexMetadataMap.containsKey(asyncQueryJobMetadata.getIndexName())) { throw new IllegalStateException( String.format( @@ -62,7 +66,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { } FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName()); FlintIndexOp jobCancelOp = flintIndexOpFactory.getCancel(datasourceName); - jobCancelOp.apply(indexMetadata); + jobCancelOp.apply(indexMetadata, asyncQueryRequestContext); return asyncQueryJobMetadata.getQueryId(); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index a424db4c34..a6fdd3f102 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -162,9 +162,11 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) .getQueryResponse(asyncQueryJobMetadata); } - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) - .cancelJob(asyncQueryJobMetadata); + .cancelJob(asyncQueryJobMetadata, asyncQueryRequestContext); } private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 2fbf2466da..80d4be27cf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -12,6 +12,7 @@ import java.util.Map; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; @@ -46,7 +47,9 @@ public StreamingQueryHandler( } @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public String cancelJob( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { throw new IllegalArgumentException( "can't cancel index DML query, using ALTER auto_refresh=off statement to stop job, using" + " VACUUM statement to stop job and delete data"); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java index ad274e429e..ece14c2a7b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataService.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.flint; import java.util.Map; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; /** Interface for FlintIndexMetadataReader */ @@ -15,16 +16,22 @@ public interface FlintIndexMetadataService { * Retrieves a map of {@link FlintIndexMetadata} instances matching the specified index pattern. * * @param indexPattern indexPattern. + * @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService * @return A map of {@link FlintIndexMetadata} instances against indexName, each providing * metadata access for a matched index. Returns an empty list if no indices match the pattern. */ - Map getFlintIndexMetadata(String indexPattern); + Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext); /** * Performs validation and updates flint index to manual refresh. * * @param indexName indexName. * @param flintIndexOptions flintIndexOptions. + * @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService */ - void updateIndexToManualRefresh(String indexName, FlintIndexOptions flintIndexOptions); + void updateIndexToManualRefresh( + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java index 94647f4e07..3872f2d5a0 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java @@ -6,20 +6,58 @@ package org.opensearch.sql.spark.flint; import java.util.Optional; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; /** * Abstraction over flint index state storage. Flint index state will maintain the status of each * flint index. */ public interface FlintIndexStateModelService { - FlintIndexStateModel createFlintIndexStateModel(FlintIndexStateModel flintIndexStateModel); - Optional getFlintIndexStateModel(String id, String datasourceName); + /** + * Create Flint index state record + * + * @param flintIndexStateModel the model to be saved + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return saved model + */ + FlintIndexStateModel createFlintIndexStateModel( + FlintIndexStateModel flintIndexStateModel, AsyncQueryRequestContext asyncQueryRequestContext); + /** + * Get Flint index state record + * + * @param id ID(latestId) of the Flint index state record + * @param datasourceName datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return retrieved model + */ + Optional getFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext); + + /** + * Update Flint index state record + * + * @param flintIndexStateModel the model to be updated + * @param flintIndexState new state + * @param datasourceName Datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return Updated model + */ FlintIndexStateModel updateFlintIndexState( FlintIndexStateModel flintIndexStateModel, FlintIndexState flintIndexState, - String datasourceName); + String datasourceName, + AsyncQueryRequestContext asyncQueryRequestContext); - boolean deleteFlintIndexStateModel(String id, String datasourceName); + /** + * Delete Flint index state record + * + * @param id ID(latestId) of the Flint index state record + * @param datasourceName datasource name + * @param asyncQueryRequestContext the request context passed to AsyncQueryExecutorService + * @return true if deleted, otherwise false + */ + boolean deleteFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 244f4aee11..78d217b8dc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -16,6 +16,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.NotNull; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -33,30 +34,33 @@ public abstract class FlintIndexOp { private final EMRServerlessClientFactory emrServerlessClientFactory; /** Apply operation on {@link FlintIndexMetadata} */ - public void apply(FlintIndexMetadata metadata) { + public void apply( + FlintIndexMetadata metadata, AsyncQueryRequestContext asyncQueryRequestContext) { // todo, remove this logic after IndexState feature is enabled in Flint. Optional latestId = metadata.getLatestId(); if (latestId.isEmpty()) { - takeActionWithoutOCC(metadata); + takeActionWithoutOCC(metadata, asyncQueryRequestContext); } else { - FlintIndexStateModel initialFlintIndexStateModel = getFlintIndexStateModel(latestId.get()); + FlintIndexStateModel initialFlintIndexStateModel = + getFlintIndexStateModel(latestId.get(), asyncQueryRequestContext); // 1.validate state. validFlintIndexInitialState(initialFlintIndexStateModel); // 2.begin, move to transitioning state FlintIndexStateModel transitionedFlintIndexStateModel = - moveToTransitioningState(initialFlintIndexStateModel); + moveToTransitioningState(initialFlintIndexStateModel, asyncQueryRequestContext); // 3.runOp try { - runOp(metadata, transitionedFlintIndexStateModel); - commit(transitionedFlintIndexStateModel); + runOp(metadata, transitionedFlintIndexStateModel, asyncQueryRequestContext); + commit(transitionedFlintIndexStateModel, asyncQueryRequestContext); } catch (Throwable e) { LOG.error("Rolling back transient log due to transaction operation failure", e); try { flintIndexStateModelService.updateFlintIndexState( transitionedFlintIndexStateModel, initialFlintIndexStateModel.getIndexState(), - datasourceName); + datasourceName, + asyncQueryRequestContext); } catch (Exception ex) { LOG.error("Failed to rollback transient log", ex); } @@ -66,9 +70,11 @@ public void apply(FlintIndexMetadata metadata) { } @NotNull - private FlintIndexStateModel getFlintIndexStateModel(String latestId) { + private FlintIndexStateModel getFlintIndexStateModel( + String latestId, AsyncQueryRequestContext asyncQueryRequestContext) { Optional flintIndexOptional = - flintIndexStateModelService.getFlintIndexStateModel(latestId, datasourceName); + flintIndexStateModelService.getFlintIndexStateModel( + latestId, datasourceName, asyncQueryRequestContext); if (flintIndexOptional.isEmpty()) { String errorMsg = String.format(Locale.ROOT, "no state found. docId: %s", latestId); LOG.error(errorMsg); @@ -77,7 +83,8 @@ private FlintIndexStateModel getFlintIndexStateModel(String latestId) { return flintIndexOptional.get(); } - private void takeActionWithoutOCC(FlintIndexMetadata metadata) { + private void takeActionWithoutOCC( + FlintIndexMetadata metadata, AsyncQueryRequestContext asyncQueryRequestContext) { // take action without occ. FlintIndexStateModel fakeModel = FlintIndexStateModel.builder() @@ -89,7 +96,7 @@ private void takeActionWithoutOCC(FlintIndexMetadata metadata) { .lastUpdateTime(System.currentTimeMillis()) .error("") .build(); - runOp(metadata, fakeModel); + runOp(metadata, fakeModel, asyncQueryRequestContext); } private void validFlintIndexInitialState(FlintIndexStateModel flintIndex) { @@ -103,13 +110,14 @@ private void validFlintIndexInitialState(FlintIndexStateModel flintIndex) { } } - private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flintIndex) { + private FlintIndexStateModel moveToTransitioningState( + FlintIndexStateModel flintIndex, AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug("Moving to transitioning state before committing."); FlintIndexState transitioningState = transitioningState(); try { flintIndex = flintIndexStateModelService.updateFlintIndexState( - flintIndex, transitioningState(), datasourceName); + flintIndex, transitioningState(), datasourceName, asyncQueryRequestContext); } catch (Exception e) { String errorMsg = String.format(Locale.ROOT, "Moving to transition state:%s failed.", transitioningState); @@ -119,16 +127,18 @@ private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flint return flintIndex; } - private void commit(FlintIndexStateModel flintIndex) { + private void commit( + FlintIndexStateModel flintIndex, AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug("Committing the transaction and moving to stable state."); FlintIndexState stableState = stableState(); try { if (stableState == FlintIndexState.NONE) { LOG.info("Deleting index state with docId: " + flintIndex.getLatestId()); flintIndexStateModelService.deleteFlintIndexStateModel( - flintIndex.getLatestId(), datasourceName); + flintIndex.getLatestId(), datasourceName, asyncQueryRequestContext); } else { - flintIndexStateModelService.updateFlintIndexState(flintIndex, stableState, datasourceName); + flintIndexStateModelService.updateFlintIndexState( + flintIndex, stableState, datasourceName, asyncQueryRequestContext); } } catch (Exception e) { String errorMsg = @@ -192,7 +202,10 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel) /** get transitioningState */ abstract FlintIndexState transitioningState(); - abstract void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex); + abstract void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext); /** get stableState */ abstract FlintIndexState stableState(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java index 9955320253..4a00195ebf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -48,11 +49,14 @@ FlintIndexState transitioningState() { @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Running alter index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); this.flintIndexMetadataService.updateIndexToManualRefresh( - flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions); + flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions, asyncQueryRequestContext); cancelStreamingJob(flintIndexStateModel); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java index 02c8e39c66..504a8f93c9 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -38,7 +39,10 @@ FlintIndexState transitioningState() { /** cancel EMR-S job, wait cancelled state upto 15s. */ @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java index 6613c29870..fc9b644fc7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; @@ -40,7 +41,10 @@ FlintIndexState transitioningState() { /** cancel EMR-S job, wait cancelled state upto 15s. */ @SneakyThrows @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndexStateModel) { + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java index a0ef955adf..06aaf8ef9f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java @@ -7,6 +7,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -42,7 +43,10 @@ FlintIndexState transitioningState() { } @Override - public void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex) { + public void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext) { LOG.info("Vacuuming Flint index {}", flintIndexMetadata.getOpensearchIndexName()); flintIndexClient.deleteIndex(flintIndexMetadata.getOpensearchIndexName()); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index d82d3bdab7..ff92762a7c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -249,7 +249,8 @@ public void createAlterIndexQuery() { assertNull(response.getSessionId()); verifyGetQueryIdCalled(); verify(flintIndexMetadataService) - .updateIndexToManualRefresh(eq(indexName), flintIndexOptionsArgumentCaptor.capture()); + .updateIndexToManualRefresh( + eq(indexName), flintIndexOptionsArgumentCaptor.capture(), eq(asyncQueryRequestContext)); FlintIndexOptions flintIndexOptions = flintIndexOptionsArgumentCaptor.getValue(); assertFalse(flintIndexOptions.autoRefresh()); verifyCancelJobRunCalled(); @@ -430,7 +431,7 @@ public void cancelInteractiveQuery() { when(statementStorageService.updateStatementState(statementModel, StatementState.CANCELLED)) .thenReturn(canceledStatementModel); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verify(statementStorageService).updateStatementState(statementModel, StatementState.CANCELLED); @@ -441,14 +442,15 @@ public void cancelIndexDMLQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobId(DROP_INDEX_JOB_ID)); assertThrows( - IllegalArgumentException.class, () -> asyncQueryExecutorService.cancelQuery(QUERY_ID)); + IllegalArgumentException.class, + () -> asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext)); } @Test public void cancelRefreshQuery() { givenJobMetadataExists( getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.BATCH).indexName(INDEX_NAME)); - when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME)) + when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( INDEX_NAME, @@ -463,7 +465,7 @@ public void cancelRefreshQuery() { new GetJobRunResult() .withJobRun(new JobRun().withJobRunId(JOB_ID).withState("Cancelled"))); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verifyCancelJobRunCalled(); @@ -475,7 +477,8 @@ public void cancelStreamingQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.STREAMING)); assertThrows( - IllegalArgumentException.class, () -> asyncQueryExecutorService.cancelQuery(QUERY_ID)); + IllegalArgumentException.class, + () -> asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext)); } @Test @@ -483,7 +486,7 @@ public void cancelBatchQuery() { givenJobMetadataExists(getBaseAsyncQueryJobMetadataBuilder().jobId(JOB_ID)); givenCancelJobRunSucceed(); - String result = asyncQueryExecutorService.cancelQuery(QUERY_ID); + String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); verifyCancelJobRunCalled(); @@ -500,7 +503,7 @@ private void givenSparkExecutionEngineConfigIsSupplied() { } private void givenFlintIndexMetadataExists(String indexName) { - when(flintIndexMetadataService.getFlintIndexMetadata(indexName)) + when(flintIndexMetadataService.getFlintIndexMetadata(indexName, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( indexName, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index dbc51bb0ad..5d8d9a3b63 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -206,7 +206,8 @@ void testCancelJobWithJobNotFound() { AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( - AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID)); + AsyncQueryNotFoundException.class, + () -> jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext)); Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); @@ -218,9 +219,10 @@ void testCancelJobWithJobNotFound() { void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.of(getAsyncQueryJobMetadata())); - when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata())).thenReturn(EMR_JOB_ID); + when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata(), asyncQueryRequestContext)) + .thenReturn(EMR_JOB_ID); - String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); + String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext); Assertions.assertEquals(EMR_JOB_ID, jobId); verifyNoInteractions(sparkExecutionEngineConfigSupplier); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 877d6ec32b..9a3c4e663e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.sql.datasource.model.DataSourceStatus.ACTIVE; @@ -27,6 +28,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; @@ -50,6 +52,7 @@ class IndexDMLHandlerTest { @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private FlintIndexOpFactory flintIndexOpFactory; @Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @InjectMocks IndexDMLHandler indexDMLHandler; @@ -82,8 +85,10 @@ public void testWhenIndexDetailsAreNotFound() { .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); - Mockito.when(flintIndexMetadataService.getFlintIndexMetadata(any())) + Mockito.when( + flintIndexMetadataService.getFlintIndexMetadata(any(), eq(asyncQueryRequestContext))) .thenReturn(new HashMap<>()); DispatchQueryResponse dispatchQueryResponse = @@ -107,10 +112,12 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { .queryId(QUERY_ID) .dataSourceMetadata(metadata) .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) .build(); HashMap flintMetadataMap = new HashMap<>(); flintMetadataMap.put(indexQueryDetails.openSearchIndexName(), flintIndexMetadata); - when(flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName())) + when(flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext)) .thenReturn(flintMetadataMap); indexDMLHandler.submit(dispatchQueryRequest, dispatchQueryContext); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index a7a79c758e..592309cb75 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -871,7 +871,8 @@ void testCancelJob() { .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); Assertions.assertEquals(QUERY_ID, queryId); } @@ -884,7 +885,8 @@ void testCancelQueryWithSession() { String queryId = sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID), + asyncQueryRequestContext); verifyNoInteractions(emrServerlessClient); verify(statement, times(1)).cancel(); @@ -900,7 +902,8 @@ void testCancelQueryWithInvalidSession() { IllegalArgumentException.class, () -> sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"))); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, "invalid"), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(session); @@ -916,7 +919,8 @@ void testCancelQueryWithInvalidStatementId() { IllegalArgumentException.class, () -> sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID))); + asyncQueryJobMetadataWithSessionId("invalid", MOCK_SESSION_ID), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(statement); @@ -933,7 +937,8 @@ void testCancelQueryWithNoSessionId() { .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); Assertions.assertEquals(QUERY_ID, queryId); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java index 0c82733ae6..8105629822 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.xcontent.XContentSerializerUtil; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -28,21 +29,26 @@ public class FlintIndexOpTest { @Mock private FlintIndexStateModelService flintIndexStateModelService; @Mock private EMRServerlessClientFactory mockEmrServerlessClientFactory; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Test public void testApplyWithTransitioningStateFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenThrow(new RuntimeException("Transitioning state failed")); FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "Moving to transition state:DELETING failed.", illegalStateException.getMessage()); @@ -53,9 +59,11 @@ public void testApplyWithCommitFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenReturn( FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) @@ -65,7 +73,9 @@ public void testApplyWithCommitFailure() { new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); @@ -76,9 +86,11 @@ public void testApplyWithRollBackFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); - when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) + when(flintIndexStateModelService.getFlintIndexStateModel( + eq("latestId"), any(), eq(asyncQueryRequestContext))) .thenReturn(Optional.of(fakeModel)); - when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState( + any(), any(), any(), eq(asyncQueryRequestContext))) .thenReturn( FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) @@ -87,7 +99,9 @@ public void testApplyWithRollBackFailure() { new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); IllegalStateException illegalStateException = - Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertThrows( + IllegalStateException.class, + () -> flintIndexOp.apply(metadata, asyncQueryRequestContext)); Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); @@ -125,7 +139,10 @@ FlintIndexState transitioningState() { } @Override - void runOp(FlintIndexMetadata flintIndexMetadata, FlintIndexStateModel flintIndex) {} + void runOp( + FlintIndexMetadata flintIndexMetadata, + FlintIndexStateModel flintIndex, + AsyncQueryRequestContext asyncQueryRequestContext) {} @Override FlintIndexState stableState() { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java index 60fa13dc93..26858c18fe 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -38,6 +39,7 @@ class FlintIndexOpVacuumTest { @Mock EMRServerlessClientFactory emrServerlessClientFactory; @Mock FlintIndexStateModel flintIndexStateModel; @Mock FlintIndexStateModel transitionedFlintIndexStateModel; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; RuntimeException testException = new RuntimeException("Test Exception"); @@ -55,110 +57,154 @@ public void setUp() { @Test public void testApplyWithEmptyLatestId() { - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITHOUT_LATEST_ID); + flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITHOUT_LATEST_ID, asyncQueryRequestContext); verify(flintIndexClient).deleteIndex(INDEX_NAME); } @Test public void testApplyWithFlintIndexStateNotFound() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.empty()); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithNotDeletedState() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.ACTIVE); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithUpdateFlintIndexStateThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenThrow(testException); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithRunOpThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); assertThrows( - Exception.class, () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + Exception.class, + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); verify(flintIndexStateModelService) .updateFlintIndexState( - transitionedFlintIndexStateModel, FlintIndexState.DELETED, DATASOURCE_NAME); + transitionedFlintIndexStateModel, + FlintIndexState.DELETED, + DATASOURCE_NAME, + asyncQueryRequestContext); } @Test public void testApplyWithRunOpThrowAndRollbackThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); when(flintIndexStateModelService.updateFlintIndexState( - transitionedFlintIndexStateModel, FlintIndexState.DELETED, DATASOURCE_NAME)) + transitionedFlintIndexStateModel, + FlintIndexState.DELETED, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenThrow(testException); assertThrows( - Exception.class, () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + Exception.class, + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyWithDeleteFlintIndexStateModelThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); - when(flintIndexStateModelService.deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.deleteFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenThrow(testException); assertThrows( IllegalStateException.class, - () -> flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID)); + () -> + flintIndexOpVacuum.apply( + FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); } @Test public void testApplyHappyPath() { - when(flintIndexStateModelService.getFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME)) + when(flintIndexStateModelService.getFlintIndexStateModel( + LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(flintIndexStateModel)); when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, FlintIndexState.VACUUMING, DATASOURCE_NAME)) + flintIndexStateModel, + FlintIndexState.VACUUMING, + DATASOURCE_NAME, + asyncQueryRequestContext)) .thenReturn(transitionedFlintIndexStateModel); when(transitionedFlintIndexStateModel.getLatestId()).thenReturn(LATEST_ID); - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID); + flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext); - verify(flintIndexStateModelService).deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME); + verify(flintIndexStateModelService) + .deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext); verify(flintIndexClient).deleteIndex(INDEX_NAME); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java b/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java index 31b1ecb49c..2dd0a4a7cf 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java @@ -17,6 +17,7 @@ import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; @@ -29,6 +30,8 @@ public class FlintStreamingJobHouseKeeperTask implements Runnable { private final DataSourceService dataSourceService; private final FlintIndexMetadataService flintIndexMetadataService; private final FlintIndexOpFactory flintIndexOpFactory; + private final NullAsyncQueryRequestContext nullAsyncQueryRequestContext = + new NullAsyncQueryRequestContext(); private static final Logger LOGGER = LogManager.getLogger(FlintStreamingJobHouseKeeperTask.class); protected static final AtomicBoolean isRunning = new AtomicBoolean(false); @@ -91,7 +94,9 @@ private void dropAutoRefreshIndex( String autoRefreshIndex, FlintIndexMetadata flintIndexMetadata, String datasourceName) { // When the datasource is deleted. Possibly Replace with VACUUM Operation. LOGGER.info("Attempting to drop auto refresh index: {}", autoRefreshIndex); - flintIndexOpFactory.getDrop(datasourceName).apply(flintIndexMetadata); + flintIndexOpFactory + .getDrop(datasourceName) + .apply(flintIndexMetadata, nullAsyncQueryRequestContext); LOGGER.info("Successfully dropped index: {}", autoRefreshIndex); } @@ -100,7 +105,9 @@ private void alterAutoRefreshIndex( LOGGER.info("Attempting to alter index: {}", autoRefreshIndex); FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); flintIndexOptions.setOption(FlintIndexOptions.AUTO_REFRESH, "false"); - flintIndexOpFactory.getAlter(flintIndexOptions, datasourceName).apply(flintIndexMetadata); + flintIndexOpFactory + .getAlter(flintIndexOptions, datasourceName) + .apply(flintIndexMetadata, nullAsyncQueryRequestContext); LOGGER.info("Successfully altered index: {}", autoRefreshIndex); } @@ -119,7 +126,7 @@ private String getDataSourceName(FlintIndexMetadata flintIndexMetadata) { private Map getAllAutoRefreshIndices() { Map flintIndexMetadataHashMap = - flintIndexMetadataService.getFlintIndexMetadata("flint_*"); + flintIndexMetadataService.getFlintIndexMetadata("flint_*", nullAsyncQueryRequestContext); return flintIndexMetadataHashMap.entrySet().stream() .filter(entry -> entry.getValue().getFlintIndexOptions().autoRefresh()) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); diff --git a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java index 893b33b39d..b8352d15b2 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java @@ -33,6 +33,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.client.Client; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; /** Implementation of {@link FlintIndexMetadataService} */ @@ -49,7 +50,8 @@ public class FlintIndexMetadataServiceImpl implements FlintIndexMetadataService Arrays.asList(AUTO_REFRESH, INCREMENTAL_REFRESH, WATERMARK_DELAY, CHECKPOINT_LOCATION)); @Override - public Map getFlintIndexMetadata(String indexPattern) { + public Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext) { GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings().setIndices(indexPattern).get(); Map indexMetadataMap = new HashMap<>(); @@ -73,7 +75,10 @@ public Map getFlintIndexMetadata(String indexPattern } @Override - public void updateIndexToManualRefresh(String indexName, FlintIndexOptions flintIndexOptions) { + public void updateIndexToManualRefresh( + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext) { GetMappingsResponse mappingsResponse = client.admin().indices().prepareGetMappings().setIndices(indexName).get(); Map flintMetadataMap = diff --git a/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java b/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java index 5781c3e44b..eba338e912 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java @@ -7,6 +7,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; @@ -20,7 +21,8 @@ public class OpenSearchFlintIndexStateModelService implements FlintIndexStateMod public FlintIndexStateModel updateFlintIndexState( FlintIndexStateModel flintIndexStateModel, FlintIndexState flintIndexState, - String datasourceName) { + String datasourceName, + AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.updateState( flintIndexStateModel, flintIndexState, @@ -29,14 +31,16 @@ public FlintIndexStateModel updateFlintIndexState( } @Override - public Optional getFlintIndexStateModel(String id, String datasourceName) { + public Optional getFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.get( id, serializer::fromXContent, OpenSearchStateStoreUtil.getIndexName(datasourceName)); } @Override public FlintIndexStateModel createFlintIndexStateModel( - FlintIndexStateModel flintIndexStateModel) { + FlintIndexStateModel flintIndexStateModel, + AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.create( flintIndexStateModel.getId(), flintIndexStateModel, @@ -45,7 +49,8 @@ public FlintIndexStateModel createFlintIndexStateModel( } @Override - public boolean deleteFlintIndexStateModel(String id, String datasourceName) { + public boolean deleteFlintIndexStateModel( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.delete(id, OpenSearchStateStoreUtil.getIndexName(datasourceName)); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 232a280db5..ce80351f70 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -13,6 +13,7 @@ import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -41,7 +42,9 @@ protected void doExecute( CancelAsyncQueryActionRequest request, ActionListener listener) { try { - String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); + String jobId = + asyncQueryExecutorService.cancelQuery( + request.getQueryId(), new NullAsyncQueryRequestContext()); listener.onResponse( new CancelAsyncQueryActionResponse( String.format("Deleted async query with id: %s", jobId))); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 3ff806bf50..ede8a348b4 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -71,7 +71,8 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { emrsClient.getJobRunResultCalled(1); // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelQueryId = + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext); assertEquals(response.getQueryId(), cancelQueryId); emrsClient.cancelJobRunCalled(1); } @@ -163,7 +164,8 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); // 3. cancel async query. - String cancelQueryId = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelQueryId = + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext); assertEquals(response.getQueryId(), cancelQueryId); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 2eed7b13a0..29c42446b3 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -152,7 +152,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals("can't cancel index DML query", exception.getMessage()); }); } @@ -326,7 +328,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals("can't cancel index DML query", exception.getMessage()); }); } @@ -901,7 +905,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); assertEquals( "can't cancel index DML query, using ALTER auto_refresh=off statement to stop" + " job, using VACUUM statement to stop job and delete data", @@ -944,7 +950,9 @@ public GetJobRunResult getJobRunResult( flintIndexJob.refreshing(); // 2. Cancel query - String cancelResponse = asyncQueryExecutorService.cancelQuery(response.getQueryId()); + String cancelResponse = + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext); assertNotNull(cancelResponse); assertTrue(clusterService.state().routingTable().hasIndex(mockDS.indexName)); @@ -992,7 +1000,9 @@ public GetJobRunResult getJobRunResult( IllegalStateException illegalStateException = Assertions.assertThrows( IllegalStateException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery( + response.getQueryId(), asyncQueryRequestContext)); Assertions.assertEquals( "Transaction failed as flint index is not in a valid state.", illegalStateException.getMessage()); @@ -1038,6 +1048,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. Cancel query Assertions.assertThrows( IllegalStateException.class, - () -> asyncQueryExecutorService.cancelQuery(response.getQueryId())); + () -> + asyncQueryExecutorService.cancelQuery(response.getQueryId(), asyncQueryRequestContext)); } } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java index 6c82188ee6..0dc8f02820 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java @@ -18,6 +18,7 @@ public class MockFlintSparkJob { private FlintIndexStateModel stateModel; private FlintIndexStateModelService flintIndexStateModelService; private String datasource; + private AsyncQueryRequestContext asyncQueryRequestContext = new NullAsyncQueryRequestContext(); public MockFlintSparkJob( FlintIndexStateModelService flintIndexStateModelService, String latestId, String datasource) { @@ -34,12 +35,15 @@ public MockFlintSparkJob( .lastUpdateTime(System.currentTimeMillis()) .error("") .build(); - stateModel = flintIndexStateModelService.createFlintIndexStateModel(stateModel); + stateModel = + flintIndexStateModelService.createFlintIndexStateModel( + stateModel, asyncQueryRequestContext); } public void transition(FlintIndexState newState) { stateModel = - flintIndexStateModelService.updateFlintIndexState(stateModel, newState, datasource); + flintIndexStateModelService.updateFlintIndexState( + stateModel, newState, datasource, asyncQueryRequestContext); } public void refreshing() { @@ -68,7 +72,8 @@ public void deleted() { public void assertState(FlintIndexState expected) { Optional stateModelOpt = - flintIndexStateModelService.getFlintIndexStateModel(stateModel.getId(), datasource); + flintIndexStateModelService.getFlintIndexStateModel( + stateModel.getId(), datasource, asyncQueryRequestContext); assertTrue(stateModelOpt.isPresent()); assertEquals(expected, stateModelOpt.get().getIndexState()); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index c5964a61e3..0a3a180932 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -20,6 +20,7 @@ import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceSpec; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.asyncquery.model.MockFlintIndex; import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; @@ -393,13 +394,16 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataService() { @Override - public Map getFlintIndexMetadata(String indexPattern) { + public Map getFlintIndexMetadata( + String indexPattern, AsyncQueryRequestContext asyncQueryRequestContext) { throw new RuntimeException("Couldn't fetch details from ElasticSearch"); } @Override public void updateIndexToManualRefresh( - String indexName, FlintIndexOptions flintIndexOptions) {} + String indexName, + FlintIndexOptions flintIndexOptions, + AsyncQueryRequestContext asyncQueryRequestContext) {} }; FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java b/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java index f6baa82dd2..b1321cc132 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImplTest.java @@ -29,6 +29,7 @@ import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; @@ -39,6 +40,8 @@ public class FlintIndexMetadataServiceImplTest { @Mock(answer = RETURNS_DEEP_STUBS) private Client client; + @Mock private AsyncQueryRequestContext asyncQueryRequestContext; + @SneakyThrows @Test void testGetJobIdFromFlintSkippingIndexMetadata() { @@ -56,8 +59,11 @@ void testGetJobIdFromFlintSkippingIndexMetadata() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertEquals( "00fhelvq7peuao0", indexMetadataMap.get(indexQueryDetails.openSearchIndexName()).getJobId()); @@ -80,8 +86,11 @@ void testGetJobIdFromFlintSkippingIndexMetadataWithIndexState() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.SKIPPING) .build(); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + FlintIndexMetadata metadata = indexMetadataMap.get(indexQueryDetails.openSearchIndexName()); Assertions.assertEquals("00fhelvq7peuao0", metadata.getJobId()); } @@ -103,8 +112,11 @@ void testGetJobIdFromFlintCoveringIndexMetadata() { .indexType(FlintIndexType.COVERING) .build(); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + Map indexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertEquals( "00fdmvv9hp8u0o0q", indexMetadataMap.get(indexQueryDetails.openSearchIndexName()).getJobId()); @@ -126,8 +138,11 @@ void testGetJobIDWithNPEException() { .indexQueryActionType(IndexQueryActionType.DROP) .indexType(FlintIndexType.COVERING) .build(); + Map flintIndexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata(indexQueryDetails.openSearchIndexName()); + flintIndexMetadataService.getFlintIndexMetadata( + indexQueryDetails.openSearchIndexName(), asyncQueryRequestContext); + Assertions.assertFalse( flintIndexMetadataMap.containsKey("flint_mys3_default_http_logs_cv1_index")); } @@ -148,8 +163,10 @@ void testGetJobIDWithNPEExceptionForMultipleIndices() { indexMappingsMap.put(indexName, mappings); mockNodeClientIndicesMappings("flint_mys3*", indexMappingsMap); FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); + Map flintIndexMetadataMap = - flintIndexMetadataService.getFlintIndexMetadata("flint_mys3*"); + flintIndexMetadataService.getFlintIndexMetadata("flint_mys3*", asyncQueryRequestContext); + Assertions.assertFalse( flintIndexMetadataMap.containsKey("flint_mys3_default_http_logs_cv1_index")); Assertions.assertTrue( diff --git a/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java index 977f77b397..4faff41fe6 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelServiceTest.java @@ -16,6 +16,7 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; @@ -30,6 +31,7 @@ public class OpenSearchFlintIndexStateModelServiceTest { @Mock FlintIndexState flintIndexState; @Mock FlintIndexStateModel responseFlintIndexStateModel; @Mock FlintIndexStateModelXContentSerializer flintIndexStateModelXContentSerializer; + @Mock AsyncQueryRequestContext asyncQueryRequestContext; @InjectMocks OpenSearchFlintIndexStateModelService openSearchFlintIndexStateModelService; @@ -40,7 +42,7 @@ void updateFlintIndexState() { FlintIndexStateModel result = openSearchFlintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, flintIndexState, DATASOURCE); + flintIndexStateModel, flintIndexState, DATASOURCE, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result); } @@ -51,7 +53,8 @@ void getFlintIndexStateModel() { .thenReturn(Optional.of(responseFlintIndexStateModel)); Optional result = - openSearchFlintIndexStateModelService.getFlintIndexStateModel("ID", DATASOURCE); + openSearchFlintIndexStateModelService.getFlintIndexStateModel( + "ID", DATASOURCE, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result.get()); } @@ -63,7 +66,8 @@ void createFlintIndexStateModel() { when(flintIndexStateModel.getDatasourceName()).thenReturn(DATASOURCE); FlintIndexStateModel result = - openSearchFlintIndexStateModelService.createFlintIndexStateModel(flintIndexStateModel); + openSearchFlintIndexStateModelService.createFlintIndexStateModel( + flintIndexStateModel, asyncQueryRequestContext); assertEquals(responseFlintIndexStateModel, result); } @@ -73,7 +77,8 @@ void deleteFlintIndexStateModel() { when(mockStateStore.delete(any(), any())).thenReturn(true); boolean result = - openSearchFlintIndexStateModelService.deleteFlintIndexStateModel(ID, DATASOURCE); + openSearchFlintIndexStateModelService.deleteFlintIndexStateModel( + ID, DATASOURCE, asyncQueryRequestContext); assertTrue(result); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index 2ff76b9b57..a2581fdea2 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -7,6 +7,8 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -24,6 +26,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -36,7 +39,6 @@ public class TransportCancelAsyncQueryRequestActionTest { @Mock private TransportCancelAsyncQueryRequestAction action; @Mock private Task task; @Mock private ActionListener actionListener; - @Mock private AsyncQueryExecutorServiceImpl asyncQueryExecutorService; @Captor @@ -54,8 +56,12 @@ public void setUp() { @Test public void testDoExecute() { CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); - when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); + when(asyncQueryExecutorService.cancelQuery( + eq(EMR_JOB_ID), any(NullAsyncQueryRequestContext.class))) + .thenReturn(EMR_JOB_ID); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse = deleteJobActionResponseArgumentCaptor.getValue(); @@ -66,8 +72,12 @@ public void testDoExecute() { @Test public void testDoExecuteWithException() { CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); - doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); + doThrow(new RuntimeException("Error")) + .when(asyncQueryExecutorService) + .cancelQuery(eq(EMR_JOB_ID), any(NullAsyncQueryRequestContext.class)); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); Exception exception = exceptionArgumentCaptor.getValue(); Assertions.assertTrue(exception instanceof RuntimeException);