Skip to content

Commit

Permalink
[SPARK-21180][SQL] Remove conf from stats functions since now we have…
Browse files Browse the repository at this point in the history
… conf in LogicalPlan

## What changes were proposed in this pull request?

After wiring `SQLConf` in logical plan ([PR 18299](apache#18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`.

## How was this patch tested?

Covered by existing tests, plus some modified existing tests.

Author: wangzhenhua <[email protected]>
Author: Zhenhua Wang <[email protected]>

Closes apache#18391 from wzhfy/removeConf.
  • Loading branch information
wzhfy authored and gatorsmile committed Jun 23, 2017
1 parent 07479b3 commit b803b66
Show file tree
Hide file tree
Showing 38 changed files with 178 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


Expand Down Expand Up @@ -436,7 +435,7 @@ case class CatalogRelation(
createTime = -1
))

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
// For data source tables, we will create a `LogicalRelation` and won't call this method, for
// hive serde tables, we will always generate a statistics.
// TODO: unify the table stats generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
// Do reordering if the number of items is appropriate and join conditions exist.
// We also need to check if costs of all items can be evaluated.
if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
items.forall(_.stats(conf).rowCount.isDefined)) {
items.forall(_.stats.rowCount.isDefined)) {
JoinReorderDP.search(conf, items, conditions, output)
} else {
plan
Expand Down Expand Up @@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
/** Get the cost of the root node of this plan tree. */
def rootCost(conf: SQLConf): Cost = {
if (itemIds.size > 1) {
val rootStats = plan.stats(conf)
val rootStats = plan.stats
Cost(rootStats.rowCount.get, rootStats.sizeInBytes)
} else {
// If the plan is a leaf item, it has zero cost.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
case FullOuter =>
(left.maxRows, right.maxRows) match {
case (None, None) =>
if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) {
if (left.stats.sizeInBytes >= right.stats.sizeInBytes) {
join.copy(left = maybePushLimit(exp, left))
} else {
join.copy(right = maybePushLimit(exp, right))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
// Find if the input plans are eligible for star join detection.
// An eligible plan is a base table access with valid statistics.
val foundEligibleJoin = input.forall {
case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true
case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true
case _ => false
}

Expand Down Expand Up @@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan)
leafCol match {
case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf)
val stats = t.stats
stats.rowCount match {
case Some(rowCount) if rowCount >= 0 =>
if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
Expand Down Expand Up @@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
val leafCol = findLeafNodeCol(column, plan)
leafCol match {
case Some(col) if t.outputSet.contains(col) =>
val stats = t.stats(conf)
val stats = t.stats
stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
case None => false
}
Expand Down Expand Up @@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
*/
private def getTableAccessCardinality(
input: LogicalPlan): Option[BigInt] = input match {
case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined =>
if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) {
Option(input.stats(conf).rowCount.get)
case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined =>
if (conf.cboEnabled && input.stats.rowCount.isDefined) {
Option(input.stats.rowCount.get)
} else {
Option(t.stats(conf).rowCount.get)
Option(t.stats.rowCount.get)
}
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
Expand Down Expand Up @@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
}
}

override def computeStats(conf: SQLConf): Statistics =
override def computeStats: Statistics =
Statistics(sizeInBytes =
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


Expand Down Expand Up @@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
* first time. If the configuration changes, the cache can be invalidated by calling
* [[invalidateStatsCache()]].
*/
final def stats(conf: SQLConf): Statistics = statsCache.getOrElse {
statsCache = Some(computeStats(conf))
final def stats: Statistics = statsCache.getOrElse {
statsCache = Some(computeStats)
statsCache.get
}

Expand All @@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
*
* [[LeafNode]]s must override this.
*/
protected def computeStats(conf: SQLConf): Statistics = {
protected def computeStats: Statistics = {
if (children.isEmpty) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product)
Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product)
}

override def verboseStringWithSuffix: String = {
Expand Down Expand Up @@ -333,21 +332,21 @@ abstract class UnaryNode extends LogicalPlan {

override protected def validConstraints: Set[Expression] = child.constraints

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
// Assume there will be the same number of rows as child has.
var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize
var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
sizeInBytes = 1
}

// Don't propagate rowCount and attributeStats, since they are not estimated here.
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.RandomSampler
Expand Down Expand Up @@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList))

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
if (conf.cboEnabled) {
ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf))
ProjectEstimation.estimate(this).getOrElse(super.computeStats)
} else {
super.computeStats(conf)
super.computeStats
}
}
}
Expand Down Expand Up @@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan)
child.constraints.union(predicates.toSet)
}

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
if (conf.cboEnabled) {
FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf))
FilterEstimation(this).estimate.getOrElse(super.computeStats)
} else {
super.computeStats(conf)
super.computeStats
}
}
}
Expand Down Expand Up @@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
}
}

