Skip to content

Commit

Permalink
Handle NaN values in unwrap casts
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi authored and martint committed Jun 7, 2020
1 parent e3d798d commit 84f61f0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.predicate.Utils;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.RealType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.InterpretedFunctionInvoker;
import io.prestosql.sql.planner.ExpressionInterpreter;
Expand Down Expand Up @@ -54,6 +56,8 @@
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -182,13 +186,40 @@ private Expression unwrapCast(ComparisonExpression expression)
return expression;
}

// Handle comparison against NaN.
// It must be done before source type range bounds are compared to target value.
if ((targetType instanceof DoubleType && Double.isNaN((double) right)) || (targetType instanceof RealType && Float.isNaN(intBitsToFloat(toIntExact((long) right))))) {
switch (operator) {
case EQUAL:
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
return falseIfNotNull(cast.getExpression());
case NOT_EQUAL:
return trueIfNotNull(cast.getExpression());
case IS_DISTINCT_FROM:
if (!typeHasNaN(sourceType)) {
return TRUE_LITERAL;
}
else {
// NaN on the right of comparison will be cast to source type later
break;
}
default:
throw new UnsupportedOperationException("Not yet implemented: " + operator);
}
}

ResolvedFunction sourceToTarget = metadata.getCoercion(sourceType, targetType);

Optional<Type.Range> sourceRange = sourceType.getRange();
if (sourceRange.isPresent()) {
Object max = sourceRange.get().getMax();
Object maxInTargetType = coerce(max, sourceToTarget);

// NaN values of `right` are excluded at this point. Otherwise, NaN would be recognized as
// greater than source type upper bound, and incorrect expression might be derived.
int upperBoundComparison = compare(targetType, right, maxInTargetType);
if (upperBoundComparison > 0) {
// larger than maximum representable value
Expand Down Expand Up @@ -383,6 +414,11 @@ private Object coerce(Object value, ResolvedFunction coercion)
{
return functionInvoker.invoke(coercion, session.toConnectorSession(), value);
}

private boolean typeHasNaN(Type type)
{
return type instanceof DoubleType || type instanceof RealType;
}
}

private static int compare(Type type, Object first, Object second)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,57 @@ public void testNull()
values("A"))));
}

@Test
public void testNaN()
{
assertPlan(
"SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a = nan()",
output(
filter("A IS NULL AND NULL",
values("A"))));

assertPlan(
"SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a < nan()",
output(
filter("A IS NULL AND NULL",
values("A"))));

assertPlan(
"SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a <> nan()",
output(
filter("NOT (A IS NULL) OR NULL",
values("A"))));

assertPlan(
"SELECT * FROM (VALUES SMALLINT '0') t(a) WHERE a IS DISTINCT FROM nan()",
output(
values("A")));

assertPlan(
"SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a = nan()",
output(
filter("A IS NULL AND NULL",
values("A"))));

assertPlan(
"SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a < nan()",
output(
filter("A IS NULL AND NULL",
values("A"))));

assertPlan(
"SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a <> nan()",
output(
filter("NOT (A IS NULL) OR NULL",
values("A"))));

assertPlan(
"SELECT * FROM (VALUES REAL '0.0') t(a) WHERE a IS DISTINCT FROM nan()",
output(
filter("A IS DISTINCT FROM CAST(nan() AS REAL)",
values("A"))));
}

@Test
public void smokeTests()
{
Expand Down

0 comments on commit 84f61f0

Please sign in to comment.