Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14032] [SQL] Eliminate Unnecessary Distinct/Aggregate #11854

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ package object dsl {

def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)

def distinct(): LogicalPlan = Distinct(logicalPlan)

def generate(
generator: Generator,
join: Boolean = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
EliminateSerialization) ::
EliminateSerialization,
EliminateDistinct) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
Expand Down Expand Up @@ -1205,6 +1206,20 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] {
}
}

/**
* Removes useless Distinct that are not necessary.
*/
object EliminateDistinct extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Eliminate the useless distinct.
// Distinct has been replaced by Aggregate in the rule ReplaceDistinctWithAggregate
case a @ Aggregate(grouping, aggs, child)
if child.distinctSet.nonEmpty && child.distinctSet.subsetOf(AttributeSet(aggs)) &&
a.isForDistinct =>
child
}
}

/**
* Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
Expand Down Expand Up @@ -1303,7 +1318,7 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Intersect(left, right) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
val joinCond = left.output.zip(right.output).map(EqualNullSafe.tupled)
Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
protected def validConstraints: Set[Expression] = Set.empty

/**
* The set of attributes whose combination can uniquely identify a row.
*/
def distinctSet: AttributeSet = AttributeSet.empty

/**
* Returns the set of attributes that are output by this node.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
}

override def distinctSet: AttributeSet = {
if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) {
child.distinctSet
} else {
AttributeSet.empty
}
}

override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList))
}
Expand Down Expand Up @@ -107,6 +115,8 @@ case class Filter(condition: Expression, child: LogicalPlan)

override def maxRows: Option[Long] = child.maxRows

override def distinctSet: AttributeSet = child.distinctSet

override protected def validConstraints: Set[Expression] =
child.constraints.union(splitConjunctivePredicates(condition).toSet)
}
Expand Down Expand Up @@ -137,6 +147,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}

override def distinctSet: AttributeSet = left.outputSet

override protected def validConstraints: Set[Expression] =
leftConstraints.union(rightConstraints)

Expand Down Expand Up @@ -168,6 +180,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output

override def distinctSet: AttributeSet = left.outputSet

override protected def validConstraints: Set[Expression] = leftConstraints

override lazy val resolved: Boolean =
Expand Down Expand Up @@ -265,6 +279,9 @@ case class Join(
}
}

override def distinctSet: AttributeSet =
if (joinType == LeftSemi) left.distinctSet else AttributeSet.empty

override protected def validConstraints: Set[Expression] = {
joinType match {
case Inner if condition.isDefined =>
Expand Down Expand Up @@ -312,6 +329,7 @@ case class Join(
*/
case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def distinctSet: AttributeSet = child.distinctSet

// We manually set statistics of BroadcastHint to smallest value to make sure
// the plan wrapped by BroadcastHint will be considered to broadcast later.
Expand Down Expand Up @@ -367,6 +385,7 @@ case class Sort(
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def distinctSet: AttributeSet = child.distinctSet
}

/** Factory for constructing new `Range` nodes. */
Expand Down Expand Up @@ -422,6 +441,19 @@ case class Aggregate(
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows

override def distinctSet: AttributeSet = {
if (isForDistinct) {
AttributeSet(aggregateExpressions)
} else if (child.outputSet.nonEmpty && child.outputSet.subsetOf(outputSet)) {
child.distinctSet
} else {
AttributeSet.empty
}
}

def isForDistinct: Boolean =
groupingExpressions.nonEmpty && groupingExpressions == aggregateExpressions

override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(aggregateExpressions))

Expand All @@ -443,6 +475,8 @@ case class Window(
override def output: Seq[Attribute] =
child.output ++ windowExpressions.map(_.toAttribute)

override def distinctSet: AttributeSet = child.distinctSet

def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute))
}