override def computeStats(conf: SQLConf): Statistics = {
val leftSize = left.stats(conf).sizeInBytes
val rightSize = right.stats(conf).sizeInBytes
override def computeStats: Statistics = {
val leftSize = left.stats.sizeInBytes
val rightSize = right.stats.sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
Statistics(
sizeInBytes = sizeInBytes,
hints = left.stats(conf).hints.resetForJoin())
hints = left.stats.hints.resetForJoin())
}
}

Expand All @@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le

override protected def validConstraints: Set[Expression] = leftConstraints

override def computeStats(conf: SQLConf): Statistics = {
left.stats(conf).copy()
override def computeStats: Statistics = {
left.stats.copy()
}
}

Expand Down Expand Up @@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
children.length > 1 && childrenResolved && allChildrenCompatible
}

override def computeStats(conf: SQLConf): Statistics = {
val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum
override def computeStats: Statistics = {
val sizeInBytes = children.map(_.stats.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes)
}

Expand Down Expand Up @@ -357,20 +356,20 @@ case class Join(
case _ => resolvedExceptNatural
}

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
def simpleEstimation: Statistics = joinType match {
case LeftAnti | LeftSemi =>
// LeftSemi and LeftAnti won't ever be bigger than left
left.stats(conf)
left.stats
case _ =>
// Make sure we don't propagate isBroadcastable in other joins, because
// they could explode the size.
val stats = super.computeStats(conf)
val stats = super.computeStats
stats.copy(hints = stats.hints.resetForJoin())
}

if (conf.cboEnabled) {
JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation)
JoinEstimation.estimate(this).getOrElse(simpleEstimation)
} else {
simpleEstimation
}
Expand Down Expand Up @@ -523,7 +522,7 @@ case class Range(

override def newInstance(): Range = copy(output = output.map(_.newInstance()))

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val sizeInBytes = LongType.defaultSize * numElements
Statistics( sizeInBytes = sizeInBytes )
}
Expand Down Expand Up @@ -556,20 +555,20 @@ case class Aggregate(
child.constraints.union(getAliasedConstraints(nonAgg))
}

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
def simpleEstimation: Statistics = {
if (groupingExpressions.isEmpty) {
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
rowCount = Some(1),
hints = child.stats(conf).hints)
hints = child.stats.hints)
} else {
super.computeStats(conf)
super.computeStats
}
}

if (conf.cboEnabled) {
AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation)
AggregateEstimation.estimate(this).getOrElse(simpleEstimation)
} else {
simpleEstimation
}
Expand Down Expand Up @@ -672,8 +671,8 @@ case class Expand(
override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))

override def computeStats(conf: SQLConf): Statistics = {
val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length
override def computeStats: Statistics = {
val sizeInBytes = super.computeStats.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}

Expand Down Expand Up @@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
case _ => None
}
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf)
val childStats = child.stats
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
// Don't propagate column stats, because we don't know the distribution after a limit operation
Statistics(
Expand All @@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
case _ => None
}
}
override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val childStats = child.stats(conf)
val childStats = child.stats
if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
Expand Down Expand Up @@ -832,9 +831,9 @@ case class Sample(

override def output: Seq[Attribute] = child.output

override def computeStats(conf: SQLConf): Statistics = {
override def computeStats: Statistics = {
val ratio = upperBound - lowerBound
val childStats = child.stats(conf)
val childStats = child.stats
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) {
sizeInBytes = 1
Expand Down Expand Up @@ -898,7 +897,7 @@ case class RepartitionByExpression(
case object OneRowRelation extends LeafNode {
override def maxRows: Option[Long] = Some(1)
override def output: Seq[Attribute] = Nil
override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1)
override def computeStats: Statistics = Statistics(sizeInBytes = 1)
}

/** A logical plan for `dropDuplicates`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf

/**
* A general hint for the child that is not yet resolved. This node is generated by the parser and
Expand All @@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())

override lazy val canonicalized: LogicalPlan = child.canonicalized

override def computeStats(conf: SQLConf): Statistics = {
val stats = child.stats(conf)
override def computeStats: Statistics = {
val stats = child.stats
stats.copy(hints = hints)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
import org.apache.spark.sql.internal.SQLConf


object AggregateEstimation {
Expand All @@ -29,13 +28,13 @@ object AggregateEstimation {
* Estimate the number of output rows based on column stats of group-by columns, and propagate
* column stats for aggregate expressions.
*/
def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.stats(conf)
def estimate(agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.stats
// Check if we have column stats for all group-by columns.
val colStatsExist = agg.groupingExpressions.forall { e =>
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
}
if (rowCountsExist(conf, agg.child) && colStatsExist) {
if (rowCountsExist(agg.child) && colStatsExist) {
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
// the data contains all combinations of distinct values of group-by columns.
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
Expand Down
Loading

0 comments on commit b803b66

Please sign in to comment.