-
Notifications
You must be signed in to change notification settings - Fork 28.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-13749][SQL] Faster pivot implementation for many distinct valu…
…es with two phase aggregation ## What changes were proposed in this pull request? The existing implementation of pivot translates into a single aggregation with one aggregate per distinct pivot value. When the number of distinct pivot values is large (say 1000+) this can get extremely slow since each input value gets evaluated on every aggregate even though it only affects the value of one of them. I'm proposing an alternate strategy for when there are 10+ (somewhat arbitrary threshold) distinct pivot values. We do two phases of aggregation. In the first we group by the grouping columns plus the pivot column and perform the specified aggregations (one or sometimes more). In the second aggregation we group by the grouping columns and use the new (non public) PivotFirst aggregate that rearranges the outputs of the first aggregation into an array indexed by the pivot value. Finally we do a project to extract the array entries into the appropriate output column. ## How was this patch tested? Additional unit tests in DataFramePivotSuite and manual larger scale testing. Author: Andrew Ray <[email protected]> Closes #11583 from aray/fast-pivot.
- Loading branch information
Showing
3 changed files
with
296 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
152 changes: 152 additions & 0 deletions
152
...alyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
||
import scala.collection.immutable.HashMap | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.util.GenericArrayData | ||
import org.apache.spark.sql.types._ | ||
|
||
object PivotFirst { | ||
|
||
def supportsDataType(dataType: DataType): Boolean = updateFunction.isDefinedAt(dataType) | ||
|
||
// Currently UnsafeRow does not support the generic update method (throws | ||
// UnsupportedOperationException), so we need to explicitly support each DataType. | ||
private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = { | ||
case DoubleType => | ||
(row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double]) | ||
case IntegerType => | ||
(row, offset, value) => row.setInt(offset, value.asInstanceOf[Int]) | ||
case LongType => | ||
(row, offset, value) => row.setLong(offset, value.asInstanceOf[Long]) | ||
case FloatType => | ||
(row, offset, value) => row.setFloat(offset, value.asInstanceOf[Float]) | ||
case BooleanType => | ||
(row, offset, value) => row.setBoolean(offset, value.asInstanceOf[Boolean]) | ||
case ShortType => | ||
(row, offset, value) => row.setShort(offset, value.asInstanceOf[Short]) | ||
case ByteType => | ||
(row, offset, value) => row.setByte(offset, value.asInstanceOf[Byte]) | ||
case d: DecimalType => | ||
(row, offset, value) => row.setDecimal(offset, value.asInstanceOf[Decimal], d.precision) | ||
} | ||
} | ||
|
||
/** | ||
* PivotFirst is a aggregate function used in the second phase of a two phase pivot to do the | ||
* required rearrangement of values into pivoted form. | ||
* | ||
* For example on an input of | ||
* A | B | ||
* --+-- | ||
* x | 1 | ||
* y | 2 | ||
* z | 3 | ||
* | ||
* with pivotColumn=A, valueColumn=B, and pivotColumnValues=[z,y] the output is [3,2]. | ||
* | ||
* @param pivotColumn column that determines which output position to put valueColumn in. | ||
* @param valueColumn the column that is being rearranged. | ||
* @param pivotColumnValues the list of pivotColumn values in the order of desired output. Values | ||
* not listed here will be ignored. | ||
*/ | ||
case class PivotFirst( | ||
pivotColumn: Expression, | ||
valueColumn: Expression, | ||
pivotColumnValues: Seq[Any], | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) extends ImperativeAggregate { | ||
|
||
override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil | ||
|
||
override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType) | ||
|
||
override val nullable: Boolean = false | ||
|
||
val valueDataType = valueColumn.dataType | ||
|
||
override val dataType: DataType = ArrayType(valueDataType) | ||
|
||
val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) | ||
|
||
val indexSize = pivotIndex.size | ||
|
||
private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) | ||
|
||
override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = { | ||
val pivotColValue = pivotColumn.eval(inputRow) | ||
if (pivotColValue != null) { | ||
// We ignore rows whose pivot column value is not in the list of pivot column values. | ||
val index = pivotIndex.getOrElse(pivotColValue, -1) | ||
if (index >= 0) { | ||
val value = valueColumn.eval(inputRow) | ||
if (value != null) { | ||
updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) | ||
} | ||
} | ||
} | ||
} | ||
|
||
override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = { | ||
for (i <- 0 until indexSize) { | ||
if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) { | ||
val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType) | ||
updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value) | ||
} | ||
} | ||
} | ||
|
||
override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match { | ||
case d: DecimalType => | ||
// Per doc of setDecimal we need to do this instead of setNullAt for DecimalType. | ||
for (i <- 0 until indexSize) { | ||
mutableAggBuffer.setDecimal(mutableAggBufferOffset + i, null, d.precision) | ||
} | ||
case _ => | ||
for (i <- 0 until indexSize) { | ||
mutableAggBuffer.setNullAt(mutableAggBufferOffset + i) | ||
} | ||
} | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val result = new Array[Any](indexSize) | ||
for (i <- 0 until indexSize) { | ||
result(i) = input.get(mutableAggBufferOffset + i, valueDataType) | ||
} | ||
new GenericArrayData(result) | ||
} | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
|
||
override lazy val aggBufferAttributes: Seq[AttributeReference] = | ||
pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) | ||
|
||
override lazy val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
|
||
override lazy val inputAggBufferAttributes: Seq[AttributeReference] = | ||
aggBufferAttributes.map(_.newInstance()) | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters