Skip to content

Commit

Permalink
[SPARK-19691][SQL] Fix ClassCastException when calculating percentile…
Browse files Browse the repository at this point in the history
… of decimal column

## What changes were proposed in this pull request?
This pr fixed a class-cast exception below;
```
scala> spark.range(10).selectExpr("cast (id as decimal) as x").selectExpr("percentile(x, 0.5)").collect()
 java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to java.lang.Number
	at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141)
	at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109)
	at
```
This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by using `CatalystTypeConverters`.

## How was this patch tested?
Added a test in `DataFrameSuite`.

Author: Takeshi Yamamuro <[email protected]>

Closes apache#17028 from maropu/SPARK-19691.
  • Loading branch information
maropu authored and Yun Ni committed Feb 27, 2017
1 parent 46f4a19 commit a0ce01e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -61,7 +61,7 @@ case class Percentile(
frequencyExpression : Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes {

def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, Literal(1L), 0, 0)
Expand Down Expand Up @@ -130,15 +130,20 @@ case class Percentile(
}
}

override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
private def toDoubleValue(d: Any): Double = d match {
case d: Decimal => d.toDouble
case n: Number => n.doubleValue
}

override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
// Initialize new counts map instance here.
new OpenHashMap[Number, Long]()
new OpenHashMap[AnyRef, Long]()
}

override def update(
buffer: OpenHashMap[Number, Long],
input: InternalRow): OpenHashMap[Number, Long] = {
val key = child.eval(input).asInstanceOf[Number]
buffer: OpenHashMap[AnyRef, Long],
input: InternalRow): OpenHashMap[AnyRef, Long] = {
val key = child.eval(input).asInstanceOf[AnyRef]
val frqValue = frequencyExpression.eval(input)

// Null values are ignored in counts map.
Expand All @@ -155,32 +160,32 @@ case class Percentile(
}

override def merge(
buffer: OpenHashMap[Number, Long],
other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = {
buffer: OpenHashMap[AnyRef, Long],
other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
}
buffer
}

override def eval(buffer: OpenHashMap[Number, Long]): Any = {
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
generateOutput(getPercentiles(buffer))
}

private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
if (buffer.isEmpty) {
return Seq.empty
}

val sortedCounts = buffer.toSeq.sortBy(_._1)(
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumlatedCounts.last._2 - 1

percentages.map { percentile =>
getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
getPercentile(accumlatedCounts, maxPosition * percentile)
}
}

Expand All @@ -200,7 +205,7 @@ case class Percentile(
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
Expand All @@ -213,18 +218,17 @@ case class Percentile(
val lowerKey = aggreCounts(lowerIndex)._1
if (higher == lower) {
// no interpolation needed because position does not have a fraction
return lowerKey
return toDoubleValue(lowerKey)
}

val higherKey = aggreCounts(higherIndex)._1
if (higherKey == lowerKey) {
// no interpolation needed because lower position and higher position has the same key
return lowerKey
return toDoubleValue(lowerKey)
}

// Linear interpolation to get the exact percentile
return (higher - position) * lowerKey.doubleValue() +
(position - lower) * higherKey.doubleValue()
(higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
}

/**
Expand All @@ -238,7 +242,7 @@ case class Percentile(
}
}

override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
Expand All @@ -261,11 +265,11 @@ case class Percentile(
}
}

override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
val counts = new OpenHashMap[Number, Long]
val counts = new OpenHashMap[AnyRef, Long]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
Expand All @@ -274,7 +278,7 @@ case class Percentile(
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
val key = row.get(0, child.dataType).asInstanceOf[Number]
val key = row.get(0, child.dataType)
val count = row.get(1, LongType).asInstanceOf[Long]
counts.update(key, count)
sizeOfNextRow = ins.readInt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
Expand All @@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))

// Check empty serialize and deserialize
val buffer = new OpenHashMap[Number, Long]()
val buffer = new OpenHashMap[AnyRef, Long]()
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))

// Check non-empty buffer serializa and deserialize.
data.foreach { key =>
buffer.changeValue(key, 1L, _ + 1L)
buffer.changeValue(new Integer(key), 1L, _ + 1L)
}
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
}
Expand All @@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(childExpression, percentageExpression)

// Test with rows without frequency
val rows = (1 to count).map( x => Seq(x))
runTest( agg, rows, expectedPercentiles)
val rows = (1 to count).map(x => Seq(x))
runTest(agg, rows, expectedPercentiles)

// Test with row with frequency. Second and third columns are frequency in Int and Long
val countForFrequencyTest = 1000
val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong)
val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)

val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)

val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)

// Run test with Flatten data
val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
(1 to current).map( y => current )).map( Seq(_))
val flattenRows = (1 to countForFrequencyTest).flatMap(current =>
(1 to current).map(y => current )).map(Seq(_))
runTest(agg, flattenRows, expectedPercentilesWithFrquency)
}

Expand Down Expand Up @@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite {
}

val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
for ( dataType <- validDataTypes;
for (dataType <- validDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
Expand All @@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite {
StringType, DateType, TimestampType,
CalendarIntervalType, NullType)

for( dataType <- invalidDataTypes;
for(dataType <- invalidDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
Expand All @@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite {
s"'`a`' is of ${dataType.simpleString} type."))
}

for( dataType <- validDataTypes;
for(dataType <- validDataTypes;
frequencyType <- invalidFrequencyDataTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
Expand Down Expand Up @@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite {
agg.update(buffer, InternalRow(1, -5))
agg.eval(buffer)
}
assert( caught.getMessage.startsWith("Negative values found in "))
assert(caught.getMessage.startsWith("Negative values found in "))
}

private def compareEquals(
left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
left.size == right.size && left.forall { case (key, count) =>
right.apply(key) == count
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j")
checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
}

test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
}
}

0 comments on commit a0ce01e

Please sign in to comment.