Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16844][SQL] Generate code for sort based aggregation #14481

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,20 @@ class CodegenContext {
}
}

/**
* Returns the code to update a variable for a given DataType.
*/
def setValue(target: String, dataType: DataType, value: String): String = {
val jt = javaType(dataType)
val codes = dataType match {
case _ if isPrimitiveType(jt) => value
case StringType => s"$value.clone()"
case _: StructType | _: ArrayType | _: MapType => s"$value.copy()"
case _ => value
}
s"$target = $codes"
}

/**
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -38,7 +38,8 @@ trait CodegenSupport extends SparkPlan {

/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: HashAggregateExec => "agg"
case _: HashAggregateExec => "hagg"
case _: SortAggregateExec => "sagg"
case _: BroadcastHashJoinExec => "bhj"
case _: SortMergeJoinExec => "smj"
case _: RDDScanExec => "rdd"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

/**
Expand All @@ -38,8 +40,7 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {

extends UnaryExecNode with CodegenSupport{
private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}
Expand Down Expand Up @@ -104,6 +105,274 @@ case class SortAggregateExec(
}
}

override def usedInputs: AttributeSet = inputSet

override def supportCodegen: Boolean = {
// ImperativeAggregate is not supported right now
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
}

override protected def doProduce(ctx: CodegenContext): String = {
if (groupingExpressions.isEmpty) {
doProduceWithoutKeys(ctx)
} else {
doProduceWithKeys(ctx)
}
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
if (groupingExpressions.isEmpty) {
doConsumeWithoutKeys(ctx, input)
} else {
doConsumeWithKeys(ctx, input)
}
}

private val modes = aggregateExpressions.map(_.mode).distinct
// The variables used as aggregation buffer
private var bufVars: Seq[ExprCode] = _
private var bufVarsType: Seq[DataType] = _
private var currentGroupingKey: String = _
private val groupingAttributes = groupingExpressions.map(_.toAttribute)
private var initBufVarsCodes: String = _
private var numOutput: String = _

private def generateInitBufVarsCodes(ctx: CodegenContext): String = {
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue")
ctx.addMutableState("boolean", isNull, "")
ctx.addMutableState(ctx.javaType(e.dataType), value, "")
// The initial expression should not access any column
val ev = e.genCode(ctx)
val initVars = s"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
bufVarsType = initExpr.map { e =>
e.dataType
}
evaluateVariables(bufVars)
}

private def generateCalBufVarsCodes(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
val updateExpr = aggregateExpressions.flatMap { e =>
e.mode match {
case Partial | Complete =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
case PartialMerge | Final =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
// aggregate buffer should be updated atomic
val updates = aggVals.zipWithIndex.map { case (ev, i) =>
s"""
| ${bufVars(i).isNull} = ${ev.isNull};
| if (${bufVars(i).value} != ${ev.value})
| ${ctx.setValue(bufVars(i).value, bufVarsType(i), ev.value)};
""".stripMargin
}
s"""
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate function
| ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
""".stripMargin
}

private def generateResultCodes(ctx: CodegenContext): String = {
if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
ctx.INPUT_ROW = currentGroupingKey
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
}
val evaluateKeyVars = evaluateVariables(keyVars)
// evaluate the aggregation result
ctx.currentVars = bufVars
val functions = aggregateExpressions.map(
_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
s"""
$evaluateKeyVars
$evaluateAggResults
${consume(ctx, resultVars)}
"""
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// This should be the last operator in a stage, we should output UnsafeRow directly
val allAttributes = groupingAttributes ++ aggregateBufferAttributes
ctx.currentVars = new Array[ExprCode](groupingAttributes.length) ++ bufVars
ctx.INPUT_ROW = currentGroupingKey
val unsafeRowProjection = GenerateUnsafeProjection.createCode(
ctx, allAttributes.map(e => BindReferences.bindReference[Expression](e, allAttributes)))
s"""
|${unsafeRowProjection.code.trim}
|${consume(ctx, null, unsafeRowProjection.value)}
""".stripMargin
} else {
// generate result based on grouping key
ctx.INPUT_ROW = currentGroupingKey
ctx.currentVars = null
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
}
consume(ctx, eval)
}
}

private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")

// generate variables for aggregation buffer
val initBufVarsCodes = generateInitBufVarsCodes(ctx)

// generate variables for output
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
}
val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
}
(resultVars, s"""
|$evaluateAggResults
|${evaluateVariables(resultVars)}
""".stripMargin)
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// output the aggregate buffer directly
(bufVars, "")
} else {
// no aggregate function, the result should be literals
val resultVars = resultExpressions.map(_.genCode(ctx))
(resultVars, evaluateVariables(resultVars))
}

numOutput = metricTerm(ctx, "numOutputRows")

val doAgg = ctx.freshName("doAggregateWithoutKey")
ctx.addNewFunction(doAgg,
s"""
| private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
| $initBufVarsCodes
|
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin)

s"""
| while (!$initAgg) {
| $initAgg = true;
| $doAgg();
|
| // output the result
| ${genResult.trim}
|
| $numOutput.add(1);
| ${consume(ctx, resultVars).trim}
| }
""".stripMargin
}

private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
generateCalBufVarsCodes(ctx, input)
}

private def doProduceWithKeys(ctx: CodegenContext): String = {
// init buffer vars
initBufVarsCodes = generateInitBufVarsCodes(ctx)
// grouping key
currentGroupingKey = ctx.freshName("currentGroupingKey")
ctx.addMutableState("UnsafeRow", currentGroupingKey, s"$currentGroupingKey = null;")
numOutput = metricTerm(ctx, "numOutputRows")
s"""
|${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
|// for the last aggregation
|if ($currentGroupingKey != null) {
| do {
| $numOutput.add(1);
| ${generateResultCodes(ctx)}
| } while (false);
| $currentGroupingKey = null;
|}
""".stripMargin
}

def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// grouping key
ctx.INPUT_ROW = null
ctx.currentVars = input
val groupingExprCode: ExprCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val groupingKey = groupingExprCode.value
// calculate buffer vars
val calBufVarsCodes: String = generateCalBufVarsCodes(ctx, input)
s"""
|// generate grouping key
|${groupingExprCode.code.trim}
|
|if ($currentGroupingKey == null) {
| $currentGroupingKey = $groupingKey.copy();
| // init aggregation buffer vars
| $initBufVarsCodes
| // do aggregation
| $calBufVarsCodes
| continue;
|} else {
| if ($currentGroupingKey.equals($groupingKey)) {
| // do aggregation
| $calBufVarsCodes
| continue;
| } else {
| do {
| $numOutput.add(1);
| ${generateResultCodes(ctx)}
| } while (false);
| // new grouping starts
| $currentGroupingKey = $groupingKey.copy();
| $initBufVarsCodes
| $calBufVarsCodes
| }
|}
""".stripMargin
}

override def simpleString: String = toString(verbose = false)

override def verboseString: String = toString(verbose = true)
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("sort based aggregation with codegen") {
checkAnswer(Seq(("a", "10")).toDF("k", "v").groupBy("k").agg(max("v")), Row("a", "10") :: Nil)
checkAnswer(Seq(("a", "10"), ("b", "1"), ("b", "2"), ("c", "5"), ("c", "3")).
toDF("k", "v").groupBy("k").agg(max("v")),
Row("a", "10") :: Row("b", "2") :: Row("c", "5") :: Nil)
checkAnswer(Seq(("a", "10", 2), ("b", "1", 3), ("b", "2", 4), ("c", "5", 1), ("c", "3", 5)).
toDF("k", "v1", "v2").groupBy("k").agg(max("v1"), min("v2"), count("v2")),
Row("a", "10", 2, 1) :: Row("b", "2", 3, 2) :: Row("c", "5", 1, 2) :: Nil)
checkAnswer(Seq(("a", "3"), ("b", "20"), ("b", "2")).toDF("k", "v").agg(max("v")),
Row("3") :: Nil)
checkAnswer(Seq(("a", "10", 2), ("b", "1", 3), ("b", "2", 4), ("c", "5", 1), ("c", "3", 5)).
toDF("k", "v1", "v2").agg(max("v1"), min("v2"), count("v2")),
Row("5", 1, 5) :: Nil)
checkAnswer(
sql("SELECT key, max(value) FROM testData GROUP BY key"),
(1 to 100).map(i => Row(i, i.toString)))
checkAnswer(sql("SELECT max(value) FROM testData"), Row("99") :: Nil)
}

test("Add Parser of SQL COALESCE()") {
checkAnswer(
sql("""SELECT COALESCE(1, 2)"""),
Expand Down