Skip to content

Commit

Permalink
[SPARK-17641][SQL] Collect_list/Collect_set should not collect null v…
Browse files Browse the repository at this point in the history
…alues.

## What changes were proposed in this pull request?
We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently also (try to) collect null values, this is different from the original Hive implementation. This PR fixes this by adding a null check to the `Collect.update` method.

## How was this patch tested?
Added a regression test to `DataFrameAggregateSuite`.

Author: Herman van Hovell <[email protected]>

Closes #15208 from hvanhovell/SPARK-17641.

(cherry picked from commit 7d09232)
Signed-off-by: Reynold Xin <[email protected]>
  • Loading branch information
hvanhovell authored and rxin committed Sep 28, 2016
1 parent d358298 commit 0a69477
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ abstract class Collect extends ImperativeAggregate {
}

override def update(b: MutableRow, input: InternalRow): Unit = {
buffer += child.eval(input)
// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
val value = child.eval(input)
if (value != null) {
buffer += value
}
}

override def merge(buffer: MutableRow, input: InternalRow): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
assert(error.message.contains("collect_set() cannot have map type data"))
}

test("SPARK-17641: collect functions should not collect null values") {
val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b")
checkAnswer(
df.select(collect_list($"a"), collect_list($"b")),
Seq(Row(Seq("1", "1"), Seq(2, 2, 4)))
)
checkAnswer(
df.select(collect_set($"a"), collect_set($"b")),
Seq(Row(Seq("1"), Seq(2, 4)))
)
}

test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
Expand Down

0 comments on commit 0a69477

Please sign in to comment.