diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 33b9b804fc601..cd9dcb15a2d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -318,6 +318,20 @@ class CodegenContext { } } + /** + * Returns the code to update a variable for a given DataType. + */ + def setValue(target: String, dataType: DataType, value: String): String = { + val jt = javaType(dataType) + val codes = dataType match { + case _ if isPrimitiveType(jt) => value + case StringType => s"$value.clone()" + case _: StructType | _: ArrayType | _: MapType => s"$value.copy()" + case _ => value + } + s"$target = $codes" + } + /** * Returns the specialized code to set a given value in a column vector for a given `DataType`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fb57ed7692de4..24f94f460d8d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf @@ -38,7 +38,8 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { - case _: HashAggregateExec => "agg" + case _: HashAggregateExec => "hagg" + case _: SortAggregateExec => "sagg" case _: BroadcastHashJoinExec => "bhj" case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2a81a823c44b3..3b648b437e1f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,9 +22,11 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -38,8 +40,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { - + extends UnaryExecNode with CodegenSupport{ private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } @@ -104,6 +105,274 @@ case class SortAggregateExec( } } + override def usedInputs: AttributeSet = inputSet + + override def supportCodegen: Boolean = { + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + override protected def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + private val modes = aggregateExpressions.map(_.mode).distinct + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + private var bufVarsType: Seq[DataType] = _ + private var currentGroupingKey: String = _ + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private var initBufVarsCodes: String = _ + private var numOutput: String = _ + + private def generateInitBufVarsCodes(ctx: CodegenContext): String = { + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + bufVars = initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") + // The initial expression should not access any column + val ev = e.genCode(ctx) + val initVars = s""" + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + bufVarsType = initExpr.map { e => + e.dataType + } + evaluateVariables(bufVars) + } + + private def generateCalBufVarsCodes(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + ctx.currentVars = bufVars ++ input + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => + s""" + | ${bufVars(i).isNull} = ${ev.isNull}; + | if (${bufVars(i).value} != ${ev.value}) + | ${ctx.setValue(bufVars(i).value, bufVarsType(i), ev.value)}; + """.stripMargin + } + s""" + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(aggVals)} + | // update aggregation buffer + | ${updates.mkString("\n").trim} + """.stripMargin + } + + private def generateResultCodes(ctx: CodegenContext): String = { + if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = currentGroupingKey + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).genCode(ctx) + } + val evaluateKeyVars = evaluateVariables(keyVars) + // evaluate the aggregation result + ctx.currentVars = bufVars + val functions = aggregateExpressions.map( + _.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).genCode(ctx) + } + s""" + $evaluateKeyVars + $evaluateAggResults + ${consume(ctx, resultVars)} + """ + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val allAttributes = groupingAttributes ++ aggregateBufferAttributes + ctx.currentVars = new Array[ExprCode](groupingAttributes.length) ++ bufVars + ctx.INPUT_ROW = currentGroupingKey + val unsafeRowProjection = GenerateUnsafeProjection.createCode( + ctx, allAttributes.map(e => BindReferences.bindReference[Expression](e, allAttributes))) + s""" + |${unsafeRowProjection.code.trim} + |${consume(ctx, null, unsafeRowProjection.value)} + """.stripMargin + } else { + // generate result based on grouping key + ctx.INPUT_ROW = currentGroupingKey + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).genCode(ctx) + } + consume(ctx, eval) + } + } + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // generate variables for aggregation buffer + val initBufVarsCodes = generateInitBufVarsCodes(ctx) + + // generate variables for output + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) + } + (resultVars, s""" + |$evaluateAggResults + |${evaluateVariables(resultVars)} + """.stripMargin) + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // output the aggregate buffer directly + (bufVars, "") + } else { + // no aggregate function, the result should be literals + val resultVars = resultExpressions.map(_.genCode(ctx)) + (resultVars, evaluateVariables(resultVars)) + } + + numOutput = metricTerm(ctx, "numOutputRows") + + val doAgg = ctx.freshName("doAggregateWithoutKey") + ctx.addNewFunction(doAgg, + s""" + | private void $doAgg() throws java.io.IOException { + | // initialize aggregation buffer + | $initBufVarsCodes + | + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin) + + s""" + | while (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | ${genResult.trim} + | + | $numOutput.add(1); + | ${consume(ctx, resultVars).trim} + | } + """.stripMargin + } + + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + generateCalBufVarsCodes(ctx, input) + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + // init buffer vars + initBufVarsCodes = generateInitBufVarsCodes(ctx) + // grouping key + currentGroupingKey = ctx.freshName("currentGroupingKey") + ctx.addMutableState("UnsafeRow", currentGroupingKey, s"$currentGroupingKey = null;") + numOutput = metricTerm(ctx, "numOutputRows") + s""" + |${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + |// for the last aggregation + |if ($currentGroupingKey != null) { + | do { + | $numOutput.add(1); + | ${generateResultCodes(ctx)} + | } while (false); + | $currentGroupingKey = null; + |} + """.stripMargin + } + + def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // grouping key + ctx.INPUT_ROW = null + ctx.currentVars = input + val groupingExprCode: ExprCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val groupingKey = groupingExprCode.value + // calculate buffer vars + val calBufVarsCodes: String = generateCalBufVarsCodes(ctx, input) + s""" + |// generate grouping key + |${groupingExprCode.code.trim} + | + |if ($currentGroupingKey == null) { + | $currentGroupingKey = $groupingKey.copy(); + | // init aggregation buffer vars + | $initBufVarsCodes + | // do aggregation + | $calBufVarsCodes + | continue; + |} else { + | if ($currentGroupingKey.equals($groupingKey)) { + | // do aggregation + | $calBufVarsCodes + | continue; + | } else { + | do { + | $numOutput.add(1); + | ${generateResultCodes(ctx)} + | } while (false); + | // new grouping starts + | $currentGroupingKey = $groupingKey.copy(); + | $initBufVarsCodes + | $calBufVarsCodes + | } + |} + """.stripMargin + } + override def simpleString: String = toString(verbose = false) override def verboseString: String = toString(verbose = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0ee8c959eeb4d..efa8d204a1652 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -355,6 +355,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("sort based aggregation with codegen") { + checkAnswer(Seq(("a", "10")).toDF("k", "v").groupBy("k").agg(max("v")), Row("a", "10") :: Nil) + checkAnswer(Seq(("a", "10"), ("b", "1"), ("b", "2"), ("c", "5"), ("c", "3")). + toDF("k", "v").groupBy("k").agg(max("v")), + Row("a", "10") :: Row("b", "2") :: Row("c", "5") :: Nil) + checkAnswer(Seq(("a", "10", 2), ("b", "1", 3), ("b", "2", 4), ("c", "5", 1), ("c", "3", 5)). + toDF("k", "v1", "v2").groupBy("k").agg(max("v1"), min("v2"), count("v2")), + Row("a", "10", 2, 1) :: Row("b", "2", 3, 2) :: Row("c", "5", 1, 2) :: Nil) + checkAnswer(Seq(("a", "3"), ("b", "20"), ("b", "2")).toDF("k", "v").agg(max("v")), + Row("3") :: Nil) + checkAnswer(Seq(("a", "10", 2), ("b", "1", 3), ("b", "2", 4), ("c", "5", 1), ("c", "3", 5)). + toDF("k", "v1", "v2").agg(max("v1"), min("v2"), count("v2")), + Row("5", 1, 5) :: Nil) + checkAnswer( + sql("SELECT key, max(value) FROM testData GROUP BY key"), + (1 to 100).map(i => Row(i, i.toString))) + checkAnswer(sql("SELECT max(value) FROM testData"), Row("99") :: Nil) + } + test("Add Parser of SQL COALESCE()") { checkAnswer( sql("""SELECT COALESCE(1, 2)"""),