From 7220cfeee3eb7383bd87466a01a848fc1f79c1a8 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Wed, 15 Mar 2023 10:41:47 -0700 Subject: [PATCH] Added RINT function to V2 engine (#240) (#1433) Added RINT to V2 engine, updated documentation, added unit and IT tests Signed-off-by: Matthew Wells --- .../org/opensearch/sql/expression/DSL.java | 4 + .../function/BuiltinFunctionName.java | 1 + .../arthmetic/MathematicalFunction.java | 12 +++ .../arthmetic/MathematicalFunctionTest.java | 88 ++++++++++++++++++- docs/user/dql/functions.rst | 16 +++- .../sql/sql/MathematicalFunctionIT.java | 19 ++++ sql/src/main/antlr/OpenSearchSQLParser.g4 | 2 +- 7 files changed, 138 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index a0df685c41..8a19230e3a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -218,6 +218,10 @@ public static FunctionExpression rand(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.RAND, expressions); } + public static FunctionExpression rint(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.RINT, expressions); + } + public static FunctionExpression round(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.ROUND, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 75c2628275..eca7f61ba6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -39,6 +39,7 @@ public enum BuiltinFunctionName { POW(FunctionName.of("pow")), POWER(FunctionName.of("power")), RAND(FunctionName.of("rand")), + RINT(FunctionName.of("rint")), ROUND(FunctionName.of("round")), SIGN(FunctionName.of("sign")), SINH(FunctionName.of("sinh")), diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index aafa5b7570..f81e775641 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -69,6 +69,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(mod()); repository.register(pow()); repository.register(power()); + repository.register(rint()); repository.register(round()); repository.register(sign()); repository.register(sinh()); @@ -411,6 +412,17 @@ private static DefaultFunctionResolver rand() { ); } + /** + * Definition of rint(x) function. + * Returns the closest whole integer value to x + * The supported signature is + * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver rint() { + return baseMathFunction(BuiltinFunctionName.RINT.getName(), + v -> new ExprDoubleValue(Math.rint(v.doubleValue())), DOUBLE); + } + /** * Definition of round(x)/round(x, d) function. * Rounds the argument x to d decimal places, d defaults to 0 if not specified. diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 4d79a3556a..5cf42eb868 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -642,7 +642,7 @@ public void expm1_short_value(Short value) { } /** - * Test expm1 with short value. + * Test expm1 with byte value. */ @ParameterizedTest(name = "expm1({0})") @ValueSource(bytes = { @@ -1570,6 +1570,92 @@ public void pow_null_missing() { assertTrue(power.valueOf(valueEnv()).isMissing()); } + /** + * Test rint with byte value. + */ + @ParameterizedTest(name = "rint({0})") + @ValueSource(bytes = { + -1, 0, 1, Byte.MAX_VALUE, Byte.MIN_VALUE}) + public void rint_byte_value(Byte value) { + FunctionExpression rint = DSL.rint(DSL.literal(value)); + assertThat( + rint.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.rint(value)))); + assertEquals(String.format("rint(%s)", value), rint.toString()); + } + + /** + * Test rint with short value. + */ + @ParameterizedTest(name = "rint({0})") + @ValueSource(shorts = { + -1, 0, 1, Short.MAX_VALUE, Short.MIN_VALUE}) + public void rint_short_value(Short value) { + FunctionExpression rint = DSL.rint(DSL.literal(value)); + assertThat( + rint.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.rint(value)))); + assertEquals(String.format("rint(%s)", value), rint.toString()); + } + + /** + * Test rint with integer value. + */ + @ParameterizedTest(name = "rint({0})") + @ValueSource(ints = { + -1, 0, 1, Integer.MAX_VALUE, Integer.MIN_VALUE}) + public void rint_int_value(Integer value) { + FunctionExpression rint = DSL.rint(DSL.literal(value)); + assertThat( + rint.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.rint(value)))); + assertEquals(String.format("rint(%s)", value), rint.toString()); + } + + /** + * Test rint with long value. + */ + @ParameterizedTest(name = "rint({0})") + @ValueSource(longs = { + -1L, 0L, 1L, Long.MAX_VALUE, Long.MIN_VALUE}) + public void rint_long_value(Long value) { + FunctionExpression rint = DSL.rint(DSL.literal(value)); + assertThat( + rint.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.rint(value)))); + assertEquals(String.format("rint(%s)", value), rint.toString()); + } + + /** + * Test rint with float value. + */ + @ParameterizedTest(name = "rint({0})") + @ValueSource(floats = { + -1F, -0.75F, -0.5F, 0F, 0.5F, 0.500000001F, + 0.75F, 1F, 1.9999F, 42.42F, Float.MAX_VALUE, Float.MIN_VALUE}) + public void rint_float_value(Float value) { + FunctionExpression rint = DSL.rint(DSL.literal(value)); + assertThat( + rint.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.rint(value)))); + assertEquals(String.format("rint(%s)", value), rint.toString()); + } + + /** + * Test rint with double value. + */ + @ParameterizedTest(name = "rint({0})") + @ValueSource(doubles = { + -1F, -0.75F, -0.5F, 0F, 0.5F, 0.500000001F, + 0.75F, 1F, 1.9999F, 42.42F, Double.MAX_VALUE, Double.MIN_VALUE}) + public void rint_double_value(Double value) { + FunctionExpression rint = DSL.rint(DSL.literal(value)); + assertThat( + rint.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.rint(value)))); + assertEquals(String.format("rint(%s)", value), rint.toString()); + } + /** * Test round with integer value. */ diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 321f26bc24..131c55be8c 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -712,9 +712,21 @@ RINT Description >>>>>>>>>>> -Specifications: +Usage: RINT(NUMBER T) returns T rounded to the closest whole integer number + +Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -1. RINT(NUMBER T) -> T +Return type: DOUBLE + +Example:: + + os> SELECT RINT(1.7); + fetched rows / total rows = 1/1 + +-------------+ + | RINT(1.7) | + |-------------| + | 2.0 | + +-------------+ ROUND diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index 9a8362e201..fd063bff14 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -113,6 +113,25 @@ public void testMod() throws IOException { verifyDataRows(result, rows(1.1)); } + @Test + public void testRint() throws IOException { + JSONObject result = executeQuery("select rint(56.78)"); + verifySchema(result, schema("rint(56.78)", null, "double")); + verifyDataRows(result, rows(57.0)); + + result = executeQuery("select rint(-56)"); + verifySchema(result, schema("rint(-56)", null, "double")); + verifyDataRows(result, rows(-56.0)); + + result = executeQuery("select rint(3.5)"); + verifySchema(result, schema("rint(3.5)", null, "double")); + verifyDataRows(result, rows(4.0)); + + result = executeQuery("select rint(-3.5)"); + verifySchema(result, schema("rint(-3.5)", null, "double")); + verifyDataRows(result, rows(-4.0)); + } + @Test public void testRound() throws IOException { JSONObject result = executeQuery("select round(56.78)"); diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index f5a165b8b8..4e296091f4 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -419,7 +419,7 @@ aggregationFunctionName mathematicalFunctionName : ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | EXPM1 | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER - | RAND | ROUND | SIGN | SQRT | TRUNCATE + | RAND | RINT | ROUND | SIGN | SQRT | TRUNCATE | trigonometricFunctionName ;