diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 415ce46788119..9d7e2a93ac2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -185,7 +185,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) - val joinKeys = predicates.flatMap { + val explicitJoinKeys = predicates.flatMap { case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) @@ -203,6 +203,27 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { ) case other => None } + + val literalEqualities = predicates.collect { + case EqualTo(l, r: Literal) if canEvaluate(l, left) && l.deterministic => + r -> (Some(l), None) + case EqualTo(l, r: Literal) if canEvaluate(l, right) && l.deterministic => + r -> (None, Some(l)) + case EqualTo(l: Literal, r) if canEvaluate(r, left) && r.deterministic => + l -> (Some(r), None) + case EqualTo(l: Literal, r) if canEvaluate(r, right) && r.deterministic => + l -> (None, Some(r)) + }.groupBy(_._1).mapValues { v => + val (l, r) = v.map(_._2).unzip + (l.flatten, r.flatten) + } + + val implicitJoinKeys = literalEqualities.values.flatMap { + case (xs, ys) => for { x <- xs; y <- ys } yield (x, y) + } + + val joinKeys = (explicitJoinKeys.toSet ++ implicitJoinKeys).toList + val otherPredicates = predicates.filterNot { case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false case Equality(l, r) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f45bd950040ce..80ffe16041bbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1082,4 +1082,33 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan assert(df2.join(df1, "id").collect().isEmpty) } } + + test("Detect equijoins better") { + val df1 = Seq((1, 1), (2, 2)).toDF("c1", "c2") + val df2 = Seq((2, 2), (3, 3)).toDF("c1", "c2") + + val explicitConstraints = df1("c1") === 2 && df2("c1") === 2 + val implicitConstraints = df1("c1") === df2("c1") + + val explicitDF = df1.join(df2, explicitConstraints && implicitConstraints, "FullOuter") + val implicitDF = df1.join(df2, explicitConstraints, "FullOuter") + + checkAnswer(explicitDF, implicitDF) + assert( + explicitDF.queryExecution.sparkPlan === implicitDF.queryExecution.sparkPlan, + "Explicit and implicit plans should match.") + + val explicitConstraints2 = + df1("c1") === 2 && df1("c2") === 2 && df2("c1") === 2 && df2("c2") === 2 + val implicitConstraints2 = df1("c1") === df2("c1") && df1("c1") === df2("c2") && + df1("c2") === df2("c1") && df1("c2") === df2("c2") + + val explicitDF2 = df1.join(df2, explicitConstraints2 && implicitConstraints2, "FullOuter") + val implicitDF2 = df1.join(df2, explicitConstraints2, "FullOuter") + + checkAnswer(explicitDF2, implicitDF2) + assert( + explicitDF2.queryExecution.sparkPlan === implicitDF2.queryExecution.sparkPlan, + "Explicit and implicit plans should match.") + } }