diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index c186cec0a3..1d467c24f2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -342,11 +342,15 @@ FunctionBuilder>>> powerFunctionImpl() { DOUBLE, LONG, LONG), FunctionDSL.impl( FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(Math.pow(v1.floatValue(), v2.floatValue()))), + (v1, v2) -> v1.floatValue() <= 0 && v2.floatValue() + != Math.floor(v2.floatValue()) ? ExprNullValue.of() : + new ExprDoubleValue(Math.pow(v1.floatValue(), v2.floatValue()))), DOUBLE, FLOAT, FLOAT), FunctionDSL.impl( FunctionDSL.nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(Math.pow(v1.doubleValue(), v2.doubleValue()))), + (v1, v2) -> v1.doubleValue() <= 0 && v2.doubleValue() + != Math.floor(v2.doubleValue()) ? ExprNullValue.of() : + new ExprDoubleValue(Math.pow(v1.doubleValue(), v2.doubleValue()))), DOUBLE, DOUBLE, DOUBLE)); } diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 7099475b64..47cb92c5b8 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -1409,6 +1409,56 @@ public void pow_null_missing() { assertTrue(power.valueOf(valueEnv()).isMissing()); } + /** + * Test pow/power with null output. + */ + @Test + public void pow_null_output() { + FunctionExpression pow = DSL.pow(DSL.literal((double) -2), DSL.literal(1.5)); + assertEquals(pow.type(), DOUBLE); + assertEquals(String.format("pow(%s, %s)", (double) -2, 1.5), pow.toString()); + assertTrue(pow.valueOf(valueEnv()).isNull()); + + pow = DSL.pow(DSL.literal((float) -2), DSL.literal((float) 1.5)); + assertEquals(pow.type(), DOUBLE); + assertEquals(String.format("pow(%s, %s)", (float) -2, (float) 1.5), pow.toString()); + assertTrue(pow.valueOf(valueEnv()).isNull()); + } + + /** + * Test pow/power with edge cases. + */ + @Test + public void pow_edge_cases() { + FunctionExpression pow = DSL.pow(DSL.literal((double) -2), DSL.literal((double) 2)); + assertEquals(pow.type(), DOUBLE); + assertEquals(String.format("pow(%s, %s)",(double) -2, (double) 2), pow.toString()); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(-2, 2)))); + + pow = DSL.pow(DSL.literal((double) 2), DSL.literal((double) 1.5)); + assertEquals(pow.type(), DOUBLE); + assertEquals(String.format("pow(%s, %s)", (double) 2, (double) 1.5), pow.toString()); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow(2, 1.5)))); + + pow = DSL.pow(DSL.literal((float) -2), DSL.literal((float) 2)); + assertEquals(pow.type(), DOUBLE); + assertEquals(String.format("pow(%s, %s)", (float) -2, (float) 2), pow.toString()); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow((float) -2, (float) 2)))); + + pow = DSL.pow(DSL.literal((float) 2), DSL.literal((float) 1.5)); + assertEquals(pow.type(), DOUBLE); + assertEquals(String.format("pow(%s, %s)", (float) 2, (float) 1.5), pow.toString()); + assertThat( + pow.valueOf(valueEnv()), + allOf(hasType(DOUBLE), hasValue(Math.pow((float) 2, (float) 1.5)))); + } + /** * Test rint with byte value. */ diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index c1aedc459f..f1b52900a1 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -98,6 +98,68 @@ public void testMod() throws IOException { verifyDataRows(result, rows(1.1)); } + @Test + public void testPow() throws IOException { + JSONObject result = executeQuery("select pow(3, 2)"); + verifySchema(result, schema("pow(3, 2)", null, "double")); + verifyDataRows(result, rows(9.0)); + + result = executeQuery("select pow(0, 2)"); + verifySchema(result, schema("pow(0, 2)", null, "double")); + verifyDataRows(result, rows(0.0)); + + result = executeQuery("select pow(3, 0)"); + verifySchema(result, schema("pow(3, 0)", null, "double")); + verifyDataRows(result, rows(1.0)); + + result = executeQuery("select pow(-2, 3)"); + verifySchema(result, schema("pow(-2, 3)", null, "double")); + verifyDataRows(result, rows(-8.0)); + + result = executeQuery("select pow(2, -2)"); + verifySchema(result, schema("pow(2, -2)", null, "double")); + verifyDataRows(result, rows(0.25)); + + result = executeQuery("select pow(-2, -3)"); + verifySchema(result, schema("pow(-2, -3)", null, "double")); + verifyDataRows(result, rows(-0.125)); + + result = executeQuery("select pow(-1, 0.5)"); + verifySchema(result, schema("pow(-1, 0.5)", null, "double")); + verifyDataRows(result, rows((Object) null)); + } + + @Test + public void testPower() throws IOException { + JSONObject result = executeQuery("select power(3, 2)"); + verifySchema(result, schema("power(3, 2)", null, "double")); + verifyDataRows(result, rows(9.0)); + + result = executeQuery("select power(0, 2)"); + verifySchema(result, schema("power(0, 2)", null, "double")); + verifyDataRows(result, rows(0.0)); + + result = executeQuery("select power(3, 0)"); + verifySchema(result, schema("power(3, 0)", null, "double")); + verifyDataRows(result, rows(1.0)); + + result = executeQuery("select power(-2, 3)"); + verifySchema(result, schema("power(-2, 3)", null, "double")); + verifyDataRows(result, rows(-8.0)); + + result = executeQuery("select power(2, -2)"); + verifySchema(result, schema("power(2, -2)", null, "double")); + verifyDataRows(result, rows(0.25)); + + result = executeQuery("select power(2, -2)"); + verifySchema(result, schema("power(2, -2)", null, "double")); + verifyDataRows(result, rows(0.25)); + + result = executeQuery("select power(-2, -3)"); + verifySchema(result, schema("power(-2, -3)", null, "double")); + verifyDataRows(result, rows(-0.125)); + } + @Test public void testRint() throws IOException { JSONObject result = executeQuery("select rint(56.78)");