Skip to content

Commit

Permalink
Added RINT function to V2 engine (#240) (#1433)
Browse files Browse the repository at this point in the history
Added RINT to V2 engine, updated documentation, added unit and IT tests

Signed-off-by: Matthew Wells <[email protected]>
  • Loading branch information
matthewryanwells authored Mar 15, 2023
1 parent 8dad71f commit 7220cfe
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 4 deletions.
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ public static FunctionExpression rand(Expression... expressions) {
return compile(FunctionProperties.None, BuiltinFunctionName.RAND, expressions);
}

public static FunctionExpression rint(Expression... expressions) {
return compile(FunctionProperties.None, BuiltinFunctionName.RINT, expressions);
}

public static FunctionExpression round(Expression... expressions) {
return compile(FunctionProperties.None, BuiltinFunctionName.ROUND, expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public enum BuiltinFunctionName {
POW(FunctionName.of("pow")),
POWER(FunctionName.of("power")),
RAND(FunctionName.of("rand")),
RINT(FunctionName.of("rint")),
ROUND(FunctionName.of("round")),
SIGN(FunctionName.of("sign")),
SINH(FunctionName.of("sinh")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public static void register(BuiltinFunctionRepository repository) {
repository.register(mod());
repository.register(pow());
repository.register(power());
repository.register(rint());
repository.register(round());
repository.register(sign());
repository.register(sinh());
Expand Down Expand Up @@ -411,6 +412,17 @@ private static DefaultFunctionResolver rand() {
);
}

/**
* Definition of rint(x) function.
* Returns the closest whole integer value to x
* The supported signature is
* BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE
*/
private static DefaultFunctionResolver rint() {
return baseMathFunction(BuiltinFunctionName.RINT.getName(),
v -> new ExprDoubleValue(Math.rint(v.doubleValue())), DOUBLE);
}

/**
* Definition of round(x)/round(x, d) function.
* Rounds the argument x to d decimal places, d defaults to 0 if not specified.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ public void expm1_short_value(Short value) {
}

/**
* Test expm1 with short value.
* Test expm1 with byte value.
*/
@ParameterizedTest(name = "expm1({0})")
@ValueSource(bytes = {
Expand Down Expand Up @@ -1570,6 +1570,92 @@ public void pow_null_missing() {
assertTrue(power.valueOf(valueEnv()).isMissing());
}

/**
* Test rint with byte value.
*/
@ParameterizedTest(name = "rint({0})")
@ValueSource(bytes = {
-1, 0, 1, Byte.MAX_VALUE, Byte.MIN_VALUE})
public void rint_byte_value(Byte value) {
FunctionExpression rint = DSL.rint(DSL.literal(value));
assertThat(
rint.valueOf(valueEnv()),
allOf(hasType(DOUBLE), hasValue(Math.rint(value))));
assertEquals(String.format("rint(%s)", value), rint.toString());
}

/**
* Test rint with short value.
*/
@ParameterizedTest(name = "rint({0})")
@ValueSource(shorts = {
-1, 0, 1, Short.MAX_VALUE, Short.MIN_VALUE})
public void rint_short_value(Short value) {
FunctionExpression rint = DSL.rint(DSL.literal(value));
assertThat(
rint.valueOf(valueEnv()),
allOf(hasType(DOUBLE), hasValue(Math.rint(value))));
assertEquals(String.format("rint(%s)", value), rint.toString());
}

/**
* Test rint with integer value.
*/
@ParameterizedTest(name = "rint({0})")
@ValueSource(ints = {
-1, 0, 1, Integer.MAX_VALUE, Integer.MIN_VALUE})
public void rint_int_value(Integer value) {
FunctionExpression rint = DSL.rint(DSL.literal(value));
assertThat(
rint.valueOf(valueEnv()),
allOf(hasType(DOUBLE), hasValue(Math.rint(value))));
assertEquals(String.format("rint(%s)", value), rint.toString());
}

/**
* Test rint with long value.
*/
@ParameterizedTest(name = "rint({0})")
@ValueSource(longs = {
-1L, 0L, 1L, Long.MAX_VALUE, Long.MIN_VALUE})
public void rint_long_value(Long value) {
FunctionExpression rint = DSL.rint(DSL.literal(value));
assertThat(
rint.valueOf(valueEnv()),
allOf(hasType(DOUBLE), hasValue(Math.rint(value))));
assertEquals(String.format("rint(%s)", value), rint.toString());
}

/**
* Test rint with float value.
*/
@ParameterizedTest(name = "rint({0})")
@ValueSource(floats = {
-1F, -0.75F, -0.5F, 0F, 0.5F, 0.500000001F,
0.75F, 1F, 1.9999F, 42.42F, Float.MAX_VALUE, Float.MIN_VALUE})
public void rint_float_value(Float value) {
FunctionExpression rint = DSL.rint(DSL.literal(value));
assertThat(
rint.valueOf(valueEnv()),
allOf(hasType(DOUBLE), hasValue(Math.rint(value))));
assertEquals(String.format("rint(%s)", value), rint.toString());
}

/**
* Test rint with double value.
*/
@ParameterizedTest(name = "rint({0})")
@ValueSource(doubles = {
-1F, -0.75F, -0.5F, 0F, 0.5F, 0.500000001F,
0.75F, 1F, 1.9999F, 42.42F, Double.MAX_VALUE, Double.MIN_VALUE})
public void rint_double_value(Double value) {
FunctionExpression rint = DSL.rint(DSL.literal(value));
assertThat(
rint.valueOf(valueEnv()),
allOf(hasType(DOUBLE), hasValue(Math.rint(value))));
assertEquals(String.format("rint(%s)", value), rint.toString());
}

/**
* Test round with integer value.
*/
Expand Down
16 changes: 14 additions & 2 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -712,9 +712,21 @@ RINT
Description
>>>>>>>>>>>

Specifications:
Usage: RINT(NUMBER T) returns T rounded to the closest whole integer number

Argument type: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE

1. RINT(NUMBER T) -> T
Return type: DOUBLE

Example::

os> SELECT RINT(1.7);
fetched rows / total rows = 1/1
+-------------+
| RINT(1.7) |
|-------------|
| 2.0 |
+-------------+


ROUND
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,25 @@ public void testMod() throws IOException {
verifyDataRows(result, rows(1.1));
}

@Test
public void testRint() throws IOException {
JSONObject result = executeQuery("select rint(56.78)");
verifySchema(result, schema("rint(56.78)", null, "double"));
verifyDataRows(result, rows(57.0));

result = executeQuery("select rint(-56)");
verifySchema(result, schema("rint(-56)", null, "double"));
verifyDataRows(result, rows(-56.0));

result = executeQuery("select rint(3.5)");
verifySchema(result, schema("rint(3.5)", null, "double"));
verifyDataRows(result, rows(4.0));

result = executeQuery("select rint(-3.5)");
verifySchema(result, schema("rint(-3.5)", null, "double"));
verifyDataRows(result, rows(-4.0));
}

@Test
public void testRound() throws IOException {
JSONObject result = executeQuery("select round(56.78)");
Expand Down
2 changes: 1 addition & 1 deletion sql/src/main/antlr/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ aggregationFunctionName

mathematicalFunctionName
: ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | EXPM1 | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER
| RAND | ROUND | SIGN | SQRT | TRUNCATE
| RAND | RINT | ROUND | SIGN | SQRT | TRUNCATE
| trigonometricFunctionName
;

Expand Down

0 comments on commit 7220cfe

Please sign in to comment.