Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding safe divide function #11904

Merged
merged 7 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions core/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
/**
* Base interface describing the mechanism used to evaluate a {@link FunctionExpr}. All {@link Function} implementations
* are immutable.
*
* <p>
* Do NOT remove "unused" members in this class. They are used by generated Antlr
*/
@SuppressWarnings("unused")
Expand Down Expand Up @@ -1165,6 +1165,51 @@ public <T> ExprVectorProcessor<T> asVectorProcessor(Expr.VectorInputBindingInspe
}
}

class SafeDivide extends BivariateMathFunction
{
public static final String NAME = "safe_divide";

@Override
public String name()
{
return NAME;
}

@Nullable
@Override
public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
return ExpressionTypeConversion.function(
args.get(0).getOutputType(inspector),
args.get(1).getOutputType(inspector)
);
}

@Override
protected ExprEval eval(final long x, final long y)
{
if (y == 0) {
if (x != 0) {
return ExprEval.ofLong(NullHandling.defaultLongValue());
}
return ExprEval.ofLong(0);
}
return ExprEval.ofLong(x / y);
}

@Override
protected ExprEval eval(final double x, final double y)
{
if (y == 0 || Double.isNaN(y)) {
if (x != 0) {
return ExprEval.ofDouble(NullHandling.defaultDoubleValue());
}
return ExprEval.ofDouble(0);
}
return ExprEval.ofDouble(x / y);
}
}

class Div extends BivariateMathFunction
{
@Override
Expand Down Expand Up @@ -1932,7 +1977,9 @@ protected ExprEval eval(ExprEval x, ExprEval y)
public Set<Expr> getScalarInputs(List<Expr> args)
{
if (args.get(1).isLiteral()) {
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()));
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1)
.getLiteralValue()
.toString()));
switch (castTo.getType()) {
case ARRAY:
return Collections.emptySet();
Expand All @@ -1948,7 +1995,9 @@ public Set<Expr> getScalarInputs(List<Expr> args)
public Set<Expr> getArrayInputs(List<Expr> args)
{
if (args.get(1).isLiteral()) {
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()));
ExpressionType castTo = ExpressionType.fromString(StringUtils.toUpperCase(args.get(1)
.getLiteralValue()
.toString()));
switch (castTo.getType()) {
case LONG:
case DOUBLE:
Expand Down Expand Up @@ -3237,7 +3286,9 @@ ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
break;
}
}
return index < 0 ? ExprEval.ofLong(NullHandling.replaceWithDefault() ? -1 : null) : ExprEval.ofLong(index + 1);
return index < 0
? ExprEval.ofLong(NullHandling.replaceWithDefault() ? -1 : null)
: ExprEval.ofLong(index + 1);
default:
throw new IAE("Function[%s] 2nd argument must be a a scalar type", name());
}
Expand Down Expand Up @@ -3591,7 +3642,8 @@ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
name()
);
}
ExpressionType complexType = ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
ExpressionType complexType = ExpressionTypeFactory.getInstance()
.ofComplex((String) args.get(0).getLiteralValue());
ObjectByteStrategy strategy = Types.getStrategy(complexType.getComplexTypeName());
if (strategy == null) {
throw new IAE(
Expand Down
141 changes: 103 additions & 38 deletions core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,36 @@ public class FunctionTest extends InitializedNullHandlingTest
@BeforeClass
public static void setupClass()
{
Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
Types.registerStrategy(
TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
new TypesTest.PairObjectByteStrategy()
);
}

@Before
public void setup()
{
ImmutableMap.Builder<String, Object> builder = ImmutableMap.<String, Object>builder()
.put("x", "foo")
.put("y", 2)
.put("z", 3.1)
.put("d", 34.56D)
.put("maxLong", Long.MAX_VALUE)
.put("minLong", Long.MIN_VALUE)
.put("f", 12.34F)
.put("nan", Double.NaN)
.put("inf", Double.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY)
.put("o", 0)
.put("od", 0D)
.put("of", 0F)
.put("a", new String[] {"foo", "bar", "baz", "foobar"})
.put("b", new Long[] {1L, 2L, 3L, 4L, 5L})
.put("c", new Double[] {3.1, 4.2, 5.3})
.put("someComplex", new TypesTest.NullableLongPair(1L, 2L));
.put("x", "foo")
.put("y", 2)
.put("z", 3.1)
.put("d", 34.56D)
.put("maxLong", Long.MAX_VALUE)
.put("minLong", Long.MIN_VALUE)
.put("f", 12.34F)
.put("nan", Double.NaN)
.put("inf", Double.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY)
.put("o", 0)
.put("od", 0D)
.put("of", 0F)
.put("a", new String[]{"foo", "bar", "baz", "foobar"})
.put("b", new Long[]{1L, 2L, 3L, 4L, 5L})
.put("c", new Double[]{3.1, 4.2, 5.3})
.put(
"someComplex",
new TypesTest.NullableLongPair(1L, 2L)
);
bindings = InputBindings.withMap(builder.build());
}

Expand Down Expand Up @@ -350,17 +356,20 @@ public void testArrayCast()
assertArrayExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"});
assertArrayExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0});
assertArrayExpr("cast(c, 'LONG_ARRAY')", new Long[]{3L, 4L, 5L});
assertArrayExpr("cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')", new Long[]{1L, 2L, 3L, 4L, 5L});
assertArrayExpr(
"cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')",
new Long[]{1L, 2L, 3L, 4L, 5L}
);
assertArrayExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L});
}

