From 50d9747f9145ebd5117a3f733ebe29e05086d31a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 11 Jun 2024 22:31:29 +0000 Subject: [PATCH] Handle create index with batch FlintJob (#2734) * update grammar file Signed-off-by: Sean Kao * batch job for create manual refresh index Signed-off-by: Sean Kao * dispatcher test for index dml query Signed-off-by: Sean Kao * borrow lease for refresh query, not batch Signed-off-by: Sean Kao * spotlessApply Signed-off-by: Sean Kao * add release note Signed-off-by: Sean Kao * update comment Signed-off-by: Sean Kao --------- Signed-off-by: Sean Kao (cherry picked from commit b959039bda2b2860656ffe1c698ae64c3861d6c4) Signed-off-by: github-actions[bot] --- .../opensearch-sql.release-notes-2.15.0.0.md | 5 +- spark/src/main/antlr/SqlBaseLexer.g4 | 44 ++++- spark/src/main/antlr/SqlBaseParser.g4 | 77 ++++++--- .../spark/dispatcher/BatchQueryHandler.java | 3 - .../spark/dispatcher/RefreshQueryHandler.java | 3 + .../dispatcher/SparkQueryDispatcher.java | 6 +- .../spark/asyncquery/IndexQuerySpecTest.java | 1 - .../dispatcher/SparkQueryDispatcherTest.java | 162 +++++++++++++++++- 8 files changed, 261 insertions(+), 40 deletions(-) diff --git a/release-notes/opensearch-sql.release-notes-2.15.0.0.md b/release-notes/opensearch-sql.release-notes-2.15.0.0.md index 9e06a8aa69..bde038f5e8 100644 --- a/release-notes/opensearch-sql.release-notes-2.15.0.0.md +++ b/release-notes/opensearch-sql.release-notes-2.15.0.0.md @@ -7,6 +7,9 @@ Compatible with OpenSearch and OpenSearch Dashboards Version 2.15.0 * Add option to use LakeFormation in S3Glue data source ([#2624](https://github.com/opensearch-project/sql/pull/2624)) * Remove direct ClusterState access in LocalClusterState ([#2717](https://github.com/opensearch-project/sql/pull/2717)) +### Bug Fixes +* Handle create index with batch FlintJob ([#2734](https://github.com/opensearch-project/sql/pull/2734)) + ### Maintenance * Use EMR serverless bundled iceberg JAR ([#2632](https://github.com/opensearch-project/sql/pull/2632)) * Update maintainers list ([#2663](https://github.com/opensearch-project/sql/pull/2663)) @@ -28,4 +31,4 @@ Compatible with OpenSearch and OpenSearch Dashboards Version 2.15.0 * Abstract queryId generation ([#2695](https://github.com/opensearch-project/sql/pull/2695)) * Introduce SessionConfigSupplier to abstract settings ([#2707](https://github.com/opensearch-project/sql/pull/2707)) * Add accountId to data models ([#2709](https://github.com/opensearch-project/sql/pull/2709)) -* Pass down request context to data accessors ([#2715](https://github.com/opensearch-project/sql/pull/2715)) \ No newline at end of file +* Pass down request context to data accessors ([#2715](https://github.com/opensearch-project/sql/pull/2715)) diff --git a/spark/src/main/antlr/SqlBaseLexer.g4 b/spark/src/main/antlr/SqlBaseLexer.g4 index 83e40c4a20..a9705c1733 100644 --- a/spark/src/main/antlr/SqlBaseLexer.g4 +++ b/spark/src/main/antlr/SqlBaseLexer.g4 @@ -69,6 +69,35 @@ lexer grammar SqlBaseLexer; public void markUnclosedComment() { has_unclosed_bracketed_comment = true; } + + /** + * When greater than zero, it's in the middle of parsing ARRAY/MAP/STRUCT type. + */ + public int complex_type_level_counter = 0; + + /** + * Increase the counter by one when hits KEYWORD 'ARRAY', 'MAP', 'STRUCT'. + */ + public void incComplexTypeLevelCounter() { + complex_type_level_counter++; + } + + /** + * Decrease the counter by one when hits close tag '>' && the counter greater than zero + * which means we are in the middle of complex type parsing. Otherwise, it's a dangling + * GT token and we do nothing. + */ + public void decComplexTypeLevelCounter() { + if (complex_type_level_counter > 0) complex_type_level_counter--; + } + + /** + * If the counter is zero, it's a shift right operator. It can be closing tags of an complex + * type definition, such as MAP>. + */ + public boolean isShiftRightOperator() { + return complex_type_level_counter == 0 ? true : false; + } } SEMICOLON: ';'; @@ -100,7 +129,7 @@ ANTI: 'ANTI'; ANY: 'ANY'; ANY_VALUE: 'ANY_VALUE'; ARCHIVE: 'ARCHIVE'; -ARRAY: 'ARRAY'; +ARRAY: 'ARRAY' {incComplexTypeLevelCounter();}; AS: 'AS'; ASC: 'ASC'; AT: 'AT'; @@ -108,6 +137,7 @@ AUTHORIZATION: 'AUTHORIZATION'; BETWEEN: 'BETWEEN'; BIGINT: 'BIGINT'; BINARY: 'BINARY'; +BINDING: 'BINDING'; BOOLEAN: 'BOOLEAN'; BOTH: 'BOTH'; BUCKET: 'BUCKET'; @@ -137,6 +167,7 @@ COMMENT: 'COMMENT'; COMMIT: 'COMMIT'; COMPACT: 'COMPACT'; COMPACTIONS: 'COMPACTIONS'; +COMPENSATION: 'COMPENSATION'; COMPUTE: 'COMPUTE'; CONCATENATE: 'CONCATENATE'; CONSTRAINT: 'CONSTRAINT'; @@ -257,7 +288,7 @@ LOCKS: 'LOCKS'; LOGICAL: 'LOGICAL'; LONG: 'LONG'; MACRO: 'MACRO'; -MAP: 'MAP'; +MAP: 'MAP' {incComplexTypeLevelCounter();}; MATCHED: 'MATCHED'; MERGE: 'MERGE'; MICROSECOND: 'MICROSECOND'; @@ -298,8 +329,6 @@ OVERWRITE: 'OVERWRITE'; PARTITION: 'PARTITION'; PARTITIONED: 'PARTITIONED'; PARTITIONS: 'PARTITIONS'; -PERCENTILE_CONT: 'PERCENTILE_CONT'; -PERCENTILE_DISC: 'PERCENTILE_DISC'; PERCENTLIT: 'PERCENT'; PIVOT: 'PIVOT'; PLACING: 'PLACING'; @@ -362,7 +391,7 @@ STATISTICS: 'STATISTICS'; STORED: 'STORED'; STRATIFY: 'STRATIFY'; STRING: 'STRING'; -STRUCT: 'STRUCT'; +STRUCT: 'STRUCT' {incComplexTypeLevelCounter();}; SUBSTR: 'SUBSTR'; SUBSTRING: 'SUBSTRING'; SYNC: 'SYNC'; @@ -439,8 +468,11 @@ NEQ : '<>'; NEQJ: '!='; LT : '<'; LTE : '<=' | '!>'; -GT : '>'; +GT : '>' {decComplexTypeLevelCounter();}; GTE : '>=' | '!<'; +SHIFT_LEFT: '<<'; +SHIFT_RIGHT: '>>' {isShiftRightOperator()}?; +SHIFT_RIGHT_UNSIGNED: '>>>' {isShiftRightOperator()}?; PLUS: '+'; MINUS: '-'; diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 index 60b67b0802..4552c17e0c 100644 --- a/spark/src/main/antlr/SqlBaseParser.g4 +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -77,7 +77,7 @@ statement | USE identifierReference #use | USE namespace identifierReference #useNamespace | SET CATALOG (errorCapturingIdentifier | stringLit) #setCatalog - | CREATE namespace (IF NOT EXISTS)? identifierReference + | CREATE namespace (IF errorCapturingNot EXISTS)? identifierReference (commentSpec | locationSpec | (WITH (DBPROPERTIES | PROPERTIES) propertyList))* #createNamespace @@ -92,7 +92,7 @@ statement | createTableHeader (LEFT_PAREN createOrReplaceTableColTypeList RIGHT_PAREN)? tableProvider? createTableClauses (AS? query)? #createTable - | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + | CREATE TABLE (IF errorCapturingNot EXISTS)? target=tableIdentifier LIKE source=tableIdentifier (tableProvider | rowFormat | @@ -141,7 +141,7 @@ statement SET SERDE stringLit (WITH SERDEPROPERTIES propertyList)? #setTableSerDe | ALTER TABLE identifierReference (partitionSpec)? SET SERDEPROPERTIES propertyList #setTableSerDe - | ALTER (TABLE | VIEW) identifierReference ADD (IF NOT EXISTS)? + | ALTER (TABLE | VIEW) identifierReference ADD (IF errorCapturingNot EXISTS)? partitionSpecLocation+ #addTablePartition | ALTER TABLE identifierReference from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition @@ -153,9 +153,10 @@ statement | DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable | DROP VIEW (IF EXISTS)? identifierReference #dropView | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? - VIEW (IF NOT EXISTS)? identifierReference + VIEW (IF errorCapturingNot EXISTS)? identifierReference identifierCommentList? (commentSpec | + schemaBinding | (PARTITIONED ON identifierList) | (TBLPROPERTIES propertyList))* AS query #createView @@ -163,7 +164,8 @@ statement tableIdentifier (LEFT_PAREN colTypeList RIGHT_PAREN)? tableProvider (OPTIONS propertyList)? #createTempViewUsing | ALTER VIEW identifierReference AS? query #alterViewQuery - | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + | ALTER VIEW identifierReference schemaBinding #alterViewSchemaBinding + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF errorCapturingNot EXISTS)? identifierReference AS className=stringLit (USING resource (COMMA resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction @@ -224,7 +226,7 @@ statement | SET .*? #setConfiguration | RESET configKey #resetQuotedConfiguration | RESET .*? #resetConfiguration - | CREATE INDEX (IF NOT EXISTS)? identifier ON TABLE? + | CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE? identifierReference (USING indexType=identifier)? LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN (OPTIONS options=propertyList)? #createIndex @@ -315,7 +317,7 @@ unsupportedHiveNativeCommands ; createTableHeader - : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? identifierReference + : CREATE TEMPORARY? EXTERNAL? TABLE (IF errorCapturingNot EXISTS)? identifierReference ; replaceTableHeader @@ -342,6 +344,10 @@ locationSpec : LOCATION stringLit ; +schemaBinding + : WITH SCHEMA (BINDING | COMPENSATION | EVOLUTION | TYPE EVOLUTION) + ; + commentSpec : COMMENT stringLit ; @@ -351,8 +357,8 @@ query ; insertInto - : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF NOT EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable - | INSERT INTO TABLE? identifierReference partitionSpec? (IF NOT EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable + : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF errorCapturingNot EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable + | INSERT INTO TABLE? identifierReference partitionSpec? (IF errorCapturingNot EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable | INSERT INTO TABLE? identifierReference REPLACE whereClause #insertIntoReplaceWhere | INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? createFileFormat? #insertOverwriteHiveDir | INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider (OPTIONS options=propertyList)? #insertOverwriteDir @@ -389,6 +395,7 @@ describeFuncName | comparisonOperator | arithmeticOperator | predicateOperator + | shiftOperator | BANG ; @@ -588,11 +595,11 @@ matchedClause : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction ; notMatchedClause - : WHEN NOT MATCHED (BY TARGET)? (AND notMatchedCond=booleanExpression)? THEN notMatchedAction + : WHEN errorCapturingNot MATCHED (BY TARGET)? (AND notMatchedCond=booleanExpression)? THEN notMatchedAction ; notMatchedBySourceClause - : WHEN NOT MATCHED BY SOURCE (AND notMatchedBySourceCond=booleanExpression)? THEN notMatchedBySourceAction + : WHEN errorCapturingNot MATCHED BY SOURCE (AND notMatchedBySourceCond=booleanExpression)? THEN notMatchedBySourceAction ; matchedAction @@ -838,9 +845,11 @@ tableArgumentPartitioning : ((WITH SINGLE PARTITION) | ((PARTITION | DISTRIBUTE) BY (((LEFT_PAREN partition+=expression (COMMA partition+=expression)* RIGHT_PAREN)) + | (expression (COMMA invalidMultiPartitionExpression=expression)+) | partition+=expression))) ((ORDER | SORT) BY (((LEFT_PAREN sortItem (COMMA sortItem)* RIGHT_PAREN) + | (sortItem (COMMA invalidMultiSortItem=sortItem)+) | sortItem)))? ; @@ -956,15 +965,20 @@ booleanExpression ; predicate - : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression - | NOT? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN - | NOT? kind=IN LEFT_PAREN query RIGHT_PAREN - | NOT? kind=RLIKE pattern=valueExpression - | NOT? kind=(LIKE | ILIKE) quantifier=(ANY | SOME | ALL) (LEFT_PAREN RIGHT_PAREN | LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN) - | NOT? kind=(LIKE | ILIKE) pattern=valueExpression (ESCAPE escapeChar=stringLit)? - | IS NOT? kind=NULL - | IS NOT? kind=(TRUE | FALSE | UNKNOWN) - | IS NOT? kind=DISTINCT FROM right=valueExpression + : errorCapturingNot? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | errorCapturingNot? kind=IN LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN + | errorCapturingNot? kind=IN LEFT_PAREN query RIGHT_PAREN + | errorCapturingNot? kind=RLIKE pattern=valueExpression + | errorCapturingNot? kind=(LIKE | ILIKE) quantifier=(ANY | SOME | ALL) (LEFT_PAREN RIGHT_PAREN | LEFT_PAREN expression (COMMA expression)* RIGHT_PAREN) + | errorCapturingNot? kind=(LIKE | ILIKE) pattern=valueExpression (ESCAPE escapeChar=stringLit)? + | IS errorCapturingNot? kind=NULL + | IS errorCapturingNot? kind=(TRUE | FALSE | UNKNOWN) + | IS errorCapturingNot? kind=DISTINCT FROM right=valueExpression + ; + +errorCapturingNot + : NOT + | BANG ; valueExpression @@ -972,12 +986,19 @@ valueExpression | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression shiftOperator right=valueExpression #shiftExpression | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary | left=valueExpression comparisonOperator right=valueExpression #comparison ; +shiftOperator + : SHIFT_LEFT + | SHIFT_RIGHT + | SHIFT_RIGHT_UNSIGNED + ; + datetimeUnit : YEAR | QUARTER | MONTH | WEEK | DAY | DAYOFYEAR @@ -1143,7 +1164,7 @@ qualifiedColTypeWithPosition ; colDefinitionDescriptorWithPosition - : NOT NULL + : errorCapturingNot NULL | defaultExpression | commentSpec | colPosition @@ -1162,7 +1183,7 @@ colTypeList ; colType - : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? + : colName=errorCapturingIdentifier dataType (errorCapturingNot NULL)? commentSpec? ; createOrReplaceTableColTypeList @@ -1174,7 +1195,7 @@ createOrReplaceTableColType ; colDefinitionOption - : NOT NULL + : errorCapturingNot NULL | defaultExpression | generationExpression | commentSpec @@ -1189,7 +1210,7 @@ complexColTypeList ; complexColType - : errorCapturingIdentifier COLON? dataType (NOT NULL)? commentSpec? + : errorCapturingIdentifier COLON? dataType (errorCapturingNot NULL)? commentSpec? ; whenClause @@ -1296,7 +1317,7 @@ alterColumnAction : TYPE dataType | commentSpec | colPosition - | setOrDrop=(SET | DROP) NOT NULL + | setOrDrop=(SET | DROP) errorCapturingNot NULL | SET defaultExpression | dropDefault=DROP DEFAULT ; @@ -1343,6 +1364,7 @@ ansiNonReserved | BIGINT | BINARY | BINARY_HEX + | BINDING | BOOLEAN | BUCKET | BUCKETS @@ -1365,6 +1387,7 @@ ansiNonReserved | COMMIT | COMPACT | COMPACTIONS + | COMPENSATION | COMPUTE | CONCATENATE | COST @@ -1643,6 +1666,7 @@ nonReserved | BIGINT | BINARY | BINARY_HEX + | BINDING | BOOLEAN | BOTH | BUCKET @@ -1672,6 +1696,7 @@ nonReserved | COMMIT | COMPACT | COMPACTIONS + | COMPENSATION | COMPUTE | CONCATENATE | CONSTRAINT @@ -1824,8 +1849,6 @@ nonReserved | PARTITION | PARTITIONED | PARTITIONS - | PERCENTILE_CONT - | PERCENTILE_DISC | PERCENTLIT | PIVOT | PLACING diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index a88fe485fe..09d2dbd6c6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -25,7 +25,6 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.leasemanager.LeaseManager; -import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -69,8 +68,6 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { - leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource())); - String clusterName = dispatchQueryRequest.getClusterName(); Map tags = context.getTags(); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 69c21321a6..78a2651317 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -18,6 +18,7 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOp; import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; +import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; /** @@ -59,6 +60,8 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { + leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource())); + DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); return DispatchQueryResponse.builder() diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 24950b5cfe..5facdee567 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -79,8 +79,12 @@ private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( return queryHandlerFactory.getIndexDMLHandler(); } else if (isEligibleForStreamingQuery(indexQueryDetails)) { return queryHandlerFactory.getStreamingQueryHandler(); + } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) { + // Create should be handled by batch handler. This is to avoid DROP index incorrectly cancel + // an interactive job. + return queryHandlerFactory.getBatchQueryHandler(); } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { - // manual refresh should be handled by batch handler + // Manual refresh should be handled by batch handler return queryHandlerFactory.getRefreshQueryHandler(); } else { return getDefaultAsyncQueryHandler(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index b4962240f5..2b6b1d2ba0 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -864,7 +864,6 @@ public void concurrentRefreshJobLimitNotAppliedToDDL() { asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest(query, MYS3_DATASOURCE, LangType.SQL, null), asyncQueryRequestContext); - assertNotNull(asyncQueryResponse.getSessionId()); } /** Cancel create flint index statement with auto_refresh=true, should throw exception. */ diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index a9cfd19307..199582dde7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; @@ -363,7 +364,7 @@ void testDispatchSelectQueryFailedCreateSession() { } @Test - void testDispatchIndexQuery() { + void testDispatchCreateAutoRefreshIndexQuery() { when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, MY_GLUE); @@ -407,6 +408,49 @@ void testDispatchIndexQuery() { verifyNoInteractions(flintIndexMetadataService); } + @Test + void testDispatchCreateManualRefreshIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + HashMap tags = new HashMap<>(); + tags.put(DATASOURCE_TAG_KEY, "my_glue"); + tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); + String query = + "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + + " (auto_refresh = false)"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }, + query); + StartJobRequest expected = + new StartJobRequest( + "TEST_CLUSTER:batch", + null, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + "query_execution_result_my_glue"); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + .thenReturn(dataSourceMetadata); + + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); + + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + verifyNoInteractions(flintIndexMetadataService); + } + @Test void testDispatchWithPPLQuery() { when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); @@ -704,6 +748,122 @@ void testDispatchDescribeIndexQuery() { verifyNoInteractions(flintIndexMetadataService); } + @Test + void testDispatchAlterToAutoRefreshIndexQuery() { + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); + HashMap tags = new HashMap<>(); + tags.put(DATASOURCE_TAG_KEY, "my_glue"); + tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index"); + tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); + String query = + "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }, + query)); + StartJobRequest expected = + new StartJobRequest( + "TEST_CLUSTER:streaming:flint_my_glue_default_http_logs_elb_and_requesturi_index", + null, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + "query_execution_result_my_glue"); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + .thenReturn(dataSourceMetadata); + + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); + + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + verifyNoInteractions(flintIndexMetadataService); + } + + @Test + void testDispatchAlterToManualRefreshIndexQuery() { + QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); + sparkQueryDispatcher = + new SparkQueryDispatcher( + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + + String query = + "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + + " (auto_refresh = false)"; + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + .thenReturn(dataSourceMetadata); + when(queryHandlerFactory.getIndexDMLHandler()) + .thenReturn( + new IndexDMLHandler( + jobExecutionResponseReader, + flintIndexMetadataService, + indexDMLResultStorageService, + flintIndexOpFactory)); + + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); + verify(queryHandlerFactory, times(1)).getIndexDMLHandler(); + } + + @Test + void testDispatchDropIndexQuery() { + QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); + sparkQueryDispatcher = + new SparkQueryDispatcher( + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + + String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs"; + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + .thenReturn(dataSourceMetadata); + when(queryHandlerFactory.getIndexDMLHandler()) + .thenReturn( + new IndexDMLHandler( + jobExecutionResponseReader, + flintIndexMetadataService, + indexDMLResultStorageService, + flintIndexOpFactory)); + + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); + verify(queryHandlerFactory, times(1)).getIndexDMLHandler(); + } + + @Test + void testDispatchVacuumIndexQuery() { + QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); + sparkQueryDispatcher = + new SparkQueryDispatcher( + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + + String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs"; + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + .thenReturn(dataSourceMetadata); + when(queryHandlerFactory.getIndexDMLHandler()) + .thenReturn( + new IndexDMLHandler( + jobExecutionResponseReader, + flintIndexMetadataService, + indexDMLResultStorageService, + flintIndexOpFactory)); + + sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext); + verify(queryHandlerFactory, times(1)).getIndexDMLHandler(); + } + @Test void testDispatchWithWrongURI() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))