From d84304f2aadf35c112cfc35a629817d19d655b72 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Tue, 17 Sep 2024 10:41:35 -0700 Subject: [PATCH] Implement SQL validation based on grammar element Signed-off-by: Tomoyuki Morita --- .../DenyListGrammarElementValidator.java | 19 + .../sql/spark/validator/GrammarElement.java | 87 ++++ .../validator/GrammarElementValidator.java | 10 + .../GrammarElementValidatorFactory.java | 74 +++ .../spark/validator/SQLQueryValidator.java | 491 ++++++++++++++++++ .../validator/SQLQueryValidatorTest.java | 301 +++++++++++ 6 files changed, 982 insertions(+) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java new file mode 100644 index 0000000000..514e2c8ad8 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import java.util.Set; +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +public class DenyListGrammarElementValidator implements GrammarElementValidator { + private final Set denyList; + + @Override + public boolean isValid(GrammarElement element) { + return !denyList.contains(element); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java new file mode 100644 index 0000000000..562a83dcd4 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +enum GrammarElement { + ALTER_NAMESPACE("ALTER DATABASE/TABLE/NAMESPACE"), + ALTER_VIEW("ALTER VIEW"), + CREATE_NAMESPACE("CREATE DATABASE/TABLE/NAMESPACE"), + CREATE_FUNCTION("CREATE FUNCTION"), + CREATE_VIEW("CREATE VIEW"), + DROP_NAMESPACE("DROP DATABASE/TABLE/NAMESPACE"), + DROP_FUNCTION("DROP FUNCTION"), + DROP_VIEW("DROP VIEW"), + DROP_TABLE("DROP TABLE"), + REPAIR_TABLE("REPAIR TABLE"), // does this conflict with DROP_NAMESPACE? + TRUNCATE_TABLE("TRUNCATE TABLE"), + // DML Statements + INSERT("INSERT"), + LOAD("LOAD"), + + // Data Retrieval Statements + EXPLAIN("EXPLAIN"), + WITH("WITH"), + CLUSTER_BY("CLUSTER BY"), + DISTRIBUTE_BY("DISTRIBUTE BY"), + GROUP_BY("GROUP BY"), + HAVING("HAVING"), + HINTS("HINTS"), + INLINE_TABLE("Inline Table(VALUES)"), + INNER_JOIN("INNER JOIN"), + CROSS_JOIN("CROSS JOIN"), + LEFT_OUTER_JOIN("LEFT OUTER JOIN"), + LEFT_SEMI_JOIN("LEFT SEMI JOIN"), + RIGHT_OUTER_JOIN("RIGHT OUTER JOIN"), + FULL_OUTER_JOIN("FULL OUTER JOIN"), + LEFT_ANTI_JOIN("LEFT ANTI JOIN"), + TABLESAMPLE("TABLESAMPLE"), + TABLE_VALUED_FUNCTION("Table-valued function"), + LATERAL_VIEW("LATERAL VIEW"), + LATERAL_SUBQUERY("LATERAL SUBQUERY"), + TRANSFORM("TRANSFORM"), + + // Auxiliary Statements + MANAGE_RESOURCE("Resource management statements"), + ANALYZE_TABLE("ANALYZE TABLE(S)"), + CACHE_TABLE("CACHE TABLE"), + CLEAR_CACHE("CLEAR CACHE"), + DESCRIBE_NAMESPACE("DESCRIBE (NAMESPACE|DATABASE|SCHEMA"), + DESCRIBE_FUNCTION("DESCRIBE FUNCTION"), + DESCRIBE_QUERY("DESCRIBE QUERY"), + DESCRIBE_TABLE("DESCRIBE TABLE"), + REFRESH_RESOURCE("REFRESH"), + REFRESH_TABLE("REFRESH TABLE"), + REFRESH_FUNCTION("REFRESH FUNCTION"), + RESET("RESET"), + SET("SET"), + SHOW_COLUMNS("SHOW COLUMNS"), + SHOW_CREATE_TABLE("SHOW CREATE TABLE"), + SHOW_NAMESPACES("SHOW (DATABASES|SCHEMAS)"), + SHOW_FUNCTIONS("SHOW FUNCTIONS"), + SHOW_PARTITIONS("SHOW PARTITIONS"), + SHOW_TABLE_EXTENDED("SHOW TABLE EXTENDED"), + SHOW_TABLES("SHOW TABLES"), + SHOW_TBLPROPERTIES("SHOW TBLPROPERTIES"), + SHOW_VIEWS("SHOW VIEWS"), + UNCACHE_TABLE("UNCACHE TABLE"), + + // Functions + MAP_FUNCTIONS("Map functions"), + CSV_FUNCTIONS("CSV functions"), + MISC_FUNCTIONS("Misc functions"), + + SELECT("SELECT"); + + String description; + + @Override + public String toString() { + return description; + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java new file mode 100644 index 0000000000..b11999b5d1 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +public interface GrammarElementValidator { + boolean isValid(GrammarElement element); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java new file mode 100644 index 0000000000..99cecf18ae --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.opensearch.sql.spark.validator.GrammarElement.*; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Map; +import java.util.Set; +import org.opensearch.sql.datasource.model.DataSourceType; + +public class GrammarElementValidatorFactory { + private static final Set DEFAULT_DENY_LIST = + ImmutableSet.of(CREATE_FUNCTION, DROP_FUNCTION, INSERT, LOAD, HINTS, TABLESAMPLE); + + private static final Set CWL_DENY_LIST = + copyBuilder(DEFAULT_DENY_LIST) + .add( + ALTER_NAMESPACE, + ALTER_VIEW, + CREATE_NAMESPACE, + CREATE_VIEW, + DROP_NAMESPACE, + DROP_VIEW, + REPAIR_TABLE, + TRUNCATE_TABLE) + .build(); + + private static final Set S3GLUE_DENY_LIST = + copyBuilder(DEFAULT_DENY_LIST) + .add( + ALTER_VIEW, + CREATE_VIEW, + DROP_VIEW, + REPAIR_TABLE, + DISTRIBUTE_BY, + INLINE_TABLE, + TRUNCATE_TABLE, + CLUSTER_BY, + DISTRIBUTE_BY, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + TRANSFORM, + MANAGE_RESOURCE, + DESCRIBE_FUNCTION, + REFRESH_RESOURCE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_FUNCTIONS, + SHOW_VIEWS, + MISC_FUNCTIONS) + .build(); + + private static Map validatorMap = + ImmutableMap.of(DataSourceType.S3GLUE, new DenyListGrammarElementValidator(S3GLUE_DENY_LIST)); + + public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { + return validatorMap.get(dataSourceType); + } + + private static ImmutableSet.Builder copyBuilder(Set original) { + return ImmutableSet.builder().addAll(original); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java new file mode 100644 index 0000000000..a737c62071 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -0,0 +1,491 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewSchemaBindingContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClearCacheContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClusterBySpecContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CtesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteHiveDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetQuotedConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SampleContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SelectClauseContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetQuantifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowCreateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowFunctionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowNamespacesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowPartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTableExtendedContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTblPropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowViewsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; + +@AllArgsConstructor +public class SQLQueryValidator extends SqlBaseParserBaseVisitor { + private final GrammarElementValidator grammarElementValidator; + + public void validate(SqlBaseParser.SingleStatementContext statement) { + this.visit(statement); + } + + @Override + public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { + validateAllowed(GrammarElement.CREATE_FUNCTION); + return super.visitCreateFunction(ctx); + } + + @Override + public Void visitSelectClause(SelectClauseContext ctx) { + validateAllowed(GrammarElement.SELECT); + return super.visitSelectClause(ctx); + } + + @Override + public Void visitSetNamespaceProperties(SetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceProperties(ctx); + } + + @Override + public Void visitUnsetNamespaceProperties(UnsetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitUnsetNamespaceProperties(ctx); + } + + @Override + public Void visitSetNamespaceLocation(SetNamespaceLocationContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceLocation(ctx); + } + + @Override + public Void visitAlterViewQuery(AlterViewQueryContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewQuery(ctx); + } + + @Override + public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewSchemaBinding(ctx); + } + + @Override + public Void visitRenameTable(RenameTableContext ctx) { + TerminalNode view = ctx.VIEW(); + TerminalNode table = ctx.TABLE(); + if (ctx.VIEW() != null) { + validateAllowed(GrammarElement.ALTER_VIEW); + } else if (ctx.TABLE() != null) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + } + + return super.visitRenameTable(ctx); + } + + @Override + public Void visitCreateNamespace(CreateNamespaceContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateNamespace(ctx); + } + + @Override + public Void visitDropNamespace(DropNamespaceContext ctx) { + validateAllowed(GrammarElement.DROP_NAMESPACE); + return super.visitDropNamespace(ctx); + } + + @Override + public Void visitCreateView(CreateViewContext ctx) { + validateAllowed(GrammarElement.CREATE_VIEW); + return super.visitCreateView(ctx); + } + + @Override + public Void visitDropView(DropViewContext ctx) { + validateAllowed(GrammarElement.DROP_VIEW); + return super.visitDropView(ctx); + } + + @Override + public Void visitDropFunction(DropFunctionContext ctx) { + validateAllowed(GrammarElement.DROP_FUNCTION); + return super.visitDropFunction(ctx); + } + + @Override + public Void visitInsertOverwriteTable(InsertOverwriteTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteTable(ctx); + } + + @Override + public Void visitInsertIntoReplaceWhere(InsertIntoReplaceWhereContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoReplaceWhere(ctx); + } + + @Override + public Void visitInsertIntoTable(InsertIntoTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoTable(ctx); + } + + @Override + public Void visitInsertOverwriteDir(InsertOverwriteDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteDir(ctx); + } + + @Override + public Void visitInsertOverwriteHiveDir(InsertOverwriteHiveDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteHiveDir(ctx); + } + + @Override + public Void visitLoadData(LoadDataContext ctx) { + validateAllowed(GrammarElement.LOAD); + return super.visitLoadData(ctx); + } + + @Override + public Void visitExplain(ExplainContext ctx) { + validateAllowed(GrammarElement.EXPLAIN); + return super.visitExplain(ctx); + } + + @Override + public Void visitCtes(CtesContext ctx) { + validateAllowed(GrammarElement.WITH); + return super.visitCtes(ctx); + } + + @Override + public Void visitClusterBySpec(ClusterBySpecContext ctx) { + validateAllowed(GrammarElement.CLUSTER_BY); + return super.visitClusterBySpec(ctx); + } + + @Override + public Void visitQueryOrganization(QueryOrganizationContext ctx) { + if (ctx.CLUSTER() != null) { + validateAllowed(GrammarElement.CLUSTER_BY); + } else if (ctx.DISTRIBUTE() != null) { + validateAllowed(GrammarElement.DISTRIBUTE_BY); + } + return super.visitQueryOrganization(ctx); + } + + @Override + public Void visitHint(HintContext ctx) { + validateAllowed(GrammarElement.HINTS); + return super.visitHint(ctx); + } + + @Override + public Void visitInlineTable(InlineTableContext ctx) { + validateAllowed(GrammarElement.INLINE_TABLE); + return super.visitInlineTable(ctx); + } + + @Override + public Void visitJoinType(JoinTypeContext ctx) { + if (ctx.CROSS() != null) { + validateAllowed(GrammarElement.CROSS_JOIN); + } else if (ctx.LEFT() != null && ctx.SEMI() != null) { + validateAllowed(GrammarElement.LEFT_SEMI_JOIN); + } else if (ctx.ANTI() != null) { + validateAllowed(GrammarElement.LEFT_ANTI_JOIN); + } else if (ctx.LEFT() != null) { + validateAllowed(GrammarElement.LEFT_OUTER_JOIN); + } else if (ctx.RIGHT() != null) { + validateAllowed(GrammarElement.RIGHT_OUTER_JOIN); + } else if (ctx.FULL() != null) { + validateAllowed(GrammarElement.FULL_OUTER_JOIN); + } else { + validateAllowed(GrammarElement.INNER_JOIN); + } + return super.visitJoinType(ctx); + } + + @Override + public Void visitSample(SampleContext ctx) { + validateAllowed(GrammarElement.TABLESAMPLE); + return super.visitSample(ctx); + } + + @Override + public Void visitTableValuedFunction(TableValuedFunctionContext ctx) { + validateAllowed(GrammarElement.TABLE_VALUED_FUNCTION); + return super.visitTableValuedFunction(ctx); + } + + @Override + public Void visitLateralView(LateralViewContext ctx) { + validateAllowed(GrammarElement.LATERAL_VIEW); + return super.visitLateralView(ctx); + } + + @Override + public Void visitRelation(RelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitRelation(ctx); + } + + @Override + public Void visitJoinRelation(JoinRelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitJoinRelation(ctx); + } + + @Override + public Void visitTransformClause(TransformClauseContext ctx) { + if (ctx.TRANSFORM() != null) { + validateAllowed(GrammarElement.TRANSFORM); + } + return super.visitTransformClause(ctx); + } + + @Override + public Void visitManageResource(ManageResourceContext ctx) { + validateAllowed(GrammarElement.MANAGE_RESOURCE); + return super.visitManageResource(ctx); + } + + @Override + public Void visitAnalyze(AnalyzeContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyze(ctx); + } + + @Override + public Void visitAnalyzeTables(AnalyzeTablesContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyzeTables(ctx); + } + + @Override + public Void visitCacheTable(CacheTableContext ctx) { + validateAllowed(GrammarElement.CACHE_TABLE); + return super.visitCacheTable(ctx); + } + + @Override + public Void visitClearCache(ClearCacheContext ctx) { + validateAllowed(GrammarElement.CLEAR_CACHE); + return super.visitClearCache(ctx); + } + + @Override + public Void visitDescribeNamespace(DescribeNamespaceContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_NAMESPACE); + return super.visitDescribeNamespace(ctx); + } + + @Override + public Void visitDescribeFunction(DescribeFunctionContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_FUNCTION); + return super.visitDescribeFunction(ctx); + } + + @Override + public Void visitDescribeRelation(DescribeRelationContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_TABLE); + return super.visitDescribeRelation(ctx); + } + + @Override + public Void visitDescribeQuery(DescribeQueryContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_QUERY); + return super.visitDescribeQuery(ctx); + } + + @Override + public Void visitRefreshResource(RefreshResourceContext ctx) { + validateAllowed(GrammarElement.REFRESH_RESOURCE); + return super.visitRefreshResource(ctx); + } + + @Override + public Void visitRefreshTable(RefreshTableContext ctx) { + validateAllowed(GrammarElement.REFRESH_TABLE); + return super.visitRefreshTable(ctx); + } + + @Override + public Void visitRefreshFunction(RefreshFunctionContext ctx) { + validateAllowed(GrammarElement.REFRESH_FUNCTION); + return super.visitRefreshFunction(ctx); + } + + @Override + public Void visitResetConfiguration(ResetConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetConfiguration(ctx); + } + + @Override + public Void visitResetQuotedConfiguration(ResetQuotedConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetQuotedConfiguration(ctx); + } + + @Override + public Void visitSetConfiguration(SetConfigurationContext ctx) { + validateAllowed(GrammarElement.SET); + return super.visitSetConfiguration(ctx); + } + + @Override + public Void visitSetQuantifier(SetQuantifierContext ctx) { + validateAllowed(GrammarElement.SET); + return super.visitSetQuantifier(ctx); + } + + @Override + public Void visitShowColumns(ShowColumnsContext ctx) { + validateAllowed(GrammarElement.SHOW_COLUMNS); + return super.visitShowColumns(ctx); + } + + @Override + public Void visitShowCreateTable(ShowCreateTableContext ctx) { + validateAllowed(GrammarElement.SHOW_CREATE_TABLE); + return super.visitShowCreateTable(ctx); + } + + @Override + public Void visitShowNamespaces(ShowNamespacesContext ctx) { + validateAllowed(GrammarElement.SHOW_NAMESPACES); + return super.visitShowNamespaces(ctx); + } + + @Override + public Void visitShowFunctions(ShowFunctionsContext ctx) { + validateAllowed(GrammarElement.SHOW_FUNCTIONS); + return super.visitShowFunctions(ctx); + } + + @Override + public Void visitShowPartitions(ShowPartitionsContext ctx) { + validateAllowed(GrammarElement.SHOW_PARTITIONS); + return super.visitShowPartitions(ctx); + } + + @Override + public Void visitShowTableExtended(ShowTableExtendedContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLE_EXTENDED); + return super.visitShowTableExtended(ctx); + } + + @Override + public Void visitShowTables(ShowTablesContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLES); + return super.visitShowTables(ctx); + } + + @Override + public Void visitShowTblProperties(ShowTblPropertiesContext ctx) { + validateAllowed(GrammarElement.SHOW_TBLPROPERTIES); + return super.visitShowTblProperties(ctx); + } + + @Override + public Void visitShowViews(ShowViewsContext ctx) { + validateAllowed(GrammarElement.SHOW_VIEWS); + return super.visitShowViews(ctx); + } + + @Override + public Void visitUncacheTable(UncacheTableContext ctx) { + validateAllowed(GrammarElement.UNCACHE_TABLE); + return super.visitUncacheTable(ctx); + } + + @Override + public Void visitFunctionIdentifier(FunctionIdentifierContext ctx) { + String function = ctx.function.getText().toLowerCase(); + if (isMapFunctions(function)) { + validateAllowed(GrammarElement.MAP_FUNCTIONS); + } else if (isCsvFunctions(function)) { + validateAllowed(GrammarElement.CSV_FUNCTIONS); + } else if (isMiscFunctions(function)) { + validateAllowed(GrammarElement.MISC_FUNCTIONS); + } + return super.visitFunctionIdentifier(ctx); + } + + private boolean isMapFunctions(String function) { + // TODO: to be implemented + return false; + } + + private boolean isCsvFunctions(String function) { + // TODO: to be implemented + return false; + } + + private boolean isMiscFunctions(String function) { + // TODO: to be implemented + return false; + } + + private void validateAllowed(GrammarElement element) { + if (!grammarElementValidator.isValid(element)) { + throw new IllegalArgumentException(element + " is not allowed."); + } + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java new file mode 100644 index 0000000000..85f9d0f284 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -0,0 +1,301 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import lombok.AllArgsConstructor; +import org.antlr.v4.runtime.CommonTokenStream; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SingleStatementContext; + +class SQLQueryValidatorTest { + GrammarElementValidatorFactory factory = new GrammarElementValidatorFactory(); + + @AllArgsConstructor + private enum TestQuery { + // DDL Statements + ALTER_DATABASE( + "ALTER DATABASE inventory SET DBPROPERTIES ('Edited-by' = 'John', 'Edit-date' =" + + " '01/01/2001');"), + ALTER_TABLE( + "ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');"), + ALTER_VIEW("ALTER VIEW tempdb1.v1 RENAME TO tempdb1.v2;"), + CREATE_DATABASE("CREATE DATABASE IF NOT EXISTS customer_db;\n"), + CREATE_FUNCTION("CREATE FUNCTION simple_udf AS 'SimpleUdf' USING JAR '/tmp/SimpleUdf.jar';"), + CREATE_TABLE("CREATE TABLE Student_Dupli like Student;"), + CREATE_VIEW( + "CREATE OR REPLACE VIEW experienced_employee" + + " (ID COMMENT 'Unique identification number', Name)" + + " COMMENT 'View for experienced employees'" + + " AS SELECT id, name FROM all_employee" + + " WHERE working_years > 5;"), + DROP_DATABASE("DROP DATABASE inventory_db CASCADE;"), + DROP_FUNCTION("DROP FUNCTION test_avg;"), + DROP_TABLE("DROP TABLE employeetable;"), + DROP_VIEW("DROP VIEW employeeView;"), + REPAIR_TABLE("REPAIR TABLE t1;"), + TRUNCATE_TABLE("TRUNCATE TABLE Student partition(age=10);"), + + // DML Statements + INSERT_TABLE("INSERT INTO target_table SELECT * FROM source_table;"), + INSERT_OVERWRITE_DIRECTORY( + "INSERT OVERWRITE DIRECTORY '/path/to/output' SELECT * FROM source_table;"), + LOAD("LOAD DATA INPATH '/path/to/data' INTO TABLE target_table;"), + + // Data Retrieval Statements + SELECT("SELECT 1"), + EXPLAIN("EXPLAIN SELECT * FROM my_table;"), + COMMON_TABLE_EXPRESSION( + "WITH cte AS (SELECT * FROM my_table WHERE age > 30) SELECT * FROM cte;"), + CLUSTER_BY_CLAUSE("SELECT * FROM my_table CLUSTER BY age;"), + DISTRIBUTE_BY_CLAUSE("SELECT * FROM my_table DISTRIBUTE BY name;"), + GROUP_BY_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name;"), + HAVING_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name HAVING count(*) > 1;"), + HINTS("SELECT /*+ BROADCAST(my_table) */ * FROM my_table;"), + INLINE_TABLE("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS inline_table(id, value);"), + FILE("SELECT * FROM text.`/path/to/file.txt`;"), + INNER_JOIN("SELECT t1.name, t2.age FROM table1 t1 INNER JOIN table2 t2 ON t1.id = t2.id;"), + CROSS_JOIN("SELECT t1.name, t2.age FROM table1 t1 CROSS JOIN table2 t2;"), + LEFT_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 LEFT OUTER JOIN table2 t2 ON t1.id = t2.id;"), + LEFT_SEMI_JOIN("SELECT t1.name FROM table1 t1 LEFT SEMI JOIN table2 t2 ON t1.id = t2.id;"), + RIGHT_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 RIGHT OUTER JOIN table2 t2 ON t1.id = t2.id;"), + FULL_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 FULL OUTER JOIN table2 t2 ON t1.id = t2.id;"), + LEFT_ANTI_JOIN("SELECT t1.name FROM table1 t1 LEFT ANTI JOIN table2 t2 ON t1.id = t2.id;"), + LIKE_PREDICATE("SELECT * FROM my_table WHERE name LIKE 'A%';"), + LIMIT_CLAUSE("SELECT * FROM my_table LIMIT 10;"), + OFFSET_CLAUSE("SELECT * FROM my_table OFFSET 5 ROWS;"), + ORDER_BY_CLAUSE("SELECT * FROM my_table ORDER BY age DESC;"), + SET_OPERATORS("SELECT * FROM table1 UNION SELECT * FROM table2;"), + SORT_BY_CLAUSE("SELECT * FROM my_table SORT BY age DESC;"), + TABLESAMPLE("SELECT * FROM my_table TABLESAMPLE(10 PERCENT);"), + // TABLE_VALUED_FUNCTION("SELECT explode(array(10, 20));"), TODO: Need to handle this case + TABLE_VALUED_FUNCTION("SELECT * FROM explode(array(10, 20));"), + WHERE_CLAUSE("SELECT * FROM my_table WHERE age > 30;"), + AGGREGATE_FUNCTION("SELECT count(*) FROM my_table;"), + WINDOW_FUNCTION("SELECT name, age, rank() OVER (ORDER BY age DESC) FROM my_table;"), + CASE_CLAUSE("SELECT name, CASE WHEN age > 30 THEN 'Adult' ELSE 'Young' END FROM my_table;"), + PIVOT_CLAUSE( + "SELECT * FROM (SELECT name, age, gender FROM my_table) PIVOT (COUNT(*) FOR gender IN ('M'," + + " 'F'));"), + UNPIVOT_CLAUSE( + "SELECT name, value, category FROM (SELECT name, 'M' AS gender, age AS male_age, 0 AS" + + " female_age FROM my_table) UNPIVOT (value FOR category IN (male_age, female_age));"), + LATERAL_VIEW_CLAUSE( + "SELECT name, age, exploded_value FROM my_table LATERAL VIEW OUTER EXPLODE(split(comments," + + " ',')) exploded_table AS exploded_value;"), + LATERAL_SUBQUERY( + "SELECT name, age, (SELECT max(age) FROM my_table t2 WHERE t1.age < t2.age) AS next_age" + + " FROM my_table t1;"), + TRANSFORM_CLAUSE( + "SELECT transform(zip_code, name, age) USING 'cat' AS (a, b, c) FROM my_table;"), + + // Auxiliary Statements + ADD_FILE("ADD FILE /tmp/test.txt;"), + ADD_JAR("ADD JAR /path/to/my.jar;"), + ANALYZE_TABLE("ANALYZE TABLE my_table COMPUTE STATISTICS;"), + CACHE_TABLE("CACHE TABLE my_table;"), + CLEAR_CACHE("CLEAR CACHE;"), + DESCRIBE_DATABASE("DESCRIBE DATABASE my_db;"), + DESCRIBE_FUNCTION("DESCRIBE FUNCTION my_function;"), + DESCRIBE_QUERY("DESCRIBE QUERY SELECT * FROM my_table;"), + DESCRIBE_TABLE("DESCRIBE TABLE my_table;"), + LIST_FILE("LIST FILE '/path/to/files';"), + LIST_JAR("LIST JAR;"), + REFRESH("REFRESH;"), + REFRESH_TABLE("REFRESH TABLE my_table;"), + REFRESH_FUNCTION("REFRESH FUNCTION my_function;"), + RESET("RESET;"), + SET("SET spark.sql.shuffle.partitions=200;"), + SHOW_COLUMNS("SHOW COLUMNS FROM my_table;"), + SHOW_CREATE_TABLE("SHOW CREATE TABLE my_table;"), + SHOW_DATABASES("SHOW DATABASES;"), + SHOW_FUNCTIONS("SHOW FUNCTIONS;"), + SHOW_PARTITIONS("SHOW PARTITIONS my_table;"), + SHOW_TABLE_EXTENDED("SHOW TABLE EXTENDED LIKE 'my_table';"), + SHOW_TABLES("SHOW TABLES;"), + SHOW_TBLPROPERTIES("SHOW TBLPROPERTIES my_table;"), + SHOW_VIEWS("SHOW VIEWS;"), + UNCACHE_TABLE("UNCACHE TABLE my_table;"), + + // Functions + ARRAY_FUNCTIONS("SELECT array_contains(array(1, 2, 3), 2);"), + MAP_FUNCTIONS("SELECT map_keys(map('a', 1, 'b', 2));"), + DATE_AND_TIMESTAMP_FUNCTIONS("SELECT date_format(current_date(), 'yyyy-MM-dd');"), + JSON_FUNCTIONS("SELECT json_tuple('{\"a\":1, \"b\":2}', 'a', 'b');"), + MATHEMATICAL_FUNCTIONS("SELECT round(3.1415, 2);"), + STRING_FUNCTIONS("SELECT concat('Hello', ' ', 'World');"), + BITWISE_FUNCTIONS("SELECT bitwiseNOT(42);"), + CONVERSION_FUNCTIONS("SELECT cast('2023-04-01' as date);"), + CONDITIONAL_FUNCTIONS("SELECT if(1 > 0, 'true', 'false');"), + PREDICATE_FUNCTIONS("SELECT array_exists(array(1, 2, 3), x -> x > 2);"), + CSV_FUNCTIONS("SELECT csv_from_array(array('a', 'b', 'c'), ',');"), + MISC_FUNCTIONS("SELECT hash('Hello World');"), + + // Aggregate-like Functions + AGGREGATE_FUNCTIONS("SELECT count(*), max(age), min(age) FROM my_table;"), + WINDOW_FUNCTIONS("SELECT name, age, rank() OVER (ORDER BY age DESC) FROM my_table;"), + + // Generator Functions + GENERATOR_FUNCTIONS("SELECT explode(array(1, 2, 3));"), + + // UDFs (User-Defined Functions) + SCALAR_USER_DEFINED_FUNCTIONS("SELECT my_udf(name) FROM my_table;"), + USER_DEFINED_AGGREGATE_FUNCTIONS("SELECT my_udaf(age) FROM my_table GROUP BY name;"), + INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS("SELECT my_hive_udf(name) FROM my_table;"); + + private final String query; + + @Override + public String toString() { + return query; + } + } + + @Test + void s3glueQueries() { + SQLQueryValidator v = + new SQLQueryValidator(factory.getValidatorForDatasource(DataSourceType.S3GLUE)); + verifyValid(v, TestQuery.ALTER_DATABASE); + verifyValid(v, TestQuery.ALTER_TABLE); + verifyInvalid(v, TestQuery.ALTER_VIEW); + verifyValid(v, TestQuery.CREATE_DATABASE); + verifyInvalid(v, TestQuery.CREATE_FUNCTION); + verifyValid(v, TestQuery.CREATE_TABLE); + verifyInvalid(v, TestQuery.CREATE_VIEW); + verifyValid(v, TestQuery.DROP_DATABASE); + verifyInvalid(v, TestQuery.DROP_FUNCTION); + verifyValid(v, TestQuery.DROP_TABLE); + verifyInvalid(v, TestQuery.DROP_VIEW); + verifyValid(v, TestQuery.REPAIR_TABLE); + verifyValid(v, TestQuery.TRUNCATE_TABLE); + + // DML Statements + verifyInvalid(v, TestQuery.INSERT_TABLE); + verifyInvalid(v, TestQuery.INSERT_OVERWRITE_DIRECTORY); + verifyInvalid(v, TestQuery.LOAD); + + // Data Retrieval + verifyValid(v, TestQuery.SELECT); + verifyValid(v, TestQuery.EXPLAIN); + verifyValid(v, TestQuery.COMMON_TABLE_EXPRESSION); + verifyInvalid(v, TestQuery.CLUSTER_BY_CLAUSE); + verifyInvalid(v, TestQuery.DISTRIBUTE_BY_CLAUSE); + verifyValid(v, TestQuery.GROUP_BY_CLAUSE); + verifyValid(v, TestQuery.HAVING_CLAUSE); + verifyInvalid(v, TestQuery.HINTS); + verifyInvalid(v, TestQuery.INLINE_TABLE); + // verifyInvalid(v, TestQuery.FILE); TODO: need dive deep + verifyValid(v, TestQuery.INNER_JOIN); + verifyInvalid(v, TestQuery.CROSS_JOIN); + verifyValid(v, TestQuery.LEFT_OUTER_JOIN); + verifyInvalid(v, TestQuery.LEFT_SEMI_JOIN); + verifyInvalid(v, TestQuery.RIGHT_OUTER_JOIN); + verifyInvalid(v, TestQuery.FULL_OUTER_JOIN); + verifyInvalid(v, TestQuery.LEFT_ANTI_JOIN); + verifyValid(v, TestQuery.LIKE_PREDICATE); + verifyValid(v, TestQuery.LIMIT_CLAUSE); + verifyValid(v, TestQuery.OFFSET_CLAUSE); + verifyValid(v, TestQuery.ORDER_BY_CLAUSE); + verifyValid(v, TestQuery.SET_OPERATORS); + verifyValid(v, TestQuery.SORT_BY_CLAUSE); + verifyInvalid(v, TestQuery.TABLESAMPLE); + verifyInvalid(v, TestQuery.TABLE_VALUED_FUNCTION); + verifyValid(v, TestQuery.WHERE_CLAUSE); + verifyValid(v, TestQuery.AGGREGATE_FUNCTION); + verifyValid(v, TestQuery.WINDOW_FUNCTION); + verifyValid(v, TestQuery.CASE_CLAUSE); + verifyValid(v, TestQuery.PIVOT_CLAUSE); + verifyValid(v, TestQuery.UNPIVOT_CLAUSE); + verifyValid(v, TestQuery.LATERAL_VIEW_CLAUSE); + verifyValid(v, TestQuery.LATERAL_SUBQUERY); + verifyInvalid(v, TestQuery.TRANSFORM_CLAUSE); + + // Auxiliary Statements + verifyInvalid(v, TestQuery.ADD_FILE); + verifyInvalid(v, TestQuery.ADD_JAR); + verifyValid(v, TestQuery.ANALYZE_TABLE); + verifyValid(v, TestQuery.CACHE_TABLE); + verifyValid(v, TestQuery.CLEAR_CACHE); + verifyValid(v, TestQuery.DESCRIBE_DATABASE); + verifyInvalid(v, TestQuery.DESCRIBE_FUNCTION); + verifyValid(v, TestQuery.DESCRIBE_QUERY); + verifyValid(v, TestQuery.DESCRIBE_TABLE); + verifyInvalid(v, TestQuery.LIST_FILE); + verifyInvalid(v, TestQuery.LIST_JAR); + verifyInvalid(v, TestQuery.REFRESH); + // verifyValid(v, TestQuery.REFRESH_TABLE); TODO: refreshTable rule won't match (matches to + // refreshResource) + verifyInvalid(v, TestQuery.REFRESH_FUNCTION); + verifyInvalid(v, TestQuery.RESET); + verifyInvalid(v, TestQuery.SET); + verifyValid(v, TestQuery.SHOW_COLUMNS); + verifyValid(v, TestQuery.SHOW_CREATE_TABLE); + verifyValid(v, TestQuery.SHOW_DATABASES); + verifyInvalid(v, TestQuery.SHOW_FUNCTIONS); + verifyValid(v, TestQuery.SHOW_PARTITIONS); + verifyValid(v, TestQuery.SHOW_TABLE_EXTENDED); + verifyValid(v, TestQuery.SHOW_TABLES); + verifyValid(v, TestQuery.SHOW_TBLPROPERTIES); + verifyInvalid(v, TestQuery.SHOW_VIEWS); + verifyValid(v, TestQuery.UNCACHE_TABLE); + + // Functions + // verifyValid(v, TestQuery.ARRAY_FUNCTIONS); + // verifyValid(v, TestQuery.MAP_FUNCTIONS); + // verifyValid(v, TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); + // verifyValid(v, TestQuery.JSON_FUNCTIONS); + // verifyValid(v, TestQuery.MATHEMATICAL_FUNCTIONS); + // verifyValid(v, TestQuery.STRING_FUNCTIONS); + // verifyValid(v, TestQuery.BITWISE_FUNCTIONS); + // verifyValid(v, TestQuery.CONVERSION_FUNCTIONS); + // verifyValid(v, TestQuery.CONDITIONAL_FUNCTIONS); + // verifyValid(v, TestQuery.PREDICATE_FUNCTIONS); + // verifyValid(v, TestQuery.CSV_FUNCTIONS); + // verifyValid(v, TestQuery.MISC_FUNCTIONS); + + // Aggregate-like Functions + // verifyValid(v, TestQuery.AGGREGATE_FUNCTIONS); + // verifyValid(v, TestQuery.WINDOW_FUNCTIONS); + + // Generator Functions + // verifyValid(v, TestQuery.GENERATOR_FUNCTIONS); + + // UDFs + // verifyInvalid(v, TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); + // verifyInvalid(v, TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); + // verifyInvalid(v, TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + + void verifyValid(SQLQueryValidator validator, TestQuery query) { + runValidate(validator, query.toString()); + } + + void verifyInvalid(SQLQueryValidator validator, TestQuery query) { + assertThrows( + IllegalArgumentException.class, + () -> runValidate(validator, query.toString()), + "The query should throw: query=`" + query.toString() + "`"); + } + + void runValidate(SQLQueryValidator validator, String query) { + validator.validate(getParser(query)); + } + + SingleStatementContext getParser(String query) { + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(query)))); + return sqlBaseParser.singleStatement(); + } +}