Skip to content

Commit

Permalink
Change array_agg to return null on no input rather than empty list (
Browse files Browse the repository at this point in the history
#11299)

* change array agg semantic for empty result

Signed-off-by: jayzhan211 <[email protected]>

* return null

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* fix order sensitive

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* add more test

Signed-off-by: jayzhan211 <[email protected]>

* fix null

Signed-off-by: jayzhan211 <[email protected]>

* fix multi-phase case

Signed-off-by: jayzhan211 <[email protected]>

* add comment

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* fix clone

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Jul 10, 2024
1 parent cc7484e commit d3f6372
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 54 deletions.
10 changes: 10 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1984,6 +1984,16 @@ impl ScalarValue {
Self::new_list(values, data_type, true)
}

/// Create ListArray with Null with specific data type
///
/// - new_null_list(i32, nullable, 1): `ListArray[NULL]`
pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize) -> Self {
let data_type = DataType::List(Field::new_list_field(data_type, nullable).into());
Self::List(Arc::new(ListArray::from(ArrayData::new_null(
&data_type, null_len,
))))
}

/// Converts `IntoIterator<Item = ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
///
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
let expected = vec![
"Projection: shapes.shape_id [shape_id:UInt32]",
" Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
];

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, false),
false
true
),])
);

Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl AggregateFunction {
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(false),
AggregateFunction::ArrayAgg => Ok(true),
}
}
}
Expand Down
17 changes: 11 additions & 6 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl AggregateExpr for ArrayAgg {
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
))
}

Expand All @@ -86,7 +86,7 @@ impl AggregateExpr for ArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
)])
}

Expand Down Expand Up @@ -137,8 +137,11 @@ impl Accumulator for ArrayAggAccumulator {
return Ok(());
}
assert!(values.len() == 1, "array_agg can only take 1 param!");

let val = Arc::clone(&values[0]);
self.values.push(val);
if val.len() > 0 {
self.values.push(val);
}
Ok(())
}

Expand All @@ -162,13 +165,15 @@ impl Accumulator for ArrayAggAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
// Transform Vec<ListArr> to ListArr

let element_arrays: Vec<&dyn Array> =
self.values.iter().map(|a| a.as_ref()).collect();

if element_arrays.is_empty() {
let arr = ScalarValue::new_list(&[], &self.datatype, self.nullable);
return Ok(ScalarValue::List(arr));
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
}

let concated_array = arrow::compute::concat(&element_arrays)?;
Expand Down
11 changes: 9 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl AggregateExpr for DistinctArrayAgg {
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
))
}

Expand All @@ -90,7 +90,7 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
)])
}

Expand Down Expand Up @@ -165,6 +165,13 @@ impl Accumulator for DistinctArrayAggAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
if values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatype.clone(),
self.nullable,
1,
));
}
let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable);
Ok(ScalarValue::List(arr))
}
Expand Down
12 changes: 10 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
false,
true,
))
}

Expand All @@ -111,7 +111,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
false, // This should be the same as field()
true, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types);
fields.push(Field::new_list(
Expand Down Expand Up @@ -309,6 +309,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatypes[0].clone(),
self.nullable,
1,
));
}

let values = self.values.clone();
let array = if self.reverse {
ScalarValue::new_list_from_iter(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ mod tests {
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
false,
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand All @@ -167,7 +167,7 @@ mod tests {
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
false,
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand Down
155 changes: 116 additions & 39 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1694,7 +1694,7 @@ SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT
query ?
SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test
----
[]
NULL

# csv_query_array_agg_one
query ?
Expand Down Expand Up @@ -1753,31 +1753,12 @@ NULL 4 29 1.260869565217 123 -117 23
NULL 5 -194 -13.857142857143 118 -101 14
NULL NULL 781 7.81 125 -117 100

# TODO: array_agg_distinct output is non-deterministic -- rewrite with array_sort(list_sort)
# unnest is also not available, so manually unnesting via CROSS JOIN
# additional count(1) forces array_agg_distinct instead of array_agg over aggregated by c2 data
#
# select with count to forces array_agg_distinct function, since single distinct expression is converted to group by by optimizer
# csv_query_array_agg_distinct
query III
WITH indices AS (
SELECT 1 AS idx UNION ALL
SELECT 2 AS idx UNION ALL
SELECT 3 AS idx UNION ALL
SELECT 4 AS idx UNION ALL
SELECT 5 AS idx
)
SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, dummy
FROM (
SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM aggregate_test_100
) data
CROSS JOIN indices
ORDER BY 1
----
1 5 100
2 5 100
3 5 100
4 5 100
5 5 100
query ?I
SELECT array_sort(array_agg(distinct c2)), count(1) FROM aggregate_test_100
----
[1, 2, 3, 4, 5] 100

# aggregate_time_min_and_max
query TT
Expand Down Expand Up @@ -2732,6 +2713,16 @@ SELECT COUNT(DISTINCT c1) FROM test

# TODO: aggregate_with_alias

# test_approx_percentile_cont_decimal_support
query TI
SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
----
a 4
b 5
c 4
d 4
e 4

# array_agg_zero
query ?
SELECT ARRAY_AGG([])
Expand All @@ -2744,28 +2735,114 @@ SELECT ARRAY_AGG([1])
----
[[1]]

# test_approx_percentile_cont_decimal_support
query TI
SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
# test array_agg with no row qualified
statement ok
create table t(a int, b float, c bigint) as values (1, 1.2, 2);

# returns NULL, follows DuckDB's behaviour
query ?
select array_agg(a) from t where a > 2;
----
a 4
b 5
c 4
d 4
e 4
NULL

query ?
select array_agg(b) from t where b > 3.1;
----
NULL

# array_agg_zero
query ?
SELECT ARRAY_AGG([]);
select array_agg(c) from t where c > 3;
----
[[]]
NULL

# array_agg_one
query ?I
select array_agg(c), count(1) from t where c > 3;
----
NULL 0

# returns 0 rows if group by is applied, follows DuckDB's behaviour
query ?
SELECT ARRAY_AGG([1]);
select array_agg(a) from t where a > 3 group by a;
----
[[1]]

query ?I
select array_agg(a), count(1) from t where a > 3 group by a;
----

# returns NULL, follows DuckDB's behaviour
query ?
select array_agg(distinct a) from t where a > 3;
----
NULL

query ?I
select array_agg(distinct a), count(1) from t where a > 3;
----
NULL 0

# returns 0 rows if group by is applied, follows DuckDB's behaviour
query ?
select array_agg(distinct a) from t where a > 3 group by a;
----

query ?I
select array_agg(distinct a), count(1) from t where a > 3 group by a;
----

# test order sensitive array agg
query ?
select array_agg(a order by a) from t where a > 3;
----
NULL

query ?
select array_agg(a order by a) from t where a > 3 group by a;
----

query ?I
select array_agg(a order by a), count(1) from t where a > 3 group by a;
----

statement ok
drop table t;

# test with no values
statement ok
create table t(a int, b float, c bigint);

query ?
select array_agg(a) from t;
----
NULL

query ?
select array_agg(b) from t;
----
NULL

query ?
select array_agg(c) from t;
----
NULL

query ?I
select array_agg(distinct a), count(1) from t;
----
NULL 0

query ?I
select array_agg(distinct b), count(1) from t;
----
NULL 0

query ?I
select array_agg(distinct b), count(1) from t;
----
NULL 0

statement ok
drop table t;


# array_agg_i32
statement ok
Expand Down

0 comments on commit d3f6372

Please sign in to comment.