diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1235204591bbd..77a0ff26bca8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,11 +90,12 @@ trait PredicateHelper { * Returns true iff `expr` could be evaluated as a condition within join. */ protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { - case l: ListQuery => + case _: ListQuery | _: Exists => // A ListQuery defines the query which we want to search in an IN subquery expression. // Currently the only way to evaluate an IN subquery is to convert it to a // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. // It cannot be evaluated as part of a Join operator. + // An Exists shouldn't be push into a Join operator too. false case e: SubqueryExpression => // non-correlated subquery will be replaced as literal diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dbf479d215134..7e49a8cbca406 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -65,7 +65,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Pullup Correlated Expressions", Once, PullupCorrelatedPredicates) :: Batch("Subquery", Once, - OptimizeSubqueries) :: + OptimizeSubqueries, + RewriteEmptyExists) :: Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, ReplaceExceptWithAntiJoin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2a3e07aebe709..6e3882712481b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -498,3 +498,31 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { } } } + +/** + * This rule rewrites a EXISTS predicate sub-queries into an Aggregate with count. + * So it doesn't be converted to a JOIN later. + */ +object RewriteEmptyExists extends Rule[LogicalPlan] with PredicateHelper { + private def containsAgg(plan: LogicalPlan): Boolean = { + plan.collect { + case a: Aggregate => a + }.nonEmpty + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Filter(condition, child) => + val (withSubquery, withoutSubquery) = + splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery) + val newWithSubquery = withSubquery.map(_.transform { + case e @ Exists(sub, conditions, exprId) if conditions.isEmpty && !containsAgg(sub) => + val countExpr = Alias(Count(Literal(1)).toAggregateExpression(), "count")() + val expr = Alias(GreaterThan(countExpr.toAttribute, Literal(0)), e.toString)() + ScalarSubquery( + Project(Seq(expr), Aggregate(Nil, Seq(countExpr), sub)), + children = Seq.empty, + exprId = exprId) + }) + Filter((newWithSubquery ++ withoutSubquery).reduce(And), child) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5fe6667ceca18..d342648fd844d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join} import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -844,4 +847,33 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(0) :: Row(1) :: Nil) } } + + test("ListQuery and Exists should work even no correlated references") { + checkAnswer( + sql("select * from l, r where l.a = r.c AND (r.d in (select d from r) OR l.a >= 1)"), + Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: + Row(2, 1.0, 2, 3.0) :: Row(3.0, 3.0, 3, 2.0) :: Row(6, null, 6, null) :: Nil) + checkAnswer( + sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"), + Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil) + } + + test("Convert Exists without correlated references to aggregation with count") { + val df = + sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)") + val joinPlan = df.queryExecution.optimizedPlan.asInstanceOf[Join] + val scalarSubquery = joinPlan.condition.get.collect { + case s: ScalarSubquery => s + } + assert(scalarSubquery.length == 1) + val aggPlan = scalarSubquery.head.plan.collect { + case a: Aggregate => a + } + assert(aggPlan.length == 1) + assert(aggPlan.head.aggregateExpressions.length == 1) + val countAggExpr = aggPlan.head.aggregateExpressions.collect { + case a @ Alias(AggregateExpression(_: Count, _, _, _), _) => a + } + assert(countAggExpr.length == 1) + } }