diff --git a/src/main/scala/org/apache/spark/sql/TiStrategy.scala b/src/main/scala/org/apache/spark/sql/TiStrategy.scala index df31eb4d2a..2f7ebe22e3 100644 --- a/src/main/scala/org/apache/spark/sql/TiStrategy.scala +++ b/src/main/scala/org/apache/spark/sql/TiStrategy.scala @@ -15,6 +15,7 @@ package org.apache.spark.sql +import com.pingcap.tikv.expression.scalar.TiScalarFunction import java.time.ZonedDateTime import com.pingcap.tikv.expression.{ExpressionBlacklist, TiByItem, TiColumnRef, TiExpr} @@ -35,6 +36,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer import scala.collection.{JavaConversions, mutable} // TODO: Too many hacks here since we hijack the planning @@ -141,6 +143,15 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { dagRequest } + def extractColumnFromFilter(tiFilter: TiExpr, result: ArrayBuffer[TiColumnRef]): Unit = + tiFilter match { + case fun: TiScalarFunction => + fun.getArgs.foreach(extractColumnFromFilter(_, result)) + case col: TiColumnRef => + result.add(col) + case _ => + } + def filterToDAGRequest( filters: Seq[Expression], source: TiDBRelation, @@ -264,6 +275,7 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { } def groupAggregateProjection( + filters: Seq[Expression], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], @@ -373,6 +385,10 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { val projectSeq: Seq[Attribute] = projects.asInstanceOf[Seq[Attribute]] projectSeq.foreach(attr => dagReq.addRequiredColumn(TiColumnRef.create(attr.name))) + val pushDownCols = ArrayBuffer[TiColumnRef]() + val tiFilters: Seq[TiExpr] = filters.collect { case BasicExpression(expr) => expr } + tiFilters.foreach(extractColumnFromFilter(_, pushDownCols)) + pushDownCols.foreach(dagReq.addRequiredColumn) aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, @@ -444,6 +460,7 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { ) if isValidAggregates(groupingExpressions, aggregateExpressions, filters, source) => val dagReq: TiDAGRequest = filterToDAGRequest(filters, source) groupAggregateProjection( + filters, groupingExpressions, aggregateExpressions, resultExpressions, diff --git a/tikv-client-lib-java b/tikv-client-lib-java index 54924f38c7..bd6375f6df 160000 --- a/tikv-client-lib-java +++ b/tikv-client-lib-java @@ -1 +1 @@ -Subproject commit 54924f38c72c11fa4e357bb0f56e145592d29cf7 +Subproject commit bd6375f6dfb09916006593a961d304e4969a33ea