diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index 7130ac105a..1f4ac3943c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -106,7 +106,8 @@ private static DefaultFunctionResolver addFunction() { private static DefaultFunctionResolver divideBase(FunctionName functionName) { return define(functionName, impl(nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() / v2.byteValue())), + (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : + new ExprByteValue(v1.byteValue() / v2.byteValue())), BYTE, BYTE, BYTE), impl(nullMissingHandling( (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : @@ -140,7 +141,7 @@ private static DefaultFunctionResolver divideFunction() { } /** - * Definition of modulo(x, y) function. + * Definition of modulus(x, y) function. * Returns the number x modulo by number y * The supported signature of modulo function is * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) @@ -149,7 +150,8 @@ private static DefaultFunctionResolver divideFunction() { private static DefaultFunctionResolver modulusBase(FunctionName functionName) { return define(functionName, impl(nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() % v2.byteValue())), + (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : + new ExprByteValue(v1.byteValue() % v2.byteValue())), BYTE, BYTE, BYTE), impl(nullMissingHandling( (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java index b28bea8b89..028ace6231 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunctionTest.java @@ -113,7 +113,7 @@ public void mod(ExprValue op1, ExprValue op2) { assertEquals(String.format("mod(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.mod(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.mod(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("mod(%s, 0)", op1.toString()), expression.toString()); } @@ -128,7 +128,7 @@ public void modulus(ExprValue op1, ExprValue op2) { assertEquals(String.format("%%(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.modulus(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.modulus(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("%%(%s, 0)", op1.toString()), expression.toString()); } @@ -144,7 +144,7 @@ public void modulusFunction(ExprValue op1, ExprValue op2) { assertEquals(String.format("modulus(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.modulusFunction(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.modulusFunction(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("modulus(%s, 0)", op1.toString()), expression.toString()); } @@ -183,7 +183,7 @@ public void divide(ExprValue op1, ExprValue op2) { assertEquals(String.format("/(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.divide(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.divide(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("/(%s, 0)", op1.toString()), expression.toString()); } @@ -199,7 +199,7 @@ public void divideFunction(ExprValue op1, ExprValue op2) { assertEquals(String.format("divide(%s, %s)", op1.toString(), op2.toString()), expression.toString()); - expression = DSL.divideFunction(literal(op1), literal(new ExprShortValue(0))); + expression = DSL.divideFunction(literal(op1), literal(new ExprByteValue(0))); assertTrue(expression.valueOf(valueEnv()).isNull()); assertEquals(String.format("divide(%s, 0)", op1.toString()), expression.toString()); }