diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java index 48af9f6d39380..9f25469f6bbc7 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java @@ -14,6 +14,7 @@ package io.prestosql.sql.planner.assertions; import io.prestosql.sql.tree.ArithmeticBinaryExpression; +import io.prestosql.sql.tree.ArithmeticUnaryExpression; import io.prestosql.sql.tree.AstVisitor; import io.prestosql.sql.tree.BetweenPredicate; import io.prestosql.sql.tree.BooleanLiteral; @@ -21,6 +22,7 @@ import io.prestosql.sql.tree.CoalesceExpression; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.DecimalLiteral; +import io.prestosql.sql.tree.DereferenceExpression; import io.prestosql.sql.tree.DoubleLiteral; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; @@ -34,6 +36,7 @@ import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NotExpression; import io.prestosql.sql.tree.NullLiteral; +import io.prestosql.sql.tree.Row; import io.prestosql.sql.tree.SimpleCaseExpression; import io.prestosql.sql.tree.StringLiteral; import io.prestosql.sql.tree.SymbolReference; @@ -88,238 +91,269 @@ protected Boolean visitNode(Node node, Node expectedExpression) } @Override - protected Boolean visitTryExpression(TryExpression actual, Node expectedExpression) + protected Boolean visitGenericLiteral(GenericLiteral actual, Node expectedExpression) { - if (!(expectedExpression instanceof TryExpression)) { + if (!(expectedExpression instanceof GenericLiteral)) { return false; } - TryExpression expected = (TryExpression) expectedExpression; - - return process(actual.getInnerExpression(), expected.getInnerExpression()); + return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); } @Override - protected Boolean visitCast(Cast actual, Node expectedExpression) + protected Boolean visitStringLiteral(StringLiteral actual, Node expectedExpression) { - if (!(expectedExpression instanceof Cast)) { + if (!(expectedExpression instanceof StringLiteral)) { return false; } - Cast expected = (Cast) expectedExpression; - - if (!actual.getType().equals(expected.getType())) { - return false; - } + StringLiteral expected = (StringLiteral) expectedExpression; - return process(actual.getExpression(), expected.getExpression()); + return actual.getValue().equals(expected.getValue()); } @Override - protected Boolean visitIsNullPredicate(IsNullPredicate actual, Node expectedExpression) + protected Boolean visitLongLiteral(LongLiteral actual, Node expectedExpression) { - if (!(expectedExpression instanceof IsNullPredicate)) { + if (!(expectedExpression instanceof LongLiteral)) { return false; } - IsNullPredicate expected = (IsNullPredicate) expectedExpression; - - return process(actual.getValue(), expected.getValue()); + return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); } @Override - protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Node expectedExpression) + protected Boolean visitDoubleLiteral(DoubleLiteral actual, Node expectedExpression) { - if (!(expectedExpression instanceof IsNotNullPredicate)) { + if (!(expectedExpression instanceof DoubleLiteral)) { return false; } - IsNotNullPredicate expected = (IsNotNullPredicate) expectedExpression; - - return process(actual.getValue(), expected.getValue()); + return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); } @Override - protected Boolean visitInPredicate(InPredicate actual, Node expectedExpression) + protected Boolean visitDecimalLiteral(DecimalLiteral actual, Node expectedExpression) { - if (!(expectedExpression instanceof InPredicate)) { + if (!(expectedExpression instanceof DecimalLiteral)) { return false; } - InPredicate expected = (InPredicate) expectedExpression; + return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); + } - if (actual.getValueList() instanceof InListExpression) { - return process(actual.getValue(), expected.getValue()) - && process(actual.getValueList(), expected.getValueList()); + @Override + protected Boolean visitBooleanLiteral(BooleanLiteral actual, Node expectedExpression) + { + if (!(expectedExpression instanceof BooleanLiteral)) { + return false; } - checkState(expected.getValueList() instanceof InListExpression, "ExpressionVerifier doesn't support unpacked expected values. Feel free to add support if needed"); - /* - * If the expected value is a value list, but the actual is e.g. a SymbolReference, - * we need to unpack the value from the list so that when we hit visitSymbolReference, the - * expected.toString() call returns something that the symbolAliases actually contains. - * For example, InListExpression.toString returns "(onlyitem)" rather than "onlyitem". - * - * This is required because actual passes through the analyzer, planner, and possibly optimizers, - * one of which sometimes takes the liberty of unpacking the InListExpression. - * - * Since the expected value doesn't go through all of that, we have to deal with the case - * of the actual value being unpacked, but the expected value being an InListExpression. - */ - List values = ((InListExpression) expected.getValueList()).getValues(); - checkState(values.size() == 1, "Multiple expressions in expected value list %s, but actual value is not a list", values, actual.getValue()); - Expression onlyExpectedExpression = values.get(0); - return process(actual.getValue(), expected.getValue()) - && process(actual.getValueList(), onlyExpectedExpression); + return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); } @Override - protected Boolean visitComparisonExpression(ComparisonExpression actual, Node expectedExpression) + protected Boolean visitNullLiteral(NullLiteral node, Node expectedExpression) { - if (!(expectedExpression instanceof ComparisonExpression)) { - return false; + return expectedExpression instanceof NullLiteral; + } + + private static String getValueFromLiteral(Node expression) + { + if (expression instanceof LongLiteral) { + return String.valueOf(((LongLiteral) expression).getValue()); } - ComparisonExpression expected = (ComparisonExpression) expectedExpression; + if (expression instanceof BooleanLiteral) { + return String.valueOf(((BooleanLiteral) expression).getValue()); + } - if (actual.getOperator() == expected.getOperator() && - process(actual.getLeft(), expected.getLeft()) && - process(actual.getRight(), expected.getRight())) { - return true; + if (expression instanceof DoubleLiteral) { + return String.valueOf(((DoubleLiteral) expression).getValue()); + } + + if (expression instanceof DecimalLiteral) { + return String.valueOf(((DecimalLiteral) expression).getValue()); + } + + if (expression instanceof GenericLiteral) { + return ((GenericLiteral) expression).getValue(); } - return actual.getOperator() == expected.getOperator().flip() - && process(actual.getLeft(), expected.getRight()) - && process(actual.getRight(), expected.getLeft()); + throw new IllegalArgumentException("Unsupported literal expression type: " + expression.getClass().getName()); } @Override - protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression actual, Node expectedExpression) + protected Boolean visitSymbolReference(SymbolReference actual, Node expectedExpression) { - if (!(expectedExpression instanceof ArithmeticBinaryExpression)) { + if (!(expectedExpression instanceof SymbolReference)) { return false; } - ArithmeticBinaryExpression expected = (ArithmeticBinaryExpression) expectedExpression; + SymbolReference expected = (SymbolReference) expectedExpression; - return actual.getOperator() == expected.getOperator() - && process(actual.getLeft(), expected.getLeft()) - && process(actual.getRight(), expected.getRight()); + return symbolAliases.get(expected.getName()).equals(actual); } @Override - protected Boolean visitGenericLiteral(GenericLiteral actual, Node expectedExpression) + protected Boolean visitDereferenceExpression(DereferenceExpression actual, Node expectedExpression) { - if (!(expectedExpression instanceof GenericLiteral)) { + if (!(expectedExpression instanceof DereferenceExpression)) { return false; } - return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); + DereferenceExpression expected = (DereferenceExpression) expectedExpression; + + return actual.getField().equals(expected.getField()) && + process(actual.getBase(), expected.getBase()); } @Override - protected Boolean visitLongLiteral(LongLiteral actual, Node expectedExpression) + protected Boolean visitCast(Cast actual, Node expectedExpression) { - if (!(expectedExpression instanceof LongLiteral)) { + if (!(expectedExpression instanceof Cast)) { return false; } - return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); + Cast expected = (Cast) expectedExpression; + + if (!actual.getType().equals(expected.getType())) { + return false; + } + + return process(actual.getExpression(), expected.getExpression()); } @Override - protected Boolean visitDoubleLiteral(DoubleLiteral actual, Node expectedExpression) + protected Boolean visitIsNullPredicate(IsNullPredicate actual, Node expectedExpression) { - if (!(expectedExpression instanceof DoubleLiteral)) { + if (!(expectedExpression instanceof IsNullPredicate)) { return false; } - return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); + IsNullPredicate expected = (IsNullPredicate) expectedExpression; + + return process(actual.getValue(), expected.getValue()); } @Override - protected Boolean visitDecimalLiteral(DecimalLiteral actual, Node expectedExpression) + protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Node expectedExpression) { - if (!(expectedExpression instanceof DecimalLiteral)) { + if (!(expectedExpression instanceof IsNotNullPredicate)) { return false; } - return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); + IsNotNullPredicate expected = (IsNotNullPredicate) expectedExpression; + + return process(actual.getValue(), expected.getValue()); } @Override - protected Boolean visitBooleanLiteral(BooleanLiteral actual, Node expectedExpression) + protected Boolean visitInPredicate(InPredicate actual, Node expectedExpression) { - if (!(expectedExpression instanceof BooleanLiteral)) { + if (!(expectedExpression instanceof InPredicate)) { return false; } - return getValueFromLiteral(actual).equals(getValueFromLiteral(expectedExpression)); + InPredicate expected = (InPredicate) expectedExpression; + + if (actual.getValueList() instanceof InListExpression) { + return process(actual.getValue(), expected.getValue()) && + process(actual.getValueList(), expected.getValueList()); + } + + checkState(expected.getValueList() instanceof InListExpression, "ExpressionVerifier doesn't support unpacked expected values. Feel free to add support if needed"); + + /* + * If the expected value is a value list, but the actual is e.g. a SymbolReference, + * we need to unpack the value from the list so that when we hit visitSymbolReference, the + * expected.toString() call returns something that the symbolAliases actually contains. + * For example, InListExpression.toString returns "(onlyitem)" rather than "onlyitem". + * + * This is required because actual passes through the analyzer, planner, and possibly optimizers, + * one of which sometimes takes the liberty of unpacking the InListExpression. + * + * Since the expected value doesn't go through all of that, we have to deal with the case + * of the actual value being unpacked, but the expected value being an InListExpression. + */ + List values = ((InListExpression) expected.getValueList()).getValues(); + checkState(values.size() == 1, "Multiple expressions in expected value list %s, but actual value is not a list", values, actual.getValue()); + Expression onlyExpectedExpression = values.get(0); + return process(actual.getValue(), expected.getValue()) && + process(actual.getValueList(), onlyExpectedExpression); } - private static String getValueFromLiteral(Node expression) + @Override + protected Boolean visitInListExpression(InListExpression actual, Node expectedExpression) { - if (expression instanceof LongLiteral) { - return String.valueOf(((LongLiteral) expression).getValue()); + if (!(expectedExpression instanceof InListExpression)) { + return false; } - if (expression instanceof BooleanLiteral) { - return String.valueOf(((BooleanLiteral) expression).getValue()); - } + InListExpression expected = (InListExpression) expectedExpression; - if (expression instanceof DoubleLiteral) { - return String.valueOf(((DoubleLiteral) expression).getValue()); - } + return process(actual.getValues(), expected.getValues()); + } - if (expression instanceof DecimalLiteral) { - return String.valueOf(((DecimalLiteral) expression).getValue()); + @Override + protected Boolean visitComparisonExpression(ComparisonExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof ComparisonExpression)) { + return false; } - if (expression instanceof GenericLiteral) { - return ((GenericLiteral) expression).getValue(); + ComparisonExpression expected = (ComparisonExpression) expectedExpression; + + if (actual.getOperator() == expected.getOperator() && + process(actual.getLeft(), expected.getLeft()) && + process(actual.getRight(), expected.getRight())) { + return true; } - throw new IllegalArgumentException("Unsupported literal expression type: " + expression.getClass().getName()); + return actual.getOperator() == expected.getOperator().flip() && + process(actual.getLeft(), expected.getRight()) && + process(actual.getRight(), expected.getLeft()); } @Override - protected Boolean visitStringLiteral(StringLiteral actual, Node expectedExpression) + protected Boolean visitBetweenPredicate(BetweenPredicate actual, Node expectedExpression) { - if (!(expectedExpression instanceof StringLiteral)) { + if (!(expectedExpression instanceof BetweenPredicate)) { return false; } - StringLiteral expected = (StringLiteral) expectedExpression; + BetweenPredicate expected = (BetweenPredicate) expectedExpression; - return actual.getValue().equals(expected.getValue()); + return process(actual.getValue(), expected.getValue()) && + process(actual.getMin(), expected.getMin()) && + process(actual.getMax(), expected.getMax()); } @Override - protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression actual, Node expectedExpression) + protected Boolean visitArithmeticUnary(ArithmeticUnaryExpression actual, Node expectedExpression) { - if (!(expectedExpression instanceof LogicalBinaryExpression)) { + if (!(expectedExpression instanceof ArithmeticUnaryExpression)) { return false; } - LogicalBinaryExpression expected = (LogicalBinaryExpression) expectedExpression; + ArithmeticUnaryExpression expected = (ArithmeticUnaryExpression) expectedExpression; - return actual.getOperator() == expected.getOperator() - && process(actual.getLeft(), expected.getLeft()) - && process(actual.getRight(), expected.getRight()); + return actual.getSign() == expected.getSign() && + process(actual.getValue(), expected.getValue()); } @Override - protected Boolean visitBetweenPredicate(BetweenPredicate actual, Node expectedExpression) + protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression actual, Node expectedExpression) { - if (!(expectedExpression instanceof BetweenPredicate)) { + if (!(expectedExpression instanceof ArithmeticBinaryExpression)) { return false; } - BetweenPredicate expected = (BetweenPredicate) expectedExpression; + ArithmeticBinaryExpression expected = (ArithmeticBinaryExpression) expectedExpression; - return process(actual.getValue(), expected.getValue()) - && process(actual.getMin(), expected.getMin()) - && process(actual.getMax(), expected.getMax()); + return actual.getOperator() == expected.getOperator() && + process(actual.getLeft(), expected.getLeft()) && + process(actual.getRight(), expected.getRight()); } @Override @@ -335,15 +369,17 @@ protected Boolean visitNotExpression(NotExpression actual, Node expectedExpressi } @Override - protected Boolean visitSymbolReference(SymbolReference actual, Node expectedExpression) + protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression actual, Node expectedExpression) { - if (!(expectedExpression instanceof SymbolReference)) { + if (!(expectedExpression instanceof LogicalBinaryExpression)) { return false; } - SymbolReference expected = (SymbolReference) expectedExpression; + LogicalBinaryExpression expected = (LogicalBinaryExpression) expectedExpression; - return symbolAliases.get(expected.getName()).equals(actual); + return actual.getOperator() == expected.getOperator() && + process(actual.getLeft(), expected.getLeft()) && + process(actual.getRight(), expected.getRight()); } @Override @@ -390,8 +426,8 @@ protected Boolean visitWhenClause(WhenClause actual, Node expectedExpression) WhenClause expected = (WhenClause) expectedExpression; - return process(actual.getOperand(), expected.getOperand()) - && process(actual.getResult(), expected.getResult()); + return process(actual.getOperand(), expected.getOperand()) && + process(actual.getResult(), expected.getResult()); } @Override @@ -411,21 +447,27 @@ protected Boolean visitFunctionCall(FunctionCall actual, Node expectedExpression } @Override - protected Boolean visitNullLiteral(NullLiteral node, Node expectedExpression) + protected Boolean visitRow(Row actual, Node expectedExpression) { - return expectedExpression instanceof NullLiteral; + if (!(expectedExpression instanceof Row)) { + return false; + } + + Row expected = (Row) expectedExpression; + + return process(actual.getItems(), expected.getItems()); } @Override - protected Boolean visitInListExpression(InListExpression actual, Node expectedExpression) + protected Boolean visitTryExpression(TryExpression actual, Node expectedExpression) { - if (!(expectedExpression instanceof InListExpression)) { + if (!(expectedExpression instanceof TryExpression)) { return false; } - InListExpression expected = (InListExpression) expectedExpression; + TryExpression expected = (TryExpression) expectedExpression; - return process(actual.getValues(), expected.getValues()); + return process(actual.getInnerExpression(), expected.getInnerExpression()); } private boolean process(List actuals, List expecteds)