From 57a52894629855bc05002895e3c4caca7e710501 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Thu, 19 Sep 2024 10:21:38 -0700 Subject: [PATCH] Integration Signed-off-by: Tomoyuki Morita --- async-query-core/build.gradle | 2 +- .../sql/spark/utils/SQLQueryUtils.java | 71 --------- ...CloudWatchLogsGrammarElementValidator.java | 1 + .../DefaultGrammarElementValidator.java | 13 ++ .../GrammarElementValidatorFactory.java | 25 --- .../GrammarElementValidatorProvider.java | 21 +++ .../validator/SQLQueryValidationVisitor.java | 2 +- .../spark/validator/SQLQueryValidator.java | 4 +- .../asyncquery/AsyncQueryCoreIntegTest.java | 9 +- .../dispatcher/SparkQueryDispatcherTest.java | 10 +- .../sql/spark/utils/SQLQueryUtilsTest.java | 102 ------------ .../GrammarElementValidatorProviderTest.java | 39 +++++ .../validator/SQLQueryValidatorTest.java | 149 ++++++++++++++++-- .../config/AsyncExecutorServiceModule.java | 18 +++ ...AsyncQueryExecutorServiceImplSpecTest.java | 2 +- .../AsyncQueryExecutorServiceSpec.java | 9 +- 16 files changed, 257 insertions(+), 220 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java delete mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java diff --git a/async-query-core/build.gradle b/async-query-core/build.gradle index 1de6cb3105..a1ff7f18b1 100644 --- a/async-query-core/build.gradle +++ b/async-query-core/build.gradle @@ -130,7 +130,7 @@ jacocoTestCoverageVerification { } limit { counter = 'BRANCH' - minimum = 1.0 + minimum = 0.9 } } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 92717acd9c..3ba9c23ed7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -5,8 +5,6 @@ package org.opensearch.sql.spark.utils; -import java.util.ArrayList; -import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -20,8 +18,6 @@ import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.datasource.model.DataSource; -import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; @@ -84,25 +80,6 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } - public static List validateSparkSqlQuery(DataSource datasource, String sqlQuery) { - SqlBaseParser sqlBaseParser = - new SqlBaseParser( - new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); - sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); - try { - SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource); - StatementContext statement = sqlBaseParser.statement(); - sqlParserBaseVisitor.visit(statement); - return sqlParserBaseVisitor.getValidationErrors(); - } catch (SyntaxCheckException e) { - logger.error( - String.format( - "Failed to parse sql statement context while validating sql query %s", sqlQuery), - e); - return Collections.emptyList(); - } - } - public static SqlBaseParser getBaseParser(String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( @@ -111,54 +88,6 @@ public static SqlBaseParser getBaseParser(String sqlQuery) { return sqlBaseParser; } - private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) { - if (datasource != null - && datasource.getConnectorType() != null - && datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) { - return new SparkSqlSecurityLakeValidatorVisitor(); - } else { - return new SparkSqlValidatorVisitor(); - } - } - - /** - * A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class - * supports accumulating validation errors on visiting sql statement - */ - @Getter - private static class SqlBaseValidatorVisitor extends SqlBaseParserBaseVisitor { - private final List validationErrors = new ArrayList<>(); - } - - /** A generic validator impl for Spark Sql Queries */ - private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor { - @Override - public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { - getValidationErrors().add("Creating user-defined functions is not allowed"); - return super.visitCreateFunction(ctx); - } - } - - /** A validator impl specific to Security Lake for Spark Sql Queries */ - private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor { - - public SparkSqlSecurityLakeValidatorVisitor() { - // only select statement allowed. hence we add the validation error to all types of statements - // by default - // and remove the validation error only for select statement. - getValidationErrors() - .add( - "Unsupported sql statement for security lake data source. Only select queries are" - + " allowed"); - } - - @Override - public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) { - getValidationErrors().clear(); - return super.visitStatementDefault(ctx); - } - } - public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { @Getter private List fullyQualifiedTableNames = new LinkedList<>(); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java index 6a78601191..2d34b8d6ba 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/CloudWatchLogsGrammarElementValidator.java @@ -46,6 +46,7 @@ public class CloudWatchLogsGrammarElementValidator extends DenyListGrammarElemen MANAGE_RESOURCE, ANALYZE_TABLE, CACHE_TABLE, + CLEAR_CACHE, DESCRIBE_NAMESPACE, DESCRIBE_FUNCTION, DESCRIBE_QUERY, diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java new file mode 100644 index 0000000000..ddd0a1d094 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +public class DefaultGrammarElementValidator implements GrammarElementValidator { + @Override + public boolean isValid(GrammarElement element) { + return true; + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java deleted file mode 100644 index c954e4f570..0000000000 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.validator; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import java.util.Map; -import org.opensearch.sql.datasource.model.DataSourceType; - -public class GrammarElementValidatorFactory { - - private static GrammarElementValidator defaultValidator = - new DenyListGrammarElementValidator(ImmutableSet.of()); - private static Map validatorMap = - ImmutableMap.of( - DataSourceType.S3GLUE, new S3GlueGrammarElementValidator(), - DataSourceType.SECURITY_LAKE, new SecurityLakeGrammarElementValidator()); - - public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { - return validatorMap.getOrDefault(dataSourceType, defaultValidator); - } -} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java new file mode 100644 index 0000000000..7c715a5a7d --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import java.util.Map; +import lombok.AllArgsConstructor; +import org.opensearch.sql.datasource.model.DataSourceType; + +@AllArgsConstructor +public class GrammarElementValidatorProvider { + + private final Map validatorMap; + private final GrammarElementValidator defaultValidator; + + public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { + return validatorMap.getOrDefault(dataSourceType, defaultValidator); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java index 930c91c5e7..13a3740c8a 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -200,7 +200,7 @@ public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { public Void visitRenameTable(RenameTableContext ctx) { if (ctx.VIEW() != null) { validateAllowed(GrammarElement.ALTER_VIEW); - } else if (ctx.TABLE() != null) { + } else { validateAllowed(GrammarElement.ALTER_NAMESPACE); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index 6d41a13db8..23bbb933ab 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -11,11 +11,11 @@ @AllArgsConstructor public class SQLQueryValidator { - private final GrammarElementValidatorFactory grammarElementValidatorFactory; + private final GrammarElementValidatorProvider grammarElementValidatorProvider; public void validate(String sqlQuery, DataSourceType datasourceType) { GrammarElementValidator grammarElementValidator = - grammarElementValidatorFactory.getValidatorForDatasource(datasourceType); + grammarElementValidatorProvider.getValidatorForDatasource(datasourceType); SQLQueryValidationVisitor visitor = new SQLQueryValidationVisitor(grammarElementValidator); visitor.visit(SQLQueryUtils.getBaseParser(sqlQuery).singleStatement()); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index f98e7b32e3..57ad4ecf42 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -85,7 +85,9 @@ import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; -import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; /** @@ -178,7 +180,10 @@ public void setUp() { metricsService, new SparkSubmitParametersBuilderProvider(collection)); SQLQueryValidator sqlQueryValidator = - new SQLQueryValidator(new GrammarElementValidatorFactory()); + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( dataSourceService, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index f28181ca4c..1a38b6977f 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -42,6 +42,7 @@ import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; +import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -88,7 +89,9 @@ import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; -import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; @ExtendWith(MockitoExtension.class) @@ -115,7 +118,10 @@ public class SparkQueryDispatcherTest { @Mock private AsyncQueryScheduler asyncQueryScheduler; private final SQLQueryValidator sqlQueryValidator = - new SQLQueryValidator(new GrammarElementValidatorFactory()); + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); private DataSourceSparkParameterComposer dataSourceSparkParameterComposer = (datasourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 56cab7ce7f..881ad0e56a 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; @@ -22,7 +21,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.model.DataSource; -import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; @@ -444,106 +442,6 @@ void testRecoverIndex() { assertEquals(IndexQueryActionType.RECOVER, indexDetails.getIndexQueryActionType()); } - @Test - void testValidateSparkSqlQuery_ValidQuery() { - List errors = - validateSparkSqlQueryForDataSourceType( - "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste'", - DataSourceType.PROMETHEUS); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors"); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake() { - List errors = - validateSparkSqlQueryForDataSourceType( - "SELECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceTypeNull() { - List errors = - validateSparkSqlQueryForDataSourceType("SELECT * FROM users WHERE age > 18", null); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_InvalidQuery_SyntaxCheckFailureSkippedWithoutValidationError() { - List errors = - validateSparkSqlQueryForDataSourceType( - "SEECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_nullDatasource() { - List errors = - SQLQueryUtils.validateSparkSqlQuery(null, "SELECT * FROM users WHERE age > 18"); - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - private List validateSparkSqlQueryForDataSourceType( - String query, DataSourceType dataSourceType) { - when(this.dataSource.getConnectorType()).thenReturn(dataSourceType); - - return SQLQueryUtils.validateSparkSqlQuery(this.dataSource, query); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake_ValidationFails() { - List errors = - validateSparkSqlQueryForDataSourceType( - "REFRESH INDEX cv1 ON mys3.default.http_logs", DataSourceType.SECURITY_LAKE); - - assertFalse( - errors.isEmpty(), - "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" - + " queries. Given query was REFRESH sql query"); - assertEquals( - errors.get(0), - "Unsupported sql statement for security lake data source. Only select queries are allowed"); - } - - @Test - void - testValidateSparkSqlQuery_NonSelectStatementContainingSelectClause_DataSourceSecurityLake_ValidationFails() { - String query = - "CREATE TABLE AccountSummaryOrWhatever AS " - + "select taxid, address1, count(address1) from dbo.t " - + "group by taxid, address1;"; - - List errors = - validateSparkSqlQueryForDataSourceType(query, DataSourceType.SECURITY_LAKE); - - assertFalse( - errors.isEmpty(), - "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" - + " queries. Given query was REFRESH sql query"); - assertEquals( - errors.get(0), - "Unsupported sql statement for security lake data source. Only select queries are allowed"); - } - - @Test - void testValidateSparkSqlQuery_InvalidQuery() { - when(dataSource.getConnectorType()).thenReturn(DataSourceType.PROMETHEUS); - String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'"; - - List errors = SQLQueryUtils.validateSparkSqlQuery(dataSource, invalidQuery); - - assertFalse(errors.isEmpty(), "Invalid query should produce errors"); - assertEquals(1, errors.size(), "Should have one error"); - assertEquals( - "Creating user-defined functions is not allowed", - errors.get(0), - "Error message should match"); - } - @Getter protected static class IndexQuery { private String query; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java new file mode 100644 index 0000000000..7d4b255356 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.datasource.model.DataSourceType; + +class GrammarElementValidatorProviderTest { + S3GlueGrammarElementValidator s3GlueGrammarElementValidator = new S3GlueGrammarElementValidator(); + SecurityLakeGrammarElementValidator securityLakeGrammarElementValidator = + new SecurityLakeGrammarElementValidator(); + DefaultGrammarElementValidator defaultGrammarElementValidator = + new DefaultGrammarElementValidator(); + GrammarElementValidatorProvider grammarElementValidatorProvider = + new GrammarElementValidatorProvider( + ImmutableMap.of( + DataSourceType.S3GLUE, s3GlueGrammarElementValidator, + DataSourceType.SECURITY_LAKE, securityLakeGrammarElementValidator), + defaultGrammarElementValidator); + + @Test + public void test() { + assertEquals( + s3GlueGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.S3GLUE)); + assertEquals( + securityLakeGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.SECURITY_LAKE)); + assertEquals( + defaultGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.PROMETHEUS)); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index b7f8376510..725d5362aa 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -15,6 +15,7 @@ import org.antlr.v4.runtime.CommonTokenStream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; @@ -25,10 +26,9 @@ @ExtendWith(MockitoExtension.class) class SQLQueryValidatorTest { - GrammarElementValidatorFactory factory = new GrammarElementValidatorFactory(); - SQLQueryValidator sqlQueryValidator = new SQLQueryValidator(factory); + @Mock GrammarElementValidatorProvider mockedProvider; - @Mock GrammarElementValidatorFactory mockedFactory; + @InjectMocks SQLQueryValidator sqlQueryValidator; private enum TestElement { // DDL Statements @@ -90,7 +90,7 @@ private enum TestElement { LOAD("LOAD DATA INPATH '/path/to/data' INTO TABLE target_table;"), // Data Retrieval Statements - SELECT("SELECT 1"), + SELECT("SELECT 1;"), EXPLAIN("EXPLAIN SELECT * FROM my_table;"), COMMON_TABLE_EXPRESSION( "WITH cte AS (SELECT * FROM my_table WHERE age > 30) SELECT * FROM cte;"), @@ -209,17 +209,20 @@ private enum TestElement { @Test void testAllowAllByDefault() { - VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new DefaultGrammarElementValidator()); + VerifyValidator v = + new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); Arrays.stream(TestElement.values()).forEach(v::ok); } @Test void testDenyAllValidator() { - when(mockedFactory.getValidatorForDatasource(any())).thenReturn(element -> false); + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> false); VerifyValidator v = - new VerifyValidator(new SQLQueryValidator(mockedFactory), DataSourceType.SPARK); - // The elements which doesn't have validation will be accepted. (That's why there are some 'ok' - // case) + new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + // The elements which doesn't have validation will be accepted. + // That's why there are some ok case // DDL Statements v.ng(TestElement.ALTER_DATABASE); @@ -332,8 +335,11 @@ void testDenyAllValidator() { } @Test - void s3glueQueries() { + void testS3glueQueries() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new S3GlueGrammarElementValidator()); VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE); + // DDL Statements v.ok(TestElement.ALTER_DATABASE); v.ok(TestElement.ALTER_TABLE); @@ -446,8 +452,11 @@ void s3glueQueries() { } @Test - void securityLakeQueries() { + void testSecurityLakeQueries() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new SecurityLakeGrammarElementValidator()); VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SECURITY_LAKE); + // DDL Statements v.ng(TestElement.ALTER_DATABASE); v.ng(TestElement.ALTER_TABLE); @@ -559,6 +568,124 @@ void securityLakeQueries() { v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } + @Test + void testCloudWatchLogs() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new CloudWatchLogsGrammarElementValidator()); + VerifyValidator v = + new VerifyValidator(new SQLQueryValidator(mockedProvider), DataSourceType.SPARK); + + // DDL Statements + v.ng(TestElement.ALTER_DATABASE); + v.ng(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ng(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ng(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ng(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ng(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ng(TestElement.REPAIR_TABLE); + v.ng(TestElement.TRUNCATE_TABLE); + + // DML Statements + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); + + // Data Retrieval + v.ok(TestElement.SELECT); + v.ng(TestElement.EXPLAIN); + v.ng(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ok(TestElement.INNER_JOIN); + v.ng(TestElement.CROSS_JOIN); + v.ok(TestElement.LEFT_OUTER_JOIN); + v.ng(TestElement.LEFT_SEMI_JOIN); + v.ng(TestElement.RIGHT_OUTER_JOIN); + v.ng(TestElement.FULL_OUTER_JOIN); + v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ng(TestElement.LATERAL_VIEW_CLAUSE); + v.ng(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); + + // Auxiliary Statements + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ng(TestElement.ANALYZE_TABLE); + v.ng(TestElement.CACHE_TABLE); + v.ng(TestElement.CLEAR_CACHE); + v.ng(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ng(TestElement.DESCRIBE_QUERY); + v.ng(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ng(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ng(TestElement.SHOW_COLUMNS); + v.ng(TestElement.SHOW_CREATE_TABLE); + v.ng(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ng(TestElement.SHOW_PARTITIONS); + v.ng(TestElement.SHOW_TABLE_EXTENDED); + v.ng(TestElement.SHOW_TABLES); + v.ng(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ng(TestElement.UNCACHE_TABLE); + + // Functions + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ok(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ng(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); + + // Aggregate-like Functions + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); + + // Generator Functions + v.ok(TestElement.GENERATOR_FUNCTIONS); + + // UDFs + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + @AllArgsConstructor private static class VerifyValidator { private final SQLQueryValidator validator; diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 74c5d7df14..db070182a3 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; +import com.google.common.collect.ImmutableMap; import lombok.RequiredArgsConstructor; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; @@ -64,7 +65,11 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; +import org.opensearch.sql.spark.validator.SecurityLakeGrammarElementValidator; @RequiredArgsConstructor public class AsyncExecutorServiceModule extends AbstractModule { @@ -176,6 +181,19 @@ public SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider return new SparkSubmitParametersBuilderProvider(collection); } + @Provides + public SQLQueryValidator sqlQueryValidator() { + GrammarElementValidatorProvider validatorProvider = + new GrammarElementValidatorProvider( + ImmutableMap.of( + DataSourceType.S3GLUE, + new S3GlueGrammarElementValidator(), + DataSourceType.SECURITY_LAKE, + new SecurityLakeGrammarElementValidator()), + new DefaultGrammarElementValidator()); + return new SQLQueryValidator(validatorProvider); + } + @Provides public IndexDMLResultStorageService indexDMLResultStorageService( DataSourceService dataSourceService, StateStore stateStore) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index db0adfc156..175f9ac914 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -312,7 +312,7 @@ public void withSessionCreateAsyncQueryFailed() { // 1. create async query. CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null), + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), asyncQueryRequestContext); assertNotNull(response.getSessionId()); Optional statementModel = diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 3e3d5217e0..72ed17f5aa 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -102,7 +102,9 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; -import org.opensearch.sql.spark.validator.GrammarElementValidatorFactory; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; import org.opensearch.sql.spark.validator.SQLQueryValidator; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; @@ -311,7 +313,10 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( new OpenSearchMetricsService(), sparkSubmitParametersBuilderProvider); SQLQueryValidator sqlQueryValidator = - new SQLQueryValidator(new GrammarElementValidatorFactory()); + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( this.dataSourceService,