@Test
public void testArraySlice()
{
assertArrayExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[] {2L, 3L});
assertArrayExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[] {3.2, 4.3});
assertArrayExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[] {null, null});
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[] {});
assertArrayExpr("array_slice([1, 2, 3, 4], 1, 3)", new Long[]{2L, 3L});
assertArrayExpr("array_slice([1.0, 2.1, 3.2, 4.3], 2)", new Double[]{3.2, 4.3});
assertArrayExpr("array_slice(['a', 'b', 'c', 'd'], 4, 6)", new String[]{null, null});
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 2)", new Long[]{});
assertArrayExpr("array_slice([1, 2, 3, 4], 5, 7)", null);
assertArrayExpr("array_slice([1, 2, 3, 4], 2, 1)", null);
}
Expand Down Expand Up @@ -438,12 +447,24 @@ public void testRoundWithExtremeNumbers()
assertExpr("round(maxLong)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(0, RoundingMode.HALF_UP).longValue());
assertExpr("round(minLong)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(0, RoundingMode.HALF_UP).longValue());
// overflow
assertExpr("round(maxLong + 1, 1)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(1, RoundingMode.HALF_UP).longValue());
assertExpr(
"round(maxLong + 1, 1)",
BigDecimal.valueOf(Long.MIN_VALUE).setScale(1, RoundingMode.HALF_UP).longValue()
);
// underflow
assertExpr("round(minLong - 1, -2)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(-2, RoundingMode.HALF_UP).longValue());
assertExpr(
"round(minLong - 1, -2)",
BigDecimal.valueOf(Long.MAX_VALUE).setScale(-2, RoundingMode.HALF_UP).longValue()
);

assertExpr("round(CAST(maxLong, 'DOUBLE') + 1, 1)", BigDecimal.valueOf(((double) Long.MAX_VALUE) + 1).setScale(1, RoundingMode.HALF_UP).doubleValue());
assertExpr("round(CAST(minLong, 'DOUBLE') - 1, -2)", BigDecimal.valueOf(((double) Long.MIN_VALUE) - 1).setScale(-2, RoundingMode.HALF_UP).doubleValue());
assertExpr(
"round(CAST(maxLong, 'DOUBLE') + 1, 1)",
BigDecimal.valueOf(((double) Long.MAX_VALUE) + 1).setScale(1, RoundingMode.HALF_UP).doubleValue()
);
assertExpr(
"round(CAST(minLong, 'DOUBLE') - 1, -2)",
BigDecimal.valueOf(((double) Long.MIN_VALUE) - 1).setScale(-2, RoundingMode.HALF_UP).doubleValue()
);
}

@Test
Expand Down Expand Up @@ -643,7 +664,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(NullHandling.sqlCompatible() ? true : false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs a number as its first argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs a number as its first argument",
e.getMessage()
);
}

try {
Expand All @@ -655,7 +679,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs an integer as its second argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs an integer as its second argument",
e.getMessage()
);
}

try {
Expand All @@ -667,7 +694,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs an integer as its second argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs an integer as its second argument",
e.getMessage()
);
}

