Skip to content

Commit

Permalink
[SPARK-12951] [SQL] support spilling in generated aggregate
Browse files Browse the repository at this point in the history
This PR add spilling support for generated TungstenAggregate.

If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated).

The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite

Author: Davies Liu <[email protected]>

Closes #10998 from davies/gen_spilling.
  • Loading branch information
Davies Liu authored and davies committed Feb 3, 2016
1 parent ff71261 commit 99a6e3c
Showing 1 changed file with 142 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator

case class TungstenAggregate(
Expand Down Expand Up @@ -258,6 +258,7 @@ case class TungstenAggregate(

// The name for HashMap
private var hashMapTerm: String = _
private var sorterTerm: String = _

/**
* This is called by generated Java class, should be public.
Expand Down Expand Up @@ -286,39 +287,98 @@ case class TungstenAggregate(
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}


/**
* Update peak execution memory, called in generated Java class.
* Called by generated Java class to finish the aggregate and return a KVIterator.
*/
def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {

// update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
val peakMemory = Math.max(mapMemory, sorterMemory)
val metrics = TaskContext.get().taskMetrics()
metrics.incPeakExecutionMemory(mapMemory)
}
metrics.incPeakExecutionMemory(peakMemory)

private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
if (sorter == null) {
// not spilled
return hashMap.iterator()
}

// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
// merge the final hashMap into sorter
sorter.merge(hashMap.destructAndCreateExternalSorter())
hashMap.free()
val sortedIter = sorter.sortedIterator()

// Create a KVIterator based on the sorted iterator.
new KVIterator[UnsafeRow, UnsafeRow] {

// Create a MutableProjection to merge the rows of same key together
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)()
val joinedRow = new JoinedRow()

var currentKey: UnsafeRow = null
var currentRow: UnsafeRow = null
var nextKey: UnsafeRow = if (sortedIter.next()) {
sortedIter.getKey
} else {
null
}

// Create a name for iterator from HashMap
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
override def next(): Boolean = {
if (nextKey != null) {
currentKey = nextKey.copy()
currentRow = sortedIter.getValue.copy()
nextKey = null
// use the first row as aggregate buffer
mergeProjection.target(currentRow)

// merge the following rows with same key together
var findNextGroup = false
while (!findNextGroup && sortedIter.next()) {
val key = sortedIter.getKey
if (currentKey.equals(key)) {
mergeProjection(joinedRow(currentRow, sortedIter.getValue))
} else {
// We find a new group.
findNextGroup = true
nextKey = key
}
}

true
} else {
false
}
}

// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
override def getKey: UnsafeRow = currentKey
override def getValue: UnsafeRow = currentRow
override def close(): Unit = {
sortedIter.close()
}
}
}

/**
* Generate the code for output.
*/
private def generateResultCode(
ctx: CodegenContext,
keyTerm: String,
bufferTerm: String,
plan: String): String = {
if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
ctx.INPUT_ROW = keyTerm
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
ctx.INPUT_ROW = bufferTerm
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
Expand Down Expand Up @@ -348,7 +408,7 @@ case class TungstenAggregate(
// This should be the last operator in a stage, we should output UnsafeRow directly
val joinerTerm = ctx.freshName("unsafeRowJoiner")
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
s"$joinerTerm = $plan.createUnsafeJoiner();")
val resultRow = ctx.freshName("resultRow")
s"""
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
Expand All @@ -367,17 +427,39 @@ case class TungstenAggregate(
${consume(ctx, eval)}
"""
}
}

private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")

// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")

// Create a name for iterator from HashMap
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")

val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
s"""
private void $doAgg() throws java.io.IOException {
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}

$iterTerm = $hashMapTerm.iterator();
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
}
""")

// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)

s"""
if (!$initAgg) {
$initAgg = true;
Expand All @@ -391,8 +473,10 @@ case class TungstenAggregate(
$outputCode
}

$thisPlan.updatePeakMemory($hashMapTerm);
$hashMapTerm.free();
$iterTerm.close();
if ($sorterTerm == null) {
$hashMapTerm.free();
}
"""
}

Expand Down Expand Up @@ -425,14 +509,42 @@ case class TungstenAggregate(
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
}

val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
val countTerm = ctx.freshName("fallbackCounter")
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
(s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
("true", "", "")
}

// We try to do hash map based in-memory aggregation first. If there is not enough memory (the
// hash map will return null for new key), we spill the hash map to disk to free memory, then
// continue to do in-memory aggregation and spilling until all the rows had been processed.
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code}
UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
}
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
if ($sorterTerm == null) {
$sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
} else {
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
}
$resetCoulter
// the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again.
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
}
}
$incCounter

// evaluate aggregate function
${evals.map(_.code).mkString("\n")}
Expand Down

0 comments on commit 99a6e3c

Please sign in to comment.