From 20800d781057ac76c42da0e8bc10282d98ce655e Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Thu, 16 Mar 2023 14:00:13 -0700 Subject: [PATCH] Added Arithmetic functions to V2 engine (#1416) * Updated Arithmetic functions from old engine to new engine (#235) * Updated ADD, SUBTRACT, MULTIPLY, DIVIDE, MODULUS to V2 engine Signed-off-by: Matthew Wells (cherry picked from commit bc3934611b5a60ebc981c18b180b58883bf6be55) --- .../org/opensearch/sql/expression/DSL.java | 59 ++- .../function/BuiltinFunctionName.java | 13 +- .../arthmetic/ArithmeticFunction.java | 380 +++++++++++------- .../arthmetic/MathematicalFunction.java | 44 -- .../sql/expression/ExpressionTestBase.java | 4 +- .../arthmetic/ArithmeticFunctionTest.java | 254 +++++++++--- .../arthmetic/MathematicalFunctionTest.java | 161 -------- docs/user/dql/functions.rst | 104 ++++- .../sql/sql/ArithmeticFunctionIT.java | 246 ++++++++++++ .../sql/sql/MathematicalFunctionIT.java | 2 +- sql/src/main/antlr/OpenSearchSQLLexer.g4 | 3 +- sql/src/main/antlr/OpenSearchSQLParser.g4 | 7 +- 12 files changed, 832 insertions(+), 445 deletions(-) create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/ArithmeticFunctionIT.java diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 98ebf428ba..2d08650019 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -150,6 +150,14 @@ public static FunctionExpression abs(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.ABS, expressions); } + public static FunctionExpression add(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ADD, expressions); + } + + public static FunctionExpression addFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ADDFUNCTION, expressions); + } + public static FunctionExpression ceil(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.CEIL, expressions); } @@ -166,6 +174,14 @@ public static FunctionExpression crc32(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.CRC32, expressions); } + public static FunctionExpression divide(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDE, expressions); + } + + public static FunctionExpression divideFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDEFUNCTION, expressions); + } + public static FunctionExpression euler(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.E, expressions); } @@ -202,6 +218,22 @@ public static FunctionExpression mod(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.MOD, expressions); } + public static FunctionExpression modulus(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MODULUS, expressions); + } + + public static FunctionExpression modulusFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MODULUSFUNCTION, expressions); + } + + public static FunctionExpression multiply(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLY, expressions); + } + + public static FunctionExpression multiplyFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLYFUNCTION, expressions); + } + public static FunctionExpression pi(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.PI, expressions); } @@ -286,20 +318,16 @@ public static FunctionExpression sin(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.SIN, expressions); } - public static FunctionExpression tan(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.TAN, expressions); - } - - public static FunctionExpression add(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ADD, expressions); - } - public static FunctionExpression subtract(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.SUBTRACT, expressions); } - public static FunctionExpression multiply(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLY, expressions); + public static FunctionExpression subtractFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SUBTRACTFUNCTION, expressions); + } + + public static FunctionExpression tan(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.TAN, expressions); } public static FunctionExpression convert_tz(Expression... expressions) { @@ -480,20 +508,11 @@ public static FunctionExpression yearweek( return compile(functionProperties, BuiltinFunctionName.YEARWEEK, expressions); } - public static FunctionExpression divide(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDE, expressions); - } - - public static FunctionExpression module(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MODULES, expressions); - } - - public static FunctionExpression str_to_date(FunctionProperties functionProperties, Expression... expressions) { return compile(functionProperties, BuiltinFunctionName.STR_TO_DATE, expressions); } - + public static FunctionExpression sec_to_time(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.SEC_TO_TIME, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 62eba2497c..c5076f7e91 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -34,7 +34,6 @@ public enum BuiltinFunctionName { LOG(FunctionName.of("log")), LOG10(FunctionName.of("log10")), LOG2(FunctionName.of("log2")), - MOD(FunctionName.of("mod")), PI(FunctionName.of("pi")), POW(FunctionName.of("pow")), POWER(FunctionName.of("power")), @@ -138,10 +137,16 @@ public enum BuiltinFunctionName { * Arithmetic Operators. */ ADD(FunctionName.of("+")), - SUBTRACT(FunctionName.of("-")), - MULTIPLY(FunctionName.of("*")), + ADDFUNCTION(FunctionName.of("add")), DIVIDE(FunctionName.of("/")), - MODULES(FunctionName.of("%")), + DIVIDEFUNCTION(FunctionName.of("divide")), + MOD(FunctionName.of("mod")), + MODULUS(FunctionName.of("%")), + MODULUSFUNCTION(FunctionName.of("modulus")), + MULTIPLY(FunctionName.of("*")), + MULTIPLYFUNCTION(FunctionName.of("multiply")), + SUBTRACT(FunctionName.of("-")), + SUBTRACTFUNCTION(FunctionName.of("subtract")), /** * Boolean Operators. diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index c4b106bbf4..cfa952ffe7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -25,6 +25,7 @@ import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionDSL; +import org.opensearch.sql.expression.function.FunctionName; /** * The definition of arithmetic function @@ -43,168 +44,255 @@ public class ArithmeticFunction { */ public static void register(BuiltinFunctionRepository repository) { repository.register(add()); - repository.register(subtract()); - repository.register(multiply()); + repository.register(addFunction()); repository.register(divide()); - repository.register(modules()); + repository.register(divideFunction()); + repository.register(mod()); + repository.register(modulus()); + repository.register(modulusFunction()); + repository.register(multiply()); + repository.register(multiplyFunction()); + repository.register(subtract()); + repository.register(subtractFunction()); } - private static DefaultFunctionResolver add() { - return FunctionDSL.define(BuiltinFunctionName.ADD.getName(), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() + v2.byteValue())), - BYTE, BYTE, BYTE), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprShortValue(v1.shortValue() + v2.shortValue())), - SHORT, SHORT, SHORT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprIntegerValue(Math.addExact(v1.integerValue(), - v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprLongValue(Math.addExact(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprFloatValue(v1.floatValue() + v2.floatValue())), - FLOAT, FLOAT, FLOAT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(v1.doubleValue() + v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) + /** + * Definition of add(x, y) function. + * Returns the number x plus number y + * The supported signature of add function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver addBase(FunctionName functionName) { + return FunctionDSL.define(functionName, + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprByteValue(v1.byteValue() + v2.byteValue())), + BYTE, BYTE, BYTE), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprShortValue(v1.shortValue() + v2.shortValue())), + SHORT, SHORT, SHORT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprIntegerValue(Math.addExact(v1.integerValue(), + v2.integerValue()))), + INTEGER, INTEGER, INTEGER), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprLongValue( + Math.addExact(v1.longValue(), v2.longValue()))), + LONG, LONG, LONG), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprFloatValue(v1.floatValue() + v2.floatValue())), + FLOAT, FLOAT, FLOAT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprDoubleValue(v1.doubleValue() + v2.doubleValue())), + DOUBLE, DOUBLE, DOUBLE) ); } - private static DefaultFunctionResolver subtract() { - return FunctionDSL.define(BuiltinFunctionName.SUBTRACT.getName(), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() - v2.byteValue())), - BYTE, BYTE, BYTE), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprShortValue(v1.shortValue() - v2.shortValue())), - SHORT, SHORT, SHORT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprIntegerValue(Math.subtractExact(v1.integerValue(), - v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprLongValue(Math.subtractExact(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprFloatValue(v1.floatValue() - v2.floatValue())), - FLOAT, FLOAT, FLOAT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(v1.doubleValue() - v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); + private static DefaultFunctionResolver add() { + return addBase(BuiltinFunctionName.ADD.getName()); } - private static DefaultFunctionResolver multiply() { - return FunctionDSL.define(BuiltinFunctionName.MULTIPLY.getName(), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() * v2.byteValue())), - BYTE, BYTE, BYTE), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprShortValue(v1.shortValue() * v2.shortValue())), - SHORT, SHORT, SHORT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprIntegerValue(Math.multiplyExact(v1.integerValue(), - v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprLongValue(Math.multiplyExact(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprFloatValue(v1.floatValue() * v2.floatValue())), - FLOAT, FLOAT, FLOAT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(v1.doubleValue() * v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) + private static DefaultFunctionResolver addFunction() { + return addBase(BuiltinFunctionName.ADDFUNCTION.getName()); + } + + /** + * Definition of divide(x, y) function. + * Returns the number x divided by number y + * The supported signature of divide function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver divideBase(FunctionName functionName) { + return FunctionDSL.define(functionName, + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprByteValue(v1.byteValue() / v2.byteValue())), + BYTE, BYTE, BYTE), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : + new ExprShortValue(v1.shortValue() / v2.shortValue())), + SHORT, SHORT, SHORT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : + new ExprIntegerValue(v1.integerValue() / v2.integerValue())), + INTEGER, INTEGER, INTEGER), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : + new ExprLongValue(v1.longValue() / v2.longValue())), + LONG, LONG, LONG), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : + new ExprFloatValue(v1.floatValue() / v2.floatValue())), + FLOAT, FLOAT, FLOAT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : + new ExprDoubleValue(v1.doubleValue() / v2.doubleValue())), + DOUBLE, DOUBLE, DOUBLE) ); } private static DefaultFunctionResolver divide() { - return FunctionDSL.define(BuiltinFunctionName.DIVIDE.getName(), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() / v2.byteValue())), - BYTE, BYTE, BYTE), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprShortValue(v1.shortValue() / v2.shortValue())), - SHORT, SHORT, SHORT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : - new ExprIntegerValue(v1.integerValue() / v2.integerValue())), - INTEGER, INTEGER, INTEGER), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : - new ExprLongValue(v1.longValue() / v2.longValue())), - LONG, LONG, LONG), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : - new ExprFloatValue(v1.floatValue() / v2.floatValue())), - FLOAT, FLOAT, FLOAT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : - new ExprDoubleValue(v1.doubleValue() / v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) + return divideBase(BuiltinFunctionName.DIVIDE.getName()); + } + + private static DefaultFunctionResolver divideFunction() { + return divideBase(BuiltinFunctionName.DIVIDEFUNCTION.getName()); + } + + /** + * Definition of modulo(x, y) function. + * Returns the number x modulo by number y + * The supported signature of modulo function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver modulusBase(FunctionName functionName) { + return FunctionDSL.define(functionName, + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprByteValue(v1.byteValue() % v2.byteValue())), + BYTE, BYTE, BYTE), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : + new ExprShortValue(v1.shortValue() % v2.shortValue())), + SHORT, SHORT, SHORT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : + new ExprIntegerValue(v1.integerValue() % v2.integerValue())), + INTEGER, INTEGER, INTEGER), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : + new ExprLongValue(v1.longValue() % v2.longValue())), + LONG, LONG, LONG), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : + new ExprFloatValue(v1.floatValue() % v2.floatValue())), + FLOAT, FLOAT, FLOAT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : + new ExprDoubleValue(v1.doubleValue() % v2.doubleValue())), + DOUBLE, DOUBLE, DOUBLE) ); } + private static DefaultFunctionResolver mod() { + return modulusBase(BuiltinFunctionName.MOD.getName()); + } - private static DefaultFunctionResolver modules() { - return FunctionDSL.define(BuiltinFunctionName.MODULES.getName(), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() % v2.byteValue())), - BYTE, BYTE, BYTE), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprShortValue(v1.shortValue() % v2.shortValue())), - SHORT, SHORT, SHORT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : - new ExprIntegerValue(v1.integerValue() % v2.integerValue())), - INTEGER, INTEGER, INTEGER), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : - new ExprLongValue(v1.longValue() % v2.longValue())), - LONG, LONG, LONG), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : - new ExprFloatValue(v1.floatValue() % v2.floatValue())), - FLOAT, FLOAT, FLOAT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : - new ExprDoubleValue(v1.doubleValue() % v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) + private static DefaultFunctionResolver modulus() { + return modulusBase(BuiltinFunctionName.MODULUS.getName()); + } + + private static DefaultFunctionResolver modulusFunction() { + return modulusBase(BuiltinFunctionName.MODULUSFUNCTION.getName()); + } + + /** + * Definition of multiply(x, y) function. + * Returns the number x multiplied by number y + * The supported signature of multiply function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver multiplyBase(FunctionName functionName) { + return FunctionDSL.define(functionName, + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprByteValue(v1.byteValue() * v2.byteValue())), + BYTE, BYTE, BYTE), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprShortValue(v1.shortValue() * v2.shortValue())), + SHORT, SHORT, SHORT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprIntegerValue(Math.multiplyExact(v1.integerValue(), + v2.integerValue()))), + INTEGER, INTEGER, INTEGER), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprLongValue( + Math.multiplyExact(v1.longValue(), v2.longValue()))), + LONG, LONG, LONG), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprFloatValue(v1.floatValue() * v2.floatValue())), + FLOAT, FLOAT, FLOAT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprDoubleValue(v1.doubleValue() * v2.doubleValue())), + DOUBLE, DOUBLE, DOUBLE) ); } + + private static DefaultFunctionResolver multiply() { + return multiplyBase(BuiltinFunctionName.MULTIPLY.getName()); + } + + private static DefaultFunctionResolver multiplyFunction() { + return multiplyBase(BuiltinFunctionName.MULTIPLYFUNCTION.getName()); + } + + /** + * Definition of subtract(x, y) function. + * Returns the number x minus number y + * The supported signature of subtract function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver subtractBase(FunctionName functionName) { + return FunctionDSL.define(functionName, + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprByteValue(v1.byteValue() - v2.byteValue())), + BYTE, BYTE, BYTE), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprShortValue(v1.shortValue() - v2.shortValue())), + SHORT, SHORT, SHORT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprIntegerValue(Math.subtractExact(v1.integerValue(), + v2.integerValue()))), + INTEGER, INTEGER, INTEGER), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprLongValue( + Math.subtractExact(v1.longValue(), v2.longValue()))), + LONG, LONG, LONG), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprFloatValue(v1.floatValue() - v2.floatValue())), + FLOAT, FLOAT, FLOAT), + FunctionDSL.impl( + FunctionDSL.nullMissingHandling( + (v1, v2) -> new ExprDoubleValue(v1.doubleValue() - v2.doubleValue())), + DOUBLE, DOUBLE, DOUBLE) + ); + } + + private static DefaultFunctionResolver subtract() { + return subtractBase(BuiltinFunctionName.SUBTRACT.getName()); + } + + private static DefaultFunctionResolver subtractFunction() { + return subtractBase(BuiltinFunctionName.SUBTRACTFUNCTION.getName()); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index 2fc4b9b608..929839be4d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -66,7 +66,6 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(log()); repository.register(log10()); repository.register(log2()); - repository.register(mod()); repository.register(pow()); repository.register(power()); repository.register(rint()); @@ -297,49 +296,6 @@ private static DefaultFunctionResolver log2() { new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), DOUBLE); } - /** - * Definition of mod(x, y) function. - * Calculate the remainder of x divided by y - * The supported signature of mod function is - * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) - * -> wider type between types of x and y - */ - private static DefaultFunctionResolver mod() { - return FunctionDSL.define(BuiltinFunctionName.MOD.getName(), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : - new ExprByteValue(v1.byteValue() % v2.byteValue())), - BYTE, BYTE, BYTE), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprShortValue(v1.shortValue() % v2.shortValue())), - SHORT, SHORT, SHORT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprIntegerValue(Math.floorMod(v1.integerValue(), - v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprLongValue(Math.floorMod(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprFloatValue(v1.floatValue() % v2.floatValue())), - FLOAT, FLOAT, FLOAT), - FunctionDSL.impl( - FunctionDSL.nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprDoubleValue(v1.doubleValue() % v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); - } - /** * Definition of pi() function. * Get the value of pi. diff --git a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java index 3d735d6762..8ce7a52394 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java @@ -95,8 +95,8 @@ protected Function, FunctionExpression> functionMapping( return (expressions) -> DSL.multiply(expressions.get(0), expressions.get(1)); case DIVIDE: return (expressions) -> DSL.divide(expressions.get(0), expressions.get(1)); - case MODULES: - return (expressions) -> DSL.module(expressions.get(0), expressions.get(1)); + case MODULUS: + return (expressions) -> DSL.modulus(expressions.get(0), expressions.get(1)); default: throw new RuntimeException(); } diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java index 7c0c6f8a82..fc3c0bce8f 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java @@ -7,6 +7,7 @@ package org.opensearch.sql.expression.operator.arthmetic; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.config.TestConfig.INT_TYPE_MISSING_VALUE_FIELD; import static org.opensearch.sql.config.TestConfig.INT_TYPE_NULL_VALUE_FIELD; @@ -14,7 +15,6 @@ import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; -import static org.opensearch.sql.data.type.ExprCoreType.SHORT; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.ref; @@ -39,6 +39,7 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.data.type.WideningTypeRule; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionTestBase; @@ -47,7 +48,6 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class ArithmeticFunctionTest extends ExpressionTestBase { - private static Stream arithmeticFunctionArguments() { List numberOp1 = Arrays.asList(new ExprByteValue(3), new ExprShortValue(3), new ExprIntegerValue(3), new ExprLongValue(3L), new ExprFloatValue(3f), @@ -66,16 +66,6 @@ private static Stream arithmeticOperatorArguments() { BuiltinFunctionName.DIVIDE, BuiltinFunctionName.DIVIDE).map(Arguments::of); } - @ParameterizedTest(name = "add({1}, {2})") - @MethodSource("arithmeticFunctionArguments") - public void add(ExprValue op1, ExprValue op2) { - FunctionExpression expression = DSL.add(literal(op1), literal(op2)); - ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); - assertEquals(expectedType, expression.type()); - assertValueEqual(BuiltinFunctionName.ADD, expectedType, op1, op2, expression.valueOf()); - assertEquals(String.format("+(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - } - @ParameterizedTest(name = "{0}(int,null)") @MethodSource("arithmeticOperatorArguments") public void arithmetic_int_null(BuiltinFunctionName builtinFunctionName) { @@ -141,6 +131,28 @@ public void arithmetic_null_missing(BuiltinFunctionName builtinFunctionName) { assertEquals(LITERAL_MISSING, functionExpression.valueOf(valueEnv())); } + @ParameterizedTest(name = "add({1}, {2})") + @MethodSource("arithmeticFunctionArguments") + public void add(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.add(literal(op1), literal(op2)); + ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); + assertEquals(expectedType, expression.type()); + assertValueEqual(BuiltinFunctionName.ADD, expectedType, op1, op2, expression.valueOf()); + assertEquals(String.format("+(%s, %s)", op1.toString(), op2.toString()), expression.toString()); + } + + @ParameterizedTest(name = "addFunction({1}, {2})") + @MethodSource("arithmeticFunctionArguments") + public void addFunction(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.addFunction(literal(op1), literal(op2)); + ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); + assertEquals(expectedType, expression.type()); + assertValueEqual(BuiltinFunctionName.ADDFUNCTION, + expectedType, op1, op2, expression.valueOf()); + assertEquals(String.format("add(%s, %s)", + op1.toString(), op2.toString()), expression.toString()); + } + @ParameterizedTest(name = "subtract({1}, {2})") @MethodSource("arithmeticFunctionArguments") public void subtract(ExprValue op1, ExprValue op2) { @@ -153,6 +165,64 @@ public void subtract(ExprValue op1, ExprValue op2) { expression.toString()); } + @ParameterizedTest(name = "subtractFunction({1}, {2})") + @MethodSource("arithmeticFunctionArguments") + public void subtractFunction(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.subtractFunction(literal(op1), literal(op2)); + ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); + assertEquals(expectedType, expression.type()); + assertValueEqual(BuiltinFunctionName.SUBTRACTFUNCTION, expectedType, op1, op2, + expression.valueOf()); + assertEquals(String.format("subtract(%s, %s)", op1.toString(), op2.toString()), + expression.toString()); + } + + @ParameterizedTest(name = "mod({1}, {2})") + @MethodSource("arithmeticFunctionArguments") + public void mod(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.mod(literal(op1), literal(op2)); + ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); + assertEquals(expectedType, expression.type()); + assertValueEqual(BuiltinFunctionName.MOD, expectedType, op1, op2, expression.valueOf()); + assertEquals(String.format("mod(%s, %s)", op1.toString(), op2.toString()), + expression.toString()); + + expression = DSL.mod(literal(op1), literal(new ExprShortValue(0))); + assertTrue(expression.valueOf(valueEnv()).isNull()); + assertEquals(String.format("mod(%s, 0)", op1.toString()), expression.toString()); + } + + @ParameterizedTest(name = "modulus({1}, {2})") + @MethodSource("arithmeticFunctionArguments") + public void modulus(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.modulus(literal(op1), literal(op2)); + ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); + assertEquals(expectedType, expression.type()); + assertValueEqual(BuiltinFunctionName.MODULUS, expectedType, op1, op2, expression.valueOf()); + assertEquals(String.format("%%(%s, %s)", op1.toString(), op2.toString()), + expression.toString()); + + expression = DSL.modulus(literal(op1), literal(new ExprShortValue(0))); + assertTrue(expression.valueOf(valueEnv()).isNull()); + assertEquals(String.format("%%(%s, 0)", op1.toString()), expression.toString()); + } + + @ParameterizedTest(name = "modulusFunction({1}, {2})") + @MethodSource("arithmeticFunctionArguments") + public void modulusFunction(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.modulusFunction(literal(op1), literal(op2)); + ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); + assertEquals(expectedType, expression.type()); + assertValueEqual(BuiltinFunctionName.MODULUSFUNCTION, + expectedType, op1, op2, expression.valueOf()); + assertEquals(String.format("modulus(%s, %s)", op1.toString(), op2.toString()), + expression.toString()); + + expression = DSL.modulusFunction(literal(op1), literal(new ExprShortValue(0))); + assertTrue(expression.valueOf(valueEnv()).isNull()); + assertEquals(String.format("modulus(%s, 0)", op1.toString()), expression.toString()); + } + @ParameterizedTest(name = "multiply({1}, {2})") @MethodSource("arithmeticFunctionArguments") public void multiply(ExprValue op1, ExprValue op2) { @@ -165,6 +235,18 @@ public void multiply(ExprValue op1, ExprValue op2) { expression.toString()); } + @ParameterizedTest(name = "multiplyFunction({1}, {2})") + @MethodSource("arithmeticFunctionArguments") + public void multiplyFunction(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.multiplyFunction(literal(op1), literal(op2)); + ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); + assertEquals(expectedType, expression.type()); + assertValueEqual(BuiltinFunctionName.MULTIPLYFUNCTION, expectedType, op1, op2, + expression.valueOf()); + assertEquals(String.format("multiply(%s, %s)", op1.toString(), op2.toString()), + expression.toString()); + } + @ParameterizedTest(name = "divide({1}, {2})") @MethodSource("arithmeticFunctionArguments") public void divide(ExprValue op1, ExprValue op2) { @@ -176,27 +258,55 @@ public void divide(ExprValue op1, ExprValue op2) { expression.toString()); expression = DSL.divide(literal(op1), literal(new ExprShortValue(0))); - expectedType = WideningTypeRule.max(op1.type(), SHORT); - assertEquals(expectedType, expression.type()); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("/(%s, 0)", op1.toString()), expression.toString()); } - @ParameterizedTest(name = "module({1}, {2})") + @ParameterizedTest(name = "divideFunction({1}, {2})") @MethodSource("arithmeticFunctionArguments") - public void module(ExprValue op1, ExprValue op2) { - FunctionExpression expression = DSL.module(literal(op1), literal(op2)); + public void divideFunction(ExprValue op1, ExprValue op2) { + FunctionExpression expression = DSL.divideFunction(literal(op1), literal(op2)); ExprType expectedType = WideningTypeRule.max(op1.type(), op2.type()); assertEquals(expectedType, expression.type()); - assertValueEqual(BuiltinFunctionName.MODULES, expectedType, op1, op2, expression.valueOf()); - assertEquals(String.format("%%(%s, %s)", op1.toString(), op2.toString()), - expression.toString()); + assertValueEqual(BuiltinFunctionName.DIVIDEFUNCTION, + expectedType, op1, op2, expression.valueOf()); + assertEquals(String.format("divide(%s, %s)", op1.toString(), op2.toString()), + expression.toString()); - expression = DSL.module(literal(op1), literal(new ExprShortValue(0))); - expectedType = WideningTypeRule.max(op1.type(), SHORT); - assertEquals(expectedType, expression.type()); + expression = DSL.divideFunction(literal(op1), literal(new ExprShortValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); - assertEquals(String.format("%%(%s, 0)", op1.toString()), expression.toString()); + assertEquals(String.format("divide(%s, 0)", op1.toString()), expression.toString()); + } + + @ParameterizedTest(name = "multipleParameters({1},{2})") + @MethodSource("arithmeticFunctionArguments") + public void multipleParameters(ExprValue op1) { + assertThrows(ExpressionEvaluationException.class, + () -> DSL.add(literal(op1), literal(op1), literal(op1))); + assertThrows(ExpressionEvaluationException.class, + () -> DSL.addFunction(literal(op1), literal(op1), literal(op1))); + + assertThrows(ExpressionEvaluationException.class, + () -> DSL.subtract(literal(op1), literal(op1), literal(op1))); + assertThrows(ExpressionEvaluationException.class, + () -> DSL.subtractFunction(literal(op1), literal(op1), literal(op1))); + + assertThrows(ExpressionEvaluationException.class, + () -> DSL.multiply(literal(op1), literal(op1), literal(op1))); + assertThrows(ExpressionEvaluationException.class, + () -> DSL.multiplyFunction(literal(op1), literal(op1), literal(op1))); + + assertThrows(ExpressionEvaluationException.class, + () -> DSL.divide(literal(op1), literal(op1), literal(op1))); + assertThrows(ExpressionEvaluationException.class, + () -> DSL.divideFunction(literal(op1), literal(op1), literal(op1))); + + assertThrows(ExpressionEvaluationException.class, + () -> DSL.mod(literal(op1), literal(op1), literal(op1))); + assertThrows(ExpressionEvaluationException.class, + () -> DSL.modulus(literal(op1), literal(op1), literal(op1))); + assertThrows(ExpressionEvaluationException.class, + () -> DSL.modulusFunction(literal(op1), literal(op1), literal(op1))); } protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprType type, @@ -210,19 +320,25 @@ protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprTyp Integer vbActual = actual.integerValue(); switch (builtinFunctionName) { case ADD: + case ADDFUNCTION: assertEquals(vb1 + vb2, vbActual); return; - case SUBTRACT: - assertEquals(vb1 - vb2, vbActual); - return; case DIVIDE: + case DIVIDEFUNCTION: assertEquals(vb1 / vb2, vbActual); return; + case MOD: + case MODULUS: + case MODULUSFUNCTION: + assertEquals(vb1 % vb2, vbActual); + return; case MULTIPLY: + case MULTIPLYFUNCTION: assertEquals(vb1 * vb2, vbActual); return; - case MODULES: - assertEquals(vb1 % vb2, vbActual); + case SUBTRACT: + case SUBTRACTFUNCTION: + assertEquals(vb1 - vb2, vbActual); return; default: throw new IllegalStateException("illegal function name: " + builtinFunctionName); @@ -233,19 +349,25 @@ protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprTyp Integer vsActual = actual.integerValue(); switch (builtinFunctionName) { case ADD: + case ADDFUNCTION: assertEquals(vs1 + vs2, vsActual); return; - case SUBTRACT: - assertEquals(vs1 - vs2, vsActual); - return; case DIVIDE: + case DIVIDEFUNCTION: assertEquals(vs1 / vs2, vsActual); return; + case MOD: + case MODULUS: + case MODULUSFUNCTION: + assertEquals(vs1 % vs2, vsActual); + return; case MULTIPLY: + case MULTIPLYFUNCTION: assertEquals(vs1 * vs2, vsActual); return; - case MODULES: - assertEquals(vs1 % vs2, vsActual); + case SUBTRACT: + case SUBTRACTFUNCTION: + assertEquals(vs1 - vs2, vsActual); return; default: throw new IllegalStateException("illegal function name " + builtinFunctionName); @@ -256,19 +378,25 @@ protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprTyp Integer viActual = ExprValueUtils.getIntegerValue(actual); switch (builtinFunctionName) { case ADD: + case ADDFUNCTION: assertEquals(vi1 + vi2, viActual); return; - case SUBTRACT: - assertEquals(vi1 - vi2, viActual); - return; case DIVIDE: + case DIVIDEFUNCTION: assertEquals(vi1 / vi2, viActual); return; + case MOD: + case MODULUS: + case MODULUSFUNCTION: + assertEquals(vi1 % vi2, viActual); + return; case MULTIPLY: + case MULTIPLYFUNCTION: assertEquals(vi1 * vi2, viActual); return; - case MODULES: - assertEquals(vi1 % vi2, viActual); + case SUBTRACT: + case SUBTRACTFUNCTION: + assertEquals(vi1 - vi2, viActual); return; default: throw new IllegalStateException("illegal function name " + builtinFunctionName); @@ -279,19 +407,25 @@ protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprTyp Long vlActual = ExprValueUtils.getLongValue(actual); switch (builtinFunctionName) { case ADD: + case ADDFUNCTION: assertEquals(vl1 + vl2, vlActual); return; - case SUBTRACT: - assertEquals(vl1 - vl2, vlActual); - return; case DIVIDE: + case DIVIDEFUNCTION: assertEquals(vl1 / vl2, vlActual); return; + case MOD: + case MODULUS: + case MODULUSFUNCTION: + assertEquals(vl1 % vl2, vlActual); + return; case MULTIPLY: + case MULTIPLYFUNCTION: assertEquals(vl1 * vl2, vlActual); return; - case MODULES: - assertEquals(vl1 % vl2, vlActual); + case SUBTRACT: + case SUBTRACTFUNCTION: + assertEquals(vl1 - vl2, vlActual); return; default: throw new IllegalStateException("illegal function name " + builtinFunctionName); @@ -302,19 +436,25 @@ protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprTyp Float vfActual = ExprValueUtils.getFloatValue(actual); switch (builtinFunctionName) { case ADD: + case ADDFUNCTION: assertEquals(vf1 + vf2, vfActual); return; - case SUBTRACT: - assertEquals(vf1 - vf2, vfActual); - return; case DIVIDE: + case DIVIDEFUNCTION: assertEquals(vf1 / vf2, vfActual); return; + case MODULUS: + case MODULUSFUNCTION: + case MOD: + assertEquals(vf1 % vf2, vfActual); + return; case MULTIPLY: + case MULTIPLYFUNCTION: assertEquals(vf1 * vf2, vfActual); return; - case MODULES: - assertEquals(vf1 % vf2, vfActual); + case SUBTRACT: + case SUBTRACTFUNCTION: + assertEquals(vf1 - vf2, vfActual); return; default: throw new IllegalStateException("illegal function name " + builtinFunctionName); @@ -325,19 +465,25 @@ protected void assertValueEqual(BuiltinFunctionName builtinFunctionName, ExprTyp Double vdActual = ExprValueUtils.getDoubleValue(actual); switch (builtinFunctionName) { case ADD: + case ADDFUNCTION: assertEquals(vd1 + vd2, vdActual); return; - case SUBTRACT: - assertEquals(vd1 - vd2, vdActual); - return; case DIVIDE: + case DIVIDEFUNCTION: assertEquals(vd1 / vd2, vdActual); return; + case MOD: + case MODULUS: + case MODULUSFUNCTION: + assertEquals(vd1 % vd2, vdActual); + return; case MULTIPLY: + case MULTIPLYFUNCTION: assertEquals(vd1 * vd2, vdActual); return; - case MODULES: - assertEquals(vd1 % vd2, vdActual); + case SUBTRACT: + case SUBTRACTFUNCTION: + assertEquals(vd1 - vd2, vdActual); return; default: throw new IllegalStateException("illegal function name " + builtinFunctionName); diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 55a5a4fd25..80b4aa7e6e 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -1216,167 +1216,6 @@ public void log2_missing_value() { assertTrue(log.valueOf(valueEnv()).isMissing()); } - /** - * Test mod with byte value. - */ - @ParameterizedTest(name = "mod({0}, {1})") - @MethodSource("testLogByteArguments") - public void mod_byte_value(Byte v1, Byte v2) { - FunctionExpression mod = DSL.mod(DSL.literal(v1), DSL.literal(v2)); - - assertThat( - mod.valueOf(valueEnv()), - allOf(hasType(BYTE), hasValue(Integer.valueOf(v1 % v2).byteValue()))); - assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); - - mod = DSL.mod(DSL.literal(v1), DSL.literal(new ExprByteValue(0))); - assertEquals(BYTE, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - } - - /** - * Test mod with short value. - */ - @ParameterizedTest(name = "mod({0}, {1})") - @MethodSource("testLogShortArguments") - public void mod_short_value(Short v1, Short v2) { - FunctionExpression mod = DSL.mod(DSL.literal(v1), DSL.literal(v2)); - - assertThat( - mod.valueOf(valueEnv()), - allOf(hasType(SHORT), - hasValue(Integer.valueOf(v1 % v2).shortValue()))); - assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); - - mod = DSL.mod(DSL.literal(v1), DSL.literal(new ExprShortValue(0))); - assertEquals(SHORT, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - } - - /** - * Test mod with integer value. - */ - @ParameterizedTest(name = "mod({0}, {1})") - @MethodSource("testLogIntegerArguments") - public void mod_int_value(Integer v1, Integer v2) { - FunctionExpression mod = DSL.mod(DSL.literal(v1), DSL.literal(v2)); - assertThat( - mod.valueOf(valueEnv()), - allOf(hasType(INTEGER), hasValue(v1 % v2))); - assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); - - mod = DSL.mod(DSL.literal(v1), DSL.literal(0)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - } - - /** - * Test mod with long value. - */ - @ParameterizedTest(name = "mod({0}, {1})") - @MethodSource("testLogLongArguments") - public void mod_long_value(Long v1, Long v2) { - FunctionExpression mod = DSL.mod(DSL.literal(v1), DSL.literal(v2)); - assertThat( - mod.valueOf(valueEnv()), - allOf(hasType(LONG), hasValue(v1 % v2))); - assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); - - mod = DSL.mod(DSL.literal(v1), DSL.literal(0)); - assertEquals(LONG, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - } - - /** - * Test mod with long value. - */ - @ParameterizedTest(name = "mod({0}, {1})") - @MethodSource("testLogFloatArguments") - public void mod_float_value(Float v1, Float v2) { - FunctionExpression mod = DSL.mod(DSL.literal(v1), DSL.literal(v2)); - assertThat( - mod.valueOf(valueEnv()), - allOf(hasType(FLOAT), hasValue(v1 % v2))); - assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); - - mod = DSL.mod(DSL.literal(v1), DSL.literal(0)); - assertEquals(FLOAT, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - } - - /** - * Test mod with double value. - */ - @ParameterizedTest(name = "mod({0}, {1})") - @MethodSource("testLogDoubleArguments") - public void mod_double_value(Double v1, Double v2) { - FunctionExpression mod = DSL.mod(DSL.literal(v1), DSL.literal(v2)); - assertThat( - mod.valueOf(valueEnv()), - allOf(hasType(DOUBLE), hasValue(v1 % v2))); - assertEquals(String.format("mod(%s, %s)", v1, v2), mod.toString()); - - mod = DSL.mod(DSL.literal(v1), DSL.literal(0)); - assertEquals(DOUBLE, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - } - - /** - * Test mod with null value. - */ - @Test - public void mod_null_value() { - FunctionExpression mod = DSL.mod(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.literal(1)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - - mod = DSL.mod(DSL.literal(1), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - - mod = DSL.mod( - DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isNull()); - } - - /** - * Test mod with missing value. - */ - @Test - public void mod_missing_value() { - FunctionExpression mod = - DSL.mod(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), DSL.literal(1)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isMissing()); - - mod = DSL.mod(DSL.literal(1), DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isMissing()); - - mod = DSL.mod( - DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), - DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isMissing()); - } - - /** - * Test mod with null and missing values. - */ - @Test - public void mod_null_missing() { - FunctionExpression mod = DSL.mod(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER), - DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isMissing()); - - mod = DSL.mod(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER), - DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); - assertEquals(INTEGER, mod.type()); - assertTrue(mod.valueOf(valueEnv()).isMissing()); - } - /** * Test pow/power with short value. */ diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 37e6cb937c..178848149e 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -129,10 +129,23 @@ ADD Description >>>>>>>>>>> -Specifications: +Usage: add(x, y) calculates x plus y. + +Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -1. ADD(NUMBER T, NUMBER) -> T +Return type: Wider number between x and y +Synonyms: Addition Symbol (+) + +Example:: + + os> SELECT ADD(2, 1), ADD(2.5, 3); + fetched rows / total rows = 1/1 + +-------------+---------------+ + | ADD(2, 1) | ADD(2.5, 3) | + |-------------+---------------| + | 3 | 5.5 | + +-------------+---------------+ ASIN ---- @@ -400,9 +413,23 @@ DIVIDE Description >>>>>>>>>>> -Specifications: +Usage: divide(x, y) calculates x divided by y. + +Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -1. DIVIDE(NUMBER T, NUMBER) -> T +Return type: Wider number between x and y + +Synonyms: Division Symbol (/) + +Example:: + + os> SELECT DIVIDE(10, 2), DIVIDE(7.5, 3); + fetched rows / total rows = 1/1 + +-----------------+------------------+ + | DIVIDE(10, 2) | DIVIDE(7.5, 3) | + |-----------------+------------------| + | 5 | 2.5 | + +-----------------+------------------+ E @@ -550,11 +577,13 @@ MOD Description >>>>>>>>>>> -Usage: MOD(n, m) calculates the remainder of the number n divided by m. +Usage: MOD(x, y) calculates the remainder of the number x divided by y. -Argument type: INTEGER/LONG/FLOAT/DOUBLE +Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -Return type: Wider type between types of n and m if m is nonzero value. If m equals to 0, then returns NULL. +Return type: Wider number between x and y. If y equals to 0, then returns NULL. + +Synonyms: Modulus Symbol (%), `MODULUS`_ Example:: @@ -566,6 +595,30 @@ Example:: | 1 | 1.1 | +-------------+---------------+ +MODULUS +------- + +Description +>>>>>>>>>>> + +Usage: MODULUS(x, y) calculates the remainder of the number x divided by y. + +Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE + +Return type: Wider number between x and y. If y equals to 0, then returns NULL. + +Synonyms: Modulus Symbol (%), `MOD`_ + +Example:: + + os> SELECT MODULUS(3, 2), MODULUS(3.1, 2) + fetched rows / total rows = 1/1 + +-----------------+-------------------+ + | MODULUS(3, 2) | MODULUS(3.1, 2) | + |-----------------+-------------------| + | 1 | 1.1 | + +-----------------+-------------------+ + MULTIPLY -------- @@ -573,9 +626,24 @@ MULTIPLY Description >>>>>>>>>>> -Specifications: +Usage: MULTIPLY(x, y) calculates the multiplication of x and y. + +Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE + +Return type: Wider number between x and y. If y equals to 0, then returns NULL. + +Synonyms: Multiplication Symbol (\*) + +Example:: + + os> SELECT MULTIPLY(1, 2), MULTIPLY(-2, 1), MULTIPLY(1.5, 2); + fetched rows / total rows = 1/1 + +------------------+-------------------+--------------------+ + | MULTIPLY(1, 2) | MULTIPLY(-2, 1) | MULTIPLY(1.5, 2) | + |------------------+-------------------+--------------------| + | 2 | -2 | 3.0 | + +------------------+-------------------+--------------------+ -1. MULTIPLY(NUMBER T, NUMBER) -> NUMBER PI -- @@ -879,9 +947,23 @@ SUBTRACT Description >>>>>>>>>>> -Specifications: +Usage: subtract(x, y) calculates x minus y. -1. SUBTRACT(NUMBER T, NUMBER) -> T +Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE + +Return type: Wider number between x and y + +Synonyms: Subtraction Symbol (-) + +Example:: + + os> SELECT SUBTRACT(2, 1), SUBTRACT(2.5, 3); + fetched rows / total rows = 1/1 + +------------------+--------------------+ + | SUBTRACT(2, 1) | SUBTRACT(2.5, 3) | + |------------------+--------------------| + | 1 | -0.5 | + +------------------+--------------------+ TAN diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/ArithmeticFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/ArithmeticFunctionIT.java new file mode 100644 index 0000000000..5b6c742e28 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/ArithmeticFunctionIT.java @@ -0,0 +1,246 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; +import static org.opensearch.sql.util.TestUtils.getResponseBody; + +import java.io.IOException; +import java.util.Locale; + +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +public class ArithmeticFunctionIT extends SQLIntegTestCase { + + @Override + public void init() throws Exception { + super.init(); + loadIndex(Index.BANK); + } + + public void testAdd() throws IOException { + JSONObject result = executeQuery("select 3 + 2"); + verifySchema(result, schema("3 + 2", null, "integer")); + verifyDataRows(result, rows(3 + 2)); + + result = executeQuery("select 2.5 + 2"); + verifySchema(result, schema("2.5 + 2", null, "double")); + verifyDataRows(result, rows(2.5D + 2)); + + result = executeQuery("select 3000000000 + 2"); + verifySchema(result, schema("3000000000 + 2", null, "long")); + verifyDataRows(result, rows(3000000000L + 2)); + + result = executeQuery("select CAST(6.666666 AS FLOAT) + 2"); + verifySchema(result, schema("CAST(6.666666 AS FLOAT) + 2", null, "float")); + verifyDataRows(result, rows(6.666666 + 2)); + } + + @Test + public void testAddFunction() throws IOException { + JSONObject result = executeQuery("select add(3, 2)"); + verifySchema(result, schema("add(3, 2)", null, "integer")); + verifyDataRows(result, rows(3 + 2)); + + result = executeQuery("select add(2.5, 2)"); + verifySchema(result, schema("add(2.5, 2)", null, "double")); + verifyDataRows(result, rows(2.5D + 2)); + + result = executeQuery("select add(3000000000, 2)"); + verifySchema(result, schema("add(3000000000, 2)", null, "long")); + verifyDataRows(result, rows(3000000000L + 2)); + + result = executeQuery("select add(CAST(6.666666 AS FLOAT), 2)"); + verifySchema(result, schema("add(CAST(6.666666 AS FLOAT), 2)", null, "float")); + verifyDataRows(result, rows(6.666666 + 2)); + } + + public void testDivide() throws IOException { + JSONObject result = executeQuery("select 3 / 2"); + verifySchema(result, schema("3 / 2", null, "integer")); + verifyDataRows(result, rows(3 / 2)); + + result = executeQuery("select 2.5 / 2"); + verifySchema(result, schema("2.5 / 2", null, "double")); + verifyDataRows(result, rows(2.5D / 2)); + + result = executeQuery("select 6000000000 / 2"); + verifySchema(result, schema("6000000000 / 2", null, "long")); + verifyDataRows(result, rows(6000000000L / 2)); + + result = executeQuery("select cast(1.6 AS float) / 2"); + verifySchema(result, schema("cast(1.6 AS float) / 2", null, "float")); + verifyDataRows(result, rows(1.6 / 2)); + } + + public void testDivideFunction() throws IOException { + JSONObject result = executeQuery("select divide(3, 2)"); + verifySchema(result, schema("divide(3, 2)", null, "integer")); + verifyDataRows(result, rows(3 / 2)); + + result = executeQuery("select divide(2.5, 2)"); + verifySchema(result, schema("divide(2.5, 2)", null, "double")); + verifyDataRows(result, rows(2.5D / 2)); + + result = executeQuery("select divide(6000000000, 2)"); + verifySchema(result, schema("divide(6000000000, 2)", null, "long")); + verifyDataRows(result, rows(6000000000L / 2)); + + result = executeQuery("select divide(cast(1.6 AS float), 2)"); + verifySchema(result, schema("divide(cast(1.6 AS float), 2)", null, "float")); + verifyDataRows(result, rows(1.6 / 2)); + } + + public void testMod() throws IOException { + JSONObject result = executeQuery("select mod(3, 2)"); + verifySchema(result, schema("mod(3, 2)", null, "integer")); + verifyDataRows(result, rows(3 % 2)); + + result = executeQuery("select mod(2.5, 2)"); + verifySchema(result, schema("mod(2.5, 2)", null, "double")); + verifyDataRows(result, rows(2.5D % 2)); + + result = executeQuery("select mod(cast(300001 as long), 2)"); + verifySchema(result, schema("mod(cast(300001 as long), 2)", null, "long")); + verifyDataRows(result, rows(3000001 % 2)); + + result = executeQuery("select mod(cast(1.6 AS float), 2)"); + verifySchema(result, schema("mod(cast(1.6 AS float), 2)", null, "float")); + verifyDataRows(result, rows(1.6 % 2)); + } + + public void testModulus() throws IOException { + JSONObject result = executeQuery("select 3 % 2"); + verifySchema(result, schema("3 % 2", null, "integer")); + verifyDataRows(result, rows(3 % 2)); + + result = executeQuery("select 2.5 % 2"); + verifySchema(result, schema("2.5 % 2", null, "double")); + verifyDataRows(result, rows(2.5D % 2)); + + result = executeQuery("select cast(300001 as long) % 2"); + verifySchema(result, schema("cast(300001 as long) % 2", null, "long")); + verifyDataRows(result, rows(300001 % 2)); + + result = executeQuery("select cast(1.6 AS float) % 2"); + verifySchema(result, schema("cast(1.6 AS float) % 2", null, "float")); + verifyDataRows(result, rows(1.6 % 2)); + } + + public void testModulusFunction() throws IOException { + JSONObject result = executeQuery("select modulus(3, 2)"); + verifySchema(result, schema("modulus(3, 2)", null, "integer")); + verifyDataRows(result, rows(3 % 2)); + + result = executeQuery("select modulus(2.5, 2)"); + verifySchema(result, schema("modulus(2.5, 2)", null, "double")); + verifyDataRows(result, rows(2.5D % 2)); + + result = executeQuery("select modulus(cast(300001 as long), 2)"); + verifySchema(result, schema("modulus(cast(300001 as long), 2)", null, "long")); + verifyDataRows(result, rows(300001 % 2)); + + result = executeQuery("select modulus(cast(1.6 AS float), 2)"); + verifySchema(result, schema("modulus(cast(1.6 AS float), 2)", null, "float")); + verifyDataRows(result, rows(1.6 % 2)); + } + + public void testMultiply() throws IOException { + JSONObject result = executeQuery("select 3 * 2"); + verifySchema(result, schema("3 * 2", null, "integer")); + verifyDataRows(result, rows(3 * 2)); + + result = executeQuery("select 2.5 * 2"); + verifySchema(result, schema("2.5 * 2", null, "double")); + verifyDataRows(result, rows(2.5D * 2)); + + result = executeQuery("select 3000000000 * 2"); + verifySchema(result, schema("3000000000 * 2", null, "long")); + verifyDataRows(result, rows(3000000000L * 2)); + + result = executeQuery("select CAST(1.6 AS FLOAT) * 2"); + verifySchema(result, schema("CAST(1.6 AS FLOAT) * 2", null, "float")); + verifyDataRows(result, rows(1.6 * 2)); + } + + @Test + public void testMultiplyFunction() throws IOException { + JSONObject result = executeQuery("select multiply(3, 2)"); + verifySchema(result, schema("multiply(3, 2)", null, "integer")); + verifyDataRows(result, rows(3 * 2)); + + result = executeQuery("select multiply(2.5, 2)"); + verifySchema(result, schema("multiply(2.5, 2)", null, "double")); + verifyDataRows(result, rows(2.5D * 2)); + + result = executeQuery("select multiply(3000000000, 2)"); + verifySchema(result, schema("multiply(3000000000, 2)", null, "long")); + verifyDataRows(result, rows(3000000000L * 2)); + + result = executeQuery("select multiply(CAST(1.6 AS FLOAT), 2)"); + verifySchema(result, schema("multiply(CAST(1.6 AS FLOAT), 2)", null, "float")); + verifyDataRows(result, rows(1.6 * 2)); + } + + public void testSubtract() throws IOException { + JSONObject result = executeQuery("select 3 - 2"); + verifySchema(result, schema("3 - 2", null, "integer")); + verifyDataRows(result, rows(3 - 2)); + + result = executeQuery("select 2.5 - 2"); + verifySchema(result, schema("2.5 - 2", null, "double")); + verifyDataRows(result, rows(2.5D - 2)); + + result = executeQuery("select 3000000000 - 2"); + verifySchema(result, schema("3000000000 - 2", null, "long")); + verifyDataRows(result, rows(3000000000L - 2)); + + result = executeQuery("select CAST(6.666666 AS FLOAT) - 2"); + verifySchema(result, schema("CAST(6.666666 AS FLOAT) - 2", null, "float")); + verifyDataRows(result, rows(6.666666 - 2)); + } + + @Test + public void testSubtractFunction() throws IOException { + JSONObject result = executeQuery("select subtract(3, 2)"); + verifySchema(result, schema("subtract(3, 2)", null, "integer")); + verifyDataRows(result, rows(3 - 2)); + + result = executeQuery("select subtract(2.5, 2)"); + verifySchema(result, schema("subtract(2.5, 2)", null, "double")); + verifyDataRows(result, rows(2.5D - 2)); + + result = executeQuery("select subtract(3000000000, 2)"); + verifySchema(result, schema("subtract(3000000000, 2)", null, "long")); + verifyDataRows(result, rows(3000000000L - 2)); + + result = executeQuery("select cast(subtract(cast(6.666666 as float), 2) as float)"); + verifySchema(result, schema("cast(subtract(cast(6.666666 as float), 2) as float)", null, "float")); + verifyDataRows(result, rows(6.666666 - 2)); + } + + protected JSONObject executeQuery(String query) throws IOException { + Request request = new Request("POST", QUERY_API_ENDPOINT); + request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); + + RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); + restOptionsBuilder.addHeader("Content-Type", "application/json"); + request.setOptions(restOptionsBuilder); + + Response response = client().performRequest(request); + return new JSONObject(getResponseBody(response)); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index 26e482a0d8..b4fcf8e4f9 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -39,7 +39,7 @@ public void testPI() throws IOException { schema("PI()", null, "double")); verifyDataRows(result, rows(3.141592653589793)); } - + @Test public void testCeil() throws IOException { JSONObject result = executeQuery("select ceil(0)"); diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index e41851bd63..343cefb34b 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -207,6 +207,7 @@ DAYOFMONTH: 'DAYOFMONTH'; DAYOFWEEK: 'DAYOFWEEK'; DAYOFYEAR: 'DAYOFYEAR'; DEGREES: 'DEGREES'; +DIVIDE: 'DIVIDE'; E: 'E'; EXP: 'EXP'; EXPM1: 'EXPM1'; @@ -392,7 +393,7 @@ MATCH_BOOL_PREFIX: 'MATCH_BOOL_PREFIX'; // Operators. Arithmetics STAR: '*'; -DIVIDE: '/'; +SLASH: '/'; MODULE: '%'; PLUS: '+'; MINUS: '-'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index f132595352..937c213d4c 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -289,7 +289,7 @@ expressionAtom | functionCall #functionCallExpressionAtom | LR_BRACKET expression RR_BRACKET #nestedExpressionAtom | left=expressionAtom - mathOperator=(STAR | DIVIDE | MODULE) + mathOperator=(STAR | SLASH | MODULE) right=expressionAtom #mathExpressionAtom | left=expressionAtom mathOperator=(PLUS | MINUS) @@ -453,12 +453,17 @@ mathematicalFunctionName : ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | EXPM1 | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER | RAND | RINT | ROUND | SIGN | SQRT | TRUNCATE | trigonometricFunctionName + | arithmeticFunctionName ; trigonometricFunctionName : ACOS | ASIN | ATAN | ATAN2 | COS | COT | DEGREES | RADIANS | SIN | SINH | TAN ; +arithmeticFunctionName + : ADD | SUBTRACT | MULTIPLY | DIVIDE | MOD | MODULUS + ; + dateTimeFunctionName : datetimeConstantLiteral | ADDDATE