try {
Expand All @@ -679,7 +709,10 @@ public void testSizeForatInvalidArgumentType()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Function[human_readable_binary_byte_format] needs an integer as its second argument", e.getMessage());
Assert.assertEquals(
"Function[human_readable_binary_byte_format] needs an integer as its second argument",
e.getMessage()
);
}
}

Expand All @@ -692,7 +725,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[9223372036854775807] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[9223372036854775807] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}

try {
Expand All @@ -701,7 +737,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[-9223372036854775808] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[-9223372036854775808] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}

try {
Expand All @@ -710,7 +749,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[-1] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[-1] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}

try {
Expand All @@ -719,7 +761,10 @@ public void testSizeFormatInvalidPrecision()
Assert.assertTrue(false);
}
catch (IAE e) {
Assert.assertEquals("Given precision[4] of Function[human_readable_binary_byte_format] must be in the range of [0,3]", e.getMessage());
Assert.assertEquals(
"Given precision[4] of Function[human_readable_binary_byte_format] must be in the range of [0,3]",
e.getMessage()
);
}
}

Expand All @@ -732,6 +777,21 @@ public void testSizeFormatInvalidArgumentSize()
.eval(bindings);
}

@Test
public void testSafeDivide()
{
// happy path maths
assertExpr("safe_divide(3, 1)", 3L);
assertExpr("safe_divide(4.5, 2)", 2.25);
assertExpr("safe_divide(3, 0)", NullHandling.defaultLongValue());
assertExpr("safe_divide(1, 0.0)", NullHandling.defaultDoubleValue());
// NaN and Infinity cases
assertExpr("safe_divide(NaN, 0.0)", NullHandling.defaultDoubleValue());
assertExpr("safe_divide(0, NaN)", 0.0);
assertExpr("safe_divide(0, POSITIVE_INFINITY)", NullHandling.defaultLongValue());
assertExpr("safe_divide(POSITIVE_INFINITY,0)", NullHandling.defaultLongValue());
}

@Test
public void testBitwise()
{
Expand Down Expand Up @@ -763,7 +823,10 @@ public void testBitwise()
Assert.fail("Did not throw IllegalArgumentException");
}
catch (IllegalArgumentException e) {
Assert.assertEquals("Possible data truncation, param [461168601842738800000000000000.000000] is out of long value range", e.getMessage());
Assert.assertEquals(
"Possible data truncation, param [461168601842738800000000000000.000000] is out of long value range",
e.getMessage()
);
}

// doubles are cast
Expand Down Expand Up @@ -845,7 +908,8 @@ public void testComplexDecodeBaseWrongArgCount()
public void testComplexDecodeBaseArg0BadType()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[complex_decode_base64] first argument must be constant 'STRING' expression containing a valid complex type name");
expectedException.expectMessage(
"Function[complex_decode_base64] first argument must be constant 'STRING' expression containing a valid complex type name");
assertExpr(
"complex_decode_base64(1, string)",
null
Expand All @@ -856,7 +920,8 @@ public void testComplexDecodeBaseArg0BadType()
public void testComplexDecodeBaseArg0Unknown()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[complex_decode_base64] first argument must be a valid complex type name, unknown complex type [COMPLEX<unknown>]");
expectedException.expectMessage(
"Function[complex_decode_base64] first argument must be a valid complex type name, unknown complex type [COMPLEX<unknown>]");
assertExpr(
"complex_decode_base64('unknown', string)",
null
Expand Down
Loading