Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12951] [SQL] support spilling in generated aggregate #10998

Closed
wants to merge 5 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator

case class TungstenAggregate(
Expand Down Expand Up @@ -259,6 +259,7 @@ case class TungstenAggregate(

// The name for HashMap
private var hashMapTerm: String = _
private var sorterTerm: String = _

/**
* This is called by generated Java class, should be public.
Expand Down Expand Up @@ -287,39 +288,98 @@ case class TungstenAggregate(
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}


/**
* Update peak execution memory, called in generated Java class.
* Called by generated Java class to finish the aggregate and return a KVIterator.
*/
def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {

// update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
val peakMemory = Math.max(mapMemory, sorterMemory)
val metrics = TaskContext.get().taskMetrics()
metrics.incPeakExecutionMemory(mapMemory)
}
metrics.incPeakExecutionMemory(peakMemory)

private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
if (sorter == null) {
// not spilled
return hashMap.iterator()
}

// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
// merge the final hashMap into sorter
sorter.merge(hashMap.destructAndCreateExternalSorter())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you call free after destructAndCreate() but in the other places you don't. Do you need to?

hashMap.free()
val sortedIter = sorter.sortedIterator()

// Create a KVIterator based on the sorted iterator.
new KVIterator[UnsafeRow, UnsafeRow] {

// Create a MutableProjection to merge the rows of same key together
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)()
val joinedRow = new JoinedRow()

var currentKey: UnsafeRow = null
var currentRow: UnsafeRow = null
var nextKey: UnsafeRow = if (sortedIter.next()) {
sortedIter.getKey
} else {
null
}

// Create a name for iterator from HashMap
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
override def next(): Boolean = {
if (nextKey != null) {
currentKey = nextKey.copy()
currentRow = sortedIter.getValue.copy()
nextKey = null
// use the first row as aggregate buffer
mergeProjection.target(currentRow)

// merge the following rows with same key together
var findNextGroup = false
while (!findNextGroup && sortedIter.next()) {
val key = sortedIter.getKey
if (currentKey.equals(key)) {
mergeProjection(joinedRow(currentRow, sortedIter.getValue))
} else {
// We find a new group.
findNextGroup = true
nextKey = key
}
}

true
} else {
false
}
}

// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
override def getKey: UnsafeRow = currentKey
override def getValue: UnsafeRow = currentRow
override def close(): Unit = {
sortedIter.close()
}
}
}

/**
* Generate the code for output.
*/
private def generateResultCode(
ctx: CodegenContext,
keyTerm: String,
bufferTerm: String,
plan: String): String = {
if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
ctx.INPUT_ROW = keyTerm
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
ctx.INPUT_ROW = bufferTerm
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
Expand Down Expand Up @@ -349,7 +409,7 @@ case class TungstenAggregate(
// This should be the last operator in a stage, we should output UnsafeRow directly
val joinerTerm = ctx.freshName("unsafeRowJoiner")
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
s"$joinerTerm = $plan.createUnsafeJoiner();")
val resultRow = ctx.freshName("resultRow")
s"""
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
Expand All @@ -368,17 +428,39 @@ case class TungstenAggregate(
${consume(ctx, eval)}
"""
}
}

private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")

// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")

// Create a name for iterator from HashMap
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")

val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
s"""
private void $doAgg() throws java.io.IOException {
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}

$iterTerm = $hashMapTerm.iterator();
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
}
""")

// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)

s"""
if (!$initAgg) {
$initAgg = true;
Expand All @@ -392,8 +474,10 @@ case class TungstenAggregate(
$outputCode
}

$thisPlan.updatePeakMemory($hashMapTerm);
$hashMapTerm.free();
$iterTerm.close();
if ($sorterTerm == null) {
$hashMapTerm.free();
}
"""
}

Expand Down Expand Up @@ -426,14 +510,39 @@ case class TungstenAggregate(
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
}

val countTerm = ctx.freshName("count")
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
val checkFallback = if (testFallbackStartsAt.isDefined) {
s"$countTerm < ${testFallbackStartsAt.get}"
} else {
"true"
}

// We try to do hash map based in-memory aggregation first. If there is not enough memory (the
// hash map will return null for new key), we spill the hash map to disk to free memory, then
// continue to do in-memory aggregation and spilling until all the rows had been processed.
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code}
UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
UnsafeRow $buffer = null;
if ($checkFallback) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this. You seem to do a look up twice and line (539) . Is that intentional?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm. I get it now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is confusing we should consider adding documentations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments.

$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
}
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
if ($sorterTerm == null) {
$sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
} else {
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
}
$countTerm = 0;
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
}
}
$countTerm += 1;

// evaluate aggregate function
${evals.map(_.code).mkString("\n")}
Expand Down