-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-13749][SQL] Faster pivot implementation for many distinct values with two phase aggregation #11583
[SPARK-13749][SQL] Faster pivot implementation for many distinct values with two phase aggregation #11583
Changes from all commits
75a101a
e42cb36
b65cfb2
adbcd1b
4b33b47
d0b0b2f
7a662ba
66e69db
b3ccf61
bc0571d
cc9f49f
bffc7aa
359a374
28bbbef
32e97a2
1723046
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -363,43 +363,68 @@ class Analyzer( | |
|
||
object ResolvePivot extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p | ||
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) | ||
| !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p | ||
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => | ||
val singleAgg = aggregates.size == 1 | ||
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => | ||
def ifExpr(expr: Expression) = { | ||
If(EqualTo(pivotColumn, value), expr, Literal(null)) | ||
def outputName(value: Literal, aggregate: Expression): String = { | ||
if (singleAgg) value.toString else value + "_" + aggregate.sql | ||
} | ||
if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { | ||
// Since evaluating |pivotValues| if statements for each input row can get slow this is an | ||
// alternate plan that instead uses two steps of aggregation. | ||
val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) | ||
val namedPivotCol = pivotColumn match { | ||
case n: NamedExpression => n | ||
case _ => Alias(pivotColumn, "__pivot_col")() | ||
} | ||
val bigGroup = groupByExprs :+ namedPivotCol | ||
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) | ||
val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) | ||
val pivotAggs = namedAggExps.map { a => | ||
Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) | ||
.toAggregateExpression() | ||
, "__pivot_" + a.sql)() | ||
} | ||
val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg) | ||
val pivotAggAttribute = pivotAggs.map(_.toAttribute) | ||
val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => | ||
aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => | ||
Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() | ||
} | ||
} | ||
aggregates.map { aggregate => | ||
val filteredAggregate = aggregate.transformDown { | ||
// Assumption is the aggregate function ignores nulls. This is true for all current | ||
// AggregateFunction's with the exception of First and Last in their default mode | ||
// (which we handle) and possibly some Hive UDAF's. | ||
case First(expr, _) => | ||
First(ifExpr(expr), Literal(true)) | ||
case Last(expr, _) => | ||
Last(ifExpr(expr), Literal(true)) | ||
case a: AggregateFunction => | ||
a.withNewChildren(a.children.map(ifExpr)) | ||
}.transform { | ||
// We are duplicating aggregates that are now computing a different value for each | ||
// pivot value. | ||
// TODO: Don't construct the physical container until after analysis. | ||
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) | ||
Project(groupByExprs ++ pivotOutputs, secondAgg) | ||
} else { | ||
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => | ||
def ifExpr(expr: Expression) = { | ||
If(EqualTo(pivotColumn, value), expr, Literal(null)) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This map is not needed anymore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, added a check for |
||
if (filteredAggregate.fastEquals(aggregate)) { | ||
throw new AnalysisException( | ||
s"Aggregate expression required for pivot, found '$aggregate'") | ||
aggregates.map { aggregate => | ||
val filteredAggregate = aggregate.transformDown { | ||
// Assumption is the aggregate function ignores nulls. This is true for all current | ||
// AggregateFunction's with the exception of First and Last in their default mode | ||
// (which we handle) and possibly some Hive UDAF's. | ||
case First(expr, _) => | ||
First(ifExpr(expr), Literal(true)) | ||
case Last(expr, _) => | ||
Last(ifExpr(expr), Literal(true)) | ||
case a: AggregateFunction => | ||
a.withNewChildren(a.children.map(ifExpr)) | ||
}.transform { | ||
// We are duplicating aggregates that are now computing a different value for each | ||
// pivot value. | ||
// TODO: Don't construct the physical container until after analysis. | ||
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) | ||
} | ||
if (filteredAggregate.fastEquals(aggregate)) { | ||
throw new AnalysisException( | ||
s"Aggregate expression required for pivot, found '$aggregate'") | ||
} | ||
Alias(filteredAggregate, outputName(value, aggregate))() | ||
} | ||
val name = if (singleAgg) value.toString else value + "_" + aggregate.sql | ||
Alias(filteredAggregate, name)() | ||
} | ||
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) | ||
} | ||
val newGroupByExprs = groupByExprs.map { | ||
case UnresolvedAlias(e, _) => e | ||
case e => e | ||
} | ||
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
||
import scala.collection.immutable.HashMap | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.util.GenericArrayData | ||
import org.apache.spark.sql.types._ | ||
|
||
object PivotFirst { | ||
|
||
def supportsDataType(dataType: DataType): Boolean = updateFunction.isDefinedAt(dataType) | ||
|
||
// Currently UnsafeRow does not support the generic update method (throws | ||
// UnsupportedOperationException), so we need to explicitly support each DataType. | ||
private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = { | ||
case DoubleType => | ||
(row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double]) | ||
case IntegerType => | ||
(row, offset, value) => row.setInt(offset, value.asInstanceOf[Int]) | ||
case LongType => | ||
(row, offset, value) => row.setLong(offset, value.asInstanceOf[Long]) | ||
case FloatType => | ||
(row, offset, value) => row.setFloat(offset, value.asInstanceOf[Float]) | ||
case BooleanType => | ||
(row, offset, value) => row.setBoolean(offset, value.asInstanceOf[Boolean]) | ||
case ShortType => | ||
(row, offset, value) => row.setShort(offset, value.asInstanceOf[Short]) | ||
case ByteType => | ||
(row, offset, value) => row.setByte(offset, value.asInstanceOf[Byte]) | ||
case d: DecimalType => | ||
(row, offset, value) => row.setDecimal(offset, value.asInstanceOf[Decimal], d.precision) | ||
} | ||
} | ||
|
||
/** | ||
* PivotFirst is a aggregate function used in the second phase of a two phase pivot to do the | ||
* required rearrangement of values into pivoted form. | ||
* | ||
* For example on an input of | ||
* A | B | ||
* --+-- | ||
* x | 1 | ||
* y | 2 | ||
* z | 3 | ||
* | ||
* with pivotColumn=A, valueColumn=B, and pivotColumnValues=[z,y] the output is [3,2]. | ||
* | ||
* @param pivotColumn column that determines which output position to put valueColumn in. | ||
* @param valueColumn the column that is being rearranged. | ||
* @param pivotColumnValues the list of pivotColumn values in the order of desired output. Values | ||
* not listed here will be ignored. | ||
*/ | ||
case class PivotFirst( | ||
pivotColumn: Expression, | ||
valueColumn: Expression, | ||
pivotColumnValues: Seq[Any], | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) extends ImperativeAggregate { | ||
|
||
override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil | ||
|
||
override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType) | ||
|
||
override val nullable: Boolean = false | ||
|
||
val valueDataType = valueColumn.dataType | ||
|
||
override val dataType: DataType = ArrayType(valueDataType) | ||
|
||
val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) | ||
|
||
val indexSize = pivotIndex.size | ||
|
||
private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) | ||
|
||
override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = { | ||
val pivotColValue = pivotColumn.eval(inputRow) | ||
if (pivotColValue != null) { | ||
// We ignore rows whose pivot column value is not in the list of pivot column values. | ||
val index = pivotIndex.getOrElse(pivotColValue, -1) | ||
if (index >= 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment to explain when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, for two different inputRows, we should not get the same index, right? |
||
val value = valueColumn.eval(inputRow) | ||
if (value != null) { | ||
updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) | ||
} | ||
} | ||
} | ||
} | ||
|
||
override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = { | ||
for (i <- 0 until indexSize) { | ||
if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) { | ||
val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType) | ||
updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value) | ||
} | ||
} | ||
} | ||
|
||
override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match { | ||
case d: DecimalType => | ||
// Per doc of setDecimal we need to do this instead of setNullAt for DecimalType. | ||
for (i <- 0 until indexSize) { | ||
mutableAggBuffer.setDecimal(mutableAggBufferOffset + i, null, d.precision) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a comment to explain why we need a special care for |
||
case _ => | ||
for (i <- 0 until indexSize) { | ||
mutableAggBuffer.setNullAt(mutableAggBufferOffset + i) | ||
} | ||
} | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val result = new Array[Any](indexSize) | ||
for (i <- 0 until indexSize) { | ||
result(i) = input.get(mutableAggBufferOffset + i, valueDataType) | ||
} | ||
new GenericArrayData(result) | ||
} | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
|
||
override lazy val aggBufferAttributes: Seq[AttributeReference] = | ||
pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we avoid of using lazy val for |
||
|
||
override lazy val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
|
||
override lazy val inputAggBufferAttributes: Seq[AttributeReference] = | ||
aggBufferAttributes.map(_.newInstance()) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we will decide which branch to use based on the datatypes, do we still have enough test coverage for this else branch?