From 2f2b690025666fdf52d097afd75cc162e70f3da9 Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Thu, 21 Jan 2021 15:02:02 +0200 Subject: [PATCH] QL: Refactor FunctionRegistry to make it pluggable (#67761) Break the FunctionRegistry monolith into a common QL base and a SQL specific registry that handles aspects such as distinct and extract. In the process clean-up the names and signature of internal interfaces. Most of the semantics were preserved however the error messages were slightly tweaked to make them more readable - this shouldn't be a problem as they are being used internally mainly in test assertions. --- .../function/EqlFunctionRegistry.java | 4 +- .../function/FunctionDefinition.java | 26 +- .../expression/function/FunctionRegistry.java | 454 ++++++------------ .../function/FunctionResolutionStrategy.java | 2 +- .../function/FunctionRegistryTests.java | 20 +- .../function/SqlFunctionDefinition.java | 41 ++ .../function/SqlFunctionRegistry.java | 224 +++++++-- .../function/SqlFunctionResolution.java | 12 +- .../function/SqlFunctionRegistryTests.java | 46 +- 9 files changed, 448 insertions(+), 381 deletions(-) create mode 100644 x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionDefinition.java diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/expression/function/EqlFunctionRegistry.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/expression/function/EqlFunctionRegistry.java index 04c9f73213625..af8ac645d0ca0 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/expression/function/EqlFunctionRegistry.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/expression/function/EqlFunctionRegistry.java @@ -31,10 +31,10 @@ public class EqlFunctionRegistry extends FunctionRegistry { public EqlFunctionRegistry() { - super(functions()); + register(functions()); } - private static FunctionDefinition[][] functions() { + private FunctionDefinition[][] functions() { return new FunctionDefinition[][] { // Scalar functions // String diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionDefinition.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionDefinition.java index f6edccc10e5ce..29b1a5d8025df 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionDefinition.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionDefinition.java @@ -14,10 +14,14 @@ public class FunctionDefinition { /** * Converts an {@link UnresolvedFunction} into the a proper {@link Function}. + *

+ * Provides the basic signature (unresolved function + runtime configuration object) while + * allowing extensions through the vararg extras which subclasses should expand for their + * own purposes. */ @FunctionalInterface public interface Builder { - Function build(UnresolvedFunction uf, boolean distinct, Configuration configuration); + Function build(UnresolvedFunction uf, Configuration configuration, Object... extras); } private final String name; @@ -25,20 +29,11 @@ public interface Builder { private final Class clazz; private final Builder builder; - /** - * Is this a datetime function compatible with {@code EXTRACT}. - */ - // TODO: needs refactoring so that specific function properties (per language) are isolated from QL - private final boolean extractViable; - - - protected FunctionDefinition(String name, List aliases, Class clazz, boolean dateTime, Builder builder) { + protected FunctionDefinition(String name, List aliases, Class clazz, Builder builder) { this.name = name; this.aliases = aliases; this.clazz = clazz; this.builder = builder; - - this.extractViable = dateTime; } public String name() { @@ -53,17 +48,10 @@ public Class clazz() { return clazz; } - public Builder builder() { + protected Builder builder() { return builder; } - /** - * Is this a datetime function compatible with {@code EXTRACT}. - */ - public boolean extractViable() { - return extractViable; - } - @Override public String toString() { return format(null, "{}({})", name, aliases.isEmpty() ? "" : aliases.size() == 1 ? aliases.get(0) : aliases); diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistry.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistry.java index 13c0c45173030..cf6f7e4fb9de9 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistry.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistry.java @@ -6,15 +6,14 @@ package org.elasticsearch.xpack.ql.expression.function; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.xpack.ql.ParsingException; import org.elasticsearch.xpack.ql.QlIllegalArgumentException; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.session.Configuration; import org.elasticsearch.xpack.ql.tree.Source; -import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.util.Check; -import java.time.ZoneId; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -59,6 +58,10 @@ public FunctionRegistry(FunctionDefinition... functions) { } public FunctionRegistry(FunctionDefinition[]... groupFunctions) { + register(groupFunctions); + } + + protected void register(FunctionDefinition[]... groupFunctions) { for (FunctionDefinition[] group : groupFunctions) { register(group); } @@ -73,7 +76,7 @@ protected void register(FunctionDefinition... functions) { Object old = batchMap.put(alias, f); if (old != null || defs.containsKey(alias)) { throw new QlIllegalArgumentException("alias [" + alias + "] is used by " - + "[" + (old != null ? old : defs.get(alias).name()) + "] and [" + f.name() + "]"); + + "[" + (old != null ? old : defs.get(alias).name()) + "] and [" + f.name() + "]"); } aliases.put(alias, f.name()); } @@ -89,9 +92,7 @@ FunctionDefinition, LinkedHashMap> toMap(Map.Entry:: public FunctionDefinition resolveFunction(String functionName) { FunctionDefinition def = defs.get(functionName); if (def == null) { - throw new QlIllegalArgumentException( - "Cannot find function {}; this should have been caught during analysis", - functionName); + throw new QlIllegalArgumentException("Cannot find function {}; this should have been caught during analysis", functionName); } return def; } @@ -124,304 +125,133 @@ public Collection listFunctions(String pattern) { } protected FunctionDefinition cloneDefinition(String name, FunctionDefinition definition) { - return new FunctionDefinition(name, emptyList(), definition.clazz(), definition.extractViable(), definition.builder()); + return new FunctionDefinition(name, emptyList(), definition.clazz(), definition.builder()); } - /** - * Build a {@linkplain FunctionDefinition} for a no-argument function that - * is not aware of time zone and does not support {@code DISTINCT}. - */ - public static FunctionDefinition def(Class function, - java.util.function.Function ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (false == children.isEmpty()) { - throw new QlIllegalArgumentException("expects no arguments"); - } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.apply(source); - }; - return def(function, builder, false, names); + protected interface FunctionBuilder { + Function build(Source source, List children, Configuration cfg); } /** - * Build a {@linkplain FunctionDefinition} for a no-argument function that - * is not aware of time zone, does not support {@code DISTINCT} and needs - * the cluster name (DATABASE()) or the user name (USER()). + * Main method to register a function. + * + * @param names Must always have at least one entry which is the method's primary name */ @SuppressWarnings("overloads") - protected static FunctionDefinition def(Class function, - ConfigurationAwareFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (false == children.isEmpty()) { - throw new QlIllegalArgumentException("expects no arguments"); + protected static FunctionDefinition def(Class function, FunctionBuilder builder, String... names) { + Check.isTrue(names.length > 0, "At least one name must be provided for the function"); + String primaryName = names[0]; + List aliases = Arrays.asList(names).subList(1, names.length); + FunctionDefinition.Builder realBuilder = (uf, cfg, extras) -> { + if (CollectionUtils.isEmpty(extras) == false) { + throw new ParsingException(uf.source(), "Unused parameters {} detected when building [{}]", + Arrays.toString(extras), + primaryName); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); + try { + return builder.build(uf.source(), uf.children(), cfg); + } catch (QlIllegalArgumentException e) { + throw new ParsingException(e, uf.source(), "error building [{}]: {}", primaryName, e.getMessage()); } - return ctorRef.build(source, cfg); }; - return def(function, builder, false, names); - } - - protected interface ConfigurationAwareFunctionBuilder { - T build(Source source, Configuration configuration); + return new FunctionDefinition(primaryName, unmodifiableList(aliases), function, realBuilder); } /** - * Build a {@linkplain FunctionDefinition} for a one-argument function that - * is not aware of time zone, does not support {@code DISTINCT} and needs - * the configuration object. - */ - @SuppressWarnings("overloads") + * Build a {@linkplain FunctionDefinition} for a no-argument function. + */ protected static FunctionDefinition def(Class function, - UnaryConfigurationAwareFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (children.size() > 1) { - throw new QlIllegalArgumentException("expects exactly one argument"); - } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); + java.util.function.Function ctorRef, + String... names) { + FunctionBuilder builder = (source, children, cfg) -> { + if (false == children.isEmpty()) { + throw new QlIllegalArgumentException("expects no arguments"); } - Expression ex = children.size() == 1 ? children.get(0) : null; - return ctorRef.build(source, ex, cfg); + return ctorRef.apply(source); }; - return def(function, builder, false, names); - } - - protected interface UnaryConfigurationAwareFunctionBuilder { - T build(Source source, Expression exp, Configuration configuration); + return def(function, builder, names); } - /** - * Build a {@linkplain FunctionDefinition} for a unary function that is not - * aware of time zone and does not support {@code DISTINCT}. + * Build a {@linkplain FunctionDefinition} for a unary function. */ @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - BiFunction ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { + protected static FunctionDefinition def(Class function, + BiFunction ctorRef, + String... names) { + FunctionBuilder builder = (source, children, cfg) -> { if (children.size() != 1) { throw new QlIllegalArgumentException("expects exactly one argument"); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } return ctorRef.apply(source, children.get(0)); }; - return def(function, builder, false, names); + return def(function, builder, names); } /** - * Build a {@linkplain FunctionDefinition} for multi-arg function that - * is not aware of time zone and does not support {@code DISTINCT}. + * Build a {@linkplain FunctionDefinition} for multi-arg/n-ary function. */ @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - MultiFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } + protected FunctionDefinition def(Class function, NaryBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { return ctorRef.build(source, children); }; - return def(function, builder, false, names); + return def(function, builder, names); } - protected interface MultiFunctionBuilder { + protected interface NaryBuilder { T build(Source source, List children); } /** - * Build a {@linkplain FunctionDefinition} for a unary function that is not - * aware of time zone but does support {@code DISTINCT}. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - DistinctAwareUnaryFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (children.size() != 1) { - throw new QlIllegalArgumentException("expects exactly one argument"); - } - return ctorRef.build(source, children.get(0), distinct); - }; - return def(function, builder, false, names); - } - - public interface DistinctAwareUnaryFunctionBuilder { - T build(Source source, Expression target, boolean distinct); - } - - /** - * Build a {@linkplain FunctionDefinition} for a unary function that - * operates on a datetime. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - DatetimeUnaryFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (children.size() != 1) { - throw new QlIllegalArgumentException("expects exactly one argument"); - } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.build(source, children.get(0), cfg.zoneId()); - }; - return def(function, builder, true, names); - } - - public interface DatetimeUnaryFunctionBuilder { - T build(Source source, Expression target, ZoneId zi); - } - - /** - * Build a {@linkplain FunctionDefinition} for a binary function that - * requires a timezone. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, DatetimeBinaryFunctionBuilder ctorRef, - String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (children.size() != 2) { - throw new QlIllegalArgumentException("expects exactly two arguments"); - } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.build(source, children.get(0), children.get(1), cfg.zoneId()); - }; - return def(function, builder, false, names); - } - - protected interface DatetimeBinaryFunctionBuilder { - T build(Source source, Expression lhs, Expression rhs, ZoneId zi); - } - - /** - * Build a {@linkplain FunctionDefinition} for a three-args function that - * requires a timezone. - */ - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, DatetimeThreeArgsFunctionBuilder ctorRef, - String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - if (children.size() != 3) { - throw new QlIllegalArgumentException("expects three arguments"); - } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.build(source, children.get(0), children.get(1), children.get(2), cfg.zoneId()); - }; - return def(function, builder, false, names); - } - - protected interface DatetimeThreeArgsFunctionBuilder { - T build(Source source, Expression first, Expression second, Expression third, ZoneId zi); - } - - /** - * Build a {@linkplain FunctionDefinition} for a binary function that is - * not aware of time zone and does not support {@code DISTINCT}. + * Build a {@linkplain FunctionDefinition} for a binary function. */ @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - BinaryFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { + protected static FunctionDefinition def(Class function, BinaryBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { boolean isBinaryOptionalParamFunction = OptionalArgument.class.isAssignableFrom(function); if (isBinaryOptionalParamFunction && (children.size() > 2 || children.size() < 1)) { throw new QlIllegalArgumentException("expects one or two arguments"); - } else if (!isBinaryOptionalParamFunction && children.size() != 2) { + } else if (isBinaryOptionalParamFunction == false && children.size() != 2) { throw new QlIllegalArgumentException("expects exactly two arguments"); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } return ctorRef.build(source, children.get(0), children.size() == 2 ? children.get(1) : null); }; - return def(function, builder, false, names); + return def(function, builder, names); } - protected interface BinaryFunctionBuilder { - T build(Source source, Expression lhs, Expression rhs); + protected interface BinaryBuilder { + T build(Source source, Expression left, Expression right); } /** - * Main method to register a function/ - * @param names Must always have at least one entry which is the method's primary name + * Build a {@linkplain FunctionDefinition} for a ternary function. */ - @SuppressWarnings("overloads") - public static FunctionDefinition def(Class function, FunctionBuilder builder, - boolean datetime, String... names) { - Check.isTrue(names.length > 0, "At least one name must be provided for the function"); - String primaryName = names[0]; - List aliases = Arrays.asList(names).subList(1, names.length); - FunctionDefinition.Builder realBuilder = (uf, distinct, cfg) -> { - try { - return builder.build(uf.source(), uf.children(), distinct, cfg); - } catch (QlIllegalArgumentException e) { - throw new ParsingException(uf.source(), "error building [" + primaryName + "]: " + e.getMessage(), e); - } - }; - return new FunctionDefinition(primaryName, unmodifiableList(aliases), function, datetime, realBuilder); - } - - public interface FunctionBuilder { - Function build(Source source, List children, boolean distinct, Configuration cfg); - } - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - ThreeParametersFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { + protected static FunctionDefinition def(Class function, TernaryBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { boolean hasMinimumTwo = OptionalArgument.class.isAssignableFrom(function); if (hasMinimumTwo && (children.size() > 3 || children.size() < 2)) { throw new QlIllegalArgumentException("expects two or three arguments"); - } else if (!hasMinimumTwo && children.size() != 3) { + } else if (hasMinimumTwo == false && children.size() != 3) { throw new QlIllegalArgumentException("expects exactly three arguments"); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } return ctorRef.build(source, children.get(0), children.get(1), children.size() == 3 ? children.get(2) : null); }; - return def(function, builder, false, names); - } - - protected interface ThreeParametersFunctionBuilder { - T build(Source source, Expression src, Expression exp1, Expression exp2); - } - - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - ScalarTriFunctionConfigurationAwareBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { - boolean hasMinimumTwo = OptionalArgument.class.isAssignableFrom(function); - if (hasMinimumTwo && (children.size() > 3 || children.size() < 2)) { - throw new QlIllegalArgumentException("expects two or three arguments"); - } else if (!hasMinimumTwo && children.size() != 3) { - throw new QlIllegalArgumentException("expects exactly three arguments"); - } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } - return ctorRef.build(source, children.get(0), children.get(1), children.size() == 3 ? children.get(2) : null, cfg); - }; - return def(function, builder, false, names); + return def(function, builder, names); } - protected interface ScalarTriFunctionConfigurationAwareBuilder { - T build(Source source, Expression exp1, Expression exp2, Expression exp3, Configuration configuration); + protected interface TernaryBuilder { + T build(Source source, Expression one, Expression two, Expression three); } + /** + * Build a {@linkplain FunctionDefinition} for a quaternary function. + */ @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - FourParametersFunctionBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { + protected static FunctionDefinition def(Class function, QuaternaryBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { if (OptionalArgument.class.isAssignableFrom(function)) { if (children.size() > 4 || children.size() < 3) { throw new QlIllegalArgumentException("expects three or four arguments"); @@ -433,110 +263,146 @@ public static FunctionDefinition def(Class function, } else if (children.size() != 4) { throw new QlIllegalArgumentException("expects exactly four arguments"); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } return ctorRef.build(source, children.get(0), children.get(1), children.size() > 2 ? children.get(2) : null, children.size() > 3 ? children.get(3) : null); }; - return def(function, builder, false, names); + return def(function, builder, names); } - protected interface FourParametersFunctionBuilder { - T build(Source source, Expression src, Expression exp1, Expression exp2, Expression exp3); + protected interface QuaternaryBuilder { + T build(Source source, Expression one, Expression two, Expression three, Expression four); } + /** + * Build a {@linkplain FunctionDefinition} for a quinary function. + */ @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - FiveParametersFunctionBuilder ctorRef, - int numOptionalParams, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { + protected static FunctionDefinition def(Class function, + QuinaryBuilder ctorRef, + int numOptionalParams, + String... names) { + FunctionBuilder builder = (source, children, cfg) -> { final int NUM_TOTAL_PARAMS = 5; boolean hasOptionalParams = OptionalArgument.class.isAssignableFrom(function); if (hasOptionalParams && (children.size() > NUM_TOTAL_PARAMS || children.size() < NUM_TOTAL_PARAMS - numOptionalParams)) { throw new QlIllegalArgumentException("expects between " + NUM_NAMES[NUM_TOTAL_PARAMS - numOptionalParams] - + " and " + NUM_NAMES[NUM_TOTAL_PARAMS] + " arguments"); + + " and " + NUM_NAMES[NUM_TOTAL_PARAMS] + " arguments"); } else if (hasOptionalParams == false && children.size() != NUM_TOTAL_PARAMS) { throw new QlIllegalArgumentException("expects exactly " + NUM_NAMES[NUM_TOTAL_PARAMS] + " arguments"); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } return ctorRef.build(source, - children.size() > 0 ? children.get(0) : null, - children.size() > 1 ? children.get(1) : null, - children.size() > 2 ? children.get(2) : null, - children.size() > 3 ? children.get(3) : null, - children.size() > 4 ? children.get(4) : null); + children.size() > 0 ? children.get(0) : null, + children.size() > 1 ? children.get(1) : null, + children.size() > 2 ? children.get(2) : null, + children.size() > 3 ? children.get(3) : null, + children.size() > 4 ? children.get(4) : null); }; - return def(function, builder, false, names); + return def(function, builder, names); } - protected interface FiveParametersFunctionBuilder { - T build(Source source, Expression src, Expression exp1, Expression exp2, Expression exp3, Expression exp4); + protected interface QuinaryBuilder { + T build(Source source, Expression one, Expression two, Expression three, Expression four, Expression five); } /** - * Special method to create function definition for Cast as its - * signature is not compatible with {@link UnresolvedFunction} - * - * @return Cast function definition + * Build a {@linkplain FunctionDefinition} for functions with a mandatory argument followed by a varidic list. */ @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - CastFunctionBuilder ctorRef, - String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> - ctorRef.build(source, children.get(0), children.get(0).dataType()); - return def(function, builder, false, names); - } - - protected interface CastFunctionBuilder { - T build(Source source, Expression expression, DataType dataType); - } - - @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - TwoParametersVariadicBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { + protected static FunctionDefinition def(Class function, UnaryVariadicBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { boolean hasMinimumOne = OptionalArgument.class.isAssignableFrom(function); if (hasMinimumOne && children.size() < 1) { throw new QlIllegalArgumentException("expects at least one argument"); } else if (!hasMinimumOne && children.size() < 2) { throw new QlIllegalArgumentException("expects at least two arguments"); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } return ctorRef.build(source, children.get(0), children.subList(1, children.size())); }; - return def(function, builder, false, names); + return def(function, builder, names); + } + + protected interface UnaryVariadicBuilder { + T build(Source source, Expression exp, List variadic); + } + + /** + * Build a {@linkplain FunctionDefinition} for a no-argument function that is configuration aware. + */ + @SuppressWarnings("overloads") + protected static FunctionDefinition def(Class function, ConfigurationAwareBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { + if (false == children.isEmpty()) { + throw new QlIllegalArgumentException("expects no arguments"); + } + return ctorRef.build(source, cfg); + }; + return def(function, builder, names); + } + + protected interface ConfigurationAwareBuilder { + T build(Source source, Configuration configuration); + } + + /** + * Build a {@linkplain FunctionDefinition} for a one-argument function that is configuration aware. + */ + @SuppressWarnings("overloads") + protected static FunctionDefinition def(Class function, + UnaryConfigurationAwareBuilder ctorRef, + String... names) { + FunctionBuilder builder = (source, children, cfg) -> { + if (children.size() > 1) { + throw new QlIllegalArgumentException("expects exactly one argument"); + } + Expression ex = children.size() == 1 ? children.get(0) : null; + return ctorRef.build(source, ex, cfg); + }; + return def(function, builder, names); } - protected interface TwoParametersVariadicBuilder { - T build(Source source, Expression src, List remaining); + protected interface UnaryConfigurationAwareBuilder { + T build(Source source, Expression exp, Configuration configuration); } /** - * Build a {@linkplain FunctionDefinition} for a binary function that is case sensitive aware. + * Build a {@linkplain FunctionDefinition} for a binary function that is configuration aware. */ @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do - public static FunctionDefinition def(Class function, - ScalarBiFunctionConfigurationAwareBuilder ctorRef, String... names) { - FunctionBuilder builder = (source, children, distinct, cfg) -> { + protected static FunctionDefinition def(Class function, + BinaryConfigurationAwareBuilder ctorRef, + String... names) { + FunctionBuilder builder = (source, children, cfg) -> { if (children.size() != 2) { throw new QlIllegalArgumentException("expects exactly two arguments"); } - if (distinct) { - throw new QlIllegalArgumentException("does not support DISTINCT yet it was specified"); - } return ctorRef.build(source, children.get(0), children.get(1), cfg); }; - return def(function, builder, true, names); + return def(function, builder, names); + } + + protected interface BinaryConfigurationAwareBuilder { + T build(Source source, Expression left, Expression right, Configuration configuration); + } + + /** + * Build a {@linkplain FunctionDefinition} for a ternary function that is configuration aware. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + protected FunctionDefinition def(Class function, TernaryConfigurationAwareBuilder ctorRef, String... names) { + FunctionBuilder builder = (source, children, cfg) -> { + boolean hasMinimumTwo = OptionalArgument.class.isAssignableFrom(function); + if (hasMinimumTwo && (children.size() > 3 || children.size() < 2)) { + throw new QlIllegalArgumentException("expects two or three arguments"); + } else if (!hasMinimumTwo && children.size() != 3) { + throw new QlIllegalArgumentException("expects exactly three arguments"); + } + return ctorRef.build(source, children.get(0), children.get(1), children.size() == 3 ? children.get(2) : null, cfg); + }; + return def(function, builder, names); } - protected interface ScalarBiFunctionConfigurationAwareBuilder { - T build(Source source, Expression e1, Expression e2, Configuration configuration); + protected interface TernaryConfigurationAwareBuilder { + T build(Source source, Expression one, Expression two, Expression three, Configuration configuration); } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionResolutionStrategy.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionResolutionStrategy.java index b6f207d938ab1..164a8dcc3efaf 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionResolutionStrategy.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/function/FunctionResolutionStrategy.java @@ -24,7 +24,7 @@ public interface FunctionResolutionStrategy { * Build the real function from this one and resolution metadata. */ default Function buildResolved(UnresolvedFunction uf, Configuration cfg, FunctionDefinition def) { - return def.builder().build(uf, false, cfg); + return def.builder().build(uf, cfg); } /** diff --git a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistryTests.java b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistryTests.java index 783345aa6f6bb..9273e8999e997 100644 --- a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistryTests.java +++ b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/expression/function/FunctionRegistryTests.java @@ -32,19 +32,19 @@ public class FunctionRegistryTests extends ESTestCase { public void testNoArgFunction() { UnresolvedFunction ur = uf(DEFAULT); - FunctionRegistry r = new FunctionRegistry(def(DummyFunction.class, DummyFunction::new, "DUMMY_FUNCTION")); + FunctionRegistry r = new FunctionRegistry(defineDummyNoArgFunction()); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.source(), ur.buildResolved(randomConfiguration(), def).source()); } + public static FunctionDefinition defineDummyNoArgFunction() { + return def(DummyFunction.class, DummyFunction::new, "DUMMY_FUNCTION"); + } + public void testUnaryFunction() { UnresolvedFunction ur = uf(DEFAULT, mock(Expression.class)); - FunctionRegistry r = new FunctionRegistry(def(DummyFunction.class, (Source l, Expression e) -> { - assertSame(e, ur.children().get(0)); - return new DummyFunction(l); - }, "DUMMY_FUNCTION")); + FunctionRegistry r = new FunctionRegistry(defineDummyUnaryFunction(ur)); FunctionDefinition def = r.resolveFunction(ur.name()); - assertFalse(def.extractViable()); assertEquals(ur.source(), ur.buildResolved(randomConfiguration(), def).source()); // No children aren't supported @@ -58,6 +58,13 @@ public void testUnaryFunction() { assertThat(e.getMessage(), endsWith("expects exactly one argument")); } + public static FunctionDefinition defineDummyUnaryFunction(UnresolvedFunction ur) { + return def(DummyFunction.class, (Source l, Expression e) -> { + assertSame(e, ur.children().get(0)); + return new DummyFunction(l); + }, "DUMMY_FUNCTION"); + } + public void testBinaryFunction() { UnresolvedFunction ur = uf(DEFAULT, mock(Expression.class), mock(Expression.class)); FunctionRegistry r = new FunctionRegistry(def(DummyFunction.class, (Source l, Expression lhs, Expression rhs) -> { @@ -67,7 +74,6 @@ public void testBinaryFunction() { }, "DUMMY_FUNCTION")); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.source(), ur.buildResolved(randomConfiguration(), def).source()); - assertFalse(def.extractViable()); // No children aren't supported ParsingException e = expectThrows(ParsingException.class, () -> diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionDefinition.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionDefinition.java new file mode 100644 index 0000000000000..0dc9ae76418a4 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionDefinition.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.sql.expression.function; + +import org.elasticsearch.xpack.ql.expression.function.Function; +import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition; + +import java.util.List; + +public class SqlFunctionDefinition extends FunctionDefinition { + + /** + * Is this a datetime function compatible with {@code EXTRACT}. + */ + private final boolean extractViable; + + protected SqlFunctionDefinition(String name, + List aliases, + Class clazz, + boolean dateTime, + Builder builder) { + super(name, aliases, clazz, builder); + this.extractViable = dateTime; + } + + /** + * Is this a datetime function compatible with {@code EXTRACT}. + */ + public boolean extractViable() { + return extractViable; + } + + @Override + protected Builder builder() { + return super.builder(); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistry.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistry.java index f720ca231d1eb..851392f06c9e3 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistry.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistry.java @@ -5,10 +5,21 @@ */ package org.elasticsearch.xpack.sql.expression.function; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.xpack.ql.ParsingException; +import org.elasticsearch.xpack.ql.QlIllegalArgumentException; +import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.expression.function.Function; import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition; import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry; +import org.elasticsearch.xpack.ql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.ql.expression.function.aggregate.Count; import org.elasticsearch.xpack.ql.expression.function.scalar.string.StartsWith; +import org.elasticsearch.xpack.ql.session.Configuration; +import org.elasticsearch.xpack.ql.tree.Source; +import org.elasticsearch.xpack.ql.type.DataType; +import org.elasticsearch.xpack.ql.util.Check; +import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg; import org.elasticsearch.xpack.sql.expression.function.aggregate.First; import org.elasticsearch.xpack.sql.expression.function.aggregate.Kurtosis; @@ -35,8 +46,8 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.CurrentTime; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateAdd; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateDiff; -import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DatePart; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateParse; +import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DatePart; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFormat; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeParse; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTrunc; @@ -122,16 +133,27 @@ import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIf; import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.Mod; +import java.time.ZoneId; +import java.util.Arrays; +import java.util.List; + +import static java.util.Collections.unmodifiableList; + public class SqlFunctionRegistry extends FunctionRegistry { public SqlFunctionRegistry() { - super(functions()); + register(functions()); } - private static FunctionDefinition[][] functions() { - return new FunctionDefinition[][] { - // Aggregate functions - new FunctionDefinition[] { + protected SqlFunctionRegistry(FunctionDefinition... functions) { + register(functions); + } + + + private FunctionDefinition[][] functions() { + return new FunctionDefinition[][]{ + // Aggregate functions + new FunctionDefinition[]{ def(Avg.class, Avg::new, "AVG"), def(Count.class, Count::new, "COUNT"), def(First.class, First::new, "FIRST", "FIRST_VALUE"), @@ -139,9 +161,9 @@ private static FunctionDefinition[][] functions() { def(Max.class, Max::new, "MAX"), def(Min.class, Min::new, "MIN"), def(Sum.class, Sum::new, "SUM") - }, - // Statistics - new FunctionDefinition[] { + }, + // Statistics + new FunctionDefinition[]{ def(Kurtosis.class, Kurtosis::new, "KURTOSIS"), def(MedianAbsoluteDeviation.class, MedianAbsoluteDeviation::new, "MAD"), def(Percentile.class, Percentile::new, "PERCENTILE"), @@ -152,14 +174,14 @@ private static FunctionDefinition[][] functions() { def(SumOfSquares.class, SumOfSquares::new, "SUM_OF_SQUARES"), def(VarPop.class, VarPop::new, "VAR_POP"), def(VarSamp.class, VarSamp::new, "VAR_SAMP") - }, - // histogram - new FunctionDefinition[] { + }, + // histogram + new FunctionDefinition[]{ def(Histogram.class, Histogram::new, "HISTOGRAM") - }, - // Scalar functions - // Conditional - new FunctionDefinition[] { + }, + // Scalar functions + // Conditional + new FunctionDefinition[]{ def(Case.class, Case::new, "CASE"), def(Coalesce.class, Coalesce::new, "COALESCE"), def(Iif.class, Iif::new, "IIF"), @@ -167,9 +189,9 @@ private static FunctionDefinition[][] functions() { def(NullIf.class, NullIf::new, "NULLIF"), def(Greatest.class, Greatest::new, "GREATEST"), def(Least.class, Least::new, "LEAST") - }, - // Date - new FunctionDefinition[] { + }, + // Date + new FunctionDefinition[]{ def(CurrentDate.class, CurrentDate::new, "CURRENT_DATE", "CURDATE", "TODAY"), def(CurrentTime.class, CurrentTime::new, "CURRENT_TIME", "CURTIME"), def(CurrentDateTime.class, CurrentDateTime::new, "CURRENT_TIMESTAMP", "NOW"), @@ -180,7 +202,7 @@ private static FunctionDefinition[][] functions() { def(DateAdd.class, DateAdd::new, "DATEADD", "DATE_ADD", "TIMESTAMPADD", "TIMESTAMP_ADD"), def(DateDiff.class, DateDiff::new, "DATEDIFF", "DATE_DIFF", "TIMESTAMPDIFF", "TIMESTAMP_DIFF"), def(DateParse.class, DateParse::new, "DATE_PARSE"), - def(DatePart.class, DatePart::new, "DATEPART", "DATE_PART"), + def(DatePart.class, DatePart::new, "DATEPART", "DATE_PART"), def(DateTimeFormat.class, DateTimeFormat::new, "DATETIME_FORMAT"), def(DateTimeParse.class, DateTimeParse::new, "DATETIME_PARSE"), def(DateTrunc.class, DateTrunc::new, "DATETRUNC", "DATE_TRUNC"), @@ -199,8 +221,8 @@ private static FunctionDefinition[][] functions() { def(Year.class, Year::new, "YEAR"), def(WeekOfYear.class, WeekOfYear::new, "WEEK_OF_YEAR", "WEEK") }, - // Math - new FunctionDefinition[] { + // Math + new FunctionDefinition[]{ def(Abs.class, Abs::new, "ABS"), def(ACos.class, ACos::new, "ACOS"), def(ASin.class, ASin::new, "ASIN"), @@ -232,8 +254,8 @@ private static FunctionDefinition[][] functions() { def(Tan.class, Tan::new, "TAN"), def(Truncate.class, Truncate::new, "TRUNCATE", "TRUNC") }, - // String - new FunctionDefinition[] { + // String + new FunctionDefinition[]{ def(Ascii.class, Ascii::new, "ASCII"), def(BitLength.class, BitLength::new, "BIT_LENGTH"), def(Char.class, Char::new, "CHAR"), @@ -257,17 +279,17 @@ private static FunctionDefinition[][] functions() { def(Trim.class, Trim::new, "TRIM"), def(UCase.class, UCase::new, "UCASE") }, - // DataType conversion - new FunctionDefinition[] { + // DataType conversion + new FunctionDefinition[]{ def(Cast.class, Cast::new, "CAST", "CONVERT") }, - // Scalar "meta" functions - new FunctionDefinition[] { + // Scalar "meta" functions + new FunctionDefinition[]{ def(Database.class, Database::new, "DATABASE"), def(User.class, User::new, "USER") }, - // Geo Functions - new FunctionDefinition[] { + // Geo Functions + new FunctionDefinition[]{ def(StAswkt.class, StAswkt::new, "ST_ASWKT", "ST_ASTEXT"), def(StDistance.class, StDistance::new, "ST_DISTANCE"), def(StWkttosql.class, StWkttosql::new, "ST_WKTTOSQL", "ST_GEOMFROMTEXT"), @@ -276,11 +298,149 @@ private static FunctionDefinition[][] functions() { def(StY.class, StY::new, "ST_Y"), def(StZ.class, StZ::new, "ST_Z") }, - // Special - new FunctionDefinition[] { + // Special + new FunctionDefinition[]{ def(Score.class, Score::new, "SCORE") } }; } + /** + * Builder for creating SQL-specific functions. + * All other definitions defined here end up being translated to this form. + */ + protected interface SqlFunctionBuilder { + Function build(Source source, List children, Configuration cfg, Boolean distinct); + } + + /** + * Main method to register a function. + */ + @SuppressWarnings("overloads") + protected static FunctionDefinition def(Class function, + SqlFunctionBuilder builder, + boolean datetime, + String... names) { + Check.isTrue(names.length > 0, "At least one name must be provided for the function"); + String primaryName = names[0]; + List aliases = Arrays.asList(names).subList(1, names.length); + FunctionDefinition.Builder realBuilder = (uf, cfg, extras) -> { + try { + return builder.build(uf.source(), uf.children(), cfg, asBool(extras)); + } catch (QlIllegalArgumentException e) { + throw new ParsingException(uf.source(), "error building [" + primaryName + "]: " + e.getMessage(), e); + } + }; + return new SqlFunctionDefinition(primaryName, unmodifiableList(aliases), function, datetime, realBuilder); + } + + private static Boolean asBool(Object[] extras) { + if (CollectionUtils.isEmpty(extras)) { + return null; + } + if (extras.length != 1 || (extras[0] instanceof Boolean) == false) { + throw new SqlIllegalArgumentException("Expected exactly one bool argument, found [{}], entry [{}]", extras.length, extras[0]); + } + return (Boolean) extras[0]; + } + + /** + * Build a {@linkplain FunctionDefinition} for a unary function that is not aware of time zone but does support {@code DISTINCT}. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + protected static FunctionDefinition def(Class function, UnaryDistinctAwareBuilder ctorRef, String... names) { + SqlFunctionBuilder builder = (source, children, cfg, distinct) -> { + if (children.size() != 1) { + throw new QlIllegalArgumentException("expects exactly one argument"); + } + return ctorRef.build(source, children.get(0), distinct == null ? Boolean.FALSE : distinct); + }; + return def(function, builder, false, names); + } + + protected interface UnaryDistinctAwareBuilder { + T build(Source source, Expression target, Boolean distinct); + } + + /** + * Build a {@linkplain FunctionDefinition} for a unary function that requires a timezone. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + protected static FunctionDefinition def(Class function, UnaryZoneIdAwareBuilder ctorRef, String... names) { + SqlFunctionBuilder builder = (source, children, cfg, distinct) -> { + if (children.size() != 1) { + throw new QlIllegalArgumentException("expects exactly one argument"); + } + forbidDistinct(source, distinct); + return ctorRef.build(source, children.get(0), cfg.zoneId()); + }; + return def(function, builder, true, names); + } + + protected interface UnaryZoneIdAwareBuilder { + T build(Source source, Expression exp, ZoneId zi); + } + + /** + * Build a {@linkplain FunctionDefinition} for a binary function that requires a timezone. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + protected static FunctionDefinition def(Class function, BinaryZoneIdAwareBuilder ctorRef, String... names) { + SqlFunctionBuilder builder = (source, children, cfg, distinct) -> { + if (children.size() != 2) { + throw new QlIllegalArgumentException("expects exactly two arguments"); + } + forbidDistinct(source, distinct); + return ctorRef.build(source, children.get(0), children.get(1), cfg.zoneId()); + }; + return def(function, builder, true, names); + } + + protected interface BinaryZoneIdAwareBuilder { + T build(Source source, Expression left, Expression right, ZoneId zi); + } + + /** + * Build a {@linkplain FunctionDefinition} for a three-args function that requires a timezone. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + protected static FunctionDefinition def(Class function, + TernaryZoneIdAwareBuilder ctorRef, + String... names) { + SqlFunctionBuilder builder = (source, children, cfg, distinct) -> { + if (children.size() != 3) { + throw new QlIllegalArgumentException("expects three arguments"); + } + forbidDistinct(source, distinct); + return ctorRef.build(source, children.get(0), children.get(1), children.get(2), cfg.zoneId()); + }; + return def(function, builder, true, names); + } + + protected interface TernaryZoneIdAwareBuilder { + T build(Source source, Expression first, Expression second, Expression third, ZoneId zi); + } + + + /** + * Special method to create function definition for Cast as its signature is not compatible with {@link UnresolvedFunction}. + */ + @SuppressWarnings("overloads") // These are ambiguous if you aren't using ctor references but we always do + protected static FunctionDefinition def(Class function, CastBuilder ctorRef, String... names) { + SqlFunctionBuilder builder = (source, children, cfg, distinct) -> { + forbidDistinct(source, distinct); + return ctorRef.build(source, children.get(0), children.get(0).dataType()); + }; + return def(function, builder, false, names); + } + + protected interface CastBuilder { + T build(Source source, Expression expression, DataType dataType); + } + + private static void forbidDistinct(Source source, Boolean distinct) { + if (distinct != null) { + throw new ParsingException(source, "does not support DISTINCT yet it was specified"); + } + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionResolution.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionResolution.java index dd6c6c15b0bed..aa1fdcc6e84be 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionResolution.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionResolution.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.sql.expression.function; +import org.elasticsearch.xpack.ql.ParsingException; import org.elasticsearch.xpack.ql.expression.function.Function; import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition; import org.elasticsearch.xpack.ql.expression.function.FunctionResolutionStrategy; @@ -20,7 +21,10 @@ public enum SqlFunctionResolution implements FunctionResolutionStrategy { DISTINCT { @Override public Function buildResolved(UnresolvedFunction uf, Configuration cfg, FunctionDefinition def) { - return def.builder().build(uf, true, cfg); + if (def instanceof SqlFunctionDefinition) { + return ((SqlFunctionDefinition) def).builder().build(uf, cfg, true); + } + throw new ParsingException(uf.source(), "Cannot use {} on non-SQL function {}", name(), def); } @Override @@ -34,15 +38,15 @@ public boolean isValidAlternative(FunctionDefinition def) { EXTRACT { @Override public Function buildResolved(UnresolvedFunction uf, Configuration cfg, FunctionDefinition def) { - if (def.extractViable()) { - return def.builder().build(uf, false, cfg); + if (isValidAlternative(def)) { + return ((SqlFunctionDefinition) def).builder().build(uf, cfg); } return uf.withMessage("Invalid datetime field [" + uf.name() + "]. Use any datetime function."); } @Override public boolean isValidAlternative(FunctionDefinition def) { - return def.extractViable(); + return (def instanceof SqlFunctionDefinition) && ((SqlFunctionDefinition) def).extractViable(); } @Override diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistryTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistryTests.java index 60e443b08f84a..cdc98cdb27813 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistryTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/SqlFunctionRegistryTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition; import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry; -import org.elasticsearch.xpack.ql.expression.function.FunctionRegistryTests; import org.elasticsearch.xpack.ql.expression.function.FunctionRegistryTests.DummyFunction; import org.elasticsearch.xpack.ql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.ql.session.Configuration; @@ -20,9 +19,11 @@ import java.time.ZoneId; import static org.elasticsearch.xpack.ql.TestUtils.randomConfiguration; -import static org.elasticsearch.xpack.ql.expression.function.FunctionRegistry.def; +import static org.elasticsearch.xpack.ql.expression.function.FunctionRegistryTests.defineDummyNoArgFunction; +import static org.elasticsearch.xpack.ql.expression.function.FunctionRegistryTests.defineDummyUnaryFunction; import static org.elasticsearch.xpack.ql.expression.function.FunctionRegistryTests.uf; import static org.elasticsearch.xpack.ql.expression.function.FunctionResolutionStrategy.DEFAULT; +import static org.elasticsearch.xpack.sql.expression.function.SqlFunctionRegistry.def; import static org.elasticsearch.xpack.sql.expression.function.SqlFunctionResolution.DISTINCT; import static org.elasticsearch.xpack.sql.expression.function.SqlFunctionResolution.EXTRACT; import static org.hamcrest.Matchers.endsWith; @@ -32,13 +33,13 @@ public class SqlFunctionRegistryTests extends ESTestCase { public void testNoArgFunction() { UnresolvedFunction ur = uf(DEFAULT); - FunctionRegistry r = new FunctionRegistry(def(DummyFunction.class, DummyFunction::new, "DUMMY_FUNCTION")); + FunctionRegistry r = new SqlFunctionRegistry(defineDummyNoArgFunction()); FunctionDefinition def = r.resolveFunction(ur.name()); // Distinct isn't supported ParsingException e = expectThrows(ParsingException.class, () -> uf(DISTINCT).buildResolved(randomConfiguration(), def)); - assertThat(e.getMessage(), endsWith("does not support DISTINCT yet it was specified")); + assertThat(e.getMessage(), endsWith("Cannot use DISTINCT on non-SQL function DUMMY_FUNCTION()")); // Any children aren't supported e = expectThrows(ParsingException.class, () -> @@ -48,16 +49,13 @@ public void testNoArgFunction() { public void testUnaryFunction() { UnresolvedFunction ur = uf(DEFAULT, mock(Expression.class)); - FunctionRegistry r = new FunctionRegistry(def(DummyFunction.class, (Source l, Expression e) -> { - assertSame(e, ur.children().get(0)); - return new DummyFunction(l); - }, "DUMMY_FUNCTION")); + FunctionRegistry r = new SqlFunctionRegistry(defineDummyUnaryFunction(ur)); FunctionDefinition def = r.resolveFunction(ur.name()); // Distinct isn't supported ParsingException e = expectThrows(ParsingException.class, () -> uf(DISTINCT, mock(Expression.class)).buildResolved(randomConfiguration(), def)); - assertThat(e.getMessage(), endsWith("does not support DISTINCT yet it was specified")); + assertThat(e.getMessage(), endsWith("Cannot use DISTINCT on non-SQL function DUMMY_FUNCTION()")); // No children aren't supported e = expectThrows(ParsingException.class, () -> @@ -73,15 +71,17 @@ public void testUnaryFunction() { public void testUnaryDistinctAwareFunction() { boolean urIsDistinct = randomBoolean(); UnresolvedFunction ur = uf(urIsDistinct ? DISTINCT : DEFAULT, mock(Expression.class)); - FunctionRegistry r = new FunctionRegistry( - def(FunctionRegistryTests.DummyFunction.class, (Source l, Expression e, boolean distinct) -> { - assertEquals(urIsDistinct, distinct); - assertSame(e, ur.children().get(0)); - return new FunctionRegistryTests.DummyFunction(l); - }, "DUMMY_FUNCTION")); + FunctionDefinition definition = def(DummyFunction.class, (Source l, Expression e, Boolean distinct) -> { + assertEquals(urIsDistinct, distinct); + assertSame(e, ur.children().get(0)); + return new DummyFunction(l); + }, "DUMMY_FUNCTION"); + FunctionRegistry r = new SqlFunctionRegistry(definition); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.source(), ur.buildResolved(randomConfiguration(), def).source()); - assertFalse(def.extractViable()); + + assertEquals(SqlFunctionDefinition.class, def.getClass()); + assertFalse(((SqlFunctionDefinition) def).extractViable()); // No children aren't supported ParsingException e = expectThrows(ParsingException.class, () -> @@ -96,21 +96,23 @@ public void testUnaryDistinctAwareFunction() { public void testDateTimeFunction() { boolean urIsExtract = randomBoolean(); - UnresolvedFunction ur = uf(urIsExtract ? EXTRACT : DEFAULT, mock(Expression.class)); + Expression exprMock = mock(Expression.class); + UnresolvedFunction ur = uf(urIsExtract ? EXTRACT : DEFAULT, exprMock); ZoneId providedTimeZone = randomZone().normalized(); Configuration providedConfiguration = randomConfiguration(providedTimeZone); - FunctionRegistry r = new FunctionRegistry(def(FunctionRegistryTests.DummyFunction.class, (Source l, Expression e, ZoneId zi) -> { + FunctionRegistry r = new SqlFunctionRegistry(def(DummyFunction.class, (Source l, Expression e, ZoneId zi) -> { assertEquals(providedTimeZone, zi); assertSame(e, ur.children().get(0)); - return new FunctionRegistryTests.DummyFunction(l); + return new DummyFunction(l); }, "DUMMY_FUNCTION")); FunctionDefinition def = r.resolveFunction(ur.name()); assertEquals(ur.source(), ur.buildResolved(providedConfiguration, def).source()); - assertTrue(def.extractViable()); + assertEquals(SqlFunctionDefinition.class, def.getClass()); + assertTrue(((SqlFunctionDefinition) def).extractViable()); // Distinct isn't supported ParsingException e = expectThrows(ParsingException.class, () -> - uf(DISTINCT, mock(Expression.class)).buildResolved(randomConfiguration(), def)); + uf(DISTINCT, exprMock).buildResolved(providedConfiguration, def)); assertThat(e.getMessage(), endsWith("does not support DISTINCT yet it was specified")); // No children aren't supported @@ -120,7 +122,7 @@ public void testDateTimeFunction() { // Multiple children aren't supported e = expectThrows(ParsingException.class, () -> - uf(DEFAULT, mock(Expression.class), mock(Expression.class)).buildResolved(randomConfiguration(), def)); + uf(DEFAULT, exprMock, exprMock).buildResolved(randomConfiguration(), def)); assertThat(e.getMessage(), endsWith("expects exactly one argument")); } }