From e1ee3b130ab19b99bffd2c80c5c9bb40ac9bb145 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Thu, 29 Aug 2024 15:07:29 -0700 Subject: [PATCH 01/11] Add AsyncQueryRequestContext to update/get in StatementStorageService (#2943) Signed-off-by: Tomoyuki Morita --- .../asyncquery/AsyncQueryExecutorService.java | 3 +- .../AsyncQueryExecutorServiceImpl.java | 6 ++- .../spark/dispatcher/AsyncQueryHandler.java | 15 ++++-- .../spark/dispatcher/BatchQueryHandler.java | 8 +++- .../sql/spark/dispatcher/IndexDMLHandler.java | 8 +++- .../dispatcher/InteractiveQueryHandler.java | 22 ++++++--- .../dispatcher/SparkQueryDispatcher.java | 6 ++- .../execution/session/InteractiveSession.java | 6 ++- .../sql/spark/execution/session/Session.java | 2 +- .../spark/execution/statement/Statement.java | 3 +- .../statestore/StatementStorageService.java | 7 ++- .../asyncquery/AsyncQueryCoreIntegTest.java | 17 ++++--- .../AsyncQueryExecutorServiceImplTest.java | 14 ++++-- .../spark/dispatcher/IndexDMLHandlerTest.java | 4 +- .../dispatcher/SparkQueryDispatcherTest.java | 21 +++++---- .../OpenSearchStatementStorageService.java | 12 +++-- .../TransportGetAsyncQueryResultAction.java | 3 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 24 ++++++---- .../AsyncQueryGetResultSpecTest.java | 13 ++++-- .../asyncquery/IndexQuerySpecAlterTest.java | 46 ++++++++++++------- .../spark/asyncquery/IndexQuerySpecTest.java | 39 ++++++++++------ .../asyncquery/IndexQuerySpecVacuumTest.java | 3 +- .../execution/statement/StatementTest.java | 19 +++++--- ...ransportGetAsyncQueryResultActionTest.java | 20 ++++++-- 24 files changed, 217 insertions(+), 104 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index b0c339e93d..1240545acd 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -31,7 +31,8 @@ CreateAsyncQueryResponse createAsyncQuery( * @param queryId queryId. * @return {@link AsyncQueryExecutionResponse} */ - AsyncQueryExecutionResponse getAsyncQueryResults(String queryId); + AsyncQueryExecutionResponse getAsyncQueryResults( + String queryId, AsyncQueryRequestContext asyncQueryRequestContext); /** * Cancels running async query and returns the cancelled queryId. diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index d304766465..5933343ba4 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -74,12 +74,14 @@ public CreateAsyncQueryResponse createAsyncQuery( } @Override - public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { + public AsyncQueryExecutionResponse getAsyncQueryResults( + String queryId, AsyncQueryRequestContext asyncQueryRequestContext) { Optional jobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { String sessionId = jobMetadata.get().getSessionId(); - JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get()); + JSONObject jsonObject = + sparkQueryDispatcher.getQueryResponse(jobMetadata.get(), asyncQueryRequestContext); if (JobRunState.SUCCESS.toString().equals(jsonObject.getString(STATUS_FIELD))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = new DefaultSparkSqlFunctionResponseHandle(jsonObject); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java index 2bafd88b85..441846d678 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -21,8 +21,10 @@ /** Process async query request. */ public abstract class AsyncQueryHandler { - public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - JSONObject result = getResponseFromResultIndex(asyncQueryJobMetadata); + public JSONObject getQueryResponse( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { + JSONObject result = getResponseFromResultIndex(asyncQueryJobMetadata, asyncQueryRequestContext); if (result.has(DATA_FIELD)) { JSONObject items = result.getJSONObject(DATA_FIELD); @@ -35,7 +37,8 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) result.put(ERROR_FIELD, error); return result; } else { - JSONObject statement = getResponseFromExecutor(asyncQueryJobMetadata); + JSONObject statement = + getResponseFromExecutor(asyncQueryJobMetadata, asyncQueryRequestContext); // Consider statement still running if state is success but query result unavailable if (isSuccessState(statement)) { @@ -50,10 +53,12 @@ private boolean isSuccessState(JSONObject statement) { } protected abstract JSONObject getResponseFromResultIndex( - AsyncQueryJobMetadata asyncQueryJobMetadata); + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext); protected abstract JSONObject getResponseFromExecutor( - AsyncQueryJobMetadata asyncQueryJobMetadata); + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext); public abstract String cancelJob( AsyncQueryJobMetadata asyncQueryJobMetadata, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 661ebe27fc..bce1918631 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -41,7 +41,9 @@ public class BatchQueryHandler extends AsyncQueryHandler { protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; @Override - protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + protected JSONObject getResponseFromResultIndex( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { // either empty json when the result is not available or data with status // Fetch from Result Index return jobExecutionResponseReader.getResultWithJobId( @@ -49,7 +51,9 @@ protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQuery } @Override - protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + protected JSONObject getResponseFromExecutor( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { JSONObject result = new JSONObject(); // make call to EMR Serverless when related result index documents are not available GetJobRunResult getJobRunResult = diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index f8217142c3..d5885e6f2a 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -162,14 +162,18 @@ private FlintIndexMetadata getFlintIndexMetadata( } @Override - protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + protected JSONObject getResponseFromResultIndex( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String queryId = asyncQueryJobMetadata.getQueryId(); return jobExecutionResponseReader.getResultWithQueryId( queryId, asyncQueryJobMetadata.getResultIndex()); } @Override - protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + protected JSONObject getResponseFromExecutor( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { // Consider statement still running if result doc created in submit() is not available yet JSONObject result = new JSONObject(); result.put(STATUS_FIELD, StatementState.RUNNING.getState()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 9a9baedde2..7be6809912 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -50,21 +50,26 @@ public class InteractiveQueryHandler extends AsyncQueryHandler { protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; @Override - protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + protected JSONObject getResponseFromResultIndex( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { String queryId = asyncQueryJobMetadata.getQueryId(); return jobExecutionResponseReader.getResultWithQueryId( queryId, asyncQueryJobMetadata.getResultIndex()); } @Override - protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + protected JSONObject getResponseFromExecutor( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { JSONObject result = new JSONObject(); String queryId = asyncQueryJobMetadata.getQueryId(); Statement statement = getStatementByQueryId( asyncQueryJobMetadata.getSessionId(), queryId, - asyncQueryJobMetadata.getDatasourceName()); + asyncQueryJobMetadata.getDatasourceName(), + asyncQueryRequestContext); StatementState statementState = statement.getStatementState(); result.put(STATUS_FIELD, statementState.getState()); result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse("")); @@ -79,7 +84,8 @@ public String cancelJob( getStatementByQueryId( asyncQueryJobMetadata.getSessionId(), queryId, - asyncQueryJobMetadata.getDatasourceName()) + asyncQueryJobMetadata.getDatasourceName(), + asyncQueryRequestContext) .cancel(); return queryId; } @@ -148,12 +154,16 @@ public DispatchQueryResponse submit( .build(); } - private Statement getStatementByQueryId(String sessionId, String queryId, String datasourceName) { + private Statement getStatementByQueryId( + String sessionId, + String queryId, + String datasourceName, + AsyncQueryRequestContext asyncQueryRequestContext) { Optional session = sessionManager.getSession(sessionId, datasourceName); if (session.isPresent()) { // todo, statementId == jobId if statement running in session. StatementId statementId = new StatementId(queryId); - Optional statement = session.get().get(statementId); + Optional statement = session.get().get(statementId, asyncQueryRequestContext); if (statement.isPresent()) { return statement.get(); } else { diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index a6fdd3f102..710f472acb 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -157,9 +157,11 @@ private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetail && !indexQueryDetails.getFlintIndexOptions().autoRefresh())); } - public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { + public JSONObject getQueryResponse( + AsyncQueryJobMetadata asyncQueryJobMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { return getAsyncQueryHandlerForExistingQuery(asyncQueryJobMetadata) - .getQueryResponse(asyncQueryJobMetadata); + .getQueryResponse(asyncQueryJobMetadata, asyncQueryRequestContext); } public String cancelJob( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index aeedaef4e7..2915e2a3e1 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -121,9 +121,10 @@ public StatementId submit( } @Override - public Optional get(StatementId stID) { + public Optional get( + StatementId stID, AsyncQueryRequestContext asyncQueryRequestContext) { return statementStorageService - .getStatement(stID.getId(), sessionModel.getDatasourceName()) + .getStatement(stID.getId(), sessionModel.getDatasourceName(), asyncQueryRequestContext) .map( model -> Statement.builder() @@ -137,6 +138,7 @@ public Optional get(StatementId stID) { .queryId(model.getQueryId()) .statementStorageService(statementStorageService) .statementModel(model) + .asyncQueryRequestContext(asyncQueryRequestContext) .build()); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index fad097ca1b..4c083d79c4 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -35,7 +35,7 @@ void open( * @param stID {@link StatementId} * @return {@link Statement} */ - Optional get(StatementId stID); + Optional get(StatementId stID, AsyncQueryRequestContext asyncQueryRequestContext); SessionModel getSessionModel(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 3237a5d372..272f0edf4a 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -70,7 +70,8 @@ public void cancel() { throw new IllegalStateException(errorMsg); } this.statementModel = - statementStorageService.updateStatementState(statementModel, StatementState.CANCELLED); + statementStorageService.updateStatementState( + statementModel, StatementState.CANCELLED, asyncQueryRequestContext); } public StatementState getStatementState() { diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java index 39f1ecf704..b9446809fb 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/execution/statestore/StatementStorageService.java @@ -20,7 +20,10 @@ StatementModel createStatement( StatementModel statementModel, AsyncQueryRequestContext asyncQueryRequestContext); StatementModel updateStatementState( - StatementModel oldStatementModel, StatementState statementState); + StatementModel oldStatementModel, + StatementState statementState, + AsyncQueryRequestContext asyncQueryRequestContext); - Optional getStatement(String id, String datasourceName); + Optional getStatement( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index ff92762a7c..feb8c8c0ac 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -377,7 +377,8 @@ public void getResultOfInteractiveQuery() { when(jobExecutionResponseReader.getResultWithQueryId(QUERY_ID, RESULT_INDEX)) .thenReturn(result); - AsyncQueryExecutionResponse response = asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID); + AsyncQueryExecutionResponse response = + asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID, asyncQueryRequestContext); assertEquals("SUCCESS", response.getStatus()); assertEquals(SESSION_ID, response.getSessionId()); @@ -395,7 +396,8 @@ public void getResultOfIndexDMLQuery() { when(jobExecutionResponseReader.getResultWithQueryId(QUERY_ID, RESULT_INDEX)) .thenReturn(result); - AsyncQueryExecutionResponse response = asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID); + AsyncQueryExecutionResponse response = + asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID, asyncQueryRequestContext); assertEquals("SUCCESS", response.getStatus()); assertNull(response.getSessionId()); @@ -413,7 +415,8 @@ public void getResultOfRefreshQuery() { JSONObject result = getValidExecutionResponse(); when(jobExecutionResponseReader.getResultWithJobId(JOB_ID, RESULT_INDEX)).thenReturn(result); - AsyncQueryExecutionResponse response = asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID); + AsyncQueryExecutionResponse response = + asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID, asyncQueryRequestContext); assertEquals("SUCCESS", response.getStatus()); assertNull(response.getSessionId()); @@ -428,13 +431,15 @@ public void cancelInteractiveQuery() { final StatementModel statementModel = givenStatementExists(); StatementModel canceledStatementModel = StatementModel.copyWithState(statementModel, StatementState.CANCELLED, ImmutableMap.of()); - when(statementStorageService.updateStatementState(statementModel, StatementState.CANCELLED)) + when(statementStorageService.updateStatementState( + statementModel, StatementState.CANCELLED, asyncQueryRequestContext)) .thenReturn(canceledStatementModel); String result = asyncQueryExecutorService.cancelQuery(QUERY_ID, asyncQueryRequestContext); assertEquals(QUERY_ID, result); - verify(statementStorageService).updateStatementState(statementModel, StatementState.CANCELLED); + verify(statementStorageService) + .updateStatementState(statementModel, StatementState.CANCELLED, asyncQueryRequestContext); } @Test @@ -596,7 +601,7 @@ private StatementModel givenStatementExists() { .statementId(new StatementId(QUERY_ID)) .statementState(StatementState.RUNNING) .build(); - when(statementStorageService.getStatement(QUERY_ID, DATASOURCE_NAME)) + when(statementStorageService.getStatement(QUERY_ID, DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn(Optional.of(statementModel)); return statementModel; } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 5d8d9a3b63..1491f0bd61 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -152,7 +152,7 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( AsyncQueryNotFoundException.class, - () -> jobExecutorService.getAsyncQueryResults(EMR_JOB_ID)); + () -> jobExecutorService.getAsyncQueryResults(EMR_JOB_ID, asyncQueryRequestContext)); Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); @@ -166,10 +166,12 @@ void testGetAsyncQueryResultsWithInProgressJob() { .thenReturn(Optional.of(getAsyncQueryJobMetadata())); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); - when(sparkQueryDispatcher.getQueryResponse(getAsyncQueryJobMetadata())).thenReturn(jobResult); + when(sparkQueryDispatcher.getQueryResponse( + getAsyncQueryJobMetadata(), asyncQueryRequestContext)) + .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); + jobExecutorService.getAsyncQueryResults(EMR_JOB_ID, asyncQueryRequestContext); Assertions.assertNull(asyncQueryExecutionResponse.getResults()); Assertions.assertNull(asyncQueryExecutionResponse.getSchema()); @@ -183,10 +185,12 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException { .thenReturn(Optional.of(getAsyncQueryJobMetadata())); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); - when(sparkQueryDispatcher.getQueryResponse(getAsyncQueryJobMetadata())).thenReturn(jobResult); + when(sparkQueryDispatcher.getQueryResponse( + getAsyncQueryJobMetadata(), asyncQueryRequestContext)) + .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); + jobExecutorService.getAsyncQueryResults(EMR_JOB_ID, asyncQueryRequestContext); Assertions.assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); Assertions.assertEquals(1, asyncQueryExecutionResponse.getSchema().getColumns().size()); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 9a3c4e663e..570a7cab7d 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -66,7 +66,9 @@ class IndexDMLHandlerTest { @Test public void getResponseFromExecutor() { - JSONObject result = new IndexDMLHandler(null, null, null, null).getResponseFromExecutor(null); + JSONObject result = + new IndexDMLHandler(null, null, null, null) + .getResponseFromExecutor(null, asyncQueryRequestContext); assertEquals("running", result.getString(STATUS_FIELD)); assertEquals("", result.getString(ERROR_FIELD)); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index ee840e8b4c..b6369292a6 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -847,7 +847,7 @@ void testCancelJob() { @Test void testCancelQueryWithSession() { doReturn(Optional.of(session)).when(sessionManager).getSession(MOCK_SESSION_ID, MY_GLUE); - doReturn(Optional.of(statement)).when(session).get(any()); + doReturn(Optional.of(statement)).when(session).get(any(), eq(asyncQueryRequestContext)); doNothing().when(statement).cancel(); String queryId = @@ -919,7 +919,8 @@ void testGetQueryResponse() { when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null)) .thenReturn(new JSONObject()); - JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + JSONObject result = + sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext); Assertions.assertEquals("PENDING", result.get("status")); } @@ -927,7 +928,7 @@ void testGetQueryResponse() { @Test void testGetQueryResponseWithSession() { doReturn(Optional.of(session)).when(sessionManager).getSession(MOCK_SESSION_ID, MY_GLUE); - doReturn(Optional.of(statement)).when(session).get(any()); + doReturn(Optional.of(statement)).when(session).get(any(), eq(asyncQueryRequestContext)); when(statement.getStatementModel().getError()).thenReturn("mock error"); doReturn(StatementState.WAITING).when(statement).getStatementState(); doReturn(new JSONObject()) @@ -936,7 +937,8 @@ void testGetQueryResponseWithSession() { JSONObject result = sparkQueryDispatcher.getQueryResponse( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID)); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID), + asyncQueryRequestContext); verifyNoInteractions(emrServerlessClient); Assertions.assertEquals("waiting", result.get("status")); @@ -954,7 +956,8 @@ void testGetQueryResponseWithInvalidSession() { IllegalArgumentException.class, () -> sparkQueryDispatcher.getQueryResponse( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); Assertions.assertEquals("no session found. " + MOCK_SESSION_ID, exception.getMessage()); @@ -963,7 +966,7 @@ void testGetQueryResponseWithInvalidSession() { @Test void testGetQueryResponseWithStatementNotExist() { doReturn(Optional.of(session)).when(sessionManager).getSession(MOCK_SESSION_ID, MY_GLUE); - doReturn(Optional.empty()).when(session).get(any()); + doReturn(Optional.empty()).when(session).get(any(), eq(asyncQueryRequestContext)); doReturn(new JSONObject()) .when(jobExecutionResponseReader) .getResultWithQueryId(eq(MOCK_STATEMENT_ID), any()); @@ -973,7 +976,8 @@ void testGetQueryResponseWithStatementNotExist() { IllegalArgumentException.class, () -> sparkQueryDispatcher.getQueryResponse( - asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID))); + asyncQueryJobMetadataWithSessionId(MOCK_STATEMENT_ID, MOCK_SESSION_ID), + asyncQueryRequestContext)); verifyNoInteractions(emrServerlessClient); Assertions.assertEquals( @@ -989,7 +993,8 @@ void testGetQueryResponseWithSuccess() { queryResult.put(DATA_FIELD, resultMap); when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null)).thenReturn(queryResult); - JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + JSONObject result = + sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext); verify(jobExecutionResponseReader, times(1)).getResultWithJobId(EMR_JOB_ID, null); Assertions.assertEquals( diff --git a/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java b/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java index 67d0609ca5..527cd24bc8 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/OpenSearchStatementStorageService.java @@ -40,14 +40,17 @@ public StatementModel createStatement( } @Override - public Optional getStatement(String id, String datasourceName) { + public Optional getStatement( + String id, String datasourceName, AsyncQueryRequestContext asyncQueryRequestContext) { return stateStore.get( id, serializer::fromXContent, OpenSearchStateStoreUtil.getIndexName(datasourceName)); } @Override public StatementModel updateStatementState( - StatementModel oldStatementModel, StatementState statementState) { + StatementModel oldStatementModel, + StatementState statementState, + AsyncQueryRequestContext asyncQueryRequestContext) { try { return stateStore.updateState( oldStatementModel, @@ -63,7 +66,10 @@ public StatementModel updateStatementState( throw new IllegalStateException(errorMsg); } catch (VersionConflictEngineException e) { StatementModel statementModel = - getStatement(oldStatementModel.getId(), oldStatementModel.getDatasourceName()) + getStatement( + oldStatementModel.getId(), + oldStatementModel.getDatasourceName(), + asyncQueryRequestContext) .orElse(oldStatementModel); String errorMsg = String.format( diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java index 0e9da0c13c..250837e0cd 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultAction.java @@ -16,6 +16,7 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; import org.opensearch.sql.spark.transport.model.AsyncQueryResult; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; @@ -50,7 +51,7 @@ protected void doExecute( try { String jobId = request.getQueryId(); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(jobId); + asyncQueryExecutorService.getAsyncQueryResults(jobId, new NullAsyncQueryRequestContext()); ResponseFormatter formatter = new AsyncQueryResultResponseFormatter(JsonResponseFormatter.Style.PRETTY); String responseContent = diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index ede8a348b4..db0adfc156 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -66,7 +66,8 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { // 2. fetch async query result. AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("RUNNING", asyncQueryResults.getStatus()); emrsClient.getJobRunResultCalled(1); @@ -152,13 +153,15 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { asyncQueryRequestContext); assertNotNull(response.getSessionId()); Optional statementModel = - statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); + statementStorageService.getStatement( + response.getQueryId(), MYS3_DATASOURCE, asyncQueryRequestContext); assertTrue(statementModel.isPresent()); assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); // 2. fetch async query result. AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("", asyncQueryResults.getError()); assertTrue(Strings.isEmpty(asyncQueryResults.getError())); assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); @@ -211,13 +214,15 @@ public void reuseSessionWhenCreateAsyncQuery() { .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); Optional firstModel = - statementStorageService.getStatement(first.getQueryId(), MYS3_DATASOURCE); + statementStorageService.getStatement( + first.getQueryId(), MYS3_DATASOURCE, asyncQueryRequestContext); assertTrue(firstModel.isPresent()); assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); assertEquals(first.getQueryId(), firstModel.get().getQueryId()); Optional secondModel = - statementStorageService.getStatement(second.getQueryId(), MYS3_DATASOURCE); + statementStorageService.getStatement( + second.getQueryId(), MYS3_DATASOURCE, asyncQueryRequestContext); assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); assertEquals(second.getQueryId(), secondModel.get().getQueryId()); @@ -311,7 +316,8 @@ public void withSessionCreateAsyncQueryFailed() { asyncQueryRequestContext); assertNotNull(response.getSessionId()); Optional statementModel = - statementStorageService.getStatement(response.getQueryId(), MYS3_DATASOURCE); + statementStorageService.getStatement( + response.getQueryId(), MYS3_DATASOURCE, asyncQueryRequestContext); assertTrue(statementModel.isPresent()); assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); @@ -334,10 +340,12 @@ public void withSessionCreateAsyncQueryFailed() { .error("mock error") .metadata(submitted.getMetadata()) .build(); - statementStorageService.updateStatementState(mocked, StatementState.FAILED); + statementStorageService.updateStatementState( + mocked, StatementState.FAILED, asyncQueryRequestContext); AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals(StatementState.FAILED.getState(), asyncQueryResults.getStatus()); assertEquals("mock error", asyncQueryResults.getError()); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index e0f04761c7..7ccbad969d 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -450,7 +450,8 @@ AssertionHelper withInteraction(Interaction interaction) { AssertionHelper assertQueryResults(String status, List data) { AsyncQueryExecutionResponse results = - queryService.getAsyncQueryResults(createQueryResponse.getQueryId()); + queryService.getAsyncQueryResults( + createQueryResponse.getQueryId(), asyncQueryRequestContext); assertEquals(status, results.getStatus()); assertEquals(data, results.getResults()); return this; @@ -458,7 +459,8 @@ AssertionHelper assertQueryResults(String status, List data) { AssertionHelper assertFormattedQueryResults(String expected) { AsyncQueryExecutionResponse results = - queryService.getAsyncQueryResults(createQueryResponse.getQueryId()); + queryService.getAsyncQueryResults( + createQueryResponse.getQueryId(), asyncQueryRequestContext); ResponseFormatter formatter = new AsyncQueryResultResponseFormatter(JsonResponseFormatter.Style.COMPACT); @@ -515,8 +517,11 @@ void emrJobWriteResultDoc(Map resultDoc) { /** Simulate EMR-S updates query_execution_request with state */ void emrJobUpdateStatementState(StatementState newState) { - StatementModel stmt = statementStorageService.getStatement(queryId, MYS3_DATASOURCE).get(); - statementStorageService.updateStatementState(stmt, newState); + StatementModel stmt = + statementStorageService + .getStatement(queryId, MYS3_DATASOURCE, asyncQueryRequestContext) + .get(); + statementStorageService.updateStatementState(stmt, newState, asyncQueryRequestContext); } void emrJobUpdateJobState(JobRunState jobState) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java index 70a43e42d5..d69c7d4864 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecAlterTest.java @@ -86,7 +86,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); emrsClient.startJobRunCalled(0); emrsClient.cancelJobRunCalled(1); @@ -155,7 +156,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); emrsClient.startJobRunCalled(0); emrsClient.cancelJobRunCalled(1); @@ -237,7 +239,8 @@ public CancelJobRunResult cancelJobRun( // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); emrsClient.startJobRunCalled(0); emrsClient.cancelJobRunCalled(1); @@ -303,7 +306,7 @@ public void testAlterIndexQueryConvertingToAutoRefresh() { assertEquals( "RUNNING", asyncQueryExecutorService - .getAsyncQueryResults(response.getQueryId()) + .getAsyncQueryResults(response.getQueryId(), asyncQueryRequestContext) .getStatus()); flintIndexJob.assertState(FlintIndexState.ACTIVE); @@ -369,7 +372,7 @@ public void testAlterIndexQueryWithOutAnyAutoRefresh() { assertEquals( "RUNNING", asyncQueryExecutorService - .getAsyncQueryResults(response.getQueryId()) + .getAsyncQueryResults(response.getQueryId(), asyncQueryRequestContext) .getStatus()); flintIndexJob.assertState(FlintIndexState.ACTIVE); @@ -442,7 +445,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals( "Altering to full refresh only allows: [auto_refresh, incremental_refresh]" @@ -517,7 +521,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals( "Altering to incremental refresh only allows: [auto_refresh, incremental_refresh," @@ -586,7 +591,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals( "Conversion to incremental refresh index cannot proceed due to missing" @@ -648,7 +654,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals( "Conversion to incremental refresh index cannot proceed due to missing" @@ -712,7 +719,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals( "Conversion to incremental refresh index cannot proceed due to missing" @@ -776,7 +784,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); emrsClient.startJobRunCalled(0); emrsClient.getJobRunResultCalled(1); @@ -837,7 +846,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); emrsClient.startJobRunCalled(0); emrsClient.getJobRunResultCalled(1); @@ -896,7 +906,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals( "Transaction failed as flint index is not in a valid state.", @@ -963,7 +974,8 @@ public CancelJobRunResult cancelJobRun( // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); emrsClient.startJobRunCalled(0); emrsClient.cancelJobRunCalled(1); @@ -1028,7 +1040,8 @@ public CancelJobRunResult cancelJobRun( // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals("Internal Server Error.", asyncQueryExecutionResponse.getError()); emrsClient.startJobRunCalled(0); @@ -1094,7 +1107,8 @@ public CancelJobRunResult cancelJobRun( // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals("Internal Server Error.", asyncQueryExecutionResponse.getError()); emrsClient.startJobRunCalled(0); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 29c42446b3..920981abf1 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -143,7 +143,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2.fetch result AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryResults.getStatus()); assertNull(asyncQueryResults.getError()); emrsClient.cancelJobRunCalled(1); @@ -193,7 +194,8 @@ public CancelJobRunResult cancelJobRun( // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryResults.getStatus()); assertNull(asyncQueryResults.getError()); }); @@ -233,7 +235,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryResults.getStatus()); assertEquals("Cancel job operation timed out.", asyncQueryResults.getError()); }); @@ -270,7 +273,8 @@ public CancelJobRunResult cancelJobRun( // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryResults.getStatus()); assertNull(asyncQueryResults.getError()); } @@ -319,7 +323,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2.fetch result AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryResults.getStatus()); assertNull(asyncQueryResults.getError()); emrsClient.cancelJobRunCalled(1); @@ -375,7 +380,8 @@ public CancelJobRunResult cancelJobRun( // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryResults.getStatus()); assertNull(asyncQueryResults.getError()); @@ -422,7 +428,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryResults.getStatus()); assertEquals("Cancel job operation timed out.", asyncQueryResults.getError()); flintIndexJob.assertState(FlintIndexState.REFRESHING); @@ -470,7 +477,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { assertEquals( "SUCCESS", asyncQueryExecutorService - .getAsyncQueryResults(response.getQueryId()) + .getAsyncQueryResults(response.getQueryId(), asyncQueryRequestContext) .getStatus()); flintIndexJob.assertState(FlintIndexState.DELETED); @@ -519,7 +526,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); flintIndexJob.assertState(FlintIndexState.DELETED); emrsClient.startJobRunCalled(0); @@ -569,7 +577,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { assertEquals( "SUCCESS", asyncQueryExecutorService - .getAsyncQueryResults(response.getQueryId()) + .getAsyncQueryResults(response.getQueryId(), asyncQueryRequestContext) .getStatus()); flintIndexJob.assertState(FlintIndexState.DELETED); @@ -616,7 +624,7 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { assertEquals( "SUCCESS", asyncQueryExecutorService - .getAsyncQueryResults(response.getQueryId()) + .getAsyncQueryResults(response.getQueryId(), asyncQueryRequestContext) .getStatus()); flintIndexJob.assertState(FlintIndexState.DELETED); @@ -668,7 +676,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { asyncQueryRequestContext); AsyncQueryExecutionResponse asyncQueryExecutionResponse = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); // 2. fetch result assertEquals("FAILED", asyncQueryExecutionResponse.getStatus()); assertEquals( @@ -714,7 +723,8 @@ public CancelJobRunResult cancelJobRun( // 2.fetch result. AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryResults.getStatus()); assertEquals("Internal Server Error.", asyncQueryResults.getError()); @@ -762,7 +772,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { // 2. fetch result AsyncQueryExecutionResponse asyncQueryResults = - asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); assertEquals("FAILED", asyncQueryResults.getStatus()); assertTrue(asyncQueryResults.getError().contains("no state found")); }); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java index 439b2ed2d6..e62b60bfd2 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java @@ -174,7 +174,8 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { new CreateAsyncQueryRequest(mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), asyncQueryRequestContext); - return asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + return asyncQueryExecutorService.getAsyncQueryResults( + response.getQueryId(), asyncQueryRequestContext); } private boolean flintIndexExists(String flintIndexName) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index e76776e2fc..fe3d5f3177 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -144,7 +144,8 @@ public void cancelFailedBecauseOfConflict() { st.open(); StatementModel running = - statementStorageService.updateStatementState(st.getStatementModel(), CANCELLED); + statementStorageService.updateStatementState( + st.getStatementModel(), CANCELLED, asyncQueryRequestContext); assertEquals(StatementState.CANCELLED, running.getStatementState()); IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); @@ -265,7 +266,7 @@ public void newStatementFieldAssert() { Session session = sessionManager.createSession(createSessionRequest(), asyncQueryRequestContext); StatementId statementId = session.submit(queryRequest(), asyncQueryRequestContext); - Optional statement = session.get(statementId); + Optional statement = session.get(statementId, asyncQueryRequestContext); assertTrue(statement.isPresent()); assertEquals(session.getSessionId(), statement.get().getSessionId()); @@ -301,7 +302,7 @@ public void getStatementSuccess() { sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING); StatementId statementId = session.submit(queryRequest(), asyncQueryRequestContext); - Optional statement = session.get(statementId); + Optional statement = session.get(statementId, asyncQueryRequestContext); assertTrue(statement.isPresent()); assertEquals(WAITING, statement.get().getStatementState()); assertEquals(statementId, statement.get().getStatementId()); @@ -314,7 +315,8 @@ public void getStatementNotExist() { // App change state to running sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.RUNNING); - Optional statement = session.get(StatementId.newStatementId("not-exist-id")); + Optional statement = + session.get(StatementId.newStatementId("not-exist-id"), asyncQueryRequestContext); assertFalse(statement.isPresent()); } @@ -332,7 +334,8 @@ public TestStatement assertSessionState(StatementState expected) { assertEquals(expected, st.getStatementModel().getStatementState()); Optional model = - statementStorageService.getStatement(st.getStatementId().getId(), TEST_DATASOURCE_NAME); + statementStorageService.getStatement( + st.getStatementId().getId(), TEST_DATASOURCE_NAME, st.getAsyncQueryRequestContext()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementState()); @@ -343,7 +346,8 @@ public TestStatement assertStatementId(StatementId expected) { assertEquals(expected, st.getStatementModel().getStatementId()); Optional model = - statementStorageService.getStatement(st.getStatementId().getId(), TEST_DATASOURCE_NAME); + statementStorageService.getStatement( + st.getStatementId().getId(), TEST_DATASOURCE_NAME, st.getAsyncQueryRequestContext()); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementId()); return this; @@ -361,7 +365,8 @@ public TestStatement cancel() { public TestStatement run() { StatementModel model = - statementStorageService.updateStatementState(st.getStatementModel(), RUNNING); + statementStorageService.updateStatementState( + st.getStatementModel(), RUNNING, st.getAsyncQueryRequestContext()); st.setStatementModel(model); return this; } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java index 34f10b0083..475eceb37e 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/transport/TransportGetAsyncQueryResultActionTest.java @@ -7,6 +7,8 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -33,6 +35,7 @@ import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; import org.opensearch.tasks.Task; @@ -64,8 +67,11 @@ public void testDoExecute() { GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("jobId"); AsyncQueryExecutionResponse asyncQueryExecutionResponse = new AsyncQueryExecutionResponse("IN_PROGRESS", null, null, null, null); - when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); + when(jobExecutorService.getAsyncQueryResults(eq("jobId"), any())) + .thenReturn(asyncQueryExecutionResponse); + action.doExecute(task, request, actionListener); + verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); GetAsyncQueryResultActionResponse getAsyncQueryResultActionResponse = createJobActionResponseArgumentCaptor.getValue(); @@ -91,8 +97,11 @@ public void testDoExecuteWithSuccessResponse() { tupleValue(ImmutableMap.of("name", "Smith", "age", 30))), null, null); - when(jobExecutorService.getAsyncQueryResults("jobId")).thenReturn(asyncQueryExecutionResponse); + when(jobExecutorService.getAsyncQueryResults(eq("jobId"), any())) + .thenReturn(asyncQueryExecutionResponse); + action.doExecute(task, request, actionListener); + verify(actionListener).onResponse(createJobActionResponseArgumentCaptor.capture()); GetAsyncQueryResultActionResponse getAsyncQueryResultActionResponse = createJobActionResponseArgumentCaptor.getValue(); @@ -130,9 +139,12 @@ public void testDoExecuteWithException() { GetAsyncQueryResultActionRequest request = new GetAsyncQueryResultActionRequest("123"); doThrow(new AsyncQueryNotFoundException("JobId 123 not found")) .when(jobExecutorService) - .getAsyncQueryResults("123"); + .getAsyncQueryResults(eq("123"), any()); + action.doExecute(task, request, actionListener); - verify(jobExecutorService, times(1)).getAsyncQueryResults("123"); + + verify(jobExecutorService, times(1)) + .getAsyncQueryResults(eq("123"), any(NullAsyncQueryRequestContext.class)); verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); Exception exception = exceptionArgumentCaptor.getValue(); Assertions.assertTrue(exception instanceof RuntimeException); From d260e0e449051eac0e3bc73db04103badebeb848 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Thu, 29 Aug 2024 21:10:34 -0700 Subject: [PATCH 02/11] Extract validation logic from FlintIndexMetadataServiceImpl (#2944) Signed-off-by: Tomoyuki Morita --- .../flint/FlintIndexMetadataValidator.java | 88 ++++++++++++++++++ .../FlintIndexMetadataValidatorTest.java | 90 +++++++++++++++++++ .../flint/FlintIndexMetadataServiceImpl.java | 69 +------------- 3 files changed, 179 insertions(+), 68 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidator.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidatorTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidator.java new file mode 100644 index 0000000000..68ba34c476 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidator.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.AUTO_REFRESH; +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.CHECKPOINT_LOCATION; +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.INCREMENTAL_REFRESH; +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.WATERMARK_DELAY; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class FlintIndexMetadataValidator { + private static final Logger LOGGER = LogManager.getLogger(FlintIndexMetadataValidator.class); + + public static final Set ALTER_TO_FULL_REFRESH_ALLOWED_OPTIONS = + new LinkedHashSet<>(Arrays.asList(AUTO_REFRESH, INCREMENTAL_REFRESH)); + public static final Set ALTER_TO_INCREMENTAL_REFRESH_ALLOWED_OPTIONS = + new LinkedHashSet<>( + Arrays.asList(AUTO_REFRESH, INCREMENTAL_REFRESH, WATERMARK_DELAY, CHECKPOINT_LOCATION)); + + /** + * Validate if the flint index options contain valid key/value pairs. Throws + * IllegalArgumentException with description about invalid options. + */ + public static void validateFlintIndexOptions( + String kind, Map existingOptions, Map newOptions) { + if ((newOptions.containsKey(INCREMENTAL_REFRESH) + && Boolean.parseBoolean(newOptions.get(INCREMENTAL_REFRESH))) + || ((!newOptions.containsKey(INCREMENTAL_REFRESH) + && Boolean.parseBoolean((String) existingOptions.get(INCREMENTAL_REFRESH))))) { + validateConversionToIncrementalRefresh(kind, existingOptions, newOptions); + } else { + validateConversionToFullRefresh(newOptions); + } + } + + private static void validateConversionToFullRefresh(Map newOptions) { + if (!ALTER_TO_FULL_REFRESH_ALLOWED_OPTIONS.containsAll(newOptions.keySet())) { + throw new IllegalArgumentException( + String.format( + "Altering to full refresh only allows: %s options", + ALTER_TO_FULL_REFRESH_ALLOWED_OPTIONS)); + } + } + + private static void validateConversionToIncrementalRefresh( + String kind, Map existingOptions, Map newOptions) { + if (!ALTER_TO_INCREMENTAL_REFRESH_ALLOWED_OPTIONS.containsAll(newOptions.keySet())) { + throw new IllegalArgumentException( + String.format( + "Altering to incremental refresh only allows: %s options", + ALTER_TO_INCREMENTAL_REFRESH_ALLOWED_OPTIONS)); + } + HashMap mergedOptions = new HashMap<>(); + mergedOptions.putAll(existingOptions); + mergedOptions.putAll(newOptions); + List missingAttributes = new ArrayList<>(); + if (!mergedOptions.containsKey(CHECKPOINT_LOCATION) + || StringUtils.isEmpty((String) mergedOptions.get(CHECKPOINT_LOCATION))) { + missingAttributes.add(CHECKPOINT_LOCATION); + } + if (kind.equals("mv") + && (!mergedOptions.containsKey(WATERMARK_DELAY) + || StringUtils.isEmpty((String) mergedOptions.get(WATERMARK_DELAY)))) { + missingAttributes.add(WATERMARK_DELAY); + } + if (missingAttributes.size() > 0) { + String errorMessage = + "Conversion to incremental refresh index cannot proceed due to missing attributes: " + + String.join(", ", missingAttributes) + + "."; + LOGGER.error(errorMessage); + throw new IllegalArgumentException(errorMessage); + } + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidatorTest.java new file mode 100644 index 0000000000..7a1e718c05 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataValidatorTest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.AUTO_REFRESH; +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.CHECKPOINT_LOCATION; +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.INCREMENTAL_REFRESH; +import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.WATERMARK_DELAY; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.junit.jupiter.api.Test; + +class FlintIndexMetadataValidatorTest { + @Test + public void conversionToIncrementalRefreshWithValidOption() { + Map existingOptions = + ImmutableMap.builder().put(INCREMENTAL_REFRESH, "false").build(); + Map newOptions = + ImmutableMap.builder() + .put(INCREMENTAL_REFRESH, "true") + .put(CHECKPOINT_LOCATION, "checkpoint_location") + .put(WATERMARK_DELAY, "1") + .build(); + + FlintIndexMetadataValidator.validateFlintIndexOptions("mv", existingOptions, newOptions); + } + + @Test + public void conversionToIncrementalRefreshWithMissingOptions() { + Map existingOptions = + ImmutableMap.builder().put(AUTO_REFRESH, "true").build(); + Map newOptions = + ImmutableMap.builder().put(INCREMENTAL_REFRESH, "true").build(); + + assertThrows( + IllegalArgumentException.class, + () -> + FlintIndexMetadataValidator.validateFlintIndexOptions( + "mv", existingOptions, newOptions)); + } + + @Test + public void conversionToIncrementalRefreshWithInvalidOption() { + Map existingOptions = + ImmutableMap.builder().put(INCREMENTAL_REFRESH, "false").build(); + Map newOptions = + ImmutableMap.builder() + .put(INCREMENTAL_REFRESH, "true") + .put("INVALID_OPTION", "1") + .build(); + + assertThrows( + IllegalArgumentException.class, + () -> + FlintIndexMetadataValidator.validateFlintIndexOptions( + "mv", existingOptions, newOptions)); + } + + @Test + public void conversionToFullRefreshWithValidOption() { + Map existingOptions = + ImmutableMap.builder().put(AUTO_REFRESH, "false").build(); + Map newOptions = + ImmutableMap.builder().put(AUTO_REFRESH, "true").build(); + + FlintIndexMetadataValidator.validateFlintIndexOptions("mv", existingOptions, newOptions); + } + + @Test + public void conversionToFullRefreshWithInvalidOption() { + Map existingOptions = + ImmutableMap.builder().put(AUTO_REFRESH, "false").build(); + Map newOptions = + ImmutableMap.builder() + .put(AUTO_REFRESH, "true") + .put(WATERMARK_DELAY, "1") + .build(); + + assertThrows( + IllegalArgumentException.class, + () -> + FlintIndexMetadataValidator.validateFlintIndexOptions( + "mv", existingOptions, newOptions)); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java index b8352d15b2..38789dd796 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataServiceImpl.java @@ -5,10 +5,6 @@ package org.opensearch.sql.spark.flint; -import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.AUTO_REFRESH; -import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.CHECKPOINT_LOCATION; -import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.INCREMENTAL_REFRESH; -import static org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions.WATERMARK_DELAY; import static org.opensearch.sql.spark.flint.FlintIndexMetadata.APP_ID; import static org.opensearch.sql.spark.flint.FlintIndexMetadata.ENV_KEY; import static org.opensearch.sql.spark.flint.FlintIndexMetadata.KIND_KEY; @@ -20,15 +16,9 @@ import static org.opensearch.sql.spark.flint.FlintIndexMetadata.SERVERLESS_EMR_JOB_ID; import static org.opensearch.sql.spark.flint.FlintIndexMetadata.SOURCE_KEY; -import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; -import java.util.Set; import lombok.AllArgsConstructor; -import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; @@ -43,11 +33,6 @@ public class FlintIndexMetadataServiceImpl implements FlintIndexMetadataService private static final Logger LOGGER = LogManager.getLogger(FlintIndexMetadataServiceImpl.class); private final Client client; - public static final Set ALTER_TO_FULL_REFRESH_ALLOWED_OPTIONS = - new LinkedHashSet<>(Arrays.asList(AUTO_REFRESH, INCREMENTAL_REFRESH)); - public static final Set ALTER_TO_INCREMENTAL_REFRESH_ALLOWED_OPTIONS = - new LinkedHashSet<>( - Arrays.asList(AUTO_REFRESH, INCREMENTAL_REFRESH, WATERMARK_DELAY, CHECKPOINT_LOCATION)); @Override public Map getFlintIndexMetadata( @@ -87,63 +72,11 @@ public void updateIndexToManualRefresh( String kind = (String) meta.get("kind"); Map options = (Map) meta.get("options"); Map newOptions = flintIndexOptions.getProvidedOptions(); - validateFlintIndexOptions(kind, options, newOptions); + FlintIndexMetadataValidator.validateFlintIndexOptions(kind, options, newOptions); options.putAll(newOptions); client.admin().indices().preparePutMapping(indexName).setSource(flintMetadataMap).get(); } - private void validateFlintIndexOptions( - String kind, Map existingOptions, Map newOptions) { - if ((newOptions.containsKey(INCREMENTAL_REFRESH) - && Boolean.parseBoolean(newOptions.get(INCREMENTAL_REFRESH))) - || ((!newOptions.containsKey(INCREMENTAL_REFRESH) - && Boolean.parseBoolean((String) existingOptions.get(INCREMENTAL_REFRESH))))) { - validateConversionToIncrementalRefresh(kind, existingOptions, newOptions); - } else { - validateConversionToFullRefresh(newOptions); - } - } - - private void validateConversionToFullRefresh(Map newOptions) { - if (!ALTER_TO_FULL_REFRESH_ALLOWED_OPTIONS.containsAll(newOptions.keySet())) { - throw new IllegalArgumentException( - String.format( - "Altering to full refresh only allows: %s options", - ALTER_TO_FULL_REFRESH_ALLOWED_OPTIONS)); - } - } - - private void validateConversionToIncrementalRefresh( - String kind, Map existingOptions, Map newOptions) { - if (!ALTER_TO_INCREMENTAL_REFRESH_ALLOWED_OPTIONS.containsAll(newOptions.keySet())) { - throw new IllegalArgumentException( - String.format( - "Altering to incremental refresh only allows: %s options", - ALTER_TO_INCREMENTAL_REFRESH_ALLOWED_OPTIONS)); - } - HashMap mergedOptions = new HashMap<>(); - mergedOptions.putAll(existingOptions); - mergedOptions.putAll(newOptions); - List missingAttributes = new ArrayList<>(); - if (!mergedOptions.containsKey(CHECKPOINT_LOCATION) - || StringUtils.isEmpty((String) mergedOptions.get(CHECKPOINT_LOCATION))) { - missingAttributes.add(CHECKPOINT_LOCATION); - } - if (kind.equals("mv") - && (!mergedOptions.containsKey(WATERMARK_DELAY) - || StringUtils.isEmpty((String) mergedOptions.get(WATERMARK_DELAY)))) { - missingAttributes.add(WATERMARK_DELAY); - } - if (missingAttributes.size() > 0) { - String errorMessage = - "Conversion to incremental refresh index cannot proceed due to missing attributes: " - + String.join(", ", missingAttributes) - + "."; - LOGGER.error(errorMessage); - throw new IllegalArgumentException(errorMessage); - } - } - private FlintIndexMetadata fromMetadata(String indexName, Map metaMap) { FlintIndexMetadata.FlintIndexMetadataBuilder flintIndexMetadataBuilder = FlintIndexMetadata.builder(); From c13f7705fca39c737b11fcf6fb0bf5ce9fb540ef Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Tue, 3 Sep 2024 22:13:57 -0700 Subject: [PATCH 03/11] Fix jobType for Batch and IndexDML query (#2955) Signed-off-by: Tomoyuki Morita --- .../sql/spark/dispatcher/BatchQueryHandler.java | 2 +- .../sql/spark/dispatcher/IndexDMLHandler.java | 4 ++-- .../asyncquery/AsyncQueryCoreIntegTest.java | 17 +++++++++-------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index bce1918631..36e4c227b8 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -109,7 +109,7 @@ public DispatchQueryResponse submit( .jobId(jobId) .resultIndex(dataSourceMetadata.getResultIndex()) .datasourceName(dataSourceMetadata.getName()) - .jobType(JobType.INTERACTIVE) + .jobType(JobType.BATCH) .build(); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index d5885e6f2a..4698bfcccc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -82,7 +82,7 @@ public DispatchQueryResponse submit( .jobId(DML_QUERY_JOB_ID) .resultIndex(dataSourceMetadata.getResultIndex()) .datasourceName(dataSourceMetadata.getName()) - .jobType(JobType.INTERACTIVE) + .jobType(JobType.BATCH) .build(); } catch (Exception e) { LOG.error(e.getMessage()); @@ -100,7 +100,7 @@ public DispatchQueryResponse submit( .jobId(DML_QUERY_JOB_ID) .resultIndex(dataSourceMetadata.getResultIndex()) .datasourceName(dataSourceMetadata.getName()) - .jobType(JobType.INTERACTIVE) + .jobType(JobType.BATCH) .build(); } } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index feb8c8c0ac..09767d16bd 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -202,7 +202,7 @@ public void createDropIndexQuery() { verifyGetQueryIdCalled(); verifyCancelJobRunCalled(); verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } @Test @@ -224,7 +224,7 @@ public void createVacuumIndexQuery() { verifyGetQueryIdCalled(); verify(flintIndexClient).deleteIndex(indexName); verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } @Test @@ -255,7 +255,7 @@ public void createAlterIndexQuery() { assertFalse(flintIndexOptions.autoRefresh()); verifyCancelJobRunCalled(); verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } @Test @@ -280,7 +280,7 @@ public void createStreamingQuery() { verifyGetQueryIdCalled(); verify(leaseManager).borrow(any()); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.STREAMING); } private void verifyStartJobRunCalled() { @@ -315,7 +315,7 @@ public void createCreateIndexQuery() { assertNull(response.getSessionId()); verifyGetQueryIdCalled(); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.BATCH); } @Test @@ -337,7 +337,7 @@ public void createRefreshQuery() { verifyGetQueryIdCalled(); verify(leaseManager).borrow(any()); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.BATCH); } @Test @@ -363,7 +363,7 @@ public void createInteractiveQuery() { verifyGetSessionIdCalled(); verify(leaseManager).borrow(any()); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.INTERACTIVE); } @Test @@ -560,7 +560,7 @@ private void verifyGetSessionIdCalled() { assertEquals(APPLICATION_ID, createSessionRequest.getApplicationId()); } - private void verifyStoreJobMetadataCalled(String jobId) { + private void verifyStoreJobMetadataCalled(String jobId, JobType jobType) { verify(asyncQueryJobMetadataStorageService) .storeJobMetadata( asyncQueryJobMetadataArgumentCaptor.capture(), eq(asyncQueryRequestContext)); @@ -568,6 +568,7 @@ private void verifyStoreJobMetadataCalled(String jobId) { assertEquals(QUERY_ID, asyncQueryJobMetadata.getQueryId()); assertEquals(jobId, asyncQueryJobMetadata.getJobId()); assertEquals(DATASOURCE_NAME, asyncQueryJobMetadata.getDatasourceName()); + assertEquals(jobType, asyncQueryJobMetadata.getJobType()); } private void verifyCreateIndexDMLResultCalled() { From a83ab20d90009a9fd30580f2356f28e415d55752 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Tue, 3 Sep 2024 22:21:07 -0700 Subject: [PATCH 04/11] Add queryId Spark parameter to batch query (#2952) Signed-off-by: Tomoyuki Morita --- .../spark/data/constants/SparkConstants.java | 1 + .../spark/dispatcher/BatchQueryHandler.java | 1 + .../dispatcher/StreamingQueryHandler.java | 1 + .../SparkSubmitParametersBuilder.java | 6 ++ .../dispatcher/SparkQueryDispatcherTest.java | 58 ++++++++++++++----- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index e87dbba03e..9b82022d8f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -86,6 +86,7 @@ public class SparkConstants { "com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"; public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; public static final String FLINT_JOB_QUERY = "spark.flint.job.query"; + public static final String FLINT_JOB_QUERY_ID = "spark.flint.job.queryId"; public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 36e4c227b8..c693656150 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -91,6 +91,7 @@ public DispatchQueryResponse submit( sparkSubmitParametersBuilderProvider .getSparkSubmitParametersBuilder() .clusterName(clusterName) + .queryId(context.getQueryId()) .query(dispatchQueryRequest.getQuery()) .dataSource( context.getDataSourceMetadata(), diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java index 80d4be27cf..51e245b57c 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -82,6 +82,7 @@ public DispatchQueryResponse submit( sparkSubmitParametersBuilderProvider .getSparkSubmitParametersBuilder() .clusterName(clusterName) + .queryId(context.getQueryId()) .query(dispatchQueryRequest.getQuery()) .structuredStreaming(true) .dataSource( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java index d9d5859f64..db74d0a5a7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/parameter/SparkSubmitParametersBuilder.java @@ -20,6 +20,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_PORT_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_QUERY_ID; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_REQUEST_INDEX; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_SESSION_ID; import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_PPL_EXTENSION; @@ -108,6 +109,11 @@ public SparkSubmitParametersBuilder query(String query) { return this; } + public SparkSubmitParametersBuilder queryId(String queryId) { + setConfigItem(FLINT_JOB_QUERY_ID, queryId); + return this; + } + public SparkSubmitParametersBuilder dataSource( DataSourceMetadata metadata, DispatchQueryRequest dispatchQueryRequest, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index b6369292a6..1587ce6638 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -162,12 +162,14 @@ void setUp() { @Test void testDispatchSelectQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -206,12 +208,14 @@ void testDispatchSelectQuery() { @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -310,6 +314,7 @@ void testDispatchSelectQueryFailedCreateSession() { @Test void testDispatchCreateAutoRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -318,7 +323,8 @@ void testDispatchCreateAutoRefreshIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -347,6 +353,7 @@ void testDispatchCreateAutoRefreshIndexQuery() { @Test void testDispatchCreateManualRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); @@ -354,7 +361,8 @@ void testDispatchCreateManualRefreshIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = false)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -383,12 +391,14 @@ void testDispatchCreateManualRefreshIndexQuery() { @Test void testDispatchWithPPLQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "source = my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -448,12 +458,14 @@ void testDispatchWithSparkUDFQuery() { @Test void testInvalidSQLQueryDispatchToSpark() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "myselect 1"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -492,12 +504,14 @@ void testInvalidSQLQueryDispatchToSpark() { @Test void testDispatchQueryWithoutATableAndDataSourceName() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "show tables"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -526,6 +540,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { @Test void testDispatchIndexQueryWithoutADatasourceName() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -534,7 +549,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -563,6 +579,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { @Test void testDispatchMaterializedViewQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(INDEX_TAG_KEY, "flint_mv_1"); @@ -570,7 +587,8 @@ void testDispatchMaterializedViewQuery() { tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); String query = "CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_mv_1", @@ -599,12 +617,14 @@ void testDispatchMaterializedViewQuery() { @Test void testDispatchShowMVQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "SHOW MATERIALIZED VIEW IN mys3.default"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -633,12 +653,14 @@ void testDispatchShowMVQuery() { @Test void testRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "REFRESH SKIPPING INDEX ON my_glue.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -667,12 +689,14 @@ void testRefreshIndexQuery() { @Test void testDispatchDescribeIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:batch", @@ -701,6 +725,7 @@ void testDispatchDescribeIndexQuery() { @Test void testDispatchAlterToAutoRefreshIndexQuery() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); @@ -709,7 +734,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() { String query = "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = true)"; - String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, "streaming", QUERY_ID); StartJobRequest expected = new StartJobRequest( "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", @@ -1048,10 +1074,11 @@ void testDispatchQueryWithExtraSparkSubmitParameters() { } private String constructExpectedSparkSubmitParameterString(String query) { - return constructExpectedSparkSubmitParameterString(query, null); + return constructExpectedSparkSubmitParameterString(query, null, null); } - private String constructExpectedSparkSubmitParameterString(String query, String jobType) { + private String constructExpectedSparkSubmitParameterString( + String query, String jobType, String queryId) { query = "\"" + query + "\""; return " --class org.apache.spark.sql.FlintJob " + getConfParam( @@ -1070,6 +1097,7 @@ private String constructExpectedSparkSubmitParameterString(String query, String "spark.datasource.flint.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider", "spark.sql.extensions=org.opensearch.flint.spark.FlintSparkExtensions,org.opensearch.flint.spark.FlintPPLSparkExtensions", "spark.hadoop.hive.metastore.client.factory.class=com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory") + + (queryId != null ? getConfParam("spark.flint.job.queryId=" + queryId) : "") + getConfParam("spark.flint.job.query=" + query) + (jobType != null ? getConfParam("spark.flint.job.type=" + jobType) : "") + getConfParam( From b4a6c60daf1a2c35a740d0245d7b3fb32fd8a0e0 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Wed, 4 Sep 2024 14:04:11 -0700 Subject: [PATCH 05/11] Fix SparkExecutionEngineConfigClusterSetting deserialize issue (#2966) Signed-off-by: Tomoyuki Morita --- .../config/SparkExecutionEngineConfigClusterSetting.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java index adaaa57d31..f940680c06 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java @@ -7,8 +7,10 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.google.gson.Gson; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; /** * This POJO is just for reading stringified json in `plugins.query.executionengine.spark.config` @@ -16,6 +18,8 @@ */ @Data @Builder +@AllArgsConstructor +@NoArgsConstructor @JsonIgnoreProperties(ignoreUnknown = true) public class SparkExecutionEngineConfigClusterSetting { // optional From 729bb13247f12c3b7b91a92276e79430e9477db3 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 4 Sep 2024 15:32:22 -0700 Subject: [PATCH 06/11] Flint query scheduler part 2 (#2961) * Flint query scheduler part 2 Signed-off-by: Louis Chu * spotless apply Signed-off-by: Louis Chu * Add UT Signed-off-by: Louis Chu * Resolve comments Signed-off-by: Louis Chu * Add more UTs Signed-off-by: Louis Chu * Resolve comments Signed-off-by: Louis Chu * Use SQL thread pool Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .../src/main/antlr/FlintSparkSqlExtensions.g4 | 5 +- .../src/main/antlr/SparkSqlBase.g4 | 1 + .../src/main/antlr/SqlBaseLexer.g4 | 2 + .../src/main/antlr/SqlBaseParser.g4 | 22 +- .../dispatcher/model/FlintIndexOptions.java | 6 + .../flint/operation/FlintIndexOpAlter.java | 12 +- .../flint/operation/FlintIndexOpDrop.java | 13 +- .../flint/operation/FlintIndexOpFactory.java | 13 +- .../flint/operation/FlintIndexOpVacuum.java | 11 +- .../spark/scheduler/AsyncQueryScheduler.java | 57 +++++ .../model/AsyncQuerySchedulerRequest.java | 31 +++ .../asyncquery/AsyncQueryCoreIntegTest.java | 110 ++++++++- .../dispatcher/SparkQueryDispatcherTest.java | 2 + .../operation/FlintIndexOpFactoryTest.java | 2 + .../operation/FlintIndexOpVacuumTest.java | 57 ++++- async-query/build.gradle | 2 +- .../OpenSearchAsyncQueryScheduler.java | 58 ++--- .../job/OpenSearchRefreshIndexJob.java | 93 -------- .../job/ScheduledAsyncQueryJobRunner.java | 116 ++++++++++ .../OpenSearchRefreshIndexJobRequest.java | 108 --------- .../model/ScheduledAsyncQueryJobRequest.java | 156 +++++++++++++ .../parser/IntervalScheduleParser.java | 100 +++++++++ ...nSearchScheduleQueryJobRequestParser.java} | 40 ++-- .../config/AsyncExecutorServiceModule.java | 16 +- .../async-query-scheduler-index-mapping.yml | 10 +- .../AsyncQueryExecutorServiceSpec.java | 10 +- .../OpenSearchAsyncQuerySchedulerTest.java | 63 +++--- .../job/OpenSearchRefreshIndexJobTest.java | 145 ------------ .../job/ScheduledAsyncQueryJobRunnerTest.java | 210 ++++++++++++++++++ .../OpenSearchRefreshIndexJobRequestTest.java | 81 ------- .../ScheduledAsyncQueryJobRequestTest.java | 210 ++++++++++++++++++ .../parser/IntervalScheduleParserTest.java | 122 ++++++++++ .../org/opensearch/sql/plugin/SQLPlugin.java | 23 +- 33 files changed, 1371 insertions(+), 536 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/AsyncQueryScheduler.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/model/AsyncQuerySchedulerRequest.java delete mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunner.java delete mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequest.java create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java rename async-query/src/main/java/org/opensearch/sql/spark/scheduler/{OpenSearchRefreshIndexJobRequestParser.java => parser/OpenSearchScheduleQueryJobRequestParser.java} (57%) delete mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunnerTest.java delete mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequestTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java diff --git a/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 b/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 index 2e8d634dad..46e814e9f5 100644 --- a/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 +++ b/async-query-core/src/main/antlr/FlintSparkSqlExtensions.g4 @@ -156,7 +156,10 @@ indexManagementStatement ; showFlintIndexStatement - : SHOW FLINT (INDEX | INDEXES) IN catalogDb=multipartIdentifier + : SHOW FLINT (INDEX | INDEXES) + IN catalogDb=multipartIdentifier #showFlintIndex + | SHOW FLINT (INDEX | INDEXES) EXTENDED + IN catalogDb=multipartIdentifier #showFlintIndexExtended ; indexJobManagementStatement diff --git a/async-query-core/src/main/antlr/SparkSqlBase.g4 b/async-query-core/src/main/antlr/SparkSqlBase.g4 index 283981e471..c53c61adfd 100644 --- a/async-query-core/src/main/antlr/SparkSqlBase.g4 +++ b/async-query-core/src/main/antlr/SparkSqlBase.g4 @@ -163,6 +163,7 @@ DESC: 'DESC'; DESCRIBE: 'DESCRIBE'; DROP: 'DROP'; EXISTS: 'EXISTS'; +EXTENDED: 'EXTENDED'; FALSE: 'FALSE'; FLINT: 'FLINT'; IF: 'IF'; diff --git a/async-query-core/src/main/antlr/SqlBaseLexer.g4 b/async-query-core/src/main/antlr/SqlBaseLexer.g4 index bde298c23e..acfc0011f5 100644 --- a/async-query-core/src/main/antlr/SqlBaseLexer.g4 +++ b/async-query-core/src/main/antlr/SqlBaseLexer.g4 @@ -212,6 +212,7 @@ DIRECTORY: 'DIRECTORY'; DISTINCT: 'DISTINCT'; DISTRIBUTE: 'DISTRIBUTE'; DIV: 'DIV'; +DO: 'DO'; DOUBLE: 'DOUBLE'; DROP: 'DROP'; ELSE: 'ELSE'; @@ -467,6 +468,7 @@ WEEK: 'WEEK'; WEEKS: 'WEEKS'; WHEN: 'WHEN'; WHERE: 'WHERE'; +WHILE: 'WHILE'; WINDOW: 'WINDOW'; WITH: 'WITH'; WITHIN: 'WITHIN'; diff --git a/async-query-core/src/main/antlr/SqlBaseParser.g4 b/async-query-core/src/main/antlr/SqlBaseParser.g4 index c7aa56cf92..5b8805821b 100644 --- a/async-query-core/src/main/antlr/SqlBaseParser.g4 +++ b/async-query-core/src/main/antlr/SqlBaseParser.g4 @@ -63,6 +63,8 @@ compoundStatement : statement | setStatementWithOptionalVarKeyword | beginEndCompoundBlock + | ifElseStatement + | whileStatement ; setStatementWithOptionalVarKeyword @@ -71,6 +73,16 @@ setStatementWithOptionalVarKeyword LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword ; +whileStatement + : beginLabel? WHILE booleanExpression DO compoundBody END WHILE endLabel? + ; + +ifElseStatement + : IF booleanExpression THEN conditionalBodies+=compoundBody + (ELSE IF booleanExpression THEN conditionalBodies+=compoundBody)* + (ELSE elseBody=compoundBody)? END IF + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; @@ -406,9 +418,9 @@ query ; insertInto - : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF errorCapturingNot EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable - | INSERT INTO TABLE? identifierReference partitionSpec? (IF errorCapturingNot EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable - | INSERT INTO TABLE? identifierReference REPLACE whereClause #insertIntoReplaceWhere + : INSERT OVERWRITE TABLE? identifierReference optionsClause? (partitionSpec (IF errorCapturingNot EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable + | INSERT INTO TABLE? identifierReference optionsClause? partitionSpec? (IF errorCapturingNot EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable + | INSERT INTO TABLE? identifierReference optionsClause? REPLACE whereClause #insertIntoReplaceWhere | INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? createFileFormat? #insertOverwriteHiveDir | INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider (OPTIONS options=propertyList)? #insertOverwriteDir ; @@ -1522,6 +1534,7 @@ ansiNonReserved | DIRECTORY | DISTRIBUTE | DIV + | DO | DOUBLE | DROP | ESCAPED @@ -1723,6 +1736,7 @@ ansiNonReserved | VOID | WEEK | WEEKS + | WHILE | WINDOW | YEAR | YEARS @@ -1853,6 +1867,7 @@ nonReserved | DISTINCT | DISTRIBUTE | DIV + | DO | DOUBLE | DROP | ELSE @@ -2092,6 +2107,7 @@ nonReserved | VOID | WEEK | WEEKS + | WHILE | WHEN | WHERE | WINDOW diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/FlintIndexOptions.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/FlintIndexOptions.java index 79af1c91ab..6c7cc7c5fb 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/FlintIndexOptions.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/FlintIndexOptions.java @@ -19,6 +19,7 @@ public class FlintIndexOptions { public static final String INCREMENTAL_REFRESH = "incremental_refresh"; public static final String CHECKPOINT_LOCATION = "checkpoint_location"; public static final String WATERMARK_DELAY = "watermark_delay"; + public static final String SCHEDULER_MODE = "scheduler_mode"; private final Map options = new HashMap<>(); public void setOption(String key, String value) { @@ -33,6 +34,11 @@ public boolean autoRefresh() { return Boolean.parseBoolean(getOption(AUTO_REFRESH).orElse("false")); } + public boolean isExternalScheduler() { + // Default is false, which means using internal scheduler to refresh the index. + return getOption(SCHEDULER_MODE).map(mode -> "external".equals(mode)).orElse(false); + } + public Map getProvidedOptions() { return new HashMap<>(options); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java index 4a00195ebf..de34803823 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java @@ -16,6 +16,7 @@ import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; /** * Index Operation for Altering the flint index. Only handles alter operation when @@ -25,16 +26,19 @@ public class FlintIndexOpAlter extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(FlintIndexOpAlter.class); private final FlintIndexMetadataService flintIndexMetadataService; private final FlintIndexOptions flintIndexOptions; + private final AsyncQueryScheduler asyncQueryScheduler; public FlintIndexOpAlter( FlintIndexOptions flintIndexOptions, FlintIndexStateModelService flintIndexStateModelService, String datasourceName, EMRServerlessClientFactory emrServerlessClientFactory, - FlintIndexMetadataService flintIndexMetadataService) { + FlintIndexMetadataService flintIndexMetadataService, + AsyncQueryScheduler asyncQueryScheduler) { super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); this.flintIndexMetadataService = flintIndexMetadataService; this.flintIndexOptions = flintIndexOptions; + this.asyncQueryScheduler = asyncQueryScheduler; } @Override @@ -57,7 +61,11 @@ void runOp( "Running alter index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); this.flintIndexMetadataService.updateIndexToManualRefresh( flintIndexMetadata.getOpensearchIndexName(), flintIndexOptions, asyncQueryRequestContext); - cancelStreamingJob(flintIndexStateModel); + if (flintIndexMetadata.getFlintIndexOptions().isExternalScheduler()) { + asyncQueryScheduler.unscheduleJob(flintIndexMetadata.getOpensearchIndexName()); + } else { + cancelStreamingJob(flintIndexStateModel); + } } @Override diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java index fc9b644fc7..3fa5423c10 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java @@ -14,16 +14,21 @@ import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; /** Operation to drop Flint index */ public class FlintIndexOpDrop extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); + private final AsyncQueryScheduler asyncQueryScheduler; + public FlintIndexOpDrop( FlintIndexStateModelService flintIndexStateModelService, String datasourceName, - EMRServerlessClientFactory emrServerlessClientFactory) { + EMRServerlessClientFactory emrServerlessClientFactory, + AsyncQueryScheduler asyncQueryScheduler) { super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); + this.asyncQueryScheduler = asyncQueryScheduler; } public boolean validate(FlintIndexState state) { @@ -48,7 +53,11 @@ void runOp( LOG.debug( "Performing drop index operation for index: {}", flintIndexMetadata.getOpensearchIndexName()); - cancelStreamingJob(flintIndexStateModel); + if (flintIndexMetadata.getFlintIndexOptions().isExternalScheduler()) { + asyncQueryScheduler.unscheduleJob(flintIndexMetadata.getOpensearchIndexName()); + } else { + cancelStreamingJob(flintIndexStateModel); + } } @Override diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java index 14cf9fa7c9..9f925e0bcf 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java @@ -11,6 +11,7 @@ import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; @RequiredArgsConstructor public class FlintIndexOpFactory { @@ -18,10 +19,11 @@ public class FlintIndexOpFactory { private final FlintIndexClient flintIndexClient; private final FlintIndexMetadataService flintIndexMetadataService; private final EMRServerlessClientFactory emrServerlessClientFactory; + private final AsyncQueryScheduler asyncQueryScheduler; public FlintIndexOpDrop getDrop(String datasource) { return new FlintIndexOpDrop( - flintIndexStateModelService, datasource, emrServerlessClientFactory); + flintIndexStateModelService, datasource, emrServerlessClientFactory, asyncQueryScheduler); } public FlintIndexOpAlter getAlter(FlintIndexOptions flintIndexOptions, String datasource) { @@ -30,12 +32,17 @@ public FlintIndexOpAlter getAlter(FlintIndexOptions flintIndexOptions, String da flintIndexStateModelService, datasource, emrServerlessClientFactory, - flintIndexMetadataService); + flintIndexMetadataService, + asyncQueryScheduler); } public FlintIndexOpVacuum getVacuum(String datasource) { return new FlintIndexOpVacuum( - flintIndexStateModelService, datasource, flintIndexClient, emrServerlessClientFactory); + flintIndexStateModelService, + datasource, + flintIndexClient, + emrServerlessClientFactory, + asyncQueryScheduler); } public FlintIndexOpCancel getCancel(String datasource) { diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java index 06aaf8ef9f..324ddb5720 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java @@ -14,12 +14,14 @@ import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; /** Flint index vacuum operation. */ public class FlintIndexOpVacuum extends FlintIndexOp { - private static final Logger LOG = LogManager.getLogger(); + private final AsyncQueryScheduler asyncQueryScheduler; + /** OpenSearch client. */ private final FlintIndexClient flintIndexClient; @@ -27,9 +29,11 @@ public FlintIndexOpVacuum( FlintIndexStateModelService flintIndexStateModelService, String datasourceName, FlintIndexClient flintIndexClient, - EMRServerlessClientFactory emrServerlessClientFactory) { + EMRServerlessClientFactory emrServerlessClientFactory, + AsyncQueryScheduler asyncQueryScheduler) { super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); this.flintIndexClient = flintIndexClient; + this.asyncQueryScheduler = asyncQueryScheduler; } @Override @@ -48,6 +52,9 @@ public void runOp( FlintIndexStateModel flintIndex, AsyncQueryRequestContext asyncQueryRequestContext) { LOG.info("Vacuuming Flint index {}", flintIndexMetadata.getOpensearchIndexName()); + if (flintIndexMetadata.getFlintIndexOptions().isExternalScheduler()) { + asyncQueryScheduler.removeJob(flintIndexMetadata.getOpensearchIndexName()); + } flintIndexClient.deleteIndex(flintIndexMetadata.getOpensearchIndexName()); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/AsyncQueryScheduler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/AsyncQueryScheduler.java new file mode 100644 index 0000000000..8ac499081e --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/AsyncQueryScheduler.java @@ -0,0 +1,57 @@ +package org.opensearch.sql.spark.scheduler; + +import org.opensearch.sql.spark.scheduler.model.AsyncQuerySchedulerRequest; + +/** Scheduler interface for scheduling asynchronous query jobs. */ +public interface AsyncQueryScheduler { + + /** + * Schedules a new job in the system. This method creates a new job entry based on the provided + * request parameters. + * + *

Use cases: - Creating a new periodic query execution - Setting up a scheduled data refresh + * task + * + * @param asyncQuerySchedulerRequest The request containing job configuration details + * @throws IllegalArgumentException if a job with the same name already exists + * @throws RuntimeException if there's an error during job creation + */ + void scheduleJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest); + + /** + * Updates an existing job with new parameters. This method modifies the configuration of an + * already scheduled job. + * + *

Use cases: - Changing the schedule of an existing job - Modifying query parameters of a + * scheduled job - Updating resource allocations for a job + * + * @param asyncQuerySchedulerRequest The request containing updated job configuration + * @throws IllegalArgumentException if the job to be updated doesn't exist + * @throws RuntimeException if there's an error during the update process + */ + void updateJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest); + + /** + * Unschedules a job by marking it as disabled and updating its last update time. This method is + * used when you want to temporarily stop a job from running but keep its configuration and + * history in the system. + * + *

Use cases: - Pausing a job that's causing issues without losing its configuration - + * Temporarily disabling a job during maintenance or high-load periods - Allowing for easy + * re-enabling of the job in the future + * + * @param jobId The unique identifier of the job to unschedule + */ + void unscheduleJob(String jobId); + + /** + * Removes a job completely from the scheduler. This method permanently deletes the job and all + * its associated data from the system. + * + *

Use cases: - Cleaning up jobs that are no longer needed - Removing obsolete or erroneously + * created jobs - Freeing up resources by deleting unused job configurations + * + * @param jobId The unique identifier of the job to remove + */ + void removeJob(String jobId); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/model/AsyncQuerySchedulerRequest.java b/async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/model/AsyncQuerySchedulerRequest.java new file mode 100644 index 0000000000..b54e5b30ce --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/scheduler/model/AsyncQuerySchedulerRequest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.model; + +import java.time.Instant; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Represents a job request for a scheduled task. */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AsyncQuerySchedulerRequest { + protected String accountId; + // Scheduler jobid is the opensearch index name until we support multiple jobs per index + protected String jobId; + protected String dataSource; + protected String scheduledQuery; + protected LangType queryLang; + protected Object schedule; + protected boolean enabled; + protected Instant lastUpdateTime; + protected Instant enabledTime; + protected Long lockDurationSeconds; + protected Double jitter; +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 09767d16bd..226e0ff5eb 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -83,6 +83,7 @@ import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; /** * This tests async-query-core library end-to-end using mocked implementation of extension points. @@ -112,6 +113,7 @@ public class AsyncQueryCoreIntegTest { @Mock FlintIndexClient flintIndexClient; @Mock AsyncQueryRequestContext asyncQueryRequestContext; @Mock MetricsService metricsService; + @Mock AsyncQueryScheduler asyncQueryScheduler; @Mock SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; // storage services @@ -159,7 +161,8 @@ public void setUp() { flintIndexStateModelService, flintIndexClient, flintIndexMetadataService, - emrServerlessClientFactory); + emrServerlessClientFactory, + asyncQueryScheduler); QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( jobExecutionResponseReader, @@ -205,6 +208,30 @@ public void createDropIndexQuery() { verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } + @Test + public void createDropIndexQueryWithScheduler() { + givenSparkExecutionEngineConfigIsSupplied(); + givenValidDataSourceMetadataExist(); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); + + String indexName = "flint_datasource_name_table_name_index_name_index"; + givenFlintIndexMetadataExistsWithExternalScheduler(indexName); + + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "DROP INDEX index_name ON table_name", DATASOURCE_NAME, LangType.SQL), + asyncQueryRequestContext); + + assertEquals(QUERY_ID, response.getQueryId()); + assertNull(response.getSessionId()); + verifyGetQueryIdCalled(); + verifyCreateIndexDMLResultCalled(); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + + verify(asyncQueryScheduler).unscheduleJob(indexName); + } + @Test public void createVacuumIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); @@ -227,6 +254,32 @@ public void createVacuumIndexQuery() { verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } + @Test + public void createVacuumIndexQueryWithScheduler() { + givenSparkExecutionEngineConfigIsSupplied(); + givenValidDataSourceMetadataExist(); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); + + String indexName = "flint_datasource_name_table_name_index_name_index"; + givenFlintIndexMetadataExistsWithExternalScheduler(indexName); + + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "VACUUM INDEX index_name ON table_name", DATASOURCE_NAME, LangType.SQL), + asyncQueryRequestContext); + + assertEquals(QUERY_ID, response.getQueryId()); + assertNull(response.getSessionId()); + verifyGetQueryIdCalled(); + + verify(flintIndexClient).deleteIndex(indexName); + verifyCreateIndexDMLResultCalled(); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + + verify(asyncQueryScheduler).removeJob(indexName); + } + @Test public void createAlterIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); @@ -258,6 +311,40 @@ public void createAlterIndexQuery() { verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } + @Test + public void createAlterIndexQueryWithScheduler() { + givenSparkExecutionEngineConfigIsSupplied(); + givenValidDataSourceMetadataExist(); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); + + String indexName = "flint_datasource_name_table_name_index_name_index"; + givenFlintIndexMetadataExistsWithExternalScheduler(indexName); + + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "ALTER INDEX index_name ON table_name WITH (auto_refresh = false)", + DATASOURCE_NAME, + LangType.SQL), + asyncQueryRequestContext); + + assertEquals(QUERY_ID, response.getQueryId()); + assertNull(response.getSessionId()); + verifyGetQueryIdCalled(); + + verify(flintIndexMetadataService) + .updateIndexToManualRefresh( + eq(indexName), flintIndexOptionsArgumentCaptor.capture(), eq(asyncQueryRequestContext)); + + FlintIndexOptions flintIndexOptions = flintIndexOptionsArgumentCaptor.getValue(); + assertFalse(flintIndexOptions.autoRefresh()); + + verify(asyncQueryScheduler).unscheduleJob(indexName); + + verifyCreateIndexDMLResultCalled(); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + } + @Test public void createStreamingQuery() { givenSparkExecutionEngineConfigIsSupplied(); @@ -507,7 +594,8 @@ private void givenSparkExecutionEngineConfigIsSupplied() { .build()); } - private void givenFlintIndexMetadataExists(String indexName) { + private void givenFlintIndexMetadataExists( + String indexName, FlintIndexOptions flintIndexOptions) { when(flintIndexMetadataService.getFlintIndexMetadata(indexName, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( @@ -516,9 +604,27 @@ private void givenFlintIndexMetadataExists(String indexName) { .appId(APPLICATION_ID) .jobId(JOB_ID) .opensearchIndexName(indexName) + .flintIndexOptions(flintIndexOptions) .build())); } + // Overload method for default FlintIndexOptions usage + private void givenFlintIndexMetadataExists(String indexName) { + givenFlintIndexMetadataExists(indexName, new FlintIndexOptions()); + } + + // Method to set up FlintIndexMetadata with external scheduler + private void givenFlintIndexMetadataExistsWithExternalScheduler(String indexName) { + givenFlintIndexMetadataExists(indexName, createExternalSchedulerFlintIndexOptions()); + } + + // Helper method for creating FlintIndexOptions with external scheduler + private FlintIndexOptions createExternalSchedulerFlintIndexOptions() { + FlintIndexOptions options = new FlintIndexOptions(); + options.setOption(FlintIndexOptions.SCHEDULER_MODE, "external"); + return options; + } + private void givenValidDataSourceMetadataExist() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( DATASOURCE_NAME, asyncQueryRequestContext)) diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 1587ce6638..d040db24b2 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -86,6 +86,7 @@ import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; @ExtendWith(MockitoExtension.class) public class SparkQueryDispatcherTest { @@ -108,6 +109,7 @@ public class SparkQueryDispatcherTest { @Mock private QueryIdProvider queryIdProvider; @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Mock private MetricsService metricsService; + @Mock private AsyncQueryScheduler asyncQueryScheduler; private DataSourceSparkParameterComposer dataSourceSparkParameterComposer = (datasourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> { sparkSubmitParameters.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, "basic"); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java index 3bf438aeb9..62ac98f1a2 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java @@ -17,6 +17,7 @@ import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; @ExtendWith(MockitoExtension.class) class FlintIndexOpFactoryTest { @@ -26,6 +27,7 @@ class FlintIndexOpFactoryTest { @Mock private FlintIndexClient flintIndexClient; @Mock private FlintIndexMetadataService flintIndexMetadataService; @Mock private EMRServerlessClientFactory emrServerlessClientFactory; + @Mock private AsyncQueryScheduler asyncQueryScheduler; @InjectMocks FlintIndexOpFactory flintIndexOpFactory; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java index 26858c18fe..08f8efd488 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java @@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -18,11 +19,13 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.flint.FlintIndexClient; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; @ExtendWith(MockitoExtension.class) class FlintIndexOpVacuumTest { @@ -30,16 +33,20 @@ class FlintIndexOpVacuumTest { public static final String DATASOURCE_NAME = "DATASOURCE_NAME"; public static final String LATEST_ID = "LATEST_ID"; public static final String INDEX_NAME = "INDEX_NAME"; + public static final FlintIndexMetadata FLINT_INDEX_METADATA_WITH_LATEST_ID = - FlintIndexMetadata.builder().latestId(LATEST_ID).opensearchIndexName(INDEX_NAME).build(); + createFlintIndexMetadataWithLatestId(); + public static final FlintIndexMetadata FLINT_INDEX_METADATA_WITHOUT_LATEST_ID = - FlintIndexMetadata.builder().opensearchIndexName(INDEX_NAME).build(); + createFlintIndexMetadataWithoutLatestId(); + @Mock FlintIndexClient flintIndexClient; @Mock FlintIndexStateModelService flintIndexStateModelService; @Mock EMRServerlessClientFactory emrServerlessClientFactory; @Mock FlintIndexStateModel flintIndexStateModel; @Mock FlintIndexStateModel transitionedFlintIndexStateModel; @Mock AsyncQueryRequestContext asyncQueryRequestContext; + @Mock AsyncQueryScheduler asyncQueryScheduler; RuntimeException testException = new RuntimeException("Test Exception"); @@ -52,7 +59,33 @@ public void setUp() { flintIndexStateModelService, DATASOURCE_NAME, flintIndexClient, - emrServerlessClientFactory); + emrServerlessClientFactory, + asyncQueryScheduler); + } + + private static FlintIndexMetadata createFlintIndexMetadataWithLatestId() { + return FlintIndexMetadata.builder() + .latestId(LATEST_ID) + .opensearchIndexName(INDEX_NAME) + .flintIndexOptions(new FlintIndexOptions()) + .build(); + } + + private static FlintIndexMetadata createFlintIndexMetadataWithoutLatestId() { + return FlintIndexMetadata.builder() + .opensearchIndexName(INDEX_NAME) + .flintIndexOptions(new FlintIndexOptions()) + .build(); + } + + private FlintIndexMetadata createFlintIndexMetadataWithExternalScheduler() { + FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); + flintIndexOptions.setOption(FlintIndexOptions.SCHEDULER_MODE, "external"); + + return FlintIndexMetadata.builder() + .opensearchIndexName(INDEX_NAME) + .flintIndexOptions(flintIndexOptions) + .build(); } @Test @@ -207,4 +240,22 @@ public void testApplyHappyPath() { .deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext); verify(flintIndexClient).deleteIndex(INDEX_NAME); } + + @Test + public void testRunOpWithExternalScheduler() { + FlintIndexMetadata flintIndexMetadata = createFlintIndexMetadataWithExternalScheduler(); + flintIndexOpVacuum.runOp(flintIndexMetadata, flintIndexStateModel, asyncQueryRequestContext); + + verify(asyncQueryScheduler).removeJob(INDEX_NAME); + verify(flintIndexClient).deleteIndex(INDEX_NAME); + } + + @Test + public void testRunOpWithoutExternalScheduler() { + FlintIndexMetadata flintIndexMetadata = FLINT_INDEX_METADATA_WITHOUT_LATEST_ID; + flintIndexOpVacuum.runOp(flintIndexMetadata, flintIndexStateModel, asyncQueryRequestContext); + + verify(asyncQueryScheduler, never()).removeJob(INDEX_NAME); + verify(flintIndexClient).deleteIndex(INDEX_NAME); + } } diff --git a/async-query/build.gradle b/async-query/build.gradle index abda6161d3..53fdcbe292 100644 --- a/async-query/build.gradle +++ b/async-query/build.gradle @@ -99,7 +99,7 @@ jacocoTestCoverageVerification { // ignore because XContext IOException 'org.opensearch.sql.spark.execution.statestore.StateStore', 'org.opensearch.sql.spark.rest.*', - 'org.opensearch.sql.spark.scheduler.OpenSearchRefreshIndexJobRequestParser', + 'org.opensearch.sql.spark.scheduler.parser.OpenSearchScheduleQueryJobRequestParser', 'org.opensearch.sql.spark.transport.model.*' ] limit { diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java index c7a66fc6be..9ebde4fe83 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQueryScheduler.java @@ -8,10 +8,11 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import com.google.common.annotations.VisibleForTesting; -import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.time.Instant; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; import org.apache.commons.io.IOUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -34,12 +35,13 @@ import org.opensearch.index.engine.DocumentMissingException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.jobscheduler.spi.ScheduledJobRunner; -import org.opensearch.sql.spark.scheduler.job.OpenSearchRefreshIndexJob; -import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; -import org.opensearch.threadpool.ThreadPool; +import org.opensearch.sql.spark.scheduler.job.ScheduledAsyncQueryJobRunner; +import org.opensearch.sql.spark.scheduler.model.AsyncQuerySchedulerRequest; +import org.opensearch.sql.spark.scheduler.model.ScheduledAsyncQueryJobRequest; /** Scheduler class for managing asynchronous query jobs. */ -public class OpenSearchAsyncQueryScheduler { +@RequiredArgsConstructor +public class OpenSearchAsyncQueryScheduler implements AsyncQueryScheduler { public static final String SCHEDULER_INDEX_NAME = ".async-query-scheduler"; public static final String SCHEDULER_PLUGIN_JOB_TYPE = "async-query-scheduler"; private static final String SCHEDULER_INDEX_MAPPING_FILE_NAME = @@ -48,22 +50,14 @@ public class OpenSearchAsyncQueryScheduler { "async-query-scheduler-index-settings.yml"; private static final Logger LOG = LogManager.getLogger(); - private Client client; - private ClusterService clusterService; - - /** Loads job resources, setting up required services and job runner instance. */ - public void loadJobResource(Client client, ClusterService clusterService, ThreadPool threadPool) { - this.client = client; - this.clusterService = clusterService; - OpenSearchRefreshIndexJob openSearchRefreshIndexJob = - OpenSearchRefreshIndexJob.getJobRunnerInstance(); - openSearchRefreshIndexJob.setClusterService(clusterService); - openSearchRefreshIndexJob.setThreadPool(threadPool); - openSearchRefreshIndexJob.setClient(client); - } + private final Client client; + private final ClusterService clusterService; + @Override /** Schedules a new job by indexing it into the job index. */ - public void scheduleJob(OpenSearchRefreshIndexJobRequest request) { + public void scheduleJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) { + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.fromAsyncQuerySchedulerRequest(asyncQuerySchedulerRequest); if (!this.clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) { createAsyncQuerySchedulerIndex(); } @@ -92,19 +86,28 @@ public void scheduleJob(OpenSearchRefreshIndexJobRequest request) { } /** Unschedules a job by marking it as disabled and updating its last update time. */ - public void unscheduleJob(String jobId) throws IOException { - assertIndexExists(); - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(jobId) + @Override + public void unscheduleJob(String jobId) { + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(jobId) .enabled(false) .lastUpdateTime(Instant.now()) .build(); - updateJob(request); + try { + updateJob(request); + LOG.info("Unscheduled job for jobId: {}", jobId); + } catch (IllegalStateException | DocumentMissingException e) { + LOG.error("Failed to unschedule job: {}", jobId, e); + } } /** Updates an existing job with new parameters. */ - public void updateJob(OpenSearchRefreshIndexJobRequest request) throws IOException { + @Override + @SneakyThrows + public void updateJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) { + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.fromAsyncQuerySchedulerRequest(asyncQuerySchedulerRequest); assertIndexExists(); UpdateRequest updateRequest = new UpdateRequest(SCHEDULER_INDEX_NAME, request.getName()); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -130,6 +133,7 @@ public void updateJob(OpenSearchRefreshIndexJobRequest request) throws IOExcepti } /** Removes a job by deleting its document from the index. */ + @Override public void removeJob(String jobId) { assertIndexExists(); DeleteRequest deleteRequest = new DeleteRequest(SCHEDULER_INDEX_NAME, jobId); @@ -192,6 +196,6 @@ private void assertIndexExists() { /** Returns the job runner instance for the scheduler. */ public static ScheduledJobRunner getJobRunner() { - return OpenSearchRefreshIndexJob.getJobRunnerInstance(); + return ScheduledAsyncQueryJobRunner.getJobRunnerInstance(); } } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java deleted file mode 100644 index e465a8790f..0000000000 --- a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJob.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.scheduler.job; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.jobscheduler.spi.JobExecutionContext; -import org.opensearch.jobscheduler.spi.ScheduledJobParameter; -import org.opensearch.jobscheduler.spi.ScheduledJobRunner; -import org.opensearch.plugins.Plugin; -import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; -import org.opensearch.threadpool.ThreadPool; - -/** - * The job runner class for scheduling refresh index query. - * - *

The job runner should be a singleton class if it uses OpenSearch client or other objects - * passed from OpenSearch. Because when registering the job runner to JobScheduler plugin, - * OpenSearch has not invoked plugins' createComponents() method. That is saying the plugin is not - * completely initialized, and the OpenSearch {@link org.opensearch.client.Client}, {@link - * ClusterService} and other objects are not available to plugin and this job runner. - * - *

So we have to move this job runner initialization to {@link Plugin} createComponents() method, - * and using singleton job runner to ensure we register a usable job runner instance to JobScheduler - * plugin. - */ -public class OpenSearchRefreshIndexJob implements ScheduledJobRunner { - - private static final Logger log = LogManager.getLogger(OpenSearchRefreshIndexJob.class); - - public static OpenSearchRefreshIndexJob INSTANCE = new OpenSearchRefreshIndexJob(); - - public static OpenSearchRefreshIndexJob getJobRunnerInstance() { - return INSTANCE; - } - - private ClusterService clusterService; - private ThreadPool threadPool; - private Client client; - - private OpenSearchRefreshIndexJob() { - // Singleton class, use getJobRunnerInstance method instead of constructor - } - - public void setClusterService(ClusterService clusterService) { - this.clusterService = clusterService; - } - - public void setThreadPool(ThreadPool threadPool) { - this.threadPool = threadPool; - } - - public void setClient(Client client) { - this.client = client; - } - - @Override - public void runJob(ScheduledJobParameter jobParameter, JobExecutionContext context) { - if (!(jobParameter instanceof OpenSearchRefreshIndexJobRequest)) { - throw new IllegalStateException( - "Job parameter is not instance of OpenSearchRefreshIndexJobRequest, type: " - + jobParameter.getClass().getCanonicalName()); - } - - if (this.clusterService == null) { - throw new IllegalStateException("ClusterService is not initialized."); - } - - if (this.threadPool == null) { - throw new IllegalStateException("ThreadPool is not initialized."); - } - - if (this.client == null) { - throw new IllegalStateException("Client is not initialized."); - } - - Runnable runnable = - () -> { - doRefresh(jobParameter.getName()); - }; - threadPool.generic().submit(runnable); - } - - void doRefresh(String refreshIndex) { - // TODO: add logic to refresh index - log.info("Scheduled refresh index job on : " + refreshIndex); - } -} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunner.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunner.java new file mode 100644 index 0000000000..3652acf295 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunner.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.job; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.plugins.Plugin; +import org.opensearch.sql.legacy.executor.AsyncRestExecutor; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; +import org.opensearch.sql.spark.scheduler.model.ScheduledAsyncQueryJobRequest; +import org.opensearch.threadpool.ThreadPool; + +/** + * The job runner class for scheduling async query. + * + *

The job runner should be a singleton class if it uses OpenSearch client or other objects + * passed from OpenSearch. Because when registering the job runner to JobScheduler plugin, + * OpenSearch has not invoked plugins' createComponents() method. That is saying the plugin is not + * completely initialized, and the OpenSearch {@link org.opensearch.client.Client}, {@link + * ClusterService} and other objects are not available to plugin and this job runner. + * + *

So we have to move this job runner initialization to {@link Plugin} createComponents() method, + * and using singleton job runner to ensure we register a usable job runner instance to JobScheduler + * plugin. + */ +public class ScheduledAsyncQueryJobRunner implements ScheduledJobRunner { + // Share SQL plugin thread pool + private static final String ASYNC_QUERY_THREAD_POOL_NAME = + AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME; + private static final Logger LOGGER = LogManager.getLogger(ScheduledAsyncQueryJobRunner.class); + + private static ScheduledAsyncQueryJobRunner INSTANCE = new ScheduledAsyncQueryJobRunner(); + + public static ScheduledAsyncQueryJobRunner getJobRunnerInstance() { + return INSTANCE; + } + + private ClusterService clusterService; + private ThreadPool threadPool; + private Client client; + private AsyncQueryExecutorService asyncQueryExecutorService; + + private ScheduledAsyncQueryJobRunner() { + // Singleton class, use getJobRunnerInstance method instead of constructor + } + + /** Loads job resources, setting up required services and job runner instance. */ + public void loadJobResource( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + AsyncQueryExecutorService asyncQueryExecutorService) { + this.client = client; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.asyncQueryExecutorService = asyncQueryExecutorService; + } + + @Override + public void runJob(ScheduledJobParameter jobParameter, JobExecutionContext context) { + // Parser will convert jobParameter to ScheduledAsyncQueryJobRequest + if (!(jobParameter instanceof ScheduledAsyncQueryJobRequest)) { + throw new IllegalStateException( + "Job parameter is not instance of ScheduledAsyncQueryJobRequest, type: " + + jobParameter.getClass().getCanonicalName()); + } + + if (this.clusterService == null) { + throw new IllegalStateException("ClusterService is not initialized."); + } + + if (this.threadPool == null) { + throw new IllegalStateException("ThreadPool is not initialized."); + } + + if (this.client == null) { + throw new IllegalStateException("Client is not initialized."); + } + + if (this.asyncQueryExecutorService == null) { + throw new IllegalStateException("AsyncQueryExecutorService is not initialized."); + } + + Runnable runnable = + () -> { + try { + doRefresh((ScheduledAsyncQueryJobRequest) jobParameter); + } catch (Throwable throwable) { + LOGGER.error(throwable); + } + }; + threadPool.executor(ASYNC_QUERY_THREAD_POOL_NAME).submit(runnable); + } + + void doRefresh(ScheduledAsyncQueryJobRequest request) { + LOGGER.info("Scheduled refresh index job on job: " + request.getName()); + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest( + request.getScheduledQuery(), request.getDataSource(), request.getQueryLang()); + CreateAsyncQueryResponse createAsyncQueryResponse = + asyncQueryExecutorService.createAsyncQuery( + createAsyncQueryRequest, new NullAsyncQueryRequestContext()); + LOGGER.info("Created async query with queryId: " + createAsyncQueryResponse.getQueryId()); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java deleted file mode 100644 index 7eaa4e2d29..0000000000 --- a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.scheduler.model; - -import java.io.IOException; -import java.time.Instant; -import lombok.Builder; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.jobscheduler.spi.ScheduledJobParameter; -import org.opensearch.jobscheduler.spi.schedule.Schedule; - -/** Represents a job request to refresh index. */ -@Builder -public class OpenSearchRefreshIndexJobRequest implements ScheduledJobParameter { - // Constant fields for JSON serialization - public static final String JOB_NAME_FIELD = "jobName"; - public static final String JOB_TYPE_FIELD = "jobType"; - public static final String LAST_UPDATE_TIME_FIELD = "lastUpdateTime"; - public static final String LAST_UPDATE_TIME_FIELD_READABLE = "last_update_time_field"; - public static final String SCHEDULE_FIELD = "schedule"; - public static final String ENABLED_TIME_FIELD = "enabledTime"; - public static final String ENABLED_TIME_FIELD_READABLE = "enabled_time_field"; - public static final String LOCK_DURATION_SECONDS = "lockDurationSeconds"; - public static final String JITTER = "jitter"; - public static final String ENABLED_FIELD = "enabled"; - - // name is doc id - private final String jobName; - private final String jobType; - private final Schedule schedule; - private final boolean enabled; - private final Instant lastUpdateTime; - private final Instant enabledTime; - private final Long lockDurationSeconds; - private final Double jitter; - - @Override - public String getName() { - return jobName; - } - - public String getJobType() { - return jobType; - } - - @Override - public Schedule getSchedule() { - return schedule; - } - - @Override - public boolean isEnabled() { - return enabled; - } - - @Override - public Instant getLastUpdateTime() { - return lastUpdateTime; - } - - @Override - public Instant getEnabledTime() { - return enabledTime; - } - - @Override - public Long getLockDurationSeconds() { - return lockDurationSeconds; - } - - @Override - public Double getJitter() { - return jitter; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) - throws IOException { - builder.startObject(); - builder.field(JOB_NAME_FIELD, getName()).field(ENABLED_FIELD, isEnabled()); - if (getSchedule() != null) { - builder.field(SCHEDULE_FIELD, getSchedule()); - } - if (getJobType() != null) { - builder.field(JOB_TYPE_FIELD, getJobType()); - } - if (getEnabledTime() != null) { - builder.timeField( - ENABLED_TIME_FIELD, ENABLED_TIME_FIELD_READABLE, getEnabledTime().toEpochMilli()); - } - builder.timeField( - LAST_UPDATE_TIME_FIELD, - LAST_UPDATE_TIME_FIELD_READABLE, - getLastUpdateTime().toEpochMilli()); - if (this.lockDurationSeconds != null) { - builder.field(LOCK_DURATION_SECONDS, this.lockDurationSeconds); - } - if (this.jitter != null) { - builder.field(JITTER, this.jitter); - } - builder.endObject(); - return builder; - } -} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequest.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequest.java new file mode 100644 index 0000000000..9b85a11888 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequest.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.model; + +import java.io.IOException; +import java.time.Instant; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.ToString; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.Schedule; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.scheduler.parser.IntervalScheduleParser; + +/** Represents a job request to refresh index. */ +@Data +@EqualsAndHashCode(callSuper = true) +@ToString(callSuper = true) +public class ScheduledAsyncQueryJobRequest extends AsyncQuerySchedulerRequest + implements ScheduledJobParameter { + // Constant fields for JSON serialization + public static final String ACCOUNT_ID_FIELD = "accountId"; + public static final String JOB_ID_FIELD = "jobId"; + public static final String DATA_SOURCE_NAME_FIELD = "dataSource"; + public static final String SCHEDULED_QUERY_FIELD = "scheduledQuery"; + public static final String QUERY_LANG_FIELD = "queryLang"; + public static final String LAST_UPDATE_TIME_FIELD = "lastUpdateTime"; + public static final String SCHEDULE_FIELD = "schedule"; + public static final String ENABLED_TIME_FIELD = "enabledTime"; + public static final String LOCK_DURATION_SECONDS = "lockDurationSeconds"; + public static final String JITTER = "jitter"; + public static final String ENABLED_FIELD = "enabled"; + private final Schedule schedule; + + @Builder + public ScheduledAsyncQueryJobRequest( + String accountId, + String jobId, + String dataSource, + String scheduledQuery, + LangType queryLang, + Schedule schedule, // Use the OpenSearch Schedule type + boolean enabled, + Instant lastUpdateTime, + Instant enabledTime, + Long lockDurationSeconds, + Double jitter) { + super( + accountId, + jobId, + dataSource, + scheduledQuery, + queryLang, + schedule, + enabled, + lastUpdateTime, + enabledTime, + lockDurationSeconds, + jitter); + this.schedule = schedule; + } + + @Override + public String getName() { + return getJobId(); + } + + @Override + public boolean isEnabled() { + return enabled; + } + + @Override + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + @Override + public Instant getEnabledTime() { + return enabledTime; + } + + @Override + public Schedule getSchedule() { + return schedule; + } + + @Override + public Long getLockDurationSeconds() { + return lockDurationSeconds; + } + + @Override + public Double getJitter() { + return jitter; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) + throws IOException { + builder.startObject(); + if (getAccountId() != null) { + builder.field(ACCOUNT_ID_FIELD, getAccountId()); + } + builder.field(JOB_ID_FIELD, getJobId()).field(ENABLED_FIELD, isEnabled()); + if (getDataSource() != null) { + builder.field(DATA_SOURCE_NAME_FIELD, getDataSource()); + } + if (getScheduledQuery() != null) { + builder.field(SCHEDULED_QUERY_FIELD, getScheduledQuery()); + } + if (getQueryLang() != null) { + builder.field(QUERY_LANG_FIELD, getQueryLang()); + } + if (getSchedule() != null) { + builder.field(SCHEDULE_FIELD, getSchedule()); + } + if (getEnabledTime() != null) { + builder.field(ENABLED_TIME_FIELD, getEnabledTime().toEpochMilli()); + } + builder.field(LAST_UPDATE_TIME_FIELD, getLastUpdateTime().toEpochMilli()); + if (this.lockDurationSeconds != null) { + builder.field(LOCK_DURATION_SECONDS, this.lockDurationSeconds); + } + if (this.jitter != null) { + builder.field(JITTER, this.jitter); + } + builder.endObject(); + return builder; + } + + public static ScheduledAsyncQueryJobRequest fromAsyncQuerySchedulerRequest( + AsyncQuerySchedulerRequest request) { + Instant updateTime = + request.getLastUpdateTime() != null ? request.getLastUpdateTime() : Instant.now(); + return ScheduledAsyncQueryJobRequest.builder() + .accountId(request.getAccountId()) + .jobId(request.getJobId()) + .dataSource(request.getDataSource()) + .scheduledQuery(request.getScheduledQuery()) + .queryLang(request.getQueryLang()) + .enabled(request.isEnabled()) + .lastUpdateTime(updateTime) + .enabledTime(request.getEnabledTime()) + .lockDurationSeconds(request.getLockDurationSeconds()) + .jitter(request.getJitter()) + .schedule(IntervalScheduleParser.parse(request.getSchedule(), updateTime)) + .build(); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java new file mode 100644 index 0000000000..2d5a1b332f --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.parser; + +import com.google.common.annotations.VisibleForTesting; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; + +/** Parse string raw schedule into job scheduler IntervalSchedule */ +public class IntervalScheduleParser { + + private static final Pattern DURATION_PATTERN = + Pattern.compile( + "^(\\d+)\\s*(years?|months?|weeks?|days?|hours?|minutes?|minute|mins?|seconds?|secs?|milliseconds?|millis?|microseconds?|microsecond|micros?|micros|nanoseconds?|nanos?)$", + Pattern.CASE_INSENSITIVE); + + public static Schedule parse(Object schedule, Instant startTime) { + if (schedule == null) { + return null; + } + + if (schedule instanceof Schedule) { + return (Schedule) schedule; + } + + if (!(schedule instanceof String)) { + throw new IllegalArgumentException("Schedule must be a String object for parsing."); + } + + String intervalStr = ((String) schedule).trim().toLowerCase(); + + Matcher matcher = DURATION_PATTERN.matcher(intervalStr); + if (!matcher.matches()) { + throw new IllegalArgumentException("Invalid interval format: " + intervalStr); + } + + long value = Long.parseLong(matcher.group(1)); + String unitStr = matcher.group(2).toLowerCase(); + + // Convert to a supported unit or directly return an IntervalSchedule + long intervalInMinutes = convertToSupportedUnit(value, unitStr); + + return new IntervalSchedule(startTime, (int) intervalInMinutes, ChronoUnit.MINUTES); + } + + @VisibleForTesting + protected static long convertToSupportedUnit(long value, String unitStr) { + switch (unitStr) { + case "years": + case "year": + throw new IllegalArgumentException("Years cannot be converted to minutes accurately."); + case "months": + case "month": + throw new IllegalArgumentException("Months cannot be converted to minutes accurately."); + case "weeks": + case "week": + return value * 7 * 24 * 60; // Convert weeks to minutes + case "days": + case "day": + return value * 24 * 60; // Convert days to minutes + case "hours": + case "hour": + return value * 60; // Convert hours to minutes + case "minutes": + case "minute": + case "mins": + case "min": + return value; // Already in minutes + case "seconds": + case "second": + case "secs": + case "sec": + return value / 60; // Convert seconds to minutes + case "milliseconds": + case "millisecond": + case "millis": + case "milli": + return value / (60 * 1000); // Convert milliseconds to minutes + case "microseconds": + case "microsecond": + case "micros": + case "micro": + return value / (60 * 1000 * 1000); // Convert microseconds to minutes + case "nanoseconds": + case "nanosecond": + case "nanos": + case "nano": + return value / (60 * 1000 * 1000 * 1000L); // Convert nanoseconds to minutes + default: + throw new IllegalArgumentException("Unsupported time unit: " + unitStr); + } + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/OpenSearchScheduleQueryJobRequestParser.java similarity index 57% rename from async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java rename to async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/OpenSearchScheduleQueryJobRequestParser.java index 0422e7c015..9e33ef0248 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/OpenSearchRefreshIndexJobRequestParser.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/OpenSearchScheduleQueryJobRequestParser.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.spark.scheduler; +package org.opensearch.sql.spark.scheduler.parser; import java.io.IOException; import java.time.Instant; @@ -11,9 +11,10 @@ import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.jobscheduler.spi.ScheduledJobParser; import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; -import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.scheduler.model.ScheduledAsyncQueryJobRequest; -public class OpenSearchRefreshIndexJobRequestParser { +public class OpenSearchScheduleQueryJobRequestParser { private static Instant parseInstantValue(XContentParser parser) throws IOException { if (XContentParser.Token.VALUE_NULL.equals(parser.currentToken())) { @@ -28,8 +29,8 @@ private static Instant parseInstantValue(XContentParser parser) throws IOExcepti public static ScheduledJobParser getJobParser() { return (parser, id, jobDocVersion) -> { - OpenSearchRefreshIndexJobRequest.OpenSearchRefreshIndexJobRequestBuilder builder = - OpenSearchRefreshIndexJobRequest.builder(); + ScheduledAsyncQueryJobRequest.ScheduledAsyncQueryJobRequestBuilder builder = + ScheduledAsyncQueryJobRequest.builder(); XContentParserUtils.ensureExpectedToken( XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -37,28 +38,37 @@ public static ScheduledJobParser getJobParser() { String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { - case OpenSearchRefreshIndexJobRequest.JOB_NAME_FIELD: - builder.jobName(parser.text()); + case ScheduledAsyncQueryJobRequest.ACCOUNT_ID_FIELD: + builder.accountId(parser.text()); break; - case OpenSearchRefreshIndexJobRequest.JOB_TYPE_FIELD: - builder.jobType(parser.text()); + case ScheduledAsyncQueryJobRequest.JOB_ID_FIELD: + builder.jobId(parser.text()); break; - case OpenSearchRefreshIndexJobRequest.ENABLED_FIELD: + case ScheduledAsyncQueryJobRequest.DATA_SOURCE_NAME_FIELD: + builder.dataSource(parser.text()); + break; + case ScheduledAsyncQueryJobRequest.SCHEDULED_QUERY_FIELD: + builder.scheduledQuery(parser.text()); + break; + case ScheduledAsyncQueryJobRequest.QUERY_LANG_FIELD: + builder.queryLang(LangType.fromString(parser.text())); + break; + case ScheduledAsyncQueryJobRequest.ENABLED_FIELD: builder.enabled(parser.booleanValue()); break; - case OpenSearchRefreshIndexJobRequest.ENABLED_TIME_FIELD: + case ScheduledAsyncQueryJobRequest.ENABLED_TIME_FIELD: builder.enabledTime(parseInstantValue(parser)); break; - case OpenSearchRefreshIndexJobRequest.LAST_UPDATE_TIME_FIELD: + case ScheduledAsyncQueryJobRequest.LAST_UPDATE_TIME_FIELD: builder.lastUpdateTime(parseInstantValue(parser)); break; - case OpenSearchRefreshIndexJobRequest.SCHEDULE_FIELD: + case ScheduledAsyncQueryJobRequest.SCHEDULE_FIELD: builder.schedule(ScheduleParser.parse(parser)); break; - case OpenSearchRefreshIndexJobRequest.LOCK_DURATION_SECONDS: + case ScheduledAsyncQueryJobRequest.LOCK_DURATION_SECONDS: builder.lockDurationSeconds(parser.longValue()); break; - case OpenSearchRefreshIndexJobRequest.JITTER: + case ScheduledAsyncQueryJobRequest.JITTER: builder.jitter(parser.doubleValue()); break; default: diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 9cc69b2fb7..52ffda483c 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -61,6 +61,8 @@ import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; +import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; @RequiredArgsConstructor public class AsyncExecutorServiceModule extends AbstractModule { @@ -136,12 +138,14 @@ public FlintIndexOpFactory flintIndexOpFactory( FlintIndexStateModelService flintIndexStateModelService, FlintIndexClient flintIndexClient, FlintIndexMetadataServiceImpl flintIndexMetadataService, - EMRServerlessClientFactory emrServerlessClientFactory) { + EMRServerlessClientFactory emrServerlessClientFactory, + AsyncQueryScheduler asyncQueryScheduler) { return new FlintIndexOpFactory( flintIndexStateModelService, flintIndexClient, flintIndexMetadataService, - emrServerlessClientFactory); + emrServerlessClientFactory, + asyncQueryScheduler); } @Provides @@ -245,6 +249,14 @@ public SessionConfigSupplier sessionConfigSupplier(Settings settings) { return new OpenSearchSessionConfigSupplier(settings); } + @Provides + @Singleton + public AsyncQueryScheduler asyncQueryScheduler(NodeClient client, ClusterService clusterService) { + OpenSearchAsyncQueryScheduler scheduler = + new OpenSearchAsyncQueryScheduler(client, clusterService); + return scheduler; + } + private void registerStateStoreMetrics(StateStore stateStore) { GaugeMetric activeSessionMetric = new GaugeMetric<>( diff --git a/async-query/src/main/resources/async-query-scheduler-index-mapping.yml b/async-query/src/main/resources/async-query-scheduler-index-mapping.yml index 36bd1b873e..1aa90e8ed8 100644 --- a/async-query/src/main/resources/async-query-scheduler-index-mapping.yml +++ b/async-query/src/main/resources/async-query-scheduler-index-mapping.yml @@ -8,9 +8,15 @@ # Also "dynamic" is set to "false" so that other fields cannot be added. dynamic: false properties: - name: + accountId: type: keyword - jobType: + jobId: + type: keyword + dataSource: + type: keyword + scheduledQuery: + type: text + queryLang: type: keyword lastUpdateTime: type: date diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 641b083d53..9b897d36b4 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -100,6 +100,8 @@ import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; +import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; +import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; @@ -124,6 +126,7 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { protected StateStore stateStore; protected SessionStorageService sessionStorageService; protected StatementStorageService statementStorageService; + protected AsyncQueryScheduler asyncQueryScheduler; protected AsyncQueryRequestContext asyncQueryRequestContext; protected SessionIdProvider sessionIdProvider = new DatasourceEmbeddedSessionIdProvider(); @@ -204,6 +207,7 @@ public void setup() { new OpenSearchSessionStorageService(stateStore, new SessionModelXContentSerializer()); statementStorageService = new OpenSearchStatementStorageService(stateStore, new StatementModelXContentSerializer()); + asyncQueryScheduler = new OpenSearchAsyncQueryScheduler(client, clusterService); } protected FlintIndexOpFactory getFlintIndexOpFactory( @@ -212,7 +216,8 @@ protected FlintIndexOpFactory getFlintIndexOpFactory( flintIndexStateModelService, flintIndexClient, flintIndexMetadataService, - emrServerlessClientFactory); + emrServerlessClientFactory, + asyncQueryScheduler); } @After @@ -298,7 +303,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( flintIndexStateModelService, flintIndexClient, new FlintIndexMetadataServiceImpl(client), - emrServerlessClientFactory), + emrServerlessClientFactory, + asyncQueryScheduler), emrServerlessClientFactory, new OpenSearchMetricsService(), sparkSubmitParametersBuilderProvider); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java index de86f111f3..a4a6eb6471 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/OpenSearchAsyncQuerySchedulerTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -42,8 +43,7 @@ import org.opensearch.index.engine.DocumentMissingException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.jobscheduler.spi.ScheduledJobRunner; -import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; -import org.opensearch.threadpool.ThreadPool; +import org.opensearch.sql.spark.scheduler.model.ScheduledAsyncQueryJobRequest; public class OpenSearchAsyncQuerySchedulerTest { @@ -57,9 +57,6 @@ public class OpenSearchAsyncQuerySchedulerTest { @Mock(answer = Answers.RETURNS_DEEP_STUBS) private ClusterService clusterService; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private ThreadPool threadPool; - @Mock private ActionFuture indexResponseActionFuture; @Mock private ActionFuture updateResponseActionFuture; @@ -77,8 +74,7 @@ public class OpenSearchAsyncQuerySchedulerTest { @BeforeEach public void setup() { MockitoAnnotations.openMocks(this); - scheduler = new OpenSearchAsyncQueryScheduler(); - scheduler.loadJobResource(client, clusterService, threadPool); + scheduler = new OpenSearchAsyncQueryScheduler(client, clusterService); } @Test @@ -95,9 +91,9 @@ public void testScheduleJob() { when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); @@ -119,9 +115,9 @@ public void testScheduleJobWithExistingJob() { when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)) .thenReturn(Boolean.TRUE); - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); @@ -148,9 +144,9 @@ public void testScheduleJobWithExceptions() { .thenReturn(new CreateIndexResponse(true, true, TEST_SCHEDULER_INDEX_NAME)); when(client.index(any(IndexRequest.class))).thenThrow(new RuntimeException("Test exception")); - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); @@ -199,14 +195,17 @@ public void testUnscheduleJob() throws IOException { public void testUnscheduleJobWithIndexNotFound() { when(clusterService.state().routingTable().hasIndex(SCHEDULER_INDEX_NAME)).thenReturn(false); - assertThrows(IllegalStateException.class, () -> scheduler.unscheduleJob(TEST_JOB_ID)); + scheduler.unscheduleJob(TEST_JOB_ID); + + // Verify that no update operation was performed + verify(client, never()).update(any(UpdateRequest.class)); } @Test public void testUpdateJob() throws IOException { - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); @@ -229,9 +228,9 @@ public void testUpdateJob() throws IOException { @Test public void testUpdateJobWithIndexNotFound() { - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); @@ -242,9 +241,9 @@ public void testUpdateJobWithIndexNotFound() { @Test public void testUpdateJobWithExceptions() { - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); @@ -351,9 +350,9 @@ public void testCreateAsyncQuerySchedulerIndexFailure() { Mockito.when(createIndexResponseActionFuture.actionGet()) .thenReturn(new CreateIndexResponse(false, false, SCHEDULER_INDEX_NAME)); - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); @@ -367,9 +366,9 @@ public void testCreateAsyncQuerySchedulerIndexFailure() { @Test public void testUpdateJobNotFound() { - OpenSearchRefreshIndexJobRequest request = - OpenSearchRefreshIndexJobRequest.builder() - .jobName(TEST_JOB_ID) + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId(TEST_JOB_ID) .lastUpdateTime(Instant.now()) .build(); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java deleted file mode 100644 index cbf137997e..0000000000 --- a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/OpenSearchRefreshIndexJobTest.java +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.scheduler.job; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; - -import java.time.Instant; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Answers; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.jobscheduler.spi.JobExecutionContext; -import org.opensearch.jobscheduler.spi.ScheduledJobParameter; -import org.opensearch.sql.spark.scheduler.model.OpenSearchRefreshIndexJobRequest; -import org.opensearch.threadpool.ThreadPool; - -public class OpenSearchRefreshIndexJobTest { - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private ClusterService clusterService; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private ThreadPool threadPool; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private Client client; - - @Mock private JobExecutionContext context; - - private OpenSearchRefreshIndexJob jobRunner; - - private OpenSearchRefreshIndexJob spyJobRunner; - - @BeforeEach - public void setup() { - MockitoAnnotations.openMocks(this); - jobRunner = OpenSearchRefreshIndexJob.getJobRunnerInstance(); - jobRunner.setClient(null); - jobRunner.setClusterService(null); - jobRunner.setThreadPool(null); - } - - @Test - public void testRunJobWithCorrectParameter() { - spyJobRunner = spy(jobRunner); - spyJobRunner.setClusterService(clusterService); - spyJobRunner.setThreadPool(threadPool); - spyJobRunner.setClient(client); - - OpenSearchRefreshIndexJobRequest jobParameter = - OpenSearchRefreshIndexJobRequest.builder() - .jobName("testJob") - .lastUpdateTime(Instant.now()) - .lockDurationSeconds(10L) - .build(); - - spyJobRunner.runJob(jobParameter, context); - - ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); - verify(threadPool.generic()).submit(captor.capture()); - - Runnable runnable = captor.getValue(); - runnable.run(); - - verify(spyJobRunner).doRefresh(eq(jobParameter.getName())); - } - - @Test - public void testRunJobWithIncorrectParameter() { - jobRunner = OpenSearchRefreshIndexJob.getJobRunnerInstance(); - jobRunner.setClusterService(clusterService); - jobRunner.setThreadPool(threadPool); - jobRunner.setClient(client); - - ScheduledJobParameter wrongParameter = mock(ScheduledJobParameter.class); - - IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> jobRunner.runJob(wrongParameter, context), - "Expected IllegalStateException but no exception was thrown"); - - assertEquals( - "Job parameter is not instance of OpenSearchRefreshIndexJobRequest, type: " - + wrongParameter.getClass().getCanonicalName(), - exception.getMessage()); - } - - @Test - public void testRunJobWithUninitializedServices() { - OpenSearchRefreshIndexJobRequest jobParameter = - OpenSearchRefreshIndexJobRequest.builder() - .jobName("testJob") - .lastUpdateTime(Instant.now()) - .build(); - - IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> jobRunner.runJob(jobParameter, context), - "Expected IllegalStateException but no exception was thrown"); - assertEquals("ClusterService is not initialized.", exception.getMessage()); - - jobRunner.setClusterService(clusterService); - - exception = - assertThrows( - IllegalStateException.class, - () -> jobRunner.runJob(jobParameter, context), - "Expected IllegalStateException but no exception was thrown"); - assertEquals("ThreadPool is not initialized.", exception.getMessage()); - - jobRunner.setThreadPool(threadPool); - - exception = - assertThrows( - IllegalStateException.class, - () -> jobRunner.runJob(jobParameter, context), - "Expected IllegalStateException but no exception was thrown"); - assertEquals("Client is not initialized.", exception.getMessage()); - } - - @Test - public void testGetJobRunnerInstanceMultipleCalls() { - OpenSearchRefreshIndexJob instance1 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); - OpenSearchRefreshIndexJob instance2 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); - OpenSearchRefreshIndexJob instance3 = OpenSearchRefreshIndexJob.getJobRunnerInstance(); - - assertSame(instance1, instance2); - assertSame(instance2, instance3); - } -} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunnerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunnerTest.java new file mode 100644 index 0000000000..cba8d43a2a --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/job/ScheduledAsyncQueryJobRunnerTest.java @@ -0,0 +1,210 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.job; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Instant; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.Appender; +import org.apache.logging.log4j.core.LogEvent; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.sql.legacy.executor.AsyncRestExecutor; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext; +import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.scheduler.model.ScheduledAsyncQueryJobRequest; +import org.opensearch.threadpool.ThreadPool; + +public class ScheduledAsyncQueryJobRunnerTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ThreadPool threadPool; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private AsyncQueryExecutorService asyncQueryExecutorService; + + @Mock private JobExecutionContext context; + + private ScheduledAsyncQueryJobRunner jobRunner; + + private ScheduledAsyncQueryJobRunner spyJobRunner; + + @BeforeEach + public void setup() { + MockitoAnnotations.openMocks(this); + jobRunner = ScheduledAsyncQueryJobRunner.getJobRunnerInstance(); + jobRunner.loadJobResource(null, null, null, null); + } + + @Test + public void testRunJobWithCorrectParameter() { + spyJobRunner = spy(jobRunner); + spyJobRunner.loadJobResource(client, clusterService, threadPool, asyncQueryExecutorService); + + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId("testJob") + .lastUpdateTime(Instant.now()) + .lockDurationSeconds(10L) + .scheduledQuery("REFRESH INDEX testIndex") + .dataSource("testDataSource") + .queryLang(LangType.SQL) + .build(); + + CreateAsyncQueryRequest createAsyncQueryRequest = + new CreateAsyncQueryRequest( + request.getScheduledQuery(), request.getDataSource(), request.getQueryLang()); + spyJobRunner.runJob(request, context); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(threadPool.executor(AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME)) + .submit(captor.capture()); + + Runnable runnable = captor.getValue(); + runnable.run(); + + verify(spyJobRunner).doRefresh(eq(request)); + verify(asyncQueryExecutorService) + .createAsyncQuery(eq(createAsyncQueryRequest), any(NullAsyncQueryRequestContext.class)); + } + + @Test + public void testRunJobWithIncorrectParameter() { + jobRunner = ScheduledAsyncQueryJobRunner.getJobRunnerInstance(); + jobRunner.loadJobResource(client, clusterService, threadPool, asyncQueryExecutorService); + + ScheduledJobParameter wrongParameter = mock(ScheduledJobParameter.class); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(wrongParameter, context), + "Expected IllegalStateException but no exception was thrown"); + + assertEquals( + "Job parameter is not instance of ScheduledAsyncQueryJobRequest, type: " + + wrongParameter.getClass().getCanonicalName(), + exception.getMessage()); + } + + @Test + public void testDoRefreshThrowsException() { + spyJobRunner = spy(jobRunner); + spyJobRunner.loadJobResource(client, clusterService, threadPool, asyncQueryExecutorService); + + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .jobId("testJob") + .lastUpdateTime(Instant.now()) + .lockDurationSeconds(10L) + .scheduledQuery("REFRESH INDEX testIndex") + .dataSource("testDataSource") + .queryLang(LangType.SQL) + .build(); + + doThrow(new RuntimeException("Test exception")).when(spyJobRunner).doRefresh(request); + + Logger logger = LogManager.getLogger(ScheduledAsyncQueryJobRunner.class); + Appender mockAppender = mock(Appender.class); + when(mockAppender.getName()).thenReturn("MockAppender"); + when(mockAppender.isStarted()).thenReturn(true); + when(mockAppender.isStopped()).thenReturn(false); + ((org.apache.logging.log4j.core.Logger) logger) + .addAppender((org.apache.logging.log4j.core.Appender) mockAppender); + + spyJobRunner.runJob(request, context); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(threadPool.executor(AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME)) + .submit(captor.capture()); + + Runnable runnable = captor.getValue(); + runnable.run(); + + verify(spyJobRunner).doRefresh(eq(request)); + verify(mockAppender).append(any(LogEvent.class)); + } + + @Test + public void testRunJobWithUninitializedServices() { + ScheduledAsyncQueryJobRequest jobParameter = + ScheduledAsyncQueryJobRequest.builder() + .jobId("testJob") + .lastUpdateTime(Instant.now()) + .build(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("ClusterService is not initialized.", exception.getMessage()); + + jobRunner.loadJobResource(null, clusterService, null, null); + + exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("ThreadPool is not initialized.", exception.getMessage()); + + jobRunner.loadJobResource(null, clusterService, threadPool, null); + + exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("Client is not initialized.", exception.getMessage()); + + jobRunner.loadJobResource(client, clusterService, threadPool, null); + + exception = + assertThrows( + IllegalStateException.class, + () -> jobRunner.runJob(jobParameter, context), + "Expected IllegalStateException but no exception was thrown"); + assertEquals("AsyncQueryExecutorService is not initialized.", exception.getMessage()); + } + + @Test + public void testGetJobRunnerInstanceMultipleCalls() { + ScheduledAsyncQueryJobRunner instance1 = ScheduledAsyncQueryJobRunner.getJobRunnerInstance(); + ScheduledAsyncQueryJobRunner instance2 = ScheduledAsyncQueryJobRunner.getJobRunnerInstance(); + ScheduledAsyncQueryJobRunner instance3 = ScheduledAsyncQueryJobRunner.getJobRunnerInstance(); + + assertSame(instance1, instance2); + assertSame(instance2, instance3); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java deleted file mode 100644 index 108f1acfd5..0000000000 --- a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/OpenSearchRefreshIndexJobRequestTest.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.scheduler.model; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import org.junit.jupiter.api.Test; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; - -public class OpenSearchRefreshIndexJobRequestTest { - - @Test - public void testBuilderAndGetterMethods() { - Instant now = Instant.now(); - IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); - - OpenSearchRefreshIndexJobRequest jobRequest = - OpenSearchRefreshIndexJobRequest.builder() - .jobName("testJob") - .jobType("testType") - .schedule(schedule) - .enabled(true) - .lastUpdateTime(now) - .enabledTime(now) - .lockDurationSeconds(60L) - .jitter(0.1) - .build(); - - assertEquals("testJob", jobRequest.getName()); - assertEquals("testType", jobRequest.getJobType()); - assertEquals(schedule, jobRequest.getSchedule()); - assertTrue(jobRequest.isEnabled()); - assertEquals(now, jobRequest.getLastUpdateTime()); - assertEquals(now, jobRequest.getEnabledTime()); - assertEquals(60L, jobRequest.getLockDurationSeconds()); - assertEquals(0.1, jobRequest.getJitter()); - } - - @Test - public void testToXContent() throws IOException { - Instant now = Instant.now(); - IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); - - OpenSearchRefreshIndexJobRequest jobRequest = - OpenSearchRefreshIndexJobRequest.builder() - .jobName("testJob") - .jobType("testType") - .schedule(schedule) - .enabled(true) - .lastUpdateTime(now) - .enabledTime(now) - .lockDurationSeconds(60L) - .jitter(0.1) - .build(); - - XContentBuilder builder = XContentFactory.jsonBuilder().prettyPrint(); - jobRequest.toXContent(builder, EMPTY_PARAMS); - String jsonString = builder.toString(); - - assertTrue(jsonString.contains("\"jobName\" : \"testJob\"")); - assertTrue(jsonString.contains("\"jobType\" : \"testType\"")); - assertTrue(jsonString.contains("\"start_time\" : " + now.toEpochMilli())); - assertTrue(jsonString.contains("\"period\" : 1")); - assertTrue(jsonString.contains("\"unit\" : \"Minutes\"")); - assertTrue(jsonString.contains("\"enabled\" : true")); - assertTrue(jsonString.contains("\"lastUpdateTime\" : " + now.toEpochMilli())); - assertTrue(jsonString.contains("\"enabledTime\" : " + now.toEpochMilli())); - assertTrue(jsonString.contains("\"lockDurationSeconds\" : 60")); - assertTrue(jsonString.contains("\"jitter\" : 0.1")); - } -} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequestTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequestTest.java new file mode 100644 index 0000000000..85d1948dc3 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/model/ScheduledAsyncQueryJobRequestTest.java @@ -0,0 +1,210 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.model; + +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import org.junit.jupiter.api.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.sql.spark.rest.model.LangType; + +public class ScheduledAsyncQueryJobRequestTest { + + @Test + public void testBuilderAndGetterMethods() { + Instant now = Instant.now(); + IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); + + ScheduledAsyncQueryJobRequest jobRequest = + ScheduledAsyncQueryJobRequest.builder() + .accountId("testAccount") + .jobId("testJob") + .dataSource("testDataSource") + .scheduledQuery("SELECT * FROM test") + .queryLang(LangType.SQL) + .schedule(schedule) + .enabled(true) + .lastUpdateTime(now) + .enabledTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + assertEquals("testAccount", jobRequest.getAccountId()); + assertEquals("testJob", jobRequest.getJobId()); + assertEquals("testJob", jobRequest.getName()); + assertEquals("testDataSource", jobRequest.getDataSource()); + assertEquals("SELECT * FROM test", jobRequest.getScheduledQuery()); + assertEquals(LangType.SQL, jobRequest.getQueryLang()); + assertEquals(schedule, jobRequest.getSchedule()); + assertTrue(jobRequest.isEnabled()); + assertEquals(now, jobRequest.getLastUpdateTime()); + assertEquals(now, jobRequest.getEnabledTime()); + assertEquals(60L, jobRequest.getLockDurationSeconds()); + assertEquals(0.1, jobRequest.getJitter()); + } + + @Test + public void testToXContent() throws IOException { + Instant now = Instant.now(); + IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); + + ScheduledAsyncQueryJobRequest request = + ScheduledAsyncQueryJobRequest.builder() + .accountId("testAccount") + .jobId("testJob") + .dataSource("testDataSource") + .scheduledQuery("SELECT * FROM test") + .queryLang(LangType.SQL) + .schedule(schedule) + .enabled(true) + .enabledTime(now) + .lastUpdateTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + XContentBuilder builder = XContentFactory.jsonBuilder().prettyPrint(); + request.toXContent(builder, EMPTY_PARAMS); + String jsonString = builder.toString(); + + assertTrue(jsonString.contains("\"accountId\" : \"testAccount\"")); + assertTrue(jsonString.contains("\"jobId\" : \"testJob\"")); + assertTrue(jsonString.contains("\"dataSource\" : \"testDataSource\"")); + assertTrue(jsonString.contains("\"scheduledQuery\" : \"SELECT * FROM test\"")); + assertTrue(jsonString.contains("\"queryLang\" : \"SQL\"")); + assertTrue(jsonString.contains("\"start_time\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"period\" : 1")); + assertTrue(jsonString.contains("\"unit\" : \"Minutes\"")); + assertTrue(jsonString.contains("\"enabled\" : true")); + assertTrue(jsonString.contains("\"lastUpdateTime\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"enabledTime\" : " + now.toEpochMilli())); + assertTrue(jsonString.contains("\"lockDurationSeconds\" : 60")); + assertTrue(jsonString.contains("\"jitter\" : 0.1")); + } + + @Test + public void testFromAsyncQuerySchedulerRequest() { + Instant now = Instant.now(); + AsyncQuerySchedulerRequest request = new AsyncQuerySchedulerRequest(); + request.setJobId("testJob"); + request.setAccountId("testAccount"); + request.setDataSource("testDataSource"); + request.setScheduledQuery("SELECT * FROM test"); + request.setQueryLang(LangType.SQL); + request.setSchedule("1 minutes"); + request.setEnabled(true); + request.setLastUpdateTime(now); + request.setLockDurationSeconds(60L); + request.setJitter(0.1); + + ScheduledAsyncQueryJobRequest jobRequest = + ScheduledAsyncQueryJobRequest.fromAsyncQuerySchedulerRequest(request); + + assertEquals("testJob", jobRequest.getJobId()); + assertEquals("testAccount", jobRequest.getAccountId()); + assertEquals("testDataSource", jobRequest.getDataSource()); + assertEquals("SELECT * FROM test", jobRequest.getScheduledQuery()); + assertEquals(LangType.SQL, jobRequest.getQueryLang()); + assertEquals(new IntervalSchedule(now, 1, ChronoUnit.MINUTES), jobRequest.getSchedule()); + assertTrue(jobRequest.isEnabled()); + assertEquals(60L, jobRequest.getLockDurationSeconds()); + assertEquals(0.1, jobRequest.getJitter()); + } + + @Test + public void testFromAsyncQuerySchedulerRequestWithInvalidSchedule() { + AsyncQuerySchedulerRequest request = new AsyncQuerySchedulerRequest(); + request.setJobId("testJob"); + request.setSchedule(new Object()); // Set schedule to a non-String object + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> { + ScheduledAsyncQueryJobRequest.fromAsyncQuerySchedulerRequest(request); + }); + + assertEquals("Schedule must be a String object for parsing.", exception.getMessage()); + } + + @Test + public void testEqualsAndHashCode() { + Instant now = Instant.now(); + IntervalSchedule schedule = new IntervalSchedule(now, 1, ChronoUnit.MINUTES); + + ScheduledAsyncQueryJobRequest request1 = + ScheduledAsyncQueryJobRequest.builder() + .accountId("testAccount") + .jobId("testJob") + .dataSource("testDataSource") + .scheduledQuery("SELECT * FROM test") + .queryLang(LangType.SQL) + .schedule(schedule) + .enabled(true) + .enabledTime(now) + .lastUpdateTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + // Test toString + String toString = request1.toString(); + assertTrue(toString.contains("accountId=testAccount")); + assertTrue(toString.contains("jobId=testJob")); + assertTrue(toString.contains("dataSource=testDataSource")); + assertTrue(toString.contains("scheduledQuery=SELECT * FROM test")); + assertTrue(toString.contains("queryLang=SQL")); + assertTrue(toString.contains("enabled=true")); + assertTrue(toString.contains("lockDurationSeconds=60")); + assertTrue(toString.contains("jitter=0.1")); + + ScheduledAsyncQueryJobRequest request2 = + ScheduledAsyncQueryJobRequest.builder() + .accountId("testAccount") + .jobId("testJob") + .dataSource("testDataSource") + .scheduledQuery("SELECT * FROM test") + .queryLang(LangType.SQL) + .schedule(schedule) + .enabled(true) + .enabledTime(now) + .lastUpdateTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + assertEquals(request1, request2); + assertEquals(request1.hashCode(), request2.hashCode()); + + ScheduledAsyncQueryJobRequest request3 = + ScheduledAsyncQueryJobRequest.builder() + .accountId("differentAccount") + .jobId("testJob") + .dataSource("testDataSource") + .scheduledQuery("SELECT * FROM test") + .queryLang(LangType.SQL) + .schedule(schedule) + .enabled(true) + .enabledTime(now) + .lastUpdateTime(now) + .lockDurationSeconds(60L) + .jitter(0.1) + .build(); + + assertNotEquals(request1, request3); + assertNotEquals(request1.hashCode(), request3.hashCode()); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java new file mode 100644 index 0000000000..b119c345b9 --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.scheduler.parser; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; + +public class IntervalScheduleParserTest { + + private Instant startTime; + + @BeforeEach + public void setup() { + startTime = Instant.now(); + } + + @Test + public void testParseValidScheduleString() { + verifyParseSchedule(5, "5 minutes"); + } + + @Test + public void testParseValidScheduleStringWithDifferentUnits() { + verifyParseSchedule(120, "2 hours"); + verifyParseSchedule(1440, "1 day"); + verifyParseSchedule(30240, "3 weeks"); + } + + @Test + public void testParseNullSchedule() { + Schedule schedule = IntervalScheduleParser.parse(null, startTime); + assertNull(schedule); + } + + @Test + public void testParseScheduleObject() { + IntervalSchedule expectedSchedule = new IntervalSchedule(startTime, 10, ChronoUnit.MINUTES); + Schedule schedule = IntervalScheduleParser.parse(expectedSchedule, startTime); + assertEquals(expectedSchedule, schedule); + } + + @Test + public void testParseInvalidScheduleString() { + String scheduleStr = "invalid schedule"; + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> IntervalScheduleParser.parse(scheduleStr, startTime), + "Expected IllegalArgumentException but no exception was thrown"); + + assertEquals("Invalid interval format: " + scheduleStr.toLowerCase(), exception.getMessage()); + } + + @Test + public void testParseUnsupportedUnits() { + assertThrows( + IllegalArgumentException.class, + () -> IntervalScheduleParser.parse("1 year", startTime), + "Expected IllegalArgumentException but no exception was thrown"); + + assertThrows( + IllegalArgumentException.class, + () -> IntervalScheduleParser.parse("1 month", startTime), + "Expected IllegalArgumentException but no exception was thrown"); + } + + @Test + public void testParseNonStringSchedule() { + Object nonStringSchedule = 12345; + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> IntervalScheduleParser.parse(nonStringSchedule, startTime), + "Expected IllegalArgumentException but no exception was thrown"); + + assertEquals("Schedule must be a String object for parsing.", exception.getMessage()); + } + + @Test + public void testParseScheduleWithNanoseconds() { + verifyParseSchedule(1, "60000000000 nanoseconds"); + } + + @Test + public void testParseScheduleWithMilliseconds() { + verifyParseSchedule(1, "60000 milliseconds"); + } + + @Test + public void testParseScheduleWithMicroseconds() { + verifyParseSchedule(1, "60000000 microseconds"); + } + + @Test + public void testUnsupportedTimeUnit() { + assertThrows( + IllegalArgumentException.class, + () -> IntervalScheduleParser.convertToSupportedUnit(10, "unsupportedunit"), + "Expected IllegalArgumentException but no exception was thrown"); + } + + @Test + public void testParseScheduleWithSeconds() { + verifyParseSchedule(2, "120 seconds"); + } + + private void verifyParseSchedule(int expectedMinutes, String scheduleStr) { + Schedule schedule = IntervalScheduleParser.parse(scheduleStr, startTime); + assertEquals(new IntervalSchedule(startTime, expectedMinutes, ChronoUnit.MINUTES), schedule); + } +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 971ef5e928..560c5edadd 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -96,8 +96,8 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; -import org.opensearch.sql.spark.scheduler.OpenSearchRefreshIndexJobRequestParser; -import org.opensearch.sql.spark.scheduler.job.OpenSearchRefreshIndexJob; +import org.opensearch.sql.spark.scheduler.job.ScheduledAsyncQueryJobRunner; +import org.opensearch.sql.spark.scheduler.parser.OpenSearchScheduleQueryJobRequestParser; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; @@ -217,8 +217,6 @@ public Collection createComponents( this.client = (NodeClient) client; this.dataSourceService = createDataSourceService(); dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); - this.asyncQueryScheduler = new OpenSearchAsyncQueryScheduler(); - this.asyncQueryScheduler.loadJobResource(client, clusterService, threadPool); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); LocalClusterState.state().setClient(client); @@ -247,11 +245,13 @@ public Collection createComponents( dataSourceService, injector.getInstance(FlintIndexMetadataServiceImpl.class), injector.getInstance(FlintIndexOpFactory.class)); + AsyncQueryExecutorService asyncQueryExecutorService = + injector.getInstance(AsyncQueryExecutorService.class); + ScheduledAsyncQueryJobRunner.getJobRunnerInstance() + .loadJobResource(client, clusterService, threadPool, asyncQueryExecutorService); + return ImmutableList.of( - dataSourceService, - injector.getInstance(AsyncQueryExecutorService.class), - clusterManagerEventListener, - pluginSettings); + dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings); } @Override @@ -266,12 +266,12 @@ public String getJobIndex() { @Override public ScheduledJobRunner getJobRunner() { - return OpenSearchRefreshIndexJob.getJobRunnerInstance(); + return ScheduledAsyncQueryJobRunner.getJobRunnerInstance(); } @Override public ScheduledJobParser getJobParser() { - return OpenSearchRefreshIndexJobRequestParser.getJobParser(); + return OpenSearchScheduleQueryJobRequestParser.getJobParser(); } @Override @@ -342,6 +342,9 @@ public Collection getSystemIndexDescriptors(Settings sett systemIndexDescriptors.add( new SystemIndexDescriptor( SPARK_REQUEST_BUFFER_INDEX_NAME + "*", "SQL Spark Request Buffer index pattern")); + systemIndexDescriptors.add( + new SystemIndexDescriptor( + OpenSearchAsyncQueryScheduler.SCHEDULER_INDEX_NAME, "SQL Scheduler job index")); return systemIndexDescriptors; } } From 6c5c68597c77a9402bc680a45f95f19f5da995fe Mon Sep 17 00:00:00 2001 From: Surya Sashank Nistala Date: Wed, 4 Sep 2024 15:39:19 -0700 Subject: [PATCH 07/11] Adds validation to allow only flint queries and sql SELECT queries to security lake type datasource (#2959) * allows only flint queries and select sql queries to security lake datasource Signed-off-by: Surya Sashank Nistala * add sql validator for security lake and refactor validateSparkSqlQuery class Signed-off-by: Surya Sashank Nistala * spotless fixes Signed-off-by: Surya Sashank Nistala * address review comments. Signed-off-by: Surya Sashank Nistala * address comment to extract validate logic into a separate method in tests Signed-off-by: Surya Sashank Nistala * add more tests to get more code coverage Signed-off-by: Surya Sashank Nistala --------- Signed-off-by: Surya Sashank Nistala --- .../dispatcher/SparkQueryDispatcher.java | 4 +- .../sql/spark/utils/SQLQueryUtils.java | 67 +++++++++++-- .../sql/spark/utils/SQLQueryUtilsTest.java | 93 ++++++++++++++++++- .../sql/datasource/model/DataSourceType.java | 2 + 4 files changed, 152 insertions(+), 14 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 710f472acb..c4b5c89540 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -54,7 +54,9 @@ public DispatchQueryResponse dispatch( dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); } - List validationErrors = SQLQueryUtils.validateSparkSqlQuery(query); + List validationErrors = + SQLQueryUtils.validateSparkSqlQuery( + dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query); if (!validationErrors.isEmpty()) { throw new IllegalArgumentException( "Query is not allowed: " + String.join(", ", validationErrors)); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index ff08a8f41e..ce3bcab06b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -15,9 +15,13 @@ import org.antlr.v4.runtime.CommonTokenStream; import org.antlr.v4.runtime.misc.Interval; import org.antlr.v4.runtime.tree.ParseTree; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; @@ -25,6 +29,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.StatementContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; @@ -38,13 +43,14 @@ */ @UtilityClass public class SQLQueryUtils { + private static final Logger logger = LogManager.getLogger(SQLQueryUtils.class); public static List extractFullyQualifiedTableNames(String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); - SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); + StatementContext statement = sqlBaseParser.statement(); SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor(); statement.accept(sparkSqlTableNameVisitor); return sparkSqlTableNameVisitor.getFullyQualifiedTableNames(); @@ -77,32 +83,73 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } - public static List validateSparkSqlQuery(String sqlQuery) { - SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor(); + public static List validateSparkSqlQuery(DataSource datasource, String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); try { - SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); - sparkSqlValidatorVisitor.visit(statement); - return sparkSqlValidatorVisitor.getValidationErrors(); - } catch (SyntaxCheckException syntaxCheckException) { + SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource); + StatementContext statement = sqlBaseParser.statement(); + sqlParserBaseVisitor.visit(statement); + return sqlParserBaseVisitor.getValidationErrors(); + } catch (SyntaxCheckException e) { + logger.error( + String.format( + "Failed to parse sql statement context while validating sql query %s", sqlQuery), + e); return Collections.emptyList(); } } - private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor { + private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) { + if (datasource != null + && datasource.getConnectorType() != null + && datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) { + return new SparkSqlSecurityLakeValidatorVisitor(); + } else { + return new SparkSqlValidatorVisitor(); + } + } - @Getter private final List validationErrors = new ArrayList<>(); + /** + * A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class + * supports accumulating validation errors on visiting sql statement + */ + @Getter + private static class SqlBaseValidatorVisitor extends SqlBaseParserBaseVisitor { + private final List validationErrors = new ArrayList<>(); + } + /** A generic validator impl for Spark Sql Queries */ + private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor { @Override public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { - validationErrors.add("Creating user-defined functions is not allowed"); + getValidationErrors().add("Creating user-defined functions is not allowed"); return super.visitCreateFunction(ctx); } } + /** A validator impl specific to Security Lake for Spark Sql Queries */ + private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor { + + public SparkSqlSecurityLakeValidatorVisitor() { + // only select statement allowed. hence we add the validation error to all types of statements + // by default + // and remove the validation error only for select statement. + getValidationErrors() + .add( + "Unsupported sql statement for security lake data source. Only select queries are" + + " allowed"); + } + + @Override + public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) { + getValidationErrors().clear(); + return super.visitStatementDefault(ctx); + } + } + public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { @Getter private List fullyQualifiedTableNames = new LinkedList<>(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index fe7777606c..235fe84c70 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; @@ -18,7 +19,10 @@ import lombok.Getter; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; @@ -27,6 +31,8 @@ @ExtendWith(MockitoExtension.class) public class SQLQueryUtilsTest { + @Mock private DataSource dataSource; + @Test void testExtractionOfTableNameFromSQLQueries() { String sqlQuery = "select * from my_glue.default.http_logs"; @@ -404,15 +410,96 @@ void testAutoRefresh() { @Test void testValidateSparkSqlQuery_ValidQuery() { - String validQuery = "SELECT * FROM users WHERE age > 18"; - List errors = SQLQueryUtils.validateSparkSqlQuery(validQuery); + List errors = + validateSparkSqlQueryForDataSourceType( + "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste'", + DataSourceType.PROMETHEUS); + assertTrue(errors.isEmpty(), "Valid query should not produce any errors"); } + @Test + void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake() { + List errors = + validateSparkSqlQueryForDataSourceType( + "SELECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); + + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + @Test + void testValidateSparkSqlQuery_SelectQuery_DataSourceTypeNull() { + List errors = + validateSparkSqlQueryForDataSourceType("SELECT * FROM users WHERE age > 18", null); + + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + @Test + void testValidateSparkSqlQuery_InvalidQuery_SyntaxCheckFailureSkippedWithoutValidationError() { + List errors = + validateSparkSqlQueryForDataSourceType( + "SEECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); + + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + @Test + void testValidateSparkSqlQuery_nullDatasource() { + List errors = + SQLQueryUtils.validateSparkSqlQuery(null, "SELECT * FROM users WHERE age > 18"); + assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); + } + + private List validateSparkSqlQueryForDataSourceType( + String query, DataSourceType dataSourceType) { + when(this.dataSource.getConnectorType()).thenReturn(dataSourceType); + + return SQLQueryUtils.validateSparkSqlQuery(this.dataSource, query); + } + + @Test + void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake_ValidationFails() { + List errors = + validateSparkSqlQueryForDataSourceType( + "REFRESH INDEX cv1 ON mys3.default.http_logs", DataSourceType.SECURITY_LAKE); + + assertFalse( + errors.isEmpty(), + "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" + + " queries. Given query was REFRESH sql query"); + assertEquals( + errors.get(0), + "Unsupported sql statement for security lake data source. Only select queries are allowed"); + } + + @Test + void + testValidateSparkSqlQuery_NonSelectStatementContainingSelectClause_DataSourceSecurityLake_ValidationFails() { + String query = + "CREATE TABLE AccountSummaryOrWhatever AS " + + "select taxid, address1, count(address1) from dbo.t " + + "group by taxid, address1;"; + + List errors = + validateSparkSqlQueryForDataSourceType(query, DataSourceType.SECURITY_LAKE); + + assertFalse( + errors.isEmpty(), + "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" + + " queries. Given query was REFRESH sql query"); + assertEquals( + errors.get(0), + "Unsupported sql statement for security lake data source. Only select queries are allowed"); + } + @Test void testValidateSparkSqlQuery_InvalidQuery() { + when(dataSource.getConnectorType()).thenReturn(DataSourceType.PROMETHEUS); String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'"; - List errors = SQLQueryUtils.validateSparkSqlQuery(invalidQuery); + + List errors = SQLQueryUtils.validateSparkSqlQuery(dataSource, invalidQuery); + assertFalse(errors.isEmpty(), "Invalid query should produce errors"); assertEquals(1, errors.size(), "Should have one error"); assertEquals( diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java index c74964fc00..ac8ae1a5e1 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java @@ -7,9 +7,11 @@ import java.util.HashMap; import java.util.Map; +import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; @RequiredArgsConstructor +@EqualsAndHashCode public class DataSourceType { public static DataSourceType PROMETHEUS = new DataSourceType("PROMETHEUS"); public static DataSourceType OPENSEARCH = new DataSourceType("OPENSEARCH"); From b14a8cb7eeb5bd74eb1aebfe7cd8d56bfe2c88d8 Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Wed, 4 Sep 2024 16:09:28 -0700 Subject: [PATCH 08/11] Fix handler for existing query (#2968) Signed-off-by: Tomoyuki Morita --- .../spark/dispatcher/RefreshQueryHandler.java | 4 +- .../dispatcher/SparkQueryDispatcher.java | 2 +- .../sql/spark/dispatcher/model/JobType.java | 1 + .../asyncquery/AsyncQueryCoreIntegTest.java | 4 +- .../dispatcher/SparkQueryDispatcherTest.java | 39 +++++++++++++------ 5 files changed, 34 insertions(+), 16 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 38145a143e..cf5a0c6c59 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -73,7 +73,7 @@ public String cancelJob( @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { - leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource())); + leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource())); DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); @@ -83,7 +83,7 @@ public DispatchQueryResponse submit( .resultIndex(resp.getResultIndex()) .sessionId(resp.getSessionId()) .datasourceName(dataSourceMetadata.getName()) - .jobType(JobType.BATCH) + .jobType(JobType.REFRESH) .indexName(context.getIndexQueryDetails().openSearchIndexName()) .build(); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index c4b5c89540..4df2b5450d 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -179,7 +179,7 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( return queryHandlerFactory.getInteractiveQueryHandler(); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { return queryHandlerFactory.getIndexDMLHandler(); - } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { + } else if (asyncQueryJobMetadata.getJobType() == JobType.REFRESH) { return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId()); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java index 01f5f422e9..af1f69d74b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java @@ -8,6 +8,7 @@ public enum JobType { INTERACTIVE("interactive"), STREAMING("streaming"), + REFRESH("refresh"), BATCH("batch"); private String text; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 226e0ff5eb..e1c9bb6f39 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -424,7 +424,7 @@ public void createRefreshQuery() { verifyGetQueryIdCalled(); verify(leaseManager).borrow(any()); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID, JobType.BATCH); + verifyStoreJobMetadataCalled(JOB_ID, JobType.REFRESH); } @Test @@ -541,7 +541,7 @@ public void cancelIndexDMLQuery() { @Test public void cancelRefreshQuery() { givenJobMetadataExists( - getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.BATCH).indexName(INDEX_NAME)); + getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.REFRESH).indexName(INDEX_NAME)); when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index d040db24b2..5154b71574 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -387,6 +387,7 @@ void testDispatchCreateManualRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -685,6 +686,7 @@ void testRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -859,12 +861,7 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) - .thenReturn( - new CancelJobRunResult() - .withJobRunId(EMR_JOB_ID) - .withApplicationId(EMRS_APPLICATION_ID)); + givenCancelJobRunSucceed(); String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); @@ -925,17 +922,32 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + @Test + void testCancelBatchJob() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadata(JobType.BATCH), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + private void givenCancelJobRunSucceed() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - - String queryId = - sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); - - Assertions.assertEquals(QUERY_ID, queryId); } @Test @@ -1184,11 +1196,16 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str } private AsyncQueryJobMetadata asyncQueryJobMetadata() { + return asyncQueryJobMetadata(JobType.INTERACTIVE); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadata(JobType jobType) { return AsyncQueryJobMetadata.builder() .queryId(QUERY_ID) .applicationId(EMRS_APPLICATION_ID) .jobId(EMR_JOB_ID) .datasourceName(MY_GLUE) + .jobType(jobType) .build(); } From da622ebd6206b5215f0eceffbbe10218853d6d6d Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Wed, 4 Sep 2024 19:10:59 -0700 Subject: [PATCH 09/11] Add feature flag for async query scheduler (#2973) * Add feature flag for async query scheduler Signed-off-by: Louis Chu * Fix Jacoco verification Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- .../spark/data/constants/SparkConstants.java | 4 ++ .../asyncquery/AsyncQueryCoreIntegTest.java | 6 +- ...archAsyncQuerySchedulerConfigComposer.java | 36 ++++++++++ .../parser/IntervalScheduleParser.java | 1 - .../config/AsyncExecutorServiceModule.java | 2 + ...AsyncQuerySchedulerConfigComposerTest.java | 68 ++++++++++++++++++ .../parser/IntervalScheduleParserTest.java | 8 +++ .../sql/common/setting/Settings.java | 4 ++ docs/user/admin/settings.rst | 69 +++++++++++++++++++ .../setting/OpenSearchSettings.java | 27 ++++++++ .../setting/OpenSearchSettingsTest.java | 20 ++++++ 11 files changed, 241 insertions(+), 4 deletions(-) create mode 100644 async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java index 9b82022d8f..43815a9904 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -87,6 +87,10 @@ public class SparkConstants { public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/"; public static final String FLINT_JOB_QUERY = "spark.flint.job.query"; public static final String FLINT_JOB_QUERY_ID = "spark.flint.job.queryId"; + public static final String FLINT_JOB_EXTERNAL_SCHEDULER_ENABLED = + "spark.flint.job.externalScheduler.enabled"; + public static final String FLINT_JOB_EXTERNAL_SCHEDULER_INTERVAL = + "spark.flint.job.externalScheduler.interval"; public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex"; public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId"; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index e1c9bb6f39..ca4a8736d2 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -227,7 +227,7 @@ public void createDropIndexQueryWithScheduler() { assertNull(response.getSessionId()); verifyGetQueryIdCalled(); verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); verify(asyncQueryScheduler).unscheduleJob(indexName); } @@ -275,7 +275,7 @@ public void createVacuumIndexQueryWithScheduler() { verify(flintIndexClient).deleteIndex(indexName); verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); verify(asyncQueryScheduler).removeJob(indexName); } @@ -342,7 +342,7 @@ public void createAlterIndexQueryWithScheduler() { verify(asyncQueryScheduler).unscheduleJob(indexName); verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } @Test diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java new file mode 100644 index 0000000000..6dce09a406 --- /dev/null +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_EXTERNAL_SCHEDULER_ENABLED; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_JOB_EXTERNAL_SCHEDULER_INTERVAL; + +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.parameter.GeneralSparkParameterComposer; +import org.opensearch.sql.spark.parameter.SparkSubmitParameters; + +@RequiredArgsConstructor +public class OpenSearchAsyncQuerySchedulerConfigComposer implements GeneralSparkParameterComposer { + private final Settings settings; + + @Override + public void compose( + SparkSubmitParameters sparkSubmitParameters, + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext context) { + String externalSchedulerEnabled = + settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED); + String externalSchedulerInterval = + settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL); + sparkSubmitParameters.setConfigItem( + FLINT_JOB_EXTERNAL_SCHEDULER_ENABLED, externalSchedulerEnabled); + sparkSubmitParameters.setConfigItem( + FLINT_JOB_EXTERNAL_SCHEDULER_INTERVAL, externalSchedulerInterval); + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java index 2d5a1b332f..47e652c570 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParser.java @@ -15,7 +15,6 @@ /** Parse string raw schedule into job scheduler IntervalSchedule */ public class IntervalScheduleParser { - private static final Pattern DURATION_PATTERN = Pattern.compile( "^(\\d+)\\s*(years?|months?|weeks?|days?|hours?|minutes?|minute|mins?|seconds?|secs?|milliseconds?|millis?|microseconds?|microsecond|micros?|micros|nanoseconds?|nanos?)$", diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 52ffda483c..c6f6ffcd81 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -24,6 +24,7 @@ import org.opensearch.sql.spark.asyncquery.OpenSearchAsyncQueryJobMetadataStorageService; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; +import org.opensearch.sql.spark.config.OpenSearchAsyncQuerySchedulerConfigComposer; import org.opensearch.sql.spark.config.OpenSearchExtraParameterComposer; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigClusterSettingLoader; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -168,6 +169,7 @@ public SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider collection.register( DataSourceType.SECURITY_LAKE, new S3GlueDataSourceSparkParameterComposer(clusterSettingLoader)); + collection.register(new OpenSearchAsyncQuerySchedulerConfigComposer(settings)); collection.register(new OpenSearchExtraParameterComposer(clusterSettingLoader)); return new SparkSubmitParametersBuilderProvider(collection); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java new file mode 100644 index 0000000000..7836c63b7a --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java @@ -0,0 +1,68 @@ +package org.opensearch.sql.spark.config; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.parameter.SparkSubmitParameters; + +@ExtendWith(MockitoExtension.class) +public class OpenSearchAsyncQuerySchedulerConfigComposerTest { + + @Mock private Settings settings; + @Mock private SparkSubmitParameters sparkSubmitParameters; + @Mock private DispatchQueryRequest dispatchQueryRequest; + @Mock private AsyncQueryRequestContext context; + + private OpenSearchAsyncQuerySchedulerConfigComposer composer; + + @BeforeEach + public void setUp() { + composer = new OpenSearchAsyncQuerySchedulerConfigComposer(settings); + } + + @Test + public void testCompose() { + when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED)) + .thenReturn("true"); + when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL)) + .thenReturn("10 minutes"); + + composer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + + verify(sparkSubmitParameters) + .setConfigItem("spark.flint.job.externalScheduler.enabled", "true"); + verify(sparkSubmitParameters) + .setConfigItem("spark.flint.job.externalScheduler.interval", "10 minutes"); + } + + @Test + public void testComposeWithDisabledScheduler() { + when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED)) + .thenReturn("false"); + + composer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + + verify(sparkSubmitParameters) + .setConfigItem("spark.flint.job.externalScheduler.enabled", "false"); + } + + @Test + public void testComposeWithMissingInterval() { + when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED)) + .thenReturn("true"); + when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL)) + .thenReturn(""); + + composer.compose(sparkSubmitParameters, dispatchQueryRequest, context); + + verify(sparkSubmitParameters).setConfigItem("spark.flint.job.externalScheduler.interval", ""); + } +} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java index b119c345b9..f211548c7c 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/scheduler/parser/IntervalScheduleParserTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.scheduler.parser; +import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -25,6 +26,13 @@ public void setup() { startTime = Instant.now(); } + @Test + public void testConstructor() { + // Test that the constructor of IntervalScheduleParser can be invoked + IntervalScheduleParser parser = new IntervalScheduleParser(); + assertNotNull(parser); + } + @Test public void testParseValidScheduleString() { verifyParseSchedule(5, "5 minutes"); diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index b6643f3209..0037032d22 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -51,6 +51,10 @@ public enum Key { /** Async query Settings * */ ASYNC_QUERY_ENABLED("plugins.query.executionengine.async_query.enabled"), + ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED( + "plugins.query.executionengine.async_query.external_scheduler.enabled"), + ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL( + "plugins.query.executionengine.async_query.external_scheduler.interval"), STREAMING_JOB_HOUSEKEEPER_INTERVAL( "plugins.query.executionengine.spark.streamingjobs.housekeeper.interval"); diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index 236406e2c7..71718d1726 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -639,6 +639,75 @@ Request:: } } +plugins.query.executionengine.async_query.external_scheduler.enabled +===================================================================== + +Description +----------- +This setting controls whether the external scheduler is enabled for async queries. + +* Default Value: true +* Scope: Node-level +* Dynamic Update: Yes, this setting can be updated dynamically. + +To disable the external scheduler, use the following command: + +Request :: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_cluster/settings \ + ... -d '{"transient":{"plugins.query.executionengine.async_query.external_scheduler.enabled":"false"}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "async_query": { + "external_scheduler": { + "enabled": "false" + } + } + } + } + } + } + } + +plugins.query.executionengine.async_query.external_scheduler.interval +===================================================================== + +Description +----------- +This setting defines the interval at which the external scheduler applies for auto refresh queries. It optimizes Spark applications by allowing them to automatically decide whether to use the Spark scheduler or the external scheduler. + +* Default Value: None (must be explicitly set) +* Format: A string representing a time duration follows Spark `CalendarInterval `__ format (e.g., ``10 minutes`` for 10 minutes, ``1 hour`` for 1 hour). + +To modify the interval to 10 minutes for example, use this command: + +Request :: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_cluster/settings \ + ... -d '{"transient":{"plugins.query.executionengine.async_query.external_scheduler.interval":"10 minutes"}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "executionengine": { + "async_query": { + "external_scheduler": { + "interval": "10 minutes" + } + } + } + } + } + } + } + plugins.query.executionengine.spark.streamingjobs.housekeeper.interval ====================================================================== diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 494b906b55..1083dbd836 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -154,6 +154,19 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED_SETTING = + Setting.boolSetting( + Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED.getKeyValue(), + true, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + + public static final Setting ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL_SETTING = + Setting.simpleString( + Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL.getKeyValue(), + Setting.Property.NodeScope, + Setting.Property.Dynamic); + public static final Setting SPARK_EXECUTION_ENGINE_CONFIG = Setting.simpleString( Key.SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue(), @@ -298,6 +311,18 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.ASYNC_QUERY_ENABLED, ASYNC_QUERY_ENABLED_SETTING, new Updater(Key.ASYNC_QUERY_ENABLED)); + register( + settingBuilder, + clusterSettings, + Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED, + ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED_SETTING, + new Updater(Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED)); + register( + settingBuilder, + clusterSettings, + Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL, + ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL_SETTING, + new Updater(Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL)); register( settingBuilder, clusterSettings, @@ -419,6 +444,8 @@ public static List> pluginSettings() { .add(DATASOURCE_URI_HOSTS_DENY_LIST) .add(DATASOURCE_ENABLED_SETTING) .add(ASYNC_QUERY_ENABLED_SETTING) + .add(ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED_SETTING) + .add(ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL_SETTING) .add(SPARK_EXECUTION_ENGINE_CONFIG) .add(SPARK_EXECUTION_SESSION_LIMIT_SETTING) .add(SPARK_EXECUTION_REFRESH_JOB_LIMIT_SETTING) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java index 84fb705ae0..026f0c6218 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java @@ -15,6 +15,8 @@ import static org.mockito.Mockito.when; import static org.opensearch.common.unit.TimeValue.timeValueMinutes; import static org.opensearch.sql.opensearch.setting.LegacyOpenDistroSettings.legacySettings; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL_SETTING; import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_INTERVAL_SETTING; import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_WINDOW_SETTING; import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.PPL_ENABLED_SETTING; @@ -195,4 +197,22 @@ void getSparkExecutionEngineConfigSetting() { .put(SPARK_EXECUTION_ENGINE_CONFIG.getKey(), sparkConfig) .build())); } + + @Test + void getAsyncQueryExternalSchedulerEnabledSetting() { + // Default is true + assertEquals( + true, + ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED_SETTING.get( + org.opensearch.common.settings.Settings.builder().build())); + } + + @Test + void getAsyncQueryExternalSchedulerIntervalSetting() { + // Default is empty string + assertEquals( + "", + ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL_SETTING.get( + org.opensearch.common.settings.Settings.builder().build())); + } } From 06c56e7d758a2ba8df9be852f33ce182b7fcb352 Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Thu, 5 Sep 2024 15:27:42 -0700 Subject: [PATCH 10/11] Fix type mismatch (#2990) Signed-off-by: Louis Chu --- .../config/OpenSearchAsyncQuerySchedulerConfigComposer.java | 4 ++-- .../OpenSearchAsyncQuerySchedulerConfigComposerTest.java | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java index 6dce09a406..f791b050a1 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposer.java @@ -24,12 +24,12 @@ public void compose( SparkSubmitParameters sparkSubmitParameters, DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext context) { - String externalSchedulerEnabled = + Boolean externalSchedulerEnabled = settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED); String externalSchedulerInterval = settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL); sparkSubmitParameters.setConfigItem( - FLINT_JOB_EXTERNAL_SCHEDULER_ENABLED, externalSchedulerEnabled); + FLINT_JOB_EXTERNAL_SCHEDULER_ENABLED, String.valueOf(externalSchedulerEnabled)); sparkSubmitParameters.setConfigItem( FLINT_JOB_EXTERNAL_SCHEDULER_INTERVAL, externalSchedulerInterval); } diff --git a/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java index 7836c63b7a..1556d4db3f 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/config/OpenSearchAsyncQuerySchedulerConfigComposerTest.java @@ -31,7 +31,7 @@ public void setUp() { @Test public void testCompose() { when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED)) - .thenReturn("true"); + .thenReturn(true); when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL)) .thenReturn("10 minutes"); @@ -46,7 +46,7 @@ public void testCompose() { @Test public void testComposeWithDisabledScheduler() { when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED)) - .thenReturn("false"); + .thenReturn(false); composer.compose(sparkSubmitParameters, dispatchQueryRequest, context); @@ -57,7 +57,7 @@ public void testComposeWithDisabledScheduler() { @Test public void testComposeWithMissingInterval() { when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_ENABLED)) - .thenReturn("true"); + .thenReturn(true); when(settings.getSettingValue(Settings.Key.ASYNC_QUERY_EXTERNAL_SCHEDULER_INTERVAL)) .thenReturn(""); From 83e89fb0f6a659b6cf5877d14ef260438b459c61 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 5 Sep 2024 18:50:16 -0700 Subject: [PATCH 11/11] Delegate Flint index vacuum operation to Spark (#2985) * Remove vacuum dispatch and update UT Signed-off-by: Chen Dai * Remove unused code and test Signed-off-by: Chen Dai * Fix jacoco test Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai --- .../sql/spark/dispatcher/IndexDMLHandler.java | 2 - .../dispatcher/SparkQueryDispatcher.java | 1 - .../flint/operation/FlintIndexOpFactory.java | 9 - .../flint/operation/FlintIndexOpVacuum.java | 66 ---- .../sql/spark/utils/SQLQueryUtils.java | 25 -- .../asyncquery/AsyncQueryCoreIntegTest.java | 42 +-- .../dispatcher/SparkQueryDispatcherTest.java | 311 +++--------------- .../operation/FlintIndexOpFactoryTest.java | 5 - .../operation/FlintIndexOpVacuumTest.java | 261 --------------- .../sql/spark/utils/SQLQueryUtilsTest.java | 6 +- .../asyncquery/IndexQuerySpecVacuumTest.java | 218 ------------ .../flint/OpenSearchFlintIndexClientTest.java | 42 +++ 12 files changed, 98 insertions(+), 890 deletions(-) delete mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java delete mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java delete mode 100644 async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java create mode 100644 async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexClientTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index 4698bfcccc..7211da0941 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -138,8 +138,6 @@ private FlintIndexOp getIndexOp( case ALTER: return flintIndexOpFactory.getAlter( indexQueryDetails.getFlintIndexOptions(), dispatchQueryRequest.getDatasource()); - case VACUUM: - return flintIndexOpFactory.getVacuum(dispatchQueryRequest.getDatasource()); default: throw new IllegalStateException( String.format( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 4df2b5450d..50e8403d36 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -150,7 +150,6 @@ private boolean isEligibleForStreamingQuery(IndexQueryDetails indexQueryDetails) private boolean isEligibleForIndexDMLHandling(IndexQueryDetails indexQueryDetails) { return IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType()) - || IndexQueryActionType.VACUUM.equals(indexQueryDetails.getIndexQueryActionType()) || (IndexQueryActionType.ALTER.equals(indexQueryDetails.getIndexQueryActionType()) && (indexQueryDetails .getFlintIndexOptions() diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java index 9f925e0bcf..d82b29e928 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactory.java @@ -36,15 +36,6 @@ public FlintIndexOpAlter getAlter(FlintIndexOptions flintIndexOptions, String da asyncQueryScheduler); } - public FlintIndexOpVacuum getVacuum(String datasource) { - return new FlintIndexOpVacuum( - flintIndexStateModelService, - datasource, - flintIndexClient, - emrServerlessClientFactory, - asyncQueryScheduler); - } - public FlintIndexOpCancel getCancel(String datasource) { return new FlintIndexOpCancel( flintIndexStateModelService, datasource, emrServerlessClientFactory); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java deleted file mode 100644 index 324ddb5720..0000000000 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.flint.operation; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.flint.FlintIndexClient; -import org.opensearch.sql.spark.flint.FlintIndexMetadata; -import org.opensearch.sql.spark.flint.FlintIndexState; -import org.opensearch.sql.spark.flint.FlintIndexStateModel; -import org.opensearch.sql.spark.flint.FlintIndexStateModelService; -import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; - -/** Flint index vacuum operation. */ -public class FlintIndexOpVacuum extends FlintIndexOp { - private static final Logger LOG = LogManager.getLogger(); - - private final AsyncQueryScheduler asyncQueryScheduler; - - /** OpenSearch client. */ - private final FlintIndexClient flintIndexClient; - - public FlintIndexOpVacuum( - FlintIndexStateModelService flintIndexStateModelService, - String datasourceName, - FlintIndexClient flintIndexClient, - EMRServerlessClientFactory emrServerlessClientFactory, - AsyncQueryScheduler asyncQueryScheduler) { - super(flintIndexStateModelService, datasourceName, emrServerlessClientFactory); - this.flintIndexClient = flintIndexClient; - this.asyncQueryScheduler = asyncQueryScheduler; - } - - @Override - boolean validate(FlintIndexState state) { - return state == FlintIndexState.DELETED; - } - - @Override - FlintIndexState transitioningState() { - return FlintIndexState.VACUUMING; - } - - @Override - public void runOp( - FlintIndexMetadata flintIndexMetadata, - FlintIndexStateModel flintIndex, - AsyncQueryRequestContext asyncQueryRequestContext) { - LOG.info("Vacuuming Flint index {}", flintIndexMetadata.getOpensearchIndexName()); - if (flintIndexMetadata.getFlintIndexOptions().isExternalScheduler()) { - asyncQueryScheduler.removeJob(flintIndexMetadata.getOpensearchIndexName()); - } - flintIndexClient.deleteIndex(flintIndexMetadata.getOpensearchIndexName()); - } - - @Override - FlintIndexState stableState() { - // Instruct StateStore to purge the index state doc - return FlintIndexState.NONE; - } -} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index ce3bcab06b..b1a8c3d4f6 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -267,31 +267,6 @@ public Void visitDropMaterializedViewStatement( return super.visitDropMaterializedViewStatement(ctx); } - @Override - public Void visitVacuumSkippingIndexStatement( - FlintSparkSqlExtensionsParser.VacuumSkippingIndexStatementContext ctx) { - indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.VACUUM); - indexQueryDetailsBuilder.indexType(FlintIndexType.SKIPPING); - return super.visitVacuumSkippingIndexStatement(ctx); - } - - @Override - public Void visitVacuumCoveringIndexStatement( - FlintSparkSqlExtensionsParser.VacuumCoveringIndexStatementContext ctx) { - indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.VACUUM); - indexQueryDetailsBuilder.indexType(FlintIndexType.COVERING); - return super.visitVacuumCoveringIndexStatement(ctx); - } - - @Override - public Void visitVacuumMaterializedViewStatement( - FlintSparkSqlExtensionsParser.VacuumMaterializedViewStatementContext ctx) { - indexQueryDetailsBuilder.indexQueryActionType(IndexQueryActionType.VACUUM); - indexQueryDetailsBuilder.indexType(FlintIndexType.MATERIALIZED_VIEW); - indexQueryDetailsBuilder.mvName(ctx.mvName.getText()); - return super.visitVacuumMaterializedViewStatement(ctx); - } - @Override public Void visitDescribeCoveringIndexStatement( FlintSparkSqlExtensionsParser.DescribeCoveringIndexStatementContext ctx) { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index ca4a8736d2..1214935dc6 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -236,32 +236,12 @@ public void createDropIndexQueryWithScheduler() { public void createVacuumIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); + givenSessionExists(); when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); - String indexName = "flint_datasource_name_table_name_index_name_index"; - givenFlintIndexMetadataExists(indexName); - - CreateAsyncQueryResponse response = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest( - "VACUUM INDEX index_name ON table_name", DATASOURCE_NAME, LangType.SQL), - asyncQueryRequestContext); - - assertEquals(QUERY_ID, response.getQueryId()); - assertNull(response.getSessionId()); - verifyGetQueryIdCalled(); - verify(flintIndexClient).deleteIndex(indexName); - verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); - } - - @Test - public void createVacuumIndexQueryWithScheduler() { - givenSparkExecutionEngineConfigIsSupplied(); - givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); - - String indexName = "flint_datasource_name_table_name_index_name_index"; - givenFlintIndexMetadataExistsWithExternalScheduler(indexName); + when(sessionIdProvider.getSessionId(any())).thenReturn(SESSION_ID); + givenSessionExists(); // called twice + when(awsemrServerless.startJobRun(any())) + .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( @@ -270,14 +250,12 @@ public void createVacuumIndexQueryWithScheduler() { asyncQueryRequestContext); assertEquals(QUERY_ID, response.getQueryId()); - assertNull(response.getSessionId()); + assertEquals(SESSION_ID, response.getSessionId()); verifyGetQueryIdCalled(); - - verify(flintIndexClient).deleteIndex(indexName); - verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); - - verify(asyncQueryScheduler).removeJob(indexName); + verifyGetSessionIdCalled(); + verify(leaseManager).borrow(any()); + verifyStartJobRunCalled(); + verifyStoreJobMetadataCalled(JOB_ID, JobType.INTERACTIVE); } @Test diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 5154b71574..8b855c190c 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -163,84 +163,12 @@ void setUp() { @Test void testDispatchSelectQuery() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, MY_GLUE); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - MY_GLUE, asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - DispatchQueryRequest.builder() - .applicationId(EMRS_APPLICATION_ID) - .query(query) - .datasource(MY_GLUE) - .langType(LangType.SQL) - .executionRoleARN(EMRS_EXECUTION_ROLE) - .clusterName(TEST_CLUSTER_NAME) - .sparkSubmitParameterModifier(sparkSubmitParameterModifier) - .build(), - asyncQueryRequestContext); - - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - verifyNoInteractions(flintIndexMetadataService); + testDispatchBatchQuery("select * from my_glue.default.http_logs"); } @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, MY_GLUE); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = "select * from my_glue.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - MY_GLUE, asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); - - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - verifyNoInteractions(flintIndexMetadataService); + testDispatchBatchQuery("select * from my_glue.default.http_logs"); } @Test @@ -354,41 +282,9 @@ void testDispatchCreateAutoRefreshIndexQuery() { @Test void testDispatchCreateManualRefreshIndexQuery() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, "my_glue"); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = + testDispatchBatchQuery( "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" - + " (auto_refresh = false)"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - "my_glue", asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); - - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - Assertions.assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType()); - verifyNoInteractions(flintIndexMetadataService); + + " (auto_refresh = false)"); } @Test @@ -460,84 +356,12 @@ void testDispatchWithSparkUDFQuery() { @Test void testInvalidSQLQueryDispatchToSpark() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, MY_GLUE); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = "myselect 1"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - MY_GLUE, asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch( - DispatchQueryRequest.builder() - .applicationId(EMRS_APPLICATION_ID) - .query(query) - .datasource(MY_GLUE) - .langType(LangType.SQL) - .executionRoleARN(EMRS_EXECUTION_ROLE) - .clusterName(TEST_CLUSTER_NAME) - .sparkSubmitParameterModifier(sparkSubmitParameterModifier) - .build(), - asyncQueryRequestContext); - - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - verifyNoInteractions(flintIndexMetadataService); + testDispatchBatchQuery("myselect 1"); } @Test void testDispatchQueryWithoutATableAndDataSourceName() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, MY_GLUE); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = "show tables"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - MY_GLUE, asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); - - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - verifyNoInteractions(flintIndexMetadataService); + testDispatchBatchQuery("show tables"); } @Test @@ -619,38 +443,7 @@ void testDispatchMaterializedViewQuery() { @Test void testDispatchShowMVQuery() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, MY_GLUE); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = "SHOW MATERIALIZED VIEW IN mys3.default"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - MY_GLUE, asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); - - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - verifyNoInteractions(flintIndexMetadataService); + testDispatchBatchQuery("SHOW MATERIALIZED VIEW IN mys3.default"); } @Test @@ -692,38 +485,7 @@ void testRefreshIndexQuery() { @Test void testDispatchDescribeIndexQuery() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); - HashMap tags = new HashMap<>(); - tags.put(DATASOURCE_TAG_KEY, MY_GLUE); - tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"; - String sparkSubmitParameters = - constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); - StartJobRequest expected = - new StartJobRequest( - "TEST_CLUSTER:batch", - null, - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - sparkSubmitParameters, - tags, - false, - "query_execution_result_my_glue"); - when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - MY_GLUE, asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - - DispatchQueryResponse dispatchQueryResponse = - sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); - - verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); - Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); - verifyNoInteractions(flintIndexMetadataService); + testDispatchBatchQuery("DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"); } @Test @@ -817,26 +579,7 @@ void testDispatchDropIndexQuery() { @Test void testDispatchVacuumIndexQuery() { - QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); - sparkQueryDispatcher = - new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); - - String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs"; - DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - "my_glue", asyncQueryRequestContext)) - .thenReturn(dataSourceMetadata); - when(queryHandlerFactory.getIndexDMLHandler()) - .thenReturn( - new IndexDMLHandler( - jobExecutionResponseReader, - flintIndexMetadataService, - indexDMLResultStorageService, - flintIndexOpFactory)); - - sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); - verify(queryHandlerFactory, times(1)).getIndexDMLHandler(); + testDispatchBatchQuery("VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs"); } @Test @@ -1087,6 +830,42 @@ void testDispatchQueryWithExtraSparkSubmitParameters() { } } + private void testDispatchBatchQuery(String query) { + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID); + HashMap tags = new HashMap<>(); + tags.put(DATASOURCE_TAG_KEY, MY_GLUE); + tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); + + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString(query, null, QUERY_ID); + StartJobRequest expected = + new StartJobRequest( + "TEST_CLUSTER:batch", + null, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + "query_execution_result_my_glue"); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) + .thenReturn(dataSourceMetadata); + + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); + + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType()); + verifyNoInteractions(flintIndexMetadataService); + } + private String constructExpectedSparkSubmitParameterString(String query) { return constructExpectedSparkSubmitParameterString(query, null, null); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java index 62ac98f1a2..e73c5614ae 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpFactoryTest.java @@ -41,11 +41,6 @@ void getAlter() { assertNotNull(flintIndexOpFactory.getAlter(new FlintIndexOptions(), DATASOURCE_NAME)); } - @Test - void getVacuum() { - assertNotNull(flintIndexOpFactory.getDrop(DATASOURCE_NAME)); - } - @Test void getCancel() { assertNotNull(flintIndexOpFactory.getDrop(DATASOURCE_NAME)); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java deleted file mode 100644 index 08f8efd488..0000000000 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuumTest.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.flint.operation; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.Optional; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; -import org.opensearch.sql.spark.flint.FlintIndexClient; -import org.opensearch.sql.spark.flint.FlintIndexMetadata; -import org.opensearch.sql.spark.flint.FlintIndexState; -import org.opensearch.sql.spark.flint.FlintIndexStateModel; -import org.opensearch.sql.spark.flint.FlintIndexStateModelService; -import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; - -@ExtendWith(MockitoExtension.class) -class FlintIndexOpVacuumTest { - - public static final String DATASOURCE_NAME = "DATASOURCE_NAME"; - public static final String LATEST_ID = "LATEST_ID"; - public static final String INDEX_NAME = "INDEX_NAME"; - - public static final FlintIndexMetadata FLINT_INDEX_METADATA_WITH_LATEST_ID = - createFlintIndexMetadataWithLatestId(); - - public static final FlintIndexMetadata FLINT_INDEX_METADATA_WITHOUT_LATEST_ID = - createFlintIndexMetadataWithoutLatestId(); - - @Mock FlintIndexClient flintIndexClient; - @Mock FlintIndexStateModelService flintIndexStateModelService; - @Mock EMRServerlessClientFactory emrServerlessClientFactory; - @Mock FlintIndexStateModel flintIndexStateModel; - @Mock FlintIndexStateModel transitionedFlintIndexStateModel; - @Mock AsyncQueryRequestContext asyncQueryRequestContext; - @Mock AsyncQueryScheduler asyncQueryScheduler; - - RuntimeException testException = new RuntimeException("Test Exception"); - - FlintIndexOpVacuum flintIndexOpVacuum; - - @BeforeEach - public void setUp() { - flintIndexOpVacuum = - new FlintIndexOpVacuum( - flintIndexStateModelService, - DATASOURCE_NAME, - flintIndexClient, - emrServerlessClientFactory, - asyncQueryScheduler); - } - - private static FlintIndexMetadata createFlintIndexMetadataWithLatestId() { - return FlintIndexMetadata.builder() - .latestId(LATEST_ID) - .opensearchIndexName(INDEX_NAME) - .flintIndexOptions(new FlintIndexOptions()) - .build(); - } - - private static FlintIndexMetadata createFlintIndexMetadataWithoutLatestId() { - return FlintIndexMetadata.builder() - .opensearchIndexName(INDEX_NAME) - .flintIndexOptions(new FlintIndexOptions()) - .build(); - } - - private FlintIndexMetadata createFlintIndexMetadataWithExternalScheduler() { - FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); - flintIndexOptions.setOption(FlintIndexOptions.SCHEDULER_MODE, "external"); - - return FlintIndexMetadata.builder() - .opensearchIndexName(INDEX_NAME) - .flintIndexOptions(flintIndexOptions) - .build(); - } - - @Test - public void testApplyWithEmptyLatestId() { - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITHOUT_LATEST_ID, asyncQueryRequestContext); - - verify(flintIndexClient).deleteIndex(INDEX_NAME); - } - - @Test - public void testApplyWithFlintIndexStateNotFound() { - when(flintIndexStateModelService.getFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenReturn(Optional.empty()); - - assertThrows( - IllegalStateException.class, - () -> - flintIndexOpVacuum.apply( - FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); - } - - @Test - public void testApplyWithNotDeletedState() { - when(flintIndexStateModelService.getFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenReturn(Optional.of(flintIndexStateModel)); - when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.ACTIVE); - - assertThrows( - IllegalStateException.class, - () -> - flintIndexOpVacuum.apply( - FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); - } - - @Test - public void testApplyWithUpdateFlintIndexStateThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenReturn(Optional.of(flintIndexStateModel)); - when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); - when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, - FlintIndexState.VACUUMING, - DATASOURCE_NAME, - asyncQueryRequestContext)) - .thenThrow(testException); - - assertThrows( - IllegalStateException.class, - () -> - flintIndexOpVacuum.apply( - FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); - } - - @Test - public void testApplyWithRunOpThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenReturn(Optional.of(flintIndexStateModel)); - when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); - when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, - FlintIndexState.VACUUMING, - DATASOURCE_NAME, - asyncQueryRequestContext)) - .thenReturn(transitionedFlintIndexStateModel); - doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); - - assertThrows( - Exception.class, - () -> - flintIndexOpVacuum.apply( - FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); - - verify(flintIndexStateModelService) - .updateFlintIndexState( - transitionedFlintIndexStateModel, - FlintIndexState.DELETED, - DATASOURCE_NAME, - asyncQueryRequestContext); - } - - @Test - public void testApplyWithRunOpThrowAndRollbackThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenReturn(Optional.of(flintIndexStateModel)); - when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); - when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, - FlintIndexState.VACUUMING, - DATASOURCE_NAME, - asyncQueryRequestContext)) - .thenReturn(transitionedFlintIndexStateModel); - doThrow(testException).when(flintIndexClient).deleteIndex(INDEX_NAME); - when(flintIndexStateModelService.updateFlintIndexState( - transitionedFlintIndexStateModel, - FlintIndexState.DELETED, - DATASOURCE_NAME, - asyncQueryRequestContext)) - .thenThrow(testException); - - assertThrows( - Exception.class, - () -> - flintIndexOpVacuum.apply( - FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); - } - - @Test - public void testApplyWithDeleteFlintIndexStateModelThrow() { - when(flintIndexStateModelService.getFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenReturn(Optional.of(flintIndexStateModel)); - when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); - when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, - FlintIndexState.VACUUMING, - DATASOURCE_NAME, - asyncQueryRequestContext)) - .thenReturn(transitionedFlintIndexStateModel); - when(flintIndexStateModelService.deleteFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenThrow(testException); - - assertThrows( - IllegalStateException.class, - () -> - flintIndexOpVacuum.apply( - FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext)); - } - - @Test - public void testApplyHappyPath() { - when(flintIndexStateModelService.getFlintIndexStateModel( - LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext)) - .thenReturn(Optional.of(flintIndexStateModel)); - when(flintIndexStateModel.getIndexState()).thenReturn(FlintIndexState.DELETED); - when(flintIndexStateModelService.updateFlintIndexState( - flintIndexStateModel, - FlintIndexState.VACUUMING, - DATASOURCE_NAME, - asyncQueryRequestContext)) - .thenReturn(transitionedFlintIndexStateModel); - when(transitionedFlintIndexStateModel.getLatestId()).thenReturn(LATEST_ID); - - flintIndexOpVacuum.apply(FLINT_INDEX_METADATA_WITH_LATEST_ID, asyncQueryRequestContext); - - verify(flintIndexStateModelService) - .deleteFlintIndexStateModel(LATEST_ID, DATASOURCE_NAME, asyncQueryRequestContext); - verify(flintIndexClient).deleteIndex(INDEX_NAME); - } - - @Test - public void testRunOpWithExternalScheduler() { - FlintIndexMetadata flintIndexMetadata = createFlintIndexMetadataWithExternalScheduler(); - flintIndexOpVacuum.runOp(flintIndexMetadata, flintIndexStateModel, asyncQueryRequestContext); - - verify(asyncQueryScheduler).removeJob(INDEX_NAME); - verify(flintIndexClient).deleteIndex(INDEX_NAME); - } - - @Test - public void testRunOpWithoutExternalScheduler() { - FlintIndexMetadata flintIndexMetadata = FLINT_INDEX_METADATA_WITHOUT_LATEST_ID; - flintIndexOpVacuum.runOp(flintIndexMetadata, flintIndexStateModel, asyncQueryRequestContext); - - verify(asyncQueryScheduler, never()).removeJob(INDEX_NAME); - verify(flintIndexClient).deleteIndex(INDEX_NAME); - } -} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 235fe84c70..f1853f2c1e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -142,7 +142,6 @@ void testExtractionFromFlintSkippingIndexQueries() { + " WHERE elb_status_code = 500 " + " WITH (auto_refresh = true)", "DROP SKIPPING INDEX ON myS3.default.alb_logs", - "VACUUM SKIPPING INDEX ON myS3.default.alb_logs", "ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)", }; @@ -171,7 +170,6 @@ void testExtractionFromFlintCoveringIndexQueries() { + " WHERE elb_status_code = 500 " + " WITH (auto_refresh = true)", "DROP INDEX elb_and_requestUri ON myS3.default.alb_logs", - "VACUUM INDEX elb_and_requestUri ON myS3.default.alb_logs", "ALTER INDEX elb_and_requestUri ON myS3.default.alb_logs WITH (auto_refresh = false)" }; @@ -203,9 +201,7 @@ void testExtractionFromCreateMVQuery() { @Test void testExtractionFromFlintMVQuery() { String[] mvQueries = { - "DROP MATERIALIZED VIEW mv_1", - "VACUUM MATERIALIZED VIEW mv_1", - "ALTER MATERIALIZED VIEW mv_1 WITH (auto_refresh = false)", + "DROP MATERIALIZED VIEW mv_1", "ALTER MATERIALIZED VIEW mv_1 WITH (auto_refresh = false)", }; for (String query : mvQueries) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java deleted file mode 100644 index e62b60bfd2..0000000000 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecVacuumTest.java +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.asyncquery; - -import static org.opensearch.sql.spark.flint.FlintIndexState.ACTIVE; -import static org.opensearch.sql.spark.flint.FlintIndexState.CREATING; -import static org.opensearch.sql.spark.flint.FlintIndexState.DELETED; -import static org.opensearch.sql.spark.flint.FlintIndexState.EMPTY; -import static org.opensearch.sql.spark.flint.FlintIndexState.REFRESHING; -import static org.opensearch.sql.spark.flint.FlintIndexState.VACUUMING; -import static org.opensearch.sql.spark.flint.FlintIndexType.COVERING; -import static org.opensearch.sql.spark.flint.FlintIndexType.MATERIALIZED_VIEW; -import static org.opensearch.sql.spark.flint.FlintIndexType.SKIPPING; - -import com.amazonaws.services.emrserverless.model.CancelJobRunResult; -import com.amazonaws.services.emrserverless.model.GetJobRunResult; -import com.amazonaws.services.emrserverless.model.JobRun; -import com.google.common.collect.Lists; -import java.util.Base64; -import java.util.List; -import java.util.function.BiConsumer; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.Test; -import org.opensearch.action.admin.indices.exists.indices.IndicesExistsRequest; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.get.GetRequest; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob; -import org.opensearch.sql.spark.client.EMRServerlessClientFactory; -import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; -import org.opensearch.sql.spark.flint.FlintIndexState; -import org.opensearch.sql.spark.flint.FlintIndexType; -import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; -import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; -import org.opensearch.sql.spark.rest.model.LangType; - -@SuppressWarnings({"unchecked", "rawtypes"}) -public class IndexQuerySpecVacuumTest extends AsyncQueryExecutorServiceSpec { - - private static final EMRApiCall DEFAULT_OP = () -> null; - - private final List FLINT_TEST_DATASETS = - List.of( - mockDataset( - "VACUUM SKIPPING INDEX ON mys3.default.http_logs", - SKIPPING, - "flint_mys3_default_http_logs_skipping_index"), - mockDataset( - "VACUUM INDEX covering ON mys3.default.http_logs", - COVERING, - "flint_mys3_default_http_logs_covering_index"), - mockDataset( - "VACUUM MATERIALIZED VIEW mys3.default.http_logs_metrics", - MATERIALIZED_VIEW, - "flint_mys3_default_http_logs_metrics"), - mockDataset( - "VACUUM SKIPPING INDEX ON mys3.default.`test ,:\"+/\\|?#><`", - SKIPPING, - "flint_mys3_default_test%20%2c%3a%22%2b%2f%5c%7c%3f%23%3e%3c_skipping_index") - .isSpecialCharacter(true)); - - @Test - public void shouldVacuumIndexInDeletedState() { - List> testCases = - Lists.cartesianProduct( - FLINT_TEST_DATASETS, - List.of(DELETED), - List.of( - Pair.of( - DEFAULT_OP, - () -> new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled"))))); - - runVacuumTestSuite( - testCases, - (mockDS, response) -> { - assertEquals("SUCCESS", response.getStatus()); - assertFalse(flintIndexExists(mockDS.indexName)); - assertFalse(indexDocExists(mockDS.latestId)); - }); - } - - @Test - public void shouldNotVacuumIndexInOtherStates() { - List> testCases = - Lists.cartesianProduct( - FLINT_TEST_DATASETS, - List.of(EMPTY, CREATING, ACTIVE, REFRESHING, VACUUMING), - List.of( - Pair.of( - () -> { - throw new AssertionError("should not call cancelJobRun"); - }, - () -> { - throw new AssertionError("should not call getJobRunResult"); - }))); - - runVacuumTestSuite( - testCases, - (mockDS, response) -> { - assertEquals("FAILED", response.getStatus()); - assertTrue(flintIndexExists(mockDS.indexName)); - assertTrue(indexDocExists(mockDS.latestId)); - }); - } - - private void runVacuumTestSuite( - List> testCases, - BiConsumer assertion) { - testCases.forEach( - params -> { - FlintDatasetMock mockDS = (FlintDatasetMock) params.get(0); - try { - FlintIndexState state = (FlintIndexState) params.get(1); - EMRApiCall cancelJobRun = ((Pair) params.get(2)).getLeft(); - EMRApiCall getJobRunResult = ((Pair) params.get(2)).getRight(); - - AsyncQueryExecutionResponse response = - runVacuumTest(mockDS, state, cancelJobRun, getJobRunResult); - assertion.accept(mockDS, response); - } finally { - // Clean up because we simulate parameterized test in single unit test method - if (flintIndexExists(mockDS.indexName)) { - mockDS.deleteIndex(); - } - if (indexDocExists(mockDS.latestId)) { - deleteIndexDoc(mockDS.latestId); - } - } - }); - } - - private AsyncQueryExecutionResponse runVacuumTest( - FlintDatasetMock mockDS, - FlintIndexState state, - EMRApiCall cancelJobRun, - EMRApiCall getJobRunResult) { - LocalEMRSClient emrsClient = - new LocalEMRSClient() { - @Override - public CancelJobRunResult cancelJobRun( - String applicationId, String jobId, boolean allowExceptionPropagation) { - if (cancelJobRun == DEFAULT_OP) { - return super.cancelJobRun(applicationId, jobId, allowExceptionPropagation); - } - return cancelJobRun.call(); - } - - @Override - public GetJobRunResult getJobRunResult(String applicationId, String jobId) { - if (getJobRunResult == DEFAULT_OP) { - return super.getJobRunResult(applicationId, jobId); - } - return getJobRunResult.call(); - } - }; - EMRServerlessClientFactory emrServerlessClientFactory = (accountId) -> emrsClient; - AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrServerlessClientFactory); - - // Mock Flint index - mockDS.createIndex(); - - // Mock index state doc - MockFlintSparkJob flintIndexJob = - new MockFlintSparkJob(flintIndexStateModelService, mockDS.latestId, "mys3"); - flintIndexJob.transition(state); - - // Vacuum index - CreateAsyncQueryResponse response = - asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest(mockDS.query, MYS3_DATASOURCE, LangType.SQL, null), - asyncQueryRequestContext); - - return asyncQueryExecutorService.getAsyncQueryResults( - response.getQueryId(), asyncQueryRequestContext); - } - - private boolean flintIndexExists(String flintIndexName) { - return client - .admin() - .indices() - .exists(new IndicesExistsRequest(flintIndexName)) - .actionGet() - .isExists(); - } - - private boolean indexDocExists(String docId) { - return client - .get(new GetRequest(OpenSearchStateStoreUtil.getIndexName("mys3"), docId)) - .actionGet() - .isExists(); - } - - private void deleteIndexDoc(String docId) { - client - .delete(new DeleteRequest(OpenSearchStateStoreUtil.getIndexName("mys3"), docId)) - .actionGet(); - } - - private FlintDatasetMock mockDataset(String query, FlintIndexType indexType, String indexName) { - FlintDatasetMock dataset = new FlintDatasetMock(query, "", indexType, indexName); - dataset.latestId(Base64.getEncoder().encodeToString(indexName.getBytes())); - return dataset; - } - - /** - * EMR API call mock interface. - * - * @param API call response type - */ - @FunctionalInterface - public interface EMRApiCall { - V call(); - } -} diff --git a/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexClientTest.java b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexClientTest.java new file mode 100644 index 0000000000..d9f2e58dba --- /dev/null +++ b/async-query/src/test/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexClientTest.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; + +@ExtendWith(MockitoExtension.class) +public class OpenSearchFlintIndexClientTest { + + @Mock(answer = RETURNS_DEEP_STUBS) + private Client client; + + @Mock private AcknowledgedResponse acknowledgedResponse; + + @InjectMocks private OpenSearchFlintIndexClient openSearchFlintIndexClient; + + @Test + public void testDeleteIndex() { + when(client.admin().indices().delete(any(DeleteIndexRequest.class)).actionGet()) + .thenReturn(acknowledgedResponse); + when(acknowledgedResponse.isAcknowledged()).thenReturn(true); + + openSearchFlintIndexClient.deleteIndex("test-index"); + verify(client.admin().indices()).delete(any(DeleteIndexRequest.class)); + verify(acknowledgedResponse).isAcknowledged(); + } +}