From 6c91e11eb4035b9d02847fc789bfcc48d1fb1fad Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Tue, 17 Sep 2024 15:01:29 -0700 Subject: [PATCH] Add function types Signed-off-by: Tomoyuki Morita --- .../sql/spark/validator/FunctionType.java | 432 ++++++++++++++++++ .../sql/spark/validator/GrammarElement.java | 3 +- .../GrammarElementValidatorFactory.java | 47 +- .../spark/validator/SQLQueryValidator.java | 47 +- .../sql/spark/validator/FunctionTypeTest.java | 46 ++ .../validator/SQLQueryValidatorTest.java | 46 +- 6 files changed, 570 insertions(+), 51 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java new file mode 100644 index 0000000000..0a821a7a8c --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java @@ -0,0 +1,432 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public enum FunctionType { + AGGREGATE("Aggregate"), + WINDOW("Window"), + ARRAY("Array"), + MAP("Map"), + DATE_TIMESTAMP("Date and Timestamp"), + JSON("JSON"), + MATH("Math"), + STRING("String"), + CONDITIONAL("Conditional"), + BITWISE("Bitwise"), + CONVERSION("Conversion"), + PREDICATE("Predicate"), + CSV("CSV"), + MISC("Misc"), + GENERATOR("Generator"), + UDF("User Defined Function"); + + private final String name; + + private static final Map> FUNCTION_TYPE_TO_FUNCTION_NAMES_MAP = + ImmutableMap.>builder() + .put( + AGGREGATE, + Set.of( + "any", + "any_value", + "approx_count_distinct", + "approx_percentile", + "array_agg", + "avg", + "bit_and", + "bit_or", + "bit_xor", + "bitmap_construct_agg", + "bitmap_or_agg", + "bool_and", + "bool_or", + "collect_list", + "collect_set", + "corr", + "count", + "count_if", + "count_min_sketch", + "covar_pop", + "covar_samp", + "every", + "first", + "first_value", + "grouping", + "grouping_id", + "histogram_numeric", + "hll_sketch_agg", + "hll_union_agg", + "kurtosis", + "last", + "last_value", + "max", + "max_by", + "mean", + "median", + "min", + "min_by", + "mode", + "percentile", + "percentile_approx", + "regr_avgx", + "regr_avgy", + "regr_count", + "regr_intercept", + "regr_r2", + "regr_slope", + "regr_sxx", + "regr_sxy", + "regr_syy", + "skewness", + "some", + "std", + "stddev", + "stddev_pop", + "stddev_samp", + "sum", + "try_avg", + "try_sum", + "var_pop", + "var_samp", + "variance")) + .put( + WINDOW, + Set.of( + "cume_dist", + "dense_rank", + "lag", + "lead", + "nth_value", + "ntile", + "percent_rank", + "rank", + "row_number")) + .put( + ARRAY, + Set.of( + "array", + "array_append", + "array_compact", + "array_contains", + "array_distinct", + "array_except", + "array_insert", + "array_intersect", + "array_join", + "array_max", + "array_min", + "array_position", + "array_prepend", + "array_remove", + "array_repeat", + "array_union", + "arrays_overlap", + "arrays_zip", + "flatten", + "get", + "sequence", + "shuffle", + "slice", + "sort_array")) + .put( + MAP, + Set.of( + "element_at", + "map", + "map_concat", + "map_contains_key", + "map_entries", + "map_from_arrays", + "map_from_entries", + "map_keys", + "map_values", + "str_to_map", + "try_element_at")) + .put( + DATE_TIMESTAMP, + Set.of( + "add_months", + "convert_timezone", + "curdate", + "current_date", + "current_timestamp", + "current_timezone", + "date_add", + "date_diff", + "date_format", + "date_from_unix_date", + "date_part", + "date_sub", + "date_trunc", + "dateadd", + "datediff", + "datepart", + "day", + "dayofmonth", + "dayofweek", + "dayofyear", + "extract", + "from_unixtime", + "from_utc_timestamp", + "hour", + "last_day", + "localtimestamp", + "make_date", + "make_dt_interval", + "make_interval", + "make_timestamp", + "make_timestamp_ltz", + "make_timestamp_ntz", + "make_ym_interval", + "minute", + "month", + "months_between", + "next_day", + "now", + "quarter", + "second", + "session_window", + "timestamp_micros", + "timestamp_millis", + "timestamp_seconds", + "to_date", + "to_timestamp", + "to_timestamp_ltz", + "to_timestamp_ntz", + "to_unix_timestamp", + "to_utc_timestamp", + "trunc", + "try_to_timestamp", + "unix_date", + "unix_micros", + "unix_millis", + "unix_seconds", + "unix_timestamp", + "weekday", + "weekofyear", + "window", + "window_time", + "year")) + .put( + JSON, + Set.of( + "from_json", + "get_json_object", + "json_array_length", + "json_object_keys", + "json_tuple", + "schema_of_json", + "to_json")) + .put( + MATH, + Set.of( + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bin", + "bround", + "cbrt", + "ceil", + "ceiling", + "conv", + "cos", + "cosh", + "cot", + "csc", + "degrees", + "e", + "exp", + "expm1", + "factorial", + "floor", + "greatest", + "hex", + "hypot", + "least", + "ln", + "log", + "log10", + "log1p", + "log2", + "negative", + "pi", + "pmod", + "positive", + "pow", + "power", + "radians", + "rand", + "randn", + "random", + "rint", + "round", + "sec", + "shiftleft", + "sign", + "signum", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "try_add", + "try_divide", + "try_multiply", + "try_subtract", + "unhex", + "width_bucket")) + .put( + STRING, + Set.of( + "ascii", + "base64", + "bit_length", + "btrim", + "char", + "char_length", + "character_length", + "chr", + "concat", + "concat_ws", + "contains", + "decode", + "elt", + "encode", + "endswith", + "find_in_set", + "format_number", + "format_string", + "initcap", + "instr", + "lcase", + "left", + "len", + "length", + "levenshtein", + "locate", + "lower", + "lpad", + "ltrim", + "luhn_check", + "mask", + "octet_length", + "overlay", + "position", + "printf", + "regexp_count", + "regexp_extract", + "regexp_extract_all", + "regexp_instr", + "regexp_replace", + "regexp_substr", + "repeat", + "replace", + "right", + "rpad", + "rtrim", + "sentences", + "soundex", + "space", + "split", + "split_part", + "startswith", + "substr", + "substring", + "substring_index", + "to_binary", + "to_char", + "to_number", + "to_varchar", + "translate", + "trim", + "try_to_binary", + "try_to_number", + "ucase", + "unbase64", + "upper")) + .put(CONDITIONAL, Set.of("coalesce", "if", "ifnull", "nanvl", "nullif", "nvl", "nvl2")) + .put( + BITWISE, Set.of("bit_count", "bit_get", "getbit", "shiftright", "shiftrightunsigned")) + .put( + CONVERSION, + Set.of( + "bigint", + "binary", + "boolean", + "cast", + "date", + "decimal", + "double", + "float", + "int", + "smallint", + "string", + "timestamp", + "tinyint")) + .put(PREDICATE, Set.of("isnan", "isnotnull", "isnull", "regexp", "regexp_like", "rlike")) + .put(CSV, Set.of("from_csv","schema_of_csv","to_csv")) + .put( + MISC, + Set.of( + "aes_decrypt", + "aes_encrypt", + "assert_true", + "bitmap_bit_position", + "bitmap_bucket_number", + "bitmap_count", + "current_catalog", + "current_database", + "current_schema", + "current_user", + "equal_null", + "hll_sketch_estimate", + "hll_union", + "input_file_block_length", + "input_file_block_start", + "input_file_name", + "java_method", + "monotonically_increasing_id", + "reflect", + "spark_partition_id", + "try_aes_decrypt", + "typeof", + "user", + "uuid", + "version")) + .put( + GENERATOR, + Set.of( + "explode", + "explode_outer", + "inline", + "inline_outer", + "posexplode", + "posexplode_outer", + "stack")) + .build(); + + private static final Map FUNCTION_NAME_TO_FUNCTION_TYPE_MAP = + FUNCTION_TYPE_TO_FUNCTION_NAMES_MAP.entrySet().stream() + .flatMap( + entry -> entry.getValue().stream().map(value -> Map.entry(value, entry.getKey()))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public static FunctionType fromFunctionName(String functionName) { + return FUNCTION_NAME_TO_FUNCTION_TYPE_MAP.getOrDefault(functionName.toLowerCase(), UDF); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java index 562a83dcd4..05d878e9dc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -76,7 +76,8 @@ enum GrammarElement { CSV_FUNCTIONS("CSV functions"), MISC_FUNCTIONS("Misc functions"), - SELECT("SELECT"); + // UDF + UDF("User Defined functions"); String description; diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java index 99cecf18ae..57374b852c 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorFactory.java @@ -27,7 +27,49 @@ public class GrammarElementValidatorFactory { DROP_NAMESPACE, DROP_VIEW, REPAIR_TABLE, - TRUNCATE_TABLE) + TRUNCATE_TABLE, + EXPLAIN, + WITH, + CLUSTER_BY, + DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + CROSS_JOIN, + LEFT_SEMI_JOIN, + RIGHT_OUTER_JOIN, + FULL_OUTER_JOIN, + LEFT_ANTI_JOIN, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + LATERAL_VIEW, + LATERAL_SUBQUERY, + TRANSFORM, + MANAGE_RESOURCE, + ANALYZE_TABLE, + CACHE_TABLE, + DESCRIBE_NAMESPACE, + DESCRIBE_FUNCTION, + DESCRIBE_QUERY, + DESCRIBE_TABLE, + REFRESH_RESOURCE, + REFRESH_TABLE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_COLUMNS, + SHOW_CREATE_TABLE, + SHOW_NAMESPACES, + SHOW_FUNCTIONS, + SHOW_PARTITIONS, + SHOW_TABLE_EXTENDED, + SHOW_TABLES, + SHOW_TBLPROPERTIES, + SHOW_VIEWS, + UNCACHE_TABLE, + CSV_FUNCTIONS, + MISC_FUNCTIONS, + UDF + ) .build(); private static final Set S3GLUE_DENY_LIST = @@ -58,7 +100,8 @@ public class GrammarElementValidatorFactory { SET, SHOW_FUNCTIONS, SHOW_VIEWS, - MISC_FUNCTIONS) + MISC_FUNCTIONS, + UDF) .build(); private static Map validatorMap = diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java index a737c62071..fba973f930 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -27,6 +27,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionIdentifierContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; @@ -82,12 +83,6 @@ public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { return super.visitCreateFunction(ctx); } - @Override - public Void visitSelectClause(SelectClauseContext ctx) { - validateAllowed(GrammarElement.SELECT); - return super.visitSelectClause(ctx); - } - @Override public Void visitSetNamespaceProperties(SetNamespacePropertiesContext ctx) { validateAllowed(GrammarElement.ALTER_NAMESPACE); @@ -457,30 +452,32 @@ public Void visitUncacheTable(UncacheTableContext ctx) { @Override public Void visitFunctionIdentifier(FunctionIdentifierContext ctx) { - String function = ctx.function.getText().toLowerCase(); - if (isMapFunctions(function)) { - validateAllowed(GrammarElement.MAP_FUNCTIONS); - } else if (isCsvFunctions(function)) { - validateAllowed(GrammarElement.CSV_FUNCTIONS); - } else if (isMiscFunctions(function)) { - validateAllowed(GrammarElement.MISC_FUNCTIONS); - } + validateFunctionAllowed(ctx.function.getText()); return super.visitFunctionIdentifier(ctx); } - private boolean isMapFunctions(String function) { - // TODO: to be implemented - return false; - } - - private boolean isCsvFunctions(String function) { - // TODO: to be implemented - return false; + @Override + public Void visitFunctionName(FunctionNameContext ctx) { + validateFunctionAllowed(ctx.qualifiedName().getText()); + return super.visitFunctionName(ctx); } - private boolean isMiscFunctions(String function) { - // TODO: to be implemented - return false; + private void validateFunctionAllowed(String function) { + FunctionType type = FunctionType.fromFunctionName(function.toLowerCase()); + switch(type) { + case MAP: + validateAllowed(GrammarElement.MAP_FUNCTIONS); + break; + case CSV: + validateAllowed(GrammarElement.CSV_FUNCTIONS); + break; + case MISC: + validateAllowed(GrammarElement.MISC_FUNCTIONS); + break; + case UDF: + validateAllowed(GrammarElement.UDF); + break; + } } private void validateAllowed(GrammarElement element) { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java new file mode 100644 index 0000000000..920d35df2f --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +class FunctionTypeTest { + @Test + public void test() { + assertEquals(FunctionType.AGGREGATE, FunctionType.fromFunctionName("any")); + assertEquals(FunctionType.AGGREGATE, FunctionType.fromFunctionName("variance")); + assertEquals(FunctionType.WINDOW, FunctionType.fromFunctionName("cume_dist")); + assertEquals(FunctionType.WINDOW, FunctionType.fromFunctionName("row_number")); + assertEquals(FunctionType.ARRAY, FunctionType.fromFunctionName("array")); + assertEquals(FunctionType.ARRAY, FunctionType.fromFunctionName("sort_array")); + assertEquals(FunctionType.MAP, FunctionType.fromFunctionName("element_at")); + assertEquals(FunctionType.MAP, FunctionType.fromFunctionName("try_element_at")); + assertEquals(FunctionType.DATE_TIMESTAMP, FunctionType.fromFunctionName("add_months")); + assertEquals(FunctionType.DATE_TIMESTAMP, FunctionType.fromFunctionName("year")); + assertEquals(FunctionType.JSON, FunctionType.fromFunctionName("from_json")); + assertEquals(FunctionType.JSON, FunctionType.fromFunctionName("to_json")); + assertEquals(FunctionType.MATH, FunctionType.fromFunctionName("abs")); + assertEquals(FunctionType.MATH, FunctionType.fromFunctionName("width_bucket")); + assertEquals(FunctionType.STRING, FunctionType.fromFunctionName("ascii")); + assertEquals(FunctionType.STRING, FunctionType.fromFunctionName("upper")); + assertEquals(FunctionType.CONDITIONAL, FunctionType.fromFunctionName("coalesce")); + assertEquals(FunctionType.CONDITIONAL, FunctionType.fromFunctionName("nvl2")); + assertEquals(FunctionType.BITWISE, FunctionType.fromFunctionName("bit_count")); + assertEquals(FunctionType.BITWISE, FunctionType.fromFunctionName("shiftrightunsigned")); + assertEquals(FunctionType.CONVERSION, FunctionType.fromFunctionName("bigint")); + assertEquals(FunctionType.CONVERSION, FunctionType.fromFunctionName("tinyint")); + assertEquals(FunctionType.PREDICATE, FunctionType.fromFunctionName("isnan")); + assertEquals(FunctionType.PREDICATE, FunctionType.fromFunctionName("rlike")); + assertEquals(FunctionType.CSV, FunctionType.fromFunctionName("from_csv")); + assertEquals(FunctionType.CSV, FunctionType.fromFunctionName("to_csv")); + assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("aes_decrypt")); + assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("version")); + assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("explode")); + assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("stack")); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java index 85f9d0f284..88fe273aec 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -134,13 +134,13 @@ private enum TestQuery { DATE_AND_TIMESTAMP_FUNCTIONS("SELECT date_format(current_date(), 'yyyy-MM-dd');"), JSON_FUNCTIONS("SELECT json_tuple('{\"a\":1, \"b\":2}', 'a', 'b');"), MATHEMATICAL_FUNCTIONS("SELECT round(3.1415, 2);"), - STRING_FUNCTIONS("SELECT concat('Hello', ' ', 'World');"), - BITWISE_FUNCTIONS("SELECT bitwiseNOT(42);"), + STRING_FUNCTIONS("SELECT map_concat('Hello', ' ', 'World');"), + BITWISE_FUNCTIONS("SELECT bit_count(42);"), CONVERSION_FUNCTIONS("SELECT cast('2023-04-01' as date);"), CONDITIONAL_FUNCTIONS("SELECT if(1 > 0, 'true', 'false');"), - PREDICATE_FUNCTIONS("SELECT array_exists(array(1, 2, 3), x -> x > 2);"), - CSV_FUNCTIONS("SELECT csv_from_array(array('a', 'b', 'c'), ',');"), - MISC_FUNCTIONS("SELECT hash('Hello World');"), + PREDICATE_FUNCTIONS("SELECT isnotnull(1);"), + CSV_FUNCTIONS("SELECT from_csv(array('a', 'b', 'c'), ',');"), + MISC_FUNCTIONS("SELECT current_user();"), // Aggregate-like Functions AGGREGATE_FUNCTIONS("SELECT count(*), max(age), min(age) FROM my_table;"), @@ -251,30 +251,30 @@ void s3glueQueries() { verifyValid(v, TestQuery.UNCACHE_TABLE); // Functions - // verifyValid(v, TestQuery.ARRAY_FUNCTIONS); - // verifyValid(v, TestQuery.MAP_FUNCTIONS); - // verifyValid(v, TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); - // verifyValid(v, TestQuery.JSON_FUNCTIONS); - // verifyValid(v, TestQuery.MATHEMATICAL_FUNCTIONS); - // verifyValid(v, TestQuery.STRING_FUNCTIONS); - // verifyValid(v, TestQuery.BITWISE_FUNCTIONS); - // verifyValid(v, TestQuery.CONVERSION_FUNCTIONS); - // verifyValid(v, TestQuery.CONDITIONAL_FUNCTIONS); - // verifyValid(v, TestQuery.PREDICATE_FUNCTIONS); - // verifyValid(v, TestQuery.CSV_FUNCTIONS); - // verifyValid(v, TestQuery.MISC_FUNCTIONS); + verifyValid(v, TestQuery.ARRAY_FUNCTIONS); + verifyValid(v, TestQuery.MAP_FUNCTIONS); + verifyValid(v, TestQuery.DATE_AND_TIMESTAMP_FUNCTIONS); + verifyValid(v, TestQuery.JSON_FUNCTIONS); + verifyValid(v, TestQuery.MATHEMATICAL_FUNCTIONS); + verifyValid(v, TestQuery.STRING_FUNCTIONS); + verifyValid(v, TestQuery.BITWISE_FUNCTIONS); + verifyValid(v, TestQuery.CONVERSION_FUNCTIONS); + verifyValid(v, TestQuery.CONDITIONAL_FUNCTIONS); + verifyValid(v, TestQuery.PREDICATE_FUNCTIONS); + verifyValid(v, TestQuery.CSV_FUNCTIONS); + verifyInvalid(v, TestQuery.MISC_FUNCTIONS); // Aggregate-like Functions - // verifyValid(v, TestQuery.AGGREGATE_FUNCTIONS); - // verifyValid(v, TestQuery.WINDOW_FUNCTIONS); + verifyValid(v, TestQuery.AGGREGATE_FUNCTIONS); + verifyValid(v, TestQuery.WINDOW_FUNCTIONS); // Generator Functions - // verifyValid(v, TestQuery.GENERATOR_FUNCTIONS); + verifyValid(v, TestQuery.GENERATOR_FUNCTIONS); // UDFs - // verifyInvalid(v, TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); - // verifyInvalid(v, TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); - // verifyInvalid(v, TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + verifyInvalid(v, TestQuery.SCALAR_USER_DEFINED_FUNCTIONS); + verifyInvalid(v, TestQuery.USER_DEFINED_AGGREGATE_FUNCTIONS); + verifyInvalid(v, TestQuery.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); } void verifyValid(SQLQueryValidator validator, TestQuery query) {