diff --git a/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala b/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala index 3e44cc4371..1176b159b0 100644 --- a/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala +++ b/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala @@ -15,10 +15,12 @@ package org.apache.spark.sql -import com.pingcap.tikv.expression.scalar.TiScalarFunction import java.time.ZonedDateTime +import scala.collection.JavaConverters._ + import com.pingcap.tikv.codec.IgnoreUnsupportedTypeException +import com.pingcap.tikv.expression.scalar.TiScalarFunction import com.pingcap.tikv.expression.{aggregate => _, _} import com.pingcap.tikv.meta.TiDAGRequest import com.pingcap.tikv.meta.TiDAGRequest.PushDownType @@ -26,8 +28,9 @@ import com.pingcap.tikv.predicates.ScanBuilder import com.pingcap.tispark.TiUtils._ import com.pingcap.tispark.{BasicExpression, TiConfigConst, TiDBRelation, TiUtils} import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.NamedExpression.newExprId import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, _} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Cast, Divide, ExprId, Expression, IntegerLiteral, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Cast, Divide, Expression, IntegerLiteral, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.planning.{PhysicalAggregation, PhysicalOperation} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ @@ -36,10 +39,6 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.{JavaConversions, mutable} - // TODO: Too many hacks here since we hijack the planning // but we don't have full control over planning stage // We cannot pass context around during planning so @@ -48,7 +47,8 @@ import scala.collection.{JavaConversions, mutable} // have multiple plan to pushdown class TiStrategy(context: SQLContext) extends Strategy with Logging { val sqlConf: SQLConf = context.conf - def blacklist: ExpressionBlacklist = { + + private def blacklist: ExpressionBlacklist = { val blacklistString = sqlConf.getConfString(TiConfigConst.UNSUPPORTED_PUSHDOWN_EXPR, "") new ExpressionBlacklist(blacklistString) } @@ -58,19 +58,19 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { new TypeBlacklist(blacklistString) } - def allowAggregationPushDown(): Boolean = { + private def allowAggregationPushdown(): Boolean = { sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toBoolean } - def allowIndexDoubleRead(): Boolean = { + private def allowIndexDoubleRead(): Boolean = { sqlConf.getConfString(TiConfigConst.ALLOW_INDEX_DOUBLE_READ, "false").toBoolean } - def useStreamingProcess(): Boolean = { + private def useStreamingProcess(): Boolean = { sqlConf.getConfString(TiConfigConst.COPROCESS_STREAMING, "false").toBoolean } - def timeZoneOffset(): Int = { + private def timeZoneOffset(): Int = { sqlConf .getConfString( TiConfigConst.KV_TIMEZONE_OFFSET, @@ -79,7 +79,7 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { .toInt } - def pushDownType(): PushDownType = { + private def pushDownType(): PushDownType = { if (useStreamingProcess()) { PushDownType.STREAMING } else { @@ -97,18 +97,20 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { .flatten } - private def toCoprocessorRDD(source: TiDBRelation, - output: Seq[Attribute], - dagRequest: TiDAGRequest): SparkPlan = { + private def toCoprocessorRDD( + source: TiDBRelation, + output: Seq[Attribute], + dagRequest: TiDAGRequest + ): SparkPlan = { val table = source.table dagRequest.setTableInfo(table) - if (dagRequest.getFields.isEmpty) { dagRequest.addRequiredColumn(TiColumnRef.create(table.getColumns.get(0).getName)) } + // Need to resolve column info after add aggregation push downs dagRequest.resolve() - val notAllowPushDown = dagRequest.getFields + val notAllowPushDown = dagRequest.getFields.asScala .map { _.getColumnInfo.getType.simpleTypeName } .exists { typeBlackList.isUnsupportedType } @@ -124,30 +126,30 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { } } - def aggregationToDAGRequest( + private def aggregationToDAGRequest( groupByList: Seq[NamedExpression], aggregates: Seq[AggregateExpression], source: TiDBRelation, dagRequest: TiDAGRequest = new TiDAGRequest(pushDownType(), timeZoneOffset()) ): TiDAGRequest = { - aggregates.foreach { - case AggregateExpression(_: Average, _, _, _) => + aggregates.map { _.aggregateFunction }.foreach { + case _: Average => throw new IllegalArgumentException("Should never be here") - case AggregateExpression(f @ Sum(BasicExpression(arg)), _, _, _) => + case f @ Sum(BasicExpression(arg)) => dagRequest.addAggregate(new TiSum(arg), fromSparkType(f.dataType)) - case AggregateExpression(f @ Count(args), _, _, _) => + case f @ Count(args) => val tiArgs = args.flatMap(BasicExpression.convertToTiExpr) dagRequest.addAggregate(new TiCount(tiArgs: _*), fromSparkType(f.dataType)) - case AggregateExpression(f @ Min(BasicExpression(arg)), _, _, _) => + case f @ Min(BasicExpression(arg)) => dagRequest.addAggregate(new TiMin(arg), fromSparkType(f.dataType)) - case AggregateExpression(f @ Max(BasicExpression(arg)), _, _, _) => + case f @ Max(BasicExpression(arg)) => dagRequest.addAggregate(new TiMax(arg), fromSparkType(f.dataType)) - case AggregateExpression(f @ First(BasicExpression(arg), _), _, _, _) => + case f @ First(BasicExpression(arg), _) => dagRequest.addAggregate(new TiFirst(arg), fromSparkType(f.dataType)) case _ => @@ -163,16 +165,13 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { dagRequest } - def extractColumnFromFilter(tiFilter: TiExpr, result: ArrayBuffer[TiColumnRef]): Unit = - tiFilter match { - case fun: TiScalarFunction => - fun.getArgs.foreach(extractColumnFromFilter(_, result)) - case col: TiColumnRef => - result.add(col) - case _ => - } + def referencedTiColumns(expression: TiExpr): Seq[TiColumnRef] = expression match { + case f: TiScalarFunction => f.getArgs.asScala.flatMap { referencedTiColumns } + case ref: TiColumnRef => Seq(ref) + case _ => Nil + } - def filterToDAGRequest( + private def filterToDAGRequest( filters: Seq[Expression], source: TiDBRelation, dagRequest: TiDAGRequest = new TiDAGRequest(pushDownType(), timeZoneOffset()) @@ -180,42 +179,39 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { val tiFilters: Seq[TiExpr] = filters.collect { case BasicExpression(expr) => expr } val scanBuilder: ScanBuilder = new ScanBuilder val tableScanPlan = - scanBuilder.buildTableScan(JavaConversions.seqAsJavaList(tiFilters), source.table) + scanBuilder.buildTableScan(tiFilters.asJava, source.table) val scanPlan = if (allowIndexDoubleRead()) { // We need to prepare downgrade information in case of index scan downgrade happens. - tableScanPlan.getFilters.foreach(dagRequest.addDowngradeFilter) - scanBuilder.buildScan(JavaConversions.seqAsJavaList(tiFilters), source.table) + tableScanPlan.getFilters.asScala.foreach { dagRequest.addDowngradeFilter } + scanBuilder.buildScan(tiFilters.asJava, source.table) } else { tableScanPlan } dagRequest.addRanges(scanPlan.getKeyRanges) - scanPlan.getFilters.foreach(dagRequest.addFilter) + scanPlan.getFilters.asScala.foreach { dagRequest.addFilter } if (scanPlan.isIndexScan) { dagRequest.setIndexInfo(scanPlan.getIndex) } dagRequest } - def addSortOrder(request: TiDAGRequest, sortOrder: Seq[SortOrder]): Unit = - if (sortOrder != null) { - sortOrder.foreach( - (order: SortOrder) => - request.addOrderByItem( - TiByItem.create( - BasicExpression.convertToTiExpr(order.child).get, - order.direction.sql.equalsIgnoreCase("DESC") - ) + private def addSortOrder(request: TiDAGRequest, sortOrder: Seq[SortOrder]): Unit = + sortOrder.foreach { order: SortOrder => + request.addOrderByItem( + TiByItem.create( + BasicExpression.convertToTiExpr(order.child).get, + order.direction.sql.equalsIgnoreCase("DESC") ) ) } - def pruneTopNFilterProject( + private def pruneTopNFilterProject( limit: Int, projectList: Seq[NamedExpression], filterPredicates: Seq[Expression], source: TiDBRelation, - sortOrder: Seq[SortOrder] = null + sortOrder: Seq[SortOrder] ): SparkPlan = { val request = new TiDAGRequest(pushDownType(), timeZoneOffset()) request.setLimit(limit) @@ -223,21 +219,21 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { pruneFilterProject(projectList, filterPredicates, source, request) } - def collectLimit(limit: Int, child: LogicalPlan): SparkPlan = child match { + private def collectLimit(limit: Int, child: LogicalPlan): SparkPlan = child match { case PhysicalOperation(projectList, filters, LogicalRelation(source: TiDBRelation, _, _)) if filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) => - pruneTopNFilterProject(limit, projectList, filters, source, null) + pruneTopNFilterProject(limit, projectList, filters, source, Nil) case _ => planLater(child) } - def takeOrderedAndProject( + private def takeOrderedAndProject( limit: Int, sortOrder: Seq[SortOrder], child: LogicalPlan, project: Seq[NamedExpression] ): SparkPlan = { - // If sortOrder is not null, limit must be greater than 0 - if (limit < 0 || (sortOrder == null && limit == 0)) { + // If sortOrder is empty, limit must be greater than 0 + if (limit < 0 || (sortOrder.isEmpty && limit == 0)) { return execution.TakeOrderedAndProjectExec(limit, sortOrder, project, planLater(child)) } @@ -254,7 +250,7 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { } } - def pruneFilterProject( + private def pruneFilterProject( projectList: Seq[NamedExpression], filterPredicates: Seq[Expression], source: TiDBRelation, @@ -298,7 +294,7 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { } } - def groupAggregateProjection( + private def groupAggregateProjection( filters: Seq[Expression], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -307,126 +303,70 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { source: TiDBRelation, dagReq: TiDAGRequest ): Seq[SparkPlan] = { - val aliasMap = mutable.HashMap[(Boolean, Expression), Alias]() - val avgPushdownRewriteMap = mutable.HashMap[ExprId, List[AggregateExpression]]() - val avgFinalRewriteMap = mutable.HashMap[ExprId, List[AggregateExpression]]() - - def newAggregate(aggFunc: AggregateFunction, originalAggExpr: AggregateExpression) = - AggregateExpression( - aggFunc, - originalAggExpr.mode, - originalAggExpr.isDistinct, - originalAggExpr.resultId - ) - - def newAggregateWithId(aggFunc: AggregateFunction, originalAggExpr: AggregateExpression) = - AggregateExpression( - aggFunc, - originalAggExpr.mode, - originalAggExpr.isDistinct, - NamedExpression.newExprId - ) + val deterministicAggAliases = aggregateExpressions.collect { + case e if e.deterministic => e.canonicalized -> Alias(e, e.toString())() + }.toMap - def toAlias(expr: AggregateExpression) = - if (!expr.deterministic) { - Alias(expr, expr.toString())() - } else { - aliasMap.getOrElseUpdate( - (expr.deterministic, expr.canonicalized), - Alias(expr, expr.toString)() - ) - } + def aliasPushedPartialResult(e: AggregateExpression): Alias = { + deterministicAggAliases.getOrElse(e.canonicalized, Alias(e, e.toString())()) + } val residualAggregateExpressions = aggregateExpressions.map { aggExpr => + // As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst + // aggregate expressions with new ones that merges the partial aggregation results returned by + // TiKV. + // + // NOTE: Unlike simple aggregate functions (e.g., `Max`, `Min`, etc.), `Count` must be + // replaced with a `Sum` to sum up the partial counts returned by TiKV. + // + // NOTE: All `Average`s should have already been rewritten into `Sum`s and `Count`s by the + // `TiAggregation` pattern extractor. + + // An attribute referring to the partial aggregation results returned by TiKV. + val partialResultRef = aliasPushedPartialResult(aggExpr).toAttribute + aggExpr.aggregateFunction match { - // here aggExpr is the original AggregationExpression - // and will be pushed down to TiKV - case Max(_) => newAggregate(Max(toAlias(aggExpr).toAttribute), aggExpr) - case Min(_) => newAggregate(Min(toAlias(aggExpr).toAttribute), aggExpr) - case Count(_) => newAggregate(Sum(toAlias(aggExpr).toAttribute), aggExpr) - case Sum(_) => newAggregate(Sum(toAlias(aggExpr).toAttribute), aggExpr) - case First(_, ignoreNullsExpr) => - newAggregate(First(toAlias(aggExpr).toAttribute, ignoreNullsExpr), aggExpr) - case _ => aggExpr - } - } flatMap { aggExpr => - aggExpr match { - // We have to separate average into sum and count - // and for outside expression such as average(x) + 1, - // Spark has lift agg + 1 up to resultExpressions - // We need to modify the reference there as well to forge - // Divide(sum/count) + 1 - case aggExpr @ AggregateExpression(Average(ref), _, _, _) => - // Need a type promotion - val sumToPush = newAggregate(Sum(ref), aggExpr) - val countToPush = newAggregate(Count(ref), aggExpr) - - // Need a new expression id since they are not simply rewrite as above - val sumFinal = newAggregateWithId(Sum(toAlias(sumToPush).toAttribute), aggExpr) - val countFinal = newAggregateWithId(Sum(toAlias(countToPush).toAttribute), aggExpr) - - avgPushdownRewriteMap(aggExpr.resultId) = List(sumToPush, countToPush) - avgFinalRewriteMap(aggExpr.resultId) = List(sumFinal, countFinal) - List(sumFinal, countFinal) - case _ => aggExpr :: Nil + case e: Max => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case e: Min => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case e: Sum => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case e: First => aggExpr.copy(aggregateFunction = e.copy(child = partialResultRef)) + case _: Count => aggExpr.copy(aggregateFunction = Sum(partialResultRef)) + case _: Average => throw new IllegalStateException("All AVGs should have been rewritten.") + case _ => aggExpr } } - val pushdownAggregates = aggregateExpressions.flatMap { aggExpr => - avgPushdownRewriteMap - .getOrElse(aggExpr.resultId, List(aggExpr)) - }.distinct - - aggregationToDAGRequest(groupingExpressions, pushdownAggregates, source, dagReq) - - val rewrittenResultExpression = resultExpressions.map( - expr => - expr - .transformDown { - case aggExpr: AttributeReference if avgFinalRewriteMap.contains(aggExpr.exprId) => - // Replace the original Average expression with Div of Alias - val sumCountPair = avgFinalRewriteMap(aggExpr.exprId) - - // We missed the chance for auto-coerce already - // so manual cast needed - // Also, convert into resultAttribute since - // they are created by tiSpark without Spark conversion - // TODO: Is DoubleType a best target type for all? - Cast( - Divide( - Cast(sumCountPair.head.resultAttribute, DoubleType), - Cast(sumCountPair(1).resultAttribute, DoubleType) - ), - aggExpr.dataType - ) - case other => other - } - .asInstanceOf[NamedExpression] - ) + aggregationToDAGRequest(groupingExpressions, aggregateExpressions.distinct, source, dagReq) - val output = (pushdownAggregates.map(x => toAlias(x)) ++ groupingExpressions) - .map(_.toAttribute) + val projectionTiRefs = projects + .map { _.toAttribute.name } + .map { TiColumnRef.create } - val projectSeq: Seq[Attribute] = projects.asInstanceOf[Seq[Attribute]] - projectSeq.foreach(attr => dagReq.addRequiredColumn(TiColumnRef.create(attr.name))) - val pushDownCols = ArrayBuffer[TiColumnRef]() - val tiFilters: Seq[TiExpr] = filters.collect { case BasicExpression(expr) => expr } - tiFilters.foreach(extractColumnFromFilter(_, pushDownCols)) - pushDownCols.foreach(dagReq.addRequiredColumn) + val filterTiRefs = filters + .collect { case BasicExpression(tiExpr) => tiExpr } + .flatMap { referencedTiColumns } + + projectionTiRefs ++ filterTiRefs foreach { dagReq.addRequiredColumn } + + val output = (aggregateExpressions.map(aliasPushedPartialResult) ++ groupingExpressions).map { + _.toAttribute + } aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, residualAggregateExpressions, - rewrittenResultExpression, + resultExpressions, toCoprocessorRDD(source, output, dagReq) ) } - def isValidAggregates(groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - filters: Seq[Expression], - source: TiDBRelation): Boolean = { - allowAggregationPushDown && + private def isValidAggregates( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + filters: Seq[Expression], + source: TiDBRelation + ): Boolean = { + allowAggregationPushdown && filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) && groupingExpressions.forall(TiUtils.isSupportedGroupingExpr(_, source, blacklist)) && aggregateExpressions.forall(TiUtils.isSupportedAggregate(_, source, blacklist)) && @@ -500,9 +440,43 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { object TiAggregation { type ReturnType = PhysicalAggregation.ReturnType - def unapply(a: Any): Option[ReturnType] = a match { + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case PhysicalAggregation(groupingExpressions, aggregateExpressions, resultExpressions, child) => - Some(groupingExpressions, aggregateExpressions, resultExpressions, child) + // Rewrites all `Average`s into the form of `Divide(Sum / Count)` so that we can push the + // converted `Sum`s and `Count`s down to TiKV. + val (averages, averagesEliminated) = aggregateExpressions.partition { + case AggregateExpression(_: Average, _, _, _) => true + case _ => false + } + + // An auxiliary map that maps result attribute IDs of all detected `Average`s to corresponding + // converted `Sum`s and `Count`s. + val rewriteMap = averages.map { + case a @ AggregateExpression(Average(ref), _, _, _) => + a.resultAttribute -> Seq( + a.copy(aggregateFunction = Sum(ref), resultId = newExprId), + a.copy(aggregateFunction = Count(ref), resultId = newExprId) + ) + }.toMap + + val rewrite: PartialFunction[Expression, Expression] = rewriteMap.map { + case (ref, Seq(sum, count)) => + val castedSum = Cast(sum.resultAttribute, DoubleType) + val castedCount = Cast(count.resultAttribute, DoubleType) + val division = Cast(Divide(castedSum, castedCount), ref.dataType) + (ref: Expression) -> Alias(division, ref.name)(exprId = ref.exprId) + } + + val rewrittenResultExpressions = resultExpressions + .map { _ transform rewrite } + .map { case e: NamedExpression => e } + + val rewrittenAggregateExpressions = { + val extraSumsAndCounts = rewriteMap.values.reduceOption { _ ++ _ } getOrElse Nil + (averagesEliminated ++ extraSumsAndCounts).distinct + } + + Some(groupingExpressions, rewrittenAggregateExpressions, rewrittenResultExpressions, child) case _ => Option.empty[ReturnType] }