Skip to content

Commit

Permalink
Simplify calling structure
Browse files Browse the repository at this point in the history
Returning a Predicate is unnecessary and makes the code harder to follow
  • Loading branch information
martint committed Sep 1, 2019
1 parent 2646039 commit ef0de2b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private ExpressionTreeUtils() {}

static List<FunctionCall> extractAggregateFunctions(Iterable<? extends Node> nodes, Metadata metadata)
{
return extractExpressions(nodes, FunctionCall.class, isAggregationPredicate(metadata));
return extractExpressions(nodes, FunctionCall.class, function -> isAggregation(function, metadata));
}

static List<FunctionCall> extractWindowFunctions(Iterable<? extends Node> nodes)
Expand All @@ -50,11 +50,11 @@ public static <T extends Expression> List<T> extractExpressions(
return extractExpressions(nodes, clazz, alwaysTrue());
}

private static Predicate<FunctionCall> isAggregationPredicate(Metadata metadata)
private static boolean isAggregation(FunctionCall functionCall, Metadata metadata)
{
return ((functionCall) -> (metadata.isAggregationFunction(functionCall.getName())
|| functionCall.getFilter().isPresent()) && !functionCall.getWindow().isPresent()
|| functionCall.getOrderBy().isPresent());
return ((metadata.isAggregationFunction(functionCall.getName()) || functionCall.getFilter().isPresent())
&& !functionCall.getWindow().isPresent())
|| functionCall.getOrderBy().isPresent();
}

private static boolean isWindowFunction(FunctionCall functionCall)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ private Expression rewriteExpression(Expression expression, Predicate<Symbol> sy
// larger subtrees over smaller subtrees
// TODO: this rewrite can probably be made more sophisticated
Expression rewritten = ExpressionTreeRewriter.rewriteWith(new ExpressionNodeInliner(expressionRemap.build()), expression);
if (!symbolToExpressionPredicate(symbolScope).apply(rewritten)) {
if (!isScoped(rewritten, symbolScope)) {
// If the rewritten is still not compliant with the symbol scope, just give up
return null;
}
Expand Down Expand Up @@ -244,32 +244,30 @@ Expression getScopedCanonical(Expression expression, Predicate<Symbol> symbolSco
if (canonicalIndex == null) {
return null;
}
return getCanonical(filter(equalitySets.get(canonicalIndex), symbolToExpressionPredicate(symbolScope)));
return getCanonical(filter(equalitySets.get(canonicalIndex), equivalentExpression -> isScoped(equivalentExpression, symbolScope)));
}

private static Predicate<Expression> symbolToExpressionPredicate(final Predicate<Symbol> symbolScope)
private static boolean isScoped(Expression expression, Predicate<Symbol> symbolScope)
{
return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope);
return Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope);
}

/**
* Determines whether an Expression may be successfully applied to the equality inference
*/
public static Predicate<Expression> isInferenceCandidate()
public static boolean isInferenceCandidate(Expression expression)
{
return expression -> {
expression = normalizeInPredicateToEquality(expression);
if (expression instanceof ComparisonExpression &&
isDeterministic(expression) &&
!mayReturnNullOnNonNullInput(expression)) {
ComparisonExpression comparison = (ComparisonExpression) expression;
if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
// We should only consider equalities that have distinct left and right components
return !comparison.getLeft().equals(comparison.getRight());
}
expression = normalizeInPredicateToEquality(expression);
if (expression instanceof ComparisonExpression &&
isDeterministic(expression) &&
!mayReturnNullOnNonNullInput(expression)) {
ComparisonExpression comparison = (ComparisonExpression) expression;
if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
// We should only consider equalities that have distinct left and right components
return !comparison.getLeft().equals(comparison.getRight());
}
return false;
};
}
return false;
}

/**
Expand All @@ -294,7 +292,7 @@ private static Expression normalizeInPredicateToEquality(Expression expression)
*/
public static Iterable<Expression> nonInferrableConjuncts(Expression expression)
{
return filter(extractConjuncts(expression), not(isInferenceCandidate()));
return filter(extractConjuncts(expression), not(EqualityInference::isInferenceCandidate));
}

public static EqualityInference createEqualityInference(Expression... expressions)
Expand Down Expand Up @@ -342,7 +340,7 @@ public static class Builder

public Builder extractInferenceCandidates(Expression expression)
{
return addAllEqualities(filter(extractConjuncts(expression), isInferenceCandidate()));
return addAllEqualities(filter(extractConjuncts(expression), EqualityInference::isInferenceCandidate));
}

@VisibleForTesting
Expand All @@ -358,7 +356,7 @@ Builder addAllEqualities(Iterable<Expression> expressions)
Builder addEquality(Expression expression)
{
expression = normalizeInPredicateToEquality(expression);
checkArgument(isInferenceCandidate().apply(expression), "Expression must be a simple equality: " + expression);
checkArgument(isInferenceCandidate(expression), "Expression must be a simple equality: " + expression);
ComparisonExpression comparison = (ComparisonExpression) expression;
addEquality(comparison.getLeft(), comparison.getRight());
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Expression> context)
List<JoinNode.EquiJoinClause> equiJoinClauses = new ArrayList<>();
ImmutableList.Builder<Expression> joinFilterBuilder = ImmutableList.builder();
for (Expression conjunct : extractConjuncts(newJoinPredicate)) {
if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) {
if (joinEqualityExpression(conjunct, node.getLeft().getOutputSymbols())) {
ComparisonExpression equality = (ComparisonExpression) conjunct;

boolean alignedComparison = Iterables.all(SymbolsExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols()));
Expand Down Expand Up @@ -1023,24 +1023,22 @@ private Object nullInputEvaluator(final Collection<Symbol> nullSymbols, Expressi
.optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference());
}

private static Predicate<Expression> joinEqualityExpression(final Collection<Symbol> leftSymbols)
private static boolean joinEqualityExpression(Expression expression, Collection<Symbol> leftSymbols)
{
return expression -> {
// At this point in time, our join predicates need to be deterministic
if (isDeterministic(expression) && expression instanceof ComparisonExpression) {
ComparisonExpression comparison = (ComparisonExpression) expression;
if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
Set<Symbol> symbols1 = SymbolsExtractor.extractUnique(comparison.getLeft());
Set<Symbol> symbols2 = SymbolsExtractor.extractUnique(comparison.getRight());
if (symbols1.isEmpty() || symbols2.isEmpty()) {
return false;
}
return (Iterables.all(symbols1, in(leftSymbols)) && Iterables.all(symbols2, not(in(leftSymbols)))) ||
(Iterables.all(symbols2, in(leftSymbols)) && Iterables.all(symbols1, not(in(leftSymbols))));
// At this point in time, our join predicates need to be deterministic
if (isDeterministic(expression) && expression instanceof ComparisonExpression) {
ComparisonExpression comparison = (ComparisonExpression) expression;
if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
Set<Symbol> symbols1 = SymbolsExtractor.extractUnique(comparison.getLeft());
Set<Symbol> symbols2 = SymbolsExtractor.extractUnique(comparison.getRight());
if (symbols1.isEmpty() || symbols2.isEmpty()) {
return false;
}
return (Iterables.all(symbols1, in(leftSymbols)) && Iterables.all(symbols2, not(in(leftSymbols)))) ||
(Iterables.all(symbols2, in(leftSymbols)) && Iterables.all(symbols1, not(in(leftSymbols))));
}
return false;
};
}
return false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,17 @@ public void testEqualityPartitionGeneration()
// There should be equalities in the scope, that only use c1 and are all inferrable equalities
assertFalse(equalityPartition.getScopeEqualities().isEmpty());
assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1"))));
assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate()));
assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference::isInferenceCandidate));

// There should be equalities in the inverse scope, that never use c1 and are all inferrable equalities
assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty());
assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(matchesSymbols("c1")))));
assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate()));
assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference::isInferenceCandidate));

// There should be equalities in the straddling scope, that should use both c1 and not c1 symbols
assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty());
assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesSymbols("c1"))));
assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate()));
assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference::isInferenceCandidate));

// There should be a "full cover" of all of the equalities used
// THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around
Expand Down Expand Up @@ -249,17 +249,17 @@ public void testMultipleEqualitySetsPredicateGeneration()
// There should be equalities in the scope, that only use a* and b* symbols and are all inferrable equalities
assertFalse(equalityPartition.getScopeEqualities().isEmpty());
assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(symbolBeginsWith("a", "b"))));
assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate()));
assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference::isInferenceCandidate));

// There should be equalities in the inverse scope, that never use a* and b* symbols and are all inferrable equalities
assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty());
assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(symbolBeginsWith("a", "b")))));
assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate()));
assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference::isInferenceCandidate));

// There should be equalities in the straddling scope, that should use both c1 and not c1 symbols
assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty());
assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(symbolBeginsWith("a", "b"))));
assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate()));
assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference::isInferenceCandidate));

// Again, there should be a "full cover" of all of the equalities used
// THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around
Expand Down

0 comments on commit ef0de2b

Please sign in to comment.