Expand Down Expand Up @@ -585,6 +619,7 @@ object Limit {

case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def distinctSet: AttributeSet = child.distinctSet
override def maxRows: Option[Long] = {
limitExpr match {
case IntegerLiteral(limit) => Some(limit)
Expand All @@ -600,6 +635,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN

case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def distinctSet: AttributeSet = child.distinctSet
override def maxRows: Option[Long] = {
limitExpr match {
case IntegerLiteral(limit) => Some(limit)
Expand All @@ -615,6 +651,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo

case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode {

override def distinctSet: AttributeSet = child.distinctSet
override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias)))
}

Expand All @@ -638,6 +675,7 @@ case class Sample(
val isTableSample: java.lang.Boolean = false) extends UnaryNode {

override def output: Seq[Attribute] = child.output
override def distinctSet: AttributeSet = child.distinctSet

override def statistics: Statistics = {
val ratio = upperBound - lowerBound
Expand All @@ -658,6 +696,7 @@ case class Sample(
case class Distinct(child: LogicalPlan) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
override def distinctSet: AttributeSet = child.outputSet
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class ReplaceOperatorSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Replace Operators", FixedPoint(100),
ReplaceIntersectWithSemiJoin,
ReplaceDistinctWithAggregate,
ReplaceIntersectWithSemiJoin) :: Nil
EliminateDistinct) :: Nil
}

test("replace Intersect with Left-semi Join") {
Expand All @@ -40,19 +41,90 @@ class ReplaceOperatorSuite extends PlanTest {
val optimized = Optimize.execute(query.analyze)

val correctAnswer =
Aggregate(table1.output, table1.output,
Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze
table1.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).groupBy('a, 'b)('a, 'b).analyze

comparePlans(optimized, correctAnswer)
}

test("replace Intersect with Left-semi Join whose left is Distinct") {
val table1 = LocalRelation('a.int, 'b.int)
val table2 = LocalRelation('c.int, 'd.int)

val query = table1.distinct().intersect(table2)
val optimized = Optimize.execute(query.analyze)

val correctAnswer =
table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze

comparePlans(optimized, correctAnswer)
}

test("continuous Intersect whose children containing Distinct") {
val table1 = LocalRelation('a.int, 'b.int)
val table2 = LocalRelation('c.int, 'd.int)
val table3 = LocalRelation('e.int, 'f.int)

// DISTINCT (actually, it is AGGREGATE) is the direct child
val query1 = table1.distinct().intersect(table2).intersect(table3)
val correctAnswer1 =
table1.groupBy('a, 'b)('a, 'b)
.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))
.join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze
comparePlans(Optimize.execute(query1.analyze), correctAnswer1)
}

test("continuous Intersect whose children do not contain Distinct") {
val table1 = LocalRelation('a.int, 'b.int)
val table2 = LocalRelation('c.int, 'd.int)
val table3 = LocalRelation('e.int, 'f.int)

// Just need one Distinct for continuous Intersect, even if no child has Distinct.
val query1 = table1.intersect(table2).intersect(table3)
val correctAnswer1 =
table1
.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).groupBy('a, 'b)('a, 'b)
.join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze
comparePlans(Optimize.execute(query1.analyze), correctAnswer1)
}

test("replace Intersect with Left-semi Join whose left is inferred to have distinct values") {
val table1 = LocalRelation('a.int)
val table2 = LocalRelation('c.int, 'd.int)
val table3 = LocalRelation('e.int, 'f.int)

// DISTINCT is inferred from the child's child
val query2 = table1.distinct()
.where('a > 3).limit(5)
.select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc)
.intersect(table2).intersect(table3)
val correctAnswer2 =
table1.groupBy('a)('a).where('a > 3).limit(5)
.select('a.attr, ('a + 1).as("b")).orderBy('a.asc, 'b.desc)
.join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))
.join(table3, LeftSemi, Option('a <=> 'e && 'b <=> 'f)).analyze
comparePlans(Optimize.execute(query2.analyze), correctAnswer2)
}

test("replace Intersect with Left-semi Join whose left is the Distinct") {
val table1 = LocalRelation('a.int, 'b.int)
val table2 = LocalRelation('c.int, 'd.int)

val query = table1.groupBy('a, 'b)('a, 'b).intersect(table2)
val optimized = Optimize.execute(query.analyze)

val correctAnswer =
table1.groupBy('a, 'b)('a, 'b).join(table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd)).analyze

comparePlans(optimized, correctAnswer)
}

test("replace Distinct with Aggregate") {
val input = LocalRelation('a.int, 'b.int)

val query = Distinct(input)
val query = input.distinct()
val optimized = Optimize.execute(query.analyze)

val correctAnswer = Aggregate(input.output, input.output, input)
val correctAnswer = input.groupBy('a, 'b)('a, 'b).analyze

comparePlans(optimized, correctAnswer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row("id1", 1) ::
Row("id", 1) ::
Row("id1", 2) :: Nil)

checkAnswer(
df.distinct().intersect(df),
Row("id1", 1) ::
Row("id", 1) ::
Row("id1", 2) :: Nil)
}

test("intersect - nullability") {
Expand Down