Skip to content

Commit

Permalink
[SPARK-26352][SQL] join reorder should not change the order of output…
Browse files Browse the repository at this point in the history
… attributes

## What changes were proposed in this pull request?

The optimizer rule `org.apache.spark.sql.catalyst.optimizer.ReorderJoin` performs join reordering on inner joins. This was introduced from SPARK-12032 (apache#10073) in 2015-12.

After it had reordered the joins, though, it didn't check whether or not the output attribute order is still the same as before. Thus, it's possible to have a mismatch between the reordered output attributes order vs the schema that a DataFrame thinks it has.
The same problem exists in the CBO version of join reordering (`CostBasedJoinReorder`) too.

This can be demonstrated with the example:
```scala
spark.sql("create table table_a (x int, y int) using parquet")
spark.sql("create table table_b (i int, j int) using parquet")
spark.sql("create table table_c (a int, b int) using parquet")
val df = spark.sql("""
  with df1 as (select * from table_a cross join table_b)
  select * from df1 join table_c on a = x and b = i
""")
```
here's what the DataFrame thinks:
```
scala> df.printSchema
root
 |-- x: integer (nullable = true)
 |-- y: integer (nullable = true)
 |-- i: integer (nullable = true)
 |-- j: integer (nullable = true)
 |-- a: integer (nullable = true)
 |-- b: integer (nullable = true)
```
here's what the optimized plan thinks, after join reordering:
```
scala> df.queryExecution.optimizedPlan.output.foreach(a => println(s"|-- ${a.name}: ${a.dataType.typeName}"))
|-- x: integer
|-- y: integer
|-- a: integer
|-- b: integer
|-- i: integer
|-- j: integer
```

If we exclude the `ReorderJoin` rule (using Spark 2.4's optimizer rule exclusion feature), it's back to normal:
```
scala> spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ReorderJoin")

scala> val df = spark.sql("with df1 as (select * from table_a cross join table_b) select * from df1 join table_c on a = x and b = i")
df: org.apache.spark.sql.DataFrame = [x: int, y: int ... 4 more fields]

scala> df.queryExecution.optimizedPlan.output.foreach(a => println(s"|-- ${a.name}: ${a.dataType.typeName}"))
|-- x: integer
|-- y: integer
|-- i: integer
|-- j: integer
|-- a: integer
|-- b: integer
```

Note that this output attribute ordering problem leads to data corruption, and can manifest itself in various symptoms:
* Silently corrupting data, if the reordered columns happen to either have matching types or have sufficiently-compatible types (e.g. all fixed length primitive types are considered as "sufficiently compatible" in an `UnsafeRow`), then only the resulting data is going to be wrong but it might not trigger any alarms immediately. Or
* Weird Java-level exceptions like `java.lang.NegativeArraySizeException`, or even SIGSEGVs.

## How was this patch tested?

Added new unit test in `JoinReorderSuite` and new end-to-end test in `JoinSuite`.
Also made `JoinReorderSuite` and `StarJoinReorderSuite` assert more strongly on maintaining output attribute order.

Closes apache#23303 from rednaxelafx/fix-join-reorder.

Authored-by: Kris Mok <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
rednaxelafx authored and cloud-fan committed Dec 17, 2018
1 parent db1c5b1 commit 56448c6
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
if projectList.forall(_.isInstanceOf[Attribute]) =>
reorder(p, p.output)
}

// After reordering is finished, convert OrderedJoin back to Join
result transformDown {
case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond)
Expand Down Expand Up @@ -175,11 +176,20 @@ object JoinReorderDP extends PredicateHelper with Logging {
assert(topOutputSet == p.outputSet)
// Keep the same order of final output attributes.
p.copy(projectList = output)
case finalPlan if !sameOutput(finalPlan, output) =>
Project(output, finalPlan)
case finalPlan =>
finalPlan
}
}

private def sameOutput(plan: LogicalPlan, expectedOutput: Seq[Attribute]): Boolean = {
val thisOutput = plan.output
thisOutput.length == expectedOutput.length && thisOutput.zip(expectedOutput).forall {
case (a1, a2) => a1.semanticEquals(a2)
}
}

/** Find all possible plans at the next level, based on existing levels. */
private def searchLevel(
existingLevels: Seq[JoinPlanMap],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case ExtractFiltersAndInnerJoins(input, conditions)
case p @ ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions)
if (starJoinPlan.nonEmpty) {
val rest = input.filterNot(starJoinPlan.contains(_))
Expand All @@ -99,6 +99,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
} else {
createOrderedJoin(input, conditions)
}

if (p.sameOutput(reordered)) {
reordered
} else {
// Reordering the joins have changed the order of the columns.
// Inject a projection to make sure we restore to the expected ordering.
Project(p.output, reordered)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,19 @@ class JoinOptimizationSuite extends PlanTest {
x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
x.join(z, condition = Some("x.b".attr === "z.b".attr))
.join(y, condition = Some("y.d".attr === "z.a".attr))
.select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
),
(
x.join(y, Cross).join(z, Cross)
.where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
x.join(z, Cross, Some("x.b".attr === "z.b".attr))
.join(y, Cross, Some("y.d".attr === "z.a".attr))
.select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
),
(
x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr),
x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner)
.select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED}
Expand Down Expand Up @@ -124,7 +124,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
// the original order (t1 J t2) J t3.
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*)

assertEqualPlans(originalPlan, bestPlan)
}
Expand All @@ -139,7 +140,9 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*) // this is redundant but we'll take it for now
.join(t4)
.select(outputsOf(t1, t2, t4, t3): _*)

assertEqualPlans(originalPlan, bestPlan)
}
Expand Down Expand Up @@ -202,6 +205,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t1, t4, t2, t3): _*)

assertEqualPlans(originalPlan, bestPlan)
}
Expand All @@ -219,6 +223,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
}
}

test("SPARK-26352: join reordering should not change the order of attributes") {
// This test case does not rely on CBO.
// It's similar to the test case above, but catches a reordering bug that the one above doesn't
val tab1 = LocalRelation('x.int, 'y.int)
val tab2 = LocalRelation('i.int, 'j.int)
val tab3 = LocalRelation('a.int, 'b.int)
val original =
tab1.join(tab2, Cross)
.join(tab3, Inner, Some('a === 'x && 'b === 'i))
val expected =
tab1.join(tab3, Inner, Some('a === 'x))
.join(tab2, Cross, Some('b === 'i))
.select(outputsOf(tab1, tab2, tab3): _*)

assertEqualPlans(original, expected)
}

test("reorder recursively") {
// Original order:
// Join
Expand Down Expand Up @@ -266,8 +287,17 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
private def assertEqualPlans(
originalPlan: LogicalPlan,
groundTruthBestPlan: LogicalPlan): Unit = {
val optimized = Optimize.execute(originalPlan.analyze)
val analyzed = originalPlan.analyze
val optimized = Optimize.execute(analyzed)
val expected = groundTruthBestPlan.analyze

assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect
assert(analyzed.sameOutput(optimized))

compareJoinOrder(optimized, expected)
}

private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
plans.map(_.output).reduce(_ ++ _)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))
.join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1")))
.select(outputsOf(f1, t1, t2, d1, d2): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -256,6 +257,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
.join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner,
Some(nameToAttr("d1_c2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1")))
.select(outputsOf(d1, t1, t2, f1, d2, t3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -297,6 +299,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
Some(nameToAttr("t3_c1") === nameToAttr("t4_c1")))
.join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner,
Some(nameToAttr("t1_c2") === nameToAttr("t4_c2")))
.select(outputsOf(d1, t1, t2, t3, t4, f1, d2): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -347,6 +350,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
Some(nameToAttr("d3_c2") === nameToAttr("t1_c1")))
.join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner,
Some(nameToAttr("d2_c2") === nameToAttr("t5_c1")))
.select(outputsOf(d1, t3, t4, f1, d2, t5, t6, d3, t1, t2): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -375,6 +379,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk")))
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))
.select(outputsOf(d1, d2, f1, d3): _*)

assertEqualPlans(query, expected)
}
Expand All @@ -400,13 +405,27 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1")))
.join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1")))
.select(outputsOf(t1, f1, t2, t3): _*)

assertEqualPlans(query, expected)
}

private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
val optimized = Optimize.execute(plan1.analyze)
val analyzed = plan1.analyze
val optimized = Optimize.execute(analyzed)
val expected = plan2.analyze

assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect
assert(equivalentOutput(analyzed, optimized))

compareJoinOrder(optimized, expected)
}

private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
plans.map(_.output).reduce(_ ++ _)
}

private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
normalizeExprIds(plan1).output == normalizeExprIds(plan2).output
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d2, f1, d3, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -220,6 +221,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, f1, d2, s3, d3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -255,7 +257,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2")))

.select(outputsOf(d1, f1, d2, s3, d3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -292,6 +294,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
.select(outputsOf(d1, f1, d2, s3, d3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -395,6 +398,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, f11, f1, d2, s3): _*)

assertEqualPlans(query, equivQuery)
}
Expand Down Expand Up @@ -430,6 +434,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -465,6 +470,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -499,6 +505,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2),
Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -532,6 +539,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -565,13 +573,27 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}

private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
val optimized = Optimize.execute(plan1.analyze)
private def assertEqualPlans(plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
val analyzed = plan1.analyze
val optimized = Optimize.execute(analyzed)
val expected = plan2.analyze

assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect
assert(equivalentOutput(analyzed, optimized))

compareJoinOrder(optimized, expected)
}

private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
plans.map(_.output).reduce(_ ++ _)
}

private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
normalizeExprIds(plan1).output == normalizeExprIds(plan2).output
}
}
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -895,4 +895,18 @@ class JoinSuite extends QueryTest with SharedSQLContext {
checkAnswer(res, Row(0, 0, 0))
}
}

test("SPARK-26352: join reordering should not change the order of columns") {
withTable("tab1", "tab2", "tab3") {
spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1")
spark.sql("select 42 as i, 200 as j").write.saveAsTable("tab2")
spark.sql("select 1 as a, 42 as b").write.saveAsTable("tab3")

val df = spark.sql("""
with tmp as (select * from tab1 cross join tab2)
select * from tmp join tab3 on a = x and b = i
""")
checkAnswer(df, Row(1, 100, 42, 200, 1, 42))
}
}
}

0 comments on commit 56448c6

Please sign in to comment.