Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14032] [SQL] Eliminate Unnecessary Distinct/Aggregate #11854

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ package object dsl {

def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)

def distinct(): LogicalPlan = Distinct(logicalPlan)

def generate(
generator: Generator,
join: Boolean = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
PruneFilters,
SimplifyCasts,
SimplifyCaseConversionExpressions,
EliminateSerialization) ::
EliminateSerialization,
EliminateDistinct) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
Expand Down Expand Up @@ -1193,6 +1194,41 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] {
}
}

/**
* Removes useless Distinct that are not necessary.
*/
object EliminateDistinct extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Eliminate the useless distinct.
// Distinct has been replaced by Aggregate in the rule ReplaceDistinctWithAggregate
case a @ Aggregate(grouping, aggs, child) if isDistinct(a) && isDistinct(child) => child
}

// propagate the distinct property from the child
@tailrec
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution is to add a property isDistinct to LogicalPlan. However, it could be expensive for recursive calls, compared with the @tailrec. In the future, if the physical plan will use the property isDistinct, we can rewrite it. Actually, this is a very critical property at runtime algorithm optimization. Thanks!

private def isDistinct(plan: LogicalPlan): Boolean = plan match {
// Distinct(left) or Aggregate(left.output, left.output, _) always returns distinct results
case _: Distinct => true
case Aggregate(grouping, aggs, _) if grouping == aggs => true
// BinaryNode:
case p @ Join(_, _, LeftSemi, _) => isDistinct(p.left)
case p: Intersect => isDistinct(p.left)
case p: Except => isDistinct(p.left)
// UnaryNode:
case p: Project if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child)
case p: Aggregate if p.child.outputSet.subsetOf(p.outputSet) => isDistinct(p.child)
case p: Filter => isDistinct(p.child)
case p: GlobalLimit => isDistinct(p.child)
case p: LocalLimit => isDistinct(p.child)
case p: Sort => isDistinct(p.child)
case p: BroadcastHint => isDistinct(p.child)
case p: Sample => isDistinct(p.child)
case p: SubqueryAlias => isDistinct(p.child)
// Others:
case o => false
}
}

/**
* Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
Expand Down Expand Up @@ -1291,7 +1327,7 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Intersect(left, right) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
val joinCond = left.output.zip(right.output).map(EqualNullSafe.tupled)
Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class ReplaceOperatorSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Replace Operators", FixedPoint(100),
ReplaceIntersectWithSemiJoin,
ReplaceDistinctWithAggregate,
ReplaceIntersectWithSemiJoin) :: Nil
EliminateDistinct) :: Nil
}

test("replace Intersect with Left-semi Join") {
Expand All @@ -40,19 +41,76 @@ class ReplaceOperatorSuite extends PlanTest {
val optimized = Optimize.execute(query.analyze)

val correctAnswer =
Aggregate(table1.output, table1.output,
Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze
table1.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).groupBy('a, 'b)('a, 'b).analyze

comparePlans(optimized, correctAnswer)
}

test("replace Intersect with Left-semi Join whose left is Distinct") {
val table1 = LocalRelation('a.int, 'b.int)
val table2 = LocalRelation('c.int, 'd.int)

val query = table1.distinct().intersect(table2)
val optimized = Optimize.execute(query.analyze)

val correctAnswer =
table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze

comparePlans(optimized, correctAnswer)
}

test("continuous Intersect whose children containing Distinct") {
val table1 = LocalRelation('a.int, 'b.int)
val table2 = LocalRelation('c.int, 'd.int)
val table3 = LocalRelation('e.int, 'f.int)

// DISTINCT (actually, it is AGGREGATE) is the direct child
val query1 = table1.distinct().intersect(table2).intersect(table3)
val correctAnswer1 =
table1.groupBy('a, 'b)('a, 'b)
.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))
.join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze
comparePlans(Optimize.execute(query1.analyze), correctAnswer1)
}

test("replace Intersect with Left-semi Join whose left is inferred to have distinct values") {
val table1 = LocalRelation('a.int)
val table2 = LocalRelation('c.int, 'd.int)
val table3 = LocalRelation('e.int, 'f.int)

// DISTINCT is inferred from the child's child
val query2 = table1.distinct()
.where('a > 3).limit(5)
.select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc)
.intersect(table2).intersect(table3)
val correctAnswer2 =
table1.groupBy('a)('a).where('a > 3).limit(5)
.select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc)
.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))
.join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze
comparePlans(Optimize.execute(query2.analyze), correctAnswer2)
}

test("replace Intersect with Left-semi Join whose left is the Distinct") {
val table1 = LocalRelation('a.int, 'b.int)
val table2 = LocalRelation('c.int, 'd.int)

val query = table1.groupBy('a, 'b)('a, 'b).intersect(table2)
val optimized = Optimize.execute(query.analyze)

val correctAnswer =
table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze

comparePlans(optimized, correctAnswer)
}

test("replace Distinct with Aggregate") {
val input = LocalRelation('a.int, 'b.int)

val query = Distinct(input)
val query = input.distinct()
val optimized = Optimize.execute(query.analyze)

val correctAnswer = Aggregate(input.output, input.output, input)
val correctAnswer = input.groupBy('a, 'b)('a, 'b).analyze

comparePlans(optimized, correctAnswer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row("id1", 1) ::
Row("id", 1) ::
Row("id1", 2) :: Nil)

checkAnswer(
df.distinct().intersect(df),
Row("id1", 1) ::
Row("id", 1) ::
Row("id1", 2) :: Nil)
}

test("intersect - nullability") {
Expand Down