Skip to content

Commit

Permalink
Optimize aggregation push down column logic (pingcap#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
Novemser authored Dec 27, 2017
1 parent fe033c1 commit 8bbec4f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/main/scala/org/apache/spark/sql/TiStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package org.apache.spark.sql

import com.pingcap.tikv.expression.scalar.TiScalarFunction
import java.time.ZonedDateTime

import com.pingcap.tikv.expression.{ExpressionBlacklist, TiByItem, TiColumnRef, TiExpr}
Expand All @@ -35,6 +36,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.collection.{JavaConversions, mutable}

// TODO: Too many hacks here since we hijack the planning
Expand Down Expand Up @@ -143,6 +145,15 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
dagRequest
}

def extractColumnFromFilter(tiFilter: TiExpr, result: ArrayBuffer[TiColumnRef]): Unit =
tiFilter match {
case fun: TiScalarFunction =>
fun.getArgs.foreach(extractColumnFromFilter(_, result))
case col: TiColumnRef =>
result.add(col)
case _ =>
}

def filterToDAGRequest(
filters: Seq[Expression],
source: TiDBRelation,
Expand Down Expand Up @@ -266,6 +277,7 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
}

def groupAggregateProjection(
filters: Seq[Expression],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
Expand Down Expand Up @@ -375,6 +387,10 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {

val projectSeq: Seq[Attribute] = projects.asInstanceOf[Seq[Attribute]]
projectSeq.foreach(attr => dagReq.addRequiredColumn(TiColumnRef.create(attr.name)))
val pushDownCols = ArrayBuffer[TiColumnRef]()
val tiFilters: Seq[TiExpr] = filters.collect { case BasicExpression(expr) => expr }
tiFilters.foreach(extractColumnFromFilter(_, pushDownCols))
pushDownCols.foreach(dagReq.addRequiredColumn)

aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
Expand Down Expand Up @@ -446,6 +462,7 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) =>
val dagReq: TiDAGRequest = filterToDAGRequest(filters, source)
groupAggregateProjection(
filters,
groupingExpressions,
aggregateExpressions,
resultExpressions,
Expand Down

0 comments on commit 8bbec4f

Please sign in to comment.