diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a8a81d6d6574e..f61db8594dab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -258,6 +258,7 @@ case class TungstenAggregate( // The name for HashMap private var hashMapTerm: String = _ + private var sorterTerm: String = _ /** * This is called by generated Java class, should be public. @@ -286,39 +287,98 @@ case class TungstenAggregate( GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } - /** - * Update peak execution memory, called in generated Java class. + * Called by generated Java class to finish the aggregate and return a KVIterator. */ - def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + def finishAggregate( + hashMap: UnsafeFixedWidthAggregationMap, + sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { + + // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) val metrics = TaskContext.get().taskMetrics() - metrics.incPeakExecutionMemory(mapMemory) - } + metrics.incPeakExecutionMemory(peakMemory) - private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + if (sorter == null) { + // not spilled + return hashMap.iterator() + } - // create hashMap - val thisPlan = ctx.addReferenceObj("plan", this) - hashMapTerm = ctx.freshName("hashMap") - val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + // merge the final hashMap into sorter + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + val sortedIter = sorter.sortedIterator() + + // Create a KVIterator based on the sorted iterator. + new KVIterator[UnsafeRow, UnsafeRow] { + + // Create a MutableProjection to merge the rows of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = newMutableProjection( + mergeExpr, + bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + subexpressionEliminationEnabled)() + val joinedRow = new JoinedRow() + + var currentKey: UnsafeRow = null + var currentRow: UnsafeRow = null + var nextKey: UnsafeRow = if (sortedIter.next()) { + sortedIter.getKey + } else { + null + } - // Create a name for iterator from HashMap - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + override def next(): Boolean = { + if (nextKey != null) { + currentKey = nextKey.copy() + currentRow = sortedIter.getValue.copy() + nextKey = null + // use the first row as aggregate buffer + mergeProjection.target(currentRow) + + // merge the following rows with same key together + var findNextGroup = false + while (!findNextGroup && sortedIter.next()) { + val key = sortedIter.getKey + if (currentKey.equals(key)) { + mergeProjection(joinedRow(currentRow, sortedIter.getValue)) + } else { + // We find a new group. + findNextGroup = true + nextKey = key + } + } + + true + } else { + false + } + } - // generate code for output - val keyTerm = ctx.freshName("aggKey") - val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + override def getKey: UnsafeRow = currentKey + override def getValue: UnsafeRow = currentRow + override def close(): Unit = { + sortedIter.close() + } + } + } + + /** + * Generate the code for output. + */ + private def generateResultCode( + ctx: CodegenContext, + keyTerm: String, + bufferTerm: String, + plan: String): String = { + if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null ctx.INPUT_ROW = keyTerm val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).gen(ctx) } ctx.INPUT_ROW = bufferTerm val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => @@ -348,7 +408,7 @@ case class TungstenAggregate( // This should be the last operator in a stage, we should output UnsafeRow directly val joinerTerm = ctx.freshName("unsafeRowJoiner") ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, - s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + s"$joinerTerm = $plan.createUnsafeJoiner();") val resultRow = ctx.freshName("resultRow") s""" UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); @@ -367,6 +427,23 @@ case class TungstenAggregate( ${consume(ctx, eval)} """ } + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + sorterTerm = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, @@ -374,10 +451,15 @@ case class TungstenAggregate( private void $doAgg() throws java.io.IOException { ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - $iterTerm = $hashMapTerm.iterator(); + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); } """) + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + s""" if (!$initAgg) { $initAgg = true; @@ -391,8 +473,10 @@ case class TungstenAggregate( $outputCode } - $thisPlan.updatePeakMemory($hashMapTerm); - $hashMapTerm.free(); + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } """ } @@ -425,14 +509,42 @@ case class TungstenAggregate( ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) } + val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) { + val countTerm = ctx.freshName("fallbackCounter") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;") + } else { + ("true", "", "") + } + + // We try to do hash map based in-memory aggregation first. If there is not enough memory (the + // hash map will return null for new key), we spill the hash map to disk to free memory, then + // continue to do in-memory aggregation and spilling until all the rows had been processed. + // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key ${keyCode.code} - UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + UnsafeRow $buffer = null; + if ($checkFallback) { + // try to get the buffer from hash map + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + } if ($buffer == null) { - // failed to allocate the first page - throw new OutOfMemoryError("No enough memory for aggregation"); + if ($sorterTerm == null) { + $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + } else { + $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + } + $resetCoulter + // the hash map had be spilled, it should have enough memory now, + // try to allocate buffer again. + $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } } + $incCounter // evaluate aggregate function ${evals.map(_.code).mkString("\n")}