From 84f61f0ad1e55845ac6b7bb184943b51a0e6707a Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Sun, 7 Jun 2020 00:06:16 +0200 Subject: [PATCH] Handle NaN values in unwrap casts --- .../rule/UnwrapCastInComparison.java | 36 +++++++++++++ .../planner/TestUnwrapCastInComparison.java | 51 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java index d4a9fce3029ab..9b1ae732df449 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -21,6 +21,8 @@ import io.prestosql.spi.PrestoException; import io.prestosql.spi.predicate.Utils; import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.DoubleType; +import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.Type; import io.prestosql.sql.InterpretedFunctionInvoker; import io.prestosql.sql.planner.ExpressionInterpreter; @@ -54,6 +56,8 @@ import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; /** @@ -182,6 +186,31 @@ private Expression unwrapCast(ComparisonExpression expression) return expression; } + // Handle comparison against NaN. + // It must be done before source type range bounds are compared to target value. + if ((targetType instanceof DoubleType && Double.isNaN((double) right)) || (targetType instanceof RealType && Float.isNaN(intBitsToFloat(toIntExact((long) right))))) { + switch (operator) { + case EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return falseIfNotNull(cast.getExpression()); + case NOT_EQUAL: + return trueIfNotNull(cast.getExpression()); + case IS_DISTINCT_FROM: + if (!typeHasNaN(sourceType)) { + return TRUE_LITERAL; + } + else { + // NaN on the right of comparison will be cast to source type later + break; + } + default: + throw new UnsupportedOperationException("Not yet implemented: " + operator); + } + } + ResolvedFunction sourceToTarget = metadata.getCoercion(sourceType, targetType); Optional sourceRange = sourceType.getRange(); @@ -189,6 +218,8 @@ private Expression unwrapCast(ComparisonExpression expression) Object max = sourceRange.get().getMax(); Object maxInTargetType = coerce(max, sourceToTarget); + // NaN values of `right` are excluded at this point. Otherwise, NaN would be recognized as + // greater than source type upper bound, and incorrect expression might be derived. int upperBoundComparison = compare(targetType, right, maxInTargetType); if (upperBoundComparison > 0) { // larger than maximum representable value @@ -383,6 +414,11 @@ private Object coerce(Object value, ResolvedFunction coercion) { return functionInvoker.invoke(coercion, session.toConnectorSession(), value); } + + private boolean typeHasNaN(Type type) + { + return type instanceof DoubleType || type instanceof RealType; + } } private static int compare(Type type, Object first, Object second) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestUnwrapCastInComparison.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestUnwrapCastInComparison.java index 194779b0304d7..bacbd502f34fb 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestUnwrapCastInComparison.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestUnwrapCastInComparison.java @@ -625,6 +625,57 @@ public void testNull() values("A")))); } + @Test + public void testNaN() + { + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = nan()", + output( + filter("A IS NULL AND NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < nan()", + output( + filter("A IS NULL AND NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> nan()", + output( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM nan()", + output( + values("A"))); + + assertPlan( + "SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a = nan()", + output( + filter("A IS NULL AND NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a < nan()", + output( + filter("A IS NULL AND NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a <> nan()", + output( + filter("NOT (A IS NULL) OR NULL", + values("A")))); + + assertPlan( + "SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a IS DISTINCT FROM nan()", + output( + filter("A IS DISTINCT FROM CAST(nan() AS REAL)", + values("A")))); + } + @Test public void smokeTests() {