Skip to content

Commit

Permalink
[SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consi…
Browse files Browse the repository at this point in the history
…der NULL as False

## What changes were proposed in this pull request?

In `ReplaceExceptWithFilter` we do not consider properly the case in which the condition returns NULL. Indeed, in that case, since negating NULL still returns NULL, so it is not true the assumption that negating the condition returns all the rows which didn't satisfy it, rows returning NULL may not be returned. This happens when constraints inferred by `InferFiltersFromConstraints` are not enough, as it happens with `OR` conditions.

The rule had also problems with non-deterministic conditions: in such a scenario, this rule would change the probability of the output.

The PR fixes these problem by:
 - returning False for the condition when it is Null (in this way we do return all the rows which didn't satisfy it);
 - avoiding any transformation when the condition is non-deterministic.

## How was this patch tested?

added UTs

Closes apache#23350 from mgaido91/SPARK-26366_2.3.

Authored-by: Marco Gaido <[email protected]>
Signed-off-by: gatorsmile <[email protected]>
  • Loading branch information
mgaido91 authored and gatorsmile committed Dec 21, 2018
1 parent b4aeb81 commit a7d50ae
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Note:
* Before flipping the filter condition of the right node, we should:
* 1. Combine all it's [[Filter]].
* 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition).
* 2. Update the attribute references to the left node;
* 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition).
*/
object ReplaceExceptWithFilter extends Rule[LogicalPlan] {

Expand All @@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {

plan.transform {
case e @ Except(left, right) if isEligible(left, right) =>
val newCondition = transformCondition(left, skipProject(right))
newCondition.map { c =>
Distinct(Filter(Not(c), left))
}.getOrElse {
val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition
if (filterCondition.deterministic) {
transformCondition(left, filterCondition).map { c =>
Distinct(Filter(Not(c), left))
}.getOrElse {
e
}
} else {
e
}
}
}

private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = {
val filterCondition =
InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition

val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap

if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) {
Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) })
private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = {
val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap
if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
val rewrittenCondition = condition.transform {
case a: AttributeReference => attributeNameMap(a.name)
}
// We need to consider as False when the condition is NULL, otherwise we do not return those
// rows containing NULL which are instead filtered in the Except right plan
Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.BooleanType

class ReplaceOperatorSuite extends PlanTest {

Expand Down Expand Up @@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)),
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze

comparePlans(optimized, correctAnswer)
Expand All @@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)), table1)).analyze
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
table1)).analyze

comparePlans(optimized, correctAnswer)
}
Expand All @@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)),
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB), table1))).analyze

comparePlans(optimized, correctAnswer)
Expand All @@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)),
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze

comparePlans(optimized, correctAnswer)
Expand All @@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA === 1 && attributeB === 2)),
Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB),
Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze

Expand Down Expand Up @@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {

comparePlans(optimized, query)
}

test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
val basePlan = LocalRelation(Seq('a.int, 'b.int))
val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
val except = Except(basePlan, otherPlan)
val result = OptimizeIn(Optimize.execute(except.analyze))
val correctAnswer = Aggregate(basePlan.output, basePlan.output,
Filter(!Coalesce(Seq(
'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)),
Literal.FalseLiteral)),
basePlan)).analyze
comparePlans(result, correctAnswer)
}

test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") {
val basePlan = LocalRelation(Seq('a.int, 'b.int))
val otherPlan = basePlan.where('a > rand(1L))
val except = Except(basePlan, otherPlan)
val result = Optimize.execute(except.analyze)
val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) =>
a1 <=> a2 }.reduce( _ && _)
val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
comparePlans(result, correctAnswer)
}
}
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
}

test("SPARK-26366: return nulls which are not filtered in except") {
val inputDF = sqlContext.createDataFrame(
sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", StringType, nullable = true))))

val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down
38 changes: 38 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2831,6 +2831,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000")))
}
}

test("SPARK-26366: verify ReplaceExceptWithFilter") {
Seq(true, false).foreach { enabled =>
withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
val df = spark.createDataFrame(
sparkContext.parallelize(Seq(Row(0, 3, 5),
Row(0, 3, null),
Row(null, 3, 5),
Row(0, null, 5),
Row(0, null, null),
Row(null, null, 5),
Row(null, 3, null),
Row(null, null, null))),
StructType(Seq(StructField("c1", IntegerType),
StructField("c2", IntegerType),
StructField("c3", IntegerType))))
val where = "c2 >= 3 OR c1 >= 0"
val whereNullSafe =
"""
|(c2 IS NOT NULL AND c2 >= 3)
|OR (c1 IS NOT NULL AND c1 >= 0)
""".stripMargin

val df_a = df.filter(where)
val df_b = df.filter(whereNullSafe)
checkAnswer(df.except(df_a), df.except(df_b))

val whereWithIn = "c2 >= 3 OR c1 in (2)"
val whereWithInNullSafe =
"""
|(c2 IS NOT NULL AND c2 >= 3)
""".stripMargin
val dfIn_a = df.filter(whereWithIn)
val dfIn_b = df.filter(whereWithInNullSafe)
checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
}
}
}
}

case class Foo(bar: Option[String])

0 comments on commit a7d50ae

Please sign in to comment.