Skip to content

Commit

Permalink
Exists should not be evaluated in Join operator too.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 31, 2017
1 parent 9712bd3 commit b012550
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ trait PredicateHelper {
* Returns true iff `expr` could be evaluated as a condition within join.
*/
protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match {
case l: ListQuery =>
case _: ListQuery | _: Exists =>
// A ListQuery defines the query which we want to search in an IN subquery expression.
// Currently the only way to evaluate an IN subquery is to convert it to a
// LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule.
// It cannot be evaluated as part of a Join operator.
// An Exists shouldn't be push into a Join operator too.
false
case e: SubqueryExpression =>
// non-correlated subquery will be replaced as literal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Pullup Correlated Expressions", Once,
PullupCorrelatedPredicates) ::
Batch("Subquery", Once,
OptimizeSubqueries) ::
OptimizeSubqueries,
RewriteEmptyExists) ::
Batch("Replace Operators", fixedPoint,
ReplaceIntersectWithSemiJoin,
ReplaceExceptWithAntiJoin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,31 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
}
}
}

/**
* This rule rewrites a EXISTS predicate sub-queries into an Aggregate with count.
* So it doesn't be converted to a JOIN later.
*/
object RewriteEmptyExists extends Rule[LogicalPlan] with PredicateHelper {
private def containsAgg(plan: LogicalPlan): Boolean = {
plan.collect {
case a: Aggregate => a
}.nonEmpty
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery)
val newWithSubquery = withSubquery.map(_.transform {
case e @ Exists(sub, conditions, exprId) if conditions.isEmpty && !containsAgg(sub) =>
val countExpr = Alias(Count(Literal(1)).toAggregateExpression(), "count")()
val expr = Alias(GreaterThan(countExpr.toAttribute, Literal(0)), e.toString)()
ScalarSubquery(
Project(Seq(expr), Aggregate(Nil, Seq(countExpr), sub)),
children = Seq.empty,
exprId = exprId)
})
Filter((newWithSubquery ++ withoutSubquery).reduce(And), child)
}
}
32 changes: 32 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join}
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
Expand Down Expand Up @@ -844,4 +847,33 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
Row(0) :: Row(1) :: Nil)
}
}

test("ListQuery and Exists should work even no correlated references") {
checkAnswer(
sql("select * from l, r where l.a = r.c AND (r.d in (select d from r) OR l.a >= 1)"),
Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) ::
Row(2, 1.0, 2, 3.0) :: Row(3.0, 3.0, 3, 2.0) :: Row(6, null, 6, null) :: Nil)
checkAnswer(
sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"),
Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil)
}

test("Convert Exists without correlated references to aggregation with count") {
val df =
sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)")
val joinPlan = df.queryExecution.optimizedPlan.asInstanceOf[Join]
val scalarSubquery = joinPlan.condition.get.collect {
case s: ScalarSubquery => s
}
assert(scalarSubquery.length == 1)
val aggPlan = scalarSubquery.head.plan.collect {
case a: Aggregate => a
}
assert(aggPlan.length == 1)
assert(aggPlan.head.aggregateExpressions.length == 1)
val countAggExpr = aggPlan.head.aggregateExpressions.collect {
case a @ Alias(AggregateExpression(_: Count, _, _, _), _) => a
}
assert(countAggExpr.length == 1)
}
}

0 comments on commit b012550

Please sign in to comment.