-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-12951] [SQL] support spilling in generated aggregate #10998
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._ | |
import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.plans.physical._ | ||
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} | ||
import org.apache.spark.sql.execution._ | ||
import org.apache.spark.sql.execution.metric.SQLMetrics | ||
import org.apache.spark.sql.types.{DecimalType, StructType} | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.unsafe.KVIterator | ||
|
||
case class TungstenAggregate( | ||
|
@@ -259,6 +259,7 @@ case class TungstenAggregate( | |
|
||
// The name for HashMap | ||
private var hashMapTerm: String = _ | ||
private var sorterTerm: String = _ | ||
|
||
/** | ||
* This is called by generated Java class, should be public. | ||
|
@@ -287,39 +288,98 @@ case class TungstenAggregate( | |
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) | ||
} | ||
|
||
|
||
/** | ||
* Update peak execution memory, called in generated Java class. | ||
* Called by generated Java class to finish the aggregate and return a KVIterator. | ||
*/ | ||
def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { | ||
def finishAggregate( | ||
hashMap: UnsafeFixedWidthAggregationMap, | ||
sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { | ||
|
||
// update peak execution memory | ||
val mapMemory = hashMap.getPeakMemoryUsedBytes | ||
val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) | ||
val peakMemory = Math.max(mapMemory, sorterMemory) | ||
val metrics = TaskContext.get().taskMetrics() | ||
metrics.incPeakExecutionMemory(mapMemory) | ||
} | ||
metrics.incPeakExecutionMemory(peakMemory) | ||
|
||
private def doProduceWithKeys(ctx: CodegenContext): String = { | ||
val initAgg = ctx.freshName("initAgg") | ||
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") | ||
if (sorter == null) { | ||
// not spilled | ||
return hashMap.iterator() | ||
} | ||
|
||
// create hashMap | ||
val thisPlan = ctx.addReferenceObj("plan", this) | ||
hashMapTerm = ctx.freshName("hashMap") | ||
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName | ||
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") | ||
// merge the final hashMap into sorter | ||
sorter.merge(hashMap.destructAndCreateExternalSorter()) | ||
hashMap.free() | ||
val sortedIter = sorter.sortedIterator() | ||
|
||
// Create a KVIterator based on the sorted iterator. | ||
new KVIterator[UnsafeRow, UnsafeRow] { | ||
|
||
// Create a MutableProjection to merge the rows of same key together | ||
val mergeExpr = declFunctions.flatMap(_.mergeExpressions) | ||
val mergeProjection = newMutableProjection( | ||
mergeExpr, | ||
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), | ||
subexpressionEliminationEnabled)() | ||
val joinedRow = new JoinedRow() | ||
|
||
var currentKey: UnsafeRow = null | ||
var currentRow: UnsafeRow = null | ||
var nextKey: UnsafeRow = if (sortedIter.next()) { | ||
sortedIter.getKey | ||
} else { | ||
null | ||
} | ||
|
||
// Create a name for iterator from HashMap | ||
val iterTerm = ctx.freshName("mapIter") | ||
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") | ||
override def next(): Boolean = { | ||
if (nextKey != null) { | ||
currentKey = nextKey.copy() | ||
currentRow = sortedIter.getValue.copy() | ||
nextKey = null | ||
// use the first row as aggregate buffer | ||
mergeProjection.target(currentRow) | ||
|
||
// merge the following rows with same key together | ||
var findNextGroup = false | ||
while (!findNextGroup && sortedIter.next()) { | ||
val key = sortedIter.getKey | ||
if (currentKey.equals(key)) { | ||
mergeProjection(joinedRow(currentRow, sortedIter.getValue)) | ||
} else { | ||
// We find a new group. | ||
findNextGroup = true | ||
nextKey = key | ||
} | ||
} | ||
|
||
true | ||
} else { | ||
false | ||
} | ||
} | ||
|
||
// generate code for output | ||
val keyTerm = ctx.freshName("aggKey") | ||
val bufferTerm = ctx.freshName("aggBuffer") | ||
val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { | ||
override def getKey: UnsafeRow = currentKey | ||
override def getValue: UnsafeRow = currentRow | ||
override def close(): Unit = { | ||
sortedIter.close() | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Generate the code for output. | ||
*/ | ||
private def generateResultCode( | ||
ctx: CodegenContext, | ||
keyTerm: String, | ||
bufferTerm: String, | ||
plan: String): String = { | ||
if (modes.contains(Final) || modes.contains(Complete)) { | ||
// generate output using resultExpressions | ||
ctx.currentVars = null | ||
ctx.INPUT_ROW = keyTerm | ||
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => | ||
BoundReference(i, e.dataType, e.nullable).gen(ctx) | ||
BoundReference(i, e.dataType, e.nullable).gen(ctx) | ||
} | ||
ctx.INPUT_ROW = bufferTerm | ||
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => | ||
|
@@ -349,7 +409,7 @@ case class TungstenAggregate( | |
// This should be the last operator in a stage, we should output UnsafeRow directly | ||
val joinerTerm = ctx.freshName("unsafeRowJoiner") | ||
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, | ||
s"$joinerTerm = $thisPlan.createUnsafeJoiner();") | ||
s"$joinerTerm = $plan.createUnsafeJoiner();") | ||
val resultRow = ctx.freshName("resultRow") | ||
s""" | ||
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); | ||
|
@@ -368,17 +428,39 @@ case class TungstenAggregate( | |
${consume(ctx, eval)} | ||
""" | ||
} | ||
} | ||
|
||
private def doProduceWithKeys(ctx: CodegenContext): String = { | ||
val initAgg = ctx.freshName("initAgg") | ||
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") | ||
|
||
// create hashMap | ||
val thisPlan = ctx.addReferenceObj("plan", this) | ||
hashMapTerm = ctx.freshName("hashMap") | ||
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName | ||
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") | ||
sorterTerm = ctx.freshName("sorter") | ||
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") | ||
|
||
// Create a name for iterator from HashMap | ||
val iterTerm = ctx.freshName("mapIter") | ||
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") | ||
|
||
val doAgg = ctx.freshName("doAggregateWithKeys") | ||
ctx.addNewFunction(doAgg, | ||
s""" | ||
private void $doAgg() throws java.io.IOException { | ||
${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
|
||
$iterTerm = $hashMapTerm.iterator(); | ||
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); | ||
} | ||
""") | ||
|
||
// generate code for output | ||
val keyTerm = ctx.freshName("aggKey") | ||
val bufferTerm = ctx.freshName("aggBuffer") | ||
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) | ||
|
||
s""" | ||
if (!$initAgg) { | ||
$initAgg = true; | ||
|
@@ -392,8 +474,10 @@ case class TungstenAggregate( | |
$outputCode | ||
} | ||
|
||
$thisPlan.updatePeakMemory($hashMapTerm); | ||
$hashMapTerm.free(); | ||
$iterTerm.close(); | ||
if ($sorterTerm == null) { | ||
$hashMapTerm.free(); | ||
} | ||
""" | ||
} | ||
|
||
|
@@ -426,14 +510,39 @@ case class TungstenAggregate( | |
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) | ||
} | ||
|
||
val countTerm = ctx.freshName("count") | ||
ctx.addMutableState("int", countTerm, s"$countTerm = 0;") | ||
val checkFallback = if (testFallbackStartsAt.isDefined) { | ||
s"$countTerm < ${testFallbackStartsAt.get}" | ||
} else { | ||
"true" | ||
} | ||
|
||
// We try to do hash map based in-memory aggregation first. If there is not enough memory (the | ||
// hash map will return null for new key), we spill the hash map to disk to free memory, then | ||
// continue to do in-memory aggregation and spilling until all the rows had been processed. | ||
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key. | ||
s""" | ||
// generate grouping key | ||
${keyCode.code} | ||
UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); | ||
UnsafeRow $buffer = null; | ||
if ($checkFallback) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get this. You seem to do a look up twice and line (539) . Is that intentional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nvm. I get it now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if it is confusing we should consider adding documentations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comments. |
||
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); | ||
} | ||
if ($buffer == null) { | ||
// failed to allocate the first page | ||
throw new OutOfMemoryError("No enough memory for aggregation"); | ||
if ($sorterTerm == null) { | ||
$sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); | ||
} else { | ||
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); | ||
} | ||
$countTerm = 0; | ||
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); | ||
if ($buffer == null) { | ||
// failed to allocate the first page | ||
throw new OutOfMemoryError("No enough memory for aggregation"); | ||
} | ||
} | ||
$countTerm += 1; | ||
|
||
// evaluate aggregate function | ||
${evals.map(_.code).mkString("\n")} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you call free after destructAndCreate() but in the other places you don't. Do you